diff --git a/mlx_embeddings/utils.py b/mlx_embeddings/utils.py index b1f8380ad5..bced432e03 100644 --- a/mlx_embeddings/utils.py +++ b/mlx_embeddings/utils.py @@ -167,17 +167,20 @@ def load_model( # siglip models have a different image size if "siglip" in config["model_type"]: - # Extract the image size - image_size = re.search( - r"patch\d+-(\d+)(?:-|$)", kwargs["path_to_repo"] - ).group(1) - # Extract the patch size - patch_size = re.search(r"patch(\d+)", kwargs["path_to_repo"]).group(1) - patch_size = ( - re.search(r"\d+", patch_size).group() - if re.search(r"\d+", patch_size) - else patch_size - ) + if not isinstance(image_size := model_config.get("image_size"), int): + # Extract the image size from hf repo name if not supplied + image_size = re.search( + r"patch\d+-(\d+)(?:-|$)", kwargs["path_to_repo"] + ).group(1) + + if not isinstance(patch_size := model_config.get("patch_size"), int): + # Extract the patch size from hf repo if not supplied + patch_size = re.search(r"patch(\d+)", kwargs["path_to_repo"]).group(1) + patch_size = ( + re.search(r"\d+", patch_size).group() + if re.search(r"\d+", patch_size) + else patch_size + ) if model_args.vision_config.image_size != int(image_size): model_args.vision_config.image_size = int(image_size) if model_args.vision_config.patch_size != int(patch_size): diff --git a/requirements.txt b/requirements.txt index 4140b0d4bd..93455fddb9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ mlx>=0.16.3 mlx-vlm>=0.1.21 transformers[sentencepiece]>=4.44.0 huggingface-hub>=0.25.1 +torch