Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import urllib
import warnings
from typing import List, Optional, Union
from typing import List, Optional, Union, BinaryIO

import torch
from tqdm import tqdm
Expand Down Expand Up @@ -51,7 +51,25 @@
}


def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
# hashlib.file_digest() added in Python 3.11
if not hasattr(hashlib, 'file_digest'):
def _file_digest(file: BinaryIO, algo: str):
d = hashlib.new(algo)

while True:
buf = file.read(65536)

if not buf:
break

d.update(buf)

return d

hashlib.file_digest = _file_digest


def _download(url: str, root: str) -> str:
os.makedirs(root, exist_ok=True)

expected_sha256 = url.split("/")[-2]
Expand All @@ -62,10 +80,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:

if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
if hashlib.file_digest(f, "sha256").hexdigest() == expected_sha256:
return download_target

warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
Expand All @@ -86,13 +103,13 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
output.write(buffer)
loop.update(len(buffer))

model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
with open(download_target, "rb") as f:
if hashlib.file_digest(f, "sha256").hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)

return model_bytes if in_memory else download_target
return download_target


def available_models() -> List[str]:
Expand Down Expand Up @@ -134,22 +151,25 @@ def load_model(
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")

if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
checkpoint_file = _download(_MODELS[name], download_root)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
checkpoint_file = name
alignment_heads = None
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)

with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
fp = open(checkpoint_file, "rb")

if in_memory:
with fp:
fp = io.BytesIO(fp.read())

with fp:
kwargs = {"weights_only": True} if torch.__version__ >= "1.13" else {}
checkpoint = torch.load(fp, map_location=device, **kwargs)
del checkpoint_file

dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
Expand Down