diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index d5e302642..1b471e340 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -17,6 +17,7 @@ from collections.abc import Callable import functools from typing import Optional, Union +import warnings import chex import jax @@ -33,10 +34,17 @@ def _normalize_tree(x): def global_norm(updates: base.PyTree) -> chex.Array: - """Compute the global norm across a nested structure of tensors.""" + """Compute the global norm across a nested structure of tensors. return jnp.sqrt( sum(jnp.sum(numerics.abs_sq(x)) for x in jax.tree.leaves(updates)) + .. warning:: + Deprecated in favor of :func:`optax.tree.norm`. + """ + warnings.warn( + "optax.global_norm is deprecated in favor of optax.tree.norm", + DeprecationWarning ) + return optax.tree.norm(updates) def _power_iteration_cond_fun(error_tolerance, num_iters, loop_vars): diff --git a/optax/_src/linear_algebra_test.py b/optax/_src/linear_algebra_test.py index 0c6db11b5..c0d79d8f4 100644 --- a/optax/_src/linear_algebra_test.py +++ b/optax/_src/linear_algebra_test.py @@ -53,7 +53,7 @@ def test_global_norm(self): } np.testing.assert_array_equal( jnp.sqrt(jnp.sum(flat_updates**2)), - linear_algebra.global_norm(nested_updates), + optax.tree.norm(nested_updates), ) def test_power_iteration_cond_fun(self, dim=6): diff --git a/optax/_src/utils.py b/optax/_src/utils.py index 5c38e0d7d..d908ecdad 100644 --- a/optax/_src/utils.py +++ b/optax/_src/utils.py @@ -345,4 +345,3 @@ def _value_and_grad( # TODO(b/183800387): remove legacy aliases. safe_norm = numerics.safe_norm safe_int32_increment = numerics.safe_int32_increment -global_norm = linear_algebra.global_norm diff --git a/optax/contrib/_sam.py b/optax/contrib/_sam.py index 9fb1f8f57..3c22545e8 100644 --- a/optax/contrib/_sam.py +++ b/optax/contrib/_sam.py @@ -54,7 +54,7 @@ import jax.numpy as jnp from optax._src import base from optax._src import update -from optax._src import utils +import optax.tree # As a helper for SAM we need a gradient normalizing transformation. @@ -74,7 +74,7 @@ def init_fn(params): def update_fn(updates, state, params=None): del params - g_norm = utils.global_norm(updates) + g_norm = optax.tree.norm(updates) updates = jax.tree.map(lambda g: g / g_norm, updates) return updates, state diff --git a/optax/transforms/_clipping.py b/optax/transforms/_clipping.py index 08a614a05..117272f28 100644 --- a/optax/transforms/_clipping.py +++ b/optax/transforms/_clipping.py @@ -24,7 +24,6 @@ import jax import jax.numpy as jnp from optax._src import base -from optax._src import linear_algebra from optax._src import numerics import optax.tree @@ -90,7 +89,7 @@ def clip_by_global_norm(max_norm: float) -> base.GradientTransformation: def update_fn(updates, state, params=None): del params - g_norm = linear_algebra.global_norm(updates) + g_norm = optax.tree.norm(updates) # TODO(b/163995078): revert back to the following (faster) implementation # once analyzed how it affects backprop through update (e.g. meta-gradients) # g_norm = jnp.maximum(max_norm, g_norm) @@ -154,7 +153,7 @@ def per_example_global_norm_clip( " `grads` to have a batch dimension in the 0th axis." ) - global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads) + global_grad_norms = jax.vmap(optax.tree.norm)(grads) multipliers = jnp.nan_to_num( jnp.minimum(l2_norm_clip / global_grad_norms, 1.0), nan=1.0 ) diff --git a/optax/transforms/_clipping_test.py b/optax/transforms/_clipping_test.py index 2c4a621df..8f53ef2a2 100644 --- a/optax/transforms/_clipping_test.py +++ b/optax/transforms/_clipping_test.py @@ -19,8 +19,8 @@ import jax import jax.numpy as jnp import numpy as np -from optax._src import linear_algebra from optax.transforms import _clipping +import optax.tree STEPS = 50 @@ -70,9 +70,7 @@ def test_clip_by_global_norm(self): clipper = _clipping.clip_by_global_norm(1.0 / i) # Check that the clipper actually works and global norm is <= max_norm updates, _ = clipper.update(updates, None) - self.assertAlmostEqual( - linear_algebra.global_norm(updates), 1.0 / i, places=6 - ) + self.assertAlmostEqual(optax.tree.norm(updates), 1.0 / i, places=6) # Check that continuously clipping won't cause numerical issues. updates_step, _ = clipper.update(self.per_step_updates, None) chex.assert_trees_all_close(updates, updates_step)