44import triton
55import triton .language as tl
66from torch import Tensor
7+ from torch .library import triton_op , wrap_triton
78
89from torchao .prototype .mx_formats .utils import to_blocked
910from torchao .utils import ceil_div
@@ -29,7 +30,14 @@ def torch_to_blocked_2d_M_groups(
2930
3031 assert x_scales .ndim == 2 , "x_scales must be 2D"
3132 assert block_size == 32 , "Only block_size=32 is supported for now"
32- blocked_scales_list = []
33+ total_M , _ = x_scales .shape
34+ num_groups = group_offs .shape [0 ]
35+
36+ # Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group,
37+ # the Triton kernenl will use an upper bound of adding 128 padding rows to each group.
38+ # (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl).
39+ total_M_padded = total_M + num_groups * 128
40+ blocked_scales = x_scales .new_zeros (total_M_padded , K // block_size )
3341 start_row_after_padding_list = [0 ]
3442 group_start_idx = 0
3543 for i , group_end_idx in enumerate (group_offs .tolist ()):
@@ -42,19 +50,24 @@ def torch_to_blocked_2d_M_groups(
4250 # Convert group scales to blocked format
4351 group_scales = x_scales [group_start_idx :group_end_idx ]
4452 group_scales_blocked = to_blocked (group_scales )
45- blocked_scales_list .append (group_scales_blocked )
4653
4754 # Calculate the start row after padding
4855 scaling_groups_per_row = K // block_size
4956 rows_for_group = group_scales_blocked .numel () // scaling_groups_per_row
5057 new_start_row = prev_start_row_after_padding + rows_for_group
5158 start_row_after_padding_list .append (new_start_row )
5259
60+ # Write output to subtensor
61+ group_rows_padded = ceil_div (group_size , 128 ) * 128
62+ blocked_scales [
63+ prev_start_row_after_padding : prev_start_row_after_padding
64+ + group_rows_padded ,
65+ :,
66+ ] = group_scales_blocked .reshape (- 1 , K // block_size )
67+
5368 # Update next group start index
5469 group_start_idx = group_end_idx
5570
56- blocked_scales = torch .cat (blocked_scales_list , dim = 0 ).contiguous ()
57- blocked_scales = blocked_scales .reshape (- 1 , K // 32 )
5871 start_row_after_padding = torch .tensor (
5972 start_row_after_padding_list , device = x_scales .device , dtype = torch .int64
6073 )
@@ -79,34 +92,44 @@ def torch_to_blocked_2d_K_groups(
7992 """
8093 assert x_scales .ndim == 2 , "x_scales must be 2D"
8194 assert block_size == 32 , "Only block_size=32 is supported for now"
82- blocked_scales_list = []
95+ M , total_K = x_scales .shape
96+ padded_M = ceil_div (M , 128 ) * 128
97+ num_groups = group_offs .shape [0 ]
98+
99+ # Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group,
100+ # Triton kernel will use an upper bound of adding 4 padding cols to each group.
101+ # (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl).
102+ total_K_padded = total_K + num_groups * 4
103+ blocked_scales = x_scales .new_zeros (padded_M , total_K_padded )
104+
83105 start_col_after_padding_list = [0 ]
84106 group_start_idx = 0
85107 for i , group_end_idx in enumerate (group_offs .tolist ()):
86108 group_size = group_end_idx - group_start_idx
87- prev_start_row_after_padding = start_col_after_padding_list [i ]
109+ prev_start_col_after_padding = start_col_after_padding_list [i ]
88110 if group_size == 0 :
89- start_col_after_padding_list .append (prev_start_row_after_padding )
111+ start_col_after_padding_list .append (prev_start_col_after_padding )
90112 continue
91113
92114 # Convert group scales to blocked format
93115 group_scales = x_scales [:, group_start_idx :group_end_idx ]
94116 group_scales_blocked = to_blocked (group_scales )
95117 cols_after_padding = ceil_div (group_size , 4 ) * 4
96- blocked_scales_list .append (group_scales_blocked )
118+
119+ # Write output to subtensor
120+ blocked_scales [
121+ :,
122+ prev_start_col_after_padding : prev_start_col_after_padding
123+ + cols_after_padding ,
124+ ] = group_scales_blocked .reshape (- 1 , cols_after_padding )
97125
98126 # Calculate the start row after padding
99- new_start_col = prev_start_row_after_padding + cols_after_padding
127+ new_start_col = prev_start_col_after_padding + cols_after_padding
100128 start_col_after_padding_list .append (new_start_col )
101129
102130 # Update next group start index
103131 group_start_idx = group_end_idx
104132
105- # blocked_scales = torch.cat(blocked_scales_list, dim=1)
106- M = x_scales .shape [0 ]
107- padded_M = ceil_div (M , 128 ) * 128
108- blocked_scales = torch .cat (blocked_scales_list )
109- blocked_scales = blocked_scales .reshape (padded_M , - 1 )
110133 start_cols_after_padding = torch .tensor (
111134 start_col_after_padding_list , device = x_scales .device , dtype = torch .int64
112135 )
@@ -192,6 +215,7 @@ def compute_blocked_scale_offsets_for_K_groups(
192215 return group_sizes , starting_col_after_padding
193216
194217
218+ @triton_op ("torchao::triton_mx_block_rearrange_2d_M_groups" , mutates_args = {})
195219def triton_mx_block_rearrange_2d_M_groups (
196220 scales_tensor : torch .Tensor ,
197221 input_group_end_offsets : torch .Tensor ,
@@ -216,14 +240,14 @@ def triton_mx_block_rearrange_2d_M_groups(
216240 "Expected element size to be 1 byte (8 bits)"
217241 )
218242 rows , cols = scales_tensor .shape
219- num_groups = input_group_end_offsets .numel ()
243+ num_groups = input_group_end_offsets .shape [ 0 ]
220244
221245 # Final offset is the total number of rows in the tensor
222- padded_rows = output_group_start_offsets [ - 1 ]
246+ padded_rows = rows + num_groups * 128
223247
224248 num_col_blocks = ceil_div (cols , 4 )
225249 padded_cols = num_col_blocks * 4
226- output = scales_tensor .new_empty ((padded_rows , padded_cols ))
250+ output = scales_tensor .new_zeros ((padded_rows , padded_cols ))
227251
228252 # Output block stride for the rearranged format
229253 BLOCK_ROWS , BLOCK_COLS = 128 , 4
@@ -238,7 +262,7 @@ def triton_mx_block_rearrange_2d_M_groups(
238262 num_groups ,
239263 num_col_blocks ,
240264 )
241- triton_scale_swizzle_M_groups [grid ](
265+ wrap_triton ( triton_scale_swizzle_M_groups ) [grid ](
242266 # Input scales
243267 scales_tensor .view (torch .uint8 ),
244268 scales_tensor .stride (0 ),
@@ -336,6 +360,7 @@ def triton_scale_swizzle_M_groups(
336360 current_start_row += BLOCK_ROWS
337361
338362
363+ @triton_op ("torchao::triton_mx_block_rearrange_per_group_3d" , mutates_args = {})
339364def triton_mx_block_rearrange_per_group_3d (scale_tensor : torch .Tensor ) -> torch .Tensor :
340365 """
341366 Rearranges an E8M0 tensor scale to block-scaled swizzle format.
@@ -379,7 +404,7 @@ def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.
379404 num_col_blocks ,
380405 )
381406
382- triton_scale_swizzle_per_group_3d [grid ](
407+ wrap_triton ( triton_scale_swizzle_per_group_3d ) [grid ](
383408 scale_tensor .view (torch .uint8 ),
384409 input_stride_dim0 ,
385410 input_stride_dim1 ,
@@ -454,6 +479,7 @@ def triton_scale_swizzle_per_group_3d(
454479 )
455480
456481
482+ @triton_op ("torchao::triton_mx_block_rearrange_2d_K_groups" , mutates_args = {})
457483def triton_mx_block_rearrange_2d_K_groups (
458484 scales_tensor : torch .Tensor ,
459485 input_group_end_offsets : torch .Tensor ,
@@ -479,13 +505,13 @@ def triton_mx_block_rearrange_2d_K_groups(
479505 )
480506 rows , cols = scales_tensor .shape
481507 # Calculate blocks needed
482- num_groups = input_group_end_offsets .numel ()
508+ num_groups = input_group_end_offsets .shape [ 0 ]
483509 num_row_blocks = ceil_div (rows , 128 )
484510 padded_rows = num_row_blocks * 128
485511
486512 # output_group_start_offsets always starts with 0 and ends with the total number of cols
487- padded_cols = output_group_start_offsets [ - 1 ]
488- output = scales_tensor .new_empty ((padded_rows , padded_cols ))
513+ padded_cols = cols + num_groups * 4
514+ output = scales_tensor .new_zeros ((padded_rows , padded_cols ))
489515
490516 # Output block stride for the rearranged format
491517 BLOCK_ROWS , BLOCK_COLS = 128 , 4
@@ -497,7 +523,7 @@ def triton_mx_block_rearrange_2d_K_groups(
497523 num_groups ,
498524 num_row_blocks ,
499525 )
500- triton_scale_swizzle_2d_K_groups [grid ](
526+ wrap_triton ( triton_scale_swizzle_2d_K_groups ) [grid ](
501527 # Input scales
502528 scales_tensor .view (torch .uint8 ),
503529 scales_tensor .stride (0 ),
0 commit comments