Skip to content

Commit 70bafeb

Browse files
authored
Merge pull request #50 from codelion/fix-ewc-last-batch-train-bug
Fix ewc last batch train bug
2 parents cbabf82 + eb080c3 commit 70bafeb

File tree

3 files changed

+218
-2
lines changed

3 files changed

+218
-2
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
setup(
1717
name="adaptive-classifier",
18-
version="0.0.15",
18+
version="0.0.16",
1919
author="codelion",
2020
author_email="[email protected]",
2121
description="A flexible, adaptive classification system for dynamic text classification",

src/adaptive_classifier/ewc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _compute_fisher(
7878
log_probs = F.log_softmax(outputs, dim=1)
7979

8080
# Sample from output distribution
81-
sampled_labels = torch.multinomial(probs, 1).squeeze()
81+
sampled_labels = torch.multinomial(probs, 1).squeeze(-1)
8282

8383
# Compute loss with sampled labels
8484
loss = F.nll_loss(log_probs, sampled_labels)

tests/test_ewc.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
"""Tests for Elastic Weight Consolidation (EWC) module."""
2+
3+
import pytest
4+
import torch
5+
import numpy as np
6+
from adaptive_classifier import AdaptiveClassifier
7+
from adaptive_classifier.ewc import EWC
8+
import torch.nn as nn
9+
10+
11+
@pytest.fixture
12+
def simple_model():
13+
"""Create a simple neural network for testing."""
14+
class SimpleModel(nn.Module):
15+
def __init__(self, input_dim=10, num_classes=3):
16+
super().__init__()
17+
self.fc = nn.Linear(input_dim, num_classes)
18+
19+
def forward(self, x):
20+
return self.fc(x)
21+
22+
return SimpleModel()
23+
24+
25+
@pytest.fixture
26+
def small_dataset():
27+
"""Create a small dataset for testing."""
28+
# Create embeddings and labels
29+
embeddings = torch.randn(33, 10) # 33 samples to test edge case
30+
labels = torch.tensor([0, 1, 2] * 11) # 3 classes repeated
31+
return torch.utils.data.TensorDataset(embeddings, labels)
32+
33+
34+
def test_ewc_single_batch_edge_case(simple_model, small_dataset):
35+
"""Test EWC with dataset size that creates single-sample batch.
36+
37+
This tests the fix for the squeeze() bug that occurred when
38+
the last batch had only 1 sample.
39+
"""
40+
device = 'cpu'
41+
42+
# This should not raise an error anymore
43+
ewc = EWC(
44+
simple_model,
45+
small_dataset,
46+
device=device,
47+
ewc_lambda=100.0
48+
)
49+
50+
assert ewc is not None
51+
assert ewc.fisher_info is not None
52+
assert ewc.old_params is not None
53+
54+
55+
def test_ewc_various_batch_sizes():
56+
"""Test EWC with various dataset sizes to ensure robustness."""
57+
class SimpleModel(nn.Module):
58+
def __init__(self):
59+
super().__init__()
60+
self.fc = nn.Linear(10, 3)
61+
62+
def forward(self, x):
63+
return self.fc(x)
64+
65+
# Test with different dataset sizes that create different batch scenarios
66+
test_sizes = [1, 31, 32, 33, 64, 65, 100] # Various edge cases
67+
68+
for size in test_sizes:
69+
model = SimpleModel()
70+
embeddings = torch.randn(size, 10)
71+
labels = torch.randint(0, 3, (size,))
72+
dataset = torch.utils.data.TensorDataset(embeddings, labels)
73+
74+
# Should not raise any errors
75+
ewc = EWC(model, dataset, device='cpu', ewc_lambda=100.0)
76+
77+
# Verify EWC was initialized properly
78+
assert ewc.fisher_info is not None
79+
assert len(ewc.fisher_info) > 0
80+
81+
# Test EWC loss computation
82+
loss = ewc.ewc_loss(batch_size=32)
83+
assert loss is not None
84+
assert loss.item() >= 0 # Loss should be non-negative
85+
86+
87+
def test_adaptive_classifier_with_many_classes():
88+
"""Test AdaptiveClassifier with many classes (simulates Banking77 scenario)."""
89+
# Set seed for reproducibility
90+
np.random.seed(42)
91+
torch.manual_seed(42)
92+
93+
# Create classifier
94+
classifier = AdaptiveClassifier('distilbert-base-uncased', device='cpu')
95+
96+
# Simulate many classes with few examples each
97+
num_classes = 20
98+
examples_per_class = 3
99+
100+
texts = []
101+
labels = []
102+
103+
for class_id in range(num_classes):
104+
class_name = f"class_{class_id}"
105+
for example_id in range(examples_per_class):
106+
texts.append(f"This is example {example_id} for {class_name}")
107+
labels.append(class_name)
108+
109+
# Add examples in batches (this should trigger EWC when new classes appear)
110+
batch_size = 10
111+
for i in range(0, len(texts), batch_size):
112+
batch_texts = texts[i:i+batch_size]
113+
batch_labels = labels[i:i+batch_size]
114+
115+
# This should not raise any errors
116+
classifier.add_examples(batch_texts, batch_labels)
117+
118+
# Verify classifier works
119+
test_text = "This is a test example"
120+
predictions = classifier.predict(test_text, k=3)
121+
122+
assert predictions is not None
123+
assert len(predictions) <= 3
124+
assert all(isinstance(p[0], str) for p in predictions) # Labels are strings
125+
assert all(isinstance(p[1], float) for p in predictions) # Scores are floats
126+
127+
128+
def test_ewc_loss_computation(simple_model, small_dataset):
129+
"""Test that EWC loss is computed correctly."""
130+
device = 'cpu'
131+
132+
# Initialize EWC
133+
ewc = EWC(
134+
simple_model,
135+
small_dataset,
136+
device=device,
137+
ewc_lambda=100.0
138+
)
139+
140+
# Modify model parameters slightly
141+
for param in simple_model.parameters():
142+
param.data += 0.1
143+
144+
# Compute EWC loss
145+
loss = ewc.ewc_loss()
146+
147+
# Loss should be positive since we changed parameters
148+
assert loss.item() > 0
149+
150+
# Test with batch size normalization
151+
loss_normalized = ewc.ewc_loss(batch_size=32)
152+
assert loss_normalized.item() > 0
153+
assert loss_normalized.item() != loss.item() # Should be different due to normalization
154+
155+
156+
def test_progressive_class_addition():
157+
"""Test adding classes progressively (triggers EWC multiple times)."""
158+
classifier = AdaptiveClassifier('distilbert-base-uncased', device='cpu')
159+
160+
# Phase 1: Add initial classes
161+
phase1_texts = ["Good product", "Bad service", "Average quality"]
162+
phase1_labels = ["positive", "negative", "neutral"]
163+
classifier.add_examples(phase1_texts, phase1_labels)
164+
165+
# Phase 2: Add new classes (should trigger EWC)
166+
phase2_texts = ["Need help", "Bug report", "Feature request"]
167+
phase2_labels = ["support", "bug", "feature"]
168+
classifier.add_examples(phase2_texts, phase2_labels)
169+
170+
# Phase 3: Add more examples to existing classes
171+
phase3_texts = ["Excellent!", "Terrible!", "It's okay"]
172+
phase3_labels = ["positive", "negative", "neutral"]
173+
classifier.add_examples(phase3_texts, phase3_labels)
174+
175+
# Phase 4: Add more new classes (should trigger EWC again)
176+
phase4_texts = ["Urgent issue", "Question about pricing"]
177+
phase4_labels = ["urgent", "inquiry"]
178+
classifier.add_examples(phase4_texts, phase4_labels)
179+
180+
# Verify all classes are learned
181+
expected_classes = {"positive", "negative", "neutral", "support",
182+
"bug", "feature", "urgent", "inquiry"}
183+
184+
for label in expected_classes:
185+
assert label in classifier.label_to_id
186+
187+
# Test prediction
188+
test_text = "This is wonderful!"
189+
predictions = classifier.predict(test_text, k=3)
190+
assert predictions is not None
191+
assert len(predictions) > 0
192+
193+
194+
def test_ewc_with_empty_batch_edge_case():
195+
"""Test EWC handles edge cases gracefully."""
196+
class TinyModel(nn.Module):
197+
def __init__(self):
198+
super().__init__()
199+
self.fc = nn.Linear(5, 2)
200+
201+
def forward(self, x):
202+
return self.fc(x)
203+
204+
model = TinyModel()
205+
206+
# Create a tiny dataset
207+
embeddings = torch.randn(1, 5) # Single sample
208+
labels = torch.tensor([0])
209+
dataset = torch.utils.data.TensorDataset(embeddings, labels)
210+
211+
# Should handle single sample without errors
212+
ewc = EWC(model, dataset, device='cpu', ewc_lambda=50.0)
213+
214+
assert ewc is not None
215+
loss = ewc.ewc_loss()
216+
assert loss.item() >= 0

0 commit comments

Comments
 (0)