Skip to content

Commit aafc833

Browse files
authored
Merge pull request #962 from onnx/gs/enable-ut
enable all cond ut for tf-2.x
2 parents 9fedcc2 + bf1ad39 commit aafc833

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

tests/test_cond.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import tensorflow as tf
1212

1313
from backend_test_base import Tf2OnnxBackendTestBase
14-
from common import unittest_main, check_opset_min_version, check_tf_min_version, check_tf_max_version, skip_tf2
14+
from common import unittest_main, check_opset_min_version, check_tf_min_version
1515

1616

1717
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -78,7 +78,6 @@ def func(x, y):
7878
output_names_with_port = ["output:0"]
7979
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port)
8080

81-
@skip_tf2()
8281
def test_nest_cond(self):
8382
x_val = np.array([1, 2, 3], dtype=np.float32)
8483
y_val = np.array([4, 5, 6], dtype=np.float32)
@@ -100,7 +99,6 @@ def cond_graph2():
10099
output_names_with_port = ["output:0"]
101100
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port)
102101

103-
@skip_tf2()
104102
def test_while_loop_in_cond(self):
105103
x_val = np.array([1, 2, 3], dtype=np.float32)
106104
y_val = np.array([4, 5, 6], dtype=np.float32)
@@ -178,7 +176,6 @@ def func(x, y):
178176
output_names_with_port = ["output:0"]
179177
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port)
180178

181-
@check_tf_max_version("1.15", "import issue in tf-2.1, fix later")
182179
def test_case_without_default_branch(self):
183180
def func(x, y):
184181
x = tf.add(x, 1, name="add_x")
@@ -212,7 +209,6 @@ def func(x, y):
212209
output_names_with_port = ["output:0"]
213210
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port)
214211

215-
@skip_tf2()
216212
def test_nest_case(self):
217213
x_val = np.array([1, 2, 3], dtype=np.float32)
218214
y_val = np.array([4, 5, 6], dtype=np.float32)
@@ -234,7 +230,6 @@ def case_graph():
234230

235231
@check_tf_min_version("1.8", "shape inference for Reshape op screws up")
236232
@check_opset_min_version(9, "ConstantOfShape")
237-
@skip_tf2()
238233
def test_cond_with_different_output_shape(self):
239234
input_shape = (10, 5, 20)
240235
def func(inputs, shape):

0 commit comments

Comments
 (0)