Skip to content

Commit b7eaf44

Browse files
Fix forward method overload in TextDecoder
1 parent 7a552cb commit b7eaf44

File tree

1 file changed

+2
-37
lines changed

1 file changed

+2
-37
lines changed

whisper/model.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -228,50 +228,16 @@ def __init__(
228228
if torch.cuda.is_available():
229229
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)
230230

231-
232-
def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
231+
def forward(self, tokens: Tensor, audio_features: Tensor, kv_cache: Optional[dict] = None):
233232
"""
234233
Args:
235234
tokens: (n_batch, n_token)
236235
audio_features: (n_batch, n_audio_ctx, n_audio_state)
236+
kv_cache: Optional cache for key/value tensors
237237
238238
Returns:
239239
logits: (n_batch, n_token, n_vocab)
240240
"""
241-
n_batch, n_token = tokens.shape
242-
n_audio_ctx, n_audio_state = audio_features.shape[1:]
243-
244-
x = self.token_embedding(tokens) + self.positional_embedding[:n_token]
245-
246-
# Optimisation: Move audio_features to GPU once here.
247-
if torch.cuda.is_available():
248-
audio_features = audio_features.cuda()
249-
250-
251-
for block in self.blocks:
252-
x = block(x, audio_features)
253-
254-
x = self.ln(x)
255-
logits = x @ self.token_embedding.weight.T
256-
257-
# Optimisation: Apply the precomputed CUDA mask if available.
258-
if torch.cuda.is_available():
259-
mask = self.mask_cuda[:n_token, :n_token]
260-
else:
261-
mask = self.mask[:n_token, :n_token]
262-
263-
logits = logits + mask
264-
265-
return logits
266-
267-
268-
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
269-
"""
270-
Args:
271-
tokens: (n_batch, n_token) or x tensor
272-
audio_features: (n_batch, n_audio_ctx, n_audio_state) or xa tensor
273-
kv_cache: Optional cache for key/value tensors
274-
"""
275241
if kv_cache is not None:
276242
# Handle the kv_cache case
277243
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
@@ -313,7 +279,6 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
313279

314280
return logits
315281

316-
317282
# The Whisper class has been moved outside of TextDecoder and is now a top-level class
318283
class Whisper(nn.Module):
319284
def __init__(self, dims: ModelDimensions):

0 commit comments

Comments
 (0)