diff --git a/whisper/__init__.py b/whisper/__init__.py index e210718f3..bfe0a8984 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -60,11 +60,20 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") + def compute_sha256(file_path: str) -> str: + sha256 = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + sha256.update(chunk) + return sha256.hexdigest() + if os.path.isfile(download_target): - with open(download_target, "rb") as f: - model_bytes = f.read() - if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: - return model_bytes if in_memory else download_target + if compute_sha256(download_target) == expected_sha256: + if in_memory: + with open(download_target, "rb") as f: + return f.read() + else: + return download_target else: warnings.warn( f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" @@ -86,13 +95,16 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: output.write(buffer) loop.update(len(buffer)) - model_bytes = open(download_target, "rb").read() - if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: + if compute_sha256(download_target) != expected_sha256: raise RuntimeError( "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." ) - return model_bytes if in_memory else download_target + if in_memory: + with open(download_target, "rb") as f: + return f.read() + else: + return download_target def available_models() -> List[str]: @@ -157,4 +169,4 @@ def load_model( if alignment_heads is not None: model.set_alignment_heads(alignment_heads) - return model.to(device) + return model.to(device) \ No newline at end of file diff --git a/whisper/model.py b/whisper/model.py index e53744738..072f0a4f5 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -224,31 +224,64 @@ def __init__( mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) self.register_buffer("mask", mask, persistent=False) - def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): + # Optimisation: pre-compute and register the mask in CUDA if available + if torch.cuda.is_available(): + self.register_buffer("mask_cuda", mask.cuda(), persistent=False) + + + def forward(self, tokens: Tensor, audio_features: Tensor, kv_cache: Optional[dict] = None) -> Tensor: """ - x : torch.LongTensor, shape = (batch_size, <= n_ctx) - the text tokens - xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) - the encoded audio features to be attended on + Args: + tokens: (n_batch, n_token) + audio_features: (n_batch, n_audio_ctx, n_audio_state) + kv_cache: Optional cache for key/value tensors + + Returns: + logits: (n_batch, n_token, n_vocab) """ - offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 - x = ( - self.token_embedding(x) - + self.positional_embedding[offset : offset + x.shape[-1]] - ) - x = x.to(xa.dtype) + n_batch, n_token = tokens.shape + + # Get the dtype of audio_features to ensure consistency + dtype = audio_features.dtype + + # Handle kv_cache for token embedding offset + if kv_cache is not None: + offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + x = self.token_embedding(tokens) + self.positional_embedding[offset:offset + tokens.shape[1]] + else: + x = self.token_embedding(tokens) + self.positional_embedding[:n_token] + + # Convert to the same dtype as audio_features + x = x.to(dtype) + # Optimisation: Move audio_features to GPU once here. + if torch.cuda.is_available(): + audio_features = audio_features.cuda() + + # Process through attention blocks for block in self.blocks: - x = block(x, xa, mask=self.mask, kv_cache=kv_cache) + x = block(x, audio_features, kv_cache=kv_cache) x = self.ln(x) - logits = ( - x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) - ).float() + + # Ensure consistent dtype for matrix multiplication + # Convert token_embedding weight to the same dtype as x + embedding_weights = self.token_embedding.weight.to(x.dtype) + logits = x @ embedding_weights.T + + # Apply mask if not using kv_cache (inference) + if kv_cache is None: + # Optimisation: Apply the precomputed CUDA mask if available. + if torch.cuda.is_available(): + mask = self.mask_cuda[:n_token, :n_token] + else: + mask = self.mask[:n_token, :n_token] + + logits = logits + mask return logits - - + +# The Whisper class has been moved outside of TextDecoder and is now a top-level class class Whisper(nn.Module): def __init__(self, dims: ModelDimensions): super().__init__()