Skip to content

Commit 7660825

Browse files
[mxfp8 moe training] add compile support
1 parent 66384a9 commit 7660825

File tree

4 files changed

+35
-21
lines changed

4 files changed

+35
-21
lines changed

test/prototype/moe_training/test_training.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@
3434

3535
@pytest.mark.parametrize(
3636
"target_fqns",
37-
[
38-
["experts"],
39-
["does.not.exist"],
40-
],
37+
[["experts"], ["experts,shared_expert"], ["invalid.fqns"]],
4138
)
4239
@pytest.mark.parametrize("compile", [False, True])
4340
@pytest.mark.parametrize(

torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import triton
55
import triton.language as tl
66
from torch import Tensor
7+
from torch.library import triton_op, wrap_triton
78

89
from torchao.prototype.mx_formats.utils import to_blocked
910
from torchao.utils import ceil_div
@@ -192,6 +193,7 @@ def compute_blocked_scale_offsets_for_K_groups(
192193
return group_sizes, starting_col_after_padding
193194

194195

196+
@triton_op("torchao::triton_mx_block_rearrange_2d_M_groups", mutates_args={})
195197
def triton_mx_block_rearrange_2d_M_groups(
196198
scales_tensor: torch.Tensor,
197199
input_group_end_offsets: torch.Tensor,
@@ -238,7 +240,7 @@ def triton_mx_block_rearrange_2d_M_groups(
238240
num_groups,
239241
num_col_blocks,
240242
)
241-
triton_scale_swizzle_M_groups[grid](
243+
wrap_triton(triton_scale_swizzle_M_groups)[grid](
242244
# Input scales
243245
scales_tensor.view(torch.uint8),
244246
scales_tensor.stride(0),
@@ -336,6 +338,7 @@ def triton_scale_swizzle_M_groups(
336338
current_start_row += BLOCK_ROWS
337339

338340

341+
@triton_op("torchao::triton_mx_block_rearrange_per_group_3d", mutates_args={})
339342
def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.Tensor:
340343
"""
341344
Rearranges an E8M0 tensor scale to block-scaled swizzle format.
@@ -379,7 +382,7 @@ def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.
379382
num_col_blocks,
380383
)
381384

382-
triton_scale_swizzle_per_group_3d[grid](
385+
wrap_triton(triton_scale_swizzle_per_group_3d)[grid](
383386
scale_tensor.view(torch.uint8),
384387
input_stride_dim0,
385388
input_stride_dim1,
@@ -454,6 +457,7 @@ def triton_scale_swizzle_per_group_3d(
454457
)
455458

456459

460+
@triton_op("torchao::triton_mx_block_rearrange_2d_K_groups", mutates_args={})
457461
def triton_mx_block_rearrange_2d_K_groups(
458462
scales_tensor: torch.Tensor,
459463
input_group_end_offsets: torch.Tensor,
@@ -497,7 +501,7 @@ def triton_mx_block_rearrange_2d_K_groups(
497501
num_groups,
498502
num_row_blocks,
499503
)
500-
triton_scale_swizzle_2d_K_groups[grid](
504+
wrap_triton(triton_scale_swizzle_2d_K_groups)[grid](
501505
# Input scales
502506
scales_tensor.view(torch.uint8),
503507
scales_tensor.stride(0),

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _scaled_grouped_mm(
6666
out_dtype,
6767
)
6868
elif scaling_type == MoEScalingType.MXFP8:
69-
logger.info("Using mxfp8 for _scaled_grouped_mm")
69+
# logger.info("Using mxfp8 for _scaled_grouped_mm")
7070
block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow?
7171
return _MXFP8GroupedMM.apply(
7272
A,

torchao/prototype/moe_training/tensor.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
torch.ops.aten.copy_.default,
2828
torch.ops.aten.view.default,
2929
torch.ops.aten.as_strided.default,
30-
torch.ops.aten._to_copy.default,
30+
torch.ops.aten._to_copy.default, # for *.to(dtype)
3131
torch.ops.aten._pin_memory.default,
3232
torch.ops.aten.split.Tensor,
3333
torch.ops.aten.clone.default,
@@ -94,11 +94,11 @@ def __torch_function__(cls, func, types, args, kwargs={}):
9494
"B should be a ScaledGroupedMMTensor"
9595
)
9696
scaling_type = B.scaling_type
97-
A_is_2d = A.dim() == 2
98-
B_is_3d = B.dim() == 3
97+
A_is_2d = A.ndim == 2
98+
B_is_2d_or_3d = B.ndim == 2 or B.ndim == 3
9999
has_offs = kwargs.get(cls.offs_arg_name) is not None
100100
other_args = args[2:]
101-
if A_is_2d and B_is_3d and has_offs:
101+
if A_is_2d and B_is_2d_or_3d and has_offs:
102102
return _scaled_grouped_mm(
103103
A,
104104
B,
@@ -125,28 +125,41 @@ def unwrap(t):
125125
assert t.scaling_type == scaling_type
126126
return t._data
127127

128-
args, kwargs = pytree.tree_map_only(
128+
args_unwrapped, kwargs_unwrapped = pytree.tree_map_only(
129129
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
130130
)
131-
assert scaling_type is not None
131+
assert scaling_type is not None, (
132+
f"__torch_dispatch__ called on {func.__name__} without any ScaledGroupedMMTensor arguments"
133+
)
132134

133-
# detach is special case
135+
# # detach is special case
134136
if func == torch.ops.aten.detach.default:
135137
return ScaledGroupedMMTensor(args[0], scaling_type)
136138

137139
# perform op
138-
out = func(*args, **kwargs)
140+
with torch._C.DisableTorchFunctionSubclass():
141+
out = func(*args_unwrapped, **kwargs_unwrapped)
139142

140143
# return regular tensors for ops that don't preserve subclass
141144
if func not in _ops_to_preserve_subclass:
142145
return out
143146

144-
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
145-
return pytree.tree_map_only(
146-
torch.Tensor,
147-
lambda x: ScaledGroupedMMTensor(x, scaling_type),
148-
out,
147+
# For ops that do preserve subclass, only preserve it if the primary input
148+
# was a ScaledGroupedMMTensor. Most ops in _ops_to_preserve_subclass are
149+
# unary operations where the first argument determines the output type.
150+
unary_op_on_subclass = len(args) > 0 and isinstance(
151+
args[0], ScaledGroupedMMTensor
149152
)
153+
if not unary_op_on_subclass:
154+
return out
155+
156+
# Only wrap tensor outputs, preserving any existing subclasses
157+
def selective_wrap(x):
158+
if isinstance(x, torch.Tensor) and type(x) == torch.Tensor:
159+
return ScaledGroupedMMTensor(x, scaling_type)
160+
return x
161+
162+
return pytree.tree_map(selective_wrap, out)
150163

151164
def __repr__(self):
152165
return f"ScaledGroupedMMTensor(data={self._data}, scaling_type={self.scaling_type})"

0 commit comments

Comments
 (0)