@@ -231,7 +231,9 @@ def _is_tensor(v):
231231 return True
232232 return False
233233
234- return all (_is_tensor (v ) for v in flat_inputs )
234+ return all (
235+ _is_tensor (v ) for v in flat_inputs if v is not None
236+ ) and any (_is_tensor (v ) for v in flat_inputs )
235237
236238 def __init__ (
237239 self ,
@@ -259,7 +261,8 @@ def __init__(
259261 inputs = pack_x_y_sample_weight (x , y , sample_weights )
260262
261263 num_samples = set (
262- int (i .shape [0 ]) for i in tf .nest .flatten (inputs )
264+ int (i .shape [0 ])
265+ for i in tf .nest .flatten (inputs ) if i is not None
263266 ).pop ()
264267 _check_data_cardinality (inputs )
265268
@@ -386,7 +389,7 @@ def slice_inputs(self, indices_dataset, inputs):
386389
387390 def grab_batch (i , data ):
388391 return tf .nest .map_structure (
389- lambda d : tf .gather (d , i , axis = 0 ), data
392+ lambda d : tf .gather (d , i , axis = 0 ) if d is not None else d , data
390393 )
391394
392395 dataset = dataset .map (grab_batch , num_parallel_calls = tf .data .AUTOTUNE )
@@ -459,7 +462,9 @@ def _is_array_like(v):
459462 if not TensorLikeDataAdapter .can_handle (
460463 x , y
461464 ) and not CompositeTensorDataAdapter .can_handle (x , y ):
462- return all (_is_array_like (v ) for v in flat_inputs )
465+ return all (
466+ _is_array_like (v ) for v in flat_inputs if v is not None
467+ ) and any (v is not None for v in flat_inputs )
463468 else :
464469 return False
465470
@@ -496,7 +501,9 @@ def dynamic_shape_like(t):
496501 shape [0 ] = None
497502 return tuple (shape )
498503
499- flat_dtypes = [inp .dtype for inp in flat_inputs ]
504+ flat_dtypes = [
505+ inp .dtype for inp in flat_inputs if inp is not None
506+ ]
500507 contiguous = True
501508 if self ._shuffle and self ._shuffle != "batch" :
502509 contiguous = False
@@ -509,15 +516,24 @@ def grab_batch(indices):
509516 # to a Tensor may force it into memory..
510517 def py_method (ind ):
511518 def slice_array (data ):
519+ if data is None :
520+ return None
512521 return training_utils .slice_arrays (
513522 data , ind .numpy (), contiguous = contiguous
514523 )
515524
516- return [slice_array (inp ) for inp in flat_inputs ]
525+ return [slice_array (inp ) for inp in flat_inputs if inp is not None ]
517526
518- flat_out = tf .py_function (py_method , [indices ], flat_dtypes )
519- for v , original_inp in zip (flat_out , flat_inputs ):
520- v .set_shape (dynamic_shape_like (original_inp ))
527+ results = tf .py_function (py_method , [indices ], flat_dtypes )
528+ results_it = iter (results )
529+ flat_out = []
530+ for original_inp in flat_inputs :
531+ if original_inp is None :
532+ flat_out .append (None )
533+ else :
534+ v = next (results_it )
535+ v .set_shape (dynamic_shape_like (original_inp ))
536+ flat_out .append (v )
521537 return tf .nest .pack_sequence_as (inputs , flat_out )
522538
523539 dataset = indices_dataset .map (
@@ -608,8 +624,10 @@ def _is_tensor_or_composite(v):
608624 return True
609625 return _is_composite (v )
610626
611- return any (_is_composite (v ) for v in flat_inputs ) and all (
612- _is_tensor_or_composite (v ) for v in flat_inputs
627+ return any (
628+ _is_composite (v ) for v in flat_inputs if v is not None
629+ ) and all (
630+ _is_tensor_or_composite (v ) for v in flat_inputs if v is not None
613631 )
614632
615633 def __init__ (
@@ -1944,14 +1962,20 @@ def single_batch_iterator(
19441962
19451963
19461964def _check_data_cardinality (data ):
1947- num_samples = set (int (i .shape [0 ]) for i in tf .nest .flatten (data ))
1965+ num_samples = set (
1966+ int (i .shape [0 ])
1967+ for i in tf .nest .flatten (data )
1968+ if i is not None
1969+ )
19481970 if len (num_samples ) > 1 :
19491971 msg = "Data cardinality is ambiguous:\n "
19501972 for label , single_data in zip (["x" , "y" , "sample_weight" ], data ):
19511973 msg += " {} sizes: {}\n " .format (
19521974 label ,
19531975 ", " .join (
1954- str (i .shape [0 ]) for i in tf .nest .flatten (single_data )
1976+ str (i .shape [0 ])
1977+ for i in tf .nest .flatten (single_data )
1978+ if i is not None
19551979 ),
19561980 )
19571981 msg += "Make sure all arrays contain the same number of samples."
0 commit comments