Skip to content

Commit 67d63d0

Browse files
[JAX] Support for checkpointing quantizations (#2356)
* Support for checkpointing quantizations Signed-off-by: Jeremy Berchtold <[email protected]> * Add jaxpr test for quant checkpoint name Signed-off-by: Jeremy Berchtold <[email protected]> * Revert "Support for checkpointing quantizations" This reverts commit f7b7849. Signed-off-by: JAX Toolbox <[email protected]> * Checkpoint quantizations Signed-off-by: Jeremy Berchtold <[email protected]> * lint Signed-off-by: Jeremy Berchtold <[email protected]> * revert other files Signed-off-by: Jeremy Berchtold <[email protected]> * move checkpointing to VJPs Signed-off-by: Jeremy Berchtold <[email protected]> * fix ci failure Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: JAX Toolbox <[email protected]> Co-authored-by: JAX Toolbox <[email protected]>
1 parent 9440b76 commit 67d63d0

File tree

7 files changed

+252
-62
lines changed

7 files changed

+252
-62
lines changed

tests/jax/test_recipe_characteristics.py

Lines changed: 84 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -263,23 +263,16 @@ def test_autocast_nvfp4_block_scaling(self):
263263
class TestJaxprAndHlo:
264264
"""Tests to verify Jaxpr and/or HLO of compiled modules apply expected recipe functionality and optimizations."""
265265

266-
@pytest_parametrize_wrapper(
267-
"quantization_recipe",
268-
[
269-
quantization_recipe
270-
for quantization_recipe in SUPPORTED_RECIPES
271-
if isinstance(quantization_recipe, NVFP4BlockScaling)
272-
],
273-
)
274-
def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe):
275-
"""Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton."""
276-
266+
def _generate_jaxpr_for_layernorm_mlp_fwd_bwd(self, quantization_recipe, ln_mlp_kwargs=None):
267+
"""Generates the jaxpr for a forward and backward pass of LayerNormMLP under the given quantization recipe."""
268+
ln_mlp_kwargs = ln_mlp_kwargs or {}
277269
with te.autocast(enabled=True, recipe=quantization_recipe, mesh_resource=te.MeshResource()):
278270
model = te_flax.LayerNormMLP(
279271
layernorm_type="rmsnorm",
280272
return_layernorm_output=False,
281273
intermediate_dropout_rate=0.0,
282274
dtype=jnp.bfloat16,
275+
**ln_mlp_kwargs,
283276
)
284277

285278
var_collect = model.init(
@@ -292,29 +285,83 @@ def loss_fn(x, rngs):
292285

293286
x = jax.random.normal(jax.random.PRNGKey(0), (128, 128), dtype=jnp.bfloat16)
294287
rngs = {"sr_rng": jax.random.PRNGKey(1), "dropout": jax.random.PRNGKey(2)}
295-
jaxpr = jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs)
296-
297-
rht_amax_eqns = [
298-
eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper"
299-
]
300-
301-
assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}"
302-
303-
def assert_param(index, tensor_name, expected_value: bool):
304-
if expected_value:
305-
assert rht_amax_eqns[index].params["produce_regular_amax"] == True, (
306-
f"Expected produce_regular_amax for {tensor_name} to be True, indicating no"
307-
" reuse of amax as this tensor does not have a previous operation to fuse"
308-
" with"
309-
)
310-
else:
311-
assert rht_amax_eqns[index].params["produce_regular_amax"] == False, (
312-
f"Expected produce_regular_amax for {tensor_name} to be False, indicating"
313-
" reuse of amax"
314-
)
315-
316-
assert_param(0, "fwd ln+q", False)
317-
assert_param(1, "fwd act+q", False)
318-
# No previous op before incoming dgrad in the backward so amax is not reused
319-
assert_param(2, "bwd dgrad", True)
320-
assert_param(3, "bwd dact+q", False)
288+
return jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs)
289+
290+
@pytest_parametrize_wrapper(
291+
"quantization_recipe",
292+
[
293+
quantization_recipe
294+
for quantization_recipe in SUPPORTED_RECIPES
295+
if isinstance(quantization_recipe, NVFP4BlockScaling)
296+
],
297+
)
298+
def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe):
299+
"""Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton."""
300+
301+
jaxpr = self._generate_jaxpr_for_layernorm_mlp_fwd_bwd(quantization_recipe)
302+
303+
rht_amax_eqns = [
304+
eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper"
305+
]
306+
307+
assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}"
308+
309+
def assert_param(index, tensor_name, expected_value: bool):
310+
if expected_value:
311+
assert rht_amax_eqns[index].params["produce_regular_amax"] == True, (
312+
f"Expected produce_regular_amax for {tensor_name} to be True, indicating no"
313+
" reuse of amax as this tensor does not have a previous operation to fuse"
314+
" with"
315+
)
316+
else:
317+
assert rht_amax_eqns[index].params["produce_regular_amax"] == False, (
318+
f"Expected produce_regular_amax for {tensor_name} to be False, indicating"
319+
" reuse of amax"
320+
)
321+
322+
assert_param(0, "fwd ln+q", False)
323+
assert_param(1, "fwd act+q", False)
324+
# No previous op before incoming dgrad in the backward so amax is not reused
325+
assert_param(2, "bwd dgrad", True)
326+
assert_param(3, "bwd dact+q", False)
327+
328+
@pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES)
329+
@pytest_parametrize_wrapper(
330+
"quantization_checkpoint_name",
331+
[None, "quantization", "some_arbitrary_user_checkpoint_name"],
332+
)
333+
def test_recipe_supports_quantization_checkpointing(
334+
self, quantization_recipe, quantization_checkpoint_name
335+
):
336+
"""Tests that all supported quantization recipes correctly use checkpoint_name."""
337+
338+
kwargs = {
339+
"quantization_checkpoint_name": quantization_checkpoint_name,
340+
}
341+
jaxpr = self._generate_jaxpr_for_layernorm_mlp_fwd_bwd(quantization_recipe, kwargs)
342+
343+
checkpoint_name_eqns = [
344+
eqn
345+
for eqn in jaxpr.jaxpr.eqns
346+
if eqn.primitive.name == "name" and eqn.params["name"] == quantization_checkpoint_name
347+
]
348+
349+
if quantization_checkpoint_name is None:
350+
assert len(checkpoint_name_eqns) == 0, (
351+
"Expected 0 checkpoint_name eqns when quantization_checkpoint_name is None, got"
352+
f" {len(checkpoint_name_eqns)}"
353+
)
354+
return
355+
356+
# 12 checkpointed values:
357+
# - Fwd pass:
358+
# - Input RMSNorm+Q -> 3 possible output tensors that will be used in the backward
359+
# - Kernel Q -> 3 possible output tensors that will be used in the backward
360+
# - Input Activation+Q -> 3 possible output tensors that will be used in the backward
361+
# - Kernel Q -> 3 possible output tensors that will be used in the backward
362+
expected_checkpoint_eqn_count = 12
363+
364+
assert len(checkpoint_name_eqns) == expected_checkpoint_eqn_count, (
365+
f"Expected {expected_checkpoint_eqn_count} checkpoint_name eqns when"
366+
f" quantization_checkpoint_name is set, got {len(checkpoint_name_eqns)}"
367+
)

transformer_engine/jax/dense.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .cpp_extensions.amax import AmaxScope
2020
from .quantize import (
2121
ScaledTensorFactory,
22+
ScaledTensor,
2223
ScalingMode,
2324
QuantizeLayout,
2425
QuantizerSet,
@@ -227,8 +228,8 @@ def _dense_fwd_rule(
227228
output += jnp.reshape(bias, bias_new_shape)
228229

229230
ctx = (
230-
casted_x.get_tensor(usage=TensorUsage.LHS_TRANS),
231-
casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS),
231+
casted_x.get_tensor(usage=TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x),
232+
casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel),
232233
x.shape,
233234
kernel.shape,
234235
use_bias,
@@ -529,8 +530,12 @@ def _grouped_dense_fwd_rule(
529530

530531
ctx = (
531532
group_sizes,
532-
ctx_x,
533-
ctx_kernel,
533+
ctx_x.checkpoint(quantizer_set.x) if isinstance(ctx_x, ScaledTensor) else ctx_x,
534+
(
535+
ctx_kernel.checkpoint(quantizer_set.kernel)
536+
if isinstance(ctx_kernel, ScaledTensor)
537+
else ctx_kernel
538+
),
534539
x.shape,
535540
kernel.shape,
536541
use_bias,

transformer_engine/jax/flax/module.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77
from functools import reduce
88
import operator
9-
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType
9+
from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType, Optional
1010

1111
import numpy as np
1212
import jax.numpy as jnp
@@ -345,7 +345,11 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
345345
"""
346346

347347
def generate_quantizer_set(
348-
self, postfix: str = "", variable_collection: str = None, fp8_recipe=None
348+
self,
349+
postfix: str = "",
350+
variable_collection: str = None,
351+
quantization_checkpoint_name: Optional[str] = None,
352+
fp8_recipe=None,
349353
):
350354
"""
351355
Generate a set of FP8 meta for a GEMM.
@@ -375,7 +379,9 @@ def generate_quantizer_set(
375379
quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta)
376380

377381
quantizer_set = QuantizerFactory.create_set(
378-
fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set
382+
fp8_recipe=fp8_recipe,
383+
quantize_meta_set=quantize_meta_set,
384+
checkpoint_name=quantization_checkpoint_name,
379385
)
380386
return quantizer_set
381387

@@ -424,6 +430,8 @@ class DenseGeneral(TransformerEngineBase):
424430
The data type used to allocate the initial parameters.
425431
transpose_batch_sequence: bool, default = False
426432
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
433+
quantization_checkpoint_name: Optional[str], default = None
434+
The name for checkpointing quantizations.
427435
"""
428436

429437
features: Union[Iterable[int], int]
@@ -439,6 +447,7 @@ class DenseGeneral(TransformerEngineBase):
439447
dtype: DType = jnp.float32
440448
input_axes: Tuple[str, ...] = ()
441449
transpose_batch_sequence: bool = False
450+
quantization_checkpoint_name: Optional[str] = None
442451

443452
def __post_init__(self):
444453
if self.kernel_init is None:
@@ -496,7 +505,9 @@ def __call__(self, inputs: Array) -> Array:
496505
else:
497506
bias = None
498507

499-
quantizer_set = self.generate_quantizer_set()
508+
quantizer_set = self.generate_quantizer_set(
509+
quantization_checkpoint_name=self.quantization_checkpoint_name
510+
)
500511
contract_ind = tuple(range(0, len(axis)))
501512
y = dense(
502513
inputs,
@@ -628,6 +639,8 @@ class LayerNormDenseGeneral(TransformerEngineBase):
628639
value or None. When None is set, then no scaling is applied.
629640
transpose_batch_sequence: bool, default = False
630641
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
642+
quantization_checkpoint_name: Optional[str], default = None
643+
The name for checkpointing quantizations.
631644
"""
632645

633646
features: Union[Iterable[int], int]
@@ -654,6 +667,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
654667
dot_input_axes: Tuple[str, ...] = None
655668
depth_scaling: float = None
656669
transpose_batch_sequence: bool = False
670+
quantization_checkpoint_name: Optional[str] = None
657671

658672
def __post_init__(self):
659673
if self.kernel_init is None:
@@ -693,7 +707,9 @@ def __call__(self, inputs: Array) -> Array:
693707
input_dtype = inputs.dtype
694708
ln_output = None
695709

696-
quantizer_set = self.generate_quantizer_set()
710+
quantizer_set = self.generate_quantizer_set(
711+
quantization_checkpoint_name=self.quantization_checkpoint_name
712+
)
697713

698714
fuse_layernorm = (
699715
get_quantize_config().is_fp8_enabled()
@@ -941,6 +957,8 @@ class LayerNormMLP(TransformerEngineBase):
941957
The data type used to allocate the initial parameters.
942958
transpose_batch_sequence: bool, default = False
943959
Indicate whether to transpose the batch and sequence dimensions of the input tensor.
960+
quantization_checkpoint_name: Optional[str], default = None
961+
The name for checkpointing quantizations.
944962
"""
945963

946964
intermediate_dim: int = 2048
@@ -976,6 +994,7 @@ class LayerNormMLP(TransformerEngineBase):
976994
ffn1_ckpt_name: str = "ffn1"
977995
ffn2_ckpt_name: str = "ffn2"
978996
transpose_batch_sequence: bool = False
997+
quantization_checkpoint_name: Optional[str] = None
979998

980999
def __post_init__(self):
9811000
if self.kernel_init is None:
@@ -1010,8 +1029,12 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
10101029
"""
10111030
assert self.axis == -1, "Only support axis == -1 at this moment"
10121031

1013-
ffn1_quantizer_set = self.generate_quantizer_set("_0")
1014-
ffn2_quantizer_set = self.generate_quantizer_set("_1")
1032+
ffn1_quantizer_set = self.generate_quantizer_set(
1033+
"_0", quantization_checkpoint_name=self.quantization_checkpoint_name
1034+
)
1035+
ffn2_quantizer_set = self.generate_quantizer_set(
1036+
"_1", quantization_checkpoint_name=self.quantization_checkpoint_name
1037+
)
10151038

10161039
input_dtype = inputs.dtype
10171040
ln_output = None

transformer_engine/jax/layernorm_dense.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,8 @@ def _layernorm_dense_fwd_rule(
236236
output += jnp.reshape(bias, bias_new_shape)
237237

238238
ctx = (
239-
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
240-
casted_kernel.get_tensor(TensorUsage.RHS_TRANS),
239+
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x),
240+
casted_kernel.get_tensor(TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel),
241241
x.shape,
242242
kernel.shape,
243243
mu,

transformer_engine/jax/layernorm_mlp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,11 +390,11 @@ def _layernorm_mlp_fwd_rule(
390390
rsigma,
391391
gamma,
392392
beta,
393-
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS),
394-
casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS),
393+
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(ffn1_quantizer_set.x),
394+
casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS).checkpoint(ffn1_quantizer_set.kernel),
395395
dot_1_output,
396-
casted_act_out.get_tensor(TensorUsage.LHS_TRANS),
397-
casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS),
396+
casted_act_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(ffn2_quantizer_set.x),
397+
casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS).checkpoint(ffn2_quantizer_set.kernel),
398398
x_contracting_dims,
399399
k_contracting_dims,
400400
kernel_1.shape,

0 commit comments

Comments
 (0)