Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 56 additions & 48 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Gradient transformations."""

import functools
from typing import NamedTuple, Optional, Union
from typing import NamedTuple, Optional

import chex
import jax
Expand Down Expand Up @@ -44,7 +44,8 @@ class ScaleByRssState(NamedTuple):


def scale_by_rss(
initial_accumulator_value: float = 0.1, eps: float = 1e-7
initial_accumulator_value: float | jax.Array = 0.1,
eps: float | jax.Array = 1e-7,
) -> base.GradientTransformation:
"""Rescale updates by the root of the sum of all squared gradients to date.

Expand Down Expand Up @@ -93,9 +94,9 @@ class ScaleByRmsWithCountState(NamedTuple):


def scale_by_rms(
decay: float = 0.9,
eps: float = 1e-8,
initial_scale: float = 0.0,
decay: float | jax.Array = 0.9,
eps: float | jax.Array = 1e-8,
initial_scale: float | jax.Array = 0.0,
eps_in_sqrt: bool = True,
bias_correction: bool = False,
) -> base.GradientTransformation:
Expand Down Expand Up @@ -168,9 +169,9 @@ class ScaleByRStdDevWithCountState(NamedTuple):


def scale_by_stddev(
decay: float = 0.9,
eps: float = 1e-8,
initial_scale: float = 0.0,
decay: float | jax.Array = 0.9,
eps: float | jax.Array = 1e-8,
initial_scale: float | jax.Array = 0.0,
eps_in_sqrt: bool = True,
bias_correction: bool = False,
) -> base.GradientTransformation:
Expand Down Expand Up @@ -244,10 +245,10 @@ class ScaleByAdamState(NamedTuple):


def scale_by_adam(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
b1: float | jax.Array = 0.9,
b2: float | jax.Array = 0.999,
eps: float | jax.Array = 1e-8,
eps_root: float | jax.Array = 0.0,
mu_dtype: Optional[chex.ArrayDType] = None,
*,
nesterov: bool = False,
Expand Down Expand Up @@ -317,10 +318,10 @@ class ScaleByAmsgradState(NamedTuple):


def scale_by_amsgrad(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
b1: float | jax.Array = 0.9,
b2: float | jax.Array = 0.999,
eps: float | jax.Array = 1e-8,
eps_root: float | jax.Array = 0.0,
mu_dtype: Optional[chex.ArrayDType] = None,
) -> base.GradientTransformation:
"""Rescale updates according to the AMSGrad algorithm.
Expand Down Expand Up @@ -373,7 +374,9 @@ def update_fn(updates, state, params=None):


def scale_by_adamax(
b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8
b1: float | jax.Array = 0.9,
b2: float | jax.Array = 0.999,
eps: float | jax.Array = 1e-8,
) -> base.GradientTransformation:
"""Rescale updates according to the Adamax algorithm.

Expand Down Expand Up @@ -414,8 +417,8 @@ class ScaleByLionState(NamedTuple):


def scale_by_lion(
b1: float = 0.9,
b2: float = 0.99,
b1: float | jax.Array = 0.9,
b2: float | jax.Array = 0.99,
mu_dtype: Optional[chex.ArrayDType] = None,
) -> base.GradientTransformation:
"""Rescale updates according to the Lion algorithm.
Expand Down Expand Up @@ -451,7 +454,9 @@ def update_fn(updates, state, params=None):
return base.GradientTransformation(init_fn, update_fn)


def scale(step_size: float) -> base.GradientTransformation:
def scale(
step_size: float | jax.Array
) -> base.GradientTransformation:
"""Scale updates by some fixed scalar `step_size`.

Args:
Expand All @@ -470,7 +475,7 @@ def update_fn(updates, state, params=None):


def scale_by_param_block_norm(
min_scale: float = 1e-3,
min_scale: float | jax.Array = 1e-3
) -> base.GradientTransformation:
"""Scale updates for each param block by the norm of that block's parameters.

Expand All @@ -496,7 +501,7 @@ def update_fn(updates, state, params):


def scale_by_param_block_rms(
min_scale: float = 1e-3,
min_scale: float | jax.Array = 1e-3,
) -> base.GradientTransformation:
"""Scale updates by rms of the gradient for each param vector or matrix.

Expand Down Expand Up @@ -656,10 +661,10 @@ class ScaleByBeliefState(NamedTuple):


def scale_by_belief(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-16,
eps_root: float = 1e-16,
b1: float | jax.Array = 0.9,
b2: float | jax.Array = 0.999,
eps: float | jax.Array = 1e-16,
eps_root: float | jax.Array = 1e-16,
*,
nesterov: bool = False,
) -> base.GradientTransformation:
Expand Down Expand Up @@ -713,11 +718,11 @@ def update_fn(updates, state, params=None):


def scale_by_yogi(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-3,
eps_root: float = 0.0,
initial_accumulator_value: float = 1e-6,
b1: float | jax.Array = 0.9,
b2: float | jax.Array = 0.999,
eps: float | jax.Array = 1e-3,
eps_root: float | jax.Array = 0.0,
initial_accumulator_value: float | jax.Array = 1e-6,
) -> base.GradientTransformation:
"""Rescale updates according to the Yogi algorithm.

Expand Down Expand Up @@ -767,11 +772,11 @@ def update_fn(updates, state, params=None):


def scale_by_radam(
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
threshold: float = 5.0,
b1: float | jax.Array = 0.9,
b2: float | jax.Array = 0.999,
eps: float | jax.Array = 1e-8,
eps_root: float | jax.Array = 0.0,
threshold: float | jax.Array = 5.0,
*,
nesterov: bool = False,
) -> base.GradientTransformation:
Expand Down Expand Up @@ -991,9 +996,9 @@ def update_fn(updates, state, params=None):


def scale_by_trust_ratio(
min_norm: float = 0.0,
trust_coefficient: float = 1.0,
eps: float = 0.0,
min_norm: float | jax.Array = 0.0,
trust_coefficient: float | jax.Array = 1.0,
eps: float | jax.Array = 0.0,
) -> base.GradientTransformation:
"""Scale updates by `trust ratio`.

Expand Down Expand Up @@ -1121,7 +1126,9 @@ class ScaleBySM3State(NamedTuple):


def scale_by_sm3(
b1: float = 0.9, b2: float = 1.0, eps: float = 1e-8
b1: float | jax.Array = 0.9,
b2: float | jax.Array = 1.0,
eps: float | jax.Array = 1e-8,
) -> base.GradientTransformation:
"""Scale updates by `sm3`.

Expand Down Expand Up @@ -1198,11 +1205,11 @@ class ScaleByNovogradState(NamedTuple):


def scale_by_novograd(
b1: float = 0.9,
b2: float = 0.25,
eps: float = 1e-8,
eps_root: float = 0.0,
weight_decay: float = 0.0,
b1: float | jax.Array = 0.9,
b2: float | jax.Array = 0.25,
eps: float | jax.Array = 1e-8,
eps_root: float | jax.Array = 0.0,
weight_decay: float | jax.Array = 0.0,
mu_dtype: Optional[chex.ArrayDType] = None,
) -> base.GradientTransformation:
"""Computes NovoGrad updates.
Expand Down Expand Up @@ -1277,7 +1284,8 @@ class ScaleByOptimisticGradientState(NamedTuple):


def scale_by_optimistic_gradient(
alpha: float = 1.0, beta: float = 1.0
alpha: float | jax.Array = 1.0,
beta: float | jax.Array = 1.0,
) -> base.GradientTransformation:
"""Compute generalized optimistic gradients.

Expand Down Expand Up @@ -1494,8 +1502,8 @@ def _precondition_by_lbfgs(
diff_params_memory: chex.ArrayTree,
diff_updates_memory: chex.ArrayTree,
weights_memory: chex.Array,
identity_scale: Union[float, jax.Array],
memory_idx: Union[int, jax.Array],
identity_scale: float | jax.Array,
memory_idx: int | jax.Array,
) -> base.Updates:
r"""Multiplies updates by an approximation of the inverse Hessian.

Expand Down
Loading