@@ -210,7 +210,8 @@ def __init__(self, inputs=None, outputs=None, name=None):
210
210
# check type of inputs and outputs
211
211
check_order = ['inputs' , 'outputs' ]
212
212
for co , check_argu in enumerate ([inputs , outputs ]):
213
- if isinstance (check_argu , tf_ops ._TensorLike ) or tf_ops .is_dense_tensor_like (check_argu ):
213
+ if isinstance (check_argu ,
214
+ (tf .Tensor , tf .SparseTensor , tf .Variable )) or tf_ops .is_dense_tensor_like (check_argu ):
214
215
pass
215
216
elif isinstance (check_argu , list ):
216
217
if len (check_argu ) == 0 :
@@ -219,8 +220,9 @@ def __init__(self, inputs=None, outputs=None, name=None):
219
220
"It should be either Tensor or a list of Tensor."
220
221
)
221
222
for idx in range (len (check_argu )):
222
- if not isinstance (check_argu [idx ], tf_ops ._TensorLike ) or not tf_ops .is_dense_tensor_like (
223
- check_argu [idx ]):
223
+ if not isinstance (check_argu [idx ],
224
+ (tf .Tensor , tf .SparseTensor , tf .Variable )) or not tf_ops .is_dense_tensor_like (
225
+ check_argu [idx ]):
224
226
raise TypeError (
225
227
"The argument `%s` should be either Tensor or a list of Tensor " % (check_order [co ]) +
226
228
"but the %s[%d] is detected as %s" % (check_order [co ], idx , type (check_argu [idx ]))
0 commit comments