Skip to content

Conversation

carlosgmartin
Copy link
Contributor

Deprecates the redundant function optax.global_norm in favor of optax.tree.norm.

Split into a new PR from #1365.

@vroulet
Copy link
Collaborator

vroulet commented Jun 27, 2025

Did you check that the hlos match?

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Jun 28, 2025

$ prog="import jax, optax; print(jax.jit(optax.global_norm).lower(0.).as_text())"
$ git checkout main; py -c $prog > ~/desktop/main.txt
$ git checkout deprecate_optax_global_norm; py -c $prog > ~/desktop/deprecate_optax_global_norm.txt
$ git diff ~/desktop/main.txt ~/desktop/deprecate_optax_global_norm.txt

Diff:

 module @jit_global_norm attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
   func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = "result"}) {
     %0 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
-    %1 = stablehlo.convert %0 : tensor<f32>
     %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
-    %2 = stablehlo.reduce(%1 init: %cst) applies stablehlo.add across dimensions = [] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %1 = stablehlo.multiply %cst, %cst : tensor<f32>
+    %2 = stablehlo.add %0, %1 : tensor<f32>
+    %3 = stablehlo.convert %2 : tensor<f32>
     %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
-    %3 = stablehlo.add %cst_0, %2 : tensor<f32>
-    %4 = stablehlo.sqrt %3 : tensor<f32>
-    return %4 : tensor<f32>
+    %4 = stablehlo.reduce(%3 init: %cst_0) applies stablehlo.add across dimensions = [] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
+    %5 = stablehlo.add %cst_1, %4 : tensor<f32>
+    %6 = stablehlo.sqrt %5 : tensor<f32>
+    return %6 : tensor<f32>
   }
 }

Separately:

Old
module @jit_global_norm attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = "result"}) {
    %0 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
    %1 = stablehlo.convert %0 : tensor<f32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %2 = stablehlo.reduce(%1 init: %cst) applies stablehlo.add across dimensions = [] : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %3 = stablehlo.add %cst_0, %2 : tensor<f32>
    %4 = stablehlo.sqrt %3 : tensor<f32>
    return %4 : tensor<f32>
  }
}
New
module @jit_global_norm attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = "result"}) {
    %0 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.multiply %cst, %cst : tensor<f32>
    %2 = stablehlo.add %0, %1 : tensor<f32>
    %3 = stablehlo.convert %2 : tensor<f32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %4 = stablehlo.reduce(%3 init: %cst_0) applies stablehlo.add across dimensions = [] : (tensor<f32>, tensor<f32>) -> tensor<f32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %5 = stablehlo.add %cst_1, %4 : tensor<f32>
    %6 = stablehlo.sqrt %5 : tensor<f32>
    return %6 : tensor<f32>
  }
}

For reference, here are the current implementations of global_norm and tree_norm.

A difference is that the current implementation of global_norm uses numerics.abs_sq:

def abs_sq(x):
  return (x.conj() * x).real

whereas tree_norm uses _tree_math._square:

def _square(leaf):
  return jnp.square(leaf.real) + jnp.square(leaf.imag)

We should probably pick the most efficient of these two for both functions.

Based on HLO size, it looks like abs_sq is slightly more efficient than _square. The latter adds an extra add instruction and an extra multiply instruction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants