diff --git a/optax/projections/_projections_test.py b/optax/projections/_projections_test.py index 2f4ea6a48..00e142b48 100644 --- a/optax/projections/_projections_test.py +++ b/optax/projections/_projections_test.py @@ -50,7 +50,7 @@ def setUp(self): 'l1': (proj.projection_l1_ball, partial(optax.tree.norm, ord=1)), 'l2': (proj.projection_l2_ball, optax.tree.norm), 'linf': (proj.projection_linf_ball, - partial(optax.tree.norm, ord='inf')), + partial(optax.tree.norm, ord=float('inf'))), } def test_projection_non_negative(self): diff --git a/optax/tree_utils/__init__.py b/optax/tree_utils/__init__.py index ce4952f2d..96324d0ba 100644 --- a/optax/tree_utils/__init__.py +++ b/optax/tree_utils/__init__.py @@ -83,7 +83,7 @@ ('optax.tree_utils.tree_linf_norm is deprecated: use' ' optax.tree_utils.tree_norm(..., ord=jnp.inf)' ' (optax v0.2.5 or newer).'), - functools.partial(tree_norm, ord='inf'), + functools.partial(tree_norm, ord=float('inf')), ), } @@ -94,7 +94,7 @@ tree_add_scalar_mul = tree_add_scale tree_l1_norm = functools.partial(tree_norm, ord=1) tree_l2_norm = tree_norm - tree_linf_norm = functools.partial(tree_norm, ord='inf') + tree_linf_norm = functools.partial(tree_norm, ord=float('inf')) else: # pylint: disable=line-too-long diff --git a/optax/tree_utils/_tree_math.py b/optax/tree_utils/_tree_math.py index 9387a0976..67d2ed4f6 100644 --- a/optax/tree_utils/_tree_math.py +++ b/optax/tree_utils/_tree_math.py @@ -236,25 +236,25 @@ def _square(leaf): def tree_norm(tree: Any, - ord: int | str | float | None = None, # pylint: disable=redefined-builtin + ord: int | float = 2, # pylint: disable=redefined-builtin squared: bool = False) -> jax.Array: """Compute the vector norm of the given ord of a pytree. Args: tree: pytree. - ord: the order of the vector norm to compute from (None, 1, 2, inf). + ord: the order of the vector norm to compute, one of ``{1, 2, inf}``. squared: whether the norm should be returned squared or not. Returns: a scalar value. """ - if ord is None or ord == 2: + if ord == 2: squared_tree = jax.tree.map(_square, tree) sqnorm = tree_sum(squared_tree) return jnp.array(sqnorm if squared else jnp.sqrt(sqnorm)) elif ord == 1: ret = tree_sum(jax.tree.map(jnp.abs, tree)) - elif ord == jnp.inf or ord in ("inf", "infinity"): + elif ord == float("inf"): ret = tree_max(jax.tree.map(jnp.abs, tree)) else: raise ValueError(f"Unsupported ord: {ord}")