Skip to content

Commit bcf5cbf

Browse files
committed
Allow gradient transform parameters to be dynamic
1 parent cd1e045 commit bcf5cbf

File tree

1 file changed

+56
-48
lines changed

1 file changed

+56
-48
lines changed

optax/_src/transform.py

Lines changed: 56 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Gradient transformations."""
1616

1717
import functools
18-
from typing import NamedTuple, Optional, Union
18+
from typing import NamedTuple, Optional
1919

2020
import chex
2121
import jax
@@ -44,7 +44,8 @@ class ScaleByRssState(NamedTuple):
4444

4545

4646
def scale_by_rss(
47-
initial_accumulator_value: float = 0.1, eps: float = 1e-7
47+
initial_accumulator_value: float | jax.Array = 0.1,
48+
eps: float | jax.Array = 1e-7,
4849
) -> base.GradientTransformation:
4950
"""Rescale updates by the root of the sum of all squared gradients to date.
5051
@@ -93,9 +94,9 @@ class ScaleByRmsWithCountState(NamedTuple):
9394

9495

9596
def scale_by_rms(
96-
decay: float = 0.9,
97-
eps: float = 1e-8,
98-
initial_scale: float = 0.0,
97+
decay: float | jax.Array = 0.9,
98+
eps: float | jax.Array = 1e-8,
99+
initial_scale: float | jax.Array = 0.,
99100
eps_in_sqrt: bool = True,
100101
bias_correction: bool = False,
101102
) -> base.GradientTransformation:
@@ -168,9 +169,9 @@ class ScaleByRStdDevWithCountState(NamedTuple):
168169

169170

170171
def scale_by_stddev(
171-
decay: float = 0.9,
172-
eps: float = 1e-8,
173-
initial_scale: float = 0.0,
172+
decay: float | jax.Array = 0.9,
173+
eps: float | jax.Array = 1e-8,
174+
initial_scale: float | jax.Array = 0.,
174175
eps_in_sqrt: bool = True,
175176
bias_correction: bool = False,
176177
) -> base.GradientTransformation:
@@ -244,10 +245,10 @@ class ScaleByAdamState(NamedTuple):
244245

245246

246247
def scale_by_adam(
247-
b1: float = 0.9,
248-
b2: float = 0.999,
249-
eps: float = 1e-8,
250-
eps_root: float = 0.0,
248+
b1: float | jax.Array = 0.9,
249+
b2: float | jax.Array = 0.999,
250+
eps: float | jax.Array = 1e-8,
251+
eps_root: float | jax.Array = 0.0,
251252
mu_dtype: Optional[chex.ArrayDType] = None,
252253
*,
253254
nesterov: bool = False,
@@ -317,10 +318,10 @@ class ScaleByAmsgradState(NamedTuple):
317318

318319

319320
def scale_by_amsgrad(
320-
b1: float = 0.9,
321-
b2: float = 0.999,
322-
eps: float = 1e-8,
323-
eps_root: float = 0.0,
321+
b1: float | jax.Array = 0.9,
322+
b2: float | jax.Array = 0.999,
323+
eps: float | jax.Array = 1e-8,
324+
eps_root: float | jax.Array = 0.0,
324325
mu_dtype: Optional[chex.ArrayDType] = None,
325326
) -> base.GradientTransformation:
326327
"""Rescale updates according to the AMSGrad algorithm.
@@ -373,7 +374,9 @@ def update_fn(updates, state, params=None):
373374

374375

375376
def scale_by_adamax(
376-
b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8
377+
b1: float | jax.Array = 0.9,
378+
b2: float | jax.Array = 0.999,
379+
eps: float | jax.Array = 1e-8,
377380
) -> base.GradientTransformation:
378381
"""Rescale updates according to the Adamax algorithm.
379382
@@ -414,8 +417,8 @@ class ScaleByLionState(NamedTuple):
414417

415418

416419
def scale_by_lion(
417-
b1: float = 0.9,
418-
b2: float = 0.99,
420+
b1: float | jax.Array = 0.9,
421+
b2: float | jax.Array = 0.99,
419422
mu_dtype: Optional[chex.ArrayDType] = None,
420423
) -> base.GradientTransformation:
421424
"""Rescale updates according to the Lion algorithm.
@@ -451,7 +454,9 @@ def update_fn(updates, state, params=None):
451454
return base.GradientTransformation(init_fn, update_fn)
452455

453456

454-
def scale(step_size: float) -> base.GradientTransformation:
457+
def scale(
458+
step_size: float | jax.Array
459+
) -> base.GradientTransformation:
455460
"""Scale updates by some fixed scalar `step_size`.
456461
457462
Args:
@@ -470,7 +475,7 @@ def update_fn(updates, state, params=None):
470475

471476

472477
def scale_by_param_block_norm(
473-
min_scale: float = 1e-3,
478+
min_scale: float | jax.Array = 1e-3
474479
) -> base.GradientTransformation:
475480
"""Scale updates for each param block by the norm of that block's parameters.
476481
@@ -496,7 +501,7 @@ def update_fn(updates, state, params):
496501

497502

498503
def scale_by_param_block_rms(
499-
min_scale: float = 1e-3,
504+
min_scale: float | jax.Array = 1e-3,
500505
) -> base.GradientTransformation:
501506
"""Scale updates by rms of the gradient for each param vector or matrix.
502507
@@ -656,10 +661,10 @@ class ScaleByBeliefState(NamedTuple):
656661

657662

658663
def scale_by_belief(
659-
b1: float = 0.9,
660-
b2: float = 0.999,
661-
eps: float = 1e-16,
662-
eps_root: float = 1e-16,
664+
b1: float | jax.Array = 0.9,
665+
b2: float | jax.Array = 0.999,
666+
eps: float | jax.Array = 1e-16,
667+
eps_root: float | jax.Array = 1e-16,
663668
*,
664669
nesterov: bool = False,
665670
) -> base.GradientTransformation:
@@ -713,11 +718,11 @@ def update_fn(updates, state, params=None):
713718

714719

715720
def scale_by_yogi(
716-
b1: float = 0.9,
717-
b2: float = 0.999,
718-
eps: float = 1e-3,
719-
eps_root: float = 0.0,
720-
initial_accumulator_value: float = 1e-6,
721+
b1: float | jax.Array = 0.9,
722+
b2: float | jax.Array = 0.999,
723+
eps: float | jax.Array = 1e-3,
724+
eps_root: float | jax.Array = 0.0,
725+
initial_accumulator_value: float | jax.Array = 1e-6,
721726
) -> base.GradientTransformation:
722727
"""Rescale updates according to the Yogi algorithm.
723728
@@ -767,11 +772,11 @@ def update_fn(updates, state, params=None):
767772

768773

769774
def scale_by_radam(
770-
b1: float = 0.9,
771-
b2: float = 0.999,
772-
eps: float = 1e-8,
773-
eps_root: float = 0.0,
774-
threshold: float = 5.0,
775+
b1: float | jax.Array = 0.9,
776+
b2: float | jax.Array = 0.999,
777+
eps: float | jax.Array = 1e-8,
778+
eps_root: float | jax.Array = 0.0,
779+
threshold: float | jax.Array = 5.0,
775780
*,
776781
nesterov: bool = False,
777782
) -> base.GradientTransformation:
@@ -991,9 +996,9 @@ def update_fn(updates, state, params=None):
991996

992997

993998
def scale_by_trust_ratio(
994-
min_norm: float = 0.0,
995-
trust_coefficient: float = 1.0,
996-
eps: float = 0.0,
999+
min_norm: float | jax.Array = 0.0,
1000+
trust_coefficient: float | jax.Array = 1.,
1001+
eps: float | jax.Array = 0.,
9971002
) -> base.GradientTransformation:
9981003
"""Scale updates by `trust ratio`.
9991004
@@ -1121,7 +1126,9 @@ class ScaleBySM3State(NamedTuple):
11211126

11221127

11231128
def scale_by_sm3(
1124-
b1: float = 0.9, b2: float = 1.0, eps: float = 1e-8
1129+
b1: float | jax.Array = 0.9,
1130+
b2: float | jax.Array = 1.0,
1131+
eps: float | jax.Array = 1e-8,
11251132
) -> base.GradientTransformation:
11261133
"""Scale updates by `sm3`.
11271134
@@ -1198,11 +1205,11 @@ class ScaleByNovogradState(NamedTuple):
11981205

11991206

12001207
def scale_by_novograd(
1201-
b1: float = 0.9,
1202-
b2: float = 0.25,
1203-
eps: float = 1e-8,
1204-
eps_root: float = 0.0,
1205-
weight_decay: float = 0.0,
1208+
b1: float | jax.Array = 0.9,
1209+
b2: float | jax.Array = 0.25,
1210+
eps: float | jax.Array = 1e-8,
1211+
eps_root: float | jax.Array = 0.0,
1212+
weight_decay: float | jax.Array = 0.0,
12061213
mu_dtype: Optional[chex.ArrayDType] = None,
12071214
) -> base.GradientTransformation:
12081215
"""Computes NovoGrad updates.
@@ -1277,7 +1284,8 @@ class ScaleByOptimisticGradientState(NamedTuple):
12771284

12781285

12791286
def scale_by_optimistic_gradient(
1280-
alpha: float = 1.0, beta: float = 1.0
1287+
alpha: float | jax.Array = 1.0,
1288+
beta: float | jax.Array = 1.0,
12811289
) -> base.GradientTransformation:
12821290
"""Compute generalized optimistic gradients.
12831291
@@ -1494,8 +1502,8 @@ def _precondition_by_lbfgs(
14941502
diff_params_memory: chex.ArrayTree,
14951503
diff_updates_memory: chex.ArrayTree,
14961504
weights_memory: chex.Array,
1497-
identity_scale: Union[float, jax.Array],
1498-
memory_idx: Union[int, jax.Array],
1505+
identity_scale: float | jax.Array,
1506+
memory_idx: int | jax.Array,
14991507
) -> base.Updates:
15001508
r"""Multiplies updates by an approximation of the inverse Hessian.
15011509

0 commit comments

Comments
 (0)