@@ -228,50 +228,16 @@ def __init__(
228
228
if torch .cuda .is_available ():
229
229
self .register_buffer ("mask_cuda" , mask .cuda (), persistent = False )
230
230
231
-
232
- def forward (self , tokens : Tensor , audio_features : Tensor ) -> Tensor :
231
+ def forward (self , tokens : Tensor , audio_features : Tensor , kv_cache : Optional [dict ] = None ):
233
232
"""
234
233
Args:
235
234
tokens: (n_batch, n_token)
236
235
audio_features: (n_batch, n_audio_ctx, n_audio_state)
236
+ kv_cache: Optional cache for key/value tensors
237
237
238
238
Returns:
239
239
logits: (n_batch, n_token, n_vocab)
240
240
"""
241
- n_batch , n_token = tokens .shape
242
- n_audio_ctx , n_audio_state = audio_features .shape [1 :]
243
-
244
- x = self .token_embedding (tokens ) + self .positional_embedding [:n_token ]
245
-
246
- # Optimisation: Move audio_features to GPU once here.
247
- if torch .cuda .is_available ():
248
- audio_features = audio_features .cuda ()
249
-
250
-
251
- for block in self .blocks :
252
- x = block (x , audio_features )
253
-
254
- x = self .ln (x )
255
- logits = x @ self .token_embedding .weight .T
256
-
257
- # Optimisation: Apply the precomputed CUDA mask if available.
258
- if torch .cuda .is_available ():
259
- mask = self .mask_cuda [:n_token , :n_token ]
260
- else :
261
- mask = self .mask [:n_token , :n_token ]
262
-
263
- logits = logits + mask
264
-
265
- return logits
266
-
267
-
268
- def forward (self , x : Tensor , xa : Tensor , kv_cache : Optional [dict ] = None ):
269
- """
270
- Args:
271
- tokens: (n_batch, n_token) or x tensor
272
- audio_features: (n_batch, n_audio_ctx, n_audio_state) or xa tensor
273
- kv_cache: Optional cache for key/value tensors
274
- """
275
241
if kv_cache is not None :
276
242
# Handle the kv_cache case
277
243
offset = next (iter (kv_cache .values ())).shape [1 ] if kv_cache else 0
@@ -313,7 +279,6 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
313
279
314
280
return logits
315
281
316
-
317
282
# The Whisper class has been moved outside of TextDecoder and is now a top-level class
318
283
class Whisper (nn .Module ):
319
284
def __init__ (self , dims : ModelDimensions ):
0 commit comments