Skip to content

Commit ba9145c

Browse files
author
Curt Tigges
committed
added new unit tests for activations
1 parent 39148c1 commit ba9145c

17 files changed

+824
-4610
lines changed

.github/workflows/python-tests.yml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
name: Python Tests
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
- develop # Or your primary development branch
8+
pull_request:
9+
branches:
10+
- main
11+
12+
jobs:
13+
test:
14+
runs-on: ubuntu-latest
15+
strategy:
16+
matrix:
17+
python-version: ["3.9", "3.10", "3.11"] # Specify python versions
18+
19+
steps:
20+
- uses: actions/checkout@v4
21+
22+
- name: Set up Python ${{ matrix.python-version }}
23+
uses: actions/setup-python@v4
24+
with:
25+
python-version: ${{ matrix.python-version }}
26+
27+
- name: Install Poetry
28+
run: |
29+
curl -sSL https://install.python-poetry.org | python3 -
30+
echo "$HOME/.local/bin" >> $GITHUB_PATH
31+
# Alternatively, if not using Poetry, or have a requirements.txt:
32+
# run: pip install -r requirements.txt
33+
34+
- name: Install dependencies
35+
run: poetry install --no-interaction --no-root
36+
# If you have dev dependencies for pytest, e.g. in a [tool.poetry.group.dev.dependencies]
37+
# run: poetry install --no-interaction --no-root --with dev
38+
# Or if using pip with requirements.txt:
39+
# run: pip install -r requirements-dev.txt # (if you have a separate dev requirements)
40+
# run: pip install pytest # or ensure pytest is in your main requirements
41+
42+
- name: Run tests with pytest
43+
run: poetry run pytest tests/
44+
# Or if not using poetry:
45+
# run: pytest tests/

clt/models/activations.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,24 @@ def backward(ctx, *grad_outputs: torch.Tensor) -> Tuple[Optional[torch.Tensor],
197197
grad_threshold_per_element = grad_output * local_grad_theta
198198

199199
if grad_threshold_per_element.dim() > threshold.dim():
200+
# Handles cases like input (B,F), threshold (F) or input (F), threshold (scalar)
200201
dims_to_sum = tuple(range(grad_threshold_per_element.dim() - threshold.dim()))
201202
grad_threshold = grad_threshold_per_element.sum(dim=dims_to_sum)
202-
if threshold.shape != torch.Size([]):
203+
# Ensure final shape matches threshold, especially if sum squeezed dimensions
204+
if grad_threshold.shape != threshold.shape:
203205
grad_threshold = grad_threshold.reshape(threshold.shape)
204-
else:
206+
elif grad_threshold_per_element.dim() == threshold.dim():
207+
# Handles cases like input (F), threshold (F), or input [1], threshold [1]
208+
grad_threshold = grad_threshold_per_element
209+
# Defensive reshape, though shapes should ideally match here.
210+
if grad_threshold.shape != threshold.shape:
211+
grad_threshold = grad_threshold.reshape(threshold.shape)
212+
else: # grad_threshold_per_element.dim() < threshold.dim()
213+
# This case is less common (e.g. input scalar, threshold vector - not typical for this op).
214+
# Defaulting to sum and reshape, primarily for scalar threshold case.
205215
grad_threshold = grad_threshold_per_element.sum()
216+
if grad_threshold.shape != threshold.shape:
217+
grad_threshold = grad_threshold.reshape(threshold.shape)
206218
return grad_input, grad_threshold, None
207219

208220

clt/training/trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,8 +580,6 @@ def train(self, eval_every: int = 1000) -> CrossLayerTranscoder:
580580
tok_cnt_t = torch.tensor([tok_cnt], device=self.device)
581581
gathered = [torch.zeros_like(tok_cnt_t) for _ in range(self.world_size)]
582582
dist.all_gather(gathered, tok_cnt_t)
583-
if self.rank == 0:
584-
print("Batch token-count per rank:", [int(x.item()) for x in gathered])
585583

586584
except StopIteration:
587585
# Rank 0 prints message

tests/integration/test_activation_store.py

Lines changed: 0 additions & 200 deletions
This file was deleted.

0 commit comments

Comments
 (0)