Skip to content

Commit 89c903d

Browse files
authored
Merge pull request #32 from ilyalasy/jumprelu-load
Remove jumprelu_threshold validation rule
2 parents aa73277 + 8470968 commit 89c903d

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

clt/config/clt_config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __post_init__(self):
4040
assert self.num_features > 0, "Number of features must be positive"
4141
assert self.num_layers > 0, "Number of layers must be positive"
4242
assert self.d_model > 0, "Model dimension must be positive"
43-
assert self.jumprelu_threshold > 0, "JumpReLU threshold must be positive"
4443
valid_norm_methods = ["auto", "estimated_mean_std", "none"]
4544
assert (
4645
self.normalization_method in valid_norm_methods

clt/models/theta.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ def __init__(
4848
self.rank = dist.get_rank(process_group)
4949

5050
if self.config.activation_fn == "jumprelu":
51+
if self.config.jumprelu_threshold == 0:
52+
logger.warning(
53+
f"Rank {self.rank}: jumprelu_threshold is 0, expecting to load log_threshold from checkpoint."
54+
)
5155
initial_threshold_val = torch.ones(
5256
config.num_layers, config.num_features, device=self.device, dtype=self.dtype
5357
) * torch.log(torch.tensor(config.jumprelu_threshold, device=self.device, dtype=self.dtype))

0 commit comments

Comments
 (0)