@@ -30,6 +30,7 @@ class MXConverter(ModelConverter):
30
30
enabled : bool
31
31
filter_fqns : List [str ]
32
32
mx_config : Any # MXLinearConfig type when imported
33
+ mxfp8_token_group_alignment_size = 32
33
34
34
35
def __init__ (self , job_config : JobConfig , parallel_dims : ParallelDims ):
35
36
# Ensure minimum torchao versions
@@ -39,8 +40,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
39
40
)
40
41
torchao_version = version ("torchao" )
41
42
42
- # Last torchao release was 0.12 .0, so nightly build starts with 0.13.0+git...
43
- is_nightly_build = torchao_version .startswith ("0.13 .0" )
43
+ # Last torchao release was 0.13 .0, so nightly build starts with 0.13.0+git...
44
+ is_nightly_build = torchao_version .startswith ("0.14 .0" )
44
45
if not is_nightly_build :
45
46
raise ImportError (
46
47
f"torchao version { torchao_version } is too old, please install torchao nightly build and try again"
@@ -52,7 +53,6 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
52
53
), "MXFP8 is only supported on SM100 or architectures"
53
54
54
55
# TP not yet supported with torch.compile
55
-
56
56
model_compile_enabled = (
57
57
job_config .compile .enable and "model" in job_config .compile .components
58
58
)
@@ -61,10 +61,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
61
61
), "TP not yet supported with torch.compile for mxfp8"
62
62
63
63
# For MoE training with mxfp8, token group sizes must be multiples of 32
64
- if job_config .mx .moe_fqns_prototype :
65
- mxfp8_block_size = 32
66
- set_token_group_alignment_size_m (mxfp8_block_size )
67
- logger .info (f"Setting token group alignment size to { mxfp8_block_size } " )
64
+ self .moe_fqns = job_config .mx .moe_fqns_prototype
65
+ if self .moe_fqns :
66
+ logger .info (
67
+ f"Setting token group alignment size to { self .mxfp8_token_group_alignment_size } "
68
+ )
69
+ set_token_group_alignment_size_m (self .mxfp8_token_group_alignment_size )
68
70
69
71
# Configure MXFP8
70
72
from torchao .prototype .mx_formats .config import (
@@ -94,6 +96,13 @@ def convert(self, model: nn.Module):
94
96
from torchao .prototype .mx_formats .config import MXLinearConfig
95
97
from torchao .quantization import quantize_
96
98
99
+ # MoE conversion must take place before MXLinear conversion, otherwise the MXLinear will
100
+ # be converted back to nn.Linear:
101
+ # https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299
102
+ # TODO: add warning in torchao when this happens, or find a better way to avoid this.
103
+ if self .moe_fqns :
104
+ self ._convert_moe_layers (model )
105
+
97
106
assert isinstance (self .config , MXLinearConfig )
98
107
quantize_ (
99
108
model ,
@@ -102,6 +111,36 @@ def convert(self, model: nn.Module):
102
111
)
103
112
logger .info ("Swapped to MXLinear layers" )
104
113
114
+ def _convert_moe_layers (self , model : nn .Module ):
115
+ """
116
+ Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
117
+ to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
118
+ """
119
+ from torchao .quantization .quant_api import quantize_
120
+
121
+ try :
122
+ from torchao .prototype .moe_training .conversion_utils import (
123
+ MoEScalingType ,
124
+ MoETrainingConfig ,
125
+ )
126
+ except ImportError as e :
127
+ raise ImportError (
128
+ "torchao installation does not have MoE training support. Please install torchao nightly build."
129
+ ) from e
130
+
131
+ def moe_module_filter_fn (mod : nn .Module , cur_fqn : str ) -> bool :
132
+ for target_fqn in self .moe_fqns :
133
+ if target_fqn in cur_fqn :
134
+ return True
135
+ return False
136
+
137
+ config = MoETrainingConfig (scaling_type = MoEScalingType .MXFP8 )
138
+ quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
139
+ logger .info (
140
+ f"Converted MoE layers matching FQNS { self .moe_fqns } "
141
+ "to use dynamic MXFP8 quantization with scaled grouped GEMMs"
142
+ )
143
+
105
144
def post_optimizer_hook (self , model : nn .Module | list [nn .Module ]):
106
145
"""
107
146
MXFP8 doesn't require any post-optimizer hooks at the moment
0 commit comments