1
1
import torch
2
- from typing import Optional , Tuple , Dict , List
2
+ from typing import Optional , Tuple , Dict , List , Any
3
3
import logging
4
4
from clt .config import CLTConfig
5
5
from torch .distributed import ProcessGroup
@@ -26,9 +26,10 @@ def _compute_mask(x: torch.Tensor, k_per_token: int, x_for_ranking: Optional[tor
26
26
27
27
if k_total_batch > 0 :
28
28
_ , flat_indices = torch .topk (ranking_flat , k_total_batch , sorted = False )
29
- mask_flat = torch .zeros_like (x_flat , dtype = torch .bool )
30
- mask_flat [flat_indices ] = True
31
- mask = mask_flat .view_as (x )
29
+ # Optimized mask creation - avoid individual indexing
30
+ mask = torch .zeros (x_flat .numel (), dtype = torch .bool , device = x .device )
31
+ mask [flat_indices ] = True
32
+ mask = mask .view_as (x )
32
33
else :
33
34
mask = torch .zeros_like (x , dtype = torch .bool )
34
35
@@ -118,6 +119,7 @@ def _compute_mask(x: torch.Tensor, k_float: float, x_for_ranking: Optional[torch
118
119
119
120
if k_per_token > 0 :
120
121
_ , topk_indices_per_row = torch .topk (ranking_tensor_to_use , k_per_token , dim = - 1 , sorted = False )
122
+ # Use scatter_ for efficient mask creation
121
123
mask = torch .zeros_like (x , dtype = torch .bool )
122
124
mask .scatter_ (- 1 , topk_indices_per_row , True )
123
125
else :
@@ -231,6 +233,7 @@ def _apply_batch_topk_helper(
231
233
dtype : torch .dtype ,
232
234
rank : int ,
233
235
process_group : Optional [ProcessGroup ],
236
+ profiler : Optional [Any ] = None ,
234
237
) -> Dict [int , torch .Tensor ]:
235
238
"""Helper to apply BatchTopK globally across concatenated layer pre-activations."""
236
239
@@ -304,17 +307,42 @@ def _apply_batch_topk_helper(
304
307
305
308
if world_size > 1 :
306
309
if rank == 0 :
307
- local_mask = BatchTopK ._compute_mask (
308
- concatenated_preactivations_original , k_val , concatenated_preactivations_normalized
309
- )
310
+ if profiler :
311
+ with profiler .timer ("batchtopk_compute_mask" ) as timer :
312
+ local_mask = BatchTopK ._compute_mask (
313
+ concatenated_preactivations_original , k_val , concatenated_preactivations_normalized
314
+ )
315
+ if hasattr (timer , 'elapsed' ):
316
+ profiler .record ("batchtopk_compute_mask" , timer .elapsed )
317
+ else :
318
+ local_mask = BatchTopK ._compute_mask (
319
+ concatenated_preactivations_original , k_val , concatenated_preactivations_normalized
320
+ )
310
321
mask .copy_ (local_mask )
311
- dist_ops .broadcast (mask , src = 0 , group = process_group )
322
+
323
+ if hasattr (profiler , 'dist_profiler' ) and profiler .dist_profiler :
324
+ with profiler .dist_profiler .profile_op ("batchtopk_broadcast" ):
325
+ dist_ops .broadcast (mask , src = 0 , group = process_group )
326
+ else :
327
+ dist_ops .broadcast (mask , src = 0 , group = process_group )
312
328
else :
313
- dist_ops .broadcast (mask , src = 0 , group = process_group )
329
+ if hasattr (profiler , 'dist_profiler' ) and profiler .dist_profiler :
330
+ with profiler .dist_profiler .profile_op ("batchtopk_broadcast" ):
331
+ dist_ops .broadcast (mask , src = 0 , group = process_group )
332
+ else :
333
+ dist_ops .broadcast (mask , src = 0 , group = process_group )
314
334
else :
315
- mask = BatchTopK ._compute_mask (
316
- concatenated_preactivations_original , k_val , concatenated_preactivations_normalized
317
- )
335
+ if profiler :
336
+ with profiler .timer ("batchtopk_compute_mask" ) as timer :
337
+ mask = BatchTopK ._compute_mask (
338
+ concatenated_preactivations_original , k_val , concatenated_preactivations_normalized
339
+ )
340
+ if hasattr (timer , 'elapsed' ):
341
+ profiler .record ("batchtopk_compute_mask" , timer .elapsed )
342
+ else :
343
+ mask = BatchTopK ._compute_mask (
344
+ concatenated_preactivations_original , k_val , concatenated_preactivations_normalized
345
+ )
318
346
319
347
activated_concatenated = concatenated_preactivations_original * mask .to (dtype )
320
348
@@ -336,6 +364,7 @@ def _apply_token_topk_helper(
336
364
dtype : torch .dtype ,
337
365
rank : int ,
338
366
process_group : Optional [ProcessGroup ],
367
+ profiler : Optional [Any ] = None ,
339
368
) -> Dict [int , torch .Tensor ]:
340
369
"""Helper to apply TokenTopK globally across concatenated layer pre-activations."""
341
370
world_size = dist_ops .get_world_size (process_group )
@@ -408,19 +437,46 @@ def _apply_token_topk_helper(
408
437
409
438
if world_size > 1 :
410
439
if rank == 0 :
411
- local_mask = TokenTopK ._compute_mask (
412
- concatenated_preactivations_original ,
413
- k_val_float ,
414
- concatenated_preactivations_normalized ,
415
- )
440
+ if profiler :
441
+ with profiler .timer ("topk_compute_mask" ) as timer :
442
+ local_mask = TokenTopK ._compute_mask (
443
+ concatenated_preactivations_original ,
444
+ k_val_float ,
445
+ concatenated_preactivations_normalized ,
446
+ )
447
+ if hasattr (timer , 'elapsed' ):
448
+ profiler .record ("topk_compute_mask" , timer .elapsed )
449
+ else :
450
+ local_mask = TokenTopK ._compute_mask (
451
+ concatenated_preactivations_original ,
452
+ k_val_float ,
453
+ concatenated_preactivations_normalized ,
454
+ )
416
455
mask .copy_ (local_mask )
417
- dist_ops .broadcast (mask , src = 0 , group = process_group )
456
+
457
+ if hasattr (profiler , 'dist_profiler' ) and profiler .dist_profiler :
458
+ with profiler .dist_profiler .profile_op ("topk_broadcast" ):
459
+ dist_ops .broadcast (mask , src = 0 , group = process_group )
460
+ else :
461
+ dist_ops .broadcast (mask , src = 0 , group = process_group )
418
462
else :
419
- dist_ops .broadcast (mask , src = 0 , group = process_group )
463
+ if hasattr (profiler , 'dist_profiler' ) and profiler .dist_profiler :
464
+ with profiler .dist_profiler .profile_op ("topk_broadcast" ):
465
+ dist_ops .broadcast (mask , src = 0 , group = process_group )
466
+ else :
467
+ dist_ops .broadcast (mask , src = 0 , group = process_group )
420
468
else :
421
- mask = TokenTopK ._compute_mask (
422
- concatenated_preactivations_original , k_val_float , concatenated_preactivations_normalized
423
- )
469
+ if profiler :
470
+ with profiler .timer ("topk_compute_mask" ) as timer :
471
+ mask = TokenTopK ._compute_mask (
472
+ concatenated_preactivations_original , k_val_float , concatenated_preactivations_normalized
473
+ )
474
+ if hasattr (timer , 'elapsed' ):
475
+ profiler .record ("topk_compute_mask" , timer .elapsed )
476
+ else :
477
+ mask = TokenTopK ._compute_mask (
478
+ concatenated_preactivations_original , k_val_float , concatenated_preactivations_normalized
479
+ )
424
480
425
481
activated_concatenated = concatenated_preactivations_original * mask .to (dtype )
426
482
0 commit comments