Skip to content

Commit 2665913

Browse files
authored
Merge pull request #27 from curt-tigges/cleanup-stage-05
Cleanup stage 05
2 parents ccd6fad + ba9145c commit 2665913

22 files changed

+2125
-6120
lines changed

.github/workflows/python-tests.yml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
name: Python Tests
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
- develop # Or your primary development branch
8+
pull_request:
9+
branches:
10+
- main
11+
12+
jobs:
13+
test:
14+
runs-on: ubuntu-latest
15+
strategy:
16+
matrix:
17+
python-version: ["3.9", "3.10", "3.11"] # Specify python versions
18+
19+
steps:
20+
- uses: actions/checkout@v4
21+
22+
- name: Set up Python ${{ matrix.python-version }}
23+
uses: actions/setup-python@v4
24+
with:
25+
python-version: ${{ matrix.python-version }}
26+
27+
- name: Install Poetry
28+
run: |
29+
curl -sSL https://install.python-poetry.org | python3 -
30+
echo "$HOME/.local/bin" >> $GITHUB_PATH
31+
# Alternatively, if not using Poetry, or have a requirements.txt:
32+
# run: pip install -r requirements.txt
33+
34+
- name: Install dependencies
35+
run: poetry install --no-interaction --no-root
36+
# If you have dev dependencies for pytest, e.g. in a [tool.poetry.group.dev.dependencies]
37+
# run: poetry install --no-interaction --no-root --with dev
38+
# Or if using pip with requirements.txt:
39+
# run: pip install -r requirements-dev.txt # (if you have a separate dev requirements)
40+
# run: pip install pytest # or ensure pytest is in your main requirements
41+
42+
- name: Run tests with pytest
43+
run: poetry run pytest tests/
44+
# Or if not using poetry:
45+
# run: pytest tests/

clt/models/activations.py

Lines changed: 240 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import torch
2-
from typing import Optional, Tuple
2+
from typing import Optional, Tuple, Dict, List
3+
import torch.distributed as dist
4+
import logging
5+
from clt.config import CLTConfig
6+
from torch.distributed import ProcessGroup
37

48

59
class BatchTopK(torch.autograd.Function):
@@ -193,10 +197,243 @@ def backward(ctx, *grad_outputs: torch.Tensor) -> Tuple[Optional[torch.Tensor],
193197
grad_threshold_per_element = grad_output * local_grad_theta
194198

195199
if grad_threshold_per_element.dim() > threshold.dim():
200+
# Handles cases like input (B,F), threshold (F) or input (F), threshold (scalar)
196201
dims_to_sum = tuple(range(grad_threshold_per_element.dim() - threshold.dim()))
197202
grad_threshold = grad_threshold_per_element.sum(dim=dims_to_sum)
198-
if threshold.shape != torch.Size([]):
203+
# Ensure final shape matches threshold, especially if sum squeezed dimensions
204+
if grad_threshold.shape != threshold.shape:
199205
grad_threshold = grad_threshold.reshape(threshold.shape)
200-
else:
206+
elif grad_threshold_per_element.dim() == threshold.dim():
207+
# Handles cases like input (F), threshold (F), or input [1], threshold [1]
208+
grad_threshold = grad_threshold_per_element
209+
# Defensive reshape, though shapes should ideally match here.
210+
if grad_threshold.shape != threshold.shape:
211+
grad_threshold = grad_threshold.reshape(threshold.shape)
212+
else: # grad_threshold_per_element.dim() < threshold.dim()
213+
# This case is less common (e.g. input scalar, threshold vector - not typical for this op).
214+
# Defaulting to sum and reshape, primarily for scalar threshold case.
201215
grad_threshold = grad_threshold_per_element.sum()
216+
if grad_threshold.shape != threshold.shape:
217+
grad_threshold = grad_threshold.reshape(threshold.shape)
202218
return grad_input, grad_threshold, None
219+
220+
221+
# --- Helper functions for applying BatchTopK/TokenTopK globally --- #
222+
# These were previously in clt.models.encoding.py
223+
224+
logger_helpers = logging.getLogger(__name__ + ".helpers") # Use a sub-logger
225+
226+
227+
def _apply_batch_topk_helper(
228+
preactivations_dict: Dict[int, torch.Tensor],
229+
config: CLTConfig,
230+
device: torch.device,
231+
dtype: torch.dtype,
232+
rank: int,
233+
process_group: Optional[ProcessGroup],
234+
) -> Dict[int, torch.Tensor]:
235+
"""Helper to apply BatchTopK globally across concatenated layer pre-activations."""
236+
237+
world_size = 1
238+
if process_group is not None and dist.is_initialized():
239+
world_size = dist.get_world_size(process_group)
240+
241+
if not preactivations_dict:
242+
logger_helpers.warning(f"Rank {rank}: _apply_batch_topk_helper received empty preactivations_dict.")
243+
return {}
244+
245+
ordered_preactivations_original: List[torch.Tensor] = []
246+
ordered_preactivations_normalized: List[torch.Tensor] = []
247+
layer_feature_sizes: List[Tuple[int, int]] = []
248+
249+
first_valid_preact = next((p for p in preactivations_dict.values() if p.numel() > 0), None)
250+
if first_valid_preact is None:
251+
logger_helpers.warning(
252+
f"Rank {rank}: No valid preactivations found in dict for BatchTopK. Returning empty dict."
253+
)
254+
return {
255+
layer_idx: torch.empty((0, config.num_features), device=device, dtype=dtype)
256+
for layer_idx in preactivations_dict.keys()
257+
}
258+
batch_tokens_dim = first_valid_preact.shape[0]
259+
260+
for layer_idx in range(config.num_layers):
261+
if layer_idx in preactivations_dict:
262+
preact_orig = preactivations_dict[layer_idx]
263+
preact_orig = preact_orig.to(device=device, dtype=dtype)
264+
current_num_features = preact_orig.shape[1] if preact_orig.numel() > 0 else config.num_features
265+
266+
if preact_orig.numel() == 0:
267+
if batch_tokens_dim > 0:
268+
zeros_shape = (batch_tokens_dim, current_num_features)
269+
ordered_preactivations_original.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
270+
ordered_preactivations_normalized.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
271+
elif preact_orig.shape[0] != batch_tokens_dim:
272+
logger_helpers.warning(
273+
f"Rank {rank} Layer {layer_idx}: Mismatched batch dim ({preact_orig.shape[0]} vs {batch_tokens_dim}). Using zeros."
274+
)
275+
zeros_shape = (batch_tokens_dim, current_num_features)
276+
ordered_preactivations_original.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
277+
ordered_preactivations_normalized.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
278+
else:
279+
ordered_preactivations_original.append(preact_orig)
280+
mean = preact_orig.mean(dim=0, keepdim=True)
281+
std = preact_orig.std(dim=0, keepdim=True)
282+
preact_norm = (preact_orig - mean) / (std + 1e-6)
283+
ordered_preactivations_normalized.append(preact_norm)
284+
layer_feature_sizes.append((layer_idx, current_num_features))
285+
286+
if not ordered_preactivations_original:
287+
logger_helpers.warning(
288+
f"Rank {rank}: No tensors collected after iterating layers for BatchTopK. Returning empty activations."
289+
)
290+
return {
291+
layer_idx: torch.empty((batch_tokens_dim, config.num_features), device=device, dtype=dtype)
292+
for layer_idx in preactivations_dict.keys()
293+
}
294+
295+
concatenated_preactivations_original = torch.cat(ordered_preactivations_original, dim=1)
296+
concatenated_preactivations_normalized = torch.cat(ordered_preactivations_normalized, dim=1)
297+
298+
k_val: int
299+
if config.batchtopk_k is not None:
300+
k_val = int(config.batchtopk_k)
301+
else:
302+
k_val = concatenated_preactivations_original.size(1)
303+
304+
mask_shape = concatenated_preactivations_original.shape
305+
mask = torch.empty(mask_shape, dtype=torch.bool, device=device)
306+
307+
if world_size > 1:
308+
if rank == 0:
309+
local_mask = BatchTopK._compute_mask(
310+
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
311+
)
312+
mask.copy_(local_mask)
313+
dist.broadcast(mask, src=0, group=process_group)
314+
else:
315+
dist.broadcast(mask, src=0, group=process_group)
316+
else:
317+
mask = BatchTopK._compute_mask(
318+
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
319+
)
320+
321+
activated_concatenated = concatenated_preactivations_original * mask.to(dtype)
322+
323+
activations_dict: Dict[int, torch.Tensor] = {}
324+
current_total_feature_offset = 0
325+
for original_layer_idx, num_features_this_layer in layer_feature_sizes:
326+
activated_segment = activated_concatenated[
327+
:, current_total_feature_offset : current_total_feature_offset + num_features_this_layer
328+
]
329+
activations_dict[original_layer_idx] = activated_segment
330+
current_total_feature_offset += num_features_this_layer
331+
return activations_dict
332+
333+
334+
def _apply_token_topk_helper(
335+
preactivations_dict: Dict[int, torch.Tensor],
336+
config: CLTConfig,
337+
device: torch.device,
338+
dtype: torch.dtype,
339+
rank: int,
340+
process_group: Optional[ProcessGroup],
341+
) -> Dict[int, torch.Tensor]:
342+
"""Helper to apply TokenTopK globally across concatenated layer pre-activations."""
343+
world_size = 1
344+
if process_group is not None and dist.is_initialized():
345+
world_size = dist.get_world_size(process_group)
346+
347+
if not preactivations_dict:
348+
logger_helpers.warning(f"Rank {rank}: _apply_token_topk_helper received empty preactivations_dict.")
349+
return {}
350+
351+
ordered_preactivations_original: List[torch.Tensor] = []
352+
ordered_preactivations_normalized: List[torch.Tensor] = []
353+
layer_feature_sizes: List[Tuple[int, int]] = []
354+
355+
first_valid_preact = next((p for p in preactivations_dict.values() if p.numel() > 0), None)
356+
if first_valid_preact is None:
357+
logger_helpers.warning(
358+
f"Rank {rank}: No valid preactivations found in dict for TokenTopK. Returning empty dict."
359+
)
360+
return {
361+
layer_idx: torch.empty((0, config.num_features), device=device, dtype=dtype)
362+
for layer_idx in preactivations_dict.keys()
363+
}
364+
batch_tokens_dim = first_valid_preact.shape[0]
365+
366+
for layer_idx in range(config.num_layers):
367+
if layer_idx in preactivations_dict:
368+
preact_orig = preactivations_dict[layer_idx]
369+
preact_orig = preact_orig.to(device=device, dtype=dtype)
370+
current_num_features = preact_orig.shape[1] if preact_orig.numel() > 0 else config.num_features
371+
372+
if preact_orig.numel() == 0:
373+
if batch_tokens_dim > 0:
374+
zeros_shape = (batch_tokens_dim, current_num_features)
375+
ordered_preactivations_original.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
376+
ordered_preactivations_normalized.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
377+
elif preact_orig.shape[0] != batch_tokens_dim:
378+
logger_helpers.warning(
379+
f"Rank {rank} Layer {layer_idx}: Mismatched batch dim ({preact_orig.shape[0]} vs {batch_tokens_dim}) for TokenTopK. Using zeros."
380+
)
381+
zeros_shape = (batch_tokens_dim, current_num_features)
382+
ordered_preactivations_original.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
383+
ordered_preactivations_normalized.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
384+
else:
385+
ordered_preactivations_original.append(preact_orig)
386+
mean = preact_orig.mean(dim=0, keepdim=True)
387+
std = preact_orig.std(dim=0, keepdim=True)
388+
preact_norm = (preact_orig - mean) / (std + 1e-6)
389+
ordered_preactivations_normalized.append(preact_norm)
390+
layer_feature_sizes.append((layer_idx, current_num_features))
391+
392+
if not ordered_preactivations_original:
393+
logger_helpers.warning(
394+
f"Rank {rank}: No tensors collected after iterating layers for TokenTopK. Returning empty activations."
395+
)
396+
return {
397+
layer_idx: torch.empty((batch_tokens_dim, config.num_features), device=device, dtype=dtype)
398+
for layer_idx in preactivations_dict.keys()
399+
}
400+
401+
concatenated_preactivations_original = torch.cat(ordered_preactivations_original, dim=1)
402+
concatenated_preactivations_normalized = torch.cat(ordered_preactivations_normalized, dim=1)
403+
404+
k_val_float: float
405+
if hasattr(config, "topk_k") and config.topk_k is not None:
406+
k_val_float = float(config.topk_k)
407+
else:
408+
k_val_float = float(concatenated_preactivations_original.size(1))
409+
410+
mask_shape = concatenated_preactivations_original.shape
411+
mask = torch.empty(mask_shape, dtype=torch.bool, device=device)
412+
413+
if world_size > 1:
414+
if rank == 0:
415+
local_mask = TokenTopK._compute_mask(
416+
concatenated_preactivations_original,
417+
k_val_float,
418+
concatenated_preactivations_normalized,
419+
)
420+
mask.copy_(local_mask)
421+
dist.broadcast(mask, src=0, group=process_group)
422+
else:
423+
dist.broadcast(mask, src=0, group=process_group)
424+
else:
425+
mask = TokenTopK._compute_mask(
426+
concatenated_preactivations_original, k_val_float, concatenated_preactivations_normalized
427+
)
428+
429+
activated_concatenated = concatenated_preactivations_original * mask.to(dtype)
430+
431+
activations_dict: Dict[int, torch.Tensor] = {}
432+
current_total_feature_offset = 0
433+
for original_layer_idx, num_features_this_layer in layer_feature_sizes:
434+
activated_segment = activated_concatenated[
435+
:, current_total_feature_offset : current_total_feature_offset + num_features_this_layer
436+
]
437+
activations_dict[original_layer_idx] = activated_segment
438+
current_total_feature_offset += num_features_this_layer
439+
return activations_dict

0 commit comments

Comments
 (0)