2323
2424from keras_cv import bounding_box
2525from keras_cv .api_export import keras_cv_export
26+ from keras_cv .backend import config
2627from keras_cv .backend import keras
2728from keras_cv .backend import ops
2829from keras_cv .backend import scope
@@ -411,14 +412,16 @@ def get_random_transformation(
411412 def call (self , inputs ):
412413 # try to convert a given backend native tensor to TensorFlow tensor
413414 # before passing it over to TFDataScope
415+ is_tf_backend = config .backend () == "tensorflow"
416+ is_in_tf_graph = not tf .executing_eagerly ()
414417 contains_ragged = lambda y : any (
415418 tree .map_structure (
416419 lambda x : isinstance (x , (tf .RaggedTensor , tf .SparseTensor )),
417420 tree .flatten (y ),
418421 )
419422 )
420423 inputs_contain_ragged = contains_ragged (inputs )
421- if not inputs_contain_ragged :
424+ if not is_tf_backend and not inputs_contain_ragged :
422425 inputs = tree .map_structure (
423426 lambda x : tf .convert_to_tensor (x ), inputs
424427 )
@@ -444,13 +447,15 @@ def call(self, inputs):
444447 # backend native tensors. This is to avoid breaking TF data
445448 # pipelines that can't easily be ported to become backend
446449 # agnostic.
447- if not inputs_contain_ragged and not contains_ragged (outputs ):
448- outputs = tree .map_structure (
449- # some layers return None, handle that case when
450- # converting to tensors
451- lambda x : ops .convert_to_tensor (x ) if x is not None else x ,
452- outputs ,
453- )
450+ # Skip this step for TF backend or if in `tf.graph` like `tf.data`.
451+ if not is_tf_backend and not is_in_tf_graph :
452+ if not inputs_contain_ragged and not contains_ragged (outputs ):
453+ outputs = tree .map_structure (
454+ # some layers return None, handle that case when
455+ # converting to tensors
456+ lambda x : ops .convert_to_tensor (x ) if x is not None else x ,
457+ outputs ,
458+ )
454459 return outputs
455460
456461 def _augment (self , inputs ):
0 commit comments