@@ -34,7 +34,14 @@ def torch_to_blocked_2d_M_groups(
34
34
35
35
assert x_scales .ndim == 2 , "x_scales must be 2D"
36
36
assert block_size == 32 , "Only block_size=32 is supported for now"
37
- blocked_scales_list = []
37
+ total_M , _ = x_scales .shape
38
+ num_groups = group_offs .shape [0 ]
39
+
40
+ # Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group,
41
+ # the Triton kernenl will use an upper bound of adding 128 padding rows to each group.
42
+ # (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl).
43
+ total_M_padded = total_M + num_groups * 128
44
+ blocked_scales = x_scales .new_zeros (total_M_padded , K // block_size )
38
45
start_row_after_padding_list = [0 ]
39
46
group_start_idx = 0
40
47
for i , group_end_idx in enumerate (group_offs .tolist ()):
@@ -47,19 +54,24 @@ def torch_to_blocked_2d_M_groups(
47
54
# Convert group scales to blocked format
48
55
group_scales = x_scales [group_start_idx :group_end_idx ]
49
56
group_scales_blocked = to_blocked (group_scales )
50
- blocked_scales_list .append (group_scales_blocked )
51
57
52
58
# Calculate the start row after padding
53
59
scaling_groups_per_row = K // block_size
54
60
rows_for_group = group_scales_blocked .numel () // scaling_groups_per_row
55
61
new_start_row = prev_start_row_after_padding + rows_for_group
56
62
start_row_after_padding_list .append (new_start_row )
57
63
64
+ # Write output to subtensor
65
+ group_rows_padded = ceil_div (group_size , 128 ) * 128
66
+ blocked_scales [
67
+ prev_start_row_after_padding : prev_start_row_after_padding
68
+ + group_rows_padded ,
69
+ :,
70
+ ] = group_scales_blocked .reshape (- 1 , K // block_size )
71
+
58
72
# Update next group start index
59
73
group_start_idx = group_end_idx
60
74
61
- blocked_scales = torch .cat (blocked_scales_list , dim = 0 ).contiguous ()
62
- blocked_scales = blocked_scales .reshape (- 1 , K // 32 )
63
75
start_row_after_padding = torch .tensor (
64
76
start_row_after_padding_list , device = x_scales .device , dtype = torch .int64
65
77
)
@@ -84,34 +96,44 @@ def torch_to_blocked_2d_K_groups(
84
96
"""
85
97
assert x_scales .ndim == 2 , "x_scales must be 2D"
86
98
assert block_size == 32 , "Only block_size=32 is supported for now"
87
- blocked_scales_list = []
99
+ M , total_K = x_scales .shape
100
+ padded_M = ceil_div (M , 128 ) * 128
101
+ num_groups = group_offs .shape [0 ]
102
+
103
+ # Each group will require a variable amount of padding, so to avoid d2h sync causing by iterating over each group,
104
+ # Triton kernel will use an upper bound of adding 4 padding cols to each group.
105
+ # (This torch impl is used as a reference for correctness, so we must match the triton kernel's impl).
106
+ total_K_padded = total_K + num_groups * 4
107
+ blocked_scales = x_scales .new_zeros (padded_M , total_K_padded )
108
+
88
109
start_col_after_padding_list = [0 ]
89
110
group_start_idx = 0
90
111
for i , group_end_idx in enumerate (group_offs .tolist ()):
91
112
group_size = group_end_idx - group_start_idx
92
- prev_start_row_after_padding = start_col_after_padding_list [i ]
113
+ prev_start_col_after_padding = start_col_after_padding_list [i ]
93
114
if group_size == 0 :
94
- start_col_after_padding_list .append (prev_start_row_after_padding )
115
+ start_col_after_padding_list .append (prev_start_col_after_padding )
95
116
continue
96
117
97
118
# Convert group scales to blocked format
98
119
group_scales = x_scales [:, group_start_idx :group_end_idx ]
99
120
group_scales_blocked = to_blocked (group_scales )
100
121
cols_after_padding = ceil_div (group_size , 4 ) * 4
101
- blocked_scales_list .append (group_scales_blocked )
122
+
123
+ # Write output to subtensor
124
+ blocked_scales [
125
+ :,
126
+ prev_start_col_after_padding : prev_start_col_after_padding
127
+ + cols_after_padding ,
128
+ ] = group_scales_blocked .reshape (- 1 , cols_after_padding )
102
129
103
130
# Calculate the start row after padding
104
- new_start_col = prev_start_row_after_padding + cols_after_padding
131
+ new_start_col = prev_start_col_after_padding + cols_after_padding
105
132
start_col_after_padding_list .append (new_start_col )
106
133
107
134
# Update next group start index
108
135
group_start_idx = group_end_idx
109
136
110
- # blocked_scales = torch.cat(blocked_scales_list, dim=1)
111
- M = x_scales .shape [0 ]
112
- padded_M = ceil_div (M , 128 ) * 128
113
- blocked_scales = torch .cat (blocked_scales_list )
114
- blocked_scales = blocked_scales .reshape (padded_M , - 1 )
115
137
start_cols_after_padding = torch .tensor (
116
138
start_col_after_padding_list , device = x_scales .device , dtype = torch .int64
117
139
)
@@ -225,11 +247,11 @@ def triton_mx_block_rearrange_2d_M_groups(
225
247
num_groups = input_group_end_offsets .shape [0 ]
226
248
227
249
# Final offset is the total number of rows in the tensor
228
- padded_rows = rows + num_groups * 128 # output_group_start_offsets[-1]
250
+ padded_rows = rows + num_groups * 128
229
251
230
252
num_col_blocks = ceil_div (cols , 4 )
231
253
padded_cols = num_col_blocks * 4
232
- output = scales_tensor .new_empty ((padded_rows , padded_cols ))
254
+ output = scales_tensor .new_zeros ((padded_rows , padded_cols ))
233
255
234
256
# Output block stride for the rearranged format
235
257
BLOCK_ROWS , BLOCK_COLS = 128 , 4
@@ -492,8 +514,8 @@ def triton_mx_block_rearrange_2d_K_groups(
492
514
padded_rows = num_row_blocks * 128
493
515
494
516
# output_group_start_offsets always starts with 0 and ends with the total number of cols
495
- padded_cols = cols + num_groups * 4 # output_group_start_offsets[-1]
496
- output = scales_tensor .new_empty ((padded_rows , padded_cols ))
517
+ padded_cols = cols + num_groups * 4
518
+ output = scales_tensor .new_zeros ((padded_rows , padded_cols ))
497
519
498
520
# Output block stride for the rearranged format
499
521
BLOCK_ROWS , BLOCK_COLS = 128 , 4
0 commit comments