-
Notifications
You must be signed in to change notification settings - Fork 568
feat: Add flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache (fused RoPE + Q + KV cache, supports MLA/GQA/MHA) #2037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughAdds a fused RoPE + FP8 quantize + paged-KV append flow: new unified CUDA kernel and host wrappers, new TVM FFI binding and C++ entry, Python API and op registrations, tests and benchmark, a GPU bandwidth helper, and new dispatch macros for interleave/rope-dim specialization. Changes
Sequence Diagram(s)sequenceDiagram
participant User as Python caller
participant PyAPI as rope_quantize_fp8_append_paged_kv_cache()
participant Validate as Validation & Dispatch
participant Op as Registered Op
participant FFI as rope_quantize_append_paged_kv_cache (FFI)
participant Kernel as RopeQuantizeAppendPagedKVCacheKernel
User->>PyAPI: call with Q/K/V, cos_sin, paged-kv metadata, flags
PyAPI->>Validate: validate shapes/dtypes, infer layout/is_neox, allocate q outputs
Validate->>Op: select MLA vs GQA/MHA, compute kv_layout_code, interleave, scales, page_size
Op->>FFI: invoke FFI with tensors, strides, layout code, page size, flags
FFI->>Kernel: construct paged_kv_t / paged_kv_mla_t and launch kernel
rect rgb(240,248,255)
Note over Kernel: apply RoPE → FP8 quantize Q/K/V → append K/V to paged cache (MLA vs GQA/MHA branch)
end
Kernel-->>FFI: return CUDA status
FFI-->>Op: return
Op-->>PyAPI: complete
PyAPI-->>User: return quantized Q outputs (K/V stored in paged caches)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (2)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
csrc/flashinfer_rope_binding.cu(1 hunks)csrc/rope.cu(1 hunks)flashinfer/rope.py(2 hunks)include/flashinfer/pos_enc.cuh(3 hunks)tests/attention/test_rope.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
csrc/flashinfer_rope_binding.cu (1)
csrc/rope.cu (4)
rope_quantize(271-422)rope_quantize(271-275)rope_quantize_append_paged_kv_cache(430-622)rope_quantize_append_paged_kv_cache(430-438)
csrc/rope.cu (2)
flashinfer/comm/cuda_ipc.py (2)
cudaSetDevice(149-150)cudaGetErrorString(146-147)csrc/tvm_ffi_utils.h (1)
get_stream(272-274)
tests/attention/test_rope.py (4)
benchmarks/bench_rope_quantize_fp8.py (1)
FlashInferRotaryEmbedding(19-88)flashinfer/page.py (2)
get_seq_lens(212-235)get_batch_indices_positions(157-209)flashinfer/rope.py (1)
rope_quantize_fp8_append_paged_kv_cache(1407-1670)tests/test_helpers/rope_reference.py (1)
forward_native(194-232)
flashinfer/rope.py (3)
flashinfer/utils.py (5)
register_custom_op(272-281)register_custom_op(291-310)register_fake_op(283-287)register_fake_op(312-317)TensorLayout(43-45)csrc/flashinfer_rope_binding.cu (1)
rope_quantize_append_paged_kv_cache(48-54)csrc/rope.cu (2)
rope_quantize_append_paged_kv_cache(430-622)rope_quantize_append_paged_kv_cache(430-438)
🪛 Ruff (0.14.3)
flashinfer/rope.py
1571-1571: Avoid specifying long messages outside the exception class
(TRY003)
1596-1596: Avoid specifying long messages outside the exception class
(TRY003)
1600-1600: Avoid specifying long messages outside the exception class
(TRY003)
1609-1612: Avoid specifying long messages outside the exception class
(TRY003)
1614-1617: Avoid specifying long messages outside the exception class
(TRY003)
1626-1629: Avoid specifying long messages outside the exception class
(TRY003)
1631-1633: Avoid specifying long messages outside the exception class
(TRY003)
|
@yzh119 @pavanimajety do you think a benchmarking script similar to https://github.com/flashinfer-ai/flashinfer/blob/main/benchmarks/bench_rope_quantize_fp8.py is needed here? AFAIK there is not really a direct pytorch equivalent of append-kv-cache. However, we could still have a benchmark comparing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
flashinfer/rope.py (1)
1591-1596: Missing validation: GQA/MHA requires non-None V tensor.The code validates that MLA does not have a V input (correctly raises an error if V is not None for MLA). However, there is no corresponding validation for the GQA/MHA path to ensure V is not None. If V is None for GQA/MHA, the kernel will receive an invalid tensor, likely causing a crash or incorrect results.
Add validation before line 1591:
+ # Validate V input based on architecture + if not is_mla and v is None: + raise ValueError("GQA/MHA requires V input (cannot be None)") + # Handle V input for MLA (create empty dummy tensor, not used) if is_mla:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/rope.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/rope.py (3)
flashinfer/utils.py (5)
register_custom_op(272-281)register_custom_op(291-310)register_fake_op(283-287)register_fake_op(312-317)TensorLayout(43-45)csrc/flashinfer_rope_binding.cu (1)
rope_quantize_append_paged_kv_cache(48-54)csrc/rope.cu (2)
rope_quantize_append_paged_kv_cache(430-622)rope_quantize_append_paged_kv_cache(430-438)
🪛 Ruff (0.14.3)
flashinfer/rope.py
1571-1571: Avoid specifying long messages outside the exception class
(TRY003)
1596-1596: Avoid specifying long messages outside the exception class
(TRY003)
1600-1600: Avoid specifying long messages outside the exception class
(TRY003)
1609-1612: Avoid specifying long messages outside the exception class
(TRY003)
1614-1617: Avoid specifying long messages outside the exception class
(TRY003)
1626-1629: Avoid specifying long messages outside the exception class
(TRY003)
1631-1633: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/rope.py (2)
229-325: LGTM! Custom op and fake op names are now aligned.The custom op and fake op are both correctly registered under the same symbol
"flashinfer::rope_quantize_append_paged_kv_cache", resolving the past review concern about FakeTensor/meta tracing. Themutates_argsspecification correctly includes all output tensors that will be modified in-place.
1407-1670: Well-structured function with comprehensive documentation.The overall implementation is well-organized with clear separation of concerns:
- Architecture detection and input validation
- Output tensor allocation
- Cache unpacking and validation
- Custom op dispatch
The docstring is exceptionally thorough with detailed parameter descriptions and complete working examples for both MLA and GQA/MHA use cases. The function correctly returns only the quantized Q tensors since K/V are written directly to the paged cache.
|
Also some other questions for reviewers:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (2)
flashinfer/rope.py (2)
1639-1639: Consider moving import to module level.Importing
TensorLayoutinside the function works but is unconventional. If circular imports are not a concern, consider moving this to the top-level imports for consistency with Python style conventions.
1571-1571: Consider extracting long error messages to reduce TRY003 violations.Ruff flags multiple long error messages defined inline (TRY003). While functional, consider extracting repeated error message patterns to constants or helper functions to improve maintainability and comply with the style guide.
Also applies to: 1596-1596, 1609-1612, 1614-1617, 1626-1629, 1631-1633
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/rope.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/rope.py (3)
flashinfer/utils.py (5)
register_custom_op(272-281)register_custom_op(291-310)register_fake_op(283-287)register_fake_op(312-317)TensorLayout(43-45)csrc/flashinfer_rope_binding.cu (1)
rope_quantize_append_paged_kv_cache(48-54)csrc/rope.cu (2)
rope_quantize_append_paged_kv_cache(430-622)rope_quantize_append_paged_kv_cache(430-438)
🪛 Ruff (0.14.3)
flashinfer/rope.py
1571-1571: Avoid specifying long messages outside the exception class
(TRY003)
1596-1596: Avoid specifying long messages outside the exception class
(TRY003)
1600-1600: Avoid specifying long messages outside the exception class
(TRY003)
1609-1612: Avoid specifying long messages outside the exception class
(TRY003)
1614-1617: Avoid specifying long messages outside the exception class
(TRY003)
1626-1629: Avoid specifying long messages outside the exception class
(TRY003)
1631-1633: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/rope.py (1)
1407-1670: Excellent comprehensive documentation and architecture detection.The function provides thorough documentation with clear examples for both MLA and GQA/MHA use cases. The automatic architecture detection based on tensor shapes (line 1574) is elegant, and the cache validation logic correctly distinguishes between the two architectures with appropriate dummy tensor creation for unused cache types.
also get rid of docstring examples; too long and too complicated. Maybe just reference test code instead as an example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
flashinfer/rope.py (1)
1587-1617: Consider validating kv_layout for clearer error messages.At line 1590, accessing
TensorLayout[kv_layout].valuewill raise aKeyErrorifkv_layoutis not "NHD" or "HND". While the default value prevents most issues, an explicit validation would provide a clearer error message.Consider adding validation before line 1590:
+ # Validate kv_layout + if kv_layout not in ("NHD", "HND"): + raise ValueError( + f"kv_layout must be 'NHD' or 'HND', got '{kv_layout}'" + ) + # Import TensorLayout enum from .utils import TensorLayout
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
flashinfer/rope.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/rope.py (3)
flashinfer/utils.py (5)
register_custom_op(272-281)register_custom_op(291-310)register_fake_op(283-287)register_fake_op(312-317)TensorLayout(43-45)csrc/flashinfer_rope_binding.cu (1)
rope_quantize_append_paged_kv_cache(48-54)csrc/rope.cu (2)
rope_quantize_append_paged_kv_cache(430-622)rope_quantize_append_paged_kv_cache(430-438)
🪛 Ruff (0.14.3)
flashinfer/rope.py
1514-1514: Avoid specifying long messages outside the exception class
(TRY003)
1539-1539: Avoid specifying long messages outside the exception class
(TRY003)
1543-1543: Avoid specifying long messages outside the exception class
(TRY003)
1552-1555: Avoid specifying long messages outside the exception class
(TRY003)
1557-1560: Avoid specifying long messages outside the exception class
(TRY003)
1570-1573: Avoid specifying long messages outside the exception class
(TRY003)
1575-1578: Avoid specifying long messages outside the exception class
(TRY003)
1580-1582: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (3)
flashinfer/rope.py (3)
229-296: LGTM! Custom operator registration is correct.The registration properly declares all mutated cache tensors and routes to the CUDA kernel. The signature matches the binding in csrc/flashinfer_rope_binding.cu.
299-325: LGTM! Fake operator registration aligns with custom op.The fake op is correctly registered under the same name as the custom op, enabling proper FakeTensor/meta tracing and torch.compile support.
1535-1586: Well-structured architecture detection and cache validation.The function correctly distinguishes MLA (2D K) from GQA/MHA (3D K) and validates that:
- MLA requires V=None and uses 3D ckv_cache/kpe_cache
- GQA/MHA requires V tensor and uses 4D k_cache/v_cache
The V validation at lines 1569-1573 properly addresses the concern from previous reviews.
|
/bot run |
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, would you mind adding benchmark suites for this operator?
| uint32_t batch_size = kv_indptr.size(0) - 1; | ||
| QKVLayout kv_layout = QKVLayout(kv_layout_code); | ||
|
|
||
| if (is_mla) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we unify the two branches?
In C++ side, we assume K tensor is 3D.
At python side, if we found K tensor is 2D, we unsqueeze its dimension 1.
csrc/rope.cu
Outdated
| uint32_t k_rope_in_stride_h, k_nope_in_stride_h; | ||
| uint32_t v_in_stride = 0, v_in_stride_h = 0; | ||
|
|
||
| if (is_mla) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
pavanimajety
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Kahyun, LGTM overall! A few comments.
|
@elvischenv could you also help to review this? thanks! |
|
[SUCCESS] Pipeline #37900083: 13/17 passed |
|
@kahyunnam could we get some basic perf benchmarking results? Just want to make sure we are achieving reasonably good DRAM bw or not too much worse than TRT-LLM's kernel. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
include/flashinfer/pos_enc.cuh (1)
756-769: Drop redundant IS_MLA template parameter
IS_MLAduplicatesCacheT. Prefer deriving it:-template <bool interleave, uint32_t vec_size, uint32_t bdx, typename DType, typename IdType, - typename QuantType, bool IS_MLA, typename CacheT> +template <bool interleave, uint32_t vec_size, uint32_t bdx, typename DType, typename IdType, + typename QuantType, typename CacheT> __global__ void RopeQuantizeAppendPagedKVCacheKernel( ... ) { + constexpr bool IS_MLA = + std::is_same<CacheT, paged_kv_mla_t<QuantType, IdType>>::value;Cleaner API and fewer instantiations to maintain.
🧹 Nitpick comments (4)
include/flashinfer/pos_enc.cuh (1)
1056-1121: Host wrapper duplication and tailoring
RopeQuantizeAppendPagedKVCacheandRopeQuantizeAppendPagedMLACachelargely duplicate launch setup. Consider a single templated host wrapper onCacheTwith a small traits struct for MLA vs GQA (presence of V, number of KV heads, strides). Reduces surface and keeps dispatch logic in one place.tests/attention/test_rope.py (2)
488-517: Add one tail-dimension case to catch V tail bugsAll current head_dim cases are multiples of 16. Add a case where
head_dim = rope_dim + no_rope_dimisn’t divisible byvec_size(e.g., rope_dim=64, no_rope_dim=48 for bf16 → 112) to exercise V tail handling.
590-658: Decode metadata setup clarityNice: using get_seq_lens/get_batch_indices_positions mirrors runtime flow. Consider factoring metadata build into a helper to reuse across tests.
flashinfer/rope.py (1)
1625-1656: Layout code derivation and dispatchUsing TensorLayout enum and passing code down is clean. Consider small doc note that only "NHD"/"HND" are supported.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
csrc/rope.cu(1 hunks)flashinfer/rope.py(5 hunks)include/flashinfer/pos_enc.cuh(4 hunks)tests/attention/test_rope.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/rope.cu (2)
flashinfer/comm/cuda_ipc.py (2)
cudaSetDevice(149-150)cudaGetErrorString(146-147)csrc/tvm_ffi_utils.h (1)
get_stream(272-274)
tests/attention/test_rope.py (4)
benchmarks/bench_rope_quantize_fp8.py (1)
FlashInferRotaryEmbedding(19-88)flashinfer/page.py (2)
get_seq_lens(212-235)get_batch_indices_positions(157-209)flashinfer/rope.py (1)
rope_quantize_fp8_append_paged_kv_cache(1424-1657)tests/test_helpers/rope_reference.py (1)
forward_native(194-232)
flashinfer/rope.py (3)
flashinfer/utils.py (5)
register_custom_op(272-281)register_custom_op(291-310)register_fake_op(283-287)register_fake_op(312-317)TensorLayout(43-45)csrc/rope.cu (2)
rope_quantize_append_paged_kv_cache(430-617)rope_quantize_append_paged_kv_cache(430-438)csrc/flashinfer_rope_binding.cu (1)
rope_quantize_append_paged_kv_cache(48-54)
🪛 Ruff (0.14.4)
flashinfer/rope.py
1531-1531: Avoid specifying long messages outside the exception class
(TRY003)
1577-1577: Avoid specifying long messages outside the exception class
(TRY003)
1581-1581: Avoid specifying long messages outside the exception class
(TRY003)
1590-1593: Avoid specifying long messages outside the exception class
(TRY003)
1595-1598: Avoid specifying long messages outside the exception class
(TRY003)
1608-1611: Avoid specifying long messages outside the exception class
(TRY003)
1613-1616: Avoid specifying long messages outside the exception class
(TRY003)
1618-1620: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (11)
include/flashinfer/pos_enc.cuh (2)
388-398: Good: avoid unnecessary cos/sin loadsGating cos/sin loads with
(by < k_rope_end)matches the Q/K‑RoPE blocks and skips non‑RoPE paths. Looks correct.
792-809: Verify page offset computation and page_size usage
paged_kv_like.page_size.divmod(...)suggestspage_sizeis a struct, but the expression multiplies bypaged_kv_like.page_sizeas if it were a scalar. Please confirmpage_sizesupports both.divmod()and scalar multiplication, or split into:
- compute
base_pages = paged_kv_like.indptr[batch_indices[idx]],- compute
delta_pages, entry_idx = divmod(positions[idx], page_size),- set
page_iter = base_pages + delta_pages.csrc/rope.cu (3)
489-505: Require K as 3D here: doc calloutC++ binding assumes 3D K for both paths and checks head_dim==1 for MLA. Good. Ensure Python always unsqueezes MLA K tensors to 3D before calling; you already do this in flashinfer/rope.py. Please add a brief comment here to prevent future regressions.
514-519: Cache dims: enforce 4D only (GQA/MHA)Checks match the kernel (
CHECK_DIM(4, k_cache)/v_cache). Good alignment with Python validation.
552-579: MLA path construction LGTM; minor: pass null for V explicitlySetting
v_into an empty tensor on the Python side and validating here is fine; alternatively, you can document thatv_inis unused in MLA and may be null. No action needed.tests/attention/test_rope.py (2)
397-401: Determinism: good seedingSetting both CPU and CUDA seeds ensures reproducible FP8 rounding paths. LGTM.
1245-1313: Cache non-regression checks in decodeComparing pre/post snapshots at exact indices is solid. Good guard against accidental overwrites.
flashinfer/rope.py (4)
229-239: Custom/fake op names alignedRegistration now matches
flashinfer::rope_quantize_append_paged_kv_cache. This unblocks FakeTensor/meta tracing.
1286-1422: Optional nope tensors: good normalizationCreating zero-width nope tensors avoids special-casing downstream and matches C++ assumptions. LGTM.
1530-1560: Quant dtype inference and Q output allocationInfers from provided outputs or defaults to e4m3fn; allocates only when needed. Simple and correct.
1606-1620: GQA/MHA path: V required validationRaising a clear error when
v is Noneprevents obscure kernel failures. Good.
| for (uint32_t j = 0; j < v_chunks; ++j) { | ||
| uint32_t v_elem_offset = j * rope_chunk_size; | ||
| if (v_elem_offset + tx * vec_size < head_dim_total) { | ||
| vec_t<float, vec_size> v_vec; | ||
| v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size); | ||
| #pragma unroll | ||
| for (uint32_t i = 0; i < vec_size; ++i) { | ||
| v_vec[i] = v_vec[i] * quant_scale_kv; | ||
| } | ||
| QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx, | ||
| v_elem_offset + tx * vec_size); | ||
| v_vec.cast_store(v_ptr); | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Potential out-of-bounds on V tail vector loads
In the V path you do vectorized cast_load with a simple bound if (v_elem_offset + tx*vec_size < head_dim_total). If head_dim_total isn’t a multiple of vec_size, the last lane may overread. Either:
- enforce
head_dim_total % vec_size == 0at dispatch and document it, or - add a masked/scalar tail for the last partial vector.
Suggested fix (scalar tail fallback):
- if (v_elem_offset + tx * vec_size < head_dim_total) {
- vec_t<float, vec_size> v_vec;
- v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size);
+ uint32_t start = v_elem_offset + tx * vec_size;
+ if (start < head_dim_total) {
+ vec_t<float, vec_size> v_vec;
+ uint32_t remaining = head_dim_total - start;
+ if (remaining >= vec_size) {
+ v_vec.cast_load(v_in_ptr + start);
+ } else {
+ #pragma unroll
+ for (uint32_t t = 0; t < remaining; ++t) v_vec[t] = v_in_ptr[start + t];
+ #pragma unroll
+ for (uint32_t t = remaining; t < vec_size; ++t) v_vec[t] = 0.f;
+ }
@@
- QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx,
- v_elem_offset + tx * vec_size);
- v_vec.cast_store(v_ptr);
+ QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx, start);
+ if (remaining >= vec_size) {
+ v_vec.cast_store(v_ptr);
+ } else {
+ #pragma unroll
+ for (uint32_t t = 0; t < remaining; ++t) v_ptr[t] = static_cast<QuantType>(v_vec[t]);
+ }
}I'll verify this concern by examining the code context and understanding the constraints. Let me gather information about the vectorization logic, dispatch constraints, and similar patterns in the codebase.
🏁 Script executed:
#!/bin/bash
# Examine the function containing the flagged code and surrounding context
cat include/flashinfer/pos_enc.cuh | sed -n '890,950p'Length of output: 2838
🏁 Script executed:
#!/bin/bash
# Search for the vec_t template definition and cast_load semantics
rg -n "template.*vec_t" -A 10 --type=cpp --type=cuh include/Length of output: 27
🏁 Script executed:
#!/bin/bash
# Check for head_dim_total constraints or alignment requirements in dispatch/initialization
rg -n "head_dim_total" -B 3 -A 3 --type=cpp --type=cuh include/flashinfer/pos_enc.cuh | head -80Length of output: 27
🏁 Script executed:
#!/bin/bash
# Search for K path (kv_in_ptr or similar) to see if same pattern exists
rg -n "cast_load.*k_in_ptr|cast_load.*kv_in_ptr" -B 5 -A 5 --type=cpp --type=cuh include/flashinfer/pos_enc.cuhLength of output: 27
🏁 Script executed:
#!/bin/bash
# Look for existing partial-vector or tail handling patterns in the file
rg -n "remaining|partial|tail" -B 2 -A 2 --type=cpp --type=cuh include/flashinfer/pos_enc.cuh | head -60Length of output: 27
🏁 Script executed:
#!/bin/bash
# Search for rope_chunk_size and vec_size alignment patterns
rg -n "rope_chunk_size|vec_size" -B 2 -A 2 --type=cpp --type=cuh include/flashinfer/pos_enc.cuh | head -100Length of output: 27
🏁 Script executed:
#!/bin/bash
# Check for dispatch-time template constraints or requirements
rg -n "static_assert|rope_chunk_size.*%|vec_size.*%" --type=cpp --type=cuh include/Length of output: 27
🏁 Script executed:
#!/bin/bash
# Look at the full function signature and template parameters
rg -n "void.*rope.*page.*kv|kernel.*rope.*page.*kv" -A 30 --type=cpp --type=cuh include/flashinfer/pos_enc.cuh | head -120Length of output: 27
🏁 Script executed:
#!/bin/bash
# Search for similar bounds checks and how they handle partial vectors
rg -n "if.*\+ tx \* vec_size.*head_dim" -B 3 -A 10 --type=cpp --type=cuh include/flashinfer/pos_enc.cuhLength of output: 27
🏁 Script executed:
#!/bin/bash
# Check if there's documentation or comments about alignment requirements
rg -n "multiple|divisible|align|constraint" -B 1 -A 1 --type=cpp --type=cuh include/flashinfer/pos_enc.cuh | head -80Length of output: 27
🏁 Script executed:
#!/bin/bash
# Look for quantization and constraint information
fd ".*\.cuh$|.*\.cu$" include/flashinfer/ -x grep -l "head_dim_total" | head -5Length of output: 30
🏁 Script executed:
#!/bin/bash
# Check the K path more carefully - look for k_vec or similar load patterns
rg -n "k_vec.*cast_load|kv_vec.*cast_load" -B 3 -A 8 --type=cpp --type=cuh include/flashinfer/pos_enc.cuhLength of output: 27
🏁 Script executed:
#!/bin/bash
# Look at function header/docstring to understand parameter constraints
sed -n '800,920p' include/flashinfer/pos_enc.cuh | grep -A 100 "^template"Length of output: 0
🏁 Script executed:
#!/bin/bash
# Search for cast_load implementation to understand its behavior
rg -n "cast_load" -B 2 -A 5 --type=cpp --type=cuh include/ | head -100Length of output: 27
🏁 Script executed:
#!/bin/bash
# Check if this is a paged attention kernel and whether alignment is guaranteed
rg -n "paged.*rope\|rope.*paged" -B 5 -A 15 --type=cpp --type=cuh include/flashinfer/pos_enc.cuh | head -150Length of output: 27
🏁 Script executed:
#!/bin/bash
# Look for the function name context by searching around the snippet
sed -n '850,930p' include/flashinfer/pos_enc.cuhLength of output: 3473
🏁 Script executed:
#!/bin/bash
# Check for any comment or specification about head_dim constraints
grep -n "head_dim" include/flashinfer/pos_enc.cuh | head -30Length of output: 1878
🏁 Script executed:
#!/bin/bash
# Look at where rope_chunk_size is defined or constrained
rg -n "rope_chunk_size\s*=" --type=cpp --type=cuh include/flashinfer/pos_enc.cuhLength of output: 27
🏁 Script executed:
#!/bin/bash
# Check the struct/class definition for paged_kv_like to understand data layout
rg -n "paged_kv_like\|PagedAttentionKVCacheBatched" -B 5 -A 20 --type=cpp --type=cuh include/flashinfer/ | head -150Length of output: 27
🏁 Script executed:
#!/bin/bash
# Examine the vec_t template from namespace/header files
find include/flashinfer -name "*.cuh" -o -name "*.h" | xargs grep -l "struct vec_t\|template.*vec_t" | head -3Length of output: 33
🏁 Script executed:
#!/bin/bash
# Look at vec_t implementation
rg -n "struct vec_t\|template.*class vec_t" -A 50 --type=cpp --type=cuh include/flashinfer/ | head -200Length of output: 27
🏁 Script executed:
#!/bin/bash
# Check rope_chunk_size initialization relative to vec_size in the kernel launch
rg -n "RoPEApplyPagedKVWithQuant" --type=cpp --type=cuh include/flashinfer/pos_enc.cuh | head -5Length of output: 27
🏁 Script executed:
#!/bin/bash
# Search entire file for function definition containing the lines in question
grep -n "kernel\|__global__\|__device__" include/flashinfer/pos_enc.cuh | grep -B 5 "91[0-9]"Length of output: 0
🏁 Script executed:
#!/bin/bash
# Find the kernel/function that contains lines 915-929
awk 'NR >= 900 && NR <= 930 { if (/^[a-zA-Z_].*[({]|__global__|__device__/) print NR": "$0 }' include/flashinfer/pos_enc.cuhLength of output: 0
🏁 Script executed:
#!/bin/bash
# Get broader function context - find the containing function
sed -n '700,930p' include/flashinfer/pos_enc.cuh | tail -230Length of output: 11018
🏁 Script executed:
#!/bin/bash
# Search for cast_load to understand whether it can safely read beyond bounds
rg -n "cast_load" include/ --type=cpp --type=cuh -A 3 -B 1 | grep -E "cast_load|^--$" | head -40Length of output: 27
I need to examine the vec_t implementation and understand the dispatch constraints. Let me continue verification.
🏁 Script executed:
#!/bin/bash
# Look at vec_t implementation
cat include/flashinfer/vec_dtypes.cuh | head -200Length of output: 5962
🏁 Script executed:
#!/bin/bash
# Search for cast_load specifically
grep -n "cast_load" include/flashinfer/vec_dtypes.cuh -A 5 -B 2 | head -80Length of output: 2381
🏁 Script executed:
#!/bin/bash
# Check dispatch for head_dim to see if alignment is enforced
grep -n "DISPATCH_HEAD_DIM\|HEAD_DIM" include/flashinfer/pos_enc.cuh -B 2 -A 5Length of output: 5141
🏁 Script executed:
#!/bin/bash
# Look for buffer allocation and padding in KV cache
fd -e cuh -e cu -e h include/flashinfer/ -x grep -l "paged_kv\|PagedKVCache" | head -5Length of output: 381
🏁 Script executed:
#!/bin/bash
# Check if rope_dim has constraints
grep -n "rope_dim" include/flashinfer/pos_enc.cuh | grep -E "constrain|assert|static|require|dispatch"Length of output: 0
🏁 Script executed:
#!/bin/bash
# Look for K nope unconditional cast_load
sed -n '875,895p' include/flashinfer/pos_enc.cuhLength of output: 864
🏁 Script executed:
#!/bin/bash
# Check cast_store in the same locations
grep -n "cast_store" include/flashinfer/pos_enc.cuh -B 1 -A 1Length of output: 2135
🏁 Script executed:
#!/bin/bash
# Look at DISPATCH_ROPE_DIM to see constraints
sed -n '715,750p' include/flashinfer/pos_enc.cuhLength of output: 2085
🏁 Script executed:
#!/bin/bash
# Check Q rope and Q nope paths for similar patterns
sed -n '795,850p' include/flashinfer/pos_enc.cuhLength of output: 2371
🏁 Script executed:
#!/bin/bash
# Look for vec_t definition structure
sed -n '1,100p' include/flashinfer/vec_dtypes.cuhLength of output: 3263
🏁 Script executed:
#!/bin/bash
# Check if there's unroll pragma impact on bounds
grep -n "pragma unroll\|#pragma" include/flashinfer/pos_enc.cuh | grep -A 1 -B 1 "91[2-9]\|92[0-5]"Length of output: 77
🏁 Script executed:
#!/bin/bash
# Search for any buffer overallocation or padding
grep -n "padding\|overallocate\|extra.*alloc\|buffer.*size" include/flashinfer/ -r --include="*.cuh" --include="*.cu" | head -20Length of output: 2607
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
benchmarks/bench_rope_quantize_fp8_append_cache.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_rope_quantize_fp8_append_cache.py (4)
flashinfer/testing/utils.py (1)
bench_gpu_time_with_cudagraph(855-969)tests/test_helpers/rope_reference.py (1)
RotaryEmbedding(117-232)flashinfer/page.py (2)
get_seq_lens(212-235)get_batch_indices_positions(157-209)flashinfer/rope.py (1)
rope_quantize_fp8_append_paged_kv_cache(1424-1657)
🪛 Ruff (0.14.4)
benchmarks/bench_rope_quantize_fp8_append_cache.py
59-59: Avoid specifying long messages outside the exception class
(TRY003)
198-198: Useless if-else condition
(RUF034)
302-302: subprocess call: check for execution of untrusted input
(S603)
313-313: Ambiguous variable name: l
(E741)
365-365: Do not catch blind exception: Exception
(BLE001)
370-370: Do not catch blind exception: Exception
(BLE001)
431-431: Do not catch blind exception: Exception
(BLE001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
8a6d1dc to
49b70ec
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @kahyunnam thanks for the great work, I run the benchmark on Hopper and it seems the achieved bandwidth reaches 80%+ region for large number of tokens which meets our previous standard of memory bound kernels so I suppose it should be ready to merge.
In the future we need to have better understanding about Blackwell's expected performance on memory-bound kernels.
I added one commit to your branch: 3a97554 to replace ncu (not all user environments have installed ncu) with pynvml to obtain gpu memory bandwidth, but let me know if you think it's not correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
flashinfer/utils.py (1)
258-296: Consider module-level NVML initialization for better resource management.The function initializes and shuts down NVML on each unique device query. While NVML's reference-counted init/shutdown makes this safe, and the
@functools.cachedecorator ensures it only happens once per device, there are cleaner alternatives:Optional improvement: Module-level NVML lifecycle management
_nvml_initialized = False def _ensure_nvml_init(): global _nvml_initialized if not _nvml_initialized: pynvml.nvmlInit() _nvml_initialized = True @functools.cache def get_gpu_memory_bandwidth(device: torch.device) -> float: """...""" if isinstance(device, str): device = torch.device(device) if device.type != "cuda": raise ValueError(f"Device must be a CUDA device, got {device}") device_index = device.index if device.index is not None else 0 _ensure_nvml_init() handle = pynvml.nvmlDeviceGetHandleByIndex(device_index) bus_width = pynvml.nvmlDeviceGetMemoryBusWidth(handle) mem_clock = pynvml.nvmlDeviceGetClockInfo(handle, pynvml.NVML_CLOCK_MEM) # Calculate theoretical peak bandwidth (GB/s) bandwidth = (mem_clock * bus_width * 2) / 8 / 1000 return bandwidthThis eliminates the try-finally overhead and avoids repeatedly incrementing/decrementing NVML's reference count.
benchmarks/bench_rope_quantize_fp8_append_cache.py (1)
197-197: Simplify useless conditional expression.Both branches of the ternary operator return the same value
"NHD", making the condition pointless.Apply this diff:
- kv_layout="NHD" if config_name != "mla" else "NHD", + kv_layout="NHD",
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
benchmarks/bench_rope_quantize_fp8_append_cache.py(1 hunks)flashinfer/utils.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/utils.py (1)
include/flashinfer/trtllm/common.h (1)
device(83-90)
benchmarks/bench_rope_quantize_fp8_append_cache.py (5)
flashinfer/testing/utils.py (1)
bench_gpu_time_with_cudagraph(855-969)flashinfer/utils.py (1)
get_gpu_memory_bandwidth(259-295)tests/test_helpers/rope_reference.py (1)
RotaryEmbedding(117-232)flashinfer/page.py (2)
get_seq_lens(212-235)get_batch_indices_positions(157-209)flashinfer/rope.py (1)
rope_quantize_fp8_append_paged_kv_cache(1424-1657)
🪛 Ruff (0.14.4)
flashinfer/utils.py
278-278: Avoid specifying long messages outside the exception class
(TRY003)
benchmarks/bench_rope_quantize_fp8_append_cache.py
58-58: Avoid specifying long messages outside the exception class
(TRY003)
197-197: Useless if-else condition
(RUF034)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
|
|
||
| import torch | ||
| import torch.version | ||
| import pynvml |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add pynvml to project dependencies or use conditional import.
The pynvml library is imported at the module level but does not appear in the project's dependency manifest. This will cause an ImportError for users who don't have pynvml installed.
Choose one of these solutions:
Solution 1 (recommended): Add pynvml to dependencies
If pynvml is a required dependency, add it to your setup.py, pyproject.toml, or requirements.txt.
Solution 2: Use conditional import with graceful degradation
If the bandwidth utility is optional, use a try-except import pattern:
-import pynvml
+try:
+ import pynvml
+ HAS_PYNVML = True
+except ImportError:
+ HAS_PYNVML = FalseThen guard the function:
@functools.cache
def get_gpu_memory_bandwidth(device: torch.device) -> float:
+ if not HAS_PYNVML:
+ raise ImportError(
+ "pynvml is required for get_gpu_memory_bandwidth. "
+ "Install it with: pip install nvidia-ml-py"
+ )
...Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In flashinfer/utils.py around line 24 the module imports pynvml at top-level
which will raise ImportError for users without that package; either add pynvml
to the project dependencies (setup.py / pyproject.toml / requirements.txt) if it
is required, or change the file to perform a conditional import: wrap import
pynvml in try/except ImportError, set a fallback flag/None when unavailable, and
update any functions that use pynvml to check the flag and raise a clear error
or gracefully no-op with documentation so the package can work without pynvml
installed.
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
include/flashinfer/pos_enc.cuh (2)
922-947: 🔴 CRITICAL: Out-of-bounds V tail vector loads (unresolved from previous review).This critical issue was flagged by coderabbitai[bot] in a previous review but remains unaddressed. The bounds check at line 934 validates only the START position, not the full
vec_size-element read extent.Vulnerability:
Whenhead_dim_total % vec_size != 0, the final threads in the last chunk read beyond the allocated buffer via the vectorizedcast_loadat line 936.Example:
head_dim_total = 100,vec_size = 8,v_elem_offset = 96,tx = 0- Check:
96 + 0*8 = 96 < 100✓ (passes)- Read: indices 96–103 (overruns by 4 elements)
Required fix: Check the full vector extent and use scalar tail for partial vectors:
for (uint32_t j = 0; j < v_chunks; ++j) { uint32_t v_elem_offset = j * rope_chunk_size; - if (v_elem_offset + tx * vec_size < head_dim_total) { + uint32_t start = v_elem_offset + tx * vec_size; + uint32_t remaining = (start < head_dim_total) ? (head_dim_total - start) : 0; + if (remaining >= vec_size) { vec_t<float, vec_size> v_vec; - v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size); + v_vec.cast_load(v_in_ptr + start); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + v_vec[i] = v_vec[i] * quant_scale_kv; + } + QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx, start); + v_vec.cast_store(v_ptr); + } else if (remaining > 0) { + // Scalar tail for partial vector + vec_t<float, vec_size> v_vec; + #pragma unroll + for (uint32_t t = 0; t < remaining; ++t) v_vec[t] = v_in_ptr[start + t]; + #pragma unroll + for (uint32_t t = remaining; t < vec_size; ++t) v_vec[t] = 0.f; #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { v_vec[i] = v_vec[i] * quant_scale_kv; } - QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx, - v_elem_offset + tx * vec_size); - v_vec.cast_store(v_ptr); + QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx, start); + #pragma unroll + for (uint32_t t = 0; t < remaining; ++t) v_ptr[t] = static_cast<QuantType>(v_vec[t]); } }Based on learnings (past review comments).
891-921: 🔴 CRITICAL: Out-of-bounds K non-RoPE vector loads/stores.Lines 906 and 919 perform unconditional vectorized loads/stores without bounds checking. When
no_rope_dim % rope_chunk_size != 0, the last chunk is partial and threads can read/write beyond allocated memory.Example scenario:
no_rope_dim = 100,rope_chunk_size = 64,vec_size = 8- Last chunk processes
elem_offset = 64, but only indices 64-99 are valid- Thread
tx = 5attempts to load from index64 + 5*8 = 104, which exceeds 100Required fix: Add bounds checking similar to the V path, but validate the FULL vector extent:
vec_t<float, vec_size> k_nope_vec; - k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size); + uint32_t start = elem_offset + tx * vec_size; + uint32_t remaining = (start < no_rope_dim) ? (no_rope_dim - start) : 0; + if (remaining >= vec_size) { + k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size); + } else if (remaining > 0) { + // Scalar tail for partial vector + #pragma unroll + for (uint32_t t = 0; t < remaining; ++t) k_nope_vec[t] = k_nope_in_ptr[tx * vec_size + t]; + #pragma unroll + for (uint32_t t = remaining; t < vec_size; ++t) k_nope_vec[t] = 0.f; + } else { + return; // Out of bounds, skip processing + } + #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv; } if constexpr (IS_MLA) { QuantType* ckv_ptr = paged_kv_like.get_ckv_ptr(page_iter, entry_idx, elem_offset + tx * vec_size); - k_nope_vec.cast_store(ckv_ptr); + if (remaining >= vec_size) { + k_nope_vec.cast_store(ckv_ptr); + } else if (remaining > 0) { + #pragma unroll + for (uint32_t t = 0; t < remaining; ++t) ckv_ptr[t] = static_cast<QuantType>(k_nope_vec[t]); + } } else { QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, rope_dim + elem_offset + tx * vec_size); - k_nope_vec.cast_store(k_ptr); + if (remaining >= vec_size) { + k_nope_vec.cast_store(k_ptr); + } else if (remaining > 0) { + #pragma unroll + for (uint32_t t = 0; t < remaining; ++t) k_ptr[t] = static_cast<QuantType>(k_nope_vec[t]); + } }
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/pos_enc.cuh(5 hunks)
🔇 Additional comments (5)
include/flashinfer/pos_enc.cuh (5)
23-23: LGTM: Necessary includes for new functionality.The
<type_traits>include is used forstd::is_sameat line 801, and"page.cuh"provides the paged KV cache types required by the new kernel.Also applies to: 27-27
33-55: LGTM: Well-structured parameter encapsulation.The parameter struct effectively consolidates kernel arguments, improving maintainability and reducing the parameter list complexity.
801-801: IS_MLA is a readability alias, not true duplication.Responding to nvpohanh's earlier question: While
IS_MLAis technically derivable fromCacheT, it serves as a clarity-enhancing alias that makes the multipleconstexpr ifbranches (lines 861, 882, 898, 912, 924) more readable than repeatedstd::is_same<CacheT, paged_kv_mla_t<...>>::valuechecks. This is a standard C++ pattern for type trait aliases.Based on learnings (past review comments).
1070-1154: LGTM: Well-structured host wrapper for GQA/MHA path.The dispatch logic, grid/block calculations (including the V section at lines 1096-1098), parameter packing, and kernel launch configuration are all correctly implemented.
1156-1247: LGTM: Correct MLA-specific host wrapper.The MLA specializations are properly implemented:
num_kv_heads = 1(line 1180, 1211)- Grid excludes V section (line 1182)
v_inset tonullptr(line 1203) with zero V strides (lines 1207, 1226-1227)- 2D K stride duplication (lines 1205-1206, 1223-1225) correctly handles MLA's flattened K layout
| } else { | ||
| // ============ Q Non-RoPE processing ============ | ||
| // MLA has no V section, so Q-nope starts immediately after K-nope. | ||
| // GQA/MHA has a V section of length num_kv_heads blocks. | ||
| uint32_t q_nope_start = k_nope_end + (IS_MLA ? 0u : num_kv_heads); | ||
| uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks; | ||
| uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks; | ||
| uint32_t elem_offset = nope_chunk_idx * rope_chunk_size; | ||
|
|
||
| DType* q_nope_in_ptr = | ||
| q_nope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_in_stride_n, | ||
| q_nope_in_stride_h); | ||
| QuantType* q_nope_out_ptr = | ||
| q_nope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_out_stride_n, | ||
| q_nope_out_stride_h); | ||
|
|
||
| vec_t<float, vec_size> q_nope_vec; | ||
| q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size); | ||
| #pragma unroll | ||
| for (uint32_t i = 0; i < vec_size; ++i) { | ||
| q_nope_vec[i] = q_nope_vec[i] * quant_scale_q; | ||
| } | ||
| q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔴 CRITICAL: Out-of-bounds Q non-RoPE vector loads/stores.
Lines 965 and 970 have the same critical vulnerability as the K non-RoPE path: unconditional vectorized loads/stores without bounds checking when no_rope_dim % rope_chunk_size != 0.
Required fix: Apply the same bounds-checking pattern as suggested for the K non-RoPE path:
vec_t<float, vec_size> q_nope_vec;
- q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
+ uint32_t start = elem_offset + tx * vec_size;
+ uint32_t remaining = (start < no_rope_dim) ? (no_rope_dim - start) : 0;
+ if (remaining >= vec_size) {
+ q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
+ } else if (remaining > 0) {
+ #pragma unroll
+ for (uint32_t t = 0; t < remaining; ++t) q_nope_vec[t] = q_nope_in_ptr[tx * vec_size + t];
+ #pragma unroll
+ for (uint32_t t = remaining; t < vec_size; ++t) q_nope_vec[t] = 0.f;
+ } else {
+ return;
+ }
+
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
}
- q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
+ if (remaining >= vec_size) {
+ q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
+ } else if (remaining > 0) {
+ #pragma unroll
+ for (uint32_t t = 0; t < remaining; ++t) q_nope_out_ptr[tx * vec_size + t] = static_cast<QuantType>(q_nope_vec[t]);
+ }Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In include/flashinfer/pos_enc.cuh around lines 948 to 971, the Q non-RoPE path
does unconditional vectorized loads/stores (lines corresponding to the
q_nope_vec.cast_load and cast_store) which can read/write past the valid
no_rope_dim when no_rope_dim % rope_chunk_size != 0; replicate the K non-RoPE
fix: compute the number of valid elements remaining in this chunk, perform a
guarded/tail-aware load (fill remaining lanes with zeros or use masked load),
apply the scalar multiply only to valid lanes, and perform a guarded/tail-aware
store writing back only the valid elements (or use a masked store). Ensure the
same index arithmetic/offsets are used (tx * vec_size + lane) and mirror the
exact bounds-checking pattern used in the K path so no out-of-bounds memory
accesses occur.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/utils.cuh(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
include/flashinfer/utils.cuh (1)
214-248: No duplicate macro definitions exist—code is correct.The verification confirms that
DISPATCH_ROPE_DIMhas only one definition (line 214) andDISPATCH_INTERLEAVEhas only one definition (line 205) in the file. The AI summary's claim about duplicate definitions is incorrect. The macro implementations are sound and follow the established pattern.
|
[FAILED] Pipeline #38402763: 1/17 passed |
|
/bot run |
|
[FAILED] Pipeline #38527670: 14/17 passed |
|
Hi @kahyunnam let's merge this one first and fix the (potential) performance issues on A100 later. |
…sed RoPE + Q + KV cache, supports MLA/GQA/MHA) (flashinfer-ai#2037) <!-- .github/pull_request_template.md --> ## 📌 Description Add `flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache`, which runs a fused RoPE + Quantization (16 -> 8) + append KV Cache operation kernel. Note that this does not support optional quantization (there is no "RoPE + append KV Cache" fused operation available). Tested on NVIDIA H100 NVL + flashinfer/flashinfer-ci-cu130:latest for MLA/MHA/GQA problem sizes for decode and prefill cases. ## 🔍 Related Issues "[Model Optimization] Add RoPE, RoPE+Q, RoPE+Q+KVCacheUpdate fused kernels for MLA/GQA/MHA" item from Q4 roadmap: flashinfer-ai#1770. This PR is part 2 to earlier PR for RoPE + Q: flashinfer-ai#1924 FW Stakeholders: @nvpohanh @pavanimajety ## 🧪 Test results ``` $ pytest tests/attention/test_rope.py::test_rope_quantize_fp8_append_paged_kv_cache_decode -s ======================================================== test session starts =========================================================platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /workspace/flashinfer configfile: pytest.ini collected 384 items tests/attention/test_rope.py ................................................................................................................................................................................................................................................................................................................................................................................................ ======================================================== 384 passed in 35.22s ======================================================== ``` ``` $ pytest tests/attention/test_rope.py::test_generalized_rope_quantize_append_kv_cache -s ======================================================== test session starts ========================================================= platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /workspace/flashinfer configfile: pytest.ini collected 1248 items tests/attention/test_rope.py ......................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ...................................................................................................................................... ....................................................................... ================================================== 1248 passed in 63.07s (0:01:03) =================================================== ``` ``` $ python benchmarks/bench_rope_quantize_fp8_append_cache.py Detected GPU: NVIDIA GB200 Theoretical Peak Memory Bandwidth: 7928.06 GB/s ==================================================================================================== MLA: 128 Q heads, 1 K head, 64+512 dims (DeepSeek-style) ==================================================================================================== Tokens Time (ms) BW (GB/s) BW% (Peak) TFLOPs ---------------------------------------------------------------------- 1 0.00258 86.53 1.1 0.010 32 0.00381 1873.82 23.6 0.208 128 0.00763 3744.50 47.2 0.416 384 0.01848 4637.34 58.5 0.515 768 0.03694 4639.75 58.5 0.515 1024 0.04879 4683.57 59.1 0.520 2048 0.09590 4766.09 60.1 0.529 4096 0.19031 4803.27 60.6 0.533 8192 0.38523 4745.78 59.9 0.527 ==================================================================================================== GQA: 32 Q heads, 8 K heads, 64+64 dims (Llama-style) ==================================================================================================== Tokens Time (ms) BW (GB/s) BW% (Peak) TFLOPs ---------------------------------------------------------------------- 1 0.00294 6.36 0.1 0.003 32 0.00316 189.48 2.4 0.078 128 0.00317 755.23 9.5 0.310 384 0.00398 1803.09 22.7 0.741 768 0.00522 2750.51 34.7 1.130 1024 0.00617 3100.80 39.1 1.274 2048 0.00927 4130.83 52.1 1.697 4096 0.01631 4695.01 59.2 1.929 8192 0.03466 4418.01 55.7 1.815 ==================================================================================================== MHA: 32 Q heads, 32 K heads, 64+64 dims (Standard) ==================================================================================================== Tokens Time (ms) BW (GB/s) BW% (Peak) TFLOPs ---------------------------------------------------------------------- 1 0.00293 12.68 0.2 0.004 32 0.00313 379.98 4.8 0.126 128 0.00357 1331.80 16.8 0.441 384 0.00517 2756.73 34.8 0.912 768 0.00742 3840.41 48.4 1.271 1024 0.00887 4287.15 54.1 1.419 2048 0.01504 5055.18 63.8 1.673 4096 0.03343 4548.12 57.4 1.505 8192 0.06410 4744.76 59.8 1.571 ==================================================================================================== Configuration details: Page size: 32, Batch size: 4 Token range: 1 (single decode) → 8192 (large prefill) GPU: NVIDIA GB200 Theoretical Peak Memory Bandwidth: 7928.06 GB/s BW% calculated as: (achieved_bandwidth / peak_bandwidth) * 100 ==================================================================================================== ``` ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Fused RoPE + FP8 quantize-and-append for paged KV caches (MLA, GQA/MHA) with layout, page-size, interleave and PDL options; returns quantized Q outputs and writes K/V into paged caches; public ops and high-level API added. * **Tests** * Deterministic, parameterized tests for append and decode/continuation across attention types, layouts, dtypes and quant settings with reference validation. * **Benchmarks** * New benchmark script for performance, bandwidth and Nsight profiling of the paged-KV quantize+append path. * **Chores** * Added cached GPU memory-bandwidth utility for benchmarks. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Zihao Ye <[email protected]>
📌 Description
Add
flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache, which runs a fused RoPE + Quantization (16 -> 8) + append KV Cache operation kernel.Note that this does not support optional quantization (there is no "RoPE + append KV Cache" fused operation available).
Tested on NVIDIA H100 NVL + flashinfer/flashinfer-ci-cu130:latest for MLA/MHA/GQA problem sizes for decode and prefill cases.
🔍 Related Issues
"[Model Optimization] Add RoPE, RoPE+Q, RoPE+Q+KVCacheUpdate fused kernels for MLA/GQA/MHA" item from Q4 roadmap: #1770.
This PR is part 2 to earlier PR for RoPE + Q: #1924
FW Stakeholders: @nvpohanh @pavanimajety
🧪 Test results
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Tests
Benchmarks
Chores