@@ -247,6 +247,7 @@ torch::Tensor gemv_forward_cuda_new(
247247 dim3 num_blocks (n / N_PER_BLOCK / K_INTERLEAVE);
248248 dim3 num_threads (BLOCK_SIZE);
249249
250+ constexpr int kSmemByteSizePerBatch = N_PER_BLOCK * K_INTERLEAVE * BLOCK_SIZE;
250251 // if (group_size == 64)
251252 // {
252253 // gemv_kernel_g64<<<num_blocks, num_threads>>>(
@@ -261,37 +262,37 @@ torch::Tensor gemv_forward_cuda_new(
261262 switch (m)
262263 {
263264 case 1 :
264- gemv_kernel<N_PER_BLOCK, 1 , BLOCK_SIZE, 128 ><<<num_blocks, num_threads>>> (
265+ gemv_kernel<N_PER_BLOCK, 1 , BLOCK_SIZE, 128 ><<<num_blocks, num_threads, kSmemByteSizePerBatch * 1 >>> (
265266 in_feats, kernel, scaling_factors, zeros, out_feats, k, n
266267 );
267268 break ;
268269 case 2 :
269- gemv_kernel<N_PER_BLOCK, 2 , BLOCK_SIZE, 128 ><<<num_blocks, num_threads>>> (
270+ gemv_kernel<N_PER_BLOCK, 2 , BLOCK_SIZE, 128 ><<<num_blocks, num_threads, kSmemByteSizePerBatch * 2 >>> (
270271 in_feats, kernel, scaling_factors, zeros, out_feats, k, n
271272 );
272273 break ;
273274 case 3 :
274- gemv_kernel<N_PER_BLOCK, 3 , BLOCK_SIZE, 128 ><<<num_blocks, num_threads>>> (
275+ gemv_kernel<N_PER_BLOCK, 3 , BLOCK_SIZE, 128 ><<<num_blocks, num_threads, kSmemByteSizePerBatch * 3 >>> (
275276 in_feats, kernel, scaling_factors, zeros, out_feats, k, n
276277 );
277278 break ;
278279 case 4 :
279- gemv_kernel<N_PER_BLOCK, 4 , BLOCK_SIZE, 128 ><<<num_blocks, num_threads>>> (
280+ gemv_kernel<N_PER_BLOCK, 4 , BLOCK_SIZE, 128 ><<<num_blocks, num_threads, kSmemByteSizePerBatch * 4 >>> (
280281 in_feats, kernel, scaling_factors, zeros, out_feats, k, n
281282 );
282283 break ;
283284 case 5 :
284- gemv_kernel<N_PER_BLOCK, 5 , BLOCK_SIZE, 128 ><<<num_blocks, num_threads>>> (
285+ gemv_kernel<N_PER_BLOCK, 5 , BLOCK_SIZE, 128 ><<<num_blocks, num_threads, kSmemByteSizePerBatch * 5 >>> (
285286 in_feats, kernel, scaling_factors, zeros, out_feats, k, n
286287 );
287288 break ;
288289 case 6 :
289- gemv_kernel<N_PER_BLOCK, 6 , BLOCK_SIZE, 128 ><<<num_blocks, num_threads>>> (
290+ gemv_kernel<N_PER_BLOCK, 6 , BLOCK_SIZE, 128 ><<<num_blocks, num_threads, kSmemByteSizePerBatch * 6 >>> (
290291 in_feats, kernel, scaling_factors, zeros, out_feats, k, n
291292 );
292293 break ;
293294 case 7 :
294- gemv_kernel<N_PER_BLOCK, 7 , BLOCK_SIZE, 128 ><<<num_blocks, num_threads>>> (
295+ gemv_kernel<N_PER_BLOCK, 7 , BLOCK_SIZE, 128 ><<<num_blocks, num_threads, kSmemByteSizePerBatch * 7 >>> (
295296 in_feats, kernel, scaling_factors, zeros, out_feats, k, n
296297 );
297298 break ;
0 commit comments