Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optax/projections/_projections_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def setUp(self):
'l1': (proj.projection_l1_ball, partial(optax.tree.norm, ord=1)),
'l2': (proj.projection_l2_ball, optax.tree.norm),
'linf': (proj.projection_linf_ball,
partial(optax.tree.norm, ord='inf')),
partial(optax.tree.norm, ord=float('inf'))),
}

def test_projection_non_negative(self):
Expand Down
4 changes: 2 additions & 2 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
('optax.tree_utils.tree_linf_norm is deprecated: use'
' optax.tree_utils.tree_norm(..., ord=jnp.inf)'
' (optax v0.2.5 or newer).'),
functools.partial(tree_norm, ord='inf'),
functools.partial(tree_norm, ord=float('inf')),
),
}

Expand All @@ -94,7 +94,7 @@
tree_add_scalar_mul = tree_add_scale
tree_l1_norm = functools.partial(tree_norm, ord=1)
tree_l2_norm = tree_norm
tree_linf_norm = functools.partial(tree_norm, ord='inf')
tree_linf_norm = functools.partial(tree_norm, ord=float('inf'))

else:
# pylint: disable=line-too-long
Expand Down
8 changes: 4 additions & 4 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,25 +236,25 @@ def _square(leaf):


def tree_norm(tree: Any,
ord: int | str | float | None = None, # pylint: disable=redefined-builtin
ord: int | float = 2, # pylint: disable=redefined-builtin
squared: bool = False) -> jax.Array:
"""Compute the vector norm of the given ord of a pytree.

Args:
tree: pytree.
ord: the order of the vector norm to compute from (None, 1, 2, inf).
ord: the order of the vector norm to compute, one of ``{1, 2, inf}``.
squared: whether the norm should be returned squared or not.

Returns:
a scalar value.
"""
if ord is None or ord == 2:
if ord == 2:
squared_tree = jax.tree.map(_square, tree)
sqnorm = tree_sum(squared_tree)
return jnp.array(sqnorm if squared else jnp.sqrt(sqnorm))
elif ord == 1:
ret = tree_sum(jax.tree.map(jnp.abs, tree))
elif ord == jnp.inf or ord in ("inf", "infinity"):
elif ord == float("inf"):
ret = tree_max(jax.tree.map(jnp.abs, tree))
else:
raise ValueError(f"Unsupported ord: {ord}")
Expand Down
Loading