-
Notifications
You must be signed in to change notification settings - Fork 520
[WIP] [mxfp8 moe training] add torchao MXFP8 MoE training integration; bump version guard #1701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
88c6894
bd14959
7cb997f
b6c4cd8
6857875
4643915
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ class MXConverter(ModelConverter): | |
enabled: bool | ||
filter_fqns: List[str] | ||
mx_config: Any # MXLinearConfig type when imported | ||
token_group_alignment_size = 32 | ||
|
||
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | ||
# Ensure minimum torchao versions | ||
|
@@ -39,8 +40,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | |
) | ||
torchao_version = version("torchao") | ||
|
||
# Last torchao release was 0.12.0, so nightly build starts with 0.13.0+git... | ||
is_nightly_build = torchao_version.startswith("0.13.0") | ||
# Require latest release or nightly builds for prototype features | ||
is_nightly_build = torchao_version.startswith("0.14.0") | ||
if not is_nightly_build: | ||
raise ImportError( | ||
f"torchao version {torchao_version} is too old, please install torchao nightly build and try again" | ||
|
@@ -52,7 +53,6 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | |
), "MXFP8 is only supported on SM100 or architectures" | ||
|
||
# TP not yet supported with torch.compile | ||
|
||
model_compile_enabled = ( | ||
job_config.compile.enable and "model" in job_config.compile.components | ||
) | ||
|
@@ -61,10 +61,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | |
), "TP not yet supported with torch.compile for mxfp8" | ||
|
||
# For MoE training with mxfp8, token group sizes must be multiples of 32 | ||
if job_config.mx.moe_fqns_prototype: | ||
mxfp8_block_size = 32 | ||
set_token_group_alignment_size_m(mxfp8_block_size) | ||
logger.info(f"Setting token group alignment size to {mxfp8_block_size}") | ||
self.moe_fqns = job_config.mx.moe_fqns_prototype | ||
if self.moe_fqns: | ||
logger.info( | ||
f"Setting token group alignment size to {self.token_group_alignment_size}" | ||
) | ||
set_token_group_alignment_size_m(self.token_group_alignment_size) | ||
|
||
# Configure MXFP8 | ||
from torchao.prototype.mx_formats.config import ( | ||
|
@@ -94,6 +96,13 @@ def convert(self, model: nn.Module): | |
from torchao.prototype.mx_formats.config import MXLinearConfig | ||
from torchao.quantization import quantize_ | ||
|
||
# MoE conversion must take place before MXLinear conversion, otherwise the MXLinear will | ||
# be converted back to nn.Linear: | ||
# https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299 | ||
# TODO: add warning in torchao when this happens, or find a better way to avoid this. | ||
if self.moe_fqns: | ||
self._convert_moe_layers(model) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it feels like the converter registered for self.config should handle this case specifically There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as in having separate converters registered for mxfp8 moe, fp8 moe etc? If so, I kind of agree actually, the converters are becoming a bit of a mess, and having separate converters would also allow users to convert just dense or just MoE (or both), rather than the current state of having to convert dense in order to convert MoE. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that is one way, another thing I was thinking is that |
||
assert isinstance(self.config, MXLinearConfig) | ||
quantize_( | ||
model, | ||
|
@@ -102,6 +111,30 @@ def convert(self, model: nn.Module): | |
) | ||
logger.info("Swapped to MXLinear layers") | ||
|
||
def _convert_moe_layers(self, model: nn.Module): | ||
""" | ||
Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor, | ||
to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs. | ||
""" | ||
from torchao.prototype.moe_training.conversion_utils import ( | ||
MoEScalingType, | ||
MoETrainingConfig, | ||
) | ||
from torchao.quantization.quant_api import quantize_ | ||
|
||
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: | ||
for target_fqn in self.moe_fqns: | ||
if target_fqn in cur_fqn: | ||
return True | ||
return False | ||
Comment on lines
+125
to
+129
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you think we can put this in |
||
|
||
config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8) | ||
quantize_(model, config=config, filter_fn=moe_module_filter_fn) | ||
logger.info( | ||
f"Converted MoE layers matching FQNS {self.moe_fqns} " | ||
"to use dynamic MXFP8 quantization with scaled grouped GEMMs" | ||
) | ||
|
||
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): | ||
""" | ||
MXFP8 doesn't require any post-optimizer hooks at the moment | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,3 +75,7 @@ enable_fsdp_float8_all_gather = false | |
precompute_float8_dynamic_scale_for_fsdp = false | ||
filter_fqns = ["output", "router.gate"] | ||
moe_fqns_prototype = ["experts"] | ||
|
||
[mx] | ||
filter_fqns = ["output", "router.gate"] | ||
moe_fqns_prototype = ["experts,shared_expert"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shared_expert is now made up of 3 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, i didn't know that this had changed. will update accordingly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe make this a MACRO if it's not really a "variable". E.g. you can keep a central map in
quantization/__init__.py
for both fp8 and mx.