15
15
"""Gradient transformations."""
16
16
17
17
import functools
18
- from typing import NamedTuple , Optional , Union
18
+ from typing import NamedTuple , Optional
19
19
20
20
import chex
21
21
import jax
@@ -44,7 +44,8 @@ class ScaleByRssState(NamedTuple):
44
44
45
45
46
46
def scale_by_rss (
47
- initial_accumulator_value : float = 0.1 , eps : float = 1e-7
47
+ initial_accumulator_value : float | jax .Array = 0.1 ,
48
+ eps : float | jax .Array = 1e-7 ,
48
49
) -> base .GradientTransformation :
49
50
"""Rescale updates by the root of the sum of all squared gradients to date.
50
51
@@ -93,9 +94,9 @@ class ScaleByRmsWithCountState(NamedTuple):
93
94
94
95
95
96
def scale_by_rms (
96
- decay : float = 0.9 ,
97
- eps : float = 1e-8 ,
98
- initial_scale : float = 0.0 ,
97
+ decay : float | jax . Array = 0.9 ,
98
+ eps : float | jax . Array = 1e-8 ,
99
+ initial_scale : float | jax . Array = 0. ,
99
100
eps_in_sqrt : bool = True ,
100
101
bias_correction : bool = False ,
101
102
) -> base .GradientTransformation :
@@ -168,9 +169,9 @@ class ScaleByRStdDevWithCountState(NamedTuple):
168
169
169
170
170
171
def scale_by_stddev (
171
- decay : float = 0.9 ,
172
- eps : float = 1e-8 ,
173
- initial_scale : float = 0.0 ,
172
+ decay : float | jax . Array = 0.9 ,
173
+ eps : float | jax . Array = 1e-8 ,
174
+ initial_scale : float | jax . Array = 0. ,
174
175
eps_in_sqrt : bool = True ,
175
176
bias_correction : bool = False ,
176
177
) -> base .GradientTransformation :
@@ -244,10 +245,10 @@ class ScaleByAdamState(NamedTuple):
244
245
245
246
246
247
def scale_by_adam (
247
- b1 : float = 0.9 ,
248
- b2 : float = 0.999 ,
249
- eps : float = 1e-8 ,
250
- eps_root : float = 0.0 ,
248
+ b1 : float | jax . Array = 0.9 ,
249
+ b2 : float | jax . Array = 0.999 ,
250
+ eps : float | jax . Array = 1e-8 ,
251
+ eps_root : float | jax . Array = 0.0 ,
251
252
mu_dtype : Optional [chex .ArrayDType ] = None ,
252
253
* ,
253
254
nesterov : bool = False ,
@@ -317,10 +318,10 @@ class ScaleByAmsgradState(NamedTuple):
317
318
318
319
319
320
def scale_by_amsgrad (
320
- b1 : float = 0.9 ,
321
- b2 : float = 0.999 ,
322
- eps : float = 1e-8 ,
323
- eps_root : float = 0.0 ,
321
+ b1 : float | jax . Array = 0.9 ,
322
+ b2 : float | jax . Array = 0.999 ,
323
+ eps : float | jax . Array = 1e-8 ,
324
+ eps_root : float | jax . Array = 0.0 ,
324
325
mu_dtype : Optional [chex .ArrayDType ] = None ,
325
326
) -> base .GradientTransformation :
326
327
"""Rescale updates according to the AMSGrad algorithm.
@@ -373,7 +374,9 @@ def update_fn(updates, state, params=None):
373
374
374
375
375
376
def scale_by_adamax (
376
- b1 : float = 0.9 , b2 : float = 0.999 , eps : float = 1e-8
377
+ b1 : float | jax .Array = 0.9 ,
378
+ b2 : float | jax .Array = 0.999 ,
379
+ eps : float | jax .Array = 1e-8 ,
377
380
) -> base .GradientTransformation :
378
381
"""Rescale updates according to the Adamax algorithm.
379
382
@@ -414,8 +417,8 @@ class ScaleByLionState(NamedTuple):
414
417
415
418
416
419
def scale_by_lion (
417
- b1 : float = 0.9 ,
418
- b2 : float = 0.99 ,
420
+ b1 : float | jax . Array = 0.9 ,
421
+ b2 : float | jax . Array = 0.99 ,
419
422
mu_dtype : Optional [chex .ArrayDType ] = None ,
420
423
) -> base .GradientTransformation :
421
424
"""Rescale updates according to the Lion algorithm.
@@ -451,7 +454,9 @@ def update_fn(updates, state, params=None):
451
454
return base .GradientTransformation (init_fn , update_fn )
452
455
453
456
454
- def scale (step_size : float ) -> base .GradientTransformation :
457
+ def scale (
458
+ step_size : float | jax .Array
459
+ ) -> base .GradientTransformation :
455
460
"""Scale updates by some fixed scalar `step_size`.
456
461
457
462
Args:
@@ -470,7 +475,7 @@ def update_fn(updates, state, params=None):
470
475
471
476
472
477
def scale_by_param_block_norm (
473
- min_scale : float = 1e-3 ,
478
+ min_scale : float | jax . Array = 1e-3
474
479
) -> base .GradientTransformation :
475
480
"""Scale updates for each param block by the norm of that block's parameters.
476
481
@@ -496,7 +501,7 @@ def update_fn(updates, state, params):
496
501
497
502
498
503
def scale_by_param_block_rms (
499
- min_scale : float = 1e-3 ,
504
+ min_scale : float | jax . Array = 1e-3 ,
500
505
) -> base .GradientTransformation :
501
506
"""Scale updates by rms of the gradient for each param vector or matrix.
502
507
@@ -656,10 +661,10 @@ class ScaleByBeliefState(NamedTuple):
656
661
657
662
658
663
def scale_by_belief (
659
- b1 : float = 0.9 ,
660
- b2 : float = 0.999 ,
661
- eps : float = 1e-16 ,
662
- eps_root : float = 1e-16 ,
664
+ b1 : float | jax . Array = 0.9 ,
665
+ b2 : float | jax . Array = 0.999 ,
666
+ eps : float | jax . Array = 1e-16 ,
667
+ eps_root : float | jax . Array = 1e-16 ,
663
668
* ,
664
669
nesterov : bool = False ,
665
670
) -> base .GradientTransformation :
@@ -713,11 +718,11 @@ def update_fn(updates, state, params=None):
713
718
714
719
715
720
def scale_by_yogi (
716
- b1 : float = 0.9 ,
717
- b2 : float = 0.999 ,
718
- eps : float = 1e-3 ,
719
- eps_root : float = 0.0 ,
720
- initial_accumulator_value : float = 1e-6 ,
721
+ b1 : float | jax . Array = 0.9 ,
722
+ b2 : float | jax . Array = 0.999 ,
723
+ eps : float | jax . Array = 1e-3 ,
724
+ eps_root : float | jax . Array = 0.0 ,
725
+ initial_accumulator_value : float | jax . Array = 1e-6 ,
721
726
) -> base .GradientTransformation :
722
727
"""Rescale updates according to the Yogi algorithm.
723
728
@@ -767,11 +772,11 @@ def update_fn(updates, state, params=None):
767
772
768
773
769
774
def scale_by_radam (
770
- b1 : float = 0.9 ,
771
- b2 : float = 0.999 ,
772
- eps : float = 1e-8 ,
773
- eps_root : float = 0.0 ,
774
- threshold : float = 5.0 ,
775
+ b1 : float | jax . Array = 0.9 ,
776
+ b2 : float | jax . Array = 0.999 ,
777
+ eps : float | jax . Array = 1e-8 ,
778
+ eps_root : float | jax . Array = 0.0 ,
779
+ threshold : float | jax . Array = 5.0 ,
775
780
* ,
776
781
nesterov : bool = False ,
777
782
) -> base .GradientTransformation :
@@ -991,9 +996,9 @@ def update_fn(updates, state, params=None):
991
996
992
997
993
998
def scale_by_trust_ratio (
994
- min_norm : float = 0.0 ,
995
- trust_coefficient : float = 1.0 ,
996
- eps : float = 0.0 ,
999
+ min_norm : float | jax . Array = 0.0 ,
1000
+ trust_coefficient : float | jax . Array = 1. ,
1001
+ eps : float | jax . Array = 0. ,
997
1002
) -> base .GradientTransformation :
998
1003
"""Scale updates by `trust ratio`.
999
1004
@@ -1121,7 +1126,9 @@ class ScaleBySM3State(NamedTuple):
1121
1126
1122
1127
1123
1128
def scale_by_sm3 (
1124
- b1 : float = 0.9 , b2 : float = 1.0 , eps : float = 1e-8
1129
+ b1 : float | jax .Array = 0.9 ,
1130
+ b2 : float | jax .Array = 1.0 ,
1131
+ eps : float | jax .Array = 1e-8 ,
1125
1132
) -> base .GradientTransformation :
1126
1133
"""Scale updates by `sm3`.
1127
1134
@@ -1198,11 +1205,11 @@ class ScaleByNovogradState(NamedTuple):
1198
1205
1199
1206
1200
1207
def scale_by_novograd (
1201
- b1 : float = 0.9 ,
1202
- b2 : float = 0.25 ,
1203
- eps : float = 1e-8 ,
1204
- eps_root : float = 0.0 ,
1205
- weight_decay : float = 0.0 ,
1208
+ b1 : float | jax . Array = 0.9 ,
1209
+ b2 : float | jax . Array = 0.25 ,
1210
+ eps : float | jax . Array = 1e-8 ,
1211
+ eps_root : float | jax . Array = 0.0 ,
1212
+ weight_decay : float | jax . Array = 0.0 ,
1206
1213
mu_dtype : Optional [chex .ArrayDType ] = None ,
1207
1214
) -> base .GradientTransformation :
1208
1215
"""Computes NovoGrad updates.
@@ -1277,7 +1284,8 @@ class ScaleByOptimisticGradientState(NamedTuple):
1277
1284
1278
1285
1279
1286
def scale_by_optimistic_gradient (
1280
- alpha : float = 1.0 , beta : float = 1.0
1287
+ alpha : float | jax .Array = 1.0 ,
1288
+ beta : float | jax .Array = 1.0 ,
1281
1289
) -> base .GradientTransformation :
1282
1290
"""Compute generalized optimistic gradients.
1283
1291
@@ -1494,8 +1502,8 @@ def _precondition_by_lbfgs(
1494
1502
diff_params_memory : chex .ArrayTree ,
1495
1503
diff_updates_memory : chex .ArrayTree ,
1496
1504
weights_memory : chex .Array ,
1497
- identity_scale : Union [ float , jax .Array ] ,
1498
- memory_idx : Union [ int , jax .Array ] ,
1505
+ identity_scale : float | jax .Array ,
1506
+ memory_idx : int | jax .Array ,
1499
1507
) -> base .Updates :
1500
1508
r"""Multiplies updates by an approximation of the inverse Hessian.
1501
1509
0 commit comments