Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Python Tests

on:
push:
branches:
- main
- develop # Or your primary development branch
pull_request:
branches:
- main

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"] # Specify python versions

steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install Poetry
run: |
curl -sSL https://install.python-poetry.org | python3 -
echo "$HOME/.local/bin" >> $GITHUB_PATH
# Alternatively, if not using Poetry, or have a requirements.txt:
# run: pip install -r requirements.txt

- name: Install dependencies
run: poetry install --no-interaction --no-root
# If you have dev dependencies for pytest, e.g. in a [tool.poetry.group.dev.dependencies]
# run: poetry install --no-interaction --no-root --with dev
# Or if using pip with requirements.txt:
# run: pip install -r requirements-dev.txt # (if you have a separate dev requirements)
# run: pip install pytest # or ensure pytest is in your main requirements

- name: Run tests with pytest
run: poetry run pytest tests/
# Or if not using poetry:
# run: pytest tests/
243 changes: 240 additions & 3 deletions clt/models/activations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import torch
from typing import Optional, Tuple
from typing import Optional, Tuple, Dict, List
import torch.distributed as dist
import logging
from clt.config import CLTConfig
from torch.distributed import ProcessGroup


class BatchTopK(torch.autograd.Function):
Expand Down Expand Up @@ -193,10 +197,243 @@ def backward(ctx, *grad_outputs: torch.Tensor) -> Tuple[Optional[torch.Tensor],
grad_threshold_per_element = grad_output * local_grad_theta

if grad_threshold_per_element.dim() > threshold.dim():
# Handles cases like input (B,F), threshold (F) or input (F), threshold (scalar)
dims_to_sum = tuple(range(grad_threshold_per_element.dim() - threshold.dim()))
grad_threshold = grad_threshold_per_element.sum(dim=dims_to_sum)
if threshold.shape != torch.Size([]):
# Ensure final shape matches threshold, especially if sum squeezed dimensions
if grad_threshold.shape != threshold.shape:
grad_threshold = grad_threshold.reshape(threshold.shape)
else:
elif grad_threshold_per_element.dim() == threshold.dim():
# Handles cases like input (F), threshold (F), or input [1], threshold [1]
grad_threshold = grad_threshold_per_element
# Defensive reshape, though shapes should ideally match here.
if grad_threshold.shape != threshold.shape:
grad_threshold = grad_threshold.reshape(threshold.shape)
else: # grad_threshold_per_element.dim() < threshold.dim()
# This case is less common (e.g. input scalar, threshold vector - not typical for this op).
# Defaulting to sum and reshape, primarily for scalar threshold case.
grad_threshold = grad_threshold_per_element.sum()
if grad_threshold.shape != threshold.shape:
grad_threshold = grad_threshold.reshape(threshold.shape)
return grad_input, grad_threshold, None


# --- Helper functions for applying BatchTopK/TokenTopK globally --- #
# These were previously in clt.models.encoding.py

logger_helpers = logging.getLogger(__name__ + ".helpers") # Use a sub-logger


def _apply_batch_topk_helper(
preactivations_dict: Dict[int, torch.Tensor],
config: CLTConfig,
device: torch.device,
dtype: torch.dtype,
rank: int,
process_group: Optional[ProcessGroup],
) -> Dict[int, torch.Tensor]:
"""Helper to apply BatchTopK globally across concatenated layer pre-activations."""

world_size = 1
if process_group is not None and dist.is_initialized():
world_size = dist.get_world_size(process_group)

if not preactivations_dict:
logger_helpers.warning(f"Rank {rank}: _apply_batch_topk_helper received empty preactivations_dict.")
return {}

ordered_preactivations_original: List[torch.Tensor] = []
ordered_preactivations_normalized: List[torch.Tensor] = []
layer_feature_sizes: List[Tuple[int, int]] = []

first_valid_preact = next((p for p in preactivations_dict.values() if p.numel() > 0), None)
if first_valid_preact is None:
logger_helpers.warning(
f"Rank {rank}: No valid preactivations found in dict for BatchTopK. Returning empty dict."
)
return {
layer_idx: torch.empty((0, config.num_features), device=device, dtype=dtype)
for layer_idx in preactivations_dict.keys()
}
batch_tokens_dim = first_valid_preact.shape[0]

for layer_idx in range(config.num_layers):
if layer_idx in preactivations_dict:
preact_orig = preactivations_dict[layer_idx]
preact_orig = preact_orig.to(device=device, dtype=dtype)
current_num_features = preact_orig.shape[1] if preact_orig.numel() > 0 else config.num_features

if preact_orig.numel() == 0:
if batch_tokens_dim > 0:
zeros_shape = (batch_tokens_dim, current_num_features)
ordered_preactivations_original.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
ordered_preactivations_normalized.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
elif preact_orig.shape[0] != batch_tokens_dim:
logger_helpers.warning(
f"Rank {rank} Layer {layer_idx}: Mismatched batch dim ({preact_orig.shape[0]} vs {batch_tokens_dim}). Using zeros."
)
zeros_shape = (batch_tokens_dim, current_num_features)
ordered_preactivations_original.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
ordered_preactivations_normalized.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
else:
ordered_preactivations_original.append(preact_orig)
mean = preact_orig.mean(dim=0, keepdim=True)
std = preact_orig.std(dim=0, keepdim=True)
preact_norm = (preact_orig - mean) / (std + 1e-6)
ordered_preactivations_normalized.append(preact_norm)
layer_feature_sizes.append((layer_idx, current_num_features))

if not ordered_preactivations_original:
logger_helpers.warning(
f"Rank {rank}: No tensors collected after iterating layers for BatchTopK. Returning empty activations."
)
return {
layer_idx: torch.empty((batch_tokens_dim, config.num_features), device=device, dtype=dtype)
for layer_idx in preactivations_dict.keys()
}

concatenated_preactivations_original = torch.cat(ordered_preactivations_original, dim=1)
concatenated_preactivations_normalized = torch.cat(ordered_preactivations_normalized, dim=1)

k_val: int
if config.batchtopk_k is not None:
k_val = int(config.batchtopk_k)
else:
k_val = concatenated_preactivations_original.size(1)

mask_shape = concatenated_preactivations_original.shape
mask = torch.empty(mask_shape, dtype=torch.bool, device=device)

if world_size > 1:
if rank == 0:
local_mask = BatchTopK._compute_mask(
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
)
mask.copy_(local_mask)
dist.broadcast(mask, src=0, group=process_group)
else:
dist.broadcast(mask, src=0, group=process_group)
else:
mask = BatchTopK._compute_mask(
concatenated_preactivations_original, k_val, concatenated_preactivations_normalized
)

activated_concatenated = concatenated_preactivations_original * mask.to(dtype)

activations_dict: Dict[int, torch.Tensor] = {}
current_total_feature_offset = 0
for original_layer_idx, num_features_this_layer in layer_feature_sizes:
activated_segment = activated_concatenated[
:, current_total_feature_offset : current_total_feature_offset + num_features_this_layer
]
activations_dict[original_layer_idx] = activated_segment
current_total_feature_offset += num_features_this_layer
return activations_dict


def _apply_token_topk_helper(
preactivations_dict: Dict[int, torch.Tensor],
config: CLTConfig,
device: torch.device,
dtype: torch.dtype,
rank: int,
process_group: Optional[ProcessGroup],
) -> Dict[int, torch.Tensor]:
"""Helper to apply TokenTopK globally across concatenated layer pre-activations."""
world_size = 1
if process_group is not None and dist.is_initialized():
world_size = dist.get_world_size(process_group)

if not preactivations_dict:
logger_helpers.warning(f"Rank {rank}: _apply_token_topk_helper received empty preactivations_dict.")
return {}

ordered_preactivations_original: List[torch.Tensor] = []
ordered_preactivations_normalized: List[torch.Tensor] = []
layer_feature_sizes: List[Tuple[int, int]] = []

first_valid_preact = next((p for p in preactivations_dict.values() if p.numel() > 0), None)
if first_valid_preact is None:
logger_helpers.warning(
f"Rank {rank}: No valid preactivations found in dict for TokenTopK. Returning empty dict."
)
return {
layer_idx: torch.empty((0, config.num_features), device=device, dtype=dtype)
for layer_idx in preactivations_dict.keys()
}
batch_tokens_dim = first_valid_preact.shape[0]

for layer_idx in range(config.num_layers):
if layer_idx in preactivations_dict:
preact_orig = preactivations_dict[layer_idx]
preact_orig = preact_orig.to(device=device, dtype=dtype)
current_num_features = preact_orig.shape[1] if preact_orig.numel() > 0 else config.num_features

if preact_orig.numel() == 0:
if batch_tokens_dim > 0:
zeros_shape = (batch_tokens_dim, current_num_features)
ordered_preactivations_original.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
ordered_preactivations_normalized.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
elif preact_orig.shape[0] != batch_tokens_dim:
logger_helpers.warning(
f"Rank {rank} Layer {layer_idx}: Mismatched batch dim ({preact_orig.shape[0]} vs {batch_tokens_dim}) for TokenTopK. Using zeros."
)
zeros_shape = (batch_tokens_dim, current_num_features)
ordered_preactivations_original.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
ordered_preactivations_normalized.append(torch.zeros(zeros_shape, device=device, dtype=dtype))
else:
ordered_preactivations_original.append(preact_orig)
mean = preact_orig.mean(dim=0, keepdim=True)
std = preact_orig.std(dim=0, keepdim=True)
preact_norm = (preact_orig - mean) / (std + 1e-6)
ordered_preactivations_normalized.append(preact_norm)
layer_feature_sizes.append((layer_idx, current_num_features))

if not ordered_preactivations_original:
logger_helpers.warning(
f"Rank {rank}: No tensors collected after iterating layers for TokenTopK. Returning empty activations."
)
return {
layer_idx: torch.empty((batch_tokens_dim, config.num_features), device=device, dtype=dtype)
for layer_idx in preactivations_dict.keys()
}

concatenated_preactivations_original = torch.cat(ordered_preactivations_original, dim=1)
concatenated_preactivations_normalized = torch.cat(ordered_preactivations_normalized, dim=1)

k_val_float: float
if hasattr(config, "topk_k") and config.topk_k is not None:
k_val_float = float(config.topk_k)
else:
k_val_float = float(concatenated_preactivations_original.size(1))

mask_shape = concatenated_preactivations_original.shape
mask = torch.empty(mask_shape, dtype=torch.bool, device=device)

if world_size > 1:
if rank == 0:
local_mask = TokenTopK._compute_mask(
concatenated_preactivations_original,
k_val_float,
concatenated_preactivations_normalized,
)
mask.copy_(local_mask)
dist.broadcast(mask, src=0, group=process_group)
else:
dist.broadcast(mask, src=0, group=process_group)
else:
mask = TokenTopK._compute_mask(
concatenated_preactivations_original, k_val_float, concatenated_preactivations_normalized
)

activated_concatenated = concatenated_preactivations_original * mask.to(dtype)

activations_dict: Dict[int, torch.Tensor] = {}
current_total_feature_offset = 0
for original_layer_idx, num_features_this_layer in layer_feature_sizes:
activated_segment = activated_concatenated[
:, current_total_feature_offset : current_total_feature_offset + num_features_this_layer
]
activations_dict[original_layer_idx] = activated_segment
current_total_feature_offset += num_features_this_layer
return activations_dict
Loading
Loading