diff --git a/keras_hub/src/models/clip/clip_layers.py b/keras_hub/src/models/clip/clip_layers.py index 38a30119e2..f49a83fc4f 100644 --- a/keras_hub/src/models/clip/clip_layers.py +++ b/keras_hub/src/models/clip/clip_layers.py @@ -52,15 +52,12 @@ def build(self, input_shape): self.position_ids = self.add_weight( shape=(1, self.num_positions), initializer="zeros", - # Let the backend determine the int dtype. For example, tf - # requires int64 for correct device placement, whereas jax and torch - # don't. dtype=int, trainable=False, name="position_ids", ) self.patch_embedding.build(input_shape) - self.position_embedding.build(self.position_ids.shape) + self.position_embedding.build((1, self.num_positions)) def call(self, inputs, training=None): x = inputs diff --git a/keras_hub/src/models/dinov2/dinov2_backbone_test.py b/keras_hub/src/models/dinov2/dinov2_backbone_test.py index 05fe7c3241..877238c5b1 100644 --- a/keras_hub/src/models/dinov2/dinov2_backbone_test.py +++ b/keras_hub/src/models/dinov2/dinov2_backbone_test.py @@ -20,6 +20,7 @@ def setUp(self): "num_register_tokens": 0, "use_swiglu_ffn": False, "image_shape": (64, 64, 3), + "name": "dinov2_backbone", } self.input_data = { "images": ops.ones((2, 64, 64, 3)), @@ -35,7 +36,7 @@ def test_backbone_basics(self): init_kwargs=self.init_kwargs, input_data=self.input_data, expected_output_shape=(2, sequence_length, hidden_dim), - run_quantization_check=False, + run_quantization_check=False, # TODO: Fix weight count mismatch ) @pytest.mark.large @@ -127,7 +128,7 @@ def test_backbone_basics(self): init_kwargs=self.init_kwargs, input_data=self.input_data, expected_output_shape=(2, sequence_length, hidden_dim), - run_quantization_check=False, + run_quantization_check=False, # TODO: Fix weight count mismatch ) @pytest.mark.large diff --git a/keras_hub/src/tests/test_case.py b/keras_hub/src/tests/test_case.py index f70ab78840..4a96ab8a30 100644 --- a/keras_hub/src/tests/test_case.py +++ b/keras_hub/src/tests/test_case.py @@ -381,7 +381,6 @@ def _get_supported_layers(mode): ) # Ensure the correct `dtype` is set for sublayers or submodels in # `init_kwargs`. - original_init_kwargs = init_kwargs.copy() for k, v in init_kwargs.items(): if isinstance(v, keras.Layer): config = v.get_config() @@ -408,8 +407,6 @@ def _get_supported_layers(mode): # Check weights loading. weights = model.get_weights() revived_model.set_weights(weights) - # Restore `init_kwargs`. - init_kwargs = original_init_kwargs def run_model_saving_test( self,