11
11
import tensorflow as tf
12
12
13
13
from backend_test_base import Tf2OnnxBackendTestBase
14
- from common import unittest_main , check_opset_min_version , check_tf_min_version , check_tf_max_version , skip_tf2
14
+ from common import unittest_main , check_opset_min_version , check_tf_min_version
15
15
16
16
17
17
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -78,7 +78,6 @@ def func(x, y):
78
78
output_names_with_port = ["output:0" ]
79
79
self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port )
80
80
81
- @skip_tf2 ()
82
81
def test_nest_cond (self ):
83
82
x_val = np .array ([1 , 2 , 3 ], dtype = np .float32 )
84
83
y_val = np .array ([4 , 5 , 6 ], dtype = np .float32 )
@@ -100,7 +99,6 @@ def cond_graph2():
100
99
output_names_with_port = ["output:0" ]
101
100
self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port )
102
101
103
- @skip_tf2 ()
104
102
def test_while_loop_in_cond (self ):
105
103
x_val = np .array ([1 , 2 , 3 ], dtype = np .float32 )
106
104
y_val = np .array ([4 , 5 , 6 ], dtype = np .float32 )
@@ -178,7 +176,6 @@ def func(x, y):
178
176
output_names_with_port = ["output:0" ]
179
177
self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port )
180
178
181
- @check_tf_max_version ("1.15" , "import issue in tf-2.1, fix later" )
182
179
def test_case_without_default_branch (self ):
183
180
def func (x , y ):
184
181
x = tf .add (x , 1 , name = "add_x" )
@@ -212,7 +209,6 @@ def func(x, y):
212
209
output_names_with_port = ["output:0" ]
213
210
self .run_test_case (func , feed_dict , input_names_with_port , output_names_with_port )
214
211
215
- @skip_tf2 ()
216
212
def test_nest_case (self ):
217
213
x_val = np .array ([1 , 2 , 3 ], dtype = np .float32 )
218
214
y_val = np .array ([4 , 5 , 6 ], dtype = np .float32 )
@@ -234,7 +230,6 @@ def case_graph():
234
230
235
231
@check_tf_min_version ("1.8" , "shape inference for Reshape op screws up" )
236
232
@check_opset_min_version (9 , "ConstantOfShape" )
237
- @skip_tf2 ()
238
233
def test_cond_with_different_output_shape (self ):
239
234
input_shape = (10 , 5 , 20 )
240
235
def func (inputs , shape ):
0 commit comments