1
+ """Tests for Elastic Weight Consolidation (EWC) module."""
2
+
3
+ import pytest
4
+ import torch
5
+ import numpy as np
6
+ from adaptive_classifier import AdaptiveClassifier
7
+ from adaptive_classifier .ewc import EWC
8
+ import torch .nn as nn
9
+
10
+
11
+ @pytest .fixture
12
+ def simple_model ():
13
+ """Create a simple neural network for testing."""
14
+ class SimpleModel (nn .Module ):
15
+ def __init__ (self , input_dim = 10 , num_classes = 3 ):
16
+ super ().__init__ ()
17
+ self .fc = nn .Linear (input_dim , num_classes )
18
+
19
+ def forward (self , x ):
20
+ return self .fc (x )
21
+
22
+ return SimpleModel ()
23
+
24
+
25
+ @pytest .fixture
26
+ def small_dataset ():
27
+ """Create a small dataset for testing."""
28
+ # Create embeddings and labels
29
+ embeddings = torch .randn (33 , 10 ) # 33 samples to test edge case
30
+ labels = torch .tensor ([0 , 1 , 2 ] * 11 ) # 3 classes repeated
31
+ return torch .utils .data .TensorDataset (embeddings , labels )
32
+
33
+
34
+ def test_ewc_single_batch_edge_case (simple_model , small_dataset ):
35
+ """Test EWC with dataset size that creates single-sample batch.
36
+
37
+ This tests the fix for the squeeze() bug that occurred when
38
+ the last batch had only 1 sample.
39
+ """
40
+ device = 'cpu'
41
+
42
+ # This should not raise an error anymore
43
+ ewc = EWC (
44
+ simple_model ,
45
+ small_dataset ,
46
+ device = device ,
47
+ ewc_lambda = 100.0
48
+ )
49
+
50
+ assert ewc is not None
51
+ assert ewc .fisher_info is not None
52
+ assert ewc .old_params is not None
53
+
54
+
55
+ def test_ewc_various_batch_sizes ():
56
+ """Test EWC with various dataset sizes to ensure robustness."""
57
+ class SimpleModel (nn .Module ):
58
+ def __init__ (self ):
59
+ super ().__init__ ()
60
+ self .fc = nn .Linear (10 , 3 )
61
+
62
+ def forward (self , x ):
63
+ return self .fc (x )
64
+
65
+ # Test with different dataset sizes that create different batch scenarios
66
+ test_sizes = [1 , 31 , 32 , 33 , 64 , 65 , 100 ] # Various edge cases
67
+
68
+ for size in test_sizes :
69
+ model = SimpleModel ()
70
+ embeddings = torch .randn (size , 10 )
71
+ labels = torch .randint (0 , 3 , (size ,))
72
+ dataset = torch .utils .data .TensorDataset (embeddings , labels )
73
+
74
+ # Should not raise any errors
75
+ ewc = EWC (model , dataset , device = 'cpu' , ewc_lambda = 100.0 )
76
+
77
+ # Verify EWC was initialized properly
78
+ assert ewc .fisher_info is not None
79
+ assert len (ewc .fisher_info ) > 0
80
+
81
+ # Test EWC loss computation
82
+ loss = ewc .ewc_loss (batch_size = 32 )
83
+ assert loss is not None
84
+ assert loss .item () >= 0 # Loss should be non-negative
85
+
86
+
87
+ def test_adaptive_classifier_with_many_classes ():
88
+ """Test AdaptiveClassifier with many classes (simulates Banking77 scenario)."""
89
+ # Set seed for reproducibility
90
+ np .random .seed (42 )
91
+ torch .manual_seed (42 )
92
+
93
+ # Create classifier
94
+ classifier = AdaptiveClassifier ('distilbert-base-uncased' , device = 'cpu' )
95
+
96
+ # Simulate many classes with few examples each
97
+ num_classes = 20
98
+ examples_per_class = 3
99
+
100
+ texts = []
101
+ labels = []
102
+
103
+ for class_id in range (num_classes ):
104
+ class_name = f"class_{ class_id } "
105
+ for example_id in range (examples_per_class ):
106
+ texts .append (f"This is example { example_id } for { class_name } " )
107
+ labels .append (class_name )
108
+
109
+ # Add examples in batches (this should trigger EWC when new classes appear)
110
+ batch_size = 10
111
+ for i in range (0 , len (texts ), batch_size ):
112
+ batch_texts = texts [i :i + batch_size ]
113
+ batch_labels = labels [i :i + batch_size ]
114
+
115
+ # This should not raise any errors
116
+ classifier .add_examples (batch_texts , batch_labels )
117
+
118
+ # Verify classifier works
119
+ test_text = "This is a test example"
120
+ predictions = classifier .predict (test_text , k = 3 )
121
+
122
+ assert predictions is not None
123
+ assert len (predictions ) <= 3
124
+ assert all (isinstance (p [0 ], str ) for p in predictions ) # Labels are strings
125
+ assert all (isinstance (p [1 ], float ) for p in predictions ) # Scores are floats
126
+
127
+
128
+ def test_ewc_loss_computation (simple_model , small_dataset ):
129
+ """Test that EWC loss is computed correctly."""
130
+ device = 'cpu'
131
+
132
+ # Initialize EWC
133
+ ewc = EWC (
134
+ simple_model ,
135
+ small_dataset ,
136
+ device = device ,
137
+ ewc_lambda = 100.0
138
+ )
139
+
140
+ # Modify model parameters slightly
141
+ for param in simple_model .parameters ():
142
+ param .data += 0.1
143
+
144
+ # Compute EWC loss
145
+ loss = ewc .ewc_loss ()
146
+
147
+ # Loss should be positive since we changed parameters
148
+ assert loss .item () > 0
149
+
150
+ # Test with batch size normalization
151
+ loss_normalized = ewc .ewc_loss (batch_size = 32 )
152
+ assert loss_normalized .item () > 0
153
+ assert loss_normalized .item () != loss .item () # Should be different due to normalization
154
+
155
+
156
+ def test_progressive_class_addition ():
157
+ """Test adding classes progressively (triggers EWC multiple times)."""
158
+ classifier = AdaptiveClassifier ('distilbert-base-uncased' , device = 'cpu' )
159
+
160
+ # Phase 1: Add initial classes
161
+ phase1_texts = ["Good product" , "Bad service" , "Average quality" ]
162
+ phase1_labels = ["positive" , "negative" , "neutral" ]
163
+ classifier .add_examples (phase1_texts , phase1_labels )
164
+
165
+ # Phase 2: Add new classes (should trigger EWC)
166
+ phase2_texts = ["Need help" , "Bug report" , "Feature request" ]
167
+ phase2_labels = ["support" , "bug" , "feature" ]
168
+ classifier .add_examples (phase2_texts , phase2_labels )
169
+
170
+ # Phase 3: Add more examples to existing classes
171
+ phase3_texts = ["Excellent!" , "Terrible!" , "It's okay" ]
172
+ phase3_labels = ["positive" , "negative" , "neutral" ]
173
+ classifier .add_examples (phase3_texts , phase3_labels )
174
+
175
+ # Phase 4: Add more new classes (should trigger EWC again)
176
+ phase4_texts = ["Urgent issue" , "Question about pricing" ]
177
+ phase4_labels = ["urgent" , "inquiry" ]
178
+ classifier .add_examples (phase4_texts , phase4_labels )
179
+
180
+ # Verify all classes are learned
181
+ expected_classes = {"positive" , "negative" , "neutral" , "support" ,
182
+ "bug" , "feature" , "urgent" , "inquiry" }
183
+
184
+ for label in expected_classes :
185
+ assert label in classifier .label_to_id
186
+
187
+ # Test prediction
188
+ test_text = "This is wonderful!"
189
+ predictions = classifier .predict (test_text , k = 3 )
190
+ assert predictions is not None
191
+ assert len (predictions ) > 0
192
+
193
+
194
+ def test_ewc_with_empty_batch_edge_case ():
195
+ """Test EWC handles edge cases gracefully."""
196
+ class TinyModel (nn .Module ):
197
+ def __init__ (self ):
198
+ super ().__init__ ()
199
+ self .fc = nn .Linear (5 , 2 )
200
+
201
+ def forward (self , x ):
202
+ return self .fc (x )
203
+
204
+ model = TinyModel ()
205
+
206
+ # Create a tiny dataset
207
+ embeddings = torch .randn (1 , 5 ) # Single sample
208
+ labels = torch .tensor ([0 ])
209
+ dataset = torch .utils .data .TensorDataset (embeddings , labels )
210
+
211
+ # Should handle single sample without errors
212
+ ewc = EWC (model , dataset , device = 'cpu' , ewc_lambda = 50.0 )
213
+
214
+ assert ewc is not None
215
+ loss = ewc .ewc_loss ()
216
+ assert loss .item () >= 0
0 commit comments