Skip to content

Commit 02e1845

Browse files
authored
Merge pull request #42 from DeepLabCut/maxim/rename_humanbody_shapshot
`download_huggingface_model()`: allow passing `rename_mapping` as `str`
2 parents 3c0cb18 + 1c78f82 commit 02e1845

File tree

1 file changed

+33
-9
lines changed

1 file changed

+33
-9
lines changed

dlclibrary/dlcmodelzoo/modelzoo_download.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def get_available_datasets() -> list[str]:
9494

9595

9696
def get_available_detectors(dataset: str) -> list[str]:
97-
""" Only for PyTorch models.
97+
"""Only for PyTorch models.
9898
9999
Returns:
100100
The detectors available for the dataset.
@@ -103,7 +103,7 @@ def get_available_detectors(dataset: str) -> list[str]:
103103

104104

105105
def get_available_models(dataset: str) -> list[str]:
106-
""" Only for PyTorch models.
106+
"""Only for PyTorch models.
107107
108108
Returns:
109109
The pose models available for the dataset.
@@ -139,19 +139,39 @@ def download_huggingface_model(
139139
model_name: str,
140140
target_dir: str = ".",
141141
remove_hf_folder: bool = True,
142-
rename_mapping: dict | None = None,
142+
rename_mapping: str | dict | None = None,
143143
):
144144
"""
145145
Downloads a DeepLabCut Model Zoo Project from Hugging Face.
146146
147147
Args:
148-
model_name (str): Name of the ModelZoo model.
148+
model_name (str):
149+
Name of the ModelZoo model.
149150
For visualizations, see http://www.mackenziemathislab.org/dlc-modelzoo.
150-
target_dir (str): Directory where the model weights and pose_cfg.yaml file will be stored.
151-
remove_hf_folder (bool, optional): Whether to remove the directory structure provided by HuggingFace
152-
after downloading and decompressing the data into DeepLabCut format. Defaults to True.
153-
rename_mapping (dict, optional): A dictionary to rename the downloaded file.
154-
If None, the original filename is used. Defaults to None.
151+
target_dir (str, optional):
152+
Target directory where the model weights will be stored.
153+
Defaults to the current directory.
154+
remove_hf_folder (bool, optional):
155+
Whether to remove the directory structure created by HuggingFace
156+
after downloading and decompressing the data into DeepLabCut format.
157+
Defaults to True.
158+
rename_mapping (dict | str | None, optional):
159+
- If a dictionary, it should map the original Hugging Face filenames
160+
to new filenames (e.g. {"snapshot-12345.tar.gz": "mymodel.tar.gz"}).
161+
- If a string, it is interpreted as the new name for the downloaded file
162+
- If None, the original filename is used.
163+
Defaults to None.
164+
165+
Examples:
166+
>>> # Download without renaming, keep original filename
167+
download_huggingface_model("superanimal_bird_resnet_50", remove_hf_folder=False)
168+
169+
>>> # Download and rename by specifying the new name directly
170+
download_huggingface_model(
171+
model_name="superanimal_humanbody_rtmpose_x",
172+
target_dir="/path/to/,y/checkpoints",
173+
rename_mapping="superanimal_humanbody_rtmpose_x.pt"
174+
)
155175
"""
156176
net_urls = _load_model_names()
157177
if model_name not in net_urls:
@@ -180,6 +200,10 @@ def download_huggingface_model(
180200
path_ = os.path.join(target_dir, hf_folder, "snapshots")
181201
commit = os.listdir(path_)[0]
182202
file_name = os.path.join(path_, commit, targzfn)
203+
204+
if isinstance(rename_mapping, str):
205+
rename_mapping = {targzfn: rename_mapping}
206+
183207
_handle_downloaded_file(file_name, target_dir, rename_mapping)
184208

185209
if remove_hf_folder:

0 commit comments

Comments
 (0)