Skip to content

Conversation

@carlosgmartin
Copy link
Contributor

Streamlines and simplifies the interface for optax.tree.norm (no need for 3 different ways to specify the infinity norm, and 2 different ways to specify the 2-norm).

Split into a new PR from #1365.

@vroulet
Copy link
Collaborator

vroulet commented Jul 15, 2025

The user will expect the same behavior as in jax, so we will stay as close as possible to jax.numpy.linalg.norm.
Maybe the "float" typing hint could be removed. But that could break some other functions and it's not a priority.
In general cosmetic changes like this one may actually cost us more energy than it solves problems so maybe consider the cost of these changes before doing PRs. Thank you though for all your efforts and fixing many bugs!

@vroulet vroulet closed this Jul 15, 2025
@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Jul 15, 2025

@vroulet jax.numpy.linalg.norm doesn't allow ord='inf' or ord='infinity':

$ py -c "from jax import numpy as jnp; jnp.linalg.norm(jnp.zeros(3), ord='inf')"
ValueError: Invalid order 'inf' for vector norm.Use 'jax.numpy.inf' instead.
$ py -c "from jax import numpy as jnp; jnp.linalg.norm(jnp.zeros(3), ord='infinity')"
ValueError: Invalid order 'infinity' for vector norm.

I'd also argue that the narrower type of ord (excluding float & matching jax.numpy.linalg.norm) can help prevent future bugs by catching an incorrect argument at type-check time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants