Skip to content

Commit b89569c

Browse files
Allow Input to be optional to take None inputs, similar to what keras3 has.
PiperOrigin-RevId: 819935785
1 parent 0dec184 commit b89569c

15 files changed

+244
-38
lines changed

tf_keras/api/golden/v1/tensorflow.keras.__internal__.legacy.layers.-input-spec.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ tf_class {
44
is_instance: "<type \'object\'>"
55
member_method {
66
name: "__init__"
7-
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
7+
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\', \'optional\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
88
}
99
member_method {
1010
name: "from_config"

tf_keras/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ tf_class {
129129
}
130130
member_method {
131131
name: "__init__"
132-
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
132+
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
133133
}
134134
member_method {
135135
name: "add_loss"

tf_keras/api/golden/v1/tensorflow.keras.layers.-input-spec.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ tf_class {
44
is_instance: "<type \'object\'>"
55
member_method {
66
name: "__init__"
7-
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
7+
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\', \'optional\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
88
}
99
member_method {
1010
name: "from_config"

tf_keras/api/golden/v1/tensorflow.keras.layers.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ tf_module {
482482
}
483483
member_method {
484484
name: "Input"
485-
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
485+
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
486486
}
487487
member_method {
488488
name: "add"

tf_keras/api/golden/v1/tensorflow.keras.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,6 @@ tf_module {
9090
}
9191
member_method {
9292
name: "Input"
93-
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
93+
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
9494
}
9595
}

tf_keras/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ tf_class {
129129
}
130130
member_method {
131131
name: "__init__"
132-
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
132+
argspec: "args=[\'self\', \'input_shape\', \'batch_size\', \'dtype\', \'input_tensor\', \'sparse\', \'name\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
133133
}
134134
member_method {
135135
name: "add_loss"

tf_keras/api/golden/v2/tensorflow.keras.layers.-input-spec.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ tf_class {
44
is_instance: "<type \'object\'>"
55
member_method {
66
name: "__init__"
7-
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\'], "
7+
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\', \'allow_last_axis_squeeze\', \'name\', \'optional\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
88
}
99
member_method {
1010
name: "from_config"

tf_keras/api/golden/v2/tensorflow.keras.layers.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ tf_module {
538538
}
539539
member_method {
540540
name: "Input"
541-
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
541+
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
542542
}
543543
member_method {
544544
name: "add"

tf_keras/api/golden/v2/tensorflow.keras.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,6 @@ tf_module {
9595
}
9696
member_method {
9797
name: "Input"
98-
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
98+
argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\', \'ragged\', \'type_spec\', \'optional\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'False\'], "
9999
}
100100
}

tf_keras/engine/data_adapter.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,9 @@ def _is_tensor(v):
231231
return True
232232
return False
233233

234-
return all(_is_tensor(v) for v in flat_inputs)
234+
return all(
235+
_is_tensor(v) for v in flat_inputs if v is not None
236+
) and any(_is_tensor(v) for v in flat_inputs)
235237

236238
def __init__(
237239
self,
@@ -259,7 +261,8 @@ def __init__(
259261
inputs = pack_x_y_sample_weight(x, y, sample_weights)
260262

261263
num_samples = set(
262-
int(i.shape[0]) for i in tf.nest.flatten(inputs)
264+
int(i.shape[0])
265+
for i in tf.nest.flatten(inputs) if i is not None
263266
).pop()
264267
_check_data_cardinality(inputs)
265268

@@ -386,7 +389,7 @@ def slice_inputs(self, indices_dataset, inputs):
386389

387390
def grab_batch(i, data):
388391
return tf.nest.map_structure(
389-
lambda d: tf.gather(d, i, axis=0), data
392+
lambda d: tf.gather(d, i, axis=0) if d is not None else d, data
390393
)
391394

392395
dataset = dataset.map(grab_batch, num_parallel_calls=tf.data.AUTOTUNE)
@@ -459,7 +462,9 @@ def _is_array_like(v):
459462
if not TensorLikeDataAdapter.can_handle(
460463
x, y
461464
) and not CompositeTensorDataAdapter.can_handle(x, y):
462-
return all(_is_array_like(v) for v in flat_inputs)
465+
return all(
466+
_is_array_like(v) for v in flat_inputs if v is not None
467+
) and any(v is not None for v in flat_inputs)
463468
else:
464469
return False
465470

@@ -496,7 +501,9 @@ def dynamic_shape_like(t):
496501
shape[0] = None
497502
return tuple(shape)
498503

499-
flat_dtypes = [inp.dtype for inp in flat_inputs]
504+
flat_dtypes = [
505+
inp.dtype for inp in flat_inputs if inp is not None
506+
]
500507
contiguous = True
501508
if self._shuffle and self._shuffle != "batch":
502509
contiguous = False
@@ -509,15 +516,24 @@ def grab_batch(indices):
509516
# to a Tensor may force it into memory..
510517
def py_method(ind):
511518
def slice_array(data):
519+
if data is None:
520+
return None
512521
return training_utils.slice_arrays(
513522
data, ind.numpy(), contiguous=contiguous
514523
)
515524

516-
return [slice_array(inp) for inp in flat_inputs]
525+
return [slice_array(inp) for inp in flat_inputs if inp is not None]
517526

518-
flat_out = tf.py_function(py_method, [indices], flat_dtypes)
519-
for v, original_inp in zip(flat_out, flat_inputs):
520-
v.set_shape(dynamic_shape_like(original_inp))
527+
results = tf.py_function(py_method, [indices], flat_dtypes)
528+
results_it = iter(results)
529+
flat_out = []
530+
for original_inp in flat_inputs:
531+
if original_inp is None:
532+
flat_out.append(None)
533+
else:
534+
v = next(results_it)
535+
v.set_shape(dynamic_shape_like(original_inp))
536+
flat_out.append(v)
521537
return tf.nest.pack_sequence_as(inputs, flat_out)
522538

523539
dataset = indices_dataset.map(
@@ -608,8 +624,10 @@ def _is_tensor_or_composite(v):
608624
return True
609625
return _is_composite(v)
610626

611-
return any(_is_composite(v) for v in flat_inputs) and all(
612-
_is_tensor_or_composite(v) for v in flat_inputs
627+
return any(
628+
_is_composite(v) for v in flat_inputs if v is not None
629+
) and all(
630+
_is_tensor_or_composite(v) for v in flat_inputs if v is not None
613631
)
614632

615633
def __init__(
@@ -1944,14 +1962,20 @@ def single_batch_iterator(
19441962

19451963

19461964
def _check_data_cardinality(data):
1947-
num_samples = set(int(i.shape[0]) for i in tf.nest.flatten(data))
1965+
num_samples = set(
1966+
int(i.shape[0])
1967+
for i in tf.nest.flatten(data)
1968+
if i is not None
1969+
)
19481970
if len(num_samples) > 1:
19491971
msg = "Data cardinality is ambiguous:\n"
19501972
for label, single_data in zip(["x", "y", "sample_weight"], data):
19511973
msg += " {} sizes: {}\n".format(
19521974
label,
19531975
", ".join(
1954-
str(i.shape[0]) for i in tf.nest.flatten(single_data)
1976+
str(i.shape[0])
1977+
for i in tf.nest.flatten(single_data)
1978+
if i is not None
19551979
),
19561980
)
19571981
msg += "Make sure all arrays contain the same number of samples."

0 commit comments

Comments
 (0)