Skip to content

Commit e381f71

Browse files
[mxfp8 moe training] add compile support
stack-info: PR: #2990, branch: danielvegamyhre/stack/66
1 parent 66384a9 commit e381f71

File tree

6 files changed

+70
-42
lines changed

6 files changed

+70
-42
lines changed

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
# Needed since changing args to function causes recompiles
2929
torch._dynamo.config.cache_size_limit = 1000
3030

31+
# Workaround for https://github.com/pytorch/ao/pull/2990#issuecomment-3285762681
32+
torch._dynamo.config.automatic_dynamic_shapes = False
33+
3134

3235
@dataclass(frozen=True)
3336
class ExperimentConfig:

test/prototype/moe_training/test_kernels.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def test_fp8_rowwise_3d_transpose_rhs_reduction(round_scales_to_power_of_2: bool
213213
@pytest.mark.parametrize(
214214
"m,k,n_groups", [(256, 256, 4), (16640, 5120, 16), (16640, 8192, 16)]
215215
)
216-
def test_mxfp8_per_group_blocked_scales_2d(
216+
def test_triton_mx_block_rearrange_2d_M_groups(
217217
m: int,
218218
k: int,
219219
n_groups: int,
@@ -272,10 +272,10 @@ def test_mxfp8_per_group_blocked_scales_3d(
272272

273273

274274
@skip_if_rocm("ROCm enablement in progress")
275-
@pytest.mark.parametrize("m", [256, 512, 1024, 5120])
276-
@pytest.mark.parametrize("total_k", [512, 1024, 2048, 4096, 8192, 16384])
277-
@pytest.mark.parametrize("n_groups", [1, 4, 8, 16])
278-
def test_mxfp8_per_group_blocked_scales_2d2d(
275+
@pytest.mark.parametrize("m", [256])
276+
@pytest.mark.parametrize("total_k", [512])
277+
@pytest.mark.parametrize("n_groups", [1])
278+
def test_triton_mx_block_rearrange_2d_K_groups(
279279
m: int,
280280
total_k: int,
281281
n_groups: int,

test/prototype/moe_training/test_training.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@
3434

3535
@pytest.mark.parametrize(
3636
"target_fqns",
37-
[
38-
["experts"],
39-
["does.not.exist"],
40-
],
37+
[["experts"], ["experts,shared_expert"], ["invalid.fqns"]],
4138
)
4239
@pytest.mark.parametrize("compile", [False, True])
4340
@pytest.mark.parametrize(

torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import triton
55
import triton.language as tl
66
from torch import Tensor
7+
from torch.library import triton_op, wrap_triton
78

89
from torchao.prototype.mx_formats.utils import to_blocked
910
from torchao.utils import ceil_div
@@ -29,7 +30,14 @@ def torch_to_blocked_2d_M_groups(
2930

3031
assert x_scales.ndim == 2, "x_scales must be 2D"
3132
assert block_size == 32, "Only block_size=32 is supported for now"
32-
blocked_scales_list = []
33+
total_M, _ = x_scales.shape
34+
num_groups = group_offs.shape[0]
35+
36+
# Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group,
37+
# the Triton kernenl will use an upper bound of adding 128 padding rows to each group.
38+
# (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl).
39+
total_M_padded = total_M + num_groups * 128
40+
blocked_scales = x_scales.new_zeros(total_M_padded, K // block_size)
3341
start_row_after_padding_list = [0]
3442
group_start_idx = 0
3543
for i, group_end_idx in enumerate(group_offs.tolist()):
@@ -42,19 +50,24 @@ def torch_to_blocked_2d_M_groups(
4250
# Convert group scales to blocked format
4351
group_scales = x_scales[group_start_idx:group_end_idx]
4452
group_scales_blocked = to_blocked(group_scales)
45-
blocked_scales_list.append(group_scales_blocked)
4653

4754
# Calculate the start row after padding
4855
scaling_groups_per_row = K // block_size
4956
rows_for_group = group_scales_blocked.numel() // scaling_groups_per_row
5057
new_start_row = prev_start_row_after_padding + rows_for_group
5158
start_row_after_padding_list.append(new_start_row)
5259

60+
# Write output to subtensor
61+
group_rows_padded = ceil_div(group_size, 128) * 128
62+
blocked_scales[
63+
prev_start_row_after_padding : prev_start_row_after_padding
64+
+ group_rows_padded,
65+
:,
66+
] = group_scales_blocked.reshape(-1, K // block_size)
67+
5368
# Update next group start index
5469
group_start_idx = group_end_idx
5570

56-
blocked_scales = torch.cat(blocked_scales_list, dim=0).contiguous()
57-
blocked_scales = blocked_scales.reshape(-1, K // 32)
5871
start_row_after_padding = torch.tensor(
5972
start_row_after_padding_list, device=x_scales.device, dtype=torch.int64
6073
)
@@ -79,34 +92,44 @@ def torch_to_blocked_2d_K_groups(
7992
"""
8093
assert x_scales.ndim == 2, "x_scales must be 2D"
8194
assert block_size == 32, "Only block_size=32 is supported for now"
82-
blocked_scales_list = []
95+
M, total_K = x_scales.shape
96+
padded_M = ceil_div(M, 128) * 128
97+
num_groups = group_offs.shape[0]
98+
99+
# Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group,
100+
# Triton kernel will use an upper bound of adding 4 padding cols to each group.
101+
# (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl).
102+
total_K_padded = total_K + num_groups * 4
103+
blocked_scales = x_scales.new_zeros(padded_M, total_K_padded)
104+
83105
start_col_after_padding_list = [0]
84106
group_start_idx = 0
85107
for i, group_end_idx in enumerate(group_offs.tolist()):
86108
group_size = group_end_idx - group_start_idx
87-
prev_start_row_after_padding = start_col_after_padding_list[i]
109+
prev_start_col_after_padding = start_col_after_padding_list[i]
88110
if group_size == 0:
89-
start_col_after_padding_list.append(prev_start_row_after_padding)
111+
start_col_after_padding_list.append(prev_start_col_after_padding)
90112
continue
91113

92114
# Convert group scales to blocked format
93115
group_scales = x_scales[:, group_start_idx:group_end_idx]
94116
group_scales_blocked = to_blocked(group_scales)
95117
cols_after_padding = ceil_div(group_size, 4) * 4
96-
blocked_scales_list.append(group_scales_blocked)
118+
119+
# Write output to subtensor
120+
blocked_scales[
121+
:,
122+
prev_start_col_after_padding : prev_start_col_after_padding
123+
+ cols_after_padding,
124+
] = group_scales_blocked.reshape(-1, cols_after_padding)
97125

98126
# Calculate the start row after padding
99-
new_start_col = prev_start_row_after_padding + cols_after_padding
127+
new_start_col = prev_start_col_after_padding + cols_after_padding
100128
start_col_after_padding_list.append(new_start_col)
101129

102130
# Update next group start index
103131
group_start_idx = group_end_idx
104132

105-
# blocked_scales = torch.cat(blocked_scales_list, dim=1)
106-
M = x_scales.shape[0]
107-
padded_M = ceil_div(M, 128) * 128
108-
blocked_scales = torch.cat(blocked_scales_list)
109-
blocked_scales = blocked_scales.reshape(padded_M, -1)
110133
start_cols_after_padding = torch.tensor(
111134
start_col_after_padding_list, device=x_scales.device, dtype=torch.int64
112135
)
@@ -192,6 +215,7 @@ def compute_blocked_scale_offsets_for_K_groups(
192215
return group_sizes, starting_col_after_padding
193216

194217

218+
@triton_op("torchao::triton_mx_block_rearrange_2d_M_groups", mutates_args={})
195219
def triton_mx_block_rearrange_2d_M_groups(
196220
scales_tensor: torch.Tensor,
197221
input_group_end_offsets: torch.Tensor,
@@ -216,14 +240,14 @@ def triton_mx_block_rearrange_2d_M_groups(
216240
"Expected element size to be 1 byte (8 bits)"
217241
)
218242
rows, cols = scales_tensor.shape
219-
num_groups = input_group_end_offsets.numel()
243+
num_groups = input_group_end_offsets.shape[0]
220244

221245
# Final offset is the total number of rows in the tensor
222-
padded_rows = output_group_start_offsets[-1]
246+
padded_rows = rows + num_groups * 128
223247

224248
num_col_blocks = ceil_div(cols, 4)
225249
padded_cols = num_col_blocks * 4
226-
output = scales_tensor.new_empty((padded_rows, padded_cols))
250+
output = scales_tensor.new_zeros((padded_rows, padded_cols))
227251

228252
# Output block stride for the rearranged format
229253
BLOCK_ROWS, BLOCK_COLS = 128, 4
@@ -238,7 +262,7 @@ def triton_mx_block_rearrange_2d_M_groups(
238262
num_groups,
239263
num_col_blocks,
240264
)
241-
triton_scale_swizzle_M_groups[grid](
265+
wrap_triton(triton_scale_swizzle_M_groups)[grid](
242266
# Input scales
243267
scales_tensor.view(torch.uint8),
244268
scales_tensor.stride(0),
@@ -336,6 +360,7 @@ def triton_scale_swizzle_M_groups(
336360
current_start_row += BLOCK_ROWS
337361

338362

363+
@triton_op("torchao::triton_mx_block_rearrange_per_group_3d", mutates_args={})
339364
def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.Tensor:
340365
"""
341366
Rearranges an E8M0 tensor scale to block-scaled swizzle format.
@@ -379,7 +404,7 @@ def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.
379404
num_col_blocks,
380405
)
381406

382-
triton_scale_swizzle_per_group_3d[grid](
407+
wrap_triton(triton_scale_swizzle_per_group_3d)[grid](
383408
scale_tensor.view(torch.uint8),
384409
input_stride_dim0,
385410
input_stride_dim1,
@@ -454,6 +479,7 @@ def triton_scale_swizzle_per_group_3d(
454479
)
455480

456481

482+
@triton_op("torchao::triton_mx_block_rearrange_2d_K_groups", mutates_args={})
457483
def triton_mx_block_rearrange_2d_K_groups(
458484
scales_tensor: torch.Tensor,
459485
input_group_end_offsets: torch.Tensor,
@@ -479,13 +505,13 @@ def triton_mx_block_rearrange_2d_K_groups(
479505
)
480506
rows, cols = scales_tensor.shape
481507
# Calculate blocks needed
482-
num_groups = input_group_end_offsets.numel()
508+
num_groups = input_group_end_offsets.shape[0]
483509
num_row_blocks = ceil_div(rows, 128)
484510
padded_rows = num_row_blocks * 128
485511

486512
# output_group_start_offsets always starts with 0 and ends with the total number of cols
487-
padded_cols = output_group_start_offsets[-1]
488-
output = scales_tensor.new_empty((padded_rows, padded_cols))
513+
padded_cols = cols + num_groups * 4
514+
output = scales_tensor.new_zeros((padded_rows, padded_cols))
489515

490516
# Output block stride for the rearranged format
491517
BLOCK_ROWS, BLOCK_COLS = 128, 4
@@ -497,7 +523,7 @@ def triton_mx_block_rearrange_2d_K_groups(
497523
num_groups,
498524
num_row_blocks,
499525
)
500-
triton_scale_swizzle_2d_K_groups[grid](
526+
wrap_triton(triton_scale_swizzle_2d_K_groups)[grid](
501527
# Input scales
502528
scales_tensor.view(torch.uint8),
503529
scales_tensor.stride(0),

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ def _scaled_grouped_mm(
5858
"""
5959
# TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging.
6060
if scaling_type == MoEScalingType.FP8_ROWWISE:
61-
logger.info("Using fp8 rowwise for _scaled_grouped_mm")
61+
logger.debug("Using fp8 rowwise for _scaled_grouped_mm")
6262
return _Float8GroupedMM.apply(
6363
A,
6464
B_t,
6565
offs,
6666
out_dtype,
6767
)
6868
elif scaling_type == MoEScalingType.MXFP8:
69-
logger.info("Using mxfp8 for _scaled_grouped_mm")
69+
logger.debug("Using mxfp8 for _scaled_grouped_mm")
7070
block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow?
7171
return _MXFP8GroupedMM.apply(
7272
A,

torchao/prototype/moe_training/tensor.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
torch.ops.aten.copy_.default,
2828
torch.ops.aten.view.default,
2929
torch.ops.aten.as_strided.default,
30-
torch.ops.aten._to_copy.default,
30+
torch.ops.aten._to_copy.default, # for *.to(dtype)
3131
torch.ops.aten._pin_memory.default,
3232
torch.ops.aten.split.Tensor,
3333
torch.ops.aten.clone.default,
@@ -94,11 +94,11 @@ def __torch_function__(cls, func, types, args, kwargs={}):
9494
"B should be a ScaledGroupedMMTensor"
9595
)
9696
scaling_type = B.scaling_type
97-
A_is_2d = A.dim() == 2
98-
B_is_3d = B.dim() == 3
97+
A_is_2d = A.ndim == 2
98+
B_is_2d_or_3d = B.ndim == 2 or B.ndim == 3
9999
has_offs = kwargs.get(cls.offs_arg_name) is not None
100100
other_args = args[2:]
101-
if A_is_2d and B_is_3d and has_offs:
101+
if A_is_2d and B_is_2d_or_3d and has_offs:
102102
return _scaled_grouped_mm(
103103
A,
104104
B,
@@ -125,17 +125,19 @@ def unwrap(t):
125125
assert t.scaling_type == scaling_type
126126
return t._data
127127

128-
args, kwargs = pytree.tree_map_only(
128+
args_unwrapped, kwargs_unwrapped = pytree.tree_map_only(
129129
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
130130
)
131-
assert scaling_type is not None
131+
assert scaling_type is not None, (
132+
f"__torch_dispatch__ called on {func.__name__} without any ScaledGroupedMMTensor arguments"
133+
)
132134

133135
# detach is special case
134136
if func == torch.ops.aten.detach.default:
135-
return ScaledGroupedMMTensor(args[0], scaling_type)
137+
return ScaledGroupedMMTensor(args_unwrapped[0], scaling_type)
136138

137139
# perform op
138-
out = func(*args, **kwargs)
140+
out = func(*args_unwrapped, **kwargs_unwrapped)
139141

140142
# return regular tensors for ops that don't preserve subclass
141143
if func not in _ops_to_preserve_subclass:

0 commit comments

Comments
 (0)