Skip to content

Commit 12f70f5

Browse files
SimpleML Teamcopybara-github
authored andcommitted
[TF-DF] Add a test to ensure an error is thrown when the train & test datasets are mismatched.
PiperOrigin-RevId: 627641579
1 parent bffa547 commit 12f70f5

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tensorflow_decision_forests/keras/keras_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,31 @@ def test_model_iris(self):
909909
predictions = model.predict(tf_test)
910910
logging.info("Predictions: %s", predictions)
911911

912+
def test_gbt_dataset_mismatch(self):
913+
"""Ensure an error is thrown when the test dataset is missing a feature.
914+
"""
915+
916+
tf_train = keras.pd_dataframe_to_tf_dataset(
917+
synthetic_pd_dataset(num_examples=100,
918+
num_numerical_features=5), label="label")
919+
tf_validation = keras.pd_dataframe_to_tf_dataset(
920+
synthetic_pd_dataset(num_examples=50,
921+
num_numerical_features=5), label="label")
922+
tf_test = keras.pd_dataframe_to_tf_dataset(
923+
synthetic_pd_dataset(num_examples=50, num_numerical_features=4),
924+
label="label")
925+
926+
model = keras.GradientBoostedTreesModel(num_trees=10,)
927+
928+
model.compile(metrics=["accuracy"])
929+
930+
model.fit(x=tf_train, validation_data=tf_validation)
931+
model.summary()
932+
with self.assertRaises(ValueError):
933+
model.evaluate(tf_test)
934+
with self.assertRaises(ValueError):
935+
model.predict(tf_test)
936+
912937
def test_model_abalone(self):
913938
"""Test on the Abalone dataset.
914939

0 commit comments

Comments
 (0)