@@ -30,6 +30,7 @@ class MXConverter(ModelConverter):
30
30
enabled : bool
31
31
filter_fqns : List [str ]
32
32
mx_config : Any # MXLinearConfig type when imported
33
+ mxfp8_token_group_alignment_size = 32
33
34
34
35
def __init__ (self , job_config : JobConfig , parallel_dims : ParallelDims ):
35
36
# Ensure minimum torchao versions
@@ -39,8 +40,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
39
40
)
40
41
torchao_version = version ("torchao" )
41
42
42
- # Last torchao release was 0.12 .0, so nightly build starts with 0.13.0+git...
43
- is_nightly_build = torchao_version .startswith ("0.13 .0" )
43
+ # Last torchao release was 0.13 .0, so nightly build starts with 0.13.0+git...
44
+ is_nightly_build = torchao_version .startswith ("0.14 .0" )
44
45
if not is_nightly_build :
45
46
raise ImportError (
46
47
f"torchao version { torchao_version } is too old, please install torchao nightly build and try again"
@@ -62,9 +63,11 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
62
63
63
64
# For MoE training with mxfp8, token group sizes must be multiples of 32
64
65
if job_config .mx .moe_fqns_prototype :
65
- mxfp8_block_size = 32
66
- set_token_group_alignment_size_m (mxfp8_block_size )
67
- logger .info (f"Setting token group alignment size to { mxfp8_block_size } " )
66
+ logger .info (
67
+ f"Setting token group alignment size to { self .mxfp8_token_group_alignment_size } "
68
+ )
69
+ set_token_group_alignment_size_m (self .mxfp8_token_group_alignment_size )
70
+ self .moe_fqns = job_config .mx .moe_fqns_prototype
68
71
69
72
# Configure MXFP8
70
73
from torchao .prototype .mx_formats .config import (
@@ -94,6 +97,13 @@ def convert(self, model: nn.Module):
94
97
from torchao .prototype .mx_formats .config import MXLinearConfig
95
98
from torchao .quantization import quantize_
96
99
100
+ # MoE conversion must take place before MXLinear conversion, otherwise the MXLinear will
101
+ # be converted back to nn.Linear:
102
+ # https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299
103
+ # TODO: add warning in torchao when this happens, or find a better way to avoid this.
104
+ if self .moe_fqns :
105
+ self ._convert_moe_layers (model )
106
+
97
107
assert isinstance (self .config , MXLinearConfig )
98
108
quantize_ (
99
109
model ,
@@ -102,6 +112,36 @@ def convert(self, model: nn.Module):
102
112
)
103
113
logger .info ("Swapped to MXLinear layers" )
104
114
115
+ def _convert_moe_layers (self , model : nn .Module ):
116
+ """
117
+ Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
118
+ to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
119
+ """
120
+ from torchao .quantization .quant_api import quantize_
121
+
122
+ try :
123
+ from torchao .prototype .moe_training .conversion_utils import (
124
+ MoEScalingType ,
125
+ MoETrainingConfig ,
126
+ )
127
+ except ImportError as e :
128
+ raise ImportError (
129
+ "torchao installation does not have MoE training support. Please install torchao nightly build."
130
+ ) from e
131
+
132
+ def moe_module_filter_fn (mod : nn .Module , cur_fqn : str ) -> bool :
133
+ for target_fqn in self .moe_fqns :
134
+ if target_fqn in cur_fqn :
135
+ return True
136
+ return False
137
+
138
+ config = MoETrainingConfig (scaling_type = MoEScalingType .MXFP8 )
139
+ quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
140
+ logger .info (
141
+ f"Converted MoE layers matching FQNS { self .moe_fqns } "
142
+ "to use dynamic MXFP8 quantization with scaled grouped GEMMs"
143
+ )
144
+
105
145
def post_optimizer_hook (self , model : nn .Module | list [nn .Module ]):
106
146
"""
107
147
MXFP8 doesn't require any post-optimizer hooks at the moment
0 commit comments