Skip to content

Conversation

@kahyunnam
Copy link
Contributor

@kahyunnam kahyunnam commented Nov 4, 2025

📌 Description

Add flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache, which runs a fused RoPE + Quantization (16 -> 8) + append KV Cache operation kernel.

Note that this does not support optional quantization (there is no "RoPE + append KV Cache" fused operation available).

Tested on NVIDIA H100 NVL + flashinfer/flashinfer-ci-cu130:latest for MLA/MHA/GQA problem sizes for decode and prefill cases.

🔍 Related Issues

"[Model Optimization] Add RoPE, RoPE+Q, RoPE+Q+KVCacheUpdate fused kernels for MLA/GQA/MHA" item from Q4 roadmap: #1770.

This PR is part 2 to earlier PR for RoPE + Q: #1924

FW Stakeholders: @nvpohanh @pavanimajety

🧪 Test results

$ pytest tests/attention/test_rope.py::test_rope_quantize_fp8_append_paged_kv_cache_decode -s
======================================================== test session starts =========================================================platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /workspace/flashinfer
configfile: pytest.ini
collected 384 items

tests/attention/test_rope.py ................................................................................................................................................................................................................................................................................................................................................................................................

======================================================== 384 passed in 35.22s ========================================================
$ pytest tests/attention/test_rope.py::test_generalized_rope_quantize_append_kv_cache -s
======================================================== test session starts =========================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /workspace/flashinfer
configfile: pytest.ini
collected 1248 items

tests/attention/test_rope.py .........................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
.......................................................................

================================================== 1248 passed in 63.07s (0:01:03) ===================================================
$ python benchmarks/bench_rope_quantize_fp8_append_cache.py

Detected GPU: NVIDIA GB200
Theoretical Peak Memory Bandwidth: 7928.06 GB/s


====================================================================================================
  MLA: 128 Q heads, 1 K head, 64+512 dims (DeepSeek-style)
====================================================================================================
Tokens     Time (ms)    BW (GB/s)    BW% (Peak)     TFLOPs
----------------------------------------------------------------------
1          0.00258      86.53        1.1            0.010
32         0.00381      1873.82      23.6           0.208
128        0.00763      3744.50      47.2           0.416
384        0.01848      4637.34      58.5           0.515
768        0.03694      4639.75      58.5           0.515
1024       0.04879      4683.57      59.1           0.520
2048       0.09590      4766.09      60.1           0.529
4096       0.19031      4803.27      60.6           0.533
8192       0.38523      4745.78      59.9           0.527

====================================================================================================
  GQA: 32 Q heads, 8 K heads, 64+64 dims (Llama-style)
====================================================================================================
Tokens     Time (ms)    BW (GB/s)    BW% (Peak)     TFLOPs
----------------------------------------------------------------------
1          0.00294      6.36         0.1            0.003
32         0.00316      189.48       2.4            0.078
128        0.00317      755.23       9.5            0.310
384        0.00398      1803.09      22.7           0.741
768        0.00522      2750.51      34.7           1.130
1024       0.00617      3100.80      39.1           1.274
2048       0.00927      4130.83      52.1           1.697
4096       0.01631      4695.01      59.2           1.929
8192       0.03466      4418.01      55.7           1.815

====================================================================================================
  MHA: 32 Q heads, 32 K heads, 64+64 dims (Standard)
====================================================================================================
Tokens     Time (ms)    BW (GB/s)    BW% (Peak)     TFLOPs
----------------------------------------------------------------------
1          0.00293      12.68        0.2            0.004
32         0.00313      379.98       4.8            0.126
128        0.00357      1331.80      16.8           0.441
384        0.00517      2756.73      34.8           0.912
768        0.00742      3840.41      48.4           1.271
1024       0.00887      4287.15      54.1           1.419
2048       0.01504      5055.18      63.8           1.673
4096       0.03343      4548.12      57.4           1.505
8192       0.06410      4744.76      59.8           1.571

====================================================================================================
Configuration details:
  Page size: 32, Batch size: 4
  Token range: 1 (single decode) → 8192 (large prefill)
  GPU: NVIDIA GB200
  Theoretical Peak Memory Bandwidth: 7928.06 GB/s
  BW% calculated as: (achieved_bandwidth / peak_bandwidth) * 100
====================================================================================================

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Fused RoPE + FP8 quantize-and-append for paged KV caches (MLA, GQA/MHA) with layout, page-size, interleave and PDL options; returns quantized Q outputs and writes K/V into paged caches; public ops and high-level API added.
  • Tests

    • Deterministic, parameterized tests for append and decode/continuation across attention types, layouts, dtypes and quant settings with reference validation.
  • Benchmarks

    • New benchmark script for performance, bandwidth and Nsight profiling of the paged-KV quantize+append path.
  • Chores

    • Added cached GPU memory-bandwidth utility for benchmarks.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 4, 2025

Walkthrough

Adds a fused RoPE + FP8 quantize + paged-KV append flow: new unified CUDA kernel and host wrappers, new TVM FFI binding and C++ entry, Python API and op registrations, tests and benchmark, a GPU bandwidth helper, and new dispatch macros for interleave/rope-dim specialization.

Changes

Cohort / File(s) Summary
CUDA kernel & host headers
include/flashinfer/pos_enc.cuh
Added RopeQuantizeAppendPagedKVCacheParams, unified templated RopeQuantizeAppendPagedKVCacheKernel (MLA vs GQA/MHA via CacheT traits), host wrappers RopeQuantizeAppendPagedKVCache / RopeQuantizeAppendPagedMLACache, and #include "page.cuh". Replaced previous DISPATCH_ROPE_DIM path with ROPE_DIM-based dispatch.
FFI binding declaration
csrc/flashinfer_rope_binding.cu
Declared and exported TVM FFI symbol rope_quantize_append_paged_kv_cache(...) with full tensor args and flags; updated export macro.
FFI binding implementation
csrc/rope.cu
Implemented rope_quantize_append_paged_kv_cache(...): input validation, MLA vs GQA/MHA detection, paged cache struct construction, kernel dispatch to MLA or GQA/MHA host wrapper, and CUDA status checking.
Python API & ops
flashinfer/rope.py
Added public API rope_quantize_fp8_append_paged_kv_cache(...), internal ops and fake ops, op registrations, layout/kv_layout_code inference, output allocation, paged-cache unpacking, interleave/is_neox handling, and dispatch to the new FFI op.
Tests
tests/attention/test_rope.py
Added deterministic seeds and parameterized tests for FP8 quantize+append with paged KV caches (MLA/GQA/MHA), decode/pre-population scenarios, multiple KV layouts, dtypes, page sizes, and PDL flag.
Benchmark
benchmarks/bench_rope_quantize_fp8_append_cache.py
New benchmark script for FP8 paged-KV rope quantize+append across MLA/GQA/MHA with Nsight Compute profiling support, token sweep timing, bandwidth/TFLOPs reporting, and CLI options.
Utils (NVML)
flashinfer/utils.py
Added cached get_gpu_memory_bandwidth(device: torch.device) -> float using pynvml to compute device memory bandwidth (GB/s).
Header macros
include/flashinfer/utils.cuh
Added DISPATCH_INTERLEAVE and DISPATCH_ROPE_DIM macros to convert runtime flags into compile-time constants; note duplicated definitions (redefinition risk).

Sequence Diagram(s)

sequenceDiagram
    participant User as Python caller
    participant PyAPI as rope_quantize_fp8_append_paged_kv_cache()
    participant Validate as Validation & Dispatch
    participant Op as Registered Op
    participant FFI as rope_quantize_append_paged_kv_cache (FFI)
    participant Kernel as RopeQuantizeAppendPagedKVCacheKernel

    User->>PyAPI: call with Q/K/V, cos_sin, paged-kv metadata, flags
    PyAPI->>Validate: validate shapes/dtypes, infer layout/is_neox, allocate q outputs
    Validate->>Op: select MLA vs GQA/MHA, compute kv_layout_code, interleave, scales, page_size
    Op->>FFI: invoke FFI with tensors, strides, layout code, page size, flags
    FFI->>Kernel: construct paged_kv_t / paged_kv_mla_t and launch kernel
    rect rgb(240,248,255)
    Note over Kernel: apply RoPE → FP8 quantize Q/K/V → append K/V to paged cache (MLA vs GQA/MHA branch)
    end
    Kernel-->>FFI: return CUDA status
    FFI-->>Op: return
    Op-->>PyAPI: complete
    PyAPI-->>User: return quantized Q outputs (K/V stored in paged caches)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Pay special attention to:
    • Kernel template instantiation and launch parameters (interleave, vec_size, block dims).
    • MLA (2D K) vs GQA/MHA (3D K) shape/stride/layout handling.
    • Construction and lifetime/alignment of paged_kv_t / paged_kv_mla_t and page metadata.
    • FFI boundary checks (contiguity, dtypes, stride calculations) and error paths.
    • Duplicated macro definitions in include/flashinfer/utils.cuh.
    • NVML usage and error-handling in get_gpu_memory_bandwidth.
    • Test/benchmark determinism and FP8 tolerance assumptions.

Possibly related PRs

Suggested reviewers

  • djmmoss
  • yongwww
  • cyx-6
  • wenscarl
  • bkryu
  • joker-eph
  • nvmbreughe

Poem

🐇
I hopped through kernels, tiny and spry,
Twisting RoPE loops in FP8 sky,
Pages filled with keys, values tucked tight,
MLA or GQA — appended just right,
A rabbit's hop — the cache hums bright! 🎉

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 57.14% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately and concisely describes the main change: addition of a fused RoPE + quantization + paged KV cache operation supporting MLA/GQA/MHA architectures.
Description check ✅ Passed The description fully adheres to the template with complete sections: detailed description of the fused operation, clear link to related Q4 roadmap issue and part 2 PR reference, comprehensive test results showing 384+1248 passing tests, and confirmed completion of all pre-commit and testing checkboxes.
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 72bd59e and 0f87afd.

📒 Files selected for processing (1)
  • include/flashinfer/utils.cuh (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
include/flashinfer/utils.cuh (2)

204-212: LGTM! Previous concern has been addressed.

The macro now correctly uses constexpr bool INTERLEAVE in both branches, enabling full compile-time evaluation and static specialization. This is consistent with other dispatch macros in the file.


214-248: LGTM! Well-structured dispatch macro.

The macro follows established patterns in the file (similar to DISPATCH_HEAD_DIM) and correctly uses constexpr uint32_t ROPE_DIM for compile-time dispatch. The error message is comprehensive, listing all supported values (16, 32, 64, 128, 256) which cover typical RoPE dimensions for MLA/GQA/MHA configurations.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@kahyunnam kahyunnam changed the title [WIP, need to test more] (RoPE + Q fp8 + append kv_cache) fused kernel for MLA/GQA/MHA [WIP, need to test more] (RoPE + Q fp8 + append KV cache) fused kernel for MLA/GQA/MHA Nov 4, 2025
@kahyunnam kahyunnam changed the title [WIP, need to test more] (RoPE + Q fp8 + append KV cache) fused kernel for MLA/GQA/MHA Add flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache (fused RoPE + Q + KV cache, supports MLA/GQA/MHA) Nov 4, 2025
@kahyunnam kahyunnam marked this pull request as ready for review November 4, 2025 19:44
@kahyunnam kahyunnam self-assigned this Nov 4, 2025
@kahyunnam kahyunnam mentioned this pull request Nov 4, 2025
28 tasks
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2d68a6b and 4040e9c.

📒 Files selected for processing (5)
  • csrc/flashinfer_rope_binding.cu (1 hunks)
  • csrc/rope.cu (1 hunks)
  • flashinfer/rope.py (2 hunks)
  • include/flashinfer/pos_enc.cuh (3 hunks)
  • tests/attention/test_rope.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
csrc/flashinfer_rope_binding.cu (1)
csrc/rope.cu (4)
  • rope_quantize (271-422)
  • rope_quantize (271-275)
  • rope_quantize_append_paged_kv_cache (430-622)
  • rope_quantize_append_paged_kv_cache (430-438)
csrc/rope.cu (2)
flashinfer/comm/cuda_ipc.py (2)
  • cudaSetDevice (149-150)
  • cudaGetErrorString (146-147)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
tests/attention/test_rope.py (4)
benchmarks/bench_rope_quantize_fp8.py (1)
  • FlashInferRotaryEmbedding (19-88)
flashinfer/page.py (2)
  • get_seq_lens (212-235)
  • get_batch_indices_positions (157-209)
flashinfer/rope.py (1)
  • rope_quantize_fp8_append_paged_kv_cache (1407-1670)
tests/test_helpers/rope_reference.py (1)
  • forward_native (194-232)
flashinfer/rope.py (3)
flashinfer/utils.py (5)
  • register_custom_op (272-281)
  • register_custom_op (291-310)
  • register_fake_op (283-287)
  • register_fake_op (312-317)
  • TensorLayout (43-45)
csrc/flashinfer_rope_binding.cu (1)
  • rope_quantize_append_paged_kv_cache (48-54)
csrc/rope.cu (2)
  • rope_quantize_append_paged_kv_cache (430-622)
  • rope_quantize_append_paged_kv_cache (430-438)
🪛 Ruff (0.14.3)
flashinfer/rope.py

1571-1571: Avoid specifying long messages outside the exception class

(TRY003)


1596-1596: Avoid specifying long messages outside the exception class

(TRY003)


1600-1600: Avoid specifying long messages outside the exception class

(TRY003)


1609-1612: Avoid specifying long messages outside the exception class

(TRY003)


1614-1617: Avoid specifying long messages outside the exception class

(TRY003)


1626-1629: Avoid specifying long messages outside the exception class

(TRY003)


1631-1633: Avoid specifying long messages outside the exception class

(TRY003)

@kahyunnam
Copy link
Contributor Author

kahyunnam commented Nov 4, 2025

@yzh119 @pavanimajety do you think a benchmarking script similar to https://github.com/flashinfer-ai/flashinfer/blob/main/benchmarks/bench_rope_quantize_fp8.py is needed here? AFAIK there is not really a direct pytorch equivalent of append-kv-cache. However, we could still have a benchmark comparing flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache (fused flashinfer) vs flashinfer.rope.rope_quantize_fp8 + flashinfer.page.append_paged_kv_cache / flashinfer.page.append_paged_mla_kv_cache(un-fused flashinfer) if that would be useful

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
flashinfer/rope.py (1)

1591-1596: Missing validation: GQA/MHA requires non-None V tensor.

The code validates that MLA does not have a V input (correctly raises an error if V is not None for MLA). However, there is no corresponding validation for the GQA/MHA path to ensure V is not None. If V is None for GQA/MHA, the kernel will receive an invalid tensor, likely causing a crash or incorrect results.

Add validation before line 1591:

+    # Validate V input based on architecture
+    if not is_mla and v is None:
+        raise ValueError("GQA/MHA requires V input (cannot be None)")
+
     # Handle V input for MLA (create empty dummy tensor, not used)
     if is_mla:
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4040e9c and 20e123b.

📒 Files selected for processing (1)
  • flashinfer/rope.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/rope.py (3)
flashinfer/utils.py (5)
  • register_custom_op (272-281)
  • register_custom_op (291-310)
  • register_fake_op (283-287)
  • register_fake_op (312-317)
  • TensorLayout (43-45)
csrc/flashinfer_rope_binding.cu (1)
  • rope_quantize_append_paged_kv_cache (48-54)
csrc/rope.cu (2)
  • rope_quantize_append_paged_kv_cache (430-622)
  • rope_quantize_append_paged_kv_cache (430-438)
🪛 Ruff (0.14.3)
flashinfer/rope.py

1571-1571: Avoid specifying long messages outside the exception class

(TRY003)


1596-1596: Avoid specifying long messages outside the exception class

(TRY003)


1600-1600: Avoid specifying long messages outside the exception class

(TRY003)


1609-1612: Avoid specifying long messages outside the exception class

(TRY003)


1614-1617: Avoid specifying long messages outside the exception class

(TRY003)


1626-1629: Avoid specifying long messages outside the exception class

(TRY003)


1631-1633: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (2)
flashinfer/rope.py (2)

229-325: LGTM! Custom op and fake op names are now aligned.

The custom op and fake op are both correctly registered under the same symbol "flashinfer::rope_quantize_append_paged_kv_cache", resolving the past review concern about FakeTensor/meta tracing. The mutates_args specification correctly includes all output tensors that will be modified in-place.


1407-1670: Well-structured function with comprehensive documentation.

The overall implementation is well-organized with clear separation of concerns:

  • Architecture detection and input validation
  • Output tensor allocation
  • Cache unpacking and validation
  • Custom op dispatch

The docstring is exceptionally thorough with detailed parameter descriptions and complete working examples for both MLA and GQA/MHA use cases. The function correctly returns only the quantized Q tensors since K/V are written directly to the paged cache.

@kahyunnam
Copy link
Contributor Author

kahyunnam commented Nov 4, 2025

Also some other questions for reviewers:

  1. Sould defining/using a struct for the params into the new kernel help? not sure whether it would simplify or overly complicate.

  2. Is there any benefit in trying to route the mla_rope_quantize_fp8 + rope_quantize_fp8 + rope_quantize_fp8_append_paged_kv_cache into the same backend kernel? I did make a separate kernel for rope_quantize_fp8_append_paged_kv_cache, even though much of the rope/quantization logic is the same, simply due to the number of optional arguments this would create for kv cache. (But perhaps there is some clever param struct design here that could handle this cleanly)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
flashinfer/rope.py (2)

1639-1639: Consider moving import to module level.

Importing TensorLayout inside the function works but is unconventional. If circular imports are not a concern, consider moving this to the top-level imports for consistency with Python style conventions.


1571-1571: Consider extracting long error messages to reduce TRY003 violations.

Ruff flags multiple long error messages defined inline (TRY003). While functional, consider extracting repeated error message patterns to constants or helper functions to improve maintainability and comply with the style guide.

Also applies to: 1596-1596, 1609-1612, 1614-1617, 1626-1629, 1631-1633

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 20e123b and 58a96f2.

📒 Files selected for processing (1)
  • flashinfer/rope.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/rope.py (3)
flashinfer/utils.py (5)
  • register_custom_op (272-281)
  • register_custom_op (291-310)
  • register_fake_op (283-287)
  • register_fake_op (312-317)
  • TensorLayout (43-45)
csrc/flashinfer_rope_binding.cu (1)
  • rope_quantize_append_paged_kv_cache (48-54)
csrc/rope.cu (2)
  • rope_quantize_append_paged_kv_cache (430-622)
  • rope_quantize_append_paged_kv_cache (430-438)
🪛 Ruff (0.14.3)
flashinfer/rope.py

1571-1571: Avoid specifying long messages outside the exception class

(TRY003)


1596-1596: Avoid specifying long messages outside the exception class

(TRY003)


1600-1600: Avoid specifying long messages outside the exception class

(TRY003)


1609-1612: Avoid specifying long messages outside the exception class

(TRY003)


1614-1617: Avoid specifying long messages outside the exception class

(TRY003)


1626-1629: Avoid specifying long messages outside the exception class

(TRY003)


1631-1633: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
flashinfer/rope.py (1)

1407-1670: Excellent comprehensive documentation and architecture detection.

The function provides thorough documentation with clear examples for both MLA and GQA/MHA use cases. The automatic architecture detection based on tensor shapes (line 1574) is elegant, and the cache validation logic correctly distinguishes between the two architectures with appropriate dummy tensor creation for unused cache types.

also get rid of docstring examples; too long and too complicated. Maybe just reference test code instead as an example?
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
flashinfer/rope.py (1)

1587-1617: Consider validating kv_layout for clearer error messages.

At line 1590, accessing TensorLayout[kv_layout].value will raise a KeyError if kv_layout is not "NHD" or "HND". While the default value prevents most issues, an explicit validation would provide a clearer error message.

Consider adding validation before line 1590:

+    # Validate kv_layout
+    if kv_layout not in ("NHD", "HND"):
+        raise ValueError(
+            f"kv_layout must be 'NHD' or 'HND', got '{kv_layout}'"
+        )
+
     # Import TensorLayout enum
     from .utils import TensorLayout
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 58a96f2 and 034be6a.

📒 Files selected for processing (1)
  • flashinfer/rope.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
flashinfer/rope.py (3)
flashinfer/utils.py (5)
  • register_custom_op (272-281)
  • register_custom_op (291-310)
  • register_fake_op (283-287)
  • register_fake_op (312-317)
  • TensorLayout (43-45)
csrc/flashinfer_rope_binding.cu (1)
  • rope_quantize_append_paged_kv_cache (48-54)
csrc/rope.cu (2)
  • rope_quantize_append_paged_kv_cache (430-622)
  • rope_quantize_append_paged_kv_cache (430-438)
🪛 Ruff (0.14.3)
flashinfer/rope.py

1514-1514: Avoid specifying long messages outside the exception class

(TRY003)


1539-1539: Avoid specifying long messages outside the exception class

(TRY003)


1543-1543: Avoid specifying long messages outside the exception class

(TRY003)


1552-1555: Avoid specifying long messages outside the exception class

(TRY003)


1557-1560: Avoid specifying long messages outside the exception class

(TRY003)


1570-1573: Avoid specifying long messages outside the exception class

(TRY003)


1575-1578: Avoid specifying long messages outside the exception class

(TRY003)


1580-1582: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (3)
flashinfer/rope.py (3)

229-296: LGTM! Custom operator registration is correct.

The registration properly declares all mutated cache tensors and routes to the CUDA kernel. The signature matches the binding in csrc/flashinfer_rope_binding.cu.


299-325: LGTM! Fake operator registration aligns with custom op.

The fake op is correctly registered under the same name as the custom op, enabling proper FakeTensor/meta tracing and torch.compile support.


1535-1586: Well-structured architecture detection and cache validation.

The function correctly distinguishes MLA (2D K) from GQA/MHA (3D K) and validates that:

  • MLA requires V=None and uses 3D ckv_cache/kpe_cache
  • GQA/MHA requires V tensor and uses 4D k_cache/v_cache

The V validation at lines 1569-1573 properly addresses the concern from previous reviews.

@kahyunnam kahyunnam changed the title Add flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache (fused RoPE + Q + KV cache, supports MLA/GQA/MHA) feat: Add flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache (fused RoPE + Q + KV cache, supports MLA/GQA/MHA) Nov 4, 2025
@yzh119
Copy link
Collaborator

yzh119 commented Nov 4, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !111 has been created, and the CI pipeline #37900083 is currently running. I'll report back once the pipeline job completes.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, would you mind adding benchmark suites for this operator?

uint32_t batch_size = kv_indptr.size(0) - 1;
QKVLayout kv_layout = QKVLayout(kv_layout_code);

if (is_mla) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we unify the two branches?

In C++ side, we assume K tensor is 3D.
At python side, if we found K tensor is 2D, we unsqueeze its dimension 1.

csrc/rope.cu Outdated
uint32_t k_rope_in_stride_h, k_nope_in_stride_h;
uint32_t v_in_stride = 0, v_in_stride_h = 0;

if (is_mla) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Kahyun, LGTM overall! A few comments.

@nvpohanh
Copy link
Contributor

nvpohanh commented Nov 5, 2025

@elvischenv could you also help to review this? thanks!

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #37900083: 13/17 passed

@nvpohanh
Copy link
Contributor

nvpohanh commented Nov 6, 2025

@kahyunnam could we get some basic perf benchmarking results? Just want to make sure we are achieving reasonably good DRAM bw or not too much worse than TRT-LLM's kernel. Thanks!

@kahyunnam kahyunnam requested a review from jiahanc as a code owner November 13, 2025 02:37
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
include/flashinfer/pos_enc.cuh (1)

756-769: Drop redundant IS_MLA template parameter

IS_MLA duplicates CacheT. Prefer deriving it:

-template <bool interleave, uint32_t vec_size, uint32_t bdx, typename DType, typename IdType,
-          typename QuantType, bool IS_MLA, typename CacheT>
+template <bool interleave, uint32_t vec_size, uint32_t bdx, typename DType, typename IdType,
+          typename QuantType, typename CacheT>
 __global__ void RopeQuantizeAppendPagedKVCacheKernel( ... ) {
+  constexpr bool IS_MLA =
+      std::is_same<CacheT, paged_kv_mla_t<QuantType, IdType>>::value;

Cleaner API and fewer instantiations to maintain.

🧹 Nitpick comments (4)
include/flashinfer/pos_enc.cuh (1)

1056-1121: Host wrapper duplication and tailoring

RopeQuantizeAppendPagedKVCache and RopeQuantizeAppendPagedMLACache largely duplicate launch setup. Consider a single templated host wrapper on CacheT with a small traits struct for MLA vs GQA (presence of V, number of KV heads, strides). Reduces surface and keeps dispatch logic in one place.

tests/attention/test_rope.py (2)

488-517: Add one tail-dimension case to catch V tail bugs

All current head_dim cases are multiples of 16. Add a case where head_dim = rope_dim + no_rope_dim isn’t divisible by vec_size (e.g., rope_dim=64, no_rope_dim=48 for bf16 → 112) to exercise V tail handling.


590-658: Decode metadata setup clarity

Nice: using get_seq_lens/get_batch_indices_positions mirrors runtime flow. Consider factoring metadata build into a helper to reuse across tests.

flashinfer/rope.py (1)

1625-1656: Layout code derivation and dispatch

Using TensorLayout enum and passing code down is clean. Consider small doc note that only "NHD"/"HND" are supported.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 034be6a and 401bab3.

📒 Files selected for processing (4)
  • csrc/rope.cu (1 hunks)
  • flashinfer/rope.py (5 hunks)
  • include/flashinfer/pos_enc.cuh (4 hunks)
  • tests/attention/test_rope.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
csrc/rope.cu (2)
flashinfer/comm/cuda_ipc.py (2)
  • cudaSetDevice (149-150)
  • cudaGetErrorString (146-147)
csrc/tvm_ffi_utils.h (1)
  • get_stream (272-274)
tests/attention/test_rope.py (4)
benchmarks/bench_rope_quantize_fp8.py (1)
  • FlashInferRotaryEmbedding (19-88)
flashinfer/page.py (2)
  • get_seq_lens (212-235)
  • get_batch_indices_positions (157-209)
flashinfer/rope.py (1)
  • rope_quantize_fp8_append_paged_kv_cache (1424-1657)
tests/test_helpers/rope_reference.py (1)
  • forward_native (194-232)
flashinfer/rope.py (3)
flashinfer/utils.py (5)
  • register_custom_op (272-281)
  • register_custom_op (291-310)
  • register_fake_op (283-287)
  • register_fake_op (312-317)
  • TensorLayout (43-45)
csrc/rope.cu (2)
  • rope_quantize_append_paged_kv_cache (430-617)
  • rope_quantize_append_paged_kv_cache (430-438)
csrc/flashinfer_rope_binding.cu (1)
  • rope_quantize_append_paged_kv_cache (48-54)
🪛 Ruff (0.14.4)
flashinfer/rope.py

1531-1531: Avoid specifying long messages outside the exception class

(TRY003)


1577-1577: Avoid specifying long messages outside the exception class

(TRY003)


1581-1581: Avoid specifying long messages outside the exception class

(TRY003)


1590-1593: Avoid specifying long messages outside the exception class

(TRY003)


1595-1598: Avoid specifying long messages outside the exception class

(TRY003)


1608-1611: Avoid specifying long messages outside the exception class

(TRY003)


1613-1616: Avoid specifying long messages outside the exception class

(TRY003)


1618-1620: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (11)
include/flashinfer/pos_enc.cuh (2)

388-398: Good: avoid unnecessary cos/sin loads

Gating cos/sin loads with (by < k_rope_end) matches the Q/K‑RoPE blocks and skips non‑RoPE paths. Looks correct.


792-809: Verify page offset computation and page_size usage

paged_kv_like.page_size.divmod(...) suggests page_size is a struct, but the expression multiplies by paged_kv_like.page_size as if it were a scalar. Please confirm page_size supports both .divmod() and scalar multiplication, or split into:

  • compute base_pages = paged_kv_like.indptr[batch_indices[idx]],
  • compute delta_pages, entry_idx = divmod(positions[idx], page_size),
  • set page_iter = base_pages + delta_pages.
csrc/rope.cu (3)

489-505: Require K as 3D here: doc callout

C++ binding assumes 3D K for both paths and checks head_dim==1 for MLA. Good. Ensure Python always unsqueezes MLA K tensors to 3D before calling; you already do this in flashinfer/rope.py. Please add a brief comment here to prevent future regressions.


514-519: Cache dims: enforce 4D only (GQA/MHA)

Checks match the kernel (CHECK_DIM(4, k_cache)/v_cache). Good alignment with Python validation.


552-579: MLA path construction LGTM; minor: pass null for V explicitly

Setting v_in to an empty tensor on the Python side and validating here is fine; alternatively, you can document that v_in is unused in MLA and may be null. No action needed.

tests/attention/test_rope.py (2)

397-401: Determinism: good seeding

Setting both CPU and CUDA seeds ensures reproducible FP8 rounding paths. LGTM.


1245-1313: Cache non-regression checks in decode

Comparing pre/post snapshots at exact indices is solid. Good guard against accidental overwrites.

flashinfer/rope.py (4)

229-239: Custom/fake op names aligned

Registration now matches flashinfer::rope_quantize_append_paged_kv_cache. This unblocks FakeTensor/meta tracing.


1286-1422: Optional nope tensors: good normalization

Creating zero-width nope tensors avoids special-casing downstream and matches C++ assumptions. LGTM.


1530-1560: Quant dtype inference and Q output allocation

Infers from provided outputs or defaults to e4m3fn; allocates only when needed. Simple and correct.


1606-1620: GQA/MHA path: V required validation

Raising a clear error when v is None prevents obscure kernel failures. Good.

Comment on lines +915 to +929
for (uint32_t j = 0; j < v_chunks; ++j) {
uint32_t v_elem_offset = j * rope_chunk_size;
if (v_elem_offset + tx * vec_size < head_dim_total) {
vec_t<float, vec_size> v_vec;
v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
v_vec[i] = v_vec[i] * quant_scale_kv;
}
QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx,
v_elem_offset + tx * vec_size);
v_vec.cast_store(v_ptr);
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Potential out-of-bounds on V tail vector loads

In the V path you do vectorized cast_load with a simple bound if (v_elem_offset + tx*vec_size < head_dim_total). If head_dim_total isn’t a multiple of vec_size, the last lane may overread. Either:

  • enforce head_dim_total % vec_size == 0 at dispatch and document it, or
  • add a masked/scalar tail for the last partial vector.

Suggested fix (scalar tail fallback):

-          if (v_elem_offset + tx * vec_size < head_dim_total) {
-            vec_t<float, vec_size> v_vec;
-            v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size);
+          uint32_t start = v_elem_offset + tx * vec_size;
+          if (start < head_dim_total) {
+            vec_t<float, vec_size> v_vec;
+            uint32_t remaining = head_dim_total - start;
+            if (remaining >= vec_size) {
+              v_vec.cast_load(v_in_ptr + start);
+            } else {
+              #pragma unroll
+              for (uint32_t t = 0; t < remaining; ++t) v_vec[t] = v_in_ptr[start + t];
+              #pragma unroll
+              for (uint32_t t = remaining; t < vec_size; ++t) v_vec[t] = 0.f;
+            }
@@
-            QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx,
-                                                       v_elem_offset + tx * vec_size);
-            v_vec.cast_store(v_ptr);
+            QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx, start);
+            if (remaining >= vec_size) {
+              v_vec.cast_store(v_ptr);
+            } else {
+              #pragma unroll
+              for (uint32_t t = 0; t < remaining; ++t) v_ptr[t] = static_cast<QuantType>(v_vec[t]);
+            }
           }

I'll verify this concern by examining the code context and understanding the constraints. Let me gather information about the vectorization logic, dispatch constraints, and similar patterns in the codebase.


🏁 Script executed:

#!/bin/bash
# Examine the function containing the flagged code and surrounding context
cat include/flashinfer/pos_enc.cuh | sed -n '890,950p'

Length of output: 2838


🏁 Script executed:

#!/bin/bash
# Search for the vec_t template definition and cast_load semantics
rg -n "template.*vec_t" -A 10 --type=cpp --type=cuh include/

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Check for head_dim_total constraints or alignment requirements in dispatch/initialization
rg -n "head_dim_total" -B 3 -A 3 --type=cpp --type=cuh include/flashinfer/pos_enc.cuh | head -80

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Search for K path (kv_in_ptr or similar) to see if same pattern exists
rg -n "cast_load.*k_in_ptr|cast_load.*kv_in_ptr" -B 5 -A 5 --type=cpp --type=cuh include/flashinfer/pos_enc.cuh

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Look for existing partial-vector or tail handling patterns in the file
rg -n "remaining|partial|tail" -B 2 -A 2 --type=cpp --type=cuh include/flashinfer/pos_enc.cuh | head -60

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Search for rope_chunk_size and vec_size alignment patterns
rg -n "rope_chunk_size|vec_size" -B 2 -A 2 --type=cpp --type=cuh include/flashinfer/pos_enc.cuh | head -100

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Check for dispatch-time template constraints or requirements
rg -n "static_assert|rope_chunk_size.*%|vec_size.*%" --type=cpp --type=cuh include/

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Look at the full function signature and template parameters
rg -n "void.*rope.*page.*kv|kernel.*rope.*page.*kv" -A 30 --type=cpp --type=cuh include/flashinfer/pos_enc.cuh | head -120

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Search for similar bounds checks and how they handle partial vectors
rg -n "if.*\+ tx \* vec_size.*head_dim" -B 3 -A 10 --type=cpp --type=cuh include/flashinfer/pos_enc.cuh

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Check if there's documentation or comments about alignment requirements
rg -n "multiple|divisible|align|constraint" -B 1 -A 1 --type=cpp --type=cuh include/flashinfer/pos_enc.cuh | head -80

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Look for quantization and constraint information
fd ".*\.cuh$|.*\.cu$" include/flashinfer/ -x grep -l "head_dim_total" | head -5

Length of output: 30


🏁 Script executed:

#!/bin/bash
# Check the K path more carefully - look for k_vec or similar load patterns
rg -n "k_vec.*cast_load|kv_vec.*cast_load" -B 3 -A 8 --type=cpp --type=cuh include/flashinfer/pos_enc.cuh

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Look at function header/docstring to understand parameter constraints
sed -n '800,920p' include/flashinfer/pos_enc.cuh | grep -A 100 "^template"

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Search for cast_load implementation to understand its behavior
rg -n "cast_load" -B 2 -A 5 --type=cpp --type=cuh include/ | head -100

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Check if this is a paged attention kernel and whether alignment is guaranteed
rg -n "paged.*rope\|rope.*paged" -B 5 -A 15 --type=cpp --type=cuh include/flashinfer/pos_enc.cuh | head -150

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Look for the function name context by searching around the snippet
sed -n '850,930p' include/flashinfer/pos_enc.cuh

Length of output: 3473


🏁 Script executed:

#!/bin/bash
# Check for any comment or specification about head_dim constraints
grep -n "head_dim" include/flashinfer/pos_enc.cuh | head -30

Length of output: 1878


🏁 Script executed:

#!/bin/bash
# Look at where rope_chunk_size is defined or constrained
rg -n "rope_chunk_size\s*=" --type=cpp --type=cuh include/flashinfer/pos_enc.cuh

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Check the struct/class definition for paged_kv_like to understand data layout
rg -n "paged_kv_like\|PagedAttentionKVCacheBatched" -B 5 -A 20 --type=cpp --type=cuh include/flashinfer/ | head -150

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Examine the vec_t template from namespace/header files
find include/flashinfer -name "*.cuh" -o -name "*.h" | xargs grep -l "struct vec_t\|template.*vec_t" | head -3

Length of output: 33


🏁 Script executed:

#!/bin/bash
# Look at vec_t implementation
rg -n "struct vec_t\|template.*class vec_t" -A 50 --type=cpp --type=cuh include/flashinfer/ | head -200

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Check rope_chunk_size initialization relative to vec_size in the kernel launch
rg -n "RoPEApplyPagedKVWithQuant" --type=cpp --type=cuh include/flashinfer/pos_enc.cuh | head -5

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Search entire file for function definition containing the lines in question
grep -n "kernel\|__global__\|__device__" include/flashinfer/pos_enc.cuh | grep -B 5 "91[0-9]"

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Find the kernel/function that contains lines 915-929
awk 'NR >= 900 && NR <= 930 { if (/^[a-zA-Z_].*[({]|__global__|__device__/) print NR": "$0 }' include/flashinfer/pos_enc.cuh

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Get broader function context - find the containing function
sed -n '700,930p' include/flashinfer/pos_enc.cuh | tail -230

Length of output: 11018


🏁 Script executed:

#!/bin/bash
# Search for cast_load to understand whether it can safely read beyond bounds
rg -n "cast_load" include/ --type=cpp --type=cuh -A 3 -B 1 | grep -E "cast_load|^--$" | head -40

Length of output: 27


I need to examine the vec_t implementation and understand the dispatch constraints. Let me continue verification.


🏁 Script executed:

#!/bin/bash
# Look at vec_t implementation
cat include/flashinfer/vec_dtypes.cuh | head -200

Length of output: 5962


🏁 Script executed:

#!/bin/bash
# Search for cast_load specifically
grep -n "cast_load" include/flashinfer/vec_dtypes.cuh -A 5 -B 2 | head -80

Length of output: 2381


🏁 Script executed:

#!/bin/bash
# Check dispatch for head_dim to see if alignment is enforced
grep -n "DISPATCH_HEAD_DIM\|HEAD_DIM" include/flashinfer/pos_enc.cuh -B 2 -A 5

Length of output: 5141


🏁 Script executed:

#!/bin/bash
# Look for buffer allocation and padding in KV cache
fd -e cuh -e cu -e h include/flashinfer/ -x grep -l "paged_kv\|PagedKVCache" | head -5

Length of output: 381


🏁 Script executed:

#!/bin/bash
# Check if rope_dim has constraints
grep -n "rope_dim" include/flashinfer/pos_enc.cuh | grep -E "constrain|assert|static|require|dispatch"

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Look for K nope unconditional cast_load
sed -n '875,895p' include/flashinfer/pos_enc.cuh

Length of output: 864


🏁 Script executed:

#!/bin/bash
# Check cast_store in the same locations
grep -n "cast_store" include/flashinfer/pos_enc.cuh -B 1 -A 1

Length of output: 2135


🏁 Script executed:

#!/bin/bash
# Look at DISPATCH_ROPE_DIM to see constraints
sed -n '715,750p' include/flashinfer/pos_enc.cuh

Length of output: 2085


🏁 Script executed:

#!/bin/bash
# Check Q rope and Q nope paths for similar patterns
sed -n '795,850p' include/flashinfer/pos_enc.cuh

Length of output: 2371


🏁 Script executed:

#!/bin/bash
# Look for vec_t definition structure
sed -n '1,100p' include/flashinfer/vec_dtypes.cuh

Length of output: 3263


🏁 Script executed:

#!/bin/bash
# Check if there's unroll pragma impact on bounds
grep -n "pragma unroll\|#pragma" include/flashinfer/pos_enc.cuh | grep -A 1 -B 1 "91[2-9]\|92[0-5]"

Length of output: 77


🏁 Script executed:

#!/bin/bash
# Search for any buffer overallocation or padding
grep -n "padding\|overallocate\|extra.*alloc\|buffer.*size" include/flashinfer/ -r --include="*.cuh" --include="*.cu" | head -20

Length of output: 2607


Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 401bab3 and 8a6d1dc.

📒 Files selected for processing (1)
  • benchmarks/bench_rope_quantize_fp8_append_cache.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_rope_quantize_fp8_append_cache.py (4)
flashinfer/testing/utils.py (1)
  • bench_gpu_time_with_cudagraph (855-969)
tests/test_helpers/rope_reference.py (1)
  • RotaryEmbedding (117-232)
flashinfer/page.py (2)
  • get_seq_lens (212-235)
  • get_batch_indices_positions (157-209)
flashinfer/rope.py (1)
  • rope_quantize_fp8_append_paged_kv_cache (1424-1657)
🪛 Ruff (0.14.4)
benchmarks/bench_rope_quantize_fp8_append_cache.py

59-59: Avoid specifying long messages outside the exception class

(TRY003)


198-198: Useless if-else condition

(RUF034)


302-302: subprocess call: check for execution of untrusted input

(S603)


313-313: Ambiguous variable name: l

(E741)


365-365: Do not catch blind exception: Exception

(BLE001)


370-370: Do not catch blind exception: Exception

(BLE001)


431-431: Do not catch blind exception: Exception

(BLE001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

@kahyunnam kahyunnam force-pushed the knam/RoPE-fusion-part2 branch from 8a6d1dc to 49b70ec Compare November 13, 2025 04:42
Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kahyunnam thanks for the great work, I run the benchmark on Hopper and it seems the achieved bandwidth reaches 80%+ region for large number of tokens which meets our previous standard of memory bound kernels so I suppose it should be ready to merge.

In the future we need to have better understanding about Blackwell's expected performance on memory-bound kernels.

I added one commit to your branch: 3a97554 to replace ncu (not all user environments have installed ncu) with pynvml to obtain gpu memory bandwidth, but let me know if you think it's not correct.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
flashinfer/utils.py (1)

258-296: Consider module-level NVML initialization for better resource management.

The function initializes and shuts down NVML on each unique device query. While NVML's reference-counted init/shutdown makes this safe, and the @functools.cache decorator ensures it only happens once per device, there are cleaner alternatives:

Optional improvement: Module-level NVML lifecycle management

_nvml_initialized = False

def _ensure_nvml_init():
    global _nvml_initialized
    if not _nvml_initialized:
        pynvml.nvmlInit()
        _nvml_initialized = True

@functools.cache
def get_gpu_memory_bandwidth(device: torch.device) -> float:
    """..."""
    if isinstance(device, str):
        device = torch.device(device)
    
    if device.type != "cuda":
        raise ValueError(f"Device must be a CUDA device, got {device}")
    
    device_index = device.index if device.index is not None else 0
    
    _ensure_nvml_init()
    handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
    bus_width = pynvml.nvmlDeviceGetMemoryBusWidth(handle)
    mem_clock = pynvml.nvmlDeviceGetClockInfo(handle, pynvml.NVML_CLOCK_MEM)
    
    # Calculate theoretical peak bandwidth (GB/s)
    bandwidth = (mem_clock * bus_width * 2) / 8 / 1000
    
    return bandwidth

This eliminates the try-finally overhead and avoids repeatedly incrementing/decrementing NVML's reference count.

benchmarks/bench_rope_quantize_fp8_append_cache.py (1)

197-197: Simplify useless conditional expression.

Both branches of the ternary operator return the same value "NHD", making the condition pointless.

Apply this diff:

-            kv_layout="NHD" if config_name != "mla" else "NHD",
+            kv_layout="NHD",
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 49b70ec and 3a97554.

📒 Files selected for processing (2)
  • benchmarks/bench_rope_quantize_fp8_append_cache.py (1 hunks)
  • flashinfer/utils.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/utils.py (1)
include/flashinfer/trtllm/common.h (1)
  • device (83-90)
benchmarks/bench_rope_quantize_fp8_append_cache.py (5)
flashinfer/testing/utils.py (1)
  • bench_gpu_time_with_cudagraph (855-969)
flashinfer/utils.py (1)
  • get_gpu_memory_bandwidth (259-295)
tests/test_helpers/rope_reference.py (1)
  • RotaryEmbedding (117-232)
flashinfer/page.py (2)
  • get_seq_lens (212-235)
  • get_batch_indices_positions (157-209)
flashinfer/rope.py (1)
  • rope_quantize_fp8_append_paged_kv_cache (1424-1657)
🪛 Ruff (0.14.4)
flashinfer/utils.py

278-278: Avoid specifying long messages outside the exception class

(TRY003)

benchmarks/bench_rope_quantize_fp8_append_cache.py

58-58: Avoid specifying long messages outside the exception class

(TRY003)


197-197: Useless if-else condition

(RUF034)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs


import torch
import torch.version
import pynvml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add pynvml to project dependencies or use conditional import.

The pynvml library is imported at the module level but does not appear in the project's dependency manifest. This will cause an ImportError for users who don't have pynvml installed.

Choose one of these solutions:

Solution 1 (recommended): Add pynvml to dependencies

If pynvml is a required dependency, add it to your setup.py, pyproject.toml, or requirements.txt.

Solution 2: Use conditional import with graceful degradation

If the bandwidth utility is optional, use a try-except import pattern:

-import pynvml
+try:
+    import pynvml
+    HAS_PYNVML = True
+except ImportError:
+    HAS_PYNVML = False

Then guard the function:

 @functools.cache
 def get_gpu_memory_bandwidth(device: torch.device) -> float:
+    if not HAS_PYNVML:
+        raise ImportError(
+            "pynvml is required for get_gpu_memory_bandwidth. "
+            "Install it with: pip install nvidia-ml-py"
+        )
     ...

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In flashinfer/utils.py around line 24 the module imports pynvml at top-level
which will raise ImportError for users without that package; either add pynvml
to the project dependencies (setup.py / pyproject.toml / requirements.txt) if it
is required, or change the file to perform a conditional import: wrap import
pynvml in try/except ImportError, set a fallback flag/None when unavailable, and
update any functions that use pynvml to check the flag and raise a clear error
or gracefully no-op with documentation so the package can work without pynvml
installed.

@kahyunnam
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !111 has been updated with latest changes, and the CI pipeline #38402763 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
include/flashinfer/pos_enc.cuh (2)

922-947: 🔴 CRITICAL: Out-of-bounds V tail vector loads (unresolved from previous review).

This critical issue was flagged by coderabbitai[bot] in a previous review but remains unaddressed. The bounds check at line 934 validates only the START position, not the full vec_size-element read extent.

Vulnerability:
When head_dim_total % vec_size != 0, the final threads in the last chunk read beyond the allocated buffer via the vectorized cast_load at line 936.

Example:

  • head_dim_total = 100, vec_size = 8, v_elem_offset = 96, tx = 0
  • Check: 96 + 0*8 = 96 < 100 ✓ (passes)
  • Read: indices 96–103 (overruns by 4 elements)

Required fix: Check the full vector extent and use scalar tail for partial vectors:

         for (uint32_t j = 0; j < v_chunks; ++j) {
           uint32_t v_elem_offset = j * rope_chunk_size;
-          if (v_elem_offset + tx * vec_size < head_dim_total) {
+          uint32_t start = v_elem_offset + tx * vec_size;
+          uint32_t remaining = (start < head_dim_total) ? (head_dim_total - start) : 0;
+          if (remaining >= vec_size) {
             vec_t<float, vec_size> v_vec;
-            v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size);
+            v_vec.cast_load(v_in_ptr + start);
+#pragma unroll
+            for (uint32_t i = 0; i < vec_size; ++i) {
+              v_vec[i] = v_vec[i] * quant_scale_kv;
+            }
+            QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx, start);
+            v_vec.cast_store(v_ptr);
+          } else if (remaining > 0) {
+            // Scalar tail for partial vector
+            vec_t<float, vec_size> v_vec;
+            #pragma unroll
+            for (uint32_t t = 0; t < remaining; ++t) v_vec[t] = v_in_ptr[start + t];
+            #pragma unroll
+            for (uint32_t t = remaining; t < vec_size; ++t) v_vec[t] = 0.f;
 #pragma unroll
             for (uint32_t i = 0; i < vec_size; ++i) {
               v_vec[i] = v_vec[i] * quant_scale_kv;
             }
-            QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx,
-                                                       v_elem_offset + tx * vec_size);
-            v_vec.cast_store(v_ptr);
+            QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx, start);
+            #pragma unroll
+            for (uint32_t t = 0; t < remaining; ++t) v_ptr[t] = static_cast<QuantType>(v_vec[t]);
           }
         }

Based on learnings (past review comments).


891-921: 🔴 CRITICAL: Out-of-bounds K non-RoPE vector loads/stores.

Lines 906 and 919 perform unconditional vectorized loads/stores without bounds checking. When no_rope_dim % rope_chunk_size != 0, the last chunk is partial and threads can read/write beyond allocated memory.

Example scenario:

  • no_rope_dim = 100, rope_chunk_size = 64, vec_size = 8
  • Last chunk processes elem_offset = 64, but only indices 64-99 are valid
  • Thread tx = 5 attempts to load from index 64 + 5*8 = 104, which exceeds 100

Required fix: Add bounds checking similar to the V path, but validate the FULL vector extent:

     vec_t<float, vec_size> k_nope_vec;
-    k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size);
+    uint32_t start = elem_offset + tx * vec_size;
+    uint32_t remaining = (start < no_rope_dim) ? (no_rope_dim - start) : 0;
+    if (remaining >= vec_size) {
+      k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size);
+    } else if (remaining > 0) {
+      // Scalar tail for partial vector
+      #pragma unroll
+      for (uint32_t t = 0; t < remaining; ++t) k_nope_vec[t] = k_nope_in_ptr[tx * vec_size + t];
+      #pragma unroll
+      for (uint32_t t = remaining; t < vec_size; ++t) k_nope_vec[t] = 0.f;
+    } else {
+      return; // Out of bounds, skip processing
+    }
+
 #pragma unroll
     for (uint32_t i = 0; i < vec_size; ++i) {
       k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
     }

     if constexpr (IS_MLA) {
       QuantType* ckv_ptr =
           paged_kv_like.get_ckv_ptr(page_iter, entry_idx, elem_offset + tx * vec_size);
-      k_nope_vec.cast_store(ckv_ptr);
+      if (remaining >= vec_size) {
+        k_nope_vec.cast_store(ckv_ptr);
+      } else if (remaining > 0) {
+        #pragma unroll
+        for (uint32_t t = 0; t < remaining; ++t) ckv_ptr[t] = static_cast<QuantType>(k_nope_vec[t]);
+      }
     } else {
       QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx,
                                                  rope_dim + elem_offset + tx * vec_size);
-      k_nope_vec.cast_store(k_ptr);
+      if (remaining >= vec_size) {
+        k_nope_vec.cast_store(k_ptr);
+      } else if (remaining > 0) {
+        #pragma unroll
+        for (uint32_t t = 0; t < remaining; ++t) k_ptr[t] = static_cast<QuantType>(k_nope_vec[t]);
+      }
     }
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3a97554 and ec566bc.

📒 Files selected for processing (1)
  • include/flashinfer/pos_enc.cuh (5 hunks)
🔇 Additional comments (5)
include/flashinfer/pos_enc.cuh (5)

23-23: LGTM: Necessary includes for new functionality.

The <type_traits> include is used for std::is_same at line 801, and "page.cuh" provides the paged KV cache types required by the new kernel.

Also applies to: 27-27


33-55: LGTM: Well-structured parameter encapsulation.

The parameter struct effectively consolidates kernel arguments, improving maintainability and reducing the parameter list complexity.


801-801: IS_MLA is a readability alias, not true duplication.

Responding to nvpohanh's earlier question: While IS_MLA is technically derivable from CacheT, it serves as a clarity-enhancing alias that makes the multiple constexpr if branches (lines 861, 882, 898, 912, 924) more readable than repeated std::is_same<CacheT, paged_kv_mla_t<...>>::value checks. This is a standard C++ pattern for type trait aliases.

Based on learnings (past review comments).


1070-1154: LGTM: Well-structured host wrapper for GQA/MHA path.

The dispatch logic, grid/block calculations (including the V section at lines 1096-1098), parameter packing, and kernel launch configuration are all correctly implemented.


1156-1247: LGTM: Correct MLA-specific host wrapper.

The MLA specializations are properly implemented:

  • num_kv_heads = 1 (line 1180, 1211)
  • Grid excludes V section (line 1182)
  • v_in set to nullptr (line 1203) with zero V strides (lines 1207, 1226-1227)
  • 2D K stride duplication (lines 1205-1206, 1223-1225) correctly handles MLA's flattened K layout

Comment on lines +948 to +971
} else {
// ============ Q Non-RoPE processing ============
// MLA has no V section, so Q-nope starts immediately after K-nope.
// GQA/MHA has a V section of length num_kv_heads blocks.
uint32_t q_nope_start = k_nope_end + (IS_MLA ? 0u : num_kv_heads);
uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks;
uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks;
uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;

DType* q_nope_in_ptr =
q_nope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_in_stride_n,
q_nope_in_stride_h);
QuantType* q_nope_out_ptr =
q_nope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_out_stride_n,
q_nope_out_stride_h);

vec_t<float, vec_size> q_nope_vec;
q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
}
q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🔴 CRITICAL: Out-of-bounds Q non-RoPE vector loads/stores.

Lines 965 and 970 have the same critical vulnerability as the K non-RoPE path: unconditional vectorized loads/stores without bounds checking when no_rope_dim % rope_chunk_size != 0.

Required fix: Apply the same bounds-checking pattern as suggested for the K non-RoPE path:

     vec_t<float, vec_size> q_nope_vec;
-    q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
+    uint32_t start = elem_offset + tx * vec_size;
+    uint32_t remaining = (start < no_rope_dim) ? (no_rope_dim - start) : 0;
+    if (remaining >= vec_size) {
+      q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
+    } else if (remaining > 0) {
+      #pragma unroll
+      for (uint32_t t = 0; t < remaining; ++t) q_nope_vec[t] = q_nope_in_ptr[tx * vec_size + t];
+      #pragma unroll
+      for (uint32_t t = remaining; t < vec_size; ++t) q_nope_vec[t] = 0.f;
+    } else {
+      return;
+    }
+
 #pragma unroll
     for (uint32_t i = 0; i < vec_size; ++i) {
       q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
     }
-    q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
+    if (remaining >= vec_size) {
+      q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
+    } else if (remaining > 0) {
+      #pragma unroll
+      for (uint32_t t = 0; t < remaining; ++t) q_nope_out_ptr[tx * vec_size + t] = static_cast<QuantType>(q_nope_vec[t]);
+    }

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In include/flashinfer/pos_enc.cuh around lines 948 to 971, the Q non-RoPE path
does unconditional vectorized loads/stores (lines corresponding to the
q_nope_vec.cast_load and cast_store) which can read/write past the valid
no_rope_dim when no_rope_dim % rope_chunk_size != 0; replicate the K non-RoPE
fix: compute the number of valid elements remaining in this chunk, perform a
guarded/tail-aware load (fill remaining lanes with zeros or use masked load),
apply the scalar multiply only to valid lanes, and perform a guarded/tail-aware
store writing back only the valid elements (or use a masked store). Ensure the
same index arithmetic/offsets are used (tx * vec_size + lane) and mirror the
exact bounds-checking pattern used in the K path so no out-of-bounds memory
accesses occur.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ec566bc and 72bd59e.

📒 Files selected for processing (1)
  • include/flashinfer/utils.cuh (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
include/flashinfer/utils.cuh (1)

214-248: No duplicate macro definitions exist—code is correct.

The verification confirms that DISPATCH_ROPE_DIM has only one definition (line 214) and DISPATCH_INTERLEAVE has only one definition (line 205) in the file. The AI summary's claim about duplicate definitions is incorrect. The macro implementations are sound and follow the established pattern.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #38402763: 1/17 passed

@kahyunnam
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !111 has been updated with latest changes, and the CI pipeline #38527670 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #38527670: 14/17 passed

@yzh119
Copy link
Collaborator

yzh119 commented Nov 18, 2025

Hi @kahyunnam let's merge this one first and fix the (potential) performance issues on A100 later.

@yzh119 yzh119 merged commit 3b07247 into flashinfer-ai:main Nov 18, 2025
4 checks passed
qsang-nv pushed a commit to qsang-nv/flashinfer that referenced this pull request Nov 18, 2025
…sed RoPE + Q + KV cache, supports MLA/GQA/MHA) (flashinfer-ai#2037)

<!-- .github/pull_request_template.md -->

## 📌 Description

Add `flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache`, which
runs a fused RoPE + Quantization (16 -> 8) + append KV Cache operation
kernel.

Note that this does not support optional quantization (there is no "RoPE
+ append KV Cache" fused operation available).

Tested on NVIDIA H100 NVL + flashinfer/flashinfer-ci-cu130:latest for
MLA/MHA/GQA problem sizes for decode and prefill cases.

## 🔍 Related Issues

"[Model Optimization] Add RoPE, RoPE+Q, RoPE+Q+KVCacheUpdate fused
kernels for MLA/GQA/MHA" item from Q4 roadmap:
flashinfer-ai#1770.

This PR is part 2 to earlier PR for RoPE + Q:
flashinfer-ai#1924

FW Stakeholders: @nvpohanh @pavanimajety 

## 🧪 Test results

```
$ pytest tests/attention/test_rope.py::test_rope_quantize_fp8_append_paged_kv_cache_decode -s
======================================================== test session starts =========================================================platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /workspace/flashinfer
configfile: pytest.ini
collected 384 items

tests/attention/test_rope.py ................................................................................................................................................................................................................................................................................................................................................................................................

======================================================== 384 passed in 35.22s ========================================================
```

```
$ pytest tests/attention/test_rope.py::test_generalized_rope_quantize_append_kv_cache -s
======================================================== test session starts =========================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /workspace/flashinfer
configfile: pytest.ini
collected 1248 items

tests/attention/test_rope.py .........................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
......................................................................................................................................
.......................................................................

================================================== 1248 passed in 63.07s (0:01:03) ===================================================
```

```
$ python benchmarks/bench_rope_quantize_fp8_append_cache.py

Detected GPU: NVIDIA GB200
Theoretical Peak Memory Bandwidth: 7928.06 GB/s


====================================================================================================
  MLA: 128 Q heads, 1 K head, 64+512 dims (DeepSeek-style)
====================================================================================================
Tokens     Time (ms)    BW (GB/s)    BW% (Peak)     TFLOPs
----------------------------------------------------------------------
1          0.00258      86.53        1.1            0.010
32         0.00381      1873.82      23.6           0.208
128        0.00763      3744.50      47.2           0.416
384        0.01848      4637.34      58.5           0.515
768        0.03694      4639.75      58.5           0.515
1024       0.04879      4683.57      59.1           0.520
2048       0.09590      4766.09      60.1           0.529
4096       0.19031      4803.27      60.6           0.533
8192       0.38523      4745.78      59.9           0.527

====================================================================================================
  GQA: 32 Q heads, 8 K heads, 64+64 dims (Llama-style)
====================================================================================================
Tokens     Time (ms)    BW (GB/s)    BW% (Peak)     TFLOPs
----------------------------------------------------------------------
1          0.00294      6.36         0.1            0.003
32         0.00316      189.48       2.4            0.078
128        0.00317      755.23       9.5            0.310
384        0.00398      1803.09      22.7           0.741
768        0.00522      2750.51      34.7           1.130
1024       0.00617      3100.80      39.1           1.274
2048       0.00927      4130.83      52.1           1.697
4096       0.01631      4695.01      59.2           1.929
8192       0.03466      4418.01      55.7           1.815

====================================================================================================
  MHA: 32 Q heads, 32 K heads, 64+64 dims (Standard)
====================================================================================================
Tokens     Time (ms)    BW (GB/s)    BW% (Peak)     TFLOPs
----------------------------------------------------------------------
1          0.00293      12.68        0.2            0.004
32         0.00313      379.98       4.8            0.126
128        0.00357      1331.80      16.8           0.441
384        0.00517      2756.73      34.8           0.912
768        0.00742      3840.41      48.4           1.271
1024       0.00887      4287.15      54.1           1.419
2048       0.01504      5055.18      63.8           1.673
4096       0.03343      4548.12      57.4           1.505
8192       0.06410      4744.76      59.8           1.571

====================================================================================================
Configuration details:
  Page size: 32, Batch size: 4
  Token range: 1 (single decode) → 8192 (large prefill)
  GPU: NVIDIA GB200
  Theoretical Peak Memory Bandwidth: 7928.06 GB/s
  BW% calculated as: (achieved_bandwidth / peak_bandwidth) * 100
====================================================================================================

```

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Fused RoPE + FP8 quantize-and-append for paged KV caches (MLA,
GQA/MHA) with layout, page-size, interleave and PDL options; returns
quantized Q outputs and writes K/V into paged caches; public ops and
high-level API added.

* **Tests**
* Deterministic, parameterized tests for append and decode/continuation
across attention types, layouts, dtypes and quant settings with
reference validation.

* **Benchmarks**
* New benchmark script for performance, bandwidth and Nsight profiling
of the paged-KV quantize+append path.

* **Chores**
  * Added cached GPU memory-bandwidth utility for benchmarks.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Zihao Ye <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants