Skip to content

Commit 081b15d

Browse files
committed
Hot-fix: do not share tags between ModelHubMixin siblings (#2394)
* Hot-fix: do not share tags between ModelHubMixin sibligs * reference regression test
1 parent 4c7aa33 commit 081b15d

File tree

2 files changed

+43
-6
lines changed

2 files changed

+43
-6
lines changed

src/huggingface_hub/hub_mixin.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,21 @@ def __init_subclass__(
230230
tags.append("model_hub_mixin")
231231

232232
# Initialize MixinInfo if not existent
233-
if not hasattr(cls, "_hub_mixin_info"):
234-
cls._hub_mixin_info = MixinInfo(
235-
model_card_template=model_card_template,
236-
model_card_data=ModelCardData(),
237-
)
238-
info = cls._hub_mixin_info
233+
info = MixinInfo(model_card_template=model_card_template, model_card_data=ModelCardData())
234+
235+
# If parent class has a MixinInfo, inherit from it as a copy
236+
if hasattr(cls, "_hub_mixin_info"):
237+
# Inherit model card template from parent class if not explicitly set
238+
if model_card_template == DEFAULT_MODEL_CARD:
239+
info.model_card_template = cls._hub_mixin_info.model_card_template
240+
241+
# Inherit from parent model card data
242+
info.model_card_data = ModelCardData(**cls._hub_mixin_info.model_card_data.to_dict())
243+
244+
# Inherit other info
245+
info.docs_url = cls._hub_mixin_info.docs_url
246+
info.repo_url = cls._hub_mixin_info.repo_url
247+
cls._hub_mixin_info = info
239248

240249
if languages is not None:
241250
warnings.warn(
@@ -269,6 +278,8 @@ def __init_subclass__(
269278
else:
270279
info.model_card_data.tags = tags
271280

281+
info.model_card_data.tags = sorted(set(info.model_card_data.tags))
282+
272283
# Handle encoders/decoders for args
273284
cls._hub_mixin_coders = coders or {}
274285
cls._hub_mixin_jsonable_custom_types = tuple(cls._hub_mixin_coders.keys())

tests/test_hub_mixin_pytorch.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,20 @@ def __init__(self, config: Namespace):
111111
super().__init__()
112112
self.config = config
113113

114+
class DummyModelWithTag1(nn.Module, PyTorchModelHubMixin, tags=["tag1"]):
115+
"""Used to test tags not shared between sibling classes (only inheritance)."""
116+
117+
class DummyModelWithTag2(nn.Module, PyTorchModelHubMixin, tags=["tag2"]):
118+
"""Used to test tags not shared between sibling classes (only inheritance)."""
119+
114120
else:
115121
DummyModel = None
116122
DummyModelWithModelCard = None
117123
DummyModelNoConfig = None
118124
DummyModelWithConfigAndKwargs = None
119125
DummyModelWithModelCardAndCustomKwargs = None
126+
DummyModelWithTag1 = None
127+
DummyModelWithTag2 = None
120128

121129

122130
@requires("torch")
@@ -451,3 +459,21 @@ def test_config_with_custom_coders(self):
451459
assert isinstance(reloaded.config, Namespace)
452460
assert reloaded.config.a == 1
453461
assert reloaded.config.b == 2
462+
463+
def test_inheritance_and_sibling_classes(self):
464+
"""
465+
Test tags are not shared between sibling classes.
466+
467+
Regression test for #2394.
468+
See https://github.com/huggingface/huggingface_hub/pull/2394.
469+
"""
470+
assert DummyModelWithTag1._hub_mixin_info.model_card_data.tags == [
471+
"model_hub_mixin",
472+
"pytorch_model_hub_mixin",
473+
"tag1",
474+
]
475+
assert DummyModelWithTag2._hub_mixin_info.model_card_data.tags == [
476+
"model_hub_mixin",
477+
"pytorch_model_hub_mixin",
478+
"tag2",
479+
]

0 commit comments

Comments
 (0)