@@ -41,46 +41,50 @@ def __init__(
41
41
# Initialize decoders based on tying configuration
42
42
if config .decoder_tying == "per_source" :
43
43
# Tied decoders: one decoder per source layer
44
- self .decoders = nn .ModuleList ([
45
- RowParallelLinear (
46
- in_features = self .config .num_features ,
47
- out_features = self .config .d_model ,
48
- bias = True ,
49
- process_group = self .process_group ,
50
- input_is_parallel = False ,
51
- d_model_for_init = self .config .d_model ,
52
- num_layers_for_init = self .config .num_layers ,
53
- device = self .device ,
54
- dtype = self .dtype ,
55
- )
56
- for _ in range (self .config .num_layers )
57
- ])
44
+ self .decoders = nn .ModuleList (
45
+ [
46
+ RowParallelLinear (
47
+ in_features = self .config .num_features ,
48
+ out_features = self .config .d_model ,
49
+ bias = True ,
50
+ process_group = self .process_group ,
51
+ input_is_parallel = False ,
52
+ d_model_for_init = self .config .d_model ,
53
+ num_layers_for_init = self .config .num_layers ,
54
+ device = self .device ,
55
+ dtype = self .dtype ,
56
+ )
57
+ for _ in range (self .config .num_layers )
58
+ ]
59
+ )
58
60
elif config .decoder_tying == "per_target" :
59
61
# Tied decoders: one decoder per target layer (EleutherAI style)
60
- self .decoders = nn .ModuleList ([
61
- RowParallelLinear (
62
- in_features = self .config .num_features ,
63
- out_features = self .config .d_model ,
64
- bias = True ,
65
- process_group = self .process_group ,
66
- input_is_parallel = False ,
67
- d_model_for_init = self .config .d_model ,
68
- num_layers_for_init = self .config .num_layers ,
69
- device = self .device ,
70
- dtype = self .dtype ,
71
- )
72
- for _ in range (self .config .num_layers )
73
- ])
74
-
62
+ self .decoders = nn .ModuleList (
63
+ [
64
+ RowParallelLinear (
65
+ in_features = self .config .num_features ,
66
+ out_features = self .config .d_model ,
67
+ bias = True ,
68
+ process_group = self .process_group ,
69
+ input_is_parallel = False ,
70
+ d_model_for_init = self .config .d_model ,
71
+ num_layers_for_init = self .config .num_layers ,
72
+ device = self .device ,
73
+ dtype = self .dtype ,
74
+ )
75
+ for _ in range (self .config .num_layers )
76
+ ]
77
+ )
78
+
75
79
# Initialize decoder weights to zeros for tied decoders (both per_source and per_target)
76
80
if config .decoder_tying in ["per_source" , "per_target" ]:
77
81
for decoder in self .decoders :
78
82
nn .init .zeros_ (decoder .weight )
79
- if hasattr (decoder , ' bias_param' ) and decoder .bias_param is not None :
83
+ if hasattr (decoder , " bias_param" ) and decoder .bias_param is not None :
80
84
nn .init .zeros_ (decoder .bias_param )
81
- elif hasattr (decoder , ' bias' ) and decoder .bias is not None :
85
+ elif hasattr (decoder , " bias" ) and decoder .bias is not None :
82
86
nn .init .zeros_ (decoder .bias )
83
-
87
+
84
88
# Note: EleutherAI doesn't have per-target scale/bias parameters
85
89
# These have been removed to match their architecture exactly
86
90
else :
@@ -103,64 +107,75 @@ def __init__(
103
107
}
104
108
)
105
109
# Note: EleutherAI doesn't have per-target scale/bias parameters
106
-
110
+
107
111
# Initialize skip connection weights if enabled
108
112
if config .skip_connection :
109
113
if config .decoder_tying in ["per_source" , "per_target" ]:
110
114
# For tied decoders, one skip connection per target layer
111
- self .skip_weights = nn .ParameterList ([
112
- nn .Parameter (torch .zeros (self .config .d_model , self .config .d_model ,
113
- device = self .device , dtype = self .dtype ))
114
- for _ in range (self .config .num_layers )
115
- ])
115
+ self .skip_weights = nn .ParameterList (
116
+ [
117
+ nn .Parameter (
118
+ torch .zeros (self .config .d_model , self .config .d_model , device = self .device , dtype = self .dtype )
119
+ )
120
+ for _ in range (self .config .num_layers )
121
+ ]
122
+ )
116
123
else :
117
124
# For untied decoders, one skip connection per src->tgt pair
118
- self .skip_weights = nn .ParameterDict ({
119
- f"{ src_layer } ->{ tgt_layer } " : nn .Parameter (
120
- torch .zeros (self .config .d_model , self .config .d_model ,
121
- device = self .device , dtype = self .dtype )
122
- )
123
- for src_layer in range (self .config .num_layers )
124
- for tgt_layer in range (src_layer , self .config .num_layers )
125
- })
125
+ self .skip_weights = nn .ParameterDict (
126
+ {
127
+ f"{ src_layer } ->{ tgt_layer } " : nn .Parameter (
128
+ torch .zeros (self .config .d_model , self .config .d_model , device = self .device , dtype = self .dtype )
129
+ )
130
+ for src_layer in range (self .config .num_layers )
131
+ for tgt_layer in range (src_layer , self .config .num_layers )
132
+ }
133
+ )
126
134
else :
127
135
self .skip_weights = None
128
-
136
+
129
137
# Initialize feature_offset and feature_scale (indexed by target layer)
130
138
# These match EleutherAI's post_enc and post_enc_scale
131
139
# Note: Currently only implemented for tied decoders to match EleutherAI
132
140
# For per_source tying, these would need to be indexed differently
133
141
if config .decoder_tying in ["per_source" , "per_target" ]:
134
142
features_per_rank = config .num_features // self .world_size if self .world_size > 1 else config .num_features
135
-
143
+
136
144
if config .enable_feature_offset :
137
145
# Initialize feature_offset for each target layer
138
- self .feature_offset = nn .ParameterList ([
139
- nn .Parameter (torch .zeros (features_per_rank , device = self .device , dtype = self .dtype ))
140
- for _ in range (config .num_layers )
141
- ])
146
+ self .feature_offset = nn .ParameterList (
147
+ [
148
+ nn .Parameter (torch .zeros (features_per_rank , device = self .device , dtype = self .dtype ))
149
+ for _ in range (config .num_layers )
150
+ ]
151
+ )
142
152
else :
143
153
self .feature_offset = None
144
-
154
+
145
155
if config .enable_feature_scale :
146
156
# Initialize feature_scale for each target layer
147
157
# First target layer gets ones, rest get small non-zero values to allow gradient flow
148
- self .feature_scale = nn .ParameterList ([
149
- nn .Parameter (
150
- torch .ones (features_per_rank , device = self .device , dtype = self .dtype ) if i == 0
151
- else torch .full ((features_per_rank ,), 0.1 , device = self .device , dtype = self .dtype )
152
- )
153
- for i in range (config .num_layers )
154
- ])
158
+ self .feature_scale = nn .ParameterList (
159
+ [
160
+ nn .Parameter (
161
+ torch .ones (features_per_rank , device = self .device , dtype = self .dtype )
162
+ if i == 0
163
+ else torch .full ((features_per_rank ,), 0.1 , device = self .device , dtype = self .dtype )
164
+ )
165
+ for i in range (config .num_layers )
166
+ ]
167
+ )
155
168
else :
156
169
self .feature_scale = None
157
170
else :
158
171
self .feature_offset = None
159
172
self .feature_scale = None
160
-
173
+
161
174
self .register_buffer ("_cached_decoder_norms" , None , persistent = False )
162
175
163
- def decode (self , a : Dict [int , torch .Tensor ], layer_idx : int , source_inputs : Optional [Dict [int , torch .Tensor ]] = None ) -> torch .Tensor :
176
+ def decode (
177
+ self , a : Dict [int , torch .Tensor ], layer_idx : int , source_inputs : Optional [Dict [int , torch .Tensor ]] = None
178
+ ) -> torch .Tensor :
164
179
"""Decode the feature activations to reconstruct outputs at the specified layer.
165
180
166
181
Input activations `a` are expected to be the *full* tensors.
@@ -192,8 +207,10 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Opti
192
207
193
208
if self .config .decoder_tying == "per_target" :
194
209
# EleutherAI style: sum activations first, then decode once
195
- summed_activation = torch .zeros ((batch_dim_size , self .config .num_features ), device = self .device , dtype = self .dtype )
196
-
210
+ summed_activation = torch .zeros (
211
+ (batch_dim_size , self .config .num_features ), device = self .device , dtype = self .dtype
212
+ )
213
+
197
214
for src_layer in range (layer_idx + 1 ):
198
215
if src_layer in a :
199
216
activation_tensor = a [src_layer ].to (device = self .device , dtype = self .dtype )
@@ -211,48 +228,28 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Opti
211
228
if self .feature_offset is not None or self .feature_scale is not None :
212
229
# Get non-zero positions (selected features)
213
230
nonzero_mask = activation_tensor != 0
214
-
231
+
215
232
if nonzero_mask .any ():
216
233
# Apply transformations only to selected features
217
234
activation_tensor = activation_tensor .clone ()
218
235
batch_indices , feature_indices = nonzero_mask .nonzero (as_tuple = True )
219
-
236
+
220
237
if self .feature_offset is not None :
221
238
# Apply offset only to non-zero features
222
239
offset_values = self .feature_offset [layer_idx ][feature_indices ]
223
240
activation_tensor [batch_indices , feature_indices ] += offset_values
224
-
241
+
225
242
if self .feature_scale is not None :
226
243
# Apply scale only to non-zero features
227
244
scale_values = self .feature_scale [layer_idx ][feature_indices ]
228
245
activation_tensor [batch_indices , feature_indices ] *= scale_values
229
-
246
+
230
247
summed_activation += activation_tensor
231
-
248
+
232
249
# Now decode ONCE with the summed activation
233
250
decoder = self .decoders [layer_idx ]
234
251
reconstruction = decoder (summed_activation )
235
-
236
- # Apply skip connections from source inputs if enabled
237
- if self .skip_weights is not None and source_inputs is not None :
238
- skip_weight = self .skip_weights [layer_idx ]
239
- # Add skip connections from each source layer that contributed
240
- for src_layer in range (layer_idx + 1 ):
241
- if src_layer in source_inputs :
242
- source_input = source_inputs [src_layer ].to (device = self .device , dtype = self .dtype )
243
- # Flatten if needed
244
- original_shape = source_input .shape
245
- if source_input .dim () == 3 :
246
- source_input_2d = source_input .view (- 1 , source_input .shape [- 1 ])
247
- else :
248
- source_input_2d = source_input
249
- # Apply skip: source @ W_skip^T
250
- skip_contribution = source_input_2d @ skip_weight .T
251
- # Reshape back if needed
252
- if source_input .dim () == 3 :
253
- skip_contribution = skip_contribution .view (original_shape )
254
- reconstruction += skip_contribution
255
-
252
+
256
253
else :
257
254
# Original logic for per_source and untied decoders
258
255
for src_layer in range (layer_idx + 1 ):
@@ -271,17 +268,17 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Opti
271
268
if self .config .decoder_tying == "per_source" :
272
269
# Get non-zero positions (selected features)
273
270
nonzero_mask = activation_tensor != 0
274
-
271
+
275
272
if nonzero_mask .any ():
276
273
# Apply transformations only to selected features
277
274
activation_tensor = activation_tensor .clone ()
278
275
batch_indices , feature_indices = nonzero_mask .nonzero (as_tuple = True )
279
-
276
+
280
277
if self .feature_offset is not None :
281
278
# Apply offset indexed by target layer
282
279
offset_values = self .feature_offset [layer_idx ][feature_indices ]
283
280
activation_tensor [batch_indices , feature_indices ] += offset_values
284
-
281
+
285
282
if self .feature_scale is not None :
286
283
# Apply scale indexed by target layer
287
284
scale_values = self .feature_scale [layer_idx ][feature_indices ]
@@ -291,10 +288,17 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Opti
291
288
# Use tied decoder for the source layer
292
289
decoder = self .decoders [src_layer ]
293
290
decoded = decoder (activation_tensor )
294
-
295
- # Apply skip connection from this source input if enabled
296
- if self .skip_weights is not None and source_inputs is not None and src_layer in source_inputs :
297
- skip_weight = self .skip_weights [layer_idx ]
291
+
292
+ else :
293
+ # Use untied decoder for (src, tgt) pair
294
+ decoder = self .decoders [f"{ src_layer } ->{ layer_idx } " ]
295
+ decoded = decoder (activation_tensor )
296
+
297
+ # Apply skip connection from this source input if enabled
298
+ if self .skip_weights is not None and source_inputs is not None and src_layer in source_inputs :
299
+ skip_key = f"{ src_layer } ->{ layer_idx } "
300
+ if skip_key in self .skip_weights :
301
+ skip_weight = self .skip_weights [skip_key ]
298
302
source_input = source_inputs [src_layer ].to (device = self .device , dtype = self .dtype )
299
303
# Flatten if needed
300
304
original_shape = source_input .shape
@@ -308,32 +312,31 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Opti
308
312
if source_input .dim () == 3 :
309
313
skip_contribution = skip_contribution .view (original_shape )
310
314
decoded += skip_contribution
311
- else :
312
- # Use untied decoder for (src, tgt) pair
313
- decoder = self .decoders [f"{ src_layer } ->{ layer_idx } " ]
314
- decoded = decoder (activation_tensor )
315
-
316
- # Apply skip connection from this source input if enabled
317
- if self .skip_weights is not None and source_inputs is not None and src_layer in source_inputs :
318
- skip_key = f"{ src_layer } ->{ layer_idx } "
319
- if skip_key in self .skip_weights :
320
- skip_weight = self .skip_weights [skip_key ]
321
- source_input = source_inputs [src_layer ].to (device = self .device , dtype = self .dtype )
322
- # Flatten if needed
323
- original_shape = source_input .shape
324
- if source_input .dim () == 3 :
325
- source_input_2d = source_input .view (- 1 , source_input .shape [- 1 ])
326
- else :
327
- source_input_2d = source_input
328
- # Apply skip: source @ W_skip^T
329
- skip_contribution = source_input_2d @ skip_weight .T
330
- # Reshape back if needed
331
- if source_input .dim () == 3 :
332
- skip_contribution = skip_contribution .view (original_shape )
333
- decoded += skip_contribution
334
-
315
+
335
316
reconstruction += decoded
336
-
317
+
318
+ # For tied decoders, apply a single skip connection from the target layer's own input
319
+ if self .config .decoder_tying in ["per_source" , "per_target" ]:
320
+ if self .skip_weights is not None and source_inputs is not None and layer_idx in source_inputs :
321
+ skip_weight = self .skip_weights [layer_idx ]
322
+ source_input = source_inputs [layer_idx ].to (device = self .device , dtype = self .dtype )
323
+
324
+ # Flatten if needed
325
+ original_shape = source_input .shape
326
+ if source_input .dim () == 3 :
327
+ source_input_2d = source_input .view (- 1 , source_input .shape [- 1 ])
328
+ else :
329
+ source_input_2d = source_input
330
+
331
+ # Apply skip: source @ W_skip^T
332
+ skip_contribution = source_input_2d @ skip_weight .T
333
+
334
+ # Reshape back if needed
335
+ if source_input .dim () == 3 :
336
+ skip_contribution = skip_contribution .view (original_shape )
337
+
338
+ reconstruction += skip_contribution
339
+
337
340
return reconstruction
338
341
339
342
def get_decoder_norms (self ) -> torch .Tensor :
0 commit comments