Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000

# Workaround for https://github.com/pytorch/ao/pull/2990#issuecomment-3285762681
torch._dynamo.config.automatic_dynamic_shapes = False


@dataclass(frozen=True)
class ExperimentConfig:
Expand Down
5 changes: 1 addition & 4 deletions test/prototype/moe_training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@

@pytest.mark.parametrize(
"target_fqns",
[
["experts"],
["does.not.exist"],
],
[["experts"], ["experts,shared_expert"], ["invalid.fqns"]],
)
@pytest.mark.parametrize("compile", [False, True])
@pytest.mark.parametrize(
Expand Down
18 changes: 11 additions & 7 deletions torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import triton
import triton.language as tl
from torch import Tensor
from torch.library import triton_op, wrap_triton

from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import ceil_div
Expand Down Expand Up @@ -192,6 +193,7 @@ def compute_blocked_scale_offsets_for_K_groups(
return group_sizes, starting_col_after_padding


@triton_op("torchao::triton_mx_block_rearrange_2d_M_groups", mutates_args={})
def triton_mx_block_rearrange_2d_M_groups(
scales_tensor: torch.Tensor,
input_group_end_offsets: torch.Tensor,
Expand All @@ -216,10 +218,10 @@ def triton_mx_block_rearrange_2d_M_groups(
"Expected element size to be 1 byte (8 bits)"
)
rows, cols = scales_tensor.shape
num_groups = input_group_end_offsets.numel()
num_groups = input_group_end_offsets.shape[0]

# Final offset is the total number of rows in the tensor
padded_rows = output_group_start_offsets[-1]
padded_rows = rows + num_groups * 128 # output_group_start_offsets[-1]

num_col_blocks = ceil_div(cols, 4)
padded_cols = num_col_blocks * 4
Expand All @@ -238,7 +240,7 @@ def triton_mx_block_rearrange_2d_M_groups(
num_groups,
num_col_blocks,
)
triton_scale_swizzle_M_groups[grid](
wrap_triton(triton_scale_swizzle_M_groups)[grid](
# Input scales
scales_tensor.view(torch.uint8),
scales_tensor.stride(0),
Expand Down Expand Up @@ -336,6 +338,7 @@ def triton_scale_swizzle_M_groups(
current_start_row += BLOCK_ROWS


@triton_op("torchao::triton_mx_block_rearrange_per_group_3d", mutates_args={})
def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.Tensor:
"""
Rearranges an E8M0 tensor scale to block-scaled swizzle format.
Expand Down Expand Up @@ -379,7 +382,7 @@ def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.
num_col_blocks,
)

triton_scale_swizzle_per_group_3d[grid](
wrap_triton(triton_scale_swizzle_per_group_3d)[grid](
scale_tensor.view(torch.uint8),
input_stride_dim0,
input_stride_dim1,
Expand Down Expand Up @@ -454,6 +457,7 @@ def triton_scale_swizzle_per_group_3d(
)


@triton_op("torchao::triton_mx_block_rearrange_2d_K_groups", mutates_args={})
def triton_mx_block_rearrange_2d_K_groups(
scales_tensor: torch.Tensor,
input_group_end_offsets: torch.Tensor,
Expand All @@ -479,12 +483,12 @@ def triton_mx_block_rearrange_2d_K_groups(
)
rows, cols = scales_tensor.shape
# Calculate blocks needed
num_groups = input_group_end_offsets.numel()
num_groups = input_group_end_offsets.shape[0]
num_row_blocks = ceil_div(rows, 128)
padded_rows = num_row_blocks * 128

# output_group_start_offsets always starts with 0 and ends with the total number of cols
padded_cols = output_group_start_offsets[-1]
padded_cols = cols + num_groups * 4 # output_group_start_offsets[-1]
output = scales_tensor.new_empty((padded_rows, padded_cols))

# Output block stride for the rearranged format
Expand All @@ -497,7 +501,7 @@ def triton_mx_block_rearrange_2d_K_groups(
num_groups,
num_row_blocks,
)
triton_scale_swizzle_2d_K_groups[grid](
wrap_triton(triton_scale_swizzle_2d_K_groups)[grid](
# Input scales
scales_tensor.view(torch.uint8),
scales_tensor.stride(0),
Expand Down
4 changes: 2 additions & 2 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ def _scaled_grouped_mm(
"""
# TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging.
if scaling_type == MoEScalingType.FP8_ROWWISE:
logger.info("Using fp8 rowwise for _scaled_grouped_mm")
logger.debug("Using fp8 rowwise for _scaled_grouped_mm")
return _Float8GroupedMM.apply(
A,
B_t,
offs,
out_dtype,
)
elif scaling_type == MoEScalingType.MXFP8:
logger.info("Using mxfp8 for _scaled_grouped_mm")
logger.debug("Using mxfp8 for _scaled_grouped_mm")
block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow?
return _MXFP8GroupedMM.apply(
A,
Expand Down
29 changes: 16 additions & 13 deletions torchao/prototype/moe_training/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
torch.ops.aten.copy_.default,
torch.ops.aten.view.default,
torch.ops.aten.as_strided.default,
torch.ops.aten._to_copy.default,
torch.ops.aten._to_copy.default, # for *.to(dtype)
torch.ops.aten._pin_memory.default,
torch.ops.aten.split.Tensor,
torch.ops.aten.clone.default,
Expand Down Expand Up @@ -94,11 +94,11 @@ def __torch_function__(cls, func, types, args, kwargs={}):
"B should be a ScaledGroupedMMTensor"
)
scaling_type = B.scaling_type
A_is_2d = A.dim() == 2
B_is_3d = B.dim() == 3
A_is_2d = A.ndim == 2
B_is_2d_or_3d = B.ndim == 2 or B.ndim == 3
has_offs = kwargs.get(cls.offs_arg_name) is not None
other_args = args[2:]
if A_is_2d and B_is_3d and has_offs:
if A_is_2d and B_is_2d_or_3d and has_offs:
return _scaled_grouped_mm(
A,
B,
Expand All @@ -125,28 +125,31 @@ def unwrap(t):
assert t.scaling_type == scaling_type
return t._data

args, kwargs = pytree.tree_map_only(
args_unwrapped, kwargs_unwrapped = pytree.tree_map_only(
ScaledGroupedMMTensor, unwrap, (args, kwargs or {})
)
assert scaling_type is not None
assert scaling_type is not None, (
f"__torch_dispatch__ called on {func.__name__} without any ScaledGroupedMMTensor arguments"
)

# detach is special case
if func == torch.ops.aten.detach.default:
return ScaledGroupedMMTensor(args[0], scaling_type)

# perform op
out = func(*args, **kwargs)
out = func(*args_unwrapped, **kwargs_unwrapped)

# return regular tensors for ops that don't preserve subclass
if func not in _ops_to_preserve_subclass:
return out

# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
return pytree.tree_map_only(
torch.Tensor,
lambda x: ScaledGroupedMMTensor(x, scaling_type),
out,
)
# Only wrap tensor outputs, prevent double wrapping
def selective_wrap(x):
if isinstance(x, torch.Tensor) and type(x) == torch.Tensor:
return ScaledGroupedMMTensor(x, scaling_type)
return x

return pytree.tree_map(selective_wrap, out)

def __repr__(self):
return f"ScaledGroupedMMTensor(data={self._data}, scaling_type={self.scaling_type})"
Expand Down
Loading