Skip to content

Commit 90582f1

Browse files
[mxfp8 moe training] update torchao MXFP8 MoE training integration; bump version guard
1 parent 612e4a1 commit 90582f1

File tree

1 file changed

+46
-7
lines changed
  • torchtitan/components/quantization

1 file changed

+46
-7
lines changed

torchtitan/components/quantization/mx.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class MXConverter(ModelConverter):
3030
enabled: bool
3131
filter_fqns: List[str]
3232
mx_config: Any # MXLinearConfig type when imported
33+
mxfp8_token_group_alignment_size = 32
3334

3435
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
3536
# Ensure minimum torchao versions
@@ -39,8 +40,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
3940
)
4041
torchao_version = version("torchao")
4142

42-
# Last torchao release was 0.12.0, so nightly build starts with 0.13.0+git...
43-
is_nightly_build = torchao_version.startswith("0.13.0")
43+
# Last torchao release was 0.13.0, so nightly build starts with 0.13.0+git...
44+
is_nightly_build = torchao_version.startswith("0.14.0")
4445
if not is_nightly_build:
4546
raise ImportError(
4647
f"torchao version {torchao_version} is too old, please install torchao nightly build and try again"
@@ -52,7 +53,6 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5253
), "MXFP8 is only supported on SM100 or architectures"
5354

5455
# TP not yet supported with torch.compile
55-
5656
model_compile_enabled = (
5757
job_config.compile.enable and "model" in job_config.compile.components
5858
)
@@ -61,10 +61,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
6161
), "TP not yet supported with torch.compile for mxfp8"
6262

6363
# For MoE training with mxfp8, token group sizes must be multiples of 32
64-
if job_config.mx.moe_fqns_prototype:
65-
mxfp8_block_size = 32
66-
set_token_group_alignment_size_m(mxfp8_block_size)
67-
logger.info(f"Setting token group alignment size to {mxfp8_block_size}")
64+
self.moe_fqns = job_config.mx.moe_fqns_prototype
65+
if self.moe_fqns:
66+
logger.info(
67+
f"Setting token group alignment size to {self.mxfp8_token_group_alignment_size}"
68+
)
69+
set_token_group_alignment_size_m(self.mxfp8_token_group_alignment_size)
6870

6971
# Configure MXFP8
7072
from torchao.prototype.mx_formats.config import (
@@ -94,6 +96,13 @@ def convert(self, model: nn.Module):
9496
from torchao.prototype.mx_formats.config import MXLinearConfig
9597
from torchao.quantization import quantize_
9698

99+
# MoE conversion must take place before MXLinear conversion, otherwise the MXLinear will
100+
# be converted back to nn.Linear:
101+
# https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299
102+
# TODO: add warning in torchao when this happens, or find a better way to avoid this.
103+
if self.moe_fqns:
104+
self._convert_moe_layers(model)
105+
97106
assert isinstance(self.config, MXLinearConfig)
98107
quantize_(
99108
model,
@@ -102,6 +111,36 @@ def convert(self, model: nn.Module):
102111
)
103112
logger.info("Swapped to MXLinear layers")
104113

114+
def _convert_moe_layers(self, model: nn.Module):
115+
"""
116+
Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
117+
to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
118+
"""
119+
from torchao.quantization.quant_api import quantize_
120+
121+
try:
122+
from torchao.prototype.moe_training.conversion_utils import (
123+
MoEScalingType,
124+
MoETrainingConfig,
125+
)
126+
except ImportError as e:
127+
raise ImportError(
128+
"torchao installation does not have MoE training support. Please install torchao nightly build."
129+
) from e
130+
131+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
132+
for target_fqn in self.moe_fqns:
133+
if target_fqn in cur_fqn:
134+
return True
135+
return False
136+
137+
config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
138+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
139+
logger.info(
140+
f"Converted MoE layers matching FQNS {self.moe_fqns} "
141+
"to use dynamic MXFP8 quantization with scaled grouped GEMMs"
142+
)
143+
105144
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
106145
"""
107146
MXFP8 doesn't require any post-optimizer hooks at the moment

0 commit comments

Comments
 (0)