@@ -224,6 +224,47 @@ def __init__(
224
224
mask = torch .empty (n_ctx , n_ctx ).fill_ (- np .inf ).triu_ (1 )
225
225
self .register_buffer ("mask" , mask , persistent = False )
226
226
227
+ # Optimisation: pre-compute and register the mask in CUDA if available
228
+ if torch .cuda .is_available ():
229
+ self .register_buffer ("mask_cuda" , mask .cuda (), persistent = False )
230
+
231
+
232
+ def forward (self , tokens : Tensor , audio_features : Tensor ) -> Tensor :
233
+ """
234
+ Args:
235
+ tokens: (n_batch, n_token)
236
+ audio_features: (n_batch, n_audio_ctx, n_audio_state)
237
+
238
+ Returns:
239
+ logits: (n_batch, n_token, n_vocab)
240
+ """
241
+ n_batch , n_token = tokens .shape
242
+ n_audio_ctx , n_audio_state = audio_features .shape [1 :]
243
+
244
+ x = self .token_embedding (tokens ) + self .positional_embedding [:n_token ]
245
+
246
+ # Optimisation: Move audio_features to GPU once here.
247
+ if torch .cuda .is_available ():
248
+ audio_features = audio_features .cuda ()
249
+
250
+
251
+ for block in self .blocks :
252
+ x = block (x , audio_features )
253
+
254
+ x = self .ln (x )
255
+ logits = x @ self .token_embedding .weight .T
256
+
257
+ # Optimisation: Apply the precomputed CUDA mask if available.
258
+ if torch .cuda .is_available ():
259
+ mask = self .mask_cuda [:n_token , :n_token ]
260
+ else :
261
+ mask = self .mask [:n_token , :n_token ]
262
+
263
+ logits = logits + mask
264
+
265
+ return logits
266
+
267
+
227
268
def forward (self , x : Tensor , xa : Tensor , kv_cache : Optional [dict ] = None ):
228
269
"""
229
270
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
@@ -342,4 +383,4 @@ def install_hooks(layer: nn.Module):
342
383
343
384
detect_language = detect_language_function
344
385
transcribe = transcribe_function
345
- decode = decode_function
386
+ decode = decode_function
0 commit comments