Skip to content

Commit 7cb66b0

Browse files
author
周鹤云
committed
fix illegal memory access of GEMV kernel
1 parent e25b350 commit 7cb66b0

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

awq/kernels/csrc/quantization_new/gemv/gemv_cuda.cu

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ torch::Tensor gemv_forward_cuda_new(
247247
dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
248248
dim3 num_threads(BLOCK_SIZE);
249249

250+
constexpr int kSmemByteSizePerBatch = N_PER_BLOCK * K_INTERLEAVE * BLOCK_SIZE;
250251
// if (group_size == 64)
251252
// {
252253
// gemv_kernel_g64<<<num_blocks, num_threads>>>(
@@ -261,37 +262,37 @@ torch::Tensor gemv_forward_cuda_new(
261262
switch (m)
262263
{
263264
case 1:
264-
gemv_kernel<N_PER_BLOCK, 1, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
265+
gemv_kernel<N_PER_BLOCK, 1, BLOCK_SIZE, 128><<<num_blocks, num_threads, kSmemByteSizePerBatch * 1>>>(
265266
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
266267
);
267268
break;
268269
case 2:
269-
gemv_kernel<N_PER_BLOCK, 2, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
270+
gemv_kernel<N_PER_BLOCK, 2, BLOCK_SIZE, 128><<<num_blocks, num_threads, kSmemByteSizePerBatch * 2>>>(
270271
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
271272
);
272273
break;
273274
case 3:
274-
gemv_kernel<N_PER_BLOCK, 3, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
275+
gemv_kernel<N_PER_BLOCK, 3, BLOCK_SIZE, 128><<<num_blocks, num_threads, kSmemByteSizePerBatch * 3>>>(
275276
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
276277
);
277278
break;
278279
case 4:
279-
gemv_kernel<N_PER_BLOCK, 4, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
280+
gemv_kernel<N_PER_BLOCK, 4, BLOCK_SIZE, 128><<<num_blocks, num_threads, kSmemByteSizePerBatch * 4>>>(
280281
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
281282
);
282283
break;
283284
case 5:
284-
gemv_kernel<N_PER_BLOCK, 5, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
285+
gemv_kernel<N_PER_BLOCK, 5, BLOCK_SIZE, 128><<<num_blocks, num_threads, kSmemByteSizePerBatch * 5>>>(
285286
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
286287
);
287288
break;
288289
case 6:
289-
gemv_kernel<N_PER_BLOCK, 6, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
290+
gemv_kernel<N_PER_BLOCK, 6, BLOCK_SIZE, 128><<<num_blocks, num_threads, kSmemByteSizePerBatch * 6>>>(
290291
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
291292
);
292293
break;
293294
case 7:
294-
gemv_kernel<N_PER_BLOCK, 7, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
295+
gemv_kernel<N_PER_BLOCK, 7, BLOCK_SIZE, 128><<<num_blocks, num_threads, kSmemByteSizePerBatch * 7>>>(
295296
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
296297
);
297298
break;

0 commit comments

Comments
 (0)