Skip to content

Commit 5a623bf

Browse files
author
Curt Tigges
committed
corrected activation generation and added tests
1 parent 0ab3bf2 commit 5a623bf

File tree

12 files changed

+1296
-13
lines changed

12 files changed

+1296
-13
lines changed

clt/activation_generation/generator.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -798,30 +798,40 @@ def _write_chunk(
798798
chunks=(min(rows, 16384), d_model),
799799
)
800800

801+
# --- Use a SINGLE permutation shared across all layers --- #
802+
if rows > 0:
803+
shared_perm = torch.randperm(rows, device=next(iter(buf_inp_gpu.values()))[0].device)
804+
else:
805+
# Degenerate case – zero-row chunk (should not normally happen)
806+
shared_perm = None
807+
801808
layer_data_to_write = []
802809
for layer_id in layer_ids:
803810
with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_data_prep"):
804811
with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_concat"):
805812
layer_inp_gpu = torch.cat(buf_inp_gpu[layer_id], dim=0)
806813
layer_tgt_gpu = torch.cat(buf_tgt_gpu[layer_id], dim=0)
807814

808-
with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_permute"):
809-
perm = torch.randperm(rows, device=layer_inp_gpu.device)
810-
layer_inp_gpu_perm = layer_inp_gpu[perm]
811-
layer_tgt_gpu_perm = layer_tgt_gpu[perm]
815+
if shared_perm is not None:
816+
with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_permute"):
817+
layer_inp_gpu_perm = layer_inp_gpu[shared_perm]
818+
layer_tgt_gpu_perm = layer_tgt_gpu[shared_perm]
819+
else:
820+
layer_inp_gpu_perm = layer_inp_gpu
821+
layer_tgt_gpu_perm = layer_tgt_gpu
812822

813823
with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_cpu_transfer"):
814824
layer_inp_cpu = layer_inp_gpu_perm.cpu()
815825
layer_tgt_cpu = layer_tgt_gpu_perm.cpu()
816826

817827
with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_convert_numpy"):
818828
inputs_np = (
819-
layer_inp_cpu.numpy().view(np.uint16)
829+
layer_inp_cpu.view(torch.int16).numpy()
820830
if self.torch_dtype == torch.bfloat16
821831
else layer_inp_cpu.numpy()
822832
)
823833
targets_np = (
824-
layer_tgt_cpu.numpy().view(np.uint16)
834+
layer_tgt_cpu.view(torch.int16).numpy()
825835
if self.torch_dtype == torch.bfloat16
826836
else layer_tgt_cpu.numpy()
827837
)
@@ -855,29 +865,38 @@ def write_layer_data(layer_id_arg: int, inputs_data: np.ndarray, targets_data: n
855865

856866
elif self.cfg.output_format == "npz":
857867
npz_save_dict = {}
868+
# --- Use a SINGLE permutation shared across all layers (same as HDF5 path) --- #
869+
if rows > 0:
870+
shared_perm = torch.randperm(rows, device=next(iter(buf_inp_gpu.values()))[0].device)
871+
else:
872+
shared_perm = None
873+
858874
for layer_id in layer_ids:
859875
with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_data_prep_npz"):
860876
with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_concat_npz"):
861877
layer_inp_gpu = torch.cat(buf_inp_gpu[layer_id], dim=0)
862878
layer_tgt_gpu = torch.cat(buf_tgt_gpu[layer_id], dim=0)
863879

864-
with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_permute_npz"):
865-
perm = torch.randperm(rows, device=layer_inp_gpu.device)
866-
layer_inp_gpu_perm = layer_inp_gpu[perm]
867-
layer_tgt_gpu_perm = layer_tgt_gpu[perm]
880+
if shared_perm is not None:
881+
with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_permute_npz"):
882+
layer_inp_gpu_perm = layer_inp_gpu[shared_perm]
883+
layer_tgt_gpu_perm = layer_tgt_gpu[shared_perm]
884+
else:
885+
layer_inp_gpu_perm = layer_inp_gpu
886+
layer_tgt_gpu_perm = layer_tgt_gpu
868887

869888
with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_cpu_transfer_npz"):
870889
layer_inp_cpu = layer_inp_gpu_perm.cpu()
871890
layer_tgt_cpu = layer_tgt_gpu_perm.cpu()
872891

873892
with self._conditional_measure(f"chunk_{chunk_idx}_layer_{layer_id}_convert_numpy_npz"):
874893
inputs_np = (
875-
layer_inp_cpu.numpy().view(np.uint16)
894+
layer_inp_cpu.view(torch.int16).numpy()
876895
if self.torch_dtype == torch.bfloat16
877896
else layer_inp_cpu.numpy()
878897
)
879898
targets_np = (
880-
layer_tgt_cpu.numpy().view(np.uint16)
899+
layer_tgt_cpu.view(torch.int16).numpy()
881900
if self.torch_dtype == torch.bfloat16
882901
else layer_tgt_cpu.numpy()
883902
)

pytest.ini

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[pytest]
22
markers =
3-
integration: marks tests as integration tests that verify multiple components working together \
3+
integration: marks tests as integration tests that verify multiple components working together \
4+
require_gpu: marks tests that require a GPU (CUDA or MPS) to run \

tests/conftest.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import pytest
2+
import torch
3+
4+
5+
def get_available_devices():
6+
"""Returns available devices, including cpu, mps, and cuda if available."""
7+
devices = ["cpu"]
8+
if torch.cuda.is_available():
9+
devices.append("cuda")
10+
if torch.backends.mps.is_available():
11+
devices.append("mps")
12+
return devices
13+
14+
15+
DEVICES = get_available_devices()
16+
17+
18+
@pytest.fixture(params=DEVICES)
19+
def device(request):
20+
"""Fixture to iterate over all available devices."""
21+
return torch.device(request.param)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import pytest
2+
import torch
3+
import torch.distributed as dist
4+
import torch.multiprocessing as mp
5+
import os
6+
from typing import cast
7+
8+
from clt.config import CLTConfig
9+
from clt.models.clt import CrossLayerTranscoder
10+
11+
12+
def setup_distributed_environment(rank, world_size, port="12356"):
13+
"""Initializes the distributed process group."""
14+
os.environ["MASTER_ADDR"] = "localhost"
15+
os.environ["MASTER_PORT"] = port
16+
dist.init_process_group("gloo", rank=rank, world_size=world_size)
17+
18+
19+
def cleanup_distributed_environment():
20+
"""Cleans up the distributed process group."""
21+
dist.destroy_process_group()
22+
23+
24+
def distributed_test_runner(rank, world_size, test_fn, *args):
25+
"""A wrapper to run a distributed test function."""
26+
setup_distributed_environment(rank, world_size)
27+
try:
28+
test_fn(rank, world_size, *args)
29+
finally:
30+
cleanup_distributed_environment()
31+
32+
33+
# --- Test Functions (to be run in separate processes) ---
34+
35+
36+
def _test_forward_pass_distributed(rank, world_size):
37+
"""
38+
Tests that the forward pass produces identical results on all ranks.
39+
"""
40+
device = torch.device("cpu")
41+
torch.manual_seed(42) # Ensure same model initialization
42+
43+
clt_config = CLTConfig(num_layers=2, d_model=8, num_features=16, activation_fn="relu")
44+
model = CrossLayerTranscoder(config=clt_config, process_group=dist.group.WORLD, device=device)
45+
46+
# All ranks get the same input
47+
torch.manual_seed(123)
48+
sample_inputs = {
49+
0: torch.randn(20, clt_config.d_model, device=device),
50+
1: torch.randn(20, clt_config.d_model, device=device),
51+
}
52+
53+
reconstructions = model.forward(sample_inputs)
54+
loss = torch.mean(reconstructions[0]) # A simple, deterministic loss
55+
56+
# Gather the loss from all ranks
57+
loss_list = [torch.zeros_like(loss) for _ in range(world_size)]
58+
dist.all_gather(loss_list, loss)
59+
60+
# The loss, and therefore the forward pass result, should be identical on all ranks
61+
for other_loss in loss_list:
62+
assert torch.allclose(loss, other_loss), "Forward pass results (losses) differ across ranks"
63+
64+
65+
def _test_sharded_gradient(rank, world_size):
66+
"""
67+
Tests that sharded parameters receive different gradients on each rank.
68+
"""
69+
device = torch.device("cpu")
70+
# Use rank-specific seed for weight initialization to ensure different weights
71+
torch.manual_seed(42 + rank)
72+
73+
clt_config = CLTConfig(num_layers=2, d_model=8, num_features=16, activation_fn="relu")
74+
model = CrossLayerTranscoder(config=clt_config, process_group=dist.group.WORLD, device=device)
75+
76+
# All ranks get the same input
77+
torch.manual_seed(123)
78+
sample_inputs = {0: torch.randn(5, clt_config.d_model, device=device)}
79+
80+
# Forward pass
81+
reconstructions = model.forward(sample_inputs)
82+
83+
# Create a loss that depends on the actual output values
84+
# This will produce different gradients for different weight values
85+
target = torch.randn_like(reconstructions[0])
86+
loss = torch.nn.functional.mse_loss(reconstructions[0], target)
87+
88+
# Backward pass
89+
loss.backward()
90+
91+
# Test gradients of a SHARDED parameter (e.g., Encoder weights)
92+
sharded_grad_optional = model.encoder_module.encoders[0].weight.grad
93+
assert sharded_grad_optional is not None, "Gradient for sharded parameter should exist"
94+
sharded_grad = cast(torch.Tensor, sharded_grad_optional)
95+
96+
# Gather all gradients to compare
97+
grad_list = [torch.zeros_like(sharded_grad) for _ in range(world_size)]
98+
dist.all_gather(grad_list, sharded_grad)
99+
100+
# The gradients for a sharded parameter should be DIFFERENT on each rank
101+
# because each rank has different weights and computes different outputs
102+
assert not torch.allclose(
103+
grad_list[0], grad_list[1], rtol=1e-5, atol=1e-8
104+
), "Gradients for sharded parameters should be different across ranks"
105+
106+
107+
# --- Pytest Test Class ---
108+
109+
110+
@pytest.mark.integration
111+
@pytest.mark.distributed
112+
@pytest.mark.skipif(not dist.is_available(), reason="torch.distributed not available")
113+
class TestCLTDistributed:
114+
def test_forward_pass(self):
115+
world_size = 2
116+
mp.spawn( # type: ignore[attr-defined]
117+
distributed_test_runner,
118+
args=(world_size, _test_forward_pass_distributed),
119+
nprocs=world_size,
120+
join=True,
121+
)
122+
123+
def test_gradient_sharding(self):
124+
world_size = 2
125+
mp.spawn( # type: ignore[attr-defined]
126+
distributed_test_runner,
127+
args=(world_size, _test_sharded_gradient),
128+
nprocs=world_size,
129+
join=True,
130+
)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import pytest
2+
import torch
3+
4+
from clt.config import CLTConfig
5+
from clt.models.clt import CrossLayerTranscoder
6+
7+
8+
def get_available_devices():
9+
"""Returns available devices, including cpu, mps, and cuda if available."""
10+
devices = ["cpu"]
11+
if torch.cuda.is_available():
12+
devices.append("cuda")
13+
if torch.backends.mps.is_available():
14+
devices.append("mps")
15+
return devices
16+
17+
18+
DEVICES = get_available_devices()
19+
20+
21+
@pytest.fixture(params=DEVICES)
22+
def device(request):
23+
"""Fixture to iterate over all available devices."""
24+
return torch.device(request.param)
25+
26+
27+
@pytest.fixture
28+
def clt_config():
29+
"""Provides a basic CLTConfig for end-to-end testing."""
30+
return CLTConfig(
31+
num_layers=2,
32+
d_model=8,
33+
num_features=16,
34+
activation_fn="relu", # Use simple ReLU for gradient checking
35+
)
36+
37+
38+
@pytest.fixture
39+
def clt_model(clt_config, device):
40+
"""Provides a CrossLayerTranscoder instance for integration tests."""
41+
model = CrossLayerTranscoder(
42+
config=clt_config,
43+
process_group=None,
44+
device=device,
45+
)
46+
# Ensure all parameters have requires_grad=True for the backward pass test
47+
for param in model.parameters():
48+
param.requires_grad = True
49+
return model.to(device)
50+
51+
52+
@pytest.fixture
53+
def sample_inputs(clt_config, device):
54+
"""Provides a sample input dictionary with consistent token counts."""
55+
total_tokens = 20
56+
return {
57+
0: torch.randn(total_tokens, clt_config.d_model, device=device),
58+
1: torch.randn(total_tokens, clt_config.d_model, device=device),
59+
}
60+
61+
62+
class TestCLTEndToEnd:
63+
def test_forward_backward_pass(self, clt_model, sample_inputs):
64+
"""
65+
Tests a full forward and backward pass to ensure gradients are computed.
66+
"""
67+
# --- Forward Pass ---
68+
reconstructions = clt_model.forward(sample_inputs)
69+
70+
# --- Loss Calculation ---
71+
# A simple MSE loss between the reconstructions and the original inputs
72+
loss = torch.tensor(0.0, device=clt_model.device, dtype=torch.float32)
73+
for layer_idx, recon_tensor in reconstructions.items():
74+
original_tensor = sample_inputs[layer_idx]
75+
loss += torch.mean((recon_tensor - original_tensor) ** 2)
76+
77+
# --- Backward Pass ---
78+
try:
79+
loss.backward()
80+
except Exception as e:
81+
pytest.fail(f"Backward pass failed with exception: {e}")
82+
83+
# --- Gradient Check ---
84+
# Check that some gradients have been computed. We check a few key parameters.
85+
# Encoder weights for layer 0
86+
assert clt_model.encoder_module.encoders[0].weight.grad is not None
87+
assert torch.all(torch.isfinite(clt_model.encoder_module.encoders[0].weight.grad))
88+
assert not torch.all(clt_model.encoder_module.encoders[0].weight.grad == 0)
89+
90+
# Decoder weights for 0->1
91+
decoder_key = "0->1"
92+
assert clt_model.decoder_module.decoders[decoder_key].weight.grad is not None
93+
assert torch.all(torch.isfinite(clt_model.decoder_module.decoders[decoder_key].weight.grad))
94+
assert not torch.all(clt_model.decoder_module.decoders[decoder_key].weight.grad == 0)
95+
96+
# Decoder bias for 1->1
97+
decoder_key = "1->1"
98+
if clt_model.decoder_module.decoders[decoder_key].bias_param is not None:
99+
assert clt_model.decoder_module.decoders[decoder_key].bias_param.grad is not None
100+
assert torch.all(torch.isfinite(clt_model.decoder_module.decoders[decoder_key].bias_param.grad))
101+
# Note: Bias gradients can sometimes be zero in simple cases, so we don't assert non-zero

0 commit comments

Comments
 (0)