Skip to content

Commit 424717d

Browse files
authored
Merge pull request #36 from curt-tigges/feature/merge_dist_files
Feature/merge dist files
2 parents 9c9aa38 + b0ee682 commit 424717d

19 files changed

+1679
-377
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ clt_test_pythia_70m_jumprelu/
207207
clt_smoke_output_local_wandb_batchtopk/
208208
clt_smoke_output_remote_wandb/
209209
wandb/
210+
scripts/debug
211+
scripts/optimization
210212

211213
# models
212214
*.pt

clt/config/clt_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ class TrainingConfig:
158158

159159
# Optional diagnostic metrics (can be slow)
160160
compute_sparsity_diagnostics: bool = False # Whether to compute detailed sparsity diagnostics during eval
161+
162+
# Performance profiling
163+
enable_profiling: bool = False # Whether to enable detailed performance profiling
161164

162165
# Dead feature tracking
163166
dead_feature_window: int = 1000 # Steps until a feature is considered dead

clt/models/activations.py

Lines changed: 78 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from typing import Optional, Tuple, Dict, List
2+
from typing import Optional, Tuple, Dict, List, Any
33
import logging
44
from clt.config import CLTConfig
55
from torch.distributed import ProcessGroup
@@ -26,9 +26,10 @@ def _compute_mask(x: torch.Tensor, k_per_token: int, x_for_ranking: Optional[tor
2626

2727
if k_total_batch > 0:
2828
_, flat_indices = torch.topk(ranking_flat, k_total_batch, sorted=False)
29-
mask_flat = torch.zeros_like(x_flat, dtype=torch.bool)
30-
mask_flat[flat_indices] = True
31-
mask = mask_flat.view_as(x)
29+
# Optimized mask creation - avoid individual indexing
30+
mask = torch.zeros(x_flat.numel(), dtype=torch.bool, device=x.device)
31+
mask[flat_indices] = True
32+
mask = mask.view_as(x)
3233
else:
3334
mask = torch.zeros_like(x, dtype=torch.bool)
3435

@@ -118,6 +119,7 @@ def _compute_mask(x: torch.Tensor, k_float: float, x_for_ranking: Optional[torch
118119

119120
if k_per_token > 0:
120121
_, topk_indices_per_row = torch.topk(ranking_tensor_to_use, k_per_token, dim=-1, sorted=False)
122+
# Use scatter_ for efficient mask creation
121123
mask = torch.zeros_like(x, dtype=torch.bool)
122124
mask.scatter_(-1, topk_indices_per_row, True)
123125
else:
@@ -231,6 +233,7 @@ def _apply_batch_topk_helper(
231233
dtype: torch.dtype,
232234
rank: int,
233235
process_group: Optional[ProcessGroup],
236+
profiler: Optional[Any] = None,
234237
) -> Dict[int, torch.Tensor]:
235238
"""Helper to apply BatchTopK globally across concatenated layer pre-activations."""
236239

@@ -304,17 +307,42 @@ def _apply_batch_topk_helper(
304307

305308
if world_size > 1:
306309
if rank == 0:
307-
local_mask = BatchTopK._compute_mask(
308-
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
309-
)
310+
if profiler:
311+
with profiler.timer("batchtopk_compute_mask") as timer:
312+
local_mask = BatchTopK._compute_mask(
313+
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
314+
)
315+
if hasattr(timer, 'elapsed'):
316+
profiler.record("batchtopk_compute_mask", timer.elapsed)
317+
else:
318+
local_mask = BatchTopK._compute_mask(
319+
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
320+
)
310321
mask.copy_(local_mask)
311-
dist_ops.broadcast(mask, src=0, group=process_group)
322+
323+
if hasattr(profiler, 'dist_profiler') and profiler.dist_profiler:
324+
with profiler.dist_profiler.profile_op("batchtopk_broadcast"):
325+
dist_ops.broadcast(mask, src=0, group=process_group)
326+
else:
327+
dist_ops.broadcast(mask, src=0, group=process_group)
312328
else:
313-
dist_ops.broadcast(mask, src=0, group=process_group)
329+
if hasattr(profiler, 'dist_profiler') and profiler.dist_profiler:
330+
with profiler.dist_profiler.profile_op("batchtopk_broadcast"):
331+
dist_ops.broadcast(mask, src=0, group=process_group)
332+
else:
333+
dist_ops.broadcast(mask, src=0, group=process_group)
314334
else:
315-
mask = BatchTopK._compute_mask(
316-
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
317-
)
335+
if profiler:
336+
with profiler.timer("batchtopk_compute_mask") as timer:
337+
mask = BatchTopK._compute_mask(
338+
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
339+
)
340+
if hasattr(timer, 'elapsed'):
341+
profiler.record("batchtopk_compute_mask", timer.elapsed)
342+
else:
343+
mask = BatchTopK._compute_mask(
344+
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
345+
)
318346

319347
activated_concatenated = concatenated_preactivations_original * mask.to(dtype)
320348

@@ -336,6 +364,7 @@ def _apply_token_topk_helper(
336364
dtype: torch.dtype,
337365
rank: int,
338366
process_group: Optional[ProcessGroup],
367+
profiler: Optional[Any] = None,
339368
) -> Dict[int, torch.Tensor]:
340369
"""Helper to apply TokenTopK globally across concatenated layer pre-activations."""
341370
world_size = dist_ops.get_world_size(process_group)
@@ -408,19 +437,46 @@ def _apply_token_topk_helper(
408437

409438
if world_size > 1:
410439
if rank == 0:
411-
local_mask = TokenTopK._compute_mask(
412-
concatenated_preactivations_original,
413-
k_val_float,
414-
concatenated_preactivations_normalized,
415-
)
440+
if profiler:
441+
with profiler.timer("topk_compute_mask") as timer:
442+
local_mask = TokenTopK._compute_mask(
443+
concatenated_preactivations_original,
444+
k_val_float,
445+
concatenated_preactivations_normalized,
446+
)
447+
if hasattr(timer, 'elapsed'):
448+
profiler.record("topk_compute_mask", timer.elapsed)
449+
else:
450+
local_mask = TokenTopK._compute_mask(
451+
concatenated_preactivations_original,
452+
k_val_float,
453+
concatenated_preactivations_normalized,
454+
)
416455
mask.copy_(local_mask)
417-
dist_ops.broadcast(mask, src=0, group=process_group)
456+
457+
if hasattr(profiler, 'dist_profiler') and profiler.dist_profiler:
458+
with profiler.dist_profiler.profile_op("topk_broadcast"):
459+
dist_ops.broadcast(mask, src=0, group=process_group)
460+
else:
461+
dist_ops.broadcast(mask, src=0, group=process_group)
418462
else:
419-
dist_ops.broadcast(mask, src=0, group=process_group)
463+
if hasattr(profiler, 'dist_profiler') and profiler.dist_profiler:
464+
with profiler.dist_profiler.profile_op("topk_broadcast"):
465+
dist_ops.broadcast(mask, src=0, group=process_group)
466+
else:
467+
dist_ops.broadcast(mask, src=0, group=process_group)
420468
else:
421-
mask = TokenTopK._compute_mask(
422-
concatenated_preactivations_original, k_val_float, concatenated_preactivations_normalized
423-
)
469+
if profiler:
470+
with profiler.timer("topk_compute_mask") as timer:
471+
mask = TokenTopK._compute_mask(
472+
concatenated_preactivations_original, k_val_float, concatenated_preactivations_normalized
473+
)
474+
if hasattr(timer, 'elapsed'):
475+
profiler.record("topk_compute_mask", timer.elapsed)
476+
else:
477+
mask = TokenTopK._compute_mask(
478+
concatenated_preactivations_original, k_val_float, concatenated_preactivations_normalized
479+
)
424480

425481
activated_concatenated = concatenated_preactivations_original * mask.to(dtype)
426482

0 commit comments

Comments
 (0)