diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 70f30d817..2e9de4212 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -15,7 +15,7 @@ """Gradient transformations.""" import functools -from typing import NamedTuple, Optional, Union +from typing import NamedTuple, Optional import chex import jax @@ -44,7 +44,8 @@ class ScaleByRssState(NamedTuple): def scale_by_rss( - initial_accumulator_value: float = 0.1, eps: float = 1e-7 + initial_accumulator_value: float | jax.Array = 0.1, + eps: float | jax.Array = 1e-7, ) -> base.GradientTransformation: """Rescale updates by the root of the sum of all squared gradients to date. @@ -93,9 +94,9 @@ class ScaleByRmsWithCountState(NamedTuple): def scale_by_rms( - decay: float = 0.9, - eps: float = 1e-8, - initial_scale: float = 0.0, + decay: float | jax.Array = 0.9, + eps: float | jax.Array = 1e-8, + initial_scale: float | jax.Array = 0.0, eps_in_sqrt: bool = True, bias_correction: bool = False, ) -> base.GradientTransformation: @@ -168,9 +169,9 @@ class ScaleByRStdDevWithCountState(NamedTuple): def scale_by_stddev( - decay: float = 0.9, - eps: float = 1e-8, - initial_scale: float = 0.0, + decay: float | jax.Array = 0.9, + eps: float | jax.Array = 1e-8, + initial_scale: float | jax.Array = 0.0, eps_in_sqrt: bool = True, bias_correction: bool = False, ) -> base.GradientTransformation: @@ -244,10 +245,10 @@ class ScaleByAdamState(NamedTuple): def scale_by_adam( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, + b1: float | jax.Array = 0.9, + b2: float | jax.Array = 0.999, + eps: float | jax.Array = 1e-8, + eps_root: float | jax.Array = 0.0, mu_dtype: Optional[chex.ArrayDType] = None, *, nesterov: bool = False, @@ -317,10 +318,10 @@ class ScaleByAmsgradState(NamedTuple): def scale_by_amsgrad( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, + b1: float | jax.Array = 0.9, + b2: float | jax.Array = 0.999, + eps: float | jax.Array = 1e-8, + eps_root: float | jax.Array = 0.0, mu_dtype: Optional[chex.ArrayDType] = None, ) -> base.GradientTransformation: """Rescale updates according to the AMSGrad algorithm. @@ -373,7 +374,9 @@ def update_fn(updates, state, params=None): def scale_by_adamax( - b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8 + b1: float | jax.Array = 0.9, + b2: float | jax.Array = 0.999, + eps: float | jax.Array = 1e-8, ) -> base.GradientTransformation: """Rescale updates according to the Adamax algorithm. @@ -414,8 +417,8 @@ class ScaleByLionState(NamedTuple): def scale_by_lion( - b1: float = 0.9, - b2: float = 0.99, + b1: float | jax.Array = 0.9, + b2: float | jax.Array = 0.99, mu_dtype: Optional[chex.ArrayDType] = None, ) -> base.GradientTransformation: """Rescale updates according to the Lion algorithm. @@ -451,7 +454,9 @@ def update_fn(updates, state, params=None): return base.GradientTransformation(init_fn, update_fn) -def scale(step_size: float) -> base.GradientTransformation: +def scale( + step_size: float | jax.Array +) -> base.GradientTransformation: """Scale updates by some fixed scalar `step_size`. Args: @@ -470,7 +475,7 @@ def update_fn(updates, state, params=None): def scale_by_param_block_norm( - min_scale: float = 1e-3, + min_scale: float | jax.Array = 1e-3 ) -> base.GradientTransformation: """Scale updates for each param block by the norm of that block's parameters. @@ -496,7 +501,7 @@ def update_fn(updates, state, params): def scale_by_param_block_rms( - min_scale: float = 1e-3, + min_scale: float | jax.Array = 1e-3, ) -> base.GradientTransformation: """Scale updates by rms of the gradient for each param vector or matrix. @@ -656,10 +661,10 @@ class ScaleByBeliefState(NamedTuple): def scale_by_belief( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-16, - eps_root: float = 1e-16, + b1: float | jax.Array = 0.9, + b2: float | jax.Array = 0.999, + eps: float | jax.Array = 1e-16, + eps_root: float | jax.Array = 1e-16, *, nesterov: bool = False, ) -> base.GradientTransformation: @@ -713,11 +718,11 @@ def update_fn(updates, state, params=None): def scale_by_yogi( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-3, - eps_root: float = 0.0, - initial_accumulator_value: float = 1e-6, + b1: float | jax.Array = 0.9, + b2: float | jax.Array = 0.999, + eps: float | jax.Array = 1e-3, + eps_root: float | jax.Array = 0.0, + initial_accumulator_value: float | jax.Array = 1e-6, ) -> base.GradientTransformation: """Rescale updates according to the Yogi algorithm. @@ -767,11 +772,11 @@ def update_fn(updates, state, params=None): def scale_by_radam( - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - threshold: float = 5.0, + b1: float | jax.Array = 0.9, + b2: float | jax.Array = 0.999, + eps: float | jax.Array = 1e-8, + eps_root: float | jax.Array = 0.0, + threshold: float | jax.Array = 5.0, *, nesterov: bool = False, ) -> base.GradientTransformation: @@ -991,9 +996,9 @@ def update_fn(updates, state, params=None): def scale_by_trust_ratio( - min_norm: float = 0.0, - trust_coefficient: float = 1.0, - eps: float = 0.0, + min_norm: float | jax.Array = 0.0, + trust_coefficient: float | jax.Array = 1.0, + eps: float | jax.Array = 0.0, ) -> base.GradientTransformation: """Scale updates by `trust ratio`. @@ -1121,7 +1126,9 @@ class ScaleBySM3State(NamedTuple): def scale_by_sm3( - b1: float = 0.9, b2: float = 1.0, eps: float = 1e-8 + b1: float | jax.Array = 0.9, + b2: float | jax.Array = 1.0, + eps: float | jax.Array = 1e-8, ) -> base.GradientTransformation: """Scale updates by `sm3`. @@ -1198,11 +1205,11 @@ class ScaleByNovogradState(NamedTuple): def scale_by_novograd( - b1: float = 0.9, - b2: float = 0.25, - eps: float = 1e-8, - eps_root: float = 0.0, - weight_decay: float = 0.0, + b1: float | jax.Array = 0.9, + b2: float | jax.Array = 0.25, + eps: float | jax.Array = 1e-8, + eps_root: float | jax.Array = 0.0, + weight_decay: float | jax.Array = 0.0, mu_dtype: Optional[chex.ArrayDType] = None, ) -> base.GradientTransformation: """Computes NovoGrad updates. @@ -1277,7 +1284,8 @@ class ScaleByOptimisticGradientState(NamedTuple): def scale_by_optimistic_gradient( - alpha: float = 1.0, beta: float = 1.0 + alpha: float | jax.Array = 1.0, + beta: float | jax.Array = 1.0, ) -> base.GradientTransformation: """Compute generalized optimistic gradients. @@ -1494,8 +1502,8 @@ def _precondition_by_lbfgs( diff_params_memory: chex.ArrayTree, diff_updates_memory: chex.ArrayTree, weights_memory: chex.Array, - identity_scale: Union[float, jax.Array], - memory_idx: Union[int, jax.Array], + identity_scale: float | jax.Array, + memory_idx: int | jax.Array, ) -> base.Updates: r"""Multiplies updates by an approximation of the inverse Hessian.