@@ -263,23 +263,16 @@ def test_autocast_nvfp4_block_scaling(self):
263263class TestJaxprAndHlo :
264264 """Tests to verify Jaxpr and/or HLO of compiled modules apply expected recipe functionality and optimizations."""
265265
266- @pytest_parametrize_wrapper (
267- "quantization_recipe" ,
268- [
269- quantization_recipe
270- for quantization_recipe in SUPPORTED_RECIPES
271- if isinstance (quantization_recipe , NVFP4BlockScaling )
272- ],
273- )
274- def test_layernorm_mlp_reuses_amax_nvfp4 (self , quantization_recipe ):
275- """Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton."""
276-
266+ def _generate_jaxpr_for_layernorm_mlp_fwd_bwd (self , quantization_recipe , ln_mlp_kwargs = None ):
267+ """Generates the jaxpr for a forward and backward pass of LayerNormMLP under the given quantization recipe."""
268+ ln_mlp_kwargs = ln_mlp_kwargs or {}
277269 with te .autocast (enabled = True , recipe = quantization_recipe , mesh_resource = te .MeshResource ()):
278270 model = te_flax .LayerNormMLP (
279271 layernorm_type = "rmsnorm" ,
280272 return_layernorm_output = False ,
281273 intermediate_dropout_rate = 0.0 ,
282274 dtype = jnp .bfloat16 ,
275+ ** ln_mlp_kwargs ,
283276 )
284277
285278 var_collect = model .init (
@@ -292,29 +285,83 @@ def loss_fn(x, rngs):
292285
293286 x = jax .random .normal (jax .random .PRNGKey (0 ), (128 , 128 ), dtype = jnp .bfloat16 )
294287 rngs = {"sr_rng" : jax .random .PRNGKey (1 ), "dropout" : jax .random .PRNGKey (2 )}
295- jaxpr = jax .make_jaxpr (jax .value_and_grad (loss_fn ))(x , rngs = rngs )
296-
297- rht_amax_eqns = [
298- eqn for eqn in jaxpr .jaxpr .eqns if eqn .primitive .name == "te_rht_amax_ffi_wrapper"
299- ]
300-
301- assert len (rht_amax_eqns ) == 4 , f"Expected 4 rht_amax_eqns, got { len (rht_amax_eqns )} "
302-
303- def assert_param (index , tensor_name , expected_value : bool ):
304- if expected_value :
305- assert rht_amax_eqns [index ].params ["produce_regular_amax" ] == True , (
306- f"Expected produce_regular_amax for { tensor_name } to be True, indicating no"
307- " reuse of amax as this tensor does not have a previous operation to fuse"
308- " with"
309- )
310- else :
311- assert rht_amax_eqns [index ].params ["produce_regular_amax" ] == False , (
312- f"Expected produce_regular_amax for { tensor_name } to be False, indicating"
313- " reuse of amax"
314- )
315-
316- assert_param (0 , "fwd ln+q" , False )
317- assert_param (1 , "fwd act+q" , False )
318- # No previous op before incoming dgrad in the backward so amax is not reused
319- assert_param (2 , "bwd dgrad" , True )
320- assert_param (3 , "bwd dact+q" , False )
288+ return jax .make_jaxpr (jax .value_and_grad (loss_fn ))(x , rngs = rngs )
289+
290+ @pytest_parametrize_wrapper (
291+ "quantization_recipe" ,
292+ [
293+ quantization_recipe
294+ for quantization_recipe in SUPPORTED_RECIPES
295+ if isinstance (quantization_recipe , NVFP4BlockScaling )
296+ ],
297+ )
298+ def test_layernorm_mlp_reuses_amax_nvfp4 (self , quantization_recipe ):
299+ """Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton."""
300+
301+ jaxpr = self ._generate_jaxpr_for_layernorm_mlp_fwd_bwd (quantization_recipe )
302+
303+ rht_amax_eqns = [
304+ eqn for eqn in jaxpr .jaxpr .eqns if eqn .primitive .name == "te_rht_amax_ffi_wrapper"
305+ ]
306+
307+ assert len (rht_amax_eqns ) == 4 , f"Expected 4 rht_amax_eqns, got { len (rht_amax_eqns )} "
308+
309+ def assert_param (index , tensor_name , expected_value : bool ):
310+ if expected_value :
311+ assert rht_amax_eqns [index ].params ["produce_regular_amax" ] == True , (
312+ f"Expected produce_regular_amax for { tensor_name } to be True, indicating no"
313+ " reuse of amax as this tensor does not have a previous operation to fuse"
314+ " with"
315+ )
316+ else :
317+ assert rht_amax_eqns [index ].params ["produce_regular_amax" ] == False , (
318+ f"Expected produce_regular_amax for { tensor_name } to be False, indicating"
319+ " reuse of amax"
320+ )
321+
322+ assert_param (0 , "fwd ln+q" , False )
323+ assert_param (1 , "fwd act+q" , False )
324+ # No previous op before incoming dgrad in the backward so amax is not reused
325+ assert_param (2 , "bwd dgrad" , True )
326+ assert_param (3 , "bwd dact+q" , False )
327+
328+ @pytest_parametrize_wrapper ("quantization_recipe" , SUPPORTED_RECIPES )
329+ @pytest_parametrize_wrapper (
330+ "quantization_checkpoint_name" ,
331+ [None , "quantization" , "some_arbitrary_user_checkpoint_name" ],
332+ )
333+ def test_recipe_supports_quantization_checkpointing (
334+ self , quantization_recipe , quantization_checkpoint_name
335+ ):
336+ """Tests that all supported quantization recipes correctly use checkpoint_name."""
337+
338+ kwargs = {
339+ "quantization_checkpoint_name" : quantization_checkpoint_name ,
340+ }
341+ jaxpr = self ._generate_jaxpr_for_layernorm_mlp_fwd_bwd (quantization_recipe , kwargs )
342+
343+ checkpoint_name_eqns = [
344+ eqn
345+ for eqn in jaxpr .jaxpr .eqns
346+ if eqn .primitive .name == "name" and eqn .params ["name" ] == quantization_checkpoint_name
347+ ]
348+
349+ if quantization_checkpoint_name is None :
350+ assert len (checkpoint_name_eqns ) == 0 , (
351+ "Expected 0 checkpoint_name eqns when quantization_checkpoint_name is None, got"
352+ f" { len (checkpoint_name_eqns )} "
353+ )
354+ return
355+
356+ # 12 checkpointed values:
357+ # - Fwd pass:
358+ # - Input RMSNorm+Q -> 3 possible output tensors that will be used in the backward
359+ # - Kernel Q -> 3 possible output tensors that will be used in the backward
360+ # - Input Activation+Q -> 3 possible output tensors that will be used in the backward
361+ # - Kernel Q -> 3 possible output tensors that will be used in the backward
362+ expected_checkpoint_eqn_count = 12
363+
364+ assert len (checkpoint_name_eqns ) == expected_checkpoint_eqn_count , (
365+ f"Expected { expected_checkpoint_eqn_count } checkpoint_name eqns when"
366+ f" quantization_checkpoint_name is set, got { len (checkpoint_name_eqns )} "
367+ )
0 commit comments