Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 72 additions & 5 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,9 @@ __global__ void concat_and_cache_ds_mla_kernel(
}

// The first two warps handle the NoPE part
const int8_t warp_idx = threadIdx.x >> 5;
const int8_t lane_idx = threadIdx.x & 31;
const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4);
const int8_t warp_idx = threadIdx.x >> 5; // 0,1
const int8_t lane_idx = threadIdx.x & 31; // 0..31
const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4); // 0..3

// Each thread handles 8 elements of NoPE
// Load the NoPE elements for this thread into registers
Expand Down Expand Up @@ -484,7 +484,7 @@ __global__ void concat_and_cache_ds_mla_kernel(
// The first lane of each half-warp writes the scale to kv_cache
if ((lane_idx == 0) || (lane_idx == 16)) {
float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx; // 128 + 0..3
kv_cache_32bit[dst_idx] = tile_scale;
}

Expand All @@ -505,6 +505,58 @@ __global__ void concat_and_cache_ds_mla_kernel(
*reinterpret_cast<const uint64_t*>(result);
}


// void indexer_k_quant_and_cache(
// torch::Tensor& k, // [num_tokens, head_dim]
// torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
// torch::Tensor& slot_mapping, // [num_tokens]
// int64_t quant_block_size, // quantization block size
// const std::string& scale_fmt) {
// int num_tokens = k.size(0);
// int head_dim = k.size(1);
// int cache_block_size = kv_cache.size(1);
// int cache_stride = kv_cache.size(2);
// bool use_ue8m0 = scale_fmt == "ue8m0";

// TORCH_CHECK(k.device() == kv_cache.device(),
// "k and kv_cache must be on the same device");
// TORCH_CHECK(k.device() == slot_mapping.device(),
// "k and slot_mapping must be on the same device");
// TORCH_CHECK(head_dim % quant_block_size == 0,
// "head_dim must be divisible by quant_block_size");

// constexpr int vec_size = 4;
// head_dim = 128
// quant_block_size = 128
// vec_size = 4
// grid_y = (128 + 128 * 4 - 1) / (128 * 4) = 1
// dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) /
// (quant_block_size * vec_size));
// dim3 block(32, vec_size);
// const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
// const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
// CALL_INDEXER_K_QUANT_AND_CACHE);

// // Macro to dispatch the kernel based on the data type.
// #define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
// vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
// <<<grid, block, 0, stream>>>( \
// reinterpret_cast<KV_T*>(k.data_ptr()), \
// reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
// slot_mapping.data_ptr<int64_t>(), head_dim, quant_block_size, \
// cache_block_size, cache_stride, use_ue8m0);


// void indexer_k_quant_and_cache(
// torch::Tensor& k, // [num_tokens, head_dim]
// torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
// torch::Tensor& slot_mapping, // [num_tokens]
// int64_t quant_block_size, // quantization block size
// const std::string& scale_fmt) {

// __nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void indexer_k_quant_and_cache_kernel(
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
Expand All @@ -513,18 +565,30 @@ __global__ void indexer_k_quant_and_cache_kernel(
const int head_dim, // dimension of each head
const int quant_block_size, // quantization block size
const int cache_block_size, // cache block size
const int cache_stride, // stride for each token in kv_cache
const int cache_stride, // stride for each token in kv_cache, 132
const bool use_ue8m0 // use ue8m0 scale format
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
// For head_dim = 0, blockIdx.y = 0
// 32 4
// blockDim.y * blockDim.x = 128
// (blockIdx.y * 128 + threadIdx.y * 4 + threadIdx.x) * 4 = thradIdx.y * 4 + threadIdx.x
// theadIdx.y, threadIdx.x
// (0, 0): 0, 1, 2, 3
// (0, 1): 4, 5, 6, 7
// (0, 2): 8, 9, 10, 11
// (0, 3): 12, 13, 14, 15
// (1, 0): 16, 17, 18, 19
// ...
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
threadIdx.y * blockDim.x + threadIdx.x) *
VEC_SIZE;
const int64_t slot_idx = slot_mapping[token_idx];
const int64_t block_idx = slot_idx / cache_block_size;
const int64_t block_offset = slot_idx % cache_block_size;

// For head_dim 128, only thread 0...31 are active
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
return;
Expand Down Expand Up @@ -572,6 +636,9 @@ __global__ void indexer_k_quant_and_cache_kernel(
}
}




template <int BLOCK_Y_SIZE>
__global__ void cp_gather_indexer_k_quant_cache_kernel(
const char* __restrict__ kv_cache, // [num_blocks, block_size,
Expand Down
162 changes: 123 additions & 39 deletions tests/kernels/attention/test_deepgemm_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,44 +63,57 @@ def _generate_cp_test_data(seq_len: int, seq_len_kv: int):


def _ref_fp8_mqa_logits(
q: torch.Tensor,
kv: torch.Tensor,
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
q: torch.Tensor, # [seq_len, heads, head_dim]
kv: torch.Tensor, # [seq_len_kv, head_dim]
weights: torch.Tensor, # [seq_len, heads]
cu_seqlen_ks: torch.Tensor, # [seq_len]
cu_seqlen_ke: torch.Tensor, # [seq_len]
):
breakpoint()
seq_len_kv = kv.shape[0]

k = kv
q = q.float()
k = k.float()

# mask_lo: [seq_len, seq_len_kv]
# mask_hi: [seq_len, seq_len_kv]
mask_lo = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
)
mask_hi = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
)
# mask: [seq_len, seq_len_kv]
mask = mask_lo & mask_hi
# q: [seq_len, heads, head_dim] -> [heads, seq_len, head_dim]
# k: [seq_len_kv, head_dim] -> [head_dim, seq_len_kv]
# [heads, seq_len, head_dim] @ [head_dim, seq_len_kv] -> [heads, seq_len, seq_len_kv]
score = torch.einsum("mhd,nd->hmn", q, k)
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
# [seq_len, heads] -> [heads, seq_len, 1] -> [heads, seq_len, 1]
cur_weight = weights.unsqueeze(-1).transpose(0, 1)
# [heads, seq_len, seq_len_kv] * [heads, seq_len, 1] -> [heads, seq_len, seq_len_kv]
# [heads, seq_len, seq_len_kv] -> [seq_len, seq_len_kv]
logits = (score.relu() * cur_weight).sum(dim=0)
logits = logits.masked_fill(~mask, float("-inf"))

return logits


@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
)
# @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
# @pytest.mark.skipif(
# not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
# )
def test_deepgemm_fp8_mqa_logits():
torch.manual_seed(0)
random.seed(0)
num_heads, head_dim = 32, 128
for seq_len in (512,):
for seq_len_kv in (1024,):
for disable_cp in (False, True):
for disable_cp in (
False,
# True,
):
q = torch.randn(
seq_len,
num_heads,
Expand All @@ -125,7 +138,6 @@ def test_deepgemm_fp8_mqa_logits():

q_fp8 = q.to(torch.float8_e4m3fn)
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False)
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)

ref_logits = _ref_fp8_mqa_logits(
q=q,
Expand All @@ -134,6 +146,8 @@ def test_deepgemm_fp8_mqa_logits():
cu_seqlen_ks=ks,
cu_seqlen_ke=ke,
)
breakpoint()
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)

ref_neginf_mask = ref_logits == float("-inf")
neginf_mask = logits == float("-inf")
Expand All @@ -146,15 +160,40 @@ def test_deepgemm_fp8_mqa_logits():


def _ref_fp8_paged_mqa_logits(
q: torch.Tensor,
kv_cache: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
max_model_len: int,
q: torch.Tensor, # [batch_size, next_n, heads, head_dim], [batch_size, 1, heads, 128]
kv_cache: torch.Tensor, # [num_blocks, block_size, 1, head_dim], [num_blocks, 64, 1, 128]
weights: torch.Tensor, # [batch_size * next_n, heads] , [batch_size * 1, heads]
context_lens: torch.Tensor, # [batch_size]
block_tables: torch.Tensor, # [batch_size, max_block_len]
max_model_len: int,
):
"""
This is a reference implementation of the fp8_paged_mqa_logits function.
The pseudo code is as follows:
# For each sequence in the batch, update its logits
for batch_idx in range(batch_size):
cur_context_len = context_lens[batch_idx].item()
block_nums = (cur_context_len + block_size - 1) // block_size
cur_query = q[batch_idx] # [1, heads, head_dim]
# For each block in the current sequence, compute the attention scores
for block_rk in range(block_nums):
block_idx = block_tables[batch_idx][block_rk]
cur_key_in_block = kv_cache[block_idx] # [block_size, 1, head_dim]
# [heads, 1, head_dim] @ [1, head_dim, block_size] -> [heads, 1, block_size]
cur_attn_score_temp = cur_query.transpose(0, 1) @ cur_key_in_block.transpose(0, 1).transpose(1, 2)
# Apply ReLU
act_cur_attn_score_temp = torch.relu(cur_attn_score_temp.to(torch.float32))
# Apply the head weights, the weighted average over all heads
act_cur_attn_score_temp_weighted = act_cur_attn_score_temp * weights[i].transpose(0, 1)[..., None]
# Sum over all heads to get the final logits for the each key position in the block
cur_attn_score = act_cur_attn_score_temp_weighted.sum(dim=0) # [1, block_size]
# Update the logits for the current block, the formula ignores the mask
logits[batch_idx][block_rk*block_size:(block_rk+1)*block_size] = cur_attn_score
"""

batch_size, next_n, _, _ = q.size()
_, block_size, _, _ = kv_cache.size()
# [batch_size, max_model_len]
logits = torch.full(
[batch_size * next_n, max_model_len],
float("-inf"),
Expand All @@ -163,49 +202,91 @@ def _ref_fp8_paged_mqa_logits(
)
context_lens_list = context_lens.tolist()
for i in range(batch_size):
# breakpoint()
context_len = context_lens_list[i]
q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
# weights: [batch_size, heads] -> [1, heads] -> [heads, 1]
# weight_slice: [heads, 1]
weight_slice = (
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
)
# For the current batch item, iterate over all blocks
# Based on the current context length and block size, there are ceil(context_len / block_size) blocks
# For example, if context_len=2447 and block_size=64, there are 39 blocks
for block_rk in range(cdiv(context_len, block_size)):
block_idx = block_tables[i][block_rk]
# qx: [1, heads, head_dim]
# kx: [block_size, 1, head_dim]
qx, kx = q[i], kv_cache[block_idx]
# k_offsets: [block_size]
k_offsets = torch.arange(
block_rk * block_size,
(block_rk + 1) * block_size,
device="cuda",
)
# mask: [1, block_size]
mask = (k_offsets[None, :] < context_len) & (
k_offsets[None, :] <= q_offsets[:, None]
)
# qx: [1, heads, head_dim] -> [heads, 1, head_dim]
qx_temp = qx.transpose(0, 1).contiguous()
# kx: [block_size, 1, head_dim] -> [1, block_size, head_dim] -> [1, head_dim, block_size]
kx_temp = kx.transpose(0, 1).transpose(1, 2).contiguous()
# [heads, 1, head_dim] @ [1, head_dim, block_size] -> [heads, 1, block_size]
# For given query, compute attention scores with all keys in the block
temp_s = (qx_temp @ kx_temp).to(torch.float32) # [heads, 1, block_size]
s = torch.where(
mask[None, :, :],
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
logits.dtype
),
temp_s,
float("-inf"),
)
s = torch.relu(s) * weight_slice[..., None]
# s = torch.where(
# mask[None, :, :],
# (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
# logits.dtype
# ),
# float("-inf"),
# )
#
# cur_weight: [heads, 1, 1]
cur_weight = weight_slice[..., None]
# [heads, 1, block_size] * [heads, 1, 1] -> [heads, 1, block_size]
s = torch.relu(s) * cur_weight
# sum the score over all heads
# [heads, 1, block_size] -> [1, block_size]
s = s.sum(dim=0)
# Update the logits for the current batch item and current block
# k_offsets: stands for the token positions of the keys in the current block
# q_offsets: stands for the token positions of the queries in the current next_n
cur_batch_start = i * next_n
cur_batch_end = (i + 1) * next_n
cur_block_start = block_rk * block_size
cur_block_end = (block_rk + 1) * block_size
print(f"Update logits for batch [{cur_batch_start}, {cur_batch_end}), block [{cur_block_start}, {cur_block_end})")
if cur_block_end == 2496:
breakpoint()
cur_logist = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
logits[
i * next_n : (i + 1) * next_n,
block_rk * block_size : (block_rk + 1) * block_size,
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
] = cur_logist
return logits


@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
)
# @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
# @pytest.mark.skipif(
# not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
# )
def test_deepgemm_fp8_paged_mqa_logits():
torch.manual_seed(0)
random.seed(0)

max_model_len = 4096
for batch_size, next_n in [(4, 1), (2, 2)]:
for batch_size, next_n in [
(4, 1),
(2, 2)
]:
for heads, index_dim in [(32, 128)]:
for avg_kv in (2048,):
num_blocks, blocksize = max_model_len * 2, 64
Expand Down Expand Up @@ -252,9 +333,22 @@ def test_deepgemm_fp8_paged_mqa_logits():
q_fp8 = q.to(torch.float8_e4m3fn)
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)


ref_logits = _ref_fp8_paged_mqa_logits(
q,
kv_cache,
weights,
context_lens,
block_tables,
max_model_len,
)

schedule_metadata = get_paged_mqa_logits_metadata(
context_lens, blocksize, get_num_sms()
)



logits = fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
Expand All @@ -264,16 +358,6 @@ def test_deepgemm_fp8_paged_mqa_logits():
schedule_metadata,
max_model_len,
)

ref_logits = _ref_fp8_paged_mqa_logits(
q,
kv_cache,
weights,
context_lens,
block_tables,
max_model_len,
)

positions = (
torch.arange(max_model_len, device="cuda")
.unsqueeze(0)
Expand Down
Loading