44import triton
55import triton .language as tl
66from torch import Tensor
7+ from torch .library import triton_op , wrap_triton
78
89from torchao .prototype .mx_formats .utils import to_blocked
910from torchao .utils import ceil_div
@@ -29,7 +30,14 @@ def torch_to_blocked_2d_M_groups(
2930
3031 assert x_scales .ndim == 2 , "x_scales must be 2D"
3132 assert block_size == 32 , "Only block_size=32 is supported for now"
32- blocked_scales_list = []
33+ total_M , _ = x_scales .shape
34+ num_groups = group_offs .shape [0 ]
35+
36+ # Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group,
37+ # the Triton kernenl will use an upper bound of adding 128 padding rows to each group.
38+ # (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl).
39+ total_M_padded = total_M + num_groups * 128
40+ blocked_scales = x_scales .new_zeros (total_M_padded , K // block_size )
3341 start_row_after_padding_list = [0 ]
3442 group_start_idx = 0
3543 for i , group_end_idx in enumerate (group_offs .tolist ()):
@@ -42,19 +50,24 @@ def torch_to_blocked_2d_M_groups(
4250 # Convert group scales to blocked format
4351 group_scales = x_scales [group_start_idx :group_end_idx ]
4452 group_scales_blocked = to_blocked (group_scales )
45- blocked_scales_list .append (group_scales_blocked )
4653
4754 # Calculate the start row after padding
4855 scaling_groups_per_row = K // block_size
4956 rows_for_group = group_scales_blocked .numel () // scaling_groups_per_row
5057 new_start_row = prev_start_row_after_padding + rows_for_group
5158 start_row_after_padding_list .append (new_start_row )
5259
60+ # Write output to subtensor
61+ group_rows_padded = ceil_div (group_size , 128 ) * 128
62+ blocked_scales [
63+ prev_start_row_after_padding : prev_start_row_after_padding
64+ + group_rows_padded ,
65+ :,
66+ ] = group_scales_blocked .reshape (- 1 , K // block_size )
67+
5368 # Update next group start index
5469 group_start_idx = group_end_idx
5570
56- blocked_scales = torch .cat (blocked_scales_list , dim = 0 ).contiguous ()
57- blocked_scales = blocked_scales .reshape (- 1 , K // 32 )
5871 start_row_after_padding = torch .tensor (
5972 start_row_after_padding_list , device = x_scales .device , dtype = torch .int64
6073 )
@@ -79,34 +92,44 @@ def torch_to_blocked_2d_K_groups(
7992 """
8093 assert x_scales .ndim == 2 , "x_scales must be 2D"
8194 assert block_size == 32 , "Only block_size=32 is supported for now"
82- blocked_scales_list = []
95+ M , total_K = x_scales .shape
96+ padded_M = ceil_div (M , 128 ) * 128
97+ num_groups = group_offs .shape [0 ]
98+
99+ # Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group,
100+ # Triton kernel will use an upper bound of adding 4 padding cols to each group.
101+ # (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl).
102+ total_K_padded = total_K + num_groups * 4
103+ blocked_scales = x_scales .new_zeros (padded_M , total_K_padded )
104+
83105 start_col_after_padding_list = [0 ]
84106 group_start_idx = 0
85107 for i , group_end_idx in enumerate (group_offs .tolist ()):
86108 group_size = group_end_idx - group_start_idx
87- prev_start_row_after_padding = start_col_after_padding_list [i ]
109+ prev_start_col_after_padding = start_col_after_padding_list [i ]
88110 if group_size == 0 :
89- start_col_after_padding_list .append (prev_start_row_after_padding )
111+ start_col_after_padding_list .append (prev_start_col_after_padding )
90112 continue
91113
92114 # Convert group scales to blocked format
93115 group_scales = x_scales [:, group_start_idx :group_end_idx ]
94116 group_scales_blocked = to_blocked (group_scales )
95117 cols_after_padding = ceil_div (group_size , 4 ) * 4
96- blocked_scales_list .append (group_scales_blocked )
118+
119+ # Write output to subtensor
120+ blocked_scales [
121+ :,
122+ prev_start_col_after_padding : prev_start_col_after_padding
123+ + cols_after_padding ,
124+ ] = group_scales_blocked .reshape (- 1 , cols_after_padding )
97125
98126 # Calculate the start row after padding
99- new_start_col = prev_start_row_after_padding + cols_after_padding
127+ new_start_col = prev_start_col_after_padding + cols_after_padding
100128 start_col_after_padding_list .append (new_start_col )
101129
102130 # Update next group start index
103131 group_start_idx = group_end_idx
104132
105- # blocked_scales = torch.cat(blocked_scales_list, dim=1)
106- M = x_scales .shape [0 ]
107- padded_M = ceil_div (M , 128 ) * 128
108- blocked_scales = torch .cat (blocked_scales_list )
109- blocked_scales = blocked_scales .reshape (padded_M , - 1 )
110133 start_cols_after_padding = torch .tensor (
111134 start_col_after_padding_list , device = x_scales .device , dtype = torch .int64
112135 )
@@ -192,6 +215,7 @@ def compute_blocked_scale_offsets_for_K_groups(
192215 return group_sizes , starting_col_after_padding
193216
194217
218+ @triton_op ("torchao::triton_mx_block_rearrange_2d_M_groups" , mutates_args = {})
195219def triton_mx_block_rearrange_2d_M_groups (
196220 scales_tensor : torch .Tensor ,
197221 input_group_end_offsets : torch .Tensor ,
@@ -216,14 +240,16 @@ def triton_mx_block_rearrange_2d_M_groups(
216240 "Expected element size to be 1 byte (8 bits)"
217241 )
218242 rows , cols = scales_tensor .shape
219- num_groups = input_group_end_offsets .numel ()
243+ num_groups = input_group_end_offsets .shape [ 0 ]
220244
221- # Final offset is the total number of rows in the tensor
222- padded_rows = output_group_start_offsets [- 1 ]
245+ # Final offset is the total number of rows in the tensor.
246+ # Padding needing per group is variable/data dependent, so we just pad each group by
247+ # the upper bound of 128 rows to avoid a d2h sync caused by iterating over each group.
248+ padded_rows = rows + num_groups * 128
223249
224250 num_col_blocks = ceil_div (cols , 4 )
225251 padded_cols = num_col_blocks * 4
226- output = scales_tensor .new_empty ((padded_rows , padded_cols ))
252+ output = scales_tensor .new_zeros ((padded_rows , padded_cols ))
227253
228254 # Output block stride for the rearranged format
229255 BLOCK_ROWS , BLOCK_COLS = 128 , 4
@@ -238,7 +264,7 @@ def triton_mx_block_rearrange_2d_M_groups(
238264 num_groups ,
239265 num_col_blocks ,
240266 )
241- triton_scale_swizzle_M_groups [grid ](
267+ wrap_triton ( triton_scale_swizzle_M_groups ) [grid ](
242268 # Input scales
243269 scales_tensor .view (torch .uint8 ),
244270 scales_tensor .stride (0 ),
@@ -336,6 +362,7 @@ def triton_scale_swizzle_M_groups(
336362 current_start_row += BLOCK_ROWS
337363
338364
365+ @triton_op ("torchao::triton_mx_block_rearrange_per_group_3d" , mutates_args = {})
339366def triton_mx_block_rearrange_per_group_3d (scale_tensor : torch .Tensor ) -> torch .Tensor :
340367 """
341368 Rearranges an E8M0 tensor scale to block-scaled swizzle format.
@@ -379,7 +406,7 @@ def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.
379406 num_col_blocks ,
380407 )
381408
382- triton_scale_swizzle_per_group_3d [grid ](
409+ wrap_triton ( triton_scale_swizzle_per_group_3d ) [grid ](
383410 scale_tensor .view (torch .uint8 ),
384411 input_stride_dim0 ,
385412 input_stride_dim1 ,
@@ -454,6 +481,7 @@ def triton_scale_swizzle_per_group_3d(
454481 )
455482
456483
484+ @triton_op ("torchao::triton_mx_block_rearrange_2d_K_groups" , mutates_args = {})
457485def triton_mx_block_rearrange_2d_K_groups (
458486 scales_tensor : torch .Tensor ,
459487 input_group_end_offsets : torch .Tensor ,
@@ -479,13 +507,14 @@ def triton_mx_block_rearrange_2d_K_groups(
479507 )
480508 rows , cols = scales_tensor .shape
481509 # Calculate blocks needed
482- num_groups = input_group_end_offsets .numel ()
510+ num_groups = input_group_end_offsets .shape [ 0 ]
483511 num_row_blocks = ceil_div (rows , 128 )
484512 padded_rows = num_row_blocks * 128
485513
486- # output_group_start_offsets always starts with 0 and ends with the total number of cols
487- padded_cols = output_group_start_offsets [- 1 ]
488- output = scales_tensor .new_empty ((padded_rows , padded_cols ))
514+ # Padding needing per group is variable/data dependent, so we just pad each group by
515+ # the upper bound of 4 cols to avoid a d2h sync caused by iterating over each group.
516+ padded_cols = cols + num_groups * 4
517+ output = scales_tensor .new_zeros ((padded_rows , padded_cols ))
489518
490519 # Output block stride for the rearranged format
491520 BLOCK_ROWS , BLOCK_COLS = 128 , 4
@@ -497,7 +526,7 @@ def triton_mx_block_rearrange_2d_K_groups(
497526 num_groups ,
498527 num_row_blocks ,
499528 )
500- triton_scale_swizzle_2d_K_groups [grid ](
529+ wrap_triton ( triton_scale_swizzle_2d_K_groups ) [grid ](
501530 # Input scales
502531 scales_tensor .view (torch .uint8 ),
503532 scales_tensor .stride (0 ),
0 commit comments