-
Notifications
You must be signed in to change notification settings - Fork 509
[mxfp8 moe training] add torchao MXFP8 MoE training integration; bump version guard #1701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
d416130
to
960840d
Compare
…ump version guard
960840d
to
90582f1
Compare
@@ -39,8 +40,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | |||
) | |||
torchao_version = version("torchao") | |||
|
|||
# Last torchao release was 0.12.0, so nightly build starts with 0.13.0+git... | |||
is_nightly_build = torchao_version.startswith("0.13.0") | |||
# Last torchao release was 0.13.0, so nightly build starts with 0.13.0+git... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: update this comment to not be specific to any specific verison
logger.info( | ||
f"Setting token group alignment size to {self.mxfp8_token_group_alignment_size}" | ||
) | ||
set_token_group_alignment_size_m(self.mxfp8_token_group_alignment_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we were making this alignment a variable, maybe update the name so it's not mxfp8 specific. If your intention so that if/when we add nvfp4/mxfp4 we just bumpthis
# TODO: add warning in torchao when this happens, or find a better way to avoid this. | ||
if self.moe_fqns: | ||
self._convert_moe_layers(model) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it feels like the converter registered for self.config should handle this case specifically
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as in having separate converters registered for mxfp8 moe, fp8 moe etc? If so, I kind of agree actually, the converters are becoming a bit of a mess, and having separate converters would also allow users to convert just dense or just MoE (or both), rather than the current state of having to convert dense in order to convert MoE.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that is one way, another thing I was thinking is that https://github.com/pytorch/ao/blob/93030e750186ace1c1c2ee7a849e2818a9f0ffde/torchao/prototype/moe_training/conversion_utils.py#L50
should be able to gracefully handle the case where the module has already be converted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the log
Swapped w1.weight to ScaledGroupedMMTensor
it's hard to tell which w1
is converted. Can we include full fqn?
(eager only currently - compile support in progress)
Does it mean we shouldn't care too much about performance with this PR?
How about numerics?
I'll try to find some time to look into #1651
@@ -30,6 +30,7 @@ class MXConverter(ModelConverter): | |||
enabled: bool | |||
filter_fqns: List[str] | |||
mx_config: Any # MXLinearConfig type when imported | |||
token_group_alignment_size = 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe make this a MACRO if it's not really a "variable". E.g. you can keep a central map in quantization/__init__.py
for both fp8 and mx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please revert change in this file, as I believe they are accidentally submitted
|
||
[mx] | ||
filter_fqns = ["output", "router.gate"] | ||
moe_fqns_prototype = ["experts,shared_expert"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shared_expert is now made up of 3 nn.Linear
instead of GroupedExperts
, so maybe you don't want to include them? In your paste they are converted to MXLinear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, i didn't know that this had changed. will update accordingly
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: | ||
for target_fqn in self.moe_fqns: | ||
if target_fqn in cur_fqn: | ||
return True | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you think we can put this in quantization/utils.py
, together with the similar thing in fp8 file?
WIP
Summary
output,router.gate
from conversion (fixes MXFP8 error for Llama4 from MXLinear #1703)experts,shared_expert
Test plan
Llama4 debug model config:
Limitations