Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions models/hrm/hrm_act_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
embed_init_std = 1.0 / self.embed_scale

self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, init_std=embed_init_std, cast_to=self.forward_dtype)
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)

self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) # ceil div
if self.config.puzzle_emb_ndim > 0:
Expand All @@ -133,9 +133,14 @@ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None:
self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)])
self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)])

# --- CORRECTED CODE BLOCK ---
# Initial states
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
h_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1)
self.register_buffer('H_init', h_init_tensor)

l_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1)
self.register_buffer('L_init', l_init_tensor)
# --- END OF CORRECTION ---

# Q head special init
# Init Q to (almost) zero for faster learning during bootstrapping
Expand Down
9 changes: 6 additions & 3 deletions models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ def __init__(self, dim, max_position_embeddings, base, device=None):

# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = nn.Buffer(emb.cos(), persistent=False)
self.sin_cached = nn.Buffer(emb.sin(), persistent=False)

# --- CORRECTED CODE BLOCK ---
self.register_buffer('cos_cached', emb.cos(), persistent=False)
self.register_buffer('sin_cached', emb.sin(), persistent=False)
# --- END OF CORRECTION ---

def forward(self):
return self.cos_cached, self.sin_cached
Expand Down Expand Up @@ -142,7 +145,7 @@ def __init__(self, hidden_size: int, expansion: float):
inter = _find_multiple(round(expansion * hidden_size * 2 / 3), 256)

self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
self.down_proj = CastedLinear(inter, hidden_size, bias=False)
self.down_proj = CastedLinear(inter, hidden_size, bias=False)

def forward(self, x):
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
Expand Down
23 changes: 12 additions & 11 deletions models/sparse_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,18 @@ def __init__(self, num_embeddings: int, embedding_dim: int, batch_size: int, ini
super().__init__()
self.cast_to = cast_to

# --- CORRECTED CODE BLOCK ---
# Real Weights
# Truncated LeCun normal init
self.weights = nn.Buffer(
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std), persistent=True
)
weights_tensor = trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std)
self.register_buffer('weights', weights_tensor)

# Local weights and IDs
# Local embeddings, with gradient, not persistent
self.local_weights = nn.Buffer(torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False)
# Local embedding IDs, not persistent
self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
local_weights_tensor = torch.zeros(batch_size, embedding_dim, requires_grad=True)
self.register_buffer('local_weights', local_weights_tensor, persistent=False)

local_ids_tensor = torch.zeros(batch_size, dtype=torch.int32)
self.register_buffer('local_ids', local_ids_tensor, persistent=False)
# --- END OF CORRECTION ---

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if not self.training:
Expand Down Expand Up @@ -81,7 +82,7 @@ def step(self, closure=None): # type: ignore
assert local_weights_grad is not None
assert local_ids is not None
assert weights is not None

# Apply SignSGD
# Adam ≈ SignSGD if gradient is very sparse
_sparse_emb_signsgd_dist(
Expand Down Expand Up @@ -112,10 +113,10 @@ def _sparse_emb_signsgd_dist(

if world_size > 1:
all_weights_grad = torch.empty((world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device)
all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)

dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
dist.all_gather_into_tensor(all_ids, local_ids)
dist.all_gather_into_tensor(all_ids, local_ids)

# Unique
grad_ids, inv = all_ids.unique(return_inverse=True)
Expand Down