4
4
import triton
5
5
import triton .language as tl
6
6
from torch import Tensor
7
+ from torch .library import triton_op , wrap_triton
7
8
8
9
from torchao .prototype .mx_formats .utils import to_blocked
9
10
from torchao .utils import ceil_div
@@ -192,6 +193,7 @@ def compute_blocked_scale_offsets_for_K_groups(
192
193
return group_sizes , starting_col_after_padding
193
194
194
195
196
+ @triton_op ("torchao::triton_mx_block_rearrange_2d_M_groups" , mutates_args = {})
195
197
def triton_mx_block_rearrange_2d_M_groups (
196
198
scales_tensor : torch .Tensor ,
197
199
input_group_end_offsets : torch .Tensor ,
@@ -216,10 +218,10 @@ def triton_mx_block_rearrange_2d_M_groups(
216
218
"Expected element size to be 1 byte (8 bits)"
217
219
)
218
220
rows , cols = scales_tensor .shape
219
- num_groups = input_group_end_offsets .numel ()
221
+ num_groups = input_group_end_offsets .shape [ 0 ]
220
222
221
223
# Final offset is the total number of rows in the tensor
222
- padded_rows = output_group_start_offsets [- 1 ]
224
+ padded_rows = rows + num_groups * 128 # output_group_start_offsets[-1]
223
225
224
226
num_col_blocks = ceil_div (cols , 4 )
225
227
padded_cols = num_col_blocks * 4
@@ -238,7 +240,7 @@ def triton_mx_block_rearrange_2d_M_groups(
238
240
num_groups ,
239
241
num_col_blocks ,
240
242
)
241
- triton_scale_swizzle_M_groups [grid ](
243
+ wrap_triton ( triton_scale_swizzle_M_groups ) [grid ](
242
244
# Input scales
243
245
scales_tensor .view (torch .uint8 ),
244
246
scales_tensor .stride (0 ),
@@ -336,6 +338,7 @@ def triton_scale_swizzle_M_groups(
336
338
current_start_row += BLOCK_ROWS
337
339
338
340
341
+ @triton_op ("torchao::triton_mx_block_rearrange_per_group_3d" , mutates_args = {})
339
342
def triton_mx_block_rearrange_per_group_3d (scale_tensor : torch .Tensor ) -> torch .Tensor :
340
343
"""
341
344
Rearranges an E8M0 tensor scale to block-scaled swizzle format.
@@ -379,7 +382,7 @@ def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.
379
382
num_col_blocks ,
380
383
)
381
384
382
- triton_scale_swizzle_per_group_3d [grid ](
385
+ wrap_triton ( triton_scale_swizzle_per_group_3d ) [grid ](
383
386
scale_tensor .view (torch .uint8 ),
384
387
input_stride_dim0 ,
385
388
input_stride_dim1 ,
@@ -454,6 +457,7 @@ def triton_scale_swizzle_per_group_3d(
454
457
)
455
458
456
459
460
+ @triton_op ("torchao::triton_mx_block_rearrange_2d_K_groups" , mutates_args = {})
457
461
def triton_mx_block_rearrange_2d_K_groups (
458
462
scales_tensor : torch .Tensor ,
459
463
input_group_end_offsets : torch .Tensor ,
@@ -479,12 +483,12 @@ def triton_mx_block_rearrange_2d_K_groups(
479
483
)
480
484
rows , cols = scales_tensor .shape
481
485
# Calculate blocks needed
482
- num_groups = input_group_end_offsets .numel ()
486
+ num_groups = input_group_end_offsets .shape [ 0 ]
483
487
num_row_blocks = ceil_div (rows , 128 )
484
488
padded_rows = num_row_blocks * 128
485
489
486
490
# output_group_start_offsets always starts with 0 and ends with the total number of cols
487
- padded_cols = output_group_start_offsets [- 1 ]
491
+ padded_cols = cols + num_groups * 4 # output_group_start_offsets[-1]
488
492
output = scales_tensor .new_empty ((padded_rows , padded_cols ))
489
493
490
494
# Output block stride for the rearranged format
@@ -497,7 +501,7 @@ def triton_mx_block_rearrange_2d_K_groups(
497
501
num_groups ,
498
502
num_row_blocks ,
499
503
)
500
- triton_scale_swizzle_2d_K_groups [grid ](
504
+ wrap_triton ( triton_scale_swizzle_2d_K_groups ) [grid ](
501
505
# Input scales
502
506
scales_tensor .view (torch .uint8 ),
503
507
scales_tensor .stride (0 ),
0 commit comments