diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 0aa0dc14c748..d4999dce7159 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -454,9 +454,9 @@ __global__ void concat_and_cache_ds_mla_kernel( } // The first two warps handle the NoPE part - const int8_t warp_idx = threadIdx.x >> 5; - const int8_t lane_idx = threadIdx.x & 31; - const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4); + const int8_t warp_idx = threadIdx.x >> 5; // 0,1 + const int8_t lane_idx = threadIdx.x & 31; // 0..31 + const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4); // 0..3 // Each thread handles 8 elements of NoPE // Load the NoPE elements for this thread into registers @@ -484,7 +484,7 @@ __global__ void concat_and_cache_ds_mla_kernel( // The first lane of each half-warp writes the scale to kv_cache if ((lane_idx == 0) || (lane_idx == 16)) { float* kv_cache_32bit = reinterpret_cast(&kv_cache[dst_idx_start]); - const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx; + const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx; // 128 + 0..3 kv_cache_32bit[dst_idx] = tile_scale; } @@ -505,6 +505,58 @@ __global__ void concat_and_cache_ds_mla_kernel( *reinterpret_cast(result); } + +// void indexer_k_quant_and_cache( +// torch::Tensor& k, // [num_tokens, head_dim] +// torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] +// torch::Tensor& slot_mapping, // [num_tokens] +// int64_t quant_block_size, // quantization block size +// const std::string& scale_fmt) { +// int num_tokens = k.size(0); +// int head_dim = k.size(1); +// int cache_block_size = kv_cache.size(1); +// int cache_stride = kv_cache.size(2); +// bool use_ue8m0 = scale_fmt == "ue8m0"; + +// TORCH_CHECK(k.device() == kv_cache.device(), +// "k and kv_cache must be on the same device"); +// TORCH_CHECK(k.device() == slot_mapping.device(), +// "k and slot_mapping must be on the same device"); +// TORCH_CHECK(head_dim % quant_block_size == 0, +// "head_dim must be divisible by quant_block_size"); + +// constexpr int vec_size = 4; +// head_dim = 128 +// quant_block_size = 128 +// vec_size = 4 +// grid_y = (128 + 128 * 4 - 1) / (128 * 4) = 1 +// dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) / +// (quant_block_size * vec_size)); +// dim3 block(32, vec_size); +// const at::cuda::OptionalCUDAGuard device_guard(device_of(k)); +// const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + +// DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3", +// CALL_INDEXER_K_QUANT_AND_CACHE); + +// // Macro to dispatch the kernel based on the data type. +// #define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ +// vllm::indexer_k_quant_and_cache_kernel \ +// <<>>( \ +// reinterpret_cast(k.data_ptr()), \ +// reinterpret_cast(kv_cache.data_ptr()), \ +// slot_mapping.data_ptr(), head_dim, quant_block_size, \ +// cache_block_size, cache_stride, use_ue8m0); + + +// void indexer_k_quant_and_cache( +// torch::Tensor& k, // [num_tokens, head_dim] +// torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] +// torch::Tensor& slot_mapping, // [num_tokens] +// int64_t quant_block_size, // quantization block size +// const std::string& scale_fmt) { + +// __nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3 template __global__ void indexer_k_quant_and_cache_kernel( const scalar_t* __restrict__ k, // [num_tokens, head_dim] @@ -513,11 +565,22 @@ __global__ void indexer_k_quant_and_cache_kernel( const int head_dim, // dimension of each head const int quant_block_size, // quantization block size const int cache_block_size, // cache block size - const int cache_stride, // stride for each token in kv_cache + const int cache_stride, // stride for each token in kv_cache, 132 const bool use_ue8m0 // use ue8m0 scale format ) { constexpr int VEC_SIZE = 4; const int64_t token_idx = blockIdx.x; + // For head_dim = 0, blockIdx.y = 0 + // 32 4 + // blockDim.y * blockDim.x = 128 + // (blockIdx.y * 128 + threadIdx.y * 4 + threadIdx.x) * 4 = thradIdx.y * 4 + threadIdx.x + // theadIdx.y, threadIdx.x + // (0, 0): 0, 1, 2, 3 + // (0, 1): 4, 5, 6, 7 + // (0, 2): 8, 9, 10, 11 + // (0, 3): 12, 13, 14, 15 + // (1, 0): 16, 17, 18, 19 + // ... const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; @@ -525,6 +588,7 @@ __global__ void indexer_k_quant_and_cache_kernel( const int64_t block_idx = slot_idx / cache_block_size; const int64_t block_offset = slot_idx % cache_block_size; + // For head_dim 128, only thread 0...31 are active // NOTE: slot_idx can be -1 if the token is padded if (slot_idx < 0 || (head_dim_idx >= head_dim)) { return; @@ -572,6 +636,9 @@ __global__ void indexer_k_quant_and_cache_kernel( } } + + + template __global__ void cp_gather_indexer_k_quant_cache_kernel( const char* __restrict__ kv_cache, // [num_blocks, block_size, diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py index f4b4fac84015..3ea1d3622504 100644 --- a/tests/kernels/attention/test_deepgemm_attention.py +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -63,44 +63,57 @@ def _generate_cp_test_data(seq_len: int, seq_len_kv: int): def _ref_fp8_mqa_logits( - q: torch.Tensor, - kv: torch.Tensor, - weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, - cu_seqlen_ke: torch.Tensor, + q: torch.Tensor, # [seq_len, heads, head_dim] + kv: torch.Tensor, # [seq_len_kv, head_dim] + weights: torch.Tensor, # [seq_len, heads] + cu_seqlen_ks: torch.Tensor, # [seq_len] + cu_seqlen_ke: torch.Tensor, # [seq_len] ): + breakpoint() seq_len_kv = kv.shape[0] k = kv q = q.float() k = k.float() - + # mask_lo: [seq_len, seq_len_kv] + # mask_hi: [seq_len, seq_len_kv] mask_lo = ( torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] ) mask_hi = ( torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] ) + # mask: [seq_len, seq_len_kv] mask = mask_lo & mask_hi + # q: [seq_len, heads, head_dim] -> [heads, seq_len, head_dim] + # k: [seq_len_kv, head_dim] -> [head_dim, seq_len_kv] + # [heads, seq_len, head_dim] @ [head_dim, seq_len_kv] -> [heads, seq_len, seq_len_kv] score = torch.einsum("mhd,nd->hmn", q, k) - logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + # [seq_len, heads] -> [heads, seq_len, 1] -> [heads, seq_len, 1] + cur_weight = weights.unsqueeze(-1).transpose(0, 1) + # [heads, seq_len, seq_len_kv] * [heads, seq_len, 1] -> [heads, seq_len, seq_len_kv] + # [heads, seq_len, seq_len_kv] -> [seq_len, seq_len_kv] + logits = (score.relu() * cur_weight).sum(dim=0) logits = logits.masked_fill(~mask, float("-inf")) return logits @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") -@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") -@pytest.mark.skipif( - not current_platform.has_device_capability(90), reason="SM90 and SM100 only" -) +# @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") +# @pytest.mark.skipif( +# not current_platform.has_device_capability(90), reason="SM90 and SM100 only" +# ) def test_deepgemm_fp8_mqa_logits(): torch.manual_seed(0) random.seed(0) num_heads, head_dim = 32, 128 for seq_len in (512,): for seq_len_kv in (1024,): - for disable_cp in (False, True): + for disable_cp in ( + False, + # True, + ): q = torch.randn( seq_len, num_heads, @@ -125,7 +138,6 @@ def test_deepgemm_fp8_mqa_logits(): q_fp8 = q.to(torch.float8_e4m3fn) kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False) - logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) ref_logits = _ref_fp8_mqa_logits( q=q, @@ -134,6 +146,8 @@ def test_deepgemm_fp8_mqa_logits(): cu_seqlen_ks=ks, cu_seqlen_ke=ke, ) + breakpoint() + logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) ref_neginf_mask = ref_logits == float("-inf") neginf_mask = logits == float("-inf") @@ -146,15 +160,40 @@ def test_deepgemm_fp8_mqa_logits(): def _ref_fp8_paged_mqa_logits( - q: torch.Tensor, - kv_cache: torch.Tensor, - weights: torch.Tensor, - context_lens: torch.Tensor, - block_tables: torch.Tensor, - max_model_len: int, + q: torch.Tensor, # [batch_size, next_n, heads, head_dim], [batch_size, 1, heads, 128] + kv_cache: torch.Tensor, # [num_blocks, block_size, 1, head_dim], [num_blocks, 64, 1, 128] + weights: torch.Tensor, # [batch_size * next_n, heads] , [batch_size * 1, heads] + context_lens: torch.Tensor, # [batch_size] + block_tables: torch.Tensor, # [batch_size, max_block_len] + max_model_len: int, ): + """ + This is a reference implementation of the fp8_paged_mqa_logits function. + The pseudo code is as follows: + # For each sequence in the batch, update its logits + for batch_idx in range(batch_size): + cur_context_len = context_lens[batch_idx].item() + block_nums = (cur_context_len + block_size - 1) // block_size + cur_query = q[batch_idx] # [1, heads, head_dim] + # For each block in the current sequence, compute the attention scores + for block_rk in range(block_nums): + block_idx = block_tables[batch_idx][block_rk] + cur_key_in_block = kv_cache[block_idx] # [block_size, 1, head_dim] + # [heads, 1, head_dim] @ [1, head_dim, block_size] -> [heads, 1, block_size] + cur_attn_score_temp = cur_query.transpose(0, 1) @ cur_key_in_block.transpose(0, 1).transpose(1, 2) + # Apply ReLU + act_cur_attn_score_temp = torch.relu(cur_attn_score_temp.to(torch.float32)) + # Apply the head weights, the weighted average over all heads + act_cur_attn_score_temp_weighted = act_cur_attn_score_temp * weights[i].transpose(0, 1)[..., None] + # Sum over all heads to get the final logits for the each key position in the block + cur_attn_score = act_cur_attn_score_temp_weighted.sum(dim=0) # [1, block_size] + # Update the logits for the current block, the formula ignores the mask + logits[batch_idx][block_rk*block_size:(block_rk+1)*block_size] = cur_attn_score + """ + batch_size, next_n, _, _ = q.size() _, block_size, _, _ = kv_cache.size() + # [batch_size, max_model_len] logits = torch.full( [batch_size * next_n, max_model_len], float("-inf"), @@ -163,49 +202,91 @@ def _ref_fp8_paged_mqa_logits( ) context_lens_list = context_lens.tolist() for i in range(batch_size): + # breakpoint() context_len = context_lens_list[i] q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") + # weights: [batch_size, heads] -> [1, heads] -> [heads, 1] + # weight_slice: [heads, 1] weight_slice = ( weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() ) + # For the current batch item, iterate over all blocks + # Based on the current context length and block size, there are ceil(context_len / block_size) blocks + # For example, if context_len=2447 and block_size=64, there are 39 blocks for block_rk in range(cdiv(context_len, block_size)): block_idx = block_tables[i][block_rk] + # qx: [1, heads, head_dim] + # kx: [block_size, 1, head_dim] qx, kx = q[i], kv_cache[block_idx] + # k_offsets: [block_size] k_offsets = torch.arange( block_rk * block_size, (block_rk + 1) * block_size, device="cuda", ) + # mask: [1, block_size] mask = (k_offsets[None, :] < context_len) & ( k_offsets[None, :] <= q_offsets[:, None] ) + # qx: [1, heads, head_dim] -> [heads, 1, head_dim] + qx_temp = qx.transpose(0, 1).contiguous() + # kx: [block_size, 1, head_dim] -> [1, block_size, head_dim] -> [1, head_dim, block_size] + kx_temp = kx.transpose(0, 1).transpose(1, 2).contiguous() + # [heads, 1, head_dim] @ [1, head_dim, block_size] -> [heads, 1, block_size] + # For given query, compute attention scores with all keys in the block + temp_s = (qx_temp @ kx_temp).to(torch.float32) # [heads, 1, block_size] s = torch.where( mask[None, :, :], - (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( - logits.dtype - ), + temp_s, float("-inf"), ) - s = torch.relu(s) * weight_slice[..., None] + # s = torch.where( + # mask[None, :, :], + # (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( + # logits.dtype + # ), + # float("-inf"), + # ) + # + # cur_weight: [heads, 1, 1] + cur_weight = weight_slice[..., None] + # [heads, 1, block_size] * [heads, 1, 1] -> [heads, 1, block_size] + s = torch.relu(s) * cur_weight + # sum the score over all heads + # [heads, 1, block_size] -> [1, block_size] s = s.sum(dim=0) + # Update the logits for the current batch item and current block + # k_offsets: stands for the token positions of the keys in the current block + # q_offsets: stands for the token positions of the queries in the current next_n + cur_batch_start = i * next_n + cur_batch_end = (i + 1) * next_n + cur_block_start = block_rk * block_size + cur_block_end = (block_rk + 1) * block_size + print(f"Update logits for batch [{cur_batch_start}, {cur_batch_end}), block [{cur_block_start}, {cur_block_end})") + if cur_block_end == 2496: + breakpoint() + cur_logist = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf")) logits[ i * next_n : (i + 1) * next_n, block_rk * block_size : (block_rk + 1) * block_size, - ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf")) + ] = cur_logist return logits @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") -@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") -@pytest.mark.skipif( - not current_platform.has_device_capability(90), reason="SM90 and SM100 only" -) +# @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") +# @pytest.mark.skipif( +# not current_platform.has_device_capability(90), reason="SM90 and SM100 only" +# ) def test_deepgemm_fp8_paged_mqa_logits(): torch.manual_seed(0) random.seed(0) max_model_len = 4096 - for batch_size, next_n in [(4, 1), (2, 2)]: + for batch_size, next_n in [ + (4, 1), + (2, 2) + ]: for heads, index_dim in [(32, 128)]: for avg_kv in (2048,): num_blocks, blocksize = max_model_len * 2, 64 @@ -252,9 +333,22 @@ def test_deepgemm_fp8_paged_mqa_logits(): q_fp8 = q.to(torch.float8_e4m3fn) kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) + + ref_logits = _ref_fp8_paged_mqa_logits( + q, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + ) + schedule_metadata = get_paged_mqa_logits_metadata( context_lens, blocksize, get_num_sms() ) + + + logits = fp8_paged_mqa_logits( q_fp8, kv_cache_fp8, @@ -264,16 +358,6 @@ def test_deepgemm_fp8_paged_mqa_logits(): schedule_metadata, max_model_len, ) - - ref_logits = _ref_fp8_paged_mqa_logits( - q, - kv_cache, - weights, - context_lens, - block_tables, - max_model_len, - ) - positions = ( torch.arange(max_model_len, device="cuda") .unsqueeze(0) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index d93136701014..8786821c4af1 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -582,6 +582,16 @@ def sparse_attn_indexer( total_seq_lens: int, topk_indices_buffer: Optional[torch.Tensor], ) -> torch.Tensor: + + """ + 1. Quantize k and store in kv_cache + 2. Compute the attention logits using the quantized k in kv_cache and q_fp8 + For the prefill stage, it computes the attention logits [num_tokens, num_kv_tokens] + For the decode stage, it computes the attention logits [num_decode_tokens, max_model_len] + 3. Select topk_tokens from the attention logits for each query position: [num_tokens, 2048] and update topk_indices_buffer + """ + + # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata # assert isinstance(attn_metadata, dict) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 93f7fd5725bd..95545ef11ad7 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -313,6 +313,13 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: if not self.scheduler.has_requests(): return {}, False scheduler_output = self.scheduler.schedule() + logger.info(( + f"scheduler_output.scheduled_new_reqs: {[(req.req_id, len(req.prompt_token_ids)) for req in scheduler_output.scheduled_new_reqs]} \n" + f"scheduler_output.scheduled_cached_reqs: {scheduler_output.scheduled_cached_reqs} \n" + f"scheduler_output.num_scheduled_tokens: {scheduler_output.num_scheduled_tokens} \n" + f"scheduler_output.total_num_scheduled_tokens: {scheduler_output.total_num_scheduled_tokens} \n" + f"scheduler_output.finished_req_ids: {scheduler_output.finished_req_ids} \n" + )) model_output = self.execute_model_with_error_logging( self.model_executor.execute_model, # type: ignore scheduler_output,