Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion optax/_src/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections.abc import Callable
import functools
from typing import Optional, Union
import warnings

import chex
import jax
Expand All @@ -33,10 +34,17 @@ def _normalize_tree(x):


def global_norm(updates: base.PyTree) -> chex.Array:
"""Compute the global norm across a nested structure of tensors."""
"""Compute the global norm across a nested structure of tensors.
return jnp.sqrt(
sum(jnp.sum(numerics.abs_sq(x)) for x in jax.tree.leaves(updates))
.. warning::
Deprecated in favor of :func:`optax.tree.norm`.
"""
warnings.warn(
"optax.global_norm is deprecated in favor of optax.tree.norm",
DeprecationWarning
)
return optax.tree.norm(updates)


def _power_iteration_cond_fun(error_tolerance, num_iters, loop_vars):
Expand Down
2 changes: 1 addition & 1 deletion optax/_src/linear_algebra_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_global_norm(self):
}
np.testing.assert_array_equal(
jnp.sqrt(jnp.sum(flat_updates**2)),
linear_algebra.global_norm(nested_updates),
optax.tree.norm(nested_updates),
)

def test_power_iteration_cond_fun(self, dim=6):
Expand Down
1 change: 0 additions & 1 deletion optax/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,4 +345,3 @@ def _value_and_grad(
# TODO(b/183800387): remove legacy aliases.
safe_norm = numerics.safe_norm
safe_int32_increment = numerics.safe_int32_increment
global_norm = linear_algebra.global_norm
4 changes: 2 additions & 2 deletions optax/contrib/_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
import jax.numpy as jnp
from optax._src import base
from optax._src import update
from optax._src import utils
import optax.tree

# As a helper for SAM we need a gradient normalizing transformation.

Expand All @@ -74,7 +74,7 @@ def init_fn(params):

def update_fn(updates, state, params=None):
del params
g_norm = utils.global_norm(updates)
g_norm = optax.tree.norm(updates)
updates = jax.tree.map(lambda g: g / g_norm, updates)
return updates, state

Expand Down
5 changes: 2 additions & 3 deletions optax/transforms/_clipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import jax
import jax.numpy as jnp
from optax._src import base
from optax._src import linear_algebra
from optax._src import numerics
import optax.tree

Expand Down Expand Up @@ -90,7 +89,7 @@ def clip_by_global_norm(max_norm: float) -> base.GradientTransformation:

def update_fn(updates, state, params=None):
del params
g_norm = linear_algebra.global_norm(updates)
g_norm = optax.tree.norm(updates)
# TODO(b/163995078): revert back to the following (faster) implementation
# once analyzed how it affects backprop through update (e.g. meta-gradients)
# g_norm = jnp.maximum(max_norm, g_norm)
Expand Down Expand Up @@ -154,7 +153,7 @@ def per_example_global_norm_clip(
" `grads` to have a batch dimension in the 0th axis."
)

global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads)
global_grad_norms = jax.vmap(optax.tree.norm)(grads)
multipliers = jnp.nan_to_num(
jnp.minimum(l2_norm_clip / global_grad_norms, 1.0), nan=1.0
)
Expand Down
6 changes: 2 additions & 4 deletions optax/transforms/_clipping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import jax
import jax.numpy as jnp
import numpy as np
from optax._src import linear_algebra
from optax.transforms import _clipping
import optax.tree


STEPS = 50
Expand Down Expand Up @@ -70,9 +70,7 @@ def test_clip_by_global_norm(self):
clipper = _clipping.clip_by_global_norm(1.0 / i)
# Check that the clipper actually works and global norm is <= max_norm
updates, _ = clipper.update(updates, None)
self.assertAlmostEqual(
linear_algebra.global_norm(updates), 1.0 / i, places=6
)
self.assertAlmostEqual(optax.tree.norm(updates), 1.0 / i, places=6)
# Check that continuously clipping won't cause numerical issues.
updates_step, _ = clipper.update(self.per_step_updates, None)
chex.assert_trees_all_close(updates, updates_step)
Expand Down
Loading