@@ -94,7 +94,7 @@ def get_available_datasets() -> list[str]:
9494
9595
9696def get_available_detectors (dataset : str ) -> list [str ]:
97- """ Only for PyTorch models.
97+ """Only for PyTorch models.
9898
9999 Returns:
100100 The detectors available for the dataset.
@@ -103,7 +103,7 @@ def get_available_detectors(dataset: str) -> list[str]:
103103
104104
105105def get_available_models (dataset : str ) -> list [str ]:
106- """ Only for PyTorch models.
106+ """Only for PyTorch models.
107107
108108 Returns:
109109 The pose models available for the dataset.
@@ -139,19 +139,39 @@ def download_huggingface_model(
139139 model_name : str ,
140140 target_dir : str = "." ,
141141 remove_hf_folder : bool = True ,
142- rename_mapping : dict | None = None ,
142+ rename_mapping : str | dict | None = None ,
143143):
144144 """
145145 Downloads a DeepLabCut Model Zoo Project from Hugging Face.
146146
147147 Args:
148- model_name (str): Name of the ModelZoo model.
148+ model_name (str):
149+ Name of the ModelZoo model.
149150 For visualizations, see http://www.mackenziemathislab.org/dlc-modelzoo.
150- target_dir (str): Directory where the model weights and pose_cfg.yaml file will be stored.
151- remove_hf_folder (bool, optional): Whether to remove the directory structure provided by HuggingFace
152- after downloading and decompressing the data into DeepLabCut format. Defaults to True.
153- rename_mapping (dict, optional): A dictionary to rename the downloaded file.
154- If None, the original filename is used. Defaults to None.
151+ target_dir (str, optional):
152+ Target directory where the model weights will be stored.
153+ Defaults to the current directory.
154+ remove_hf_folder (bool, optional):
155+ Whether to remove the directory structure created by HuggingFace
156+ after downloading and decompressing the data into DeepLabCut format.
157+ Defaults to True.
158+ rename_mapping (dict | str | None, optional):
159+ - If a dictionary, it should map the original Hugging Face filenames
160+ to new filenames (e.g. {"snapshot-12345.tar.gz": "mymodel.tar.gz"}).
161+ - If a string, it is interpreted as the new name for the downloaded file
162+ - If None, the original filename is used.
163+ Defaults to None.
164+
165+ Examples:
166+ >>> # Download without renaming, keep original filename
167+ download_huggingface_model("superanimal_bird_resnet_50", remove_hf_folder=False)
168+
169+ >>> # Download and rename by specifying the new name directly
170+ download_huggingface_model(
171+ model_name="superanimal_humanbody_rtmpose_x",
172+ target_dir="/path/to/,y/checkpoints",
173+ rename_mapping="superanimal_humanbody_rtmpose_x.pt"
174+ )
155175 """
156176 net_urls = _load_model_names ()
157177 if model_name not in net_urls :
@@ -180,6 +200,10 @@ def download_huggingface_model(
180200 path_ = os .path .join (target_dir , hf_folder , "snapshots" )
181201 commit = os .listdir (path_ )[0 ]
182202 file_name = os .path .join (path_ , commit , targzfn )
203+
204+ if isinstance (rename_mapping , str ):
205+ rename_mapping = {targzfn : rename_mapping }
206+
183207 _handle_downloaded_file (file_name , target_dir , rename_mapping )
184208
185209 if remove_hf_folder :
0 commit comments