From f30ed4988163957e225bd09012f8b748bddfbcdc Mon Sep 17 00:00:00 2001 From: rib-2 Date: Wed, 17 Jan 2024 20:31:36 +0000 Subject: [PATCH 01/47] marlin --- csrc/ops.h | 7 +++++++ csrc/pybind.cpp | 3 ++- setup.py | 2 ++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/csrc/ops.h b/csrc/ops.h index 9340a60da141..71ea700c11eb 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -70,6 +70,13 @@ torch::Tensor awq_gemm( torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters); + +void marlin_gemm( + const torch::Tensor& input, + const torch::Tensor& weights, + torch::Tensor& output, + const torch::Tensor& scales, + torch::Tensor& workspace); #endif void squeezellm_gemm( diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 95f557686f33..f9a1c6fe7369 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -51,11 +51,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #ifndef USE_ROCM // Quantization ops ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); + ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ"); #endif ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); - + // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); cache_ops.def( diff --git a/setup.py b/setup.py index 811d494e7a01..1d40f28834c4 100644 --- a/setup.py +++ b/setup.py @@ -226,6 +226,8 @@ def get_torch_arch_list() -> Set[str]: if _is_cuda(): vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") + print("\n\n HERE \n\n") + vllm_extension_sources.append("csrc/quantization/marlin/marlin_cuda_kernel.cu") vllm_extension = CUDAExtension( name="vllm._C", From 837d34481f1b52349758da344dfecf1200e1455e Mon Sep 17 00:00:00 2001 From: rib-2 Date: Wed, 17 Jan 2024 20:33:38 +0000 Subject: [PATCH 02/47] added marlin --- .../quantization/marlin/marlin_cuda_kernel.cu | 820 ++++++++++++++++++ 1 file changed, 820 insertions(+) create mode 100644 csrc/quantization/marlin/marlin_cuda_kernel.cu diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu new file mode 100644 index 000000000000..61856abb8906 --- /dev/null +++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu @@ -0,0 +1,820 @@ +#ifndef MARLIN_CUDA_KERNEL_CUH +#define MARLIN_CUDA_KERNEL_CUH + +#include +#include +#include +#include + +namespace vllm { +namespace marlin { + +constexpr int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +// Instances of `Vec` are used to organize groups of >>registers<<, as are for instance needed as inputs to tensor core +// operations. Consequently, all corresponding index accesses must be compile time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { + return elems[i]; + } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that +// are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Asynchronous global->shared copy with a chache hint indicating that the values may be evicted immediately; used for +// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need +// for inputs A and outputs C. +__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .b64 p;\n" + " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" + " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES) + ); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" :: "n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]) + ); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem) + ); +} + +// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to +// automatically recognize it in all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut) + ); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values. +// We mostly follow the strategy in the link below, with some smaller changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2( + *reinterpret_cast(&lo), + *reinterpret_cast(&SUB) + ); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), + *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) + ); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used only in group mode. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible globally. + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock) { + __syncthreads(); + if (threadIdx.x == 0) { + int val = 1; + // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier. + asm volatile ("fence.acq_rel.gpu;\n"); + asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } +} + + +template < + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const int stages, // number of stages for the async global->shared fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale +> +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple + // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs + // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as + // possible. + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles, gridDim.x); + // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case + // where a stripe starts in the middle of group. + if (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col = (iters * blockIdx.x) / k_tiles; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to top + + // Compute all information about the current slice which is required for synchronization. + auto init_slice = [&] () { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col + slice_row); + if (slice_iters < 0 || slice_col >= n_tiles) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col, iters); + if (col_first <= k_tiles * (slice_col + 1)) { + int col_off = col_first - k_tiles * slice_col; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) + slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) + slice_idx--; + } + } + }; + init_slice(); + + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + // We typically use `constexpr` to indicate that this value is a compile-time constant + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile + + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_sh_stage = s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + int s_sh_wr = threadIdx.x; + int s_sh_rd; + // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major + // layout in the former and in row-major in the latter case. + if (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than + // required for a certain tilesize or when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank + // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of + // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based + // on NSight-Compute) that each warp must also write a consecutive memory segment? + auto transform_a = [&] (int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory + // accesses are static, we simply precompute both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between + // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s = sh_b + (stages * b_sh_stage); + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; + + // Zero accumulators. + auto zero_accums = [&] () { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location. + auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i] + ); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) + cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + s_gl_rd += s_gl_rd_delta; + } + } + // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&] () { + // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when + // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer. + auto fetch_to_registers = [&] (int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a + // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the + // compiler and correspondingly a noticable drop in performance. + if (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&] (int k) { + // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + FragB frag_b0 = dequant(b_quant); + // If there are no groups, we can just scale the final output once and can avoid doing so for each weight. + if (group_blocks != -1) + scale(frag_b0, frag_s[k % 2][j], 0); + FragB frag_b1 = dequant(b_quant_shift); + if (group_blocks != -1) + scale(frag_b1, frag_s[k % 2][j], 1); + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n + // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&] () { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations, + // e.g., for two warps we write only once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over + // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather + // small, we perform this reduction serially in L2 cache. + auto global_reduce = [&] (bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step. + // To do this, we write out results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns, + // hence we also use async-copies even though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m + ); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float( + reinterpret_cast<__half*>(&c_red)[j] + ); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = __float2half( + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] + ); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step, + // the reduction above is performed in fragment layout. + auto write_result = [&] () { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final global write patterns + auto write = [&] (int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + if (group_blocks == -1) // for per-column quantization we finally apply the scale here + res = __hmul2(res, s[0]); + ((half2*) sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&] () { + #pragma unroll + for (int i = 0; i < stages - 1; i++) + fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are + // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) + break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most + // readable, other ways of writing the loop seemed to noticeably worse performance after compliation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before write-out + if (group_blocks == -1 && last) { + if (s_sh_wr_pred) + cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async_fence(); + } + thread_block_reduce(); + if (group_blocks == -1 && last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + if (slice_count > 1) { // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col]); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + + +// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more +// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. +const int THREADS = 256; +const int STAGES = 4; // 4 pipeline stages fit into shared memory +const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ + else if ( \ + thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS \ + ) { \ + cudaMemset(locks, 0, 4 * cols); \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + SHARED_MEM \ + ); \ + Marlin<<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, \ + prob_m, prob_n, prob_k, \ + locks \ + ); \ + } + +const int ERR_PROB_SHAPE = 1; +const int ERR_KERN_SHAPE = 2; + +int marlin_cuda( + const void* A, + const void* B, + void* C, + void* s, + int prob_m, + int prob_n, + int prob_k, + void* workspace, + int groupsize = -1, + int dev = 0, + int thread_k = -1, + int thread_n = -1, + int sms = -1 +) { + int tot_m = prob_m; + int tot_m_blocks = ceildiv(tot_m, 16); + + if (sms == -1) + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + if (thread_k == -1 || thread_n == -1) { + if (prob_m <= 16) { + // For small batchizes, better partioning is slightly more important than better compute utilization + thread_k = 128; + thread_n = 128; + } else { + thread_k = 64; + thread_n = 256; + } + } + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0)) + return ERR_PROB_SHAPE; + if (prob_m == 0 || prob_n == 0 || prob_k == 0) + return 0; + + const int4* A_ptr = (const int4*) A; + const int4* B_ptr = (const int4*) B; + int4* C_ptr = (int4*) C; + const int4* s_ptr = (const int4*) s; + + int cols = prob_n / thread_n; + int* locks = (int*) workspace; + + int ret = 0; + for (int i = 0; i < tot_m_blocks; i += 4) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + if (thread_m_blocks > 4) { + thread_m_blocks = 4; + prob_m = 64; + } + + // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance) + // in our testing, however many more are, in principle, possible. + if (false) {} + CALL_IF(1, 8, 8, -1) + CALL_IF(1, 8, 8, 8) + CALL_IF(1, 16, 4, -1) + CALL_IF(1, 16, 4, 8) + CALL_IF(2, 16, 4, -1) + CALL_IF(2, 16, 4, 8) + CALL_IF(3, 16, 4, -1) + CALL_IF(3, 16, 4, 8) + CALL_IF(4, 16, 4, -1) + CALL_IF(4, 16, 4, 8) + else + ret = ERR_KERN_SHAPE; + + A_ptr += 16 * thread_m_blocks * (prob_k / 8); + C_ptr += 16 * thread_m_blocks * (prob_n / 8); + } + + return ret; +} + +#endif + +} // namespace marlin +} // namespace vllm + +const int ERR_PROB_SHAPE = 1; +const int ERR_KERN_SHAPE = 2; + +// input: `torch.half` input matrix of shape `(m, k)` in standard row-major layout +// weights: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()` +// output: `torch.half` out matrix of shape `(m, n)` in standard row-major layout +// scales: `torch.half` scales of shape `(m / groupsize, n)` +// workspace: `torch.int` tensor with at least as many entries as there a GPU SMs (256 is usually safe) + +void marlin_gemm( + const torch::Tensor& input, + const torch::Tensor& weights, + torch::Tensor& output, + const torch::Tensor& scales, + torch::Tensor& workspace +) { + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + int prob_m = input.size(0); + int prob_n = output.size(1); + int prob_k = weights.size(1); + int groupsize = (scales.size(0) == 1) ? -1 : prob_k / scales.size(0); + if (groupsize != -1 && groupsize * scales.size(0) != prob_k) + AT_ERROR("k=", prob_k, " not compatible with ", scales.size(0), " groups."); + int err = vllm::marlin::marlin_cuda( + input.data_ptr(), + weights.data_ptr(), + output.data_ptr(), + scales.data_ptr(), + prob_m, prob_n, prob_k, + workspace.data_ptr(), + groupsize, + input.get_device(), + thread_k, + thread_n, + sms + ); + if (err == ERR_PROB_SHAPE) { + AT_ERROR( + "Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")", + " not compatible with thread_k=", thread_k, ", thread_n=", thread_n, "." + ); + } else if (err == ERR_KERN_SHAPE) { + AT_ERROR( + "No kernel implementation for thread_k=", thread_k, ", thread_n=", thread_n, ", groupsize=", groupsize, "." + ); + } +} \ No newline at end of file From 7a43b29c07ed3fe7abdc0acd734e26f7570414f0 Mon Sep 17 00:00:00 2001 From: rib-2 Date: Thu, 18 Jan 2024 15:39:43 +0000 Subject: [PATCH 03/47] trying to load packed weights turning out to be tricky --- .../quantization/marlin/marlin_cuda_kernel.cu | 33 ++++++++++++++----- vllm/config.py | 11 ++++--- .../layers/quantization/__init__.py | 2 ++ vllm/model_executor/models/llama.py | 1 + 4 files changed, 34 insertions(+), 13 deletions(-) diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu index 61856abb8906..a27415486b9f 100644 --- a/csrc/quantization/marlin/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu @@ -1,3 +1,20 @@ +/* + * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + #ifndef MARLIN_CUDA_KERNEL_CUH #define MARLIN_CUDA_KERNEL_CUH @@ -13,8 +30,8 @@ constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } -// Instances of `Vec` are used to organize groups of >>registers<<, as are for instance needed as inputs to tensor core -// operations. Consequently, all corresponding index accesses must be compile time constants, which is why we +// Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core +// operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we // extensively use `#pragma unroll` throughout the kernel code to guarantee this. template struct Vec { @@ -110,7 +127,7 @@ __device__ inline int lop3(int a, int b, int c) { } // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values. -// We mostly follow the strategy in the link below, with some smaller changes: +// We mostly follow the strategy in the link below, with some small changes: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h __device__ inline FragB dequant(int q) { const int LO = 0x000f000f; @@ -135,7 +152,7 @@ __device__ inline FragB dequant(int q) { return frag_b; } -// Multiply dequantized values by the corresponding quantization scale; used only in group mode. +// Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization. __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); @@ -768,10 +785,10 @@ int marlin_cuda( const int ERR_PROB_SHAPE = 1; const int ERR_KERN_SHAPE = 2; -// input: `torch.half` input matrix of shape `(m, k)` in standard row-major layout -// weights: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()` -// output: `torch.half` out matrix of shape `(m, n)` in standard row-major layout -// scales: `torch.half` scales of shape `(m / groupsize, n)` +// input: `torch.half` input matrix of shape `(m, k)` in standard row-major layout +// weights: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()` +// output: `torch.half` out matrix of shape `(m, n)` in standard row-major layout +// scales: `torch.half` scales of shape `(m / groupsize, n)` // workspace: `torch.int` tensor with at least as many entries as there a GPU SMs (256 is usually safe) void marlin_gemm( diff --git a/vllm/config.py b/vllm/config.py index f1efcc66e909..fe8da0138f4b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -144,8 +144,8 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq", "gptq", "squeezellm"] - rocm_not_supported_quantization = ["awq"] + supported_quantization = ["awq", "gptq", "squeezellm", "marlin"] + rocm_not_supported_quantization = ["awq", "marlin"] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -172,9 +172,10 @@ def _verify_quantization(self) -> None: raise ValueError( f"{self.quantization} quantization is currently not supported " f"in ROCm.") - logger.warning(f"{self.quantization} quantization is not fully " - "optimized yet. The speed can be slower than " - "non-quantized models.") + if self.quantization != "marlin": + logger.warning(f"{self.quantization} quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models.") def _verify_cuda_graph(self) -> None: if self.max_context_len_to_capture is None: diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index b3449eaff0e3..dc54641878c6 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -4,11 +4,13 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig +from vllm.model_executor.layers.quantization.marlin import MarlinConfig _QUANTIZATION_CONFIG_REGISTRY = { "awq": AWQConfig, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, + "marlin": MarlinConfig, } diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 3791aa893893..b448321083ff 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -310,6 +310,7 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: From e0346404fd880f5ee1ab999ff4b9ad501643fa61 Mon Sep 17 00:00:00 2001 From: rib-2 Date: Thu, 18 Jan 2024 15:40:16 +0000 Subject: [PATCH 04/47] trying to load packed weights turning out to be tricky due to qkv --- .../layers/quantization/marlin.py | 193 ++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 vllm/model_executor/layers/quantization/marlin.py diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py new file mode 100644 index 000000000000..0f4a3f6cd817 --- /dev/null +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -0,0 +1,193 @@ +import numpy as np +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +# Essentially all reasonable GPUs have less than 256 SMs so this should be safe for now +MAX_SMS = 256 +# Tile size used by Marlin Kernels +TILE_SIZE = 16 +# 4 Bits Packed Into 32 Bit Dtype +PACK_FACTOR = 32 // 4 + +class MarlinConfig(QuantizationConfig): + """Config class for Marlin. + + Reference: https://github.com/IST-DASLab/marlin/tree/master + """ + + def __init__( + self, + group_size: int, + ) -> None: + self.group_size = group_size + # 4Bits packed into Int32. + self.pack_factor = 32 // 4 + # Tile size of 16 used by Marlin. + self.tile_size = 16 + + # todo(rib-2): add channelwise support (-1). + if self.group_size != 128: + raise ValueError( + "Currently, only group size 128 is supported for Marlin " + f"but got {self.group_size} bits.") + + def __repr__(self) -> str: + return (f"MarlinConfig(group_size={self.group_size}") + + @classmethod + def get_name(cls) -> str: + return "marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(group_size) + + def get_linear_method(self) -> "MarlinLinearMethod": + return MarlinLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + +class MarlinLinearMethod(LinearMethodBase): + """Linear method for Marlin. + + Args: + quant_config: The Marlin quantization config. + """ + + def __init__(self, quant_config: MarlinConfig): + self.quant_config = quant_config + self._perm_len = 1024 + + def create_weights( + self, + input_size_per_partition: int, + output_size_per_partition: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + del output_size # Unused. + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + if input_size_per_partition % 128 != 0: + raise ValueError( + "The input_size_per_partition must be divisible by 128, " + f"but got {input_size_per_partition}") + + if output_size_per_partition % 256 != 0: + raise ValueError( + "The output_size_per_partition must be divisible by 256, " + f"but got {output_size_per_partition}") + + # check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self._perm_len // (self.quant_config.tile_size ** 2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + "Each permutation group must reside on the same gpu" + ) + + # Quantized 4Bit weights packed into Int32. + qweight = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.tile_size, + output_size_per_partition * self.quant_config.tile_size // self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32 + ), + requires_grad=False, + ) + + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }) + + # Scales in Float16. + scales = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + device="cuda", + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(scales, { + "input_dim": None if input_size == input_size_per_partition else 0, + "output_dim": 1, + }) + + # Workspace for the marlin kernels. + self.workspace = torch.empty(MAX_SMS, dtype=torch.int) + + return { + "B": qweight, + "s": scales, + } + + def apply_weights(self, + weights: Dict[str, Any], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = weights["B"] + scales = weights["s"] + + output = torch.empty( + x.shape[:-1] + (scales.shape[1],), + dtype=x.dtype, + device=x.device + ) + + print(scales.shape) + print(qweight.shape) + print(x.shape) + + ops.marlin_gemm( + x.view(-1, x.shape[-1]), + qweight, + output.view(-1, output.shape[-1]), + scales, + self.workspace + ) + + if bias is not None: + output = output + bias + return output + \ No newline at end of file From 15e8f9cd6ae59df3674ead8393850445a1f17a63 Mon Sep 17 00:00:00 2001 From: rib-2 Date: Thu, 18 Jan 2024 19:23:25 +0000 Subject: [PATCH 05/47] integrated marlin for single gpu --- .../quantization/marlin/marlin_cuda_kernel.cu | 2 +- vllm/model_executor/layers/linear.py | 35 +++++++++++++++++++ .../layers/quantization/marlin.py | 7 +--- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu index a27415486b9f..dd98fc6ff87d 100644 --- a/csrc/quantization/marlin/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu @@ -807,7 +807,7 @@ void marlin_gemm( int prob_m = input.size(0); int prob_n = output.size(1); - int prob_k = weights.size(1); + int prob_k = input.size(1); int groupsize = (scales.size(0) == 1) ? -1 : prob_k / scales.size(0); if (groupsize != -1 && groupsize * scales.size(0) != prob_k) AT_ERROR("k=", prob_k, " not compatible with ", scales.size(0), " groups."); diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5e1d63a6a62e..3b954fb8ffa7 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -280,6 +280,14 @@ def weight_loader(self, if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor + + # If marlin, we need to adjust the offset and size to account + # for the tiling. + marlin_tile_size = getattr(param, "tile_size", None) + if marlin_tile_size is not None: + shard_size = shard_size * marlin_tile_size + shard_offset = shard_offset * marlin_tile_size + loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) self.weight_loader(param, loaded_weight_shard, shard_id) @@ -297,6 +305,14 @@ def weight_loader(self, if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor + + # If marlin, we need to adjust the offset and size to account + # for the tiling. + marlin_tile_size = getattr(param, "tile_size", None) + if marlin_tile_size is not None: + shard_size = shard_size * marlin_tile_size + shard_offset = shard_offset * marlin_tile_size + param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = tp_rank * shard_size @@ -376,7 +392,10 @@ def weight_loader(self, loaded_shard_id: Optional[str] = None): param_data = param.data output_dim = getattr(param, "output_dim", None) + if loaded_shard_id is None: + print("--------- HERE 2") + # Loaded weight is already packed. if output_dim is None: assert param_data.shape == loaded_weight.shape @@ -397,6 +416,14 @@ def weight_loader(self, if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor + + # If marlin, we need to adjust the offset and size to account + # for the tiling. + marlin_tile_size = getattr(param, "tile_size", None) + if marlin_tile_size is not None: + shard_size = shard_size * marlin_tile_size + shard_offset = shard_offset * marlin_tile_size + loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) self.weight_loader(param, loaded_weight_shard, shard_id) @@ -421,6 +448,14 @@ def weight_loader(self, if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor + + # If marlin, we need to adjust the offset and size to account + # for the tiling + marlin_tile_size = getattr(param, "tile_size", None) + if marlin_tile_size is not None: + shard_size = shard_size * marlin_tile_size + shard_offset = shard_offset * marlin_tile_size + param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 0f4a3f6cd817..5b0d4c75d23c 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -130,13 +130,13 @@ def create_weights( ), requires_grad=False, ) - set_weight_attrs( qweight, { "input_dim": 0, "output_dim": 1, "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, + "tile_size": TILE_SIZE, }) # Scales in Float16. @@ -174,11 +174,6 @@ def apply_weights(self, dtype=x.dtype, device=x.device ) - - print(scales.shape) - print(qweight.shape) - print(x.shape) - ops.marlin_gemm( x.view(-1, x.shape[-1]), qweight, From d8286fb3048adc47c1e8e362c8c27285b864a2a8 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+rib-2@users.noreply.github.com> Date: Fri, 19 Jan 2024 09:58:46 -0500 Subject: [PATCH 06/47] Update llama.py --- vllm/model_executor/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b448321083ff..3791aa893893 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -310,7 +310,6 @@ def load_weights(self, ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: From 8bc625f0941c7314afd86bab90584747429c323b Mon Sep 17 00:00:00 2001 From: alexm Date: Fri, 19 Jan 2024 14:32:43 -0500 Subject: [PATCH 07/47] Fixes to Marlin quantization to allow execution via CUDA graphs capture (eager_force=False) --- .../quantization/marlin/marlin_cuda_kernel.cu | 30 +++++++++- setup.py | 3 +- vllm/config.py | 7 ++- vllm/model_executor/layers/linear.py | 10 ++-- .../layers/quantization/marlin.py | 59 +++++++++---------- 5 files changed, 66 insertions(+), 43 deletions(-) diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu index dd98fc6ff87d..a1526a9e5d7c 100644 --- a/csrc/quantization/marlin/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu @@ -19,6 +19,7 @@ #define MARLIN_CUDA_KERNEL_CUH #include +#include #include #include #include @@ -182,6 +183,25 @@ __device__ inline void barrier_release(int* lock) { } } +__global__ void zero_data(int* data, int len, int blocks, int threads) { + const int idx = blockIdx.x * threads + threadIdx.x; + long total_size = sizeof(int) * len; + long chunk_size = total_size / (blocks * threads); + + long thread_offset = idx * chunk_size; + if (thread_offset > total_size) { + return; + } + + if (thread_offset + chunk_size > total_size) { + chunk_size = total_size - thread_offset; + } + + void* start_addr = reinterpret_cast(reinterpret_cast(data) + thread_offset); + memset((void *) start_addr, 0, chunk_size); + + __syncthreads(); +} template < const int threads, // number of threads in a threadblock @@ -675,18 +695,22 @@ const int THREADS = 256; const int STAGES = 4; // 4 pipeline stages fit into shared memory const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) +// Essentially all reasonable GPUs have less than 256 SMs so this should be safe for now +const int MAX_SMS = 256; + #define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ else if ( \ thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \ group_blocks == GROUP_BLOCKS \ ) { \ - cudaMemset(locks, 0, 4 * cols); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + zero_data<<<1, MAX_SMS, 0, stream>>>(locks, MAX_SMS, 1, MAX_SMS); \ cudaFuncSetAttribute( \ Marlin, \ cudaFuncAttributeMaxDynamicSharedMemorySize, \ SHARED_MEM \ - ); \ - Marlin<<>>( \ + ); \ + Marlin<<>>( \ A_ptr, B_ptr, C_ptr, s_ptr, \ prob_m, prob_n, prob_k, \ locks \ diff --git a/setup.py b/setup.py index f54d824f4236..94ae06020bc4 100644 --- a/setup.py +++ b/setup.py @@ -227,7 +227,8 @@ def get_torch_arch_list() -> Set[str]: if _is_cuda(): vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") print("\n\n HERE \n\n") - vllm_extension_sources.append("csrc/quantization/marlin/marlin_cuda_kernel.cu") + vllm_extension_sources.append( + "csrc/quantization/marlin/marlin_cuda_kernel.cu") vllm_extension = CUDAExtension( name="vllm._C", diff --git a/vllm/config.py b/vllm/config.py index fe8da0138f4b..8cd96dfec382 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -173,9 +173,10 @@ def _verify_quantization(self) -> None: f"{self.quantization} quantization is currently not supported " f"in ROCm.") if self.quantization != "marlin": - logger.warning(f"{self.quantization} quantization is not fully " - "optimized yet. The speed can be slower than " - "non-quantized models.") + logger.warning( + f"{self.quantization} quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models.") def _verify_cuda_graph(self) -> None: if self.max_context_len_to_capture is None: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3b954fb8ffa7..e326ebe517bd 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -305,7 +305,7 @@ def weight_loader(self, if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - + # If marlin, we need to adjust the offset and size to account # for the tiling. marlin_tile_size = getattr(param, "tile_size", None) @@ -423,7 +423,7 @@ def weight_loader(self, if marlin_tile_size is not None: shard_size = shard_size * marlin_tile_size shard_offset = shard_offset * marlin_tile_size - + loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) self.weight_loader(param, loaded_weight_shard, shard_id) @@ -448,14 +448,14 @@ def weight_loader(self, if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - - # If marlin, we need to adjust the offset and size to account + + # If marlin, we need to adjust the offset and size to account # for the tiling marlin_tile_size = getattr(param, "tile_size", None) if marlin_tile_size is not None: shard_size = shard_size * marlin_tile_size shard_offset = shard_offset * marlin_tile_size - + param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 5b0d4c75d23c..a54eedf3c40f 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -1,4 +1,3 @@ -import numpy as np from typing import Any, Dict, List, Optional import torch @@ -17,6 +16,7 @@ # 4 Bits Packed Into 32 Bit Dtype PACK_FACTOR = 32 // 4 + class MarlinConfig(QuantizationConfig): """Config class for Marlin. @@ -32,7 +32,7 @@ def __init__( self.pack_factor = 32 // 4 # Tile size of 16 used by Marlin. self.tile_size = 16 - + # todo(rib-2): add channelwise support (-1). if self.group_size != 128: raise ValueError( @@ -70,6 +70,7 @@ def get_linear_method(self) -> "MarlinLinearMethod": def get_scaled_act_names(self) -> List[str]: return [] + class MarlinLinearMethod(LinearMethodBase): """Linear method for Marlin. @@ -107,27 +108,26 @@ def create_weights( raise ValueError( "The input_size_per_partition must be divisible by 128, " f"but got {input_size_per_partition}") - + if output_size_per_partition % 256 != 0: raise ValueError( "The output_size_per_partition must be divisible by 256, " f"but got {output_size_per_partition}") # check that we have at least 4 tiles horizontally in the shard - num_tiles_per_perm = self._perm_len // (self.quant_config.tile_size ** 2) + num_tiles_per_perm = self._perm_len // (self.quant_config.tile_size**2) if output_size_per_partition % num_tiles_per_perm != 0: raise ValueError( - "Each permutation group must reside on the same gpu" - ) + "Each permutation group must reside on the same gpu") # Quantized 4Bit weights packed into Int32. qweight = Parameter( torch.empty( - input_size_per_partition // self.quant_config.tile_size, - output_size_per_partition * self.quant_config.tile_size // self.quant_config.pack_factor, + input_size_per_partition // self.quant_config.tile_size, + output_size_per_partition * self.quant_config.tile_size // + self.quant_config.pack_factor, device="cuda", - dtype=torch.int32 - ), + dtype=torch.int32), requires_grad=False, ) set_weight_attrs( @@ -143,19 +143,21 @@ def create_weights( scales = Parameter( torch.empty( input_size_per_partition // self.quant_config.group_size, - output_size_per_partition, + output_size_per_partition, device="cuda", dtype=params_dtype, ), requires_grad=False, ) - set_weight_attrs(scales, { - "input_dim": None if input_size == input_size_per_partition else 0, - "output_dim": 1, - }) + set_weight_attrs( + scales, { + "input_dim": + None if input_size == input_size_per_partition else 0, + "output_dim": 1, + }) # Workspace for the marlin kernels. - self.workspace = torch.empty(MAX_SMS, dtype=torch.int) + self.workspace = torch.empty(MAX_SMS, dtype=torch.int, device="cuda") return { "B": qweight, @@ -169,20 +171,15 @@ def apply_weights(self, qweight = weights["B"] scales = weights["s"] - output = torch.empty( - x.shape[:-1] + (scales.shape[1],), - dtype=x.dtype, - device=x.device - ) - ops.marlin_gemm( - x.view(-1, x.shape[-1]), - qweight, - output.view(-1, output.shape[-1]), - scales, - self.workspace - ) - + output = torch.empty(x.shape[:-1] + (scales.shape[1], ), + dtype=x.dtype, + device=x.device) + + ops.marlin_gemm(x.view(-1, x.shape[-1]), qweight, + output.view(-1, output.shape[-1]), scales, + self.workspace) + if bias is not None: - output = output + bias + output.add_(bias) + return output - \ No newline at end of file From 2691e89510e6fd1a293c08af139ae57c16cd9c14 Mon Sep 17 00:00:00 2001 From: alexm Date: Fri, 19 Jan 2024 16:24:33 -0500 Subject: [PATCH 08/47] Integrate @efrantar's changes for CUDA graphs --- .../quantization/marlin/marlin_cuda_kernel.cu | 40 +++++-------------- .../layers/quantization/marlin.py | 9 ++++- 2 files changed, 18 insertions(+), 31 deletions(-) diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu index a1526a9e5d7c..36a10acd52df 100644 --- a/csrc/quantization/marlin/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu @@ -173,9 +173,13 @@ __device__ inline void barrier_acquire(int* lock, int count) { } // Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock) { +__device__ inline void barrier_release(int* lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } int val = 1; // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier. asm volatile ("fence.acq_rel.gpu;\n"); @@ -183,26 +187,6 @@ __device__ inline void barrier_release(int* lock) { } } -__global__ void zero_data(int* data, int len, int blocks, int threads) { - const int idx = blockIdx.x * threads + threadIdx.x; - long total_size = sizeof(int) * len; - long chunk_size = total_size / (blocks * threads); - - long thread_offset = idx * chunk_size; - if (thread_offset > total_size) { - return; - } - - if (thread_offset + chunk_size > total_size) { - chunk_size = total_size - thread_offset; - } - - void* start_addr = reinterpret_cast(reinterpret_cast(data) + thread_offset); - memset((void *) start_addr, 0, chunk_size); - - __syncthreads(); -} - template < const int threads, // number of threads in a threadblock const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock @@ -669,7 +653,7 @@ __global__ void Marlin( if (slice_count > 1) { // only globally reduce if there is more than one block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col]); + barrier_release(&locks[slice_col], last); } if (last) // only the last block in a slice actually writes the result write_result(); @@ -695,16 +679,11 @@ const int THREADS = 256; const int STAGES = 4; // 4 pipeline stages fit into shared memory const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) -// Essentially all reasonable GPUs have less than 256 SMs so this should be safe for now -const int MAX_SMS = 256; - #define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ else if ( \ thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \ group_blocks == GROUP_BLOCKS \ ) { \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - zero_data<<<1, MAX_SMS, 0, stream>>>(locks, MAX_SMS, 1, MAX_SMS); \ cudaFuncSetAttribute( \ Marlin, \ cudaFuncAttributeMaxDynamicSharedMemorySize, \ @@ -731,6 +710,7 @@ int marlin_cuda( void* workspace, int groupsize = -1, int dev = 0, + cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1, int sms = -1 @@ -813,7 +793,7 @@ const int ERR_KERN_SHAPE = 2; // weights: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()` // output: `torch.half` out matrix of shape `(m, n)` in standard row-major layout // scales: `torch.half` scales of shape `(m / groupsize, n)` -// workspace: `torch.int` tensor with at least as many entries as there a GPU SMs (256 is usually safe) +// workspace: `torch.int` tensor with at least `n / 128` entries that are all zero void marlin_gemm( const torch::Tensor& input, @@ -835,6 +815,7 @@ void marlin_gemm( int groupsize = (scales.size(0) == 1) ? -1 : prob_k / scales.size(0); if (groupsize != -1 && groupsize * scales.size(0) != prob_k) AT_ERROR("k=", prob_k, " not compatible with ", scales.size(0), " groups."); + int dev = input.get_device(); int err = vllm::marlin::marlin_cuda( input.data_ptr(), weights.data_ptr(), @@ -843,7 +824,8 @@ void marlin_gemm( prob_m, prob_n, prob_k, workspace.data_ptr(), groupsize, - input.get_device(), + dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index a54eedf3c40f..328dabaff485 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -33,6 +33,9 @@ def __init__( # Tile size of 16 used by Marlin. self.tile_size = 16 + # Maximum workspace (>= than the number of GPU SMs => so 512 is safe) + self.max_workspace_size = 512 + # todo(rib-2): add channelwise support (-1). if self.group_size != 128: raise ValueError( @@ -156,8 +159,10 @@ def create_weights( "output_dim": 1, }) - # Workspace for the marlin kernels. - self.workspace = torch.empty(MAX_SMS, dtype=torch.int, device="cuda") + # Alloc workspace (shared across invocations of this layer) + self.workspace = torch.zeros(self.quant_config.max_workspace_size, + dtype=torch.int, + device="cuda") return { "B": qweight, From 92f7290fb8ea3ea99f423f81aba8f7b37cb0069b Mon Sep 17 00:00:00 2001 From: alexm Date: Fri, 19 Jan 2024 16:52:06 -0500 Subject: [PATCH 09/47] review comments based on zhyncs --- setup.py | 1 - vllm/model_executor/layers/linear.py | 44 +++++++++---------- .../layers/quantization/marlin.py | 2 +- 3 files changed, 21 insertions(+), 26 deletions(-) diff --git a/setup.py b/setup.py index 94ae06020bc4..b6fe614ab10c 100644 --- a/setup.py +++ b/setup.py @@ -226,7 +226,6 @@ def get_torch_arch_list() -> Set[str]: if _is_cuda(): vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") - print("\n\n HERE \n\n") vllm_extension_sources.append( "csrc/quantization/marlin/marlin_cuda_kernel.cu") diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index e326ebe517bd..69d710472c4f 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -17,6 +17,14 @@ logger = init_logger(__name__) +def adjust_marlin_shard(param, shard_size, shard_offset): + marlin_tile_size = getattr(param, "marlin_tile_size", None) + if marlin_tile_size is None: + return shard_size, shard_offset + + return shard_size * marlin_tile_size, shard_offset * marlin_tile_size + + class LinearMethodBase(ABC): """Base class for different (maybe quantized) linear methods.""" @@ -281,12 +289,9 @@ def weight_loader(self, shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - # If marlin, we need to adjust the offset and size to account - # for the tiling. - marlin_tile_size = getattr(param, "tile_size", None) - if marlin_tile_size is not None: - shard_size = shard_size * marlin_tile_size - shard_offset = shard_offset * marlin_tile_size + # If marlin, we need to adjust the offset and size to account for the tiling. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) @@ -306,12 +311,9 @@ def weight_loader(self, shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - # If marlin, we need to adjust the offset and size to account - # for the tiling. - marlin_tile_size = getattr(param, "tile_size", None) - if marlin_tile_size is not None: - shard_size = shard_size * marlin_tile_size - shard_offset = shard_offset * marlin_tile_size + # If marlin, we need to adjust the offset and size to account for the tiling. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) param_data = param_data.narrow(output_dim, shard_offset, shard_size) @@ -417,12 +419,9 @@ def weight_loader(self, shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - # If marlin, we need to adjust the offset and size to account - # for the tiling. - marlin_tile_size = getattr(param, "tile_size", None) - if marlin_tile_size is not None: - shard_size = shard_size * marlin_tile_size - shard_offset = shard_offset * marlin_tile_size + # If marlin, we need to adjust the offset and size to account for the tiling. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) @@ -449,12 +448,9 @@ def weight_loader(self, shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - # If marlin, we need to adjust the offset and size to account - # for the tiling - marlin_tile_size = getattr(param, "tile_size", None) - if marlin_tile_size is not None: - shard_size = shard_size * marlin_tile_size - shard_offset = shard_offset * marlin_tile_size + # If marlin, we need to adjust the offset and size to account for the tiling. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) param_data = param_data.narrow(output_dim, shard_offset, shard_size) diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 328dabaff485..3b65f1f44106 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -139,7 +139,7 @@ def create_weights( "output_dim": 1, "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, - "tile_size": TILE_SIZE, + "marlin_tile_size": TILE_SIZE, }) # Scales in Float16. From bc10e4b922eea17c04e5dbf2997d528518896737 Mon Sep 17 00:00:00 2001 From: alexm Date: Tue, 30 Jan 2024 12:30:37 -0500 Subject: [PATCH 10/47] (1) Integrate the latest changes from Elias that improve large batch size by running multiple parallel problems of size 64. (2) Refactor the workspace to be dynamic per layer --- .../quantization/marlin/marlin_cuda_kernel.cu | 60 ++++++++++++---- vllm/model_executor/layers/linear.py | 2 - .../layers/quantization/marlin.py | 69 ++++++++++++------- 3 files changed, 93 insertions(+), 38 deletions(-) diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu index 36a10acd52df..ed124c4b9f94 100644 --- a/csrc/quantization/marlin/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu @@ -65,7 +65,7 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool ); } -// Asynchronous global->shared copy with a chache hint indicating that the values may be evicted immediately; used for +// Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for // quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need // for inputs A and outputs C. __device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { @@ -214,24 +214,38 @@ __global__ void Marlin( // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as // possible. + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + int k_tiles = prob_k / 16 / thread_k_blocks; int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles, gridDim.x); + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case // where a stripe starts in the middle of group. if (group_blocks != -1) iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks)); int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col = (iters * blockIdx.x) / k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; int slice_iters; // number of threadblock tiles in the current slice int slice_count = 0; // total number of active threadblocks in the current slice int slice_idx; // index of threadblock in current slice; numbered bottom to top + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + // Compute all information about the current slice which is required for synchronization. auto init_slice = [&] () { - slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col + slice_row); - if (slice_iters < 0 || slice_col >= n_tiles) + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; if (slice_iters == 0) return; @@ -239,9 +253,9 @@ __global__ void Marlin( slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col, iters); - if (col_first <= k_tiles * (slice_col + 1)) { - int col_off = col_first - k_tiles * slice_col; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; slice_count = ceildiv(k_tiles - col_off, iters); if (col_off > 0) slice_count++; @@ -254,6 +268,12 @@ __global__ void Marlin( slice_idx--; } } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } }; init_slice(); @@ -658,6 +678,7 @@ __global__ void Marlin( if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; + slice_col_par++; slice_col++; init_slice(); if (slice_iters) { @@ -713,10 +734,12 @@ int marlin_cuda( cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1, - int sms = -1 + int sms = -1, + int max_par = 8 ) { int tot_m = prob_m; int tot_m_blocks = ceildiv(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; if (sms == -1) cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); @@ -753,9 +776,15 @@ int marlin_cuda( for (int i = 0; i < tot_m_blocks; i += 4) { int thread_m_blocks = tot_m_blocks - i; prob_m = tot_m - 16 * i; + int par = 1; if (thread_m_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any padding + par = (16 * thread_m_blocks - pad) / 64; + if (par > max_par) + par = max_par; + prob_m = 64 * par; + i += 4 * (par - 1); thread_m_blocks = 4; - prob_m = 64; } // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance) @@ -774,8 +803,8 @@ int marlin_cuda( else ret = ERR_KERN_SHAPE; - A_ptr += 16 * thread_m_blocks * (prob_k / 8); - C_ptr += 16 * thread_m_blocks * (prob_n / 8); + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; } return ret; @@ -808,6 +837,8 @@ void marlin_gemm( int thread_n = -1; // sms: number of SMs to use for the kernel (can usually be left as auto -1) int sms = -1; + // number of parallel problems to solve (helps with large batch sizes) + int max_par = 8; int prob_m = input.size(0); int prob_n = output.size(1); @@ -815,6 +846,8 @@ void marlin_gemm( int groupsize = (scales.size(0) == 1) ? -1 : prob_k / scales.size(0); if (groupsize != -1 && groupsize * scales.size(0) != prob_k) AT_ERROR("k=", prob_k, " not compatible with ", scales.size(0), " groups."); + if (workspace.numel() < (prob_n / 128) * max_par) + AT_ERROR("workspace must be of size at least ", (prob_n / 128) * max_par, "."); int dev = input.get_device(); int err = vllm::marlin::marlin_cuda( input.data_ptr(), @@ -828,7 +861,8 @@ void marlin_gemm( at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, - sms + sms, + max_par ); if (err == ERR_PROB_SHAPE) { AT_ERROR( diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 69d710472c4f..b0c7e4f30e94 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -396,8 +396,6 @@ def weight_loader(self, output_dim = getattr(param, "output_dim", None) if loaded_shard_id is None: - print("--------- HERE 2") - # Loaded weight is already packed. if output_dim is None: assert param_data.shape == loaded_weight.shape diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 3b65f1f44106..79c2aeca48f4 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -4,10 +4,8 @@ from torch.nn.parameter import Parameter from vllm._C import ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig # Essentially all reasonable GPUs have less than 256 SMs so this should be safe for now MAX_SMS = 256 @@ -43,7 +41,7 @@ def __init__( f"but got {self.group_size} bits.") def __repr__(self) -> str: - return (f"MarlinConfig(group_size={self.group_size}") + return f"MarlinConfig(group_size={self.group_size}" @classmethod def get_name(cls) -> str: @@ -74,6 +72,23 @@ def get_scaled_act_names(self) -> List[str]: return [] +class MarlinWorkspace: + + def __init__(self, out_features): + max_parallel = 8 + min_n_threads = 128 + + assert ( + out_features % min_n_threads == 0 + ), "out_features = {out_features} is not divisible by min_n_threads = {min_n_threads}" + + max_workspace_size = (out_features // min_n_threads) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda") + + class MarlinLinearMethod(LinearMethodBase): """Linear method for Marlin. @@ -93,7 +108,7 @@ def create_weights( output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: - del output_size # Unused. + # del output_size # Unused. if params_dtype != torch.float16: raise ValueError( f"The params dtype must be float16, but got {params_dtype}") @@ -130,17 +145,20 @@ def create_weights( output_size_per_partition * self.quant_config.tile_size // self.quant_config.pack_factor, device="cuda", - dtype=torch.int32), + dtype=torch.int32, + ), requires_grad=False, ) set_weight_attrs( - qweight, { + qweight, + { "input_dim": 0, "output_dim": 1, "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, "marlin_tile_size": TILE_SIZE, - }) + }, + ) # Scales in Float16. scales = Parameter( @@ -153,36 +171,41 @@ def create_weights( requires_grad=False, ) set_weight_attrs( - scales, { + scales, + { "input_dim": None if input_size == input_size_per_partition else 0, "output_dim": 1, - }) - - # Alloc workspace (shared across invocations of this layer) - self.workspace = torch.zeros(self.quant_config.max_workspace_size, - dtype=torch.int, - device="cuda") + }, + ) return { "B": qweight, "s": scales, + "workspace": MarlinWorkspace(output_size_per_partition), } - def apply_weights(self, - weights: Dict[str, Any], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply_weights( + self, + weights: Dict[str, Any], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: qweight = weights["B"] scales = weights["s"] + workspace = weights["workspace"] output = torch.empty(x.shape[:-1] + (scales.shape[1], ), dtype=x.dtype, device=x.device) - ops.marlin_gemm(x.view(-1, x.shape[-1]), qweight, - output.view(-1, output.shape[-1]), scales, - self.workspace) + ops.marlin_gemm( + x.view(-1, x.shape[-1]), + qweight, + output.view(-1, output.shape[-1]), + scales, + workspace.scratch, + ) if bias is not None: output.add_(bias) From 47987da0ed971205f3c2ee4333561e9a642d80d8 Mon Sep 17 00:00:00 2001 From: alexm Date: Tue, 30 Jan 2024 13:42:02 -0500 Subject: [PATCH 11/47] add bug fix --- csrc/quantization/marlin/marlin_cuda_kernel.cu | 11 +++++++++-- vllm/model_executor/layers/quantization/marlin.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu index ed124c4b9f94..b145d2deee5d 100644 --- a/csrc/quantization/marlin/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu @@ -214,6 +214,7 @@ __global__ void Marlin( // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as // possible. + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions int parallel = 1; if (prob_m > 16 * thread_m_blocks) { parallel = prob_m / (16 * thread_m_blocks); @@ -235,6 +236,7 @@ __global__ void Marlin( int slice_count = 0; // total number of active threadblocks in the current slice int slice_idx; // index of threadblock in current slice; numbered bottom to top + // We can easily implement parallel problem execution by just remapping indices and advancing global pointers if (slice_col_par >= n_tiles) { A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; @@ -686,6 +688,11 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } s_gl_rd = s_sh_stride * slice_col + threadIdx.x; start_pipes(); } @@ -735,7 +742,7 @@ int marlin_cuda( int thread_k = -1, int thread_n = -1, int sms = -1, - int max_par = 8 + int max_par = 16 ) { int tot_m = prob_m; int tot_m_blocks = ceildiv(tot_m, 16); @@ -838,7 +845,7 @@ void marlin_gemm( // sms: number of SMs to use for the kernel (can usually be left as auto -1) int sms = -1; // number of parallel problems to solve (helps with large batch sizes) - int max_par = 8; + int max_par = 16; int prob_m = input.size(0); int prob_n = output.size(1); diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 79c2aeca48f4..2cd9058d1660 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -75,7 +75,7 @@ def get_scaled_act_names(self) -> List[str]: class MarlinWorkspace: def __init__(self, out_features): - max_parallel = 8 + max_parallel = 16 min_n_threads = 128 assert ( From 43aa818aa43c13bbb425c34e61579c31d141923d Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Thu, 1 Feb 2024 21:51:27 +0000 Subject: [PATCH 12/47] refactored some of alex's work to be consistent with the gptq config --- .../layers/quantization/marlin.py | 82 +++++++++---------- 1 file changed, 37 insertions(+), 45 deletions(-) diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 2cd9058d1660..451640740490 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -7,14 +7,6 @@ from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -# Essentially all reasonable GPUs have less than 256 SMs so this should be safe for now -MAX_SMS = 256 -# Tile size used by Marlin Kernels -TILE_SIZE = 16 -# 4 Bits Packed Into 32 Bit Dtype -PACK_FACTOR = 32 // 4 - - class MarlinConfig(QuantizationConfig): """Config class for Marlin. @@ -25,20 +17,22 @@ def __init__( self, group_size: int, ) -> None: + # Group size for the quantization. self.group_size = group_size - # 4Bits packed into Int32. + if self.group_size != 128 and self.group_size != -1: + raise ValueError( + "Currently, only group size 128 and -1 (channelwise) is supported for " + f"Marlin, but got group_size of {self.group_size}") + + # 4 Bits packed into 32 bit datatype. self.pack_factor = 32 // 4 - # Tile size of 16 used by Marlin. + # Tile size used by marlin kernels. self.tile_size = 16 - - # Maximum workspace (>= than the number of GPU SMs => so 512 is safe) - self.max_workspace_size = 512 - - # todo(rib-2): add channelwise support (-1). - if self.group_size != 128: - raise ValueError( - "Currently, only group size 128 is supported for Marlin " - f"but got {self.group_size} bits.") + # Data for workspace. + self.max_parallel = 16 + self.min_n_threads = 128 + # Permutation length use by the marlin kernels. + self.perm_len = 1024 def __repr__(self) -> str: return f"MarlinConfig(group_size={self.group_size}" @@ -54,7 +48,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod # Need to figure it out def get_min_capability(cls) -> int: - return 60 + return 80 @classmethod def get_config_filenames(cls) -> List[str]: @@ -71,24 +65,6 @@ def get_linear_method(self) -> "MarlinLinearMethod": def get_scaled_act_names(self) -> List[str]: return [] - -class MarlinWorkspace: - - def __init__(self, out_features): - max_parallel = 16 - min_n_threads = 128 - - assert ( - out_features % min_n_threads == 0 - ), "out_features = {out_features} is not divisible by min_n_threads = {min_n_threads}" - - max_workspace_size = (out_features // min_n_threads) * max_parallel - - self.scratch = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda") - - class MarlinLinearMethod(LinearMethodBase): """Linear method for Marlin. @@ -98,7 +74,6 @@ class MarlinLinearMethod(LinearMethodBase): def __init__(self, quant_config: MarlinConfig): self.quant_config = quant_config - self._perm_len = 1024 def create_weights( self, @@ -126,14 +101,18 @@ def create_weights( raise ValueError( "The input_size_per_partition must be divisible by 128, " f"but got {input_size_per_partition}") - if output_size_per_partition % 256 != 0: raise ValueError( "The output_size_per_partition must be divisible by 256, " f"but got {output_size_per_partition}") + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + "The output_size per partition must be divisible by the minimum " + f"number of threads {self.quant_config.min_n_threads}, but got {output_size_per_partition}" + ) # check that we have at least 4 tiles horizontally in the shard - num_tiles_per_perm = self._perm_len // (self.quant_config.tile_size**2) + num_tiles_per_perm = self.quant_config.perm_len // (self.quant_config.tile_size**2) if output_size_per_partition % num_tiles_per_perm != 0: raise ValueError( "Each permutation group must reside on the same gpu") @@ -156,14 +135,17 @@ def create_weights( "output_dim": 1, "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, - "marlin_tile_size": TILE_SIZE, + "marlin_tile_size": self.quant_config.tile_size, }, ) # Scales in Float16. + group_size = self.quant_config.group_size + if group_size == -1: + group_size = input_size scales = Parameter( torch.empty( - input_size_per_partition // self.quant_config.group_size, + input_size_per_partition // group_size, output_size_per_partition, device="cuda", dtype=params_dtype, @@ -179,10 +161,20 @@ def create_weights( }, ) + max_workspace_size = (output_size_per_partition // self.quant_config.min_n_threads) * self.quant_config.max_parallel + workspace = Parameter( + torch.zeros( + max_workspace_size, + device="cuda", + dtype=torch.int + ), + requires_grad=False + ) + return { "B": qweight, "s": scales, - "workspace": MarlinWorkspace(output_size_per_partition), + "workspace": workspace, } def apply_weights( @@ -204,7 +196,7 @@ def apply_weights( qweight, output.view(-1, output.shape[-1]), scales, - workspace.scratch, + workspace, ) if bias is not None: From 5906a60ce4f36df22fe93211f84f0d8b5981bf73 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Thu, 1 Feb 2024 23:09:44 +0000 Subject: [PATCH 13/47] updated to load model based on hf_config from AutoGPTQ --- vllm/config.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index 8cd96dfec382..ac6ade63f54b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -152,7 +152,15 @@ def _verify_quantization(self) -> None: # Parse quantization method from the HF model config, if available. hf_quant_config = getattr(self.hf_config, "quantization_config", None) if hf_quant_config is not None: + hf_quant_method = str(hf_quant_config["quant_method"]).lower() + # If the GPTQ model is serialized in marlin format, use marlin. + if ( + hf_quant_method == "gptq" and + "is_marlin_format" in hf_quant_config and + hf_quant_config["is_marlin_format"] + ): + hf_quant_method = "marlin" if self.quantization is None: self.quantization = hf_quant_method elif self.quantization != hf_quant_method: From 8dfeaa23690860a1e88ad687f7985f2d9b717dde Mon Sep 17 00:00:00 2001 From: alexm Date: Fri, 2 Feb 2024 13:43:05 -0500 Subject: [PATCH 14/47] Reduce Marlin's kernel limitation of thread_n from 256 to 64 (to avoid issues with tensor parallel runs) --- csrc/ops.h | 118 +- .../quantization/marlin/marlin_cuda_kernel.cu | 1066 ++++++++++------- .../layers/quantization/marlin.py | 39 +- 3 files changed, 713 insertions(+), 510 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 71ea700c11eb..7ba3c46acb4b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -2,97 +2,55 @@ #include -void paged_attention_v1( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len, - const c10::optional& alibi_slopes); +void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, + torch::Tensor &key_cache, torch::Tensor &value_cache, + int num_kv_heads, float scale, + torch::Tensor &block_tables, + torch::Tensor &context_lens, int block_size, + int max_context_len, + const c10::optional &alibi_slopes); -void paged_attention_v2( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - int num_kv_heads, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int block_size, - int max_context_len, - const c10::optional& alibi_slopes); +void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, + torch::Tensor &max_logits, torch::Tensor &tmp_out, + torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, + float scale, torch::Tensor &block_tables, + torch::Tensor &context_lens, int block_size, + int max_context_len, + const c10::optional &alibi_slopes); -void rms_norm( - torch::Tensor& out, - torch::Tensor& input, - torch::Tensor& weight, - float epsilon); +void rms_norm(torch::Tensor &out, torch::Tensor &input, torch::Tensor &weight, + float epsilon); -void fused_add_rms_norm( - torch::Tensor& input, - torch::Tensor& residual, - torch::Tensor& weight, - float epsilon); +void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, + torch::Tensor &weight, float epsilon); -void rotary_embedding( - torch::Tensor& positions, - torch::Tensor& query, - torch::Tensor& key, - int head_size, - torch::Tensor& cos_sin_cache, - bool is_neox); +void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, + torch::Tensor &key, int head_size, + torch::Tensor &cos_sin_cache, bool is_neox); -void silu_and_mul( - torch::Tensor& out, - torch::Tensor& input); +void silu_and_mul(torch::Tensor &out, torch::Tensor &input); -void gelu_new( - torch::Tensor& out, - torch::Tensor& input); +void gelu_new(torch::Tensor &out, torch::Tensor &input); -void gelu_fast( - torch::Tensor& out, - torch::Tensor& input); +void gelu_fast(torch::Tensor &out, torch::Tensor &input); #ifndef USE_ROCM -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters); +torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int split_k_iters); -void marlin_gemm( - const torch::Tensor& input, - const torch::Tensor& weights, - torch::Tensor& output, - const torch::Tensor& scales, - torch::Tensor& workspace); +torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_scales, torch::Tensor &workspace, + int64_t size_m, int64_t size_n, int64_t size_k); #endif -void squeezellm_gemm( - torch::Tensor vec, - torch::Tensor mat, - torch::Tensor mul, - torch::Tensor lookup_table); +void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, + torch::Tensor lookup_table); -torch::Tensor gptq_gemm( - torch::Tensor a, - torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, - torch::Tensor b_g_idx, - bool use_exllama); +torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, + bool use_exllama); -void gptq_shuffle( - torch::Tensor q_weight, - torch::Tensor q_perm); +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm); diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu index b145d2deee5d..87a1b4a94759 100644 --- a/csrc/quantization/marlin/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu @@ -14,70 +14,70 @@ * limitations under the License. */ - -#ifndef MARLIN_CUDA_KERNEL_CUH -#define MARLIN_CUDA_KERNEL_CUH - #include -#include + +#include +#include #include #include #include -namespace vllm { +#include + +template inline std::string str(T x) { return std::to_string(x); } + namespace marlin { -constexpr int ceildiv(int a, int b) { - return (a + b - 1) / b; -} +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } -// Instances of `Vec` are used to organize groups of >>registers<<, as needed for instance as inputs to tensor core -// operations. Consequently, all corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee this. -template -struct Vec { +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template struct Vec { T elems[n]; - __device__ T& operator[](int i) { - return elems[i]; - } + __device__ T &operator[](int i) { return elems[i]; } }; using I4 = Vec; -// Matrix fragments for tensor core instructions; their precise layout is documented here: +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type using FragA = Vec; using FragB = Vec; using FragC = Vec; using FragS = Vec; // quantization scales -// Predicated asynchronous global->shared copy; used for inputs A where we apply predication to handle batchsizes that -// are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(BYTES) - ); + asm volatile("{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); } -// Asynchronous global->shared copy with a cache hint indicating that the values may be evicted immediately; used for -// quantized weights B, which are only accessed precisely once and should thus not pollute the L2 cache which we need -// for inputs A and outputs C. -__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { +// Asynchronous global->shared copy with a cache hint indicating that the values +// may be evicted immediately; used for quantized weights B, which are only +// accessed precisely once and should thus not pollute the L2 cache which we +// need for inputs A and outputs C. +__device__ inline void cp_async4_stream(void *smem_ptr, const void *glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( - "{\n" - " .reg .b64 p;\n" - " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" - " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" - "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES) - ); + "{\n" + " .reg .b64 p;\n" + " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" + " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); } // Async copy fence. @@ -86,49 +86,48 @@ __device__ inline void cp_async_fence() { } // Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" :: "n"(n)); +template __device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); } -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]) - ); +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma(const FragA &a_frag, const FragB &frag_b, + FragC &frag_c) { + const uint32_t *a = reinterpret_cast(&a_frag); + const uint32_t *b = reinterpret_cast(&frag_b); + float *c = reinterpret_cast(&frag_c); + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), + "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } -// Instruction for loading a full 16x16 matrix fragment of operand A from shared memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem) - ); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); } -// Lookup-table based 3-input logical operation; explicitly used for dequantization as the compiler does not seem to -// automatically recognize it in all cases. -template -__device__ inline int lop3(int a, int b, int c) { +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template __device__ inline int lop3(int a, int b, int c) { int res; - asm volatile( - "lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut) - ); + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); return res; } -// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 values. -// We mostly follow the strategy in the link below, with some small changes: +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h __device__ inline FragB dequant(int q) { const int LO = 0x000f000f; @@ -137,43 +136,45 @@ __device__ inline FragB dequant(int q) { // Guarantee that the `(a & b) | c` operations are LOP3s. int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point directly into `SUB` and `ADD`. + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. const int SUB = 0x64086408; const int MUL = 0x2c002c00; const int ADD = 0xd480d480; FragB frag_b; - frag_b[0] = __hsub2( - *reinterpret_cast(&lo), - *reinterpret_cast(&SUB) - ); - frag_b[1] = __hfma2( - *reinterpret_cast(&hi), - *reinterpret_cast(&MUL), *reinterpret_cast(&ADD) - ); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } -// Multiply dequantized values by the corresponding quantization scale; used only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } // Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { +__device__ inline void barrier_acquire(int *lock, int count) { if (threadIdx.x == 0) { int state = -1; do - // Guarantee that subsequent writes by this threadblock will be visible globally. - asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); while (state != count); } __syncthreads(); } // Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { +__device__ inline void barrier_release(int *lock, bool reset = false) { __syncthreads(); if (threadIdx.x == 0) { if (reset) { @@ -181,40 +182,50 @@ __device__ inline void barrier_release(int* lock, bool reset = false) { return; } int val = 1; - // Make sure that all writes since acquiring this barrier are visible globally, while releasing the barrier. - asm volatile ("fence.acq_rel.gpu;\n"); - asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); } } -template < - const int threads, // number of threads in a threadblock - const int thread_m_blocks, // number of 16x16 blocks in the m dimension (batchsize) of the threadblock - const int thread_n_blocks, // same for n dimension (output) - const int thread_k_blocks, // same for k dimension (reduction) - const int stages, // number of stages for the async global->shared fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks with a separate quantization scale -> -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks with + // a separate quantization scale + > +__global__ void +Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 + *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization ) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the same size, which might involve multiple - // column "slices" (of width 16 * `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM example: - // 0 1 3 + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 // 0 2 3 // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it ensures good utilization of all SMs - // for many kinds of shape and GPU configurations, while requiring as few slow global cross-threadblock reductions as - // possible. + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a better partitioning with less reductions + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions int parallel = 1; if (prob_m > 16 * thread_m_blocks) { parallel = prob_m / (16 * thread_m_blocks); @@ -224,19 +235,24 @@ __global__ void Marlin( int k_tiles = prob_k / 16 / thread_k_blocks; int n_tiles = prob_n / 16 / thread_n_blocks; int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - // Ensure that the number of tiles in each stripe is a multiple of the groupsize; this avoids an annoying special case - // where a stripe starts in the middle of group. + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts in + // the middle of group. if (group_blocks != -1) - iters = (group_blocks / thread_k_blocks) * ceildiv(iters, (group_blocks / thread_k_blocks)); + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); int slice_row = (iters * blockIdx.x) % k_tiles; int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col = slice_col_par; int slice_iters; // number of threadblock tiles in the current slice - int slice_count = 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to top + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top - // We can easily implement parallel problem execution by just remapping indices and advancing global pointers + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers if (slice_col_par >= n_tiles) { A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; @@ -244,14 +260,16 @@ __global__ void Marlin( slice_col = slice_col_par % n_tiles; } - // Compute all information about the current slice which is required for synchronization. - auto init_slice = [&] () { - slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; @@ -280,15 +298,28 @@ __global__ void Marlin( init_slice(); int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory - // We typically use `constexpr` to indicate that this value is a compile-time constant - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; // delta between subsequent A tiles in global memory - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); // between shared memory tile reads - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; // within a shared memory tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); // overall size of a tile - constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); // number of shared write iterations for a tile + // We typically use `constexpr` to indicate that this value is a compile-time + // constant + constexpr int a_sh_stride = + 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = + 16 * thread_k_blocks / + 8; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = + a_gl_stride * + (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = + a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = + 2 * ((threads / 32) / + (thread_n_blocks / 4)); // between shared memory tile reads + constexpr int a_sh_rd_delta_i = + a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = + a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = + ceildiv(a_sh_stage, + a_sh_wr_delta); // number of shared write iterations for a tile int b_gl_stride = 16 * prob_n / 32; constexpr int b_sh_stride = 32 * thread_n_blocks / 4; @@ -305,155 +336,180 @@ __global__ void Marlin( int s_gl_rd_delta = s_gl_stride; // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); a_gl_rd += a_gl_rd_delta_o * slice_row; // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); // Shared read index. - int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + int b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; int b_sh_wr = threadIdx.x; int b_sh_rd = threadIdx.x; - int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; int s_sh_wr = threadIdx.x; int s_sh_rd; - // We use a different scale layout for grouped and column-wise quantization as we scale a `half2` tile in column-major - // layout in the former and in row-major in the latter case. + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. if (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; - - // Precompute which thread should not read memory in which iterations; this is needed if there are more threads than - // required for a certain tilesize or when the batchsize is not a multiple of 16. + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll +#pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - // To ensure that writing and reading A tiles to/from shared memory, the latter in fragment format, is fully bank - // conflict free, we need to use a rather fancy XOR-based layout. The key here is that neither reads nor writes of - // the 16-byte `int4` blocks of 8 consecutive threads involve the same shared memory banks. Further, it seems (based - // on NSight-Compute) that each warp must also write a consecutive memory segment? - auto transform_a = [&] (int i) { + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { int row = i / a_gl_rd_delta_o; return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; }; - // Since the computation of this remapping is non-trivial and, due to our main loop unrolls, all shared memory - // accesses are static, we simply precompute both transformed reads and writes. + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll +#pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll +#pragma unroll for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); } - // Since B-accesses have non-constant stride they have to be computed at runtime; we break dependicies between - // subsequent accesses with a tile by maintining multiple pointers (we have enough registers), a tiny optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependicies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4 *B_ptr[b_sh_wr_iters]; +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s = sh_b + (stages * b_sh_stage); - // Register storage for double buffer of shared memory reads. + // Shared memory storage for global fetch pipelines. + int4 *sh_a = sh; + int4 *sh_b = sh_a + (stages * a_sh_stage); + int4 *sh_s = sh_b + (stages * b_sh_stage); + // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2]; FragC frag_c[thread_m_blocks][4][2]; FragS frag_s[2][4]; // Zero accumulators. - auto zero_accums = [&] () { - #pragma unroll + auto zero_accums = [&]() { +#pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + reinterpret_cast(frag_c)[i] = 0; }; - // Asynchronously fetch the next A, B and s tile from global to the next shared memory pipeline location. - auto fetch_to_shared = [&] (int pipe, int a_off, bool pred = true) { + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i] - ); + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); B_ptr[i] += b_gl_rd_delta_o; } // Only fetch scales if this tile starts a new group if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; if (s_sh_wr_pred) cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); s_gl_rd += s_gl_rd_delta; } } - // Insert a fence even when we are winding down the pipeline to ensure that waiting is also correct at this point. + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. cp_async_fence(); }; // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&] () { - // We only have `stages - 2` active fetches since we are double buffering and can only issue the next fetch when - // it is guaranteed that the previous shared memory load is fully complete (as it may otherwise be overwritten). + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). cp_async_wait(); __syncthreads(); }; - // Load the next sub-tile from the current location in the shared memory pipe into the current register buffer. - auto fetch_to_registers = [&] (int k, int pipe) { - // It may seem inefficient that we reload the groups for every sub-tile; however, this does not seem to be a - // significant bottleneck, while some theoretically better attempts have lead to bad instruction ordering by the - // compiler and correspondingly a noticable drop in performance. + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; + // however, this does not seem to be a significant bottleneck, while some + // theoretically better attempts have lead to bad instruction ordering by + // the compiler and correspondingly a noticable drop in performance. if (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + int4 *sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll for (int i = 0; i < thread_m_blocks; i++) ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); }; - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&] (int k) { - // We have the m dimension as the inner loop in order to encourage overlapping dequantization and matmul operations. - #pragma unroll + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll for (int j = 0; j < 4; j++) { int b_quant = frag_b_quant[k % 2][j]; int b_quant_shift = b_quant >> 8; FragB frag_b0 = dequant(b_quant); - // If there are no groups, we can just scale the final output once and can avoid doing so for each weight. + // If there are no groups, we can just scale the final output once and can + // avoid doing so for each weight. if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); FragB frag_b1 = dequant(b_quant_shift); if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); @@ -461,47 +517,56 @@ __global__ void Marlin( } }; - // Since we slice across the k dimension of a tile in order to increase the number of warps while keeping the n - // dimension of a tile reasonable, we have multiple warps that accumulate their partial sums of the same output + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&] () { + auto thread_block_reduce = [&]() { constexpr int red_off = threads / b_sh_stride / 2; if (red_off >= 1) { int red_idx = threadIdx.x / b_sh_stride; constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); - // Parallel logarithmic shared memory reduction. We make sure to avoid any unnecessary read or write iterations, - // e.g., for two warps we write only once by warp 1 and read only once by warp 0. + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. - #pragma unroll +#pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll +#pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll +#pragma unroll for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float* c_rd = reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll + float *c_rd = reinterpret_cast( + &sh[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh[red_sh_wr]); +#pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; } - sh[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < 4 * 2; i++) { - float* c_rd = reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll + float *c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); +#pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; } } __syncthreads(); @@ -509,18 +574,21 @@ __global__ void Marlin( } }; - // Since multiple threadblocks may process parts of the same column slice, we finally have to globally reduce over - // the results. As the striped partioning minimizes the number of such reductions and our outputs are usually rather + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped partioning + // minimizes the number of such reductions and our outputs are usually rather // small, we perform this reduction serially in L2 cache. - auto global_reduce = [&] (bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to maximize L2 cache utilization in this step. - // To do this, we write out results in FP16 (but still reduce with FP32 compute). + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). constexpr int active_threads = 32 * thread_n_blocks / 4; if (threadIdx.x < active_threads) { int c_gl_stride = prob_n / 8; int c_gl_wr_delta_o = 8 * c_gl_stride; int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; c_gl_wr += (2 * thread_n_blocks) * slice_col; constexpr int c_sh_wr_delta = active_threads; int c_sh_wr = threadIdx.x; @@ -528,88 +596,103 @@ __global__ void Marlin( int row = (threadIdx.x % 32) / 4; if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up the compiler and lead to slowdowns, - // hence we also use async-copies even though these fetches are not actually asynchronous. - #pragma unroll +// Interestingly, doing direct global accesses here really seems to mess up the +// compiler and lead to slowdowns, hence we also use async-copies even though +// these fetches are not actually asynchronous. +#pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m - ); + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + row < prob_m); } cp_async_fence(); cp_async_wait<0>(); } - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll +#pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float( - reinterpret_cast<__half*>(&c_red)[j] - ); + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half *>(&c_red)[j]); } } if (!last) { int4 c; - #pragma unroll +#pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = __float2half( - reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] - ); + reinterpret_cast<__half *>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; } } } } }; - // Write out the reduce final result in the correct layout. We only actually reshuffle matrix fragments in this step, - // the reduction above is performed in fragment layout. - auto write_result = [&] () { + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { int c_gl_stride = prob_n / 8; constexpr int c_sh_stride = 2 * thread_n_blocks + 1; int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); int c_gl_wr_end = c_gl_stride * prob_m; - // We first reorder in shared memory to guarantee the most efficient final global write patterns - auto write = [&] (int idx, float c0, float c1, FragS& s) { + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS &s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - if (group_blocks == -1) // for per-column quantization we finally apply the scale here + if (group_blocks == + -1) // for per-column quantization we finally apply the scale here res = __hmul2(res, s[0]); - ((half2*) sh)[idx] = res; + ((half2 *)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll +#pragma unroll for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); } c_sh_wr += 16 * (4 * c_sh_stride); } } __syncthreads(); - #pragma unroll - for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { +#pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { if (c_gl_wr < c_gl_wr_end) { C[c_gl_wr] = sh[c_sh_rd]; c_gl_wr += c_gl_wr_delta; @@ -618,9 +701,9 @@ __global__ void Marlin( } }; - // Start global fetch and register load pipelines. - auto start_pipes = [&] () { - #pragma unroll + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { +#pragma unroll for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); zero_accums(); @@ -632,15 +715,17 @@ __global__ void Marlin( // Main loop. while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to ensure all shared memory accesses are - // static. Note that both pipelines have even length meaning that the next iteration will always start at index 0. - #pragma unroll +// We unroll over both the global fetch and the register load pipeline to ensure +// all shared memory accesses are static. Note that both pipelines have even +// length meaning that the next iteration will always start at index 0. +#pragma unroll for (int pipe = 0; pipe < stages;) { - #pragma unroll +#pragma unroll for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); pipe++; wait_for_stage(); } @@ -652,12 +737,14 @@ __global__ void Marlin( } a_gl_rd += a_gl_rd_delta_o * stages; - // Process results and, if necessary, proceed to the next column slice. While this pattern may not be the most - // readable, other ways of writing the loop seemed to noticeably worse performance after compliation. + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compliation. if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before write-out + // For per-column scales, we only fetch them here in the final step before + // write-out if (group_blocks == -1 && last) { if (s_sh_wr_pred) cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); @@ -668,11 +755,12 @@ __global__ void Marlin( cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } - if (slice_count > 1) { // only globally reduce if there is more than one block in a slice + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); @@ -684,12 +772,13 @@ __global__ void Marlin( slice_col++; init_slice(); if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; if (slice_col == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; } @@ -700,92 +789,201 @@ __global__ void Marlin( } } - -// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per schedule allows some more -// latency hiding. At the same time, we want relatively few warps to have many registers per warp and small tiles. -const int THREADS = 256; +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +const int USER_THREADS = + 256; // Note: This is only used with user-provided thread_k/n const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) - -#define CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, GROUP_BLOCKS) \ - else if ( \ - thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS \ - ) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - SHARED_MEM \ - ); \ - Marlin<<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, \ - prob_m, prob_n, prob_k, \ - locks \ - ); \ +const int SHARED_MEM = + 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +static constexpr int pack_factor_4bit = + 8; // We have 8 4-bit vals inside a 32 bit + +#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + GROUP_BLOCKS, NUM_THREADS) \ + else if (thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute(Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + SHARED_MEM); \ + Marlin<<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \ } -const int ERR_PROB_SHAPE = 1; -const int ERR_KERN_SHAPE = 2; - -int marlin_cuda( - const void* A, - const void* B, - void* C, - void* s, - int prob_m, - int prob_n, - int prob_k, - void* workspace, - int groupsize = -1, - int dev = 0, - cudaStream_t stream = 0, - int thread_k = -1, - int thread_n = -1, - int sms = -1, - int max_par = 16 -) { +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X +}; + +bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, + int prob_k) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // thread_k can be only 128 or 64 (because it must be less than groupsize + // which is 128) + if (th_config.thread_k != 128 && th_config.thread_k != 64) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + return true; +} + +thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { + + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + } + + return thread_config_t{-1, -1, -1}; +} + +#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) + +void marlin_cuda(const void *A, const void *B, void *C, void *s, int prob_m, + int prob_n, int prob_k, void *workspace, int groupsize = -1, + int dev = 0, cudaStream_t stream = 0, int thread_k = -1, + int thread_n = -1, int sms = -1, int max_par = 16) { int tot_m = prob_m; int tot_m_blocks = ceildiv(tot_m, 16); int pad = 16 * tot_m_blocks - tot_m; if (sms == -1) cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - if (thread_k == -1 || thread_n == -1) { - if (prob_m <= 16) { - // For small batchizes, better partioning is slightly more important than better compute utilization - thread_k = 128; - thread_n = 128; - } else { - thread_k = 64; - thread_n = 256; - } + + // Set thread config + thread_config_t th_config; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; + } else { + // Auto config + th_config = determine_thread_config(prob_m, prob_n, prob_k); + } + + if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { + throw std::runtime_error( + "Invalid thread config: thread_k = " + str(th_config.thread_k) + + ", thread_n = " + str(th_config.thread_n) + + ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + + str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); } + // Uncomment for debug + // std::cout << "Using thread_config: thread_k = " + str(th_config.thread_k) + + // ", thread_n = " + str(th_config.thread_n) + + // ", num_threads = " + str(th_config.num_threads) + " for + // MKN = [" + str(prob_m) + + // ", " + str(prob_k) + ", " + str(prob_n) + "]\n"; + + int num_threads = th_config.num_threads; + thread_k = th_config.thread_k; + thread_n = th_config.thread_n; + int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; int blocks = sms; - if (prob_n % thread_n != 0 || prob_k % thread_k != 0 || (group_blocks != -1 && prob_k % group_blocks != 0)) - return ERR_PROB_SHAPE; - if (prob_m == 0 || prob_n == 0 || prob_k == 0) - return 0; + if (prob_m == 0 || prob_n == 0 || prob_k == 0) { + return; + } + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + if (group_blocks != -1) { + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } - const int4* A_ptr = (const int4*) A; - const int4* B_ptr = (const int4*) B; - int4* C_ptr = (int4*) C; - const int4* s_ptr = (const int4*) s; + const int4 *A_ptr = (const int4 *)A; + const int4 *B_ptr = (const int4 *)B; + int4 *C_ptr = (int4 *)C; + const int4 *s_ptr = (const int4 *)s; - int cols = prob_n / thread_n; - int* locks = (int*) workspace; + int *locks = (int *)workspace; - int ret = 0; for (int i = 0; i < tot_m_blocks; i += 4) { int thread_m_blocks = tot_m_blocks - i; prob_m = tot_m - 16 * i; int par = 1; if (thread_m_blocks > 4) { - // Note that parallel > 1 currently only works for inputs without any padding + // Note that parallel > 1 currently only works for inputs without any + // padding par = (16 * thread_m_blocks - pad) / 64; if (par > max_par) par = max_par; @@ -793,92 +991,122 @@ int marlin_cuda( i += 4 * (par - 1); thread_m_blocks = 4; } - - // For compilation speed, we only define the kernel configurations that have seemed useful (in terms of performance) - // in our testing, however many more are, in principle, possible. - if (false) {} - CALL_IF(1, 8, 8, -1) - CALL_IF(1, 8, 8, 8) - CALL_IF(1, 16, 4, -1) - CALL_IF(1, 16, 4, 8) - CALL_IF(2, 16, 4, -1) - CALL_IF(2, 16, 4, 8) - CALL_IF(3, 16, 4, -1) - CALL_IF(3, 16, 4, 8) - CALL_IF(4, 16, 4, -1) - CALL_IF(4, 16, 4, 8) - else - ret = ERR_KERN_SHAPE; + + // For compilation speed, we only define the kernel configurations that have + // seemed useful (in terms of performance) in our testing, however many more + // are, in principle, possible. + if (false) { + } + CALL_IF(8, 8, 256) + CALL_IF(16, 4, 256) + CALL_IF(8, 4, 128) + CALL_IF(4, 8, 128) + else { + throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + + ", " + str(prob_k) + ", " + str(prob_n) + "]" + + ", groupsize = " + str(groupsize) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; } - - return ret; } -#endif - } // namespace marlin -} // namespace vllm - -const int ERR_PROB_SHAPE = 1; -const int ERR_KERN_SHAPE = 2; - -// input: `torch.half` input matrix of shape `(m, k)` in standard row-major layout -// weights: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()` -// output: `torch.half` out matrix of shape `(m, n)` in standard row-major layout -// scales: `torch.half` scales of shape `(m / groupsize, n)` -// workspace: `torch.int` tensor with at least `n / 128` entries that are all zero - -void marlin_gemm( - const torch::Tensor& input, - const torch::Tensor& weights, - torch::Tensor& output, - const torch::Tensor& scales, - torch::Tensor& workspace -) { - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as auto -1) + +torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_scales, torch::Tensor &workspace, + int64_t size_m, int64_t size_n, int64_t size_k) { + + // Verify M + TORCH_CHECK(size_m == a.size(0), + "Shape mismatch: a.size(0) = " + str(a.size(0)) + + ", size_m = " + str(size_m)); + + // Verify K + TORCH_CHECK(size_k == a.size(1), + "Shape mismatch: a.size(1) = " + str(a.size(1)) + + ", size_k = " + str(size_k)); + TORCH_CHECK(size_k % marlin::tile_size == 0, + "size_k = " + str(size_k) + + " is not divisible by tile_size = " + str(marlin::tile_size)); + TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = " + + str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + + ", tile_size = " + str(marlin::tile_size)); + + // Verify N + TORCH_CHECK(b_scales.size(1) == size_n, + "b_scales.size(1) = " + str(b_scales.size(1)) + + ", size_n = " + str(size_n)); + TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, + "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + + " is not divisible by tile_size = " + str(marlin::tile_size)); + + int actual_size_n = + (b_q_weight.size(1) / marlin::tile_size) * marlin::pack_factor_4bit; + TORCH_CHECK(size_n == actual_size_n, + "size_n = " + str(size_n) + + ", actual_size_n = " + str(actual_size_n)); + + // Verify A device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + // Verify B device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + // Verify scales device and strides + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + // Alloc C matrix + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as auto -1) + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) int thread_n = -1; // sms: number of SMs to use for the kernel (can usually be left as auto -1) int sms = -1; - // number of parallel problems to solve (helps with large batch sizes) - int max_par = 16; - - int prob_m = input.size(0); - int prob_n = output.size(1); - int prob_k = input.size(1); - int groupsize = (scales.size(0) == 1) ? -1 : prob_k / scales.size(0); - if (groupsize != -1 && groupsize * scales.size(0) != prob_k) - AT_ERROR("k=", prob_k, " not compatible with ", scales.size(0), " groups."); - if (workspace.numel() < (prob_n / 128) * max_par) - AT_ERROR("workspace must be of size at least ", (prob_n / 128) * max_par, "."); - int dev = input.get_device(); - int err = vllm::marlin::marlin_cuda( - input.data_ptr(), - weights.data_ptr(), - output.data_ptr(), - scales.data_ptr(), - prob_m, prob_n, prob_k, - workspace.data_ptr(), - groupsize, - dev, - at::cuda::getCurrentCUDAStream(dev), - thread_k, - thread_n, - sms, - max_par - ); - if (err == ERR_PROB_SHAPE) { - AT_ERROR( - "Problem (m=", prob_m, ", n=", prob_n, ", k=", prob_k, ")", - " not compatible with thread_k=", thread_k, ", thread_n=", thread_n, "." - ); - } else if (err == ERR_KERN_SHAPE) { - AT_ERROR( - "No kernel implementation for thread_k=", thread_k, ", thread_n=", thread_n, ", groupsize=", groupsize, "." - ); + + // Detect groupsize + if (b_scales.size(0) != 1) { + TORCH_CHECK(size_k % b_scales.size(0) == 0, + "size_k = " + str(size_k) + + ", is not divisible by b_scales.size(0) = " + + str(b_scales.size(0))); } -} \ No newline at end of file + int groupsize = b_scales.size(0) == 1 ? -1 : size_k / b_scales.size(0); + + // Verify groupsize + TORCH_CHECK(groupsize == -1 || groupsize == 128, + "Unexpected groupsize = " + str(groupsize)); + + // Verify workspace size + TORCH_CHECK( + size_n % marlin::min_thread_n == 0, + "size_n = " + str(size_n) + + ", is not divisible by min_thread_n = " + str(marlin::min_thread_n)); + int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = " + str(workspace.numel()) + + " is below min_workspace_size = " + str(min_workspace_size)); + + int dev = a.get_device(); + marlin::marlin_cuda(a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), + b_scales.data_ptr(), size_m, size_n, size_k, + workspace.data_ptr(), groupsize, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, + sms, marlin::max_par); + + return c; +} diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 451640740490..45ee51a31089 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -65,6 +65,24 @@ def get_linear_method(self) -> "MarlinLinearMethod": def get_scaled_act_names(self) -> List[str]: return [] + +class MarlinWorkspace: + + def __init__(self, out_features): + max_parallel = 16 + min_n_threads = 64 + + assert ( + out_features % min_n_threads == 0 + ), "out_features = {out_features} is not divisible by min_n_threads = {min_n_threads}" + + max_workspace_size = (out_features // min_n_threads) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda") + + class MarlinLinearMethod(LinearMethodBase): """Linear method for Marlin. @@ -187,19 +205,18 @@ def apply_weights( scales = weights["s"] workspace = weights["workspace"] - output = torch.empty(x.shape[:-1] + (scales.shape[1], ), - dtype=x.dtype, - device=x.device) + x_2d = x.view(-1, x.shape[-1]) - ops.marlin_gemm( - x.view(-1, x.shape[-1]), - qweight, - output.view(-1, output.shape[-1]), - scales, - workspace, - ) + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace.scratch, + size_m, size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) if bias is not None: - output.add_(bias) + output.add_(bias) # In-place add return output From c7fb928053d7b1d4865d3dbdffa0871d30ce4d00 Mon Sep 17 00:00:00 2001 From: alexm Date: Fri, 2 Feb 2024 14:10:12 -0500 Subject: [PATCH 15/47] Update checks related to MarlinConfig --- .../layers/quantization/marlin.py | 99 +++++++++---------- 1 file changed, 45 insertions(+), 54 deletions(-) diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 45ee51a31089..452c13a2d4b3 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + class MarlinConfig(QuantizationConfig): """Config class for Marlin. @@ -21,17 +22,25 @@ def __init__( self.group_size = group_size if self.group_size != 128 and self.group_size != -1: raise ValueError( - "Currently, only group size 128 and -1 (channelwise) is supported for " + "Currently, only group size 128 and -1 (channelwise) is supported for " f"Marlin, but got group_size of {self.group_size}") # 4 Bits packed into 32 bit datatype. self.pack_factor = 32 // 4 + # Tile size used by marlin kernels. self.tile_size = 16 - # Data for workspace. + + # Min out_features dim + self.min_n_threads = 64 + + # Min in_features dim + self.min_k_threads = 128 + + # Max parallel problems to solve at once (improves large batch performance) self.max_parallel = 16 - self.min_n_threads = 128 - # Permutation length use by the marlin kernels. + + # Permutation length used by the marlin kernels. self.perm_len = 1024 def __repr__(self) -> str: @@ -66,23 +75,6 @@ def get_scaled_act_names(self) -> List[str]: return [] -class MarlinWorkspace: - - def __init__(self, out_features): - max_parallel = 16 - min_n_threads = 64 - - assert ( - out_features % min_n_threads == 0 - ), "out_features = {out_features} is not divisible by min_n_threads = {min_n_threads}" - - max_workspace_size = (out_features // min_n_threads) * max_parallel - - self.scratch = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda") - - class MarlinLinearMethod(LinearMethodBase): """Linear method for Marlin. @@ -101,36 +93,36 @@ def create_weights( output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: - # del output_size # Unused. + del output_size # Unused. + if params_dtype != torch.float16: raise ValueError( f"The params dtype must be float16, but got {params_dtype}") - if input_size_per_partition % self.quant_config.group_size != 0: + + # Validate output_size_per_partition + if output_size_per_partition % self.quant_config.min_n_threads != 0: raise ValueError( - "The input size is not aligned with the quantized " - "weight shape. This can be caused by too large " - "tensor parallel size.") + f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by min_n_threads = {self.quant_config.min_n_threads}." + ) if output_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( - "The output size is not aligned with the quantized " - "weight shape. This can be caused by too large " - "tensor parallel size.") - if input_size_per_partition % 128 != 0: - raise ValueError( - "The input_size_per_partition must be divisible by 128, " - f"but got {input_size_per_partition}") - if output_size_per_partition % 256 != 0: - raise ValueError( - "The output_size_per_partition must be divisible by 256, " - f"but got {output_size_per_partition}") - if output_size_per_partition % self.quant_config.min_n_threads != 0: - raise ValueError( - "The output_size per partition must be divisible by the minimum " - f"number of threads {self.quant_config.min_n_threads}, but got {output_size_per_partition}" + f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by pack_factor = {self.quant_config.pack_factor}." ) - # check that we have at least 4 tiles horizontally in the shard - num_tiles_per_perm = self.quant_config.perm_len // (self.quant_config.tile_size**2) + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}." + ) + if self.quant_config.group_size != -1: + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}." + ) + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2) if output_size_per_partition % num_tiles_per_perm != 0: raise ValueError( "Each permutation group must reside on the same gpu") @@ -179,15 +171,14 @@ def create_weights( }, ) - max_workspace_size = (output_size_per_partition // self.quant_config.min_n_threads) * self.quant_config.max_parallel - workspace = Parameter( - torch.zeros( - max_workspace_size, - device="cuda", - dtype=torch.int - ), - requires_grad=False - ) + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_n_threads) * self.quant_config.max_parallel + workspace = Parameter(torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + requires_grad=False) return { "B": qweight, @@ -211,8 +202,8 @@ def apply_weights( size_k = x_2d.shape[1] size_n = scales.shape[1] - output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace.scratch, - size_m, size_n, size_k) + output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m, + size_n, size_k) output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) From 1ea85f35006465a7ca0889480498704f4137f9c0 Mon Sep 17 00:00:00 2001 From: alexm Date: Fri, 2 Feb 2024 14:10:36 -0500 Subject: [PATCH 16/47] formatting --- vllm/config.py | 10 ++++------ vllm/model_executor/layers/quantization/marlin.py | 9 ++++----- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ac6ade63f54b..075d9555f124 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -152,14 +152,12 @@ def _verify_quantization(self) -> None: # Parse quantization method from the HF model config, if available. hf_quant_config = getattr(self.hf_config, "quantization_config", None) if hf_quant_config is not None: - + hf_quant_method = str(hf_quant_config["quant_method"]).lower() # If the GPTQ model is serialized in marlin format, use marlin. - if ( - hf_quant_method == "gptq" and - "is_marlin_format" in hf_quant_config and - hf_quant_config["is_marlin_format"] - ): + if (hf_quant_method == "gptq" + and "is_marlin_format" in hf_quant_config + and hf_quant_config["is_marlin_format"]): hf_quant_method = "marlin" if self.quantization is None: self.quantization = hf_quant_method diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 452c13a2d4b3..1b4f811d8f4a 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -114,11 +114,10 @@ def create_weights( raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}." ) - if self.quant_config.group_size != -1: - if input_size_per_partition % self.quant_config.group_size != 0: - raise ValueError( - f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}." - ) + if self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}." + ) # Check that we have at least 4 tiles horizontally in the shard num_tiles_per_perm = self.quant_config.perm_len // ( From a435c97f8ad176c524b1cf876baacd3fa09fa9d1 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+rib-2@users.noreply.github.com> Date: Tue, 6 Feb 2024 21:04:00 -0500 Subject: [PATCH 17/47] Update pybind.cpp --- csrc/pybind.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 626d1d9dd9fb..0ea949feb86b 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -53,6 +53,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); +#endif + ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); From 90e8b8fe6bb084d0127f4170162c078d345917e5 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+rib-2@users.noreply.github.com> Date: Tue, 6 Feb 2024 21:07:23 -0500 Subject: [PATCH 18/47] Update ops.h cleanup to undo autoformatting --- csrc/ops.h | 63 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 41 insertions(+), 22 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 82cb84fabd33..90804e6f1f1a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -33,21 +33,34 @@ void paged_attention_v2( const c10::optional& alibi_slopes, const std::string& kv_cache_dtype); -void rms_norm(torch::Tensor &out, torch::Tensor &input, torch::Tensor &weight, - float epsilon); - -void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, - torch::Tensor &weight, float epsilon); - -void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, - torch::Tensor &key, int head_size, - torch::Tensor &cos_sin_cache, bool is_neox); - -void silu_and_mul(torch::Tensor &out, torch::Tensor &input); - -void gelu_new(torch::Tensor &out, torch::Tensor &input); - -void gelu_fast(torch::Tensor &out, torch::Tensor &input); +void rms_norm( + torch::Tensor &out, + torch::Tensor &input, + torch::Tensor &weight, + float epsilon); + +void fused_add_rms_norm( + torch::Tensor &input, + torch::Tensor &residual, + torch::Tensor &weight, + float epsilon); + +void rotary_embedding( + torch::Tensor &positions, torch::Tensor &query, + torch::Tensor &key, int head_size, + torch::Tensor &cos_sin_cache, bool is_neox); + +void silu_and_mul( + torch::Tensor &out, + torch::Tensor &input); + +void gelu_new( + torch::Tensor &out, + torch::Tensor &input); + +void gelu_fast( + torch::Tensor &out, + torch::Tensor &input); #ifndef USE_ROCM torch::Tensor awq_gemm( @@ -65,13 +78,19 @@ torch::Tensor awq_dequantize( int thx, int thy); -void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, - torch::Tensor lookup_table); - -torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama); +void squeezellm_gemm( + torch::Tensor vec, + torch::Tensor mat, + torch::Tensor mul, + torch::Tensor lookup_table); + +torch::Tensor gptq_gemm( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + bool use_exllama); void gptq_shuffle( torch::Tensor q_weight, From b03af7d24718aff3fddfc161ee6b759b5dcddad2 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+rib-2@users.noreply.github.com> Date: Tue, 6 Feb 2024 21:09:26 -0500 Subject: [PATCH 19/47] Update ops.h cleanup formatting --- csrc/ops.h | 41 ++++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 90804e6f1f1a..1b48d94721d9 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -34,33 +34,36 @@ void paged_attention_v2( const std::string& kv_cache_dtype); void rms_norm( - torch::Tensor &out, - torch::Tensor &input, - torch::Tensor &weight, + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& weight, float epsilon); void fused_add_rms_norm( - torch::Tensor &input, - torch::Tensor &residual, - torch::Tensor &weight, + torch::Tensor& input, + torch::Tensor& residual, + torch::Tensor& weight, float epsilon); void rotary_embedding( - torch::Tensor &positions, torch::Tensor &query, - torch::Tensor &key, int head_size, - torch::Tensor &cos_sin_cache, bool is_neox); + torch::Tensor& positions, + torch::Tensor& query, + torch::Tensor& key, + int head_size, + torch::Tensor &cos_sin_cache, + bool is_neox); void silu_and_mul( - torch::Tensor &out, - torch::Tensor &input); + torch::Tensor& out, + torch::Tensor& input); void gelu_new( - torch::Tensor &out, - torch::Tensor &input); + torch::Tensor& out, + torch::Tensor& input); void gelu_fast( - torch::Tensor &out, - torch::Tensor &input); + torch::Tensor& out, + torch::Tensor& input); #ifndef USE_ROCM torch::Tensor awq_gemm( @@ -79,16 +82,16 @@ torch::Tensor awq_dequantize( int thy); void squeezellm_gemm( - torch::Tensor vec, - torch::Tensor mat, + torch::Tensor vec, + torch::Tensor mat, torch::Tensor mul, torch::Tensor lookup_table); torch::Tensor gptq_gemm( - torch::Tensor a, + torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, + torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, bool use_exllama); From 91922877b9234b44ddbe1f4c1e90ffbd5b675672 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Wed, 7 Feb 2024 02:23:44 +0000 Subject: [PATCH 20/47] readded marlin --- csrc/ops.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/csrc/ops.h b/csrc/ops.h index 1b48d94721d9..3c0e6459b26a 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -81,6 +81,16 @@ torch::Tensor awq_dequantize( int thx, int thy); +torch::Tensor marlin_gemm( + torch::Tensor& a, + torch::Tensor& b_q_weight, + torch::Tensor& b_scales, + torch::Tensor& workspace, + int64_t size_m, + int64_t size_n, + int64_t size_k); +#endif + void squeezellm_gemm( torch::Tensor vec, torch::Tensor mat, From ce50dd4733b291ae775e4c6141dc6a8208a94b88 Mon Sep 17 00:00:00 2001 From: alexm Date: Wed, 7 Feb 2024 20:02:42 -0500 Subject: [PATCH 21/47] Bug fix for determination of the scales size in marlin layer --- vllm/model_executor/layers/quantization/marlin.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 1b4f811d8f4a..7566d78a8aba 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -148,13 +148,12 @@ def create_weights( }, ) - # Scales in Float16. - group_size = self.quant_config.group_size - if group_size == -1: - group_size = input_size + # Determine if channelwise or not + input_groups = 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size + scales = Parameter( torch.empty( - input_size_per_partition // group_size, + input_groups, output_size_per_partition, device="cuda", dtype=params_dtype, @@ -164,8 +163,7 @@ def create_weights( set_weight_attrs( scales, { - "input_dim": - None if input_size == input_size_per_partition else 0, + "input_dim": None if input_groups == 1 else 0, "output_dim": 1, }, ) From 5a305d39161851867ae38b4bb48367ee5b3e8a71 Mon Sep 17 00:00:00 2001 From: alexm-nm Date: Thu, 8 Feb 2024 14:14:54 +0000 Subject: [PATCH 22/47] Ensure marlin only compiles for GPU compute capability >= 8.0 --- .../quantization/marlin/marlin_cuda_kernel.cu | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu index 87a1b4a94759..edbd3997fb94 100644 --- a/csrc/quantization/marlin/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu @@ -24,6 +24,8 @@ #include +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + template inline std::string str(T x) { return std::to_string(x); } namespace marlin { @@ -1110,3 +1112,22 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, return c; } + +#else + +torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_scales, torch::Tensor &workspace, + int64_t size_m, int64_t size_n, int64_t size_k) { + + throw std::runtime_error( + "Marlin quantization kernel requires compute capability at least 8.0"); + + // Dummy code + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + + return c; +} + +#endif From b1773aae13787c0a3b3d31187b4d1cbfba5287b8 Mon Sep 17 00:00:00 2001 From: alexm Date: Thu, 8 Feb 2024 10:02:05 -0500 Subject: [PATCH 23/47] fix marlin compilation again --- .../quantization/marlin/marlin_cuda_kernel.cu | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu index edbd3997fb94..3bc4b5576f79 100644 --- a/csrc/quantization/marlin/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu @@ -24,14 +24,14 @@ #include -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - template inline std::string str(T x) { return std::to_string(x); } namespace marlin { constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // Instances of `Vec` are used to organize groups of >>registers<<, as needed // for instance as inputs to tensor core operations. Consequently, all // corresponding index accesses must be compile-time constants, which is why we @@ -791,6 +791,36 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } } +#else + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks with + // a separate quantization scale + > +__global__ void +Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 + *__restrict__ s, // fp16 quantization scales of shape (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization +) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + // 8 warps are a good choice since every SM has 4 schedulers and having more // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. @@ -1112,22 +1142,3 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, return c; } - -#else - -torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, - torch::Tensor &b_scales, torch::Tensor &workspace, - int64_t size_m, int64_t size_n, int64_t size_k) { - - throw std::runtime_error( - "Marlin quantization kernel requires compute capability at least 8.0"); - - // Dummy code - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - return c; -} - -#endif From d63627e6b4ab3b3e77a7bb7d05184b237ac1e4ae Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sun, 18 Feb 2024 16:33:19 +0000 Subject: [PATCH 24/47] added marlin test --- tests/conftest.py | 37 ++++++++++++- tests/models/test_marlin.py | 106 ++++++++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 2 deletions(-) create mode 100644 tests/models/test_marlin.py diff --git a/tests/conftest.py b/tests/conftest.py index 8d6afdbd0035..015d8448d160 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,8 +16,11 @@ def _read_prompts(filename: str) -> str: prompts = [] with open(filename, "r") as f: - prompt = f.readline() - prompts.append(prompt) + while True: + prompt = f.readline() + if prompt == "": + break + prompts.append(prompt) return prompts @@ -195,6 +198,24 @@ def generate( outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs + def generate_w_logprobs( + self, + prompts: List[str], + sampling_params: SamplingParams, + ) -> List[Tuple[List[int], str]]: + assert sampling_params.logprobs is not None + + req_outputs = self.model.generate(prompts, + sampling_params=sampling_params) + outputs = [] + for req_output in req_outputs: + for sample in req_output.outputs: + output_str = sample.text + output_ids = sample.token_ids + output_logprobs = sample.logprobs + outputs.append((output_ids, output_str, output_logprobs)) + return outputs + def generate_greedy( self, prompts: List[str], @@ -205,6 +226,18 @@ def generate_greedy( return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] + def generate_greedy_logprobs( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + ) -> List[Tuple[List[int], str]]: + greedy_logprobs_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) + outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params) + + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + def generate_beam_search( self, prompts: List[str], diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py new file mode 100644 index 000000000000..8917bf1de15c --- /dev/null +++ b/tests/models/test_marlin.py @@ -0,0 +1,106 @@ +"""Compare the outputs of a GPTQ model to a Marlin model. + +Note: GPTQ and Marlin do not have bitwise correctness. + +As a result, in this test, we just confirm that the top 5 selected tokens of the +Marlin model are in the top 5 selected tokens of the GPTQ model. + +Note: Marlin internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for Marlin. As a result, we re-run the test +up to 3 times to see if we pass. + +Run `pytest tests/models/test_marlin.py --forked`. +""" + +import pytest +import torch +from dataclasses import dataclass +from vllm.model_executor.layers.quantization import _QUANTIZATION_CONFIG_REGISTRY + +gptq_config_cls = _QUANTIZATION_CONFIG_REGISTRY["gptq"] +marlin_config_cls = _QUANTIZATION_CONFIG_REGISTRY["marlin"] + +@dataclass +class ModelPair: + model_marlin: str + model_gptq: str + +MODEL_PAIRS = [ + ModelPair( + model_marlin="nm-testing/zephyr-beta-7b-marlin-g128", + model_gptq="nm-testing/zephyr-beta-7b-gptq-g128" + ), + ModelPair( + model_marlin="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", + model_gptq="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-gptq" + ) +] + +@pytest.mark.parametrize("model_pair", MODEL_PAIRS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [3]) +@pytest.mark.parametrize("failure_tolerance", [3]) +def test_models( + vllm_runner, + example_prompts, + model_pair: ModelPair, + dtype: str, + max_tokens: int, + num_logprobs: int, + failure_tolerance: int, +) -> None: + + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + if (capability < gptq_config_cls.get_min_capability() or + capability < marlin_config_cls.get_min_capability()): + print(f"Skipping test_marlin due to device capability: {capability} (requires >8.0)") + return + + # Run the experiment failure_tolerance times + for retry_idx in range(failure_tolerance): + gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype) + gptq_outputs = gptq_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + del gptq_model + + marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype) + marlin_outputs = marlin_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + del marlin_model + + # index of the failed_prompt + failed_prompt_idx = -1 + failed_input_idx = -1 + + # loop through the prompts + for prompt_idx in range(len(example_prompts)): + gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[prompt_idx] + marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[prompt_idx] + + for idx, (gptq_output_id, marlin_output_id) in enumerate(zip(gptq_output_ids, marlin_output_ids)): + # If sequence is not an exact match, + if marlin_output_id != gptq_output_id: + # Each predicted token must be in top 3 of the other's or iteration is a failure + if ( + gptq_output_id not in marlin_logprobs[idx] or + marlin_output_id not in gptq_logprobs[idx] + ): + failed_prompt_idx = prompt_idx + failed_input_idx = idx + break + + # Break out of this retry + if failed_prompt_idx != -1: + print(f"Found failure on retry idx {retry_idx}") + break + + # Return if we + if failed_prompt_idx == -1: + return + + assert gptq_output_id in marlin_logprobs[failed_input_idx], ( + f"Test{failed_prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}") + assert marlin_output_id in gptq_logprobs[failed_input_idx], ( + f"Test{failed_prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}") \ No newline at end of file From 18981b1c26a1a556d37371f4aebc798fcab068c1 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sun, 18 Feb 2024 16:36:47 +0000 Subject: [PATCH 25/47] added marlin test --- tests/models/test_marlin.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 8917bf1de15c..4dbbda6e9df3 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -17,15 +17,18 @@ from dataclasses import dataclass from vllm.model_executor.layers.quantization import _QUANTIZATION_CONFIG_REGISTRY -gptq_config_cls = _QUANTIZATION_CONFIG_REGISTRY["gptq"] -marlin_config_cls = _QUANTIZATION_CONFIG_REGISTRY["marlin"] +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] +is_marlin_supported = ( + capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability() +) @dataclass class ModelPair: model_marlin: str model_gptq: str -MODEL_PAIRS = [ +model_pairs = [ ModelPair( model_marlin="nm-testing/zephyr-beta-7b-marlin-g128", model_gptq="nm-testing/zephyr-beta-7b-gptq-g128" @@ -36,7 +39,9 @@ class ModelPair: ) ] -@pytest.mark.parametrize("model_pair", MODEL_PAIRS) +@pytest.mark.skipif(not is_marlin_supported, + reason="Marlin is not supported on this GPU type.") +@pytest.mark.parametrize("model_pair", model_pairs) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [3]) @@ -50,13 +55,6 @@ def test_models( num_logprobs: int, failure_tolerance: int, ) -> None: - - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] - if (capability < gptq_config_cls.get_min_capability() or - capability < marlin_config_cls.get_min_capability()): - print(f"Skipping test_marlin due to device capability: {capability} (requires >8.0)") - return # Run the experiment failure_tolerance times for retry_idx in range(failure_tolerance): From 828c6213774702637d3ed78a15481cc0972ca7a8 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sun, 18 Feb 2024 16:37:55 +0000 Subject: [PATCH 26/47] updated skipping logic --- tests/models/test_marlin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 4dbbda6e9df3..058b00a134bd 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -19,7 +19,7 @@ capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] -is_marlin_supported = ( +marlin_not_supported = ( capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability() ) @@ -39,7 +39,7 @@ class ModelPair: ) ] -@pytest.mark.skipif(not is_marlin_supported, +@pytest.mark.skipif(marlin_not_supported, reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("model_pair", model_pairs) @pytest.mark.parametrize("dtype", ["half"]) From 4f1759b818cfdc291b80db2ccd3d349611adabdf Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sun, 18 Feb 2024 17:04:08 +0000 Subject: [PATCH 27/47] updated skipping logic --- tests/models/test_marlin.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 058b00a134bd..eec36a46e2fc 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -29,10 +29,10 @@ class ModelPair: model_gptq: str model_pairs = [ - ModelPair( - model_marlin="nm-testing/zephyr-beta-7b-marlin-g128", - model_gptq="nm-testing/zephyr-beta-7b-gptq-g128" - ), + # ModelPair( + # model_marlin="nm-testing/zephyr-beta-7b-marlin-g128", + # model_gptq="nm-testing/zephyr-beta-7b-gptq-g128" + # ), ModelPair( model_marlin="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", model_gptq="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-gptq" @@ -58,16 +58,16 @@ def test_models( # Run the experiment failure_tolerance times for retry_idx in range(failure_tolerance): - gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype) - gptq_outputs = gptq_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - del gptq_model - marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype) marlin_outputs = marlin_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) del marlin_model + gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype) + gptq_outputs = gptq_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + del gptq_model + # index of the failed_prompt failed_prompt_idx = -1 failed_input_idx = -1 From f1714e90141e5d07414edba2c2a2adf2bca625fe Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sun, 18 Feb 2024 17:06:34 +0000 Subject: [PATCH 28/47] added memory profiling --- tests/models/test_marlin.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index eec36a46e2fc..c20b60556b23 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -63,6 +63,10 @@ def test_models( example_prompts, max_tokens, num_logprobs) del marlin_model + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + print(free_gpu_memory) + print(total_gpu_memory) + gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype) gptq_outputs = gptq_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) From e3a4706fc64c616189055b27a40a4486a764e422 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sun, 18 Feb 2024 17:08:42 +0000 Subject: [PATCH 29/47] added memory profiling --- tests/conftest.py | 6 ++++++ tests/models/test_marlin.py | 4 ---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 015d8448d160..bc6b2ee510e8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -169,6 +169,12 @@ def __init__( tokenizer_name: Optional[str] = None, dtype: str = "half", ) -> None: + + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + print(f"free: {free_gpu_memory // 1024 // 1024}") + print(f"total: {total_gpu_memory // 1024 // 1024}") + print("\n") + self.model = LLM( model=model_name, tokenizer=tokenizer_name, diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index c20b60556b23..eec36a46e2fc 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -63,10 +63,6 @@ def test_models( example_prompts, max_tokens, num_logprobs) del marlin_model - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - print(free_gpu_memory) - print(total_gpu_memory) - gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype) gptq_outputs = gptq_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) From efd886cfec84efb8bd636eb40a5eefcdee3b2c07 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sun, 18 Feb 2024 17:10:03 +0000 Subject: [PATCH 30/47] test wout memory utilization --- tests/conftest.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bc6b2ee510e8..2bfad8452061 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -171,8 +171,8 @@ def __init__( ) -> None: free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - print(f"free: {free_gpu_memory // 1024 // 1024}") - print(f"total: {total_gpu_memory // 1024 // 1024}") + print(f"free: {free_gpu_memory / 1024 / 1024 // 1024}GB") + print(f"total: {total_gpu_memory / 1024 / 1024 // 1024}GB") print("\n") self.model = LLM( @@ -181,6 +181,7 @@ def __init__( trust_remote_code=True, dtype=dtype, swap_space=0, + gpu_memory_utilization=0.8, ) def generate( From 70f5850c5e51ba9417ff7b4d8da6fc78084db747 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sun, 18 Feb 2024 17:12:38 +0000 Subject: [PATCH 31/47] updating memory profiling --- vllm/entrypoints/llm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fc82018d18eb..892f18c876f6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -106,6 +106,12 @@ def __init__( disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) + + print("In LLM:") + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + print(f"free: {free_gpu_memory / 1024 / 1024 // 1024}GB") + print(f"total: {total_gpu_memory / 1024 / 1024 // 1024}GB") + print("\n") self.llm_engine = LLMEngine.from_engine_args(engine_args) self.request_counter = Counter() From 567fe38ef92c9c5769fc8b591e5749114785d33d Mon Sep 17 00:00:00 2001 From: rib-2 Date: Sun, 18 Feb 2024 17:21:21 +0000 Subject: [PATCH 32/47] adding more profiling --- vllm/engine/llm_engine.py | 8 ++++++++ vllm/entrypoints/llm.py | 2 ++ 2 files changed, 10 insertions(+) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 03a2b1157652..af236390a19d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,3 +1,5 @@ +import torch + import copy from collections import defaultdict import os @@ -76,6 +78,12 @@ def __init__( placement_group: Optional["PlacementGroup"], log_stats: bool, ) -> None: + print("In LLMEngine:") + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + print(f"free: {free_gpu_memory / 1024 / 1024 // 1024}GB") + print(f"total: {total_gpu_memory / 1024 / 1024 // 1024}GB") + print("\n") + logger.info( "Initializing an LLM engine with config: " f"model={model_config.model!r}, " diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 892f18c876f6..b156c9679abb 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,5 +1,7 @@ from typing import List, Optional, Union +import torch + from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast From 01f5e403d85ca7b623cf4a0e726484b3771cfb25 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sun, 18 Feb 2024 17:26:34 +0000 Subject: [PATCH 33/47] updating memory profiling --- tests/conftest.py | 1 - vllm/engine/llm_engine.py | 13 +++++++++++++ vllm/worker/worker.py | 11 +++++++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 2bfad8452061..bd6041edb748 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -181,7 +181,6 @@ def __init__( trust_remote_code=True, dtype=dtype, swap_space=0, - gpu_memory_utilization=0.8, ) def generate( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index af236390a19d..fbd4a6810568 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -168,6 +168,13 @@ def _init_workers(self): kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) + print("In Init Workers:") + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + print(f"free: {free_gpu_memory / 1024 / 1024 // 1024}GB") + print(f"total: {total_gpu_memory / 1024 / 1024 // 1024}GB") + print("\n") + + self._run_workers("init_model") self._run_workers("load_model") @@ -327,6 +334,12 @@ def _init_cache(self) -> None: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameters. """ + print("In Init Cache:") + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + print(f"free: {free_gpu_memory / 1024 / 1024 // 1024}GB") + print(f"total: {total_gpu_memory / 1024 / 1024 // 1024}GB") + print("\n") + # Get the maximum number of blocks that can be allocated on GPU and CPU. num_blocks = self._run_workers( "profile_num_available_blocks", diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c97e82a55a1e..4ec48e0920e5 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -115,7 +115,18 @@ def profile_num_available_blocks( """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. + print("In profile_num_available_blocks before empty_cache():") + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + print(f"free: {free_gpu_memory / 1024 / 1024 // 1024}GB") + print(f"total: {total_gpu_memory / 1024 / 1024 // 1024}GB") + print("\n") + torch.cuda.empty_cache() + print("In profile_num_available_blocks after empty_cache():") + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + print(f"free: {free_gpu_memory / 1024 / 1024 // 1024}GB") + print(f"total: {total_gpu_memory / 1024 / 1024 // 1024}GB") + print("\n") # Execute a forward pass with dummy inputs to profile the memory usage # of the model. From fc5310c94ed61fa14281ce1b8e91a9a8354a819e Mon Sep 17 00:00:00 2001 From: rib-2 Date: Sun, 18 Feb 2024 17:57:42 +0000 Subject: [PATCH 34/47] removed memory profiling --- tests/conftest.py | 5 ----- tests/models/test_marlin.py | 18 ++++++++++++++---- vllm/engine/llm_engine.py | 17 ----------------- vllm/entrypoints/llm.py | 5 ----- vllm/worker/worker.py | 11 ----------- 5 files changed, 14 insertions(+), 42 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bd6041edb748..fb3968a638f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -170,11 +170,6 @@ def __init__( dtype: str = "half", ) -> None: - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - print(f"free: {free_gpu_memory / 1024 / 1024 // 1024}GB") - print(f"total: {total_gpu_memory / 1024 / 1024 // 1024}GB") - print("\n") - self.model = LLM( model=model_name, tokenizer=tokenizer_name, diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index eec36a46e2fc..afe3dc4eb8c8 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -29,10 +29,10 @@ class ModelPair: model_gptq: str model_pairs = [ - # ModelPair( - # model_marlin="nm-testing/zephyr-beta-7b-marlin-g128", - # model_gptq="nm-testing/zephyr-beta-7b-gptq-g128" - # ), + ModelPair( + model_marlin="nm-testing/zephyr-beta-7b-marlin-g128", + model_gptq="nm-testing/zephyr-beta-7b-gptq-g128" + ), ModelPair( model_marlin="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", model_gptq="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-gptq" @@ -61,11 +61,21 @@ def test_models( marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype) marlin_outputs = marlin_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) + + # Note: not sure why, but deleting just the model on Ada Lovelace + # does not free the GPU memory. On Ampere, deleting the just model + # frees the memory. + del marlin_model.model.llm_engine.driver_worker del marlin_model gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype) gptq_outputs = gptq_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) + + # Note: not sure why, but deleting just the model on Ada Lovelace + # does not free the GPU memory. On Ampere, deleting the just model + # frees the memory. + del gptq_model.model.llm_engine.driver_worker del gptq_model # index of the failed_prompt diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fbd4a6810568..4f173f10a3b4 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -78,12 +78,6 @@ def __init__( placement_group: Optional["PlacementGroup"], log_stats: bool, ) -> None: - print("In LLMEngine:") - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - print(f"free: {free_gpu_memory / 1024 / 1024 // 1024}GB") - print(f"total: {total_gpu_memory / 1024 / 1024 // 1024}GB") - print("\n") - logger.info( "Initializing an LLM engine with config: " f"model={model_config.model!r}, " @@ -168,12 +162,6 @@ def _init_workers(self): kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) - print("In Init Workers:") - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - print(f"free: {free_gpu_memory / 1024 / 1024 // 1024}GB") - print(f"total: {total_gpu_memory / 1024 / 1024 // 1024}GB") - print("\n") - self._run_workers("init_model") self._run_workers("load_model") @@ -334,11 +322,6 @@ def _init_cache(self) -> None: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameters. """ - print("In Init Cache:") - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - print(f"free: {free_gpu_memory / 1024 / 1024 // 1024}GB") - print(f"total: {total_gpu_memory / 1024 / 1024 // 1024}GB") - print("\n") # Get the maximum number of blocks that can be allocated on GPU and CPU. num_blocks = self._run_workers( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b156c9679abb..fdd01f0a3088 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -109,11 +109,6 @@ def __init__( **kwargs, ) - print("In LLM:") - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - print(f"free: {free_gpu_memory / 1024 / 1024 // 1024}GB") - print(f"total: {total_gpu_memory / 1024 / 1024 // 1024}GB") - print("\n") self.llm_engine = LLMEngine.from_engine_args(engine_args) self.request_counter = Counter() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 4ec48e0920e5..c97e82a55a1e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -115,18 +115,7 @@ def profile_num_available_blocks( """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. - print("In profile_num_available_blocks before empty_cache():") - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - print(f"free: {free_gpu_memory / 1024 / 1024 // 1024}GB") - print(f"total: {total_gpu_memory / 1024 / 1024 // 1024}GB") - print("\n") - torch.cuda.empty_cache() - print("In profile_num_available_blocks after empty_cache():") - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - print(f"free: {free_gpu_memory / 1024 / 1024 // 1024}GB") - print(f"total: {total_gpu_memory / 1024 / 1024 // 1024}GB") - print("\n") # Execute a forward pass with dummy inputs to profile the memory usage # of the model. From 99ab19d97be1fd33e54d3d3e1b41b29635eb8cb5 Mon Sep 17 00:00:00 2001 From: rib-2 Date: Sun, 18 Feb 2024 18:03:30 +0000 Subject: [PATCH 35/47] cleaned up --- csrc/ops.h | 10 +++++----- tests/conftest.py | 1 - vllm/engine/llm_engine.py | 4 ---- vllm/entrypoints/llm.py | 3 --- 4 files changed, 5 insertions(+), 13 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 3c0e6459b26a..55cbf03ae829 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -34,15 +34,15 @@ void paged_attention_v2( const std::string& kv_cache_dtype); void rms_norm( - torch::Tensor& out, - torch::Tensor& input, + torch::Tensor& out, + torch::Tensor& input, torch::Tensor& weight, float epsilon); void fused_add_rms_norm( - torch::Tensor& input, + torch::Tensor& input, torch::Tensor& residual, - torch::Tensor& weight, + torch::Tensor& weight, float epsilon); void rotary_embedding( @@ -50,7 +50,7 @@ void rotary_embedding( torch::Tensor& query, torch::Tensor& key, int head_size, - torch::Tensor &cos_sin_cache, + torch::Tensor& cos_sin_cache, bool is_neox); void silu_and_mul( diff --git a/tests/conftest.py b/tests/conftest.py index fb3968a638f2..015d8448d160 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -169,7 +169,6 @@ def __init__( tokenizer_name: Optional[str] = None, dtype: str = "half", ) -> None: - self.model = LLM( model=model_name, tokenizer=tokenizer_name, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4f173f10a3b4..03a2b1157652 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,5 +1,3 @@ -import torch - import copy from collections import defaultdict import os @@ -162,7 +160,6 @@ def _init_workers(self): kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) - self._run_workers("init_model") self._run_workers("load_model") @@ -322,7 +319,6 @@ def _init_cache(self) -> None: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameters. """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. num_blocks = self._run_workers( "profile_num_available_blocks", diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fdd01f0a3088..fc82018d18eb 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,7 +1,5 @@ from typing import List, Optional, Union -import torch - from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -108,7 +106,6 @@ def __init__( disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) - self.llm_engine = LLMEngine.from_engine_args(engine_args) self.request_counter = Counter() From eabeea61ba84df15c180c0893ea705756cdfc6f4 Mon Sep 17 00:00:00 2001 From: rib-2 Date: Sun, 18 Feb 2024 18:04:09 +0000 Subject: [PATCH 36/47] added newline --- tests/models/test_marlin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index afe3dc4eb8c8..91ff3e00f069 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -111,4 +111,4 @@ def test_models( assert gptq_output_id in marlin_logprobs[failed_input_idx], ( f"Test{failed_prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}") assert marlin_output_id in gptq_logprobs[failed_input_idx], ( - f"Test{failed_prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}") \ No newline at end of file + f"Test{failed_prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}") From d06459590df392a52f4d0bdd736ade78b11c1731 Mon Sep 17 00:00:00 2001 From: rib-2 Date: Sun, 18 Feb 2024 18:27:03 +0000 Subject: [PATCH 37/47] ran ./format.sh --- tests/conftest.py | 4 ++- tests/models/test_marlin.py | 55 +++++++++++++++++++------------------ 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 015d8448d160..63de7fe418ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -232,7 +232,9 @@ def generate_greedy_logprobs( max_tokens: int, num_logprobs: int, ) -> List[Tuple[List[int], str]]: - greedy_logprobs_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) + greedy_logprobs_params = SamplingParams(temperature=0.0, + max_tokens=max_tokens, + logprobs=num_logprobs) outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params) return [(output_ids, output_str, output_logprobs) diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 91ff3e00f069..90b7fc3b340d 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -20,25 +20,23 @@ capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] marlin_not_supported = ( - capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability() -) + capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability()) + @dataclass class ModelPair: model_marlin: str model_gptq: str + model_pairs = [ - ModelPair( - model_marlin="nm-testing/zephyr-beta-7b-marlin-g128", - model_gptq="nm-testing/zephyr-beta-7b-gptq-g128" - ), - ModelPair( - model_marlin="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", - model_gptq="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-gptq" - ) + ModelPair(model_marlin="nm-testing/zephyr-beta-7b-marlin-g128", + model_gptq="nm-testing/zephyr-beta-7b-gptq-g128"), + ModelPair(model_marlin="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", + model_gptq="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-gptq") ] + @pytest.mark.skipif(marlin_not_supported, reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("model_pair", model_pairs) @@ -61,7 +59,7 @@ def test_models( marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype) marlin_outputs = marlin_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - + # Note: not sure why, but deleting just the model on Ada Lovelace # does not free the GPU memory. On Ampere, deleting the just model # frees the memory. @@ -71,7 +69,7 @@ def test_models( gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype) gptq_outputs = gptq_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - + # Note: not sure why, but deleting just the model on Ada Lovelace # does not free the GPU memory. On Ampere, deleting the just model # frees the memory. @@ -84,31 +82,34 @@ def test_models( # loop through the prompts for prompt_idx in range(len(example_prompts)): - gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[prompt_idx] - marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[prompt_idx] + gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[ + prompt_idx] + marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[ + prompt_idx] - for idx, (gptq_output_id, marlin_output_id) in enumerate(zip(gptq_output_ids, marlin_output_ids)): - # If sequence is not an exact match, + for idx, (gptq_output_id, marlin_output_id) in enumerate( + zip(gptq_output_ids, marlin_output_ids)): + # If sequence is not an exact match, if marlin_output_id != gptq_output_id: # Each predicted token must be in top 3 of the other's or iteration is a failure - if ( - gptq_output_id not in marlin_logprobs[idx] or - marlin_output_id not in gptq_logprobs[idx] - ): - failed_prompt_idx = prompt_idx - failed_input_idx = idx + if (gptq_output_id not in marlin_logprobs[idx] + or marlin_output_id not in gptq_logprobs[idx]): + failed_prompt_idx = prompt_idx + failed_input_idx = idx break - + # Break out of this retry if failed_prompt_idx != -1: print(f"Found failure on retry idx {retry_idx}") break - - # Return if we + + # Return if we if failed_prompt_idx == -1: return assert gptq_output_id in marlin_logprobs[failed_input_idx], ( - f"Test{failed_prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}") + f"Test{failed_prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}" + ) assert marlin_output_id in gptq_logprobs[failed_input_idx], ( - f"Test{failed_prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}") + f"Test{failed_prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}" + ) From 9b1bc5fde4140199f7db2ee116f336e909a570e5 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Sun, 18 Feb 2024 19:22:50 +0000 Subject: [PATCH 38/47] merged into upstream main --- tests/models/test_marlin.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 90b7fc3b340d..aee02d7361fd 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -32,6 +32,8 @@ class ModelPair: model_pairs = [ ModelPair(model_marlin="nm-testing/zephyr-beta-7b-marlin-g128", model_gptq="nm-testing/zephyr-beta-7b-gptq-g128"), + ModelPair(model_marlin="nm-testing/zephyr-beta-7b-marlin-channelwise", + model_gptq="nm-testing/zephyr-beta-7b-gptq-channelwise"), ModelPair(model_marlin="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", model_gptq="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-gptq") ] From 013f10f7568954418f7698a43f034cf70d370791 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+rib-2@users.noreply.github.com> Date: Sun, 18 Feb 2024 16:51:21 -0500 Subject: [PATCH 39/47] Update test_marlin.py --- tests/models/test_marlin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index aee02d7361fd..1f6b8500be0c 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -32,8 +32,8 @@ class ModelPair: model_pairs = [ ModelPair(model_marlin="nm-testing/zephyr-beta-7b-marlin-g128", model_gptq="nm-testing/zephyr-beta-7b-gptq-g128"), - ModelPair(model_marlin="nm-testing/zephyr-beta-7b-marlin-channelwise", - model_gptq="nm-testing/zephyr-beta-7b-gptq-channelwise"), + ModelPair(model_marlin="robertgshaw2/zephyr-beta-7b-marlin-channelwise", + model_gptq="robertgshaw2/zephyr-beta-7b-gptq-channelwise"), ModelPair(model_marlin="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", model_gptq="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-gptq") ] From 7f2165e1c0143547b7e8c23126ca20bfc571f798 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+rib-2@users.noreply.github.com> Date: Sun, 18 Feb 2024 16:53:50 -0500 Subject: [PATCH 40/47] Update test_marlin.py --- tests/models/test_marlin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 1f6b8500be0c..c3a1491f5874 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -32,8 +32,8 @@ class ModelPair: model_pairs = [ ModelPair(model_marlin="nm-testing/zephyr-beta-7b-marlin-g128", model_gptq="nm-testing/zephyr-beta-7b-gptq-g128"), - ModelPair(model_marlin="robertgshaw2/zephyr-beta-7b-marlin-channelwise", - model_gptq="robertgshaw2/zephyr-beta-7b-gptq-channelwise"), + ModelPair(model_marlin="robertgshaw2/zephyr-7b-beta-channelwise-marlin", + model_gptq="robertgshaw2/zephyr-7b-beta-channelwise-gptq"), ModelPair(model_marlin="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", model_gptq="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-gptq") ] From 7a9b828c44654604e464fc2fac7ac2b12fdc646a Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Mon, 19 Feb 2024 15:20:55 +0000 Subject: [PATCH 41/47] updated retry testing to use pytest-flaky rather than implementing this logic myself. --- requirements-dev.txt | 1 + tests/models/test_marlin.py | 114 ++++++++++++++++-------------------- 2 files changed, 50 insertions(+), 65 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index f8126008d079..81c59c014943 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -13,6 +13,7 @@ types-setuptools pytest pytest-forked pytest-asyncio +pytest-rerunfailures httpx einops # required for MPT flash_attn # required for HuggingFace's llama implementation diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index c3a1491f5874..6dc61f5fab26 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -1,9 +1,8 @@ """Compare the outputs of a GPTQ model to a Marlin model. Note: GPTQ and Marlin do not have bitwise correctness. - -As a result, in this test, we just confirm that the top 5 selected tokens of the -Marlin model are in the top 5 selected tokens of the GPTQ model. +As a result, in this test, we just confirm that the top selected tokens of the +Marlin/GPTQ models are in the top 3 selections of eachother. Note: Marlin internally uses locks to synchronize the threads. This can result in very slight nondeterminism for Marlin. As a result, we re-run the test @@ -39,13 +38,13 @@ class ModelPair: ] +@pytest.mark.flaky(reruns=2) @pytest.mark.skipif(marlin_not_supported, reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("model_pair", model_pairs) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [3]) -@pytest.mark.parametrize("failure_tolerance", [3]) def test_models( vllm_runner, example_prompts, @@ -53,65 +52,50 @@ def test_models( dtype: str, max_tokens: int, num_logprobs: int, - failure_tolerance: int, ) -> None: - - # Run the experiment failure_tolerance times - for retry_idx in range(failure_tolerance): - marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype) - marlin_outputs = marlin_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - - # Note: not sure why, but deleting just the model on Ada Lovelace - # does not free the GPU memory. On Ampere, deleting the just model - # frees the memory. - del marlin_model.model.llm_engine.driver_worker - del marlin_model - - gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype) - gptq_outputs = gptq_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - - # Note: not sure why, but deleting just the model on Ada Lovelace - # does not free the GPU memory. On Ampere, deleting the just model - # frees the memory. - del gptq_model.model.llm_engine.driver_worker - del gptq_model - - # index of the failed_prompt - failed_prompt_idx = -1 - failed_input_idx = -1 - - # loop through the prompts - for prompt_idx in range(len(example_prompts)): - gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[ - prompt_idx] - marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[ - prompt_idx] - - for idx, (gptq_output_id, marlin_output_id) in enumerate( - zip(gptq_output_ids, marlin_output_ids)): - # If sequence is not an exact match, - if marlin_output_id != gptq_output_id: - # Each predicted token must be in top 3 of the other's or iteration is a failure - if (gptq_output_id not in marlin_logprobs[idx] - or marlin_output_id not in gptq_logprobs[idx]): - failed_prompt_idx = prompt_idx - failed_input_idx = idx - break - - # Break out of this retry - if failed_prompt_idx != -1: - print(f"Found failure on retry idx {retry_idx}") - break - - # Return if we - if failed_prompt_idx == -1: - return - - assert gptq_output_id in marlin_logprobs[failed_input_idx], ( - f"Test{failed_prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}" - ) - assert marlin_output_id in gptq_logprobs[failed_input_idx], ( - f"Test{failed_prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}" - ) + marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype) + marlin_outputs = marlin_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + # Note: not sure why, but deleting just the model on Ada Lovelace + # does not free the GPU memory. On Ampere, deleting the just model + # frees the memory. + del marlin_model.model.llm_engine.driver_worker + del marlin_model + + gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype) + gptq_outputs = gptq_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + # Note: not sure why, but deleting just the model on Ada Lovelace + # does not free the GPU memory. On Ampere, deleting the just model + # frees the memory. + del gptq_model.model.llm_engine.driver_worker + del gptq_model + + # loop through the prompts + for prompt_idx in range(len(example_prompts)): + gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[ + prompt_idx] + marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[ + prompt_idx] + + for idx, (gptq_output_id, marlin_output_id) in enumerate( + zip(gptq_output_ids, marlin_output_ids)): + # If sequence is not an exact match, + if marlin_output_id != gptq_output_id: + if gptq_output_id not in marlin_logprobs[idx]: + print("\n\n\nOMG IM HERE\n\n\n") + if marlin_output_id not in gptq_logprobs[idx]: + print("\n\n\nOMG IM HERE\n\n\n") + + # Each predicted token must be in top 5 of the other's + assert gptq_output_id in marlin_logprobs[idx], ( + f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}" + ) + assert marlin_output_id in gptq_logprobs[idx], ( + f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}" + ) + + # Break out since sequences will now diverge. + break \ No newline at end of file From c23902f22ca5f056f608c07dbb63a580f1231eb1 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Mon, 19 Feb 2024 15:21:37 +0000 Subject: [PATCH 42/47] missed newline --- tests/models/test_marlin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 6dc61f5fab26..c740f9ef0892 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -98,4 +98,4 @@ def test_models( ) # Break out since sequences will now diverge. - break \ No newline at end of file + break From e7aba66fc11f4de681b24b21d3d7a5f1b3a03e58 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Mon, 19 Feb 2024 15:27:01 +0000 Subject: [PATCH 43/47] formatting --- tests/models/test_marlin.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index c740f9ef0892..49ed3c652ac4 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -64,8 +64,9 @@ def test_models( del marlin_model gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype) - gptq_outputs = gptq_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) + gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, + max_tokens, + num_logprobs) # Note: not sure why, but deleting just the model on Ada Lovelace # does not free the GPU memory. On Ampere, deleting the just model From 2403f7defe13cbe84ea38c42fa566aac1b46428e Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Mon, 19 Feb 2024 15:29:10 +0000 Subject: [PATCH 44/47] removed silly print --- tests/models/test_marlin.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 49ed3c652ac4..edf6c8384926 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -85,11 +85,6 @@ def test_models( zip(gptq_output_ids, marlin_output_ids)): # If sequence is not an exact match, if marlin_output_id != gptq_output_id: - if gptq_output_id not in marlin_logprobs[idx]: - print("\n\n\nOMG IM HERE\n\n\n") - if marlin_output_id not in gptq_logprobs[idx]: - print("\n\n\nOMG IM HERE\n\n\n") - # Each predicted token must be in top 5 of the other's assert gptq_output_id in marlin_logprobs[idx], ( f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}" From aabaed274ed542ab6d7cb20eb0854bed87b063a7 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Thu, 29 Feb 2024 22:33:46 +0000 Subject: [PATCH 45/47] added license --- csrc/quantization/marlin/LICENSE | 209 ++++++++++++++++++ .../quantization/marlin/marlin_cuda_kernel.cu | 1 + 2 files changed, 210 insertions(+) create mode 100644 csrc/quantization/marlin/LICENSE diff --git a/csrc/quantization/marlin/LICENSE b/csrc/quantization/marlin/LICENSE new file mode 100644 index 000000000000..1d1e4cf9c823 --- /dev/null +++ b/csrc/quantization/marlin/LICENSE @@ -0,0 +1,209 @@ +Contains code from https://github.com/IST-DASLab/marlin + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------ + +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses/ +for text of these licenses. diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu index 3bc4b5576f79..efc443b518cf 100644 --- a/csrc/quantization/marlin/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu @@ -1,4 +1,5 @@ /* + * Modified by Neural Magic * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) * * Licensed under the Apache License, Version 2.0 (the "License"); From a67dc8dac09f649ab6ea864a7c8cf1b58e7def27 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Thu, 29 Feb 2024 22:38:06 +0000 Subject: [PATCH 46/47] format --- csrc/quantization/marlin/marlin_cuda_kernel.cu | 10 +++++----- tests/models/test_marlin.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu index efc443b518cf..0bbd0975bfd4 100644 --- a/csrc/quantization/marlin/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu @@ -1,6 +1,6 @@ /* * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) + * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) # noqa * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -407,7 +407,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependicies between subsequent accesses with a tile by + // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. const int4 *B_ptr[b_sh_wr_iters]; @@ -480,7 +480,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // It may seem inefficient that we reload the groups for every sub-tile; // however, this does not seem to be a significant bottleneck, while some // theoretically better attempts have lead to bad instruction ordering by - // the compiler and correspondingly a noticable drop in performance. + // the compiler and correspondingly a noticeable drop in performance. if (group_blocks != -1) { int4 *sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * @@ -578,7 +578,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped partioning + // finally have to globally reduce over the results. As the striped partitioning // minimizes the number of such reductions and our outputs are usually rather // small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { @@ -742,7 +742,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // Process results and, if necessary, proceed to the next column slice. // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compliation. + // the loop seemed to noticeably worse performance after compilation. if (slice_iters == 0) { cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index edf6c8384926..f3cc517364f0 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -2,7 +2,7 @@ Note: GPTQ and Marlin do not have bitwise correctness. As a result, in this test, we just confirm that the top selected tokens of the -Marlin/GPTQ models are in the top 3 selections of eachother. +Marlin/GPTQ models are in the top 3 selections of each other. Note: Marlin internally uses locks to synchronize the threads. This can result in very slight nondeterminism for Marlin. As a result, we re-run the test From 8ff42c0ca67c5ff4a81fc219d8e9d0977731fd13 Mon Sep 17 00:00:00 2001 From: rsnm2 Date: Thu, 29 Feb 2024 22:40:52 +0000 Subject: [PATCH 47/47] minor change for ruff --- csrc/quantization/marlin/marlin_cuda_kernel.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu index 0bbd0975bfd4..cf1b0afdec8b 100644 --- a/csrc/quantization/marlin/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu @@ -1,6 +1,6 @@ /* * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at) # noqa + * Copyright (C) Marlin.2024 Elias Frantar * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License.