Skip to content

Conversation

rdyro
Copy link
Collaborator

@rdyro rdyro commented Jul 15, 2025

No description provided.

def mask_fn(path_param):
path, param = path_param
str_path = "/".join(map(str, path))
if param.ndim == 2 and "q_proj" in str_path:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't you also use moun_clip for "k_proj" ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, right, definitely

return jax.tree.unflatten(treedef, jax.tree.map(mask_fn, params_with_paths))


def muon_clip(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this also take alpha param and be used as:

Query scaling factor: etaalpha
Key scaling factor: eta
(1−alpha)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eta here is just "a" scalar, so perhaps the user can pass the correctly exponentiated eta

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants