Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,20 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")

def compute_sha256(file_path: str) -> str:
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
return sha256.hexdigest()

if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
if compute_sha256(download_target) == expected_sha256:
if in_memory:
with open(download_target, "rb") as f:
return f.read()
else:
return download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
Expand All @@ -86,13 +95,16 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
output.write(buffer)
loop.update(len(buffer))

model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
if compute_sha256(download_target) != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)

return model_bytes if in_memory else download_target
if in_memory:
with open(download_target, "rb") as f:
return f.read()
else:
return download_target


def available_models() -> List[str]:
Expand Down Expand Up @@ -157,4 +169,4 @@ def load_model(
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)

return model.to(device)
return model.to(device)
67 changes: 50 additions & 17 deletions whisper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,31 +224,64 @@ def __init__(
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)

def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
# Optimisation: pre-compute and register the mask in CUDA if available
if torch.cuda.is_available():
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)


def forward(self, tokens: Tensor, audio_features: Tensor, kv_cache: Optional[dict] = None) -> Tensor:
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
Args:
tokens: (n_batch, n_token)
audio_features: (n_batch, n_audio_ctx, n_audio_state)
kv_cache: Optional cache for key/value tensors

Returns:
logits: (n_batch, n_token, n_vocab)
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
n_batch, n_token = tokens.shape

# Get the dtype of audio_features to ensure consistency
dtype = audio_features.dtype

# Handle kv_cache for token embedding offset
if kv_cache is not None:
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = self.token_embedding(tokens) + self.positional_embedding[offset:offset + tokens.shape[1]]
else:
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]

# Convert to the same dtype as audio_features
x = x.to(dtype)

# Optimisation: Move audio_features to GPU once here.
if torch.cuda.is_available():
audio_features = audio_features.cuda()

# Process through attention blocks
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = block(x, audio_features, kv_cache=kv_cache)

x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()

# Ensure consistent dtype for matrix multiplication
# Convert token_embedding weight to the same dtype as x
embedding_weights = self.token_embedding.weight.to(x.dtype)
logits = x @ embedding_weights.T

# Apply mask if not using kv_cache (inference)
if kv_cache is None:
# Optimisation: Apply the precomputed CUDA mask if available.
if torch.cuda.is_available():
mask = self.mask_cuda[:n_token, :n_token]
else:
mask = self.mask[:n_token, :n_token]

logits = logits + mask

return logits


# The Whisper class has been moved outside of TextDecoder and is now a top-level class
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
super().__init__()
Expand Down