Skip to content

Commit 6327f2e

Browse files
[mxfp8 moe training] add compile support
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
1 parent 66384a9 commit 6327f2e

File tree

5 files changed

+32
-27
lines changed

5 files changed

+32
-27
lines changed

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
# Needed since changing args to function causes recompiles
2929
torch._dynamo.config.cache_size_limit = 1000
30+
torch._dynamo.config.automatic_dynamic_shapes = False
3031

3132

3233
@dataclass(frozen=True)
@@ -125,7 +126,7 @@ def run_experiment(
125126
B_t,
126127
offs,
127128
labels=labels,
128-
use_compile=args.compile,
129+
use_compile=False, # args.compile,
129130
fullgraph=False,
130131
profile_name="bf16_profile",
131132
)

test/prototype/moe_training/test_training.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@
3434

3535
@pytest.mark.parametrize(
3636
"target_fqns",
37-
[
38-
["experts"],
39-
["does.not.exist"],
40-
],
37+
[["experts"], ["experts,shared_expert"], ["invalid.fqns"]],
4138
)
4239
@pytest.mark.parametrize("compile", [False, True])
4340
@pytest.mark.parametrize(

torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import triton
55
import triton.language as tl
66
from torch import Tensor
7+
from torch.library import triton_op, wrap_triton
78

89
from torchao.prototype.mx_formats.utils import to_blocked
910
from torchao.utils import ceil_div
@@ -192,6 +193,7 @@ def compute_blocked_scale_offsets_for_K_groups(
192193
return group_sizes, starting_col_after_padding
193194

194195

196+
@triton_op("torchao::triton_mx_block_rearrange_2d_M_groups", mutates_args={})
195197
def triton_mx_block_rearrange_2d_M_groups(
196198
scales_tensor: torch.Tensor,
197199
input_group_end_offsets: torch.Tensor,
@@ -216,10 +218,10 @@ def triton_mx_block_rearrange_2d_M_groups(
216218
"Expected element size to be 1 byte (8 bits)"
217219
)
218220
rows, cols = scales_tensor.shape
219-
num_groups = input_group_end_offsets.numel()
221+
num_groups = input_group_end_offsets.shape[0]
220222

221223
# Final offset is the total number of rows in the tensor
222-
padded_rows = output_group_start_offsets[-1]
224+
padded_rows = rows + num_groups * 128 # output_group_start_offsets[-1]
223225

224226
num_col_blocks = ceil_div(cols, 4)
225227
padded_cols = num_col_blocks * 4
@@ -238,7 +240,7 @@ def triton_mx_block_rearrange_2d_M_groups(
238240
num_groups,
239241
num_col_blocks,
240242
)
241-
triton_scale_swizzle_M_groups[grid](
243+
wrap_triton(triton_scale_swizzle_M_groups)[grid](
242244
# Input scales
243245
scales_tensor.view(torch.uint8),
244246
scales_tensor.stride(0),
@@ -336,6 +338,7 @@ def triton_scale_swizzle_M_groups(
336338
current_start_row += BLOCK_ROWS
337339

338340

341+
@triton_op("torchao::triton_mx_block_rearrange_per_group_3d", mutates_args={})
339342
def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.Tensor:
340343
"""
341344
Rearranges an E8M0 tensor scale to block-scaled swizzle format.
@@ -379,7 +382,7 @@ def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.
379382
num_col_blocks,
380383
)
381384

382-
triton_scale_swizzle_per_group_3d[grid](
385+
wrap_triton(triton_scale_swizzle_per_group_3d)[grid](
383386
scale_tensor.view(torch.uint8),
384387
input_stride_dim0,
385388
input_stride_dim1,
@@ -454,6 +457,7 @@ def triton_scale_swizzle_per_group_3d(
454457
)
455458

456459

460+
@triton_op("torchao::triton_mx_block_rearrange_2d_K_groups", mutates_args={})
457461
def triton_mx_block_rearrange_2d_K_groups(
458462
scales_tensor: torch.Tensor,
459463
input_group_end_offsets: torch.Tensor,
@@ -479,12 +483,12 @@ def triton_mx_block_rearrange_2d_K_groups(
479483
)
480484
rows, cols = scales_tensor.shape
481485
# Calculate blocks needed
482-
num_groups = input_group_end_offsets.numel()
486+
num_groups = input_group_end_offsets.shape[0]
483487
num_row_blocks = ceil_div(rows, 128)
484488
padded_rows = num_row_blocks * 128
485489

486490
# output_group_start_offsets always starts with 0 and ends with the total number of cols
487-
padded_cols = output_group_start_offsets[-1]
491+
padded_cols = cols + num_groups * 4 # output_group_start_offsets[-1]
488492
output = scales_tensor.new_empty((padded_rows, padded_cols))
489493

490494
# Output block stride for the rearranged format
@@ -497,7 +501,7 @@ def triton_mx_block_rearrange_2d_K_groups(
497501
num_groups,
498502
num_row_blocks,
499503
)
500-
triton_scale_swizzle_2d_K_groups[grid](
504+
wrap_triton(triton_scale_swizzle_2d_K_groups)[grid](
501505
# Input scales
502506
scales_tensor.view(torch.uint8),
503507
scales_tensor.stride(0),

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _scaled_grouped_mm(
6666
out_dtype,
6767
)
6868
elif scaling_type == MoEScalingType.MXFP8:
69-
logger.info("Using mxfp8 for _scaled_grouped_mm")
69+
# logger.info("Using mxfp8 for _scaled_grouped_mm")
7070
block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow?
7171
return _MXFP8GroupedMM.apply(
7272
A,

torchao/prototype/moe_training/tensor.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
torch.ops.aten.copy_.default,
2828
torch.ops.aten.view.default,
2929
torch.ops.aten.as_strided.default,
30-
torch.ops.aten._to_copy.default,
30+
torch.ops.aten._to_copy.default, # for *.to(dtype)
3131
torch.ops.aten._pin_memory.default,
3232
torch.ops.aten.split.Tensor,
3333
torch.ops.aten.clone.default,
@@ -94,11 +94,11 @@ def __torch_function__(cls, func, types, args, kwargs={}):
9494
"B should be a ScaledGroupedMMTensor"
9595
)
9696
scaling_type = B.scaling_type
97-
A_is_2d = A.dim() == 2
98-
B_is_3d = B.dim() == 3
97+
A_is_2d = A.ndim == 2
98+
B_is_2d_or_3d = B.ndim == 2 or B.ndim == 3
9999
has_offs = kwargs.get(cls.offs_arg_name) is not None
100100
other_args = args[2:]
101-
if A_is_2d and B_is_3d and has_offs:
101+
if A_is_2d and B_is_2d_or_3d and has_offs:
102102
return _scaled_grouped_mm(
103103
A,
104104
B,
@@ -125,28 +125,31 @@ def unwrap(t):
125125
assert t.scaling_type == scaling_type
126126
return t._data
127127

128-
args, kwargs = pytree.tree_map_only(
128+
args_unwrapped, kwargs_unwrapped = pytree.tree_map_only(
129129
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
130130
)
131-
assert scaling_type is not None
131+
assert scaling_type is not None, (
132+
f"__torch_dispatch__ called on {func.__name__} without any ScaledGroupedMMTensor arguments"
133+
)
132134

133-
# detach is special case
135+
# # detach is special case
134136
if func == torch.ops.aten.detach.default:
135137
return ScaledGroupedMMTensor(args[0], scaling_type)
136138

137139
# perform op
138-
out = func(*args, **kwargs)
140+
out = func(*args_unwrapped, **kwargs_unwrapped)
139141

140142
# return regular tensors for ops that don't preserve subclass
141143
if func not in _ops_to_preserve_subclass:
142144
return out
143145

144-
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
145-
return pytree.tree_map_only(
146-
torch.Tensor,
147-
lambda x: ScaledGroupedMMTensor(x, scaling_type),
148-
out,
149-
)
146+
# Only wrap tensor outputs, prevent double wrapping
147+
def selective_wrap(x):
148+
if isinstance(x, torch.Tensor) and type(x) == torch.Tensor:
149+
return ScaledGroupedMMTensor(x, scaling_type)
150+
return x
151+
152+
return pytree.tree_map(selective_wrap, out)
150153

151154
def __repr__(self):
152155
return f"ScaledGroupedMMTensor(data={self._data}, scaling_type={self.scaling_type})"

0 commit comments

Comments
 (0)