Skip to content

Commit 960840d

Browse files
[mxfp8 moe training] update torchao MXFP8 MoE training integration; bump version guard
1 parent 612e4a1 commit 960840d

File tree

1 file changed

+45
-5
lines changed
  • torchtitan/components/quantization

1 file changed

+45
-5
lines changed

torchtitan/components/quantization/mx.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class MXConverter(ModelConverter):
3030
enabled: bool
3131
filter_fqns: List[str]
3232
mx_config: Any # MXLinearConfig type when imported
33+
mxfp8_token_group_alignment_size = 32
3334

3435
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
3536
# Ensure minimum torchao versions
@@ -39,8 +40,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
3940
)
4041
torchao_version = version("torchao")
4142

42-
# Last torchao release was 0.12.0, so nightly build starts with 0.13.0+git...
43-
is_nightly_build = torchao_version.startswith("0.13.0")
43+
# Last torchao release was 0.13.0, so nightly build starts with 0.13.0+git...
44+
is_nightly_build = torchao_version.startswith("0.14.0")
4445
if not is_nightly_build:
4546
raise ImportError(
4647
f"torchao version {torchao_version} is too old, please install torchao nightly build and try again"
@@ -62,9 +63,11 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
6263

6364
# For MoE training with mxfp8, token group sizes must be multiples of 32
6465
if job_config.mx.moe_fqns_prototype:
65-
mxfp8_block_size = 32
66-
set_token_group_alignment_size_m(mxfp8_block_size)
67-
logger.info(f"Setting token group alignment size to {mxfp8_block_size}")
66+
logger.info(
67+
f"Setting token group alignment size to {self.mxfp8_token_group_alignment_size}"
68+
)
69+
set_token_group_alignment_size_m(self.mxfp8_token_group_alignment_size)
70+
self.moe_fqns = job_config.mx.moe_fqns_prototype
6871

6972
# Configure MXFP8
7073
from torchao.prototype.mx_formats.config import (
@@ -94,6 +97,13 @@ def convert(self, model: nn.Module):
9497
from torchao.prototype.mx_formats.config import MXLinearConfig
9598
from torchao.quantization import quantize_
9699

100+
# MoE conversion must take place before MXLinear conversion, otherwise the MXLinear will
101+
# be converted back to nn.Linear:
102+
# https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299
103+
# TODO: add warning in torchao when this happens, or find a better way to avoid this.
104+
if self.moe_fqns:
105+
self._convert_moe_layers(model)
106+
97107
assert isinstance(self.config, MXLinearConfig)
98108
quantize_(
99109
model,
@@ -102,6 +112,36 @@ def convert(self, model: nn.Module):
102112
)
103113
logger.info("Swapped to MXLinear layers")
104114

115+
def _convert_moe_layers(self, model: nn.Module):
116+
"""
117+
Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
118+
to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
119+
"""
120+
from torchao.quantization.quant_api import quantize_
121+
122+
try:
123+
from torchao.prototype.moe_training.conversion_utils import (
124+
MoEScalingType,
125+
MoETrainingConfig,
126+
)
127+
except ImportError as e:
128+
raise ImportError(
129+
"torchao installation does not have MoE training support. Please install torchao nightly build."
130+
) from e
131+
132+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
133+
for target_fqn in self.moe_fqns:
134+
if target_fqn in cur_fqn:
135+
return True
136+
return False
137+
138+
config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
139+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
140+
logger.info(
141+
f"Converted MoE layers matching FQNS {self.moe_fqns} "
142+
"to use dynamic MXFP8 quantization with scaled grouped GEMMs"
143+
)
144+
105145
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
106146
"""
107147
MXFP8 doesn't require any post-optimizer hooks at the moment

0 commit comments

Comments
 (0)