@@ -297,3 +297,151 @@ def muon(
297
297
lambda x : 'muon' if x .ndim == 2 else 'adam' , params
298
298
),
299
299
)
300
+
301
+
302
+ class MuonClipState (NamedTuple ):
303
+ eta : list [jax .Array ] | None = None
304
+
305
+
306
+ def scale_by_clip_muon () -> base .GradientTransformationExtraArgs :
307
+ """
308
+ Rescale ("clip") the weight after gradient updates by a factor of eta_bar from
309
+ a factor of eta.
310
+
311
+ Optax uses additive updates and the weight is assumed stored scale, the update
312
+ is:
313
+
314
+ g = d(eta * W) # gradient of the scaled parameter
315
+ eta * W + updates = eta_bar * W_bar = eta_bar * (W + d(W))
316
+
317
+ updates = (eta_bar - eta) * W + eta_bar * d(W)
318
+
319
+ updates = (eta_bar - eta) * W + eta_bar / eta * d(eta * W)
320
+ """
321
+
322
+ def init_fn (params ):
323
+ del params
324
+ return MuonClipState ()
325
+
326
+ def update_fn (updates , state , params = None , eta_bar = None ):
327
+ assert eta_bar is not None
328
+ flat_updates , update_struct = jax .tree .flatten (updates )
329
+ flat_params = jax .tree .leaves (params )
330
+ if len (flat_updates ) != len (flat_params ) != len (eta_bar ):
331
+ raise ValueError ("In MuonClip, the length of updates, params, eta_bar and"
332
+ f" eta must be equal, but got { len (flat_updates )= } ,"
333
+ f" { len (flat_params )= } , { len (eta_bar )= } ." )
334
+ if state .eta is None :
335
+ state .eta = eta_bar
336
+ if len (state .eta_bar ) != len (eta_bar ):
337
+ raise ValueError ("In MuonClip, the length of eta_bar and eta must be"
338
+ f" equal, but got { len (eta_bar )= } , { len (state .eta )= } ." )
339
+ eps = jnp .finfo (flat_params [0 ]).eps if len (flat_params ) else 1e-7
340
+
341
+ updates = [
342
+ (eta_bar - eta ) * param + eta_bar / jnp .maximum (eta , eps ) * update
343
+ for eta_bar , eta , param , update in zip (eta_bar , state .eta , flat_params ,
344
+ flat_updates )]
345
+ return (
346
+ jax .tree .unflatten (updates , update_struct ), MuonClipState (eta = eta_bar )
347
+ )
348
+
349
+ return base .GradientTransformationExtraArgs (init_fn , update_fn )
350
+
351
+
352
+ def _example_qk_proj_param_label_fn (params : base .Params ):
353
+ params_with_paths , treedef = jax .tree .flatten_with_path (params )
354
+
355
+ def mask_fn (path_param ):
356
+ path , param = path_param
357
+ str_path = "/" .join (map (str , path ))
358
+ if param .ndim == 2 and "q_proj" in str_path :
359
+ return "muon_clip"
360
+ elif param .ndim == 2 :
361
+ return "muon"
362
+ else :
363
+ return "adam"
364
+
365
+ return jax .tree .unflatten (treedef , jax .tree .map (mask_fn , params_with_paths ))
366
+
367
+
368
+ def muon_clip (
369
+ learning_rate : base .ScalarOrSchedule ,
370
+ ns_coeffs : Union [
371
+ tuple [float , float , float ],
372
+ tuple [tuple [float , float , float ], ...],
373
+ ] = (3.4445 , - 4.7750 , 2.0315 ),
374
+ ns_steps : int = 5 ,
375
+ beta : float = 0.95 ,
376
+ eps : float = 1e-8 ,
377
+ weight_decay : float = 0.0 ,
378
+ weight_decay_mask : Optional [
379
+ Union [Any , Callable [[base .Params ], Any ]]
380
+ ] = None ,
381
+ mu_dtype : Optional [chex .ArrayDType ] = None ,
382
+ * ,
383
+ nesterov : bool = True ,
384
+ adaptive : bool = False ,
385
+ adam_b1 : float = 0.9 ,
386
+ adam_b2 : float = 0.999 ,
387
+ adam_eps_root : float = 0.0 ,
388
+ adam_weight_decay : float = 0.0 ,
389
+ qk_proj_param_label_fn : Callable [[base .Params ], Any ]
390
+ ) -> base .GradientTransformation :
391
+ r"""MuonClip: Muon Optimizer with q_proj and k_proj clipping.
392
+
393
+ TODO
394
+ """
395
+
396
+ def param_label_fn (params : base .Params ) -> Any :
397
+ mask = qk_proj_param_label_fn (params )
398
+ labels = set (jax .tree .leaves (mask ))
399
+ if not all (label in {"muon" , "muon_clip" , "adam" } for label in labels ):
400
+ raise ValueError (
401
+ "qk_proj_param_label_fn must return a mask with labels in "
402
+ f"{ 'muon' , 'muon_clip' , 'adam' } , but got { set (labels ) = } ."
403
+ )
404
+ return mask
405
+
406
+ return combine .partition (
407
+ transforms = {
408
+ 'muon_noclip' : combine .chain (
409
+ scale_by_muon (
410
+ ns_coeffs = ns_coeffs ,
411
+ ns_steps = ns_steps ,
412
+ beta = beta ,
413
+ eps = eps ,
414
+ mu_dtype = mu_dtype ,
415
+ nesterov = nesterov ,
416
+ adaptive = adaptive ,
417
+ ),
418
+ transform .add_decayed_weights (weight_decay , weight_decay_mask ),
419
+ transform .scale_by_learning_rate (learning_rate ),
420
+ ),
421
+ 'muon_clip' : combine .chain (
422
+ scale_by_muon (
423
+ ns_coeffs = ns_coeffs ,
424
+ ns_steps = ns_steps ,
425
+ beta = beta ,
426
+ eps = eps ,
427
+ mu_dtype = mu_dtype ,
428
+ nesterov = nesterov ,
429
+ adaptive = adaptive ,
430
+ ),
431
+ transform .add_decayed_weights (weight_decay , weight_decay_mask ),
432
+ transform .scale_by_learning_rate (learning_rate ),
433
+ scale_by_clip_muon (),
434
+ ),
435
+ 'adam' : alias .adamw (
436
+ learning_rate = learning_rate ,
437
+ b1 = adam_b1 ,
438
+ b2 = adam_b2 ,
439
+ eps = eps ,
440
+ eps_root = adam_eps_root ,
441
+ weight_decay = adam_weight_decay ,
442
+ mu_dtype = mu_dtype ,
443
+ nesterov = nesterov ,
444
+ ),
445
+ },
446
+ param_labels = param_label_fn ,
447
+ )
0 commit comments