4
4
import triton
5
5
import triton .language as tl
6
6
from torch import Tensor
7
+ from torch .library import triton_op , wrap_triton
7
8
8
9
from torchao .prototype .mx_formats .utils import to_blocked
9
10
from torchao .utils import ceil_div
@@ -29,7 +30,14 @@ def torch_to_blocked_2d_M_groups(
29
30
30
31
assert x_scales .ndim == 2 , "x_scales must be 2D"
31
32
assert block_size == 32 , "Only block_size=32 is supported for now"
32
- blocked_scales_list = []
33
+ total_M , _ = x_scales .shape
34
+ num_groups = group_offs .shape [0 ]
35
+
36
+ # Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group,
37
+ # the Triton kernenl will use an upper bound of adding 128 padding rows to each group.
38
+ # (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl).
39
+ total_M_padded = total_M + num_groups * 128
40
+ blocked_scales = x_scales .new_zeros (total_M_padded , K // block_size )
33
41
start_row_after_padding_list = [0 ]
34
42
group_start_idx = 0
35
43
for i , group_end_idx in enumerate (group_offs .tolist ()):
@@ -42,19 +50,24 @@ def torch_to_blocked_2d_M_groups(
42
50
# Convert group scales to blocked format
43
51
group_scales = x_scales [group_start_idx :group_end_idx ]
44
52
group_scales_blocked = to_blocked (group_scales )
45
- blocked_scales_list .append (group_scales_blocked )
46
53
47
54
# Calculate the start row after padding
48
55
scaling_groups_per_row = K // block_size
49
56
rows_for_group = group_scales_blocked .numel () // scaling_groups_per_row
50
57
new_start_row = prev_start_row_after_padding + rows_for_group
51
58
start_row_after_padding_list .append (new_start_row )
52
59
60
+ # Write output to subtensor
61
+ group_rows_padded = ceil_div (group_size , 128 ) * 128
62
+ blocked_scales [
63
+ prev_start_row_after_padding : prev_start_row_after_padding
64
+ + group_rows_padded ,
65
+ :,
66
+ ] = group_scales_blocked .reshape (- 1 , K // block_size )
67
+
53
68
# Update next group start index
54
69
group_start_idx = group_end_idx
55
70
56
- blocked_scales = torch .cat (blocked_scales_list , dim = 0 ).contiguous ()
57
- blocked_scales = blocked_scales .reshape (- 1 , K // 32 )
58
71
start_row_after_padding = torch .tensor (
59
72
start_row_after_padding_list , device = x_scales .device , dtype = torch .int64
60
73
)
@@ -79,34 +92,44 @@ def torch_to_blocked_2d_K_groups(
79
92
"""
80
93
assert x_scales .ndim == 2 , "x_scales must be 2D"
81
94
assert block_size == 32 , "Only block_size=32 is supported for now"
82
- blocked_scales_list = []
95
+ M , total_K = x_scales .shape
96
+ padded_M = ceil_div (M , 128 ) * 128
97
+ num_groups = group_offs .shape [0 ]
98
+
99
+ # Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group,
100
+ # Triton kernel will use an upper bound of adding 4 padding cols to each group.
101
+ # (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl).
102
+ total_K_padded = total_K + num_groups * 4
103
+ blocked_scales = x_scales .new_zeros (padded_M , total_K_padded )
104
+
83
105
start_col_after_padding_list = [0 ]
84
106
group_start_idx = 0
85
107
for i , group_end_idx in enumerate (group_offs .tolist ()):
86
108
group_size = group_end_idx - group_start_idx
87
- prev_start_row_after_padding = start_col_after_padding_list [i ]
109
+ prev_start_col_after_padding = start_col_after_padding_list [i ]
88
110
if group_size == 0 :
89
- start_col_after_padding_list .append (prev_start_row_after_padding )
111
+ start_col_after_padding_list .append (prev_start_col_after_padding )
90
112
continue
91
113
92
114
# Convert group scales to blocked format
93
115
group_scales = x_scales [:, group_start_idx :group_end_idx ]
94
116
group_scales_blocked = to_blocked (group_scales )
95
117
cols_after_padding = ceil_div (group_size , 4 ) * 4
96
- blocked_scales_list .append (group_scales_blocked )
118
+
119
+ # Write output to subtensor
120
+ blocked_scales [
121
+ :,
122
+ prev_start_col_after_padding : prev_start_col_after_padding
123
+ + cols_after_padding ,
124
+ ] = group_scales_blocked .reshape (- 1 , cols_after_padding )
97
125
98
126
# Calculate the start row after padding
99
- new_start_col = prev_start_row_after_padding + cols_after_padding
127
+ new_start_col = prev_start_col_after_padding + cols_after_padding
100
128
start_col_after_padding_list .append (new_start_col )
101
129
102
130
# Update next group start index
103
131
group_start_idx = group_end_idx
104
132
105
- # blocked_scales = torch.cat(blocked_scales_list, dim=1)
106
- M = x_scales .shape [0 ]
107
- padded_M = ceil_div (M , 128 ) * 128
108
- blocked_scales = torch .cat (blocked_scales_list )
109
- blocked_scales = blocked_scales .reshape (padded_M , - 1 )
110
133
start_cols_after_padding = torch .tensor (
111
134
start_col_after_padding_list , device = x_scales .device , dtype = torch .int64
112
135
)
@@ -192,6 +215,7 @@ def compute_blocked_scale_offsets_for_K_groups(
192
215
return group_sizes , starting_col_after_padding
193
216
194
217
218
+ @triton_op ("torchao::triton_mx_block_rearrange_2d_M_groups" , mutates_args = {})
195
219
def triton_mx_block_rearrange_2d_M_groups (
196
220
scales_tensor : torch .Tensor ,
197
221
input_group_end_offsets : torch .Tensor ,
@@ -216,14 +240,14 @@ def triton_mx_block_rearrange_2d_M_groups(
216
240
"Expected element size to be 1 byte (8 bits)"
217
241
)
218
242
rows , cols = scales_tensor .shape
219
- num_groups = input_group_end_offsets .numel ()
243
+ num_groups = input_group_end_offsets .shape [ 0 ]
220
244
221
245
# Final offset is the total number of rows in the tensor
222
- padded_rows = output_group_start_offsets [ - 1 ]
246
+ padded_rows = rows + num_groups * 128
223
247
224
248
num_col_blocks = ceil_div (cols , 4 )
225
249
padded_cols = num_col_blocks * 4
226
- output = scales_tensor .new_empty ((padded_rows , padded_cols ))
250
+ output = scales_tensor .new_zeros ((padded_rows , padded_cols ))
227
251
228
252
# Output block stride for the rearranged format
229
253
BLOCK_ROWS , BLOCK_COLS = 128 , 4
@@ -238,7 +262,7 @@ def triton_mx_block_rearrange_2d_M_groups(
238
262
num_groups ,
239
263
num_col_blocks ,
240
264
)
241
- triton_scale_swizzle_M_groups [grid ](
265
+ wrap_triton ( triton_scale_swizzle_M_groups ) [grid ](
242
266
# Input scales
243
267
scales_tensor .view (torch .uint8 ),
244
268
scales_tensor .stride (0 ),
@@ -336,6 +360,7 @@ def triton_scale_swizzle_M_groups(
336
360
current_start_row += BLOCK_ROWS
337
361
338
362
363
+ @triton_op ("torchao::triton_mx_block_rearrange_per_group_3d" , mutates_args = {})
339
364
def triton_mx_block_rearrange_per_group_3d (scale_tensor : torch .Tensor ) -> torch .Tensor :
340
365
"""
341
366
Rearranges an E8M0 tensor scale to block-scaled swizzle format.
@@ -379,7 +404,7 @@ def triton_mx_block_rearrange_per_group_3d(scale_tensor: torch.Tensor) -> torch.
379
404
num_col_blocks ,
380
405
)
381
406
382
- triton_scale_swizzle_per_group_3d [grid ](
407
+ wrap_triton ( triton_scale_swizzle_per_group_3d ) [grid ](
383
408
scale_tensor .view (torch .uint8 ),
384
409
input_stride_dim0 ,
385
410
input_stride_dim1 ,
@@ -454,6 +479,7 @@ def triton_scale_swizzle_per_group_3d(
454
479
)
455
480
456
481
482
+ @triton_op ("torchao::triton_mx_block_rearrange_2d_K_groups" , mutates_args = {})
457
483
def triton_mx_block_rearrange_2d_K_groups (
458
484
scales_tensor : torch .Tensor ,
459
485
input_group_end_offsets : torch .Tensor ,
@@ -479,13 +505,13 @@ def triton_mx_block_rearrange_2d_K_groups(
479
505
)
480
506
rows , cols = scales_tensor .shape
481
507
# Calculate blocks needed
482
- num_groups = input_group_end_offsets .numel ()
508
+ num_groups = input_group_end_offsets .shape [ 0 ]
483
509
num_row_blocks = ceil_div (rows , 128 )
484
510
padded_rows = num_row_blocks * 128
485
511
486
512
# output_group_start_offsets always starts with 0 and ends with the total number of cols
487
- padded_cols = output_group_start_offsets [ - 1 ]
488
- output = scales_tensor .new_empty ((padded_rows , padded_cols ))
513
+ padded_cols = cols + num_groups * 4
514
+ output = scales_tensor .new_zeros ((padded_rows , padded_cols ))
489
515
490
516
# Output block stride for the rearranged format
491
517
BLOCK_ROWS , BLOCK_COLS = 128 , 4
@@ -497,7 +523,7 @@ def triton_mx_block_rearrange_2d_K_groups(
497
523
num_groups ,
498
524
num_row_blocks ,
499
525
)
500
- triton_scale_swizzle_2d_K_groups [grid ](
526
+ wrap_triton ( triton_scale_swizzle_2d_K_groups ) [grid ](
501
527
# Input scales
502
528
scales_tensor .view (torch .uint8 ),
503
529
scales_tensor .stride (0 ),
0 commit comments