Skip to content

Commit 146023f

Browse files
committed
Fix DINOV2 test.
1 parent 95e8d84 commit 146023f

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

keras_hub/src/models/dinov2/dinov2_backbone.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ class DINOV2Backbone(FeaturePyramidBackbone):
1919
DINOV2 model with any number of layers, heads, and embedding dimensions. To
2020
load preset architectures and weights, use the `from_preset` constructor.
2121
22+
Note that this backbone is a Feature Pyramid Backbone that can output
23+
intermediate feature maps from different stages of the model. See the
24+
example below for how to access these feature pyramid outputs.
25+
2226
Note that this backbone supports interpolation of the position embeddings
2327
to the input image shape. This is useful when the input image shape is
2428
different from the shape used to train the position embeddings. The
@@ -97,6 +101,16 @@ class DINOV2Backbone(FeaturePyramidBackbone):
97101
position_embedding_shape=(518, 518),
98102
)
99103
model(input_data)
104+
105+
# Accessing feature pyramid outputs.
106+
backbone = keras_hub.models.DINOV2Backbone.from_preset(
107+
"dinov2_base", image_shape=(224, 224, 3)
108+
)
109+
model = keras.Model(
110+
inputs=backbone.inputs,
111+
outputs=backbone.pyramid_outputs,
112+
)
113+
features = model(input_data)
100114
```
101115
"""
102116

keras_hub/src/models/dinov2/dinov2_backbone_test.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,26 @@ def setUp(self):
1919
"layer_scale_init_value": 1.0,
2020
"num_register_tokens": 0,
2121
"use_swiglu_ffn": False,
22-
"image_shape": (64, 64, 3),
22+
"image_shape": (70, 70, 3),
23+
"name": "dinov2_backbone",
2324
}
2425
self.input_data = {
25-
"images": ops.ones((2, 64, 64, 3)),
26+
"images": ops.ones((2, 70, 70, 3)),
2627
}
2728

2829
def test_backbone_basics(self):
2930
patch_size = self.init_kwargs["patch_size"]
3031
image_size = self.init_kwargs["image_shape"][0]
3132
hidden_dim = self.init_kwargs["hidden_dim"]
3233
sequence_length = (image_size // patch_size) ** 2 + 1
33-
self.run_backbone_test(
34+
self.run_vision_backbone_test(
3435
cls=DINOV2Backbone,
3536
init_kwargs=self.init_kwargs,
3637
input_data=self.input_data,
3738
expected_output_shape=(2, sequence_length, hidden_dim),
39+
expected_pyramid_output_keys=["Stem", "Stage1", "Stage2"],
40+
expected_pyramid_image_sizes=[(sequence_length, hidden_dim)] * 3,
41+
run_data_format_check=False,
3842
)
3943

4044
@pytest.mark.large
@@ -107,10 +111,11 @@ def setUp(self):
107111
"layer_scale_init_value": 1.0,
108112
"num_register_tokens": 4,
109113
"use_swiglu_ffn": True,
110-
"image_shape": (64, 64, 3),
114+
"image_shape": (70, 70, 3),
115+
"name": "dinov2_backbone",
111116
}
112117
self.input_data = {
113-
"images": ops.ones((2, 64, 64, 3)),
118+
"images": ops.ones((2, 70, 70, 3)),
114119
}
115120

116121
def test_backbone_basics(self):
@@ -121,11 +126,14 @@ def test_backbone_basics(self):
121126
sequence_length = (
122127
(image_size // patch_size) ** 2 + 1 + num_register_tokens
123128
)
124-
self.run_backbone_test(
129+
self.run_vision_backbone_test(
125130
cls=DINOV2Backbone,
126131
init_kwargs=self.init_kwargs,
127132
input_data=self.input_data,
128133
expected_output_shape=(2, sequence_length, hidden_dim),
134+
expected_pyramid_output_keys=["Stem", "Stage1", "Stage2"],
135+
expected_pyramid_image_sizes=[(sequence_length, hidden_dim)] * 3,
136+
run_data_format_check=False,
129137
)
130138

131139
@pytest.mark.large

keras_hub/src/tests/test_case.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,10 +538,11 @@ def run_vision_backbone_test(
538538

539539
self.assertIsInstance(output_data, dict)
540540
self.assertEqual(
541-
list(output_data.keys()), list(backbone.pyramid_outputs.keys())
541+
sorted(output_data.keys()),
542+
sorted(backbone.pyramid_outputs.keys()),
542543
)
543544
self.assertEqual(
544-
list(output_data.keys()), expected_pyramid_output_keys
545+
sorted(output_data.keys()), sorted(expected_pyramid_output_keys)
545546
)
546547
# check height and width of each level.
547548
for i, (k, v) in enumerate(output_data.items()):

0 commit comments

Comments
 (0)