Skip to content

Commit c9fabf9

Browse files
Curt TiggesCurt Tigges
authored andcommitted
corrected inputs to skip layers
1 parent 0909d36 commit c9fabf9

File tree

2 files changed

+127
-122
lines changed

2 files changed

+127
-122
lines changed

clt/config/clt_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,8 @@ class TrainingConfig:
182182
optimizer: Literal["adam", "adamw"] = "adamw"
183183
optimizer_beta1: Optional[float] = None # Beta1 for Adam/AdamW (default: 0.9)
184184
optimizer_beta2: Optional[float] = None # Beta2 for Adam/AdamW (default: 0.999)
185+
optimizer_states_dtype: Literal["fp32", "model_dtype"] = "model_dtype" # Dtype for optimizer states
186+
enable_stochastic_rounding: bool = False # Enable stochastic rounding for bf16 (requires optimizer_states_dtype="fp32")
185187
# Learning rate scheduler type. "linear_final20" keeps LR constant for the first 80% of
186188
# training and then linearly decays it to 0 for the final 20% (configurable via lr_scheduler_params).
187189
lr_scheduler: Optional[Literal["linear", "cosine", "linear_final20"]] = "linear"

clt/models/decoder.py

Lines changed: 125 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -41,46 +41,50 @@ def __init__(
4141
# Initialize decoders based on tying configuration
4242
if config.decoder_tying == "per_source":
4343
# Tied decoders: one decoder per source layer
44-
self.decoders = nn.ModuleList([
45-
RowParallelLinear(
46-
in_features=self.config.num_features,
47-
out_features=self.config.d_model,
48-
bias=True,
49-
process_group=self.process_group,
50-
input_is_parallel=False,
51-
d_model_for_init=self.config.d_model,
52-
num_layers_for_init=self.config.num_layers,
53-
device=self.device,
54-
dtype=self.dtype,
55-
)
56-
for _ in range(self.config.num_layers)
57-
])
44+
self.decoders = nn.ModuleList(
45+
[
46+
RowParallelLinear(
47+
in_features=self.config.num_features,
48+
out_features=self.config.d_model,
49+
bias=True,
50+
process_group=self.process_group,
51+
input_is_parallel=False,
52+
d_model_for_init=self.config.d_model,
53+
num_layers_for_init=self.config.num_layers,
54+
device=self.device,
55+
dtype=self.dtype,
56+
)
57+
for _ in range(self.config.num_layers)
58+
]
59+
)
5860
elif config.decoder_tying == "per_target":
5961
# Tied decoders: one decoder per target layer (EleutherAI style)
60-
self.decoders = nn.ModuleList([
61-
RowParallelLinear(
62-
in_features=self.config.num_features,
63-
out_features=self.config.d_model,
64-
bias=True,
65-
process_group=self.process_group,
66-
input_is_parallel=False,
67-
d_model_for_init=self.config.d_model,
68-
num_layers_for_init=self.config.num_layers,
69-
device=self.device,
70-
dtype=self.dtype,
71-
)
72-
for _ in range(self.config.num_layers)
73-
])
74-
62+
self.decoders = nn.ModuleList(
63+
[
64+
RowParallelLinear(
65+
in_features=self.config.num_features,
66+
out_features=self.config.d_model,
67+
bias=True,
68+
process_group=self.process_group,
69+
input_is_parallel=False,
70+
d_model_for_init=self.config.d_model,
71+
num_layers_for_init=self.config.num_layers,
72+
device=self.device,
73+
dtype=self.dtype,
74+
)
75+
for _ in range(self.config.num_layers)
76+
]
77+
)
78+
7579
# Initialize decoder weights to zeros for tied decoders (both per_source and per_target)
7680
if config.decoder_tying in ["per_source", "per_target"]:
7781
for decoder in self.decoders:
7882
nn.init.zeros_(decoder.weight)
79-
if hasattr(decoder, 'bias_param') and decoder.bias_param is not None:
83+
if hasattr(decoder, "bias_param") and decoder.bias_param is not None:
8084
nn.init.zeros_(decoder.bias_param)
81-
elif hasattr(decoder, 'bias') and decoder.bias is not None:
85+
elif hasattr(decoder, "bias") and decoder.bias is not None:
8286
nn.init.zeros_(decoder.bias)
83-
87+
8488
# Note: EleutherAI doesn't have per-target scale/bias parameters
8589
# These have been removed to match their architecture exactly
8690
else:
@@ -103,64 +107,75 @@ def __init__(
103107
}
104108
)
105109
# Note: EleutherAI doesn't have per-target scale/bias parameters
106-
110+
107111
# Initialize skip connection weights if enabled
108112
if config.skip_connection:
109113
if config.decoder_tying in ["per_source", "per_target"]:
110114
# For tied decoders, one skip connection per target layer
111-
self.skip_weights = nn.ParameterList([
112-
nn.Parameter(torch.zeros(self.config.d_model, self.config.d_model,
113-
device=self.device, dtype=self.dtype))
114-
for _ in range(self.config.num_layers)
115-
])
115+
self.skip_weights = nn.ParameterList(
116+
[
117+
nn.Parameter(
118+
torch.zeros(self.config.d_model, self.config.d_model, device=self.device, dtype=self.dtype)
119+
)
120+
for _ in range(self.config.num_layers)
121+
]
122+
)
116123
else:
117124
# For untied decoders, one skip connection per src->tgt pair
118-
self.skip_weights = nn.ParameterDict({
119-
f"{src_layer}->{tgt_layer}": nn.Parameter(
120-
torch.zeros(self.config.d_model, self.config.d_model,
121-
device=self.device, dtype=self.dtype)
122-
)
123-
for src_layer in range(self.config.num_layers)
124-
for tgt_layer in range(src_layer, self.config.num_layers)
125-
})
125+
self.skip_weights = nn.ParameterDict(
126+
{
127+
f"{src_layer}->{tgt_layer}": nn.Parameter(
128+
torch.zeros(self.config.d_model, self.config.d_model, device=self.device, dtype=self.dtype)
129+
)
130+
for src_layer in range(self.config.num_layers)
131+
for tgt_layer in range(src_layer, self.config.num_layers)
132+
}
133+
)
126134
else:
127135
self.skip_weights = None
128-
136+
129137
# Initialize feature_offset and feature_scale (indexed by target layer)
130138
# These match EleutherAI's post_enc and post_enc_scale
131139
# Note: Currently only implemented for tied decoders to match EleutherAI
132140
# For per_source tying, these would need to be indexed differently
133141
if config.decoder_tying in ["per_source", "per_target"]:
134142
features_per_rank = config.num_features // self.world_size if self.world_size > 1 else config.num_features
135-
143+
136144
if config.enable_feature_offset:
137145
# Initialize feature_offset for each target layer
138-
self.feature_offset = nn.ParameterList([
139-
nn.Parameter(torch.zeros(features_per_rank, device=self.device, dtype=self.dtype))
140-
for _ in range(config.num_layers)
141-
])
146+
self.feature_offset = nn.ParameterList(
147+
[
148+
nn.Parameter(torch.zeros(features_per_rank, device=self.device, dtype=self.dtype))
149+
for _ in range(config.num_layers)
150+
]
151+
)
142152
else:
143153
self.feature_offset = None
144-
154+
145155
if config.enable_feature_scale:
146156
# Initialize feature_scale for each target layer
147157
# First target layer gets ones, rest get small non-zero values to allow gradient flow
148-
self.feature_scale = nn.ParameterList([
149-
nn.Parameter(
150-
torch.ones(features_per_rank, device=self.device, dtype=self.dtype) if i == 0
151-
else torch.full((features_per_rank,), 0.1, device=self.device, dtype=self.dtype)
152-
)
153-
for i in range(config.num_layers)
154-
])
158+
self.feature_scale = nn.ParameterList(
159+
[
160+
nn.Parameter(
161+
torch.ones(features_per_rank, device=self.device, dtype=self.dtype)
162+
if i == 0
163+
else torch.full((features_per_rank,), 0.1, device=self.device, dtype=self.dtype)
164+
)
165+
for i in range(config.num_layers)
166+
]
167+
)
155168
else:
156169
self.feature_scale = None
157170
else:
158171
self.feature_offset = None
159172
self.feature_scale = None
160-
173+
161174
self.register_buffer("_cached_decoder_norms", None, persistent=False)
162175

163-
def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Optional[Dict[int, torch.Tensor]] = None) -> torch.Tensor:
176+
def decode(
177+
self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Optional[Dict[int, torch.Tensor]] = None
178+
) -> torch.Tensor:
164179
"""Decode the feature activations to reconstruct outputs at the specified layer.
165180
166181
Input activations `a` are expected to be the *full* tensors.
@@ -192,8 +207,10 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Opti
192207

193208
if self.config.decoder_tying == "per_target":
194209
# EleutherAI style: sum activations first, then decode once
195-
summed_activation = torch.zeros((batch_dim_size, self.config.num_features), device=self.device, dtype=self.dtype)
196-
210+
summed_activation = torch.zeros(
211+
(batch_dim_size, self.config.num_features), device=self.device, dtype=self.dtype
212+
)
213+
197214
for src_layer in range(layer_idx + 1):
198215
if src_layer in a:
199216
activation_tensor = a[src_layer].to(device=self.device, dtype=self.dtype)
@@ -211,48 +228,28 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Opti
211228
if self.feature_offset is not None or self.feature_scale is not None:
212229
# Get non-zero positions (selected features)
213230
nonzero_mask = activation_tensor != 0
214-
231+
215232
if nonzero_mask.any():
216233
# Apply transformations only to selected features
217234
activation_tensor = activation_tensor.clone()
218235
batch_indices, feature_indices = nonzero_mask.nonzero(as_tuple=True)
219-
236+
220237
if self.feature_offset is not None:
221238
# Apply offset only to non-zero features
222239
offset_values = self.feature_offset[layer_idx][feature_indices]
223240
activation_tensor[batch_indices, feature_indices] += offset_values
224-
241+
225242
if self.feature_scale is not None:
226243
# Apply scale only to non-zero features
227244
scale_values = self.feature_scale[layer_idx][feature_indices]
228245
activation_tensor[batch_indices, feature_indices] *= scale_values
229-
246+
230247
summed_activation += activation_tensor
231-
248+
232249
# Now decode ONCE with the summed activation
233250
decoder = self.decoders[layer_idx]
234251
reconstruction = decoder(summed_activation)
235-
236-
# Apply skip connections from source inputs if enabled
237-
if self.skip_weights is not None and source_inputs is not None:
238-
skip_weight = self.skip_weights[layer_idx]
239-
# Add skip connections from each source layer that contributed
240-
for src_layer in range(layer_idx + 1):
241-
if src_layer in source_inputs:
242-
source_input = source_inputs[src_layer].to(device=self.device, dtype=self.dtype)
243-
# Flatten if needed
244-
original_shape = source_input.shape
245-
if source_input.dim() == 3:
246-
source_input_2d = source_input.view(-1, source_input.shape[-1])
247-
else:
248-
source_input_2d = source_input
249-
# Apply skip: source @ W_skip^T
250-
skip_contribution = source_input_2d @ skip_weight.T
251-
# Reshape back if needed
252-
if source_input.dim() == 3:
253-
skip_contribution = skip_contribution.view(original_shape)
254-
reconstruction += skip_contribution
255-
252+
256253
else:
257254
# Original logic for per_source and untied decoders
258255
for src_layer in range(layer_idx + 1):
@@ -271,17 +268,17 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Opti
271268
if self.config.decoder_tying == "per_source":
272269
# Get non-zero positions (selected features)
273270
nonzero_mask = activation_tensor != 0
274-
271+
275272
if nonzero_mask.any():
276273
# Apply transformations only to selected features
277274
activation_tensor = activation_tensor.clone()
278275
batch_indices, feature_indices = nonzero_mask.nonzero(as_tuple=True)
279-
276+
280277
if self.feature_offset is not None:
281278
# Apply offset indexed by target layer
282279
offset_values = self.feature_offset[layer_idx][feature_indices]
283280
activation_tensor[batch_indices, feature_indices] += offset_values
284-
281+
285282
if self.feature_scale is not None:
286283
# Apply scale indexed by target layer
287284
scale_values = self.feature_scale[layer_idx][feature_indices]
@@ -291,10 +288,17 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Opti
291288
# Use tied decoder for the source layer
292289
decoder = self.decoders[src_layer]
293290
decoded = decoder(activation_tensor)
294-
295-
# Apply skip connection from this source input if enabled
296-
if self.skip_weights is not None and source_inputs is not None and src_layer in source_inputs:
297-
skip_weight = self.skip_weights[layer_idx]
291+
292+
else:
293+
# Use untied decoder for (src, tgt) pair
294+
decoder = self.decoders[f"{src_layer}->{layer_idx}"]
295+
decoded = decoder(activation_tensor)
296+
297+
# Apply skip connection from this source input if enabled
298+
if self.skip_weights is not None and source_inputs is not None and src_layer in source_inputs:
299+
skip_key = f"{src_layer}->{layer_idx}"
300+
if skip_key in self.skip_weights:
301+
skip_weight = self.skip_weights[skip_key]
298302
source_input = source_inputs[src_layer].to(device=self.device, dtype=self.dtype)
299303
# Flatten if needed
300304
original_shape = source_input.shape
@@ -308,32 +312,31 @@ def decode(self, a: Dict[int, torch.Tensor], layer_idx: int, source_inputs: Opti
308312
if source_input.dim() == 3:
309313
skip_contribution = skip_contribution.view(original_shape)
310314
decoded += skip_contribution
311-
else:
312-
# Use untied decoder for (src, tgt) pair
313-
decoder = self.decoders[f"{src_layer}->{layer_idx}"]
314-
decoded = decoder(activation_tensor)
315-
316-
# Apply skip connection from this source input if enabled
317-
if self.skip_weights is not None and source_inputs is not None and src_layer in source_inputs:
318-
skip_key = f"{src_layer}->{layer_idx}"
319-
if skip_key in self.skip_weights:
320-
skip_weight = self.skip_weights[skip_key]
321-
source_input = source_inputs[src_layer].to(device=self.device, dtype=self.dtype)
322-
# Flatten if needed
323-
original_shape = source_input.shape
324-
if source_input.dim() == 3:
325-
source_input_2d = source_input.view(-1, source_input.shape[-1])
326-
else:
327-
source_input_2d = source_input
328-
# Apply skip: source @ W_skip^T
329-
skip_contribution = source_input_2d @ skip_weight.T
330-
# Reshape back if needed
331-
if source_input.dim() == 3:
332-
skip_contribution = skip_contribution.view(original_shape)
333-
decoded += skip_contribution
334-
315+
335316
reconstruction += decoded
336-
317+
318+
# For tied decoders, apply a single skip connection from the target layer's own input
319+
if self.config.decoder_tying in ["per_source", "per_target"]:
320+
if self.skip_weights is not None and source_inputs is not None and layer_idx in source_inputs:
321+
skip_weight = self.skip_weights[layer_idx]
322+
source_input = source_inputs[layer_idx].to(device=self.device, dtype=self.dtype)
323+
324+
# Flatten if needed
325+
original_shape = source_input.shape
326+
if source_input.dim() == 3:
327+
source_input_2d = source_input.view(-1, source_input.shape[-1])
328+
else:
329+
source_input_2d = source_input
330+
331+
# Apply skip: source @ W_skip^T
332+
skip_contribution = source_input_2d @ skip_weight.T
333+
334+
# Reshape back if needed
335+
if source_input.dim() == 3:
336+
skip_contribution = skip_contribution.view(original_shape)
337+
338+
reconstruction += skip_contribution
339+
337340
return reconstruction
338341

339342
def get_decoder_norms(self) -> torch.Tensor:

0 commit comments

Comments
 (0)