Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Sep 12, 2025

WIP

  • planning to update PR with perf numbers, loss curves before landing

Summary

  • We've recently landed prototype mxfp8 MoE training support in torchao (eager only currently - compile support in progress)
  • This PR adds support for MXFP8 MoE training through torchao, and bumps the version guard (torchao 0.13.0 was just recently, and we want users on nightly for now, since this is a prototype feature we are actively adding compile support, etc)
  • Add default toml configs for MX component, filtering out output,router.gate from conversion (fixes MXFP8 error for Llama4 from MXLinear #1703)
  • Set default MoE target FQNs as experts,shared_expert

Test plan

Llama4 debug model config:

    "debugmodel": TransformerModelArgs(
        dim=5120,
        n_layers=2,
        n_heads=40,
        n_kv_heads=8,
        ffn_dim_multiplier=1.2,
        multiple_of=2048,
        rope_theta=500000,
        max_seq_len=10485760,
        moe_args=MoEArgs(num_experts=4),
        interleave_moe_layer_step=1,
    ),

Limitations

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 12, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft September 12, 2025 00:21
@danielvegamyhre danielvegamyhre marked this pull request as ready for review September 12, 2025 00:27
@danielvegamyhre
Copy link
Contributor Author

cc @tianyu-l @drisspg for review

@@ -39,8 +40,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
)
torchao_version = version("torchao")

# Last torchao release was 0.12.0, so nightly build starts with 0.13.0+git...
is_nightly_build = torchao_version.startswith("0.13.0")
# Last torchao release was 0.13.0, so nightly build starts with 0.13.0+git...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: update this comment to not be specific to any specific verison

logger.info(
f"Setting token group alignment size to {self.mxfp8_token_group_alignment_size}"
)
set_token_group_alignment_size_m(self.mxfp8_token_group_alignment_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we were making this alignment a variable, maybe update the name so it's not mxfp8 specific. If your intention so that if/when we add nvfp4/mxfp4 we just bumpthis

# TODO: add warning in torchao when this happens, or find a better way to avoid this.
if self.moe_fqns:
self._convert_moe_layers(model)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it feels like the converter registered for self.config should handle this case specifically

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as in having separate converters registered for mxfp8 moe, fp8 moe etc? If so, I kind of agree actually, the converters are becoming a bit of a mess, and having separate converters would also allow users to convert just dense or just MoE (or both), rather than the current state of having to convert dense in order to convert MoE.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that is one way, another thing I was thinking is that https://github.com/pytorch/ao/blob/93030e750186ace1c1c2ee7a849e2818a9f0ffde/torchao/prototype/moe_training/conversion_utils.py#L50 should be able to gracefully handle the case where the module has already be converted

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the log

Swapped w1.weight to ScaledGroupedMMTensor

it's hard to tell which w1 is converted. Can we include full fqn?

(eager only currently - compile support in progress)

Does it mean we shouldn't care too much about performance with this PR?
How about numerics?

I'll try to find some time to look into #1651

@@ -30,6 +30,7 @@ class MXConverter(ModelConverter):
enabled: bool
filter_fqns: List[str]
mx_config: Any # MXLinearConfig type when imported
token_group_alignment_size = 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make this a MACRO if it's not really a "variable". E.g. you can keep a central map in quantization/__init__.py for both fp8 and mx.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please revert change in this file, as I believe they are accidentally submitted


[mx]
filter_fqns = ["output", "router.gate"]
moe_fqns_prototype = ["experts,shared_expert"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shared_expert is now made up of 3 nn.Linear instead of GroupedExperts, so maybe you don't want to include them? In your paste they are converted to MXLinear

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, i didn't know that this had changed. will update accordingly

Comment on lines +125 to +129
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
for target_fqn in self.moe_fqns:
if target_fqn in cur_fqn:
return True
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think we can put this in quantization/utils.py, together with the similar thing in fp8 file?

@danielvegamyhre danielvegamyhre marked this pull request as draft September 14, 2025 00:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

MXFP8 error for Llama4 from MXLinear
3 participants