@@ -909,6 +909,31 @@ def test_model_iris(self):
909909 predictions = model .predict (tf_test )
910910 logging .info ("Predictions: %s" , predictions )
911911
912+ def test_gbt_dataset_mismatch (self ):
913+ """Ensure an error is thrown when the test dataset is missing a feature.
914+ """
915+
916+ tf_train = keras .pd_dataframe_to_tf_dataset (
917+ synthetic_pd_dataset (num_examples = 100 ,
918+ num_numerical_features = 5 ), label = "label" )
919+ tf_validation = keras .pd_dataframe_to_tf_dataset (
920+ synthetic_pd_dataset (num_examples = 50 ,
921+ num_numerical_features = 5 ), label = "label" )
922+ tf_test = keras .pd_dataframe_to_tf_dataset (
923+ synthetic_pd_dataset (num_examples = 50 , num_numerical_features = 4 ),
924+ label = "label" )
925+
926+ model = keras .GradientBoostedTreesModel (num_trees = 10 ,)
927+
928+ model .compile (metrics = ["accuracy" ])
929+
930+ model .fit (x = tf_train , validation_data = tf_validation )
931+ model .summary ()
932+ with self .assertRaises (ValueError ):
933+ model .evaluate (tf_test )
934+ with self .assertRaises (ValueError ):
935+ model .predict (tf_test )
936+
912937 def test_model_abalone (self ):
913938 """Test on the Abalone dataset.
914939
0 commit comments