Skip to content

Commit d9e8a4e

Browse files
committed
MuonClip draft in optax
1 parent 20cd38d commit d9e8a4e

File tree

1 file changed

+148
-0
lines changed

1 file changed

+148
-0
lines changed

optax/contrib/_muon.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,151 @@ def muon(
297297
lambda x: 'muon' if x.ndim == 2 else 'adam', params
298298
),
299299
)
300+
301+
302+
class MuonClipState(NamedTuple):
303+
eta: list[jax.Array] | None = None
304+
305+
306+
def scale_by_clip_muon() -> base.GradientTransformationExtraArgs:
307+
"""
308+
Rescale ("clip") the weight after gradient updates by a factor of eta_bar from
309+
a factor of eta.
310+
311+
Optax uses additive updates and the weight is assumed stored scale, the update
312+
is:
313+
314+
g = d(eta * W) # gradient of the scaled parameter
315+
eta * W + updates = eta_bar * W_bar = eta_bar * (W + d(W))
316+
317+
updates = (eta_bar - eta) * W + eta_bar * d(W)
318+
319+
updates = (eta_bar - eta) * W + eta_bar / eta * d(eta * W)
320+
"""
321+
322+
def init_fn(params):
323+
del params
324+
return MuonClipState()
325+
326+
def update_fn(updates, state, params=None, eta_bar=None):
327+
assert eta_bar is not None
328+
flat_updates, update_struct = jax.tree.flatten(updates)
329+
flat_params = jax.tree.leaves(params)
330+
if len(flat_updates) != len(flat_params) != len(eta_bar):
331+
raise ValueError("In MuonClip, the length of updates, params, eta_bar and"
332+
f" eta must be equal, but got {len(flat_updates)=},"
333+
f" {len(flat_params)=}, {len(eta_bar)=}.")
334+
if state.eta is None:
335+
state.eta = eta_bar
336+
if len(state.eta_bar) != len(eta_bar):
337+
raise ValueError("In MuonClip, the length of eta_bar and eta must be"
338+
f" equal, but got {len(eta_bar)=}, {len(state.eta)=}.")
339+
eps = jnp.finfo(flat_params[0]).eps if len(flat_params) else 1e-7
340+
341+
updates = [
342+
(eta_bar - eta) * param + eta_bar / jnp.maximum(eta, eps) * update
343+
for eta_bar, eta, param, update in zip(eta_bar, state.eta, flat_params,
344+
flat_updates)]
345+
return (
346+
jax.tree.unflatten(updates, update_struct), MuonClipState(eta=eta_bar)
347+
)
348+
349+
return base.GradientTransformationExtraArgs(init_fn, update_fn)
350+
351+
352+
def _example_qk_proj_param_label_fn(params: base.Params):
353+
params_with_paths, treedef = jax.tree.flatten_with_path(params)
354+
355+
def mask_fn(path_param):
356+
path, param = path_param
357+
str_path = "/".join(map(str, path))
358+
if param.ndim == 2 and "q_proj" in str_path:
359+
return "muon_clip"
360+
elif param.ndim == 2:
361+
return "muon"
362+
else:
363+
return "adam"
364+
365+
return jax.tree.unflatten(treedef, jax.tree.map(mask_fn, params_with_paths))
366+
367+
368+
def muon_clip(
369+
learning_rate: base.ScalarOrSchedule,
370+
ns_coeffs: Union[
371+
tuple[float, float, float],
372+
tuple[tuple[float, float, float], ...],
373+
] = (3.4445, -4.7750, 2.0315),
374+
ns_steps: int = 5,
375+
beta: float = 0.95,
376+
eps: float = 1e-8,
377+
weight_decay: float = 0.0,
378+
weight_decay_mask: Optional[
379+
Union[Any, Callable[[base.Params], Any]]
380+
] = None,
381+
mu_dtype: Optional[chex.ArrayDType] = None,
382+
*,
383+
nesterov: bool = True,
384+
adaptive: bool = False,
385+
adam_b1: float = 0.9,
386+
adam_b2: float = 0.999,
387+
adam_eps_root: float = 0.0,
388+
adam_weight_decay: float = 0.0,
389+
qk_proj_param_label_fn: Callable[[base.Params], Any]
390+
) -> base.GradientTransformation:
391+
r"""MuonClip: Muon Optimizer with q_proj and k_proj clipping.
392+
393+
TODO
394+
"""
395+
396+
def param_label_fn(params: base.Params) -> Any:
397+
mask = qk_proj_param_label_fn(params)
398+
labels = set(jax.tree.leaves(mask))
399+
if not all(label in {"muon", "muon_clip", "adam"} for label in labels):
400+
raise ValueError(
401+
"qk_proj_param_label_fn must return a mask with labels in "
402+
f"{'muon', 'muon_clip', 'adam'}, but got {set(labels) = }."
403+
)
404+
return mask
405+
406+
return combine.partition(
407+
transforms={
408+
'muon_noclip': combine.chain(
409+
scale_by_muon(
410+
ns_coeffs=ns_coeffs,
411+
ns_steps=ns_steps,
412+
beta=beta,
413+
eps=eps,
414+
mu_dtype=mu_dtype,
415+
nesterov=nesterov,
416+
adaptive=adaptive,
417+
),
418+
transform.add_decayed_weights(weight_decay, weight_decay_mask),
419+
transform.scale_by_learning_rate(learning_rate),
420+
),
421+
'muon_clip': combine.chain(
422+
scale_by_muon(
423+
ns_coeffs=ns_coeffs,
424+
ns_steps=ns_steps,
425+
beta=beta,
426+
eps=eps,
427+
mu_dtype=mu_dtype,
428+
nesterov=nesterov,
429+
adaptive=adaptive,
430+
),
431+
transform.add_decayed_weights(weight_decay, weight_decay_mask),
432+
transform.scale_by_learning_rate(learning_rate),
433+
scale_by_clip_muon(),
434+
),
435+
'adam': alias.adamw(
436+
learning_rate=learning_rate,
437+
b1=adam_b1,
438+
b2=adam_b2,
439+
eps=eps,
440+
eps_root=adam_eps_root,
441+
weight_decay=adam_weight_decay,
442+
mu_dtype=mu_dtype,
443+
nesterov=nesterov,
444+
),
445+
},
446+
param_labels=param_label_fn,
447+
)

0 commit comments

Comments
 (0)