Skip to content

Commit b3b545f

Browse files
authored
mxtensor: add index select support (pytorch#3079)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent cf9ea8a commit b3b545f

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torchao.quantization.utils import compute_error
2727
from torchao.utils import (
2828
is_sm_at_least_89,
29+
is_sm_at_least_90,
2930
is_sm_at_least_100,
3031
torch_version_at_least,
3132
)
@@ -556,6 +557,26 @@ def test_to_mx_inductor_single_kernel():
556557
FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0])
557558

558559

560+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
561+
@pytest.mark.skipIf(not is_sm_at_least_90(), "Need sm90+")
562+
def test_index_select():
563+
"""
564+
test that `x_0 = x[0]` works when `x` is a 3D `MXTensor`. This is
565+
useful when stitching checkpoints of `num_experts` 2D parameters into
566+
a single 3D parameter when converting between model definitions that
567+
use 2D and 3D parameters for their expert weights.
568+
"""
569+
570+
E, K, N = 128, 256, 512
571+
x = torch.randn(E, N, K, device="cuda", dtype=torch.bfloat16)
572+
x_mx = MXTensor.to_mx(x, torch.float8_e4m3fn, 32)
573+
574+
x_mx_1 = x_mx[1]
575+
torch.testing.assert_close(
576+
x_mx.to_dtype(x.dtype)[1], x_mx_1.to_dtype(x.dtype), atol=0, rtol=0
577+
)
578+
579+
559580
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
560581
@pytest.mark.skipif(
561582
not is_sm_at_least_89(),

torchao/prototype/mx_formats/mx_ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,23 @@ def mx_clone(func, types, args, kwargs):
322322
clone_fn = lambda x: x.clone()
323323

324324
return self._apply_fn_to_data(clone_fn)
325+
326+
327+
@implements([aten.select.int])
328+
def mx_select(func, types, args, kwargs):
329+
old_mx_tensor, dim, index = args
330+
assert dim == 0, f"MXTensor aten.select.int with {dim=} is not yet supported"
331+
assert len(old_mx_tensor.qdata.shape) == len(old_mx_tensor._scale_e8m0.shape), (
332+
"unsupported"
333+
)
334+
new_mx_tensor = old_mx_tensor.__class__(
335+
old_mx_tensor.qdata[index],
336+
old_mx_tensor._scale_e8m0[index],
337+
old_mx_tensor._elem_dtype,
338+
old_mx_tensor._block_size,
339+
old_mx_tensor._orig_dtype,
340+
old_mx_tensor._gemm_kernel_choice,
341+
old_mx_tensor._pack_fp6,
342+
old_mx_tensor.act_quant_kwargs,
343+
)
344+
return return_and_correct_aliasing(func, args, kwargs, new_mx_tensor)

0 commit comments

Comments
 (0)