Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
99fc347
move weights loading related logic to ModelLoader
QiJune Sep 6, 2025
018b022
fix
QiJune Sep 6, 2025
0242a6e
rebase
QiJune Sep 8, 2025
e4a542c
Merge branch 'main' into model_loader
QiJune Sep 8, 2025
cf84a58
clean
QiJune Sep 8, 2025
746e486
fix
QiJune Sep 8, 2025
2c92b6e
Merge branch 'main' into model_loader
QiJune Sep 9, 2025
38505f5
Merge branch 'main' into model_loader
QiJune Sep 9, 2025
c64d320
fix ci
QiJune Sep 9, 2025
e0e5bf8
rebase
QiJune Sep 9, 2025
18e79fe
Merge branch 'main' into model_loader
QiJune Sep 16, 2025
96e700e
Merge branch 'main' into model_loader
QiJune Sep 16, 2025
1a9b420
Merge branch 'main' into model_loader
QiJune Sep 17, 2025
91d79d6
Merge branch 'main' into model_loader
QiJune Sep 17, 2025
9656901
Merge branch 'main' into model_loader
QiJune Sep 18, 2025
133a9eb
Merge branch 'main' into model_loader
QiJune Sep 18, 2025
9556a94
Merge branch 'main' into model_loader
QiJune Sep 18, 2025
1ae8cfd
rebase
QiJune Sep 19, 2025
ea079fa
[None][infra] Waive failed tests in post-merge (#7859)
EmmaQiaoCh Sep 19, 2025
451475e
[None][ci] Waive llama3 auto dtype test bug in https://nvbugs/5527956…
dominicshanshan Sep 19, 2025
6b33bcc
[None][test] Add accuracy benchmark in stress test (#7561)
crazydemo Sep 19, 2025
0ac5148
[None][chore] remove cli cases for rtx6k (#7833)
crazydemo Sep 19, 2025
0e72e8f
[None][feat] Support EPLB in Qwen3 MoE (#7443)
lucifer1004 Sep 19, 2025
efb7634
[None][chore] Add failed cases into waives.txt (#7841)
xinhe-nv Sep 19, 2025
18095a7
[https://nvbugs/5503440][fix] Fix potential hang due to wrong type of…
lancelly Sep 19, 2025
c8cc16d
[None][doc] Tech blog: Combining Guided Decoding and Speculative Deco…
syuoni Sep 19, 2025
7d28acd
[https://nvbugs/5522332][fix] Pin numpy version for Gemma. (cherry-pi…
yuxianq Sep 19, 2025
1be7fae
[TRTLLM-5966][feat] Helix: add custom position ids to MLA kernels (#6…
MatthiasKohl Sep 19, 2025
fbe325c
[https://nvbugs/5471108][chore] Unwaiving disagg acc test (#7686)
pcastonguay Sep 19, 2025
fd2a93e
Merge branch 'main' into model_loader
QiJune Sep 19, 2025
8030b54
[https://nvbugs/5522462][fix] Fix FP8 scout illegal memory access (#7…
mikeiovine Sep 19, 2025
8fcd115
[#7704][chore] Enable MathJax to fix formulas in documentation (#7744)
karljang Sep 19, 2025
8adaf0b
[TRTLLM-6342][feat] Support for partial sharding from factory (#7393)
greg-kwasniewski1 Sep 19, 2025
2e317a7
[https://nvbugs/5520490][fix] Fix intermittent test failures by avoid…
chang-l Sep 20, 2025
41291db
Merge branch 'main' into model_loader
QiJune Sep 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ TensorRT-LLM
<div align="left">

## Tech Blogs
* [09/19] Combining Guided Decoding and Speculative Decoding: Making CPU and GPU Cooperate Seamlessly
[➡️ link](./docs/source/blogs/tech_blog/blog12_Combining_Guided_Decoding_and_Speculative_Decoding.md)

* [08/29] ADP Balance Strategy
[➡️ link](./docs/source/blogs/tech_blog/blog10_ADP_Balance_Strategy.md)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1363,7 +1363,8 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
#ifndef MMHA_USE_FP32_ACCUM_FOR_LOGITS
if (sizeof(Tk) != 4)
{
auto const max_timesteps = min(timestep, static_cast<unsigned>(cyclic_kv_cache_len));
auto const max_timesteps
= min(timestep, min(chunked_attention_size, static_cast<unsigned>(cyclic_kv_cache_len)));
logits_smem_ += divUp(max_timesteps + 1, 4u) * 16;
}
Tk* logits_smem = reinterpret_cast<Tk*>(logits_smem_);
Expand Down
8 changes: 5 additions & 3 deletions cpp/tensorrt_llm/kernels/mlaKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
int* seqQOffset, uint32_t* fmha_tile_counter, int32_t const* kv_cache_lengths, int* seqKVOffsets, int q_pe_ld,
int q_pe_stride, KvCacheDataType cache_type, float* bmm1_scale, float* bmm2_scale, float const* quant_scale_o,
float const* quant_scale_q, float const* quant_scale_kv, float const* dequant_scale_q,
float const* dequant_scale_kv, float host_bmm1_scale)
float const* dequant_scale_kv, float host_bmm1_scale, int32_t const* helix_position_offsets)
{

// Constants.
Expand Down Expand Up @@ -409,7 +409,9 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
if (valid_token)
{

auto const position_id = kv_cache_lengths[batch_idx] - seq_len + local_token_idx;
auto const position_id
= (helix_position_offsets != nullptr ? helix_position_offsets[global_token_idx]
: kv_cache_lengths[batch_idx] - seq_len + local_token_idx);
float2 const* rotary_coef_cache_buffer
= cos_sin_cache + static_cast<size_t>(ROPE_DIM) * position_id + (head_dim_idx / 2);

Expand Down Expand Up @@ -992,7 +994,7 @@ void invokeMLARopeGeneration(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer
params.seqQOffset, params.fmha_tile_counter, params.cache_seq_lens, params.cu_kv_seqlens, params.q_pe_ld,
params.q_pe_stride, params.cache_type, params.bmm1_scale, params.bmm2_scale, params.quant_scale_o,
params.quant_scale_q, params.quant_scale_kv, params.dequant_scale_q, params.dequant_scale_kv,
params.host_bmm1_scale);
params.host_bmm1_scale, params.helix_position_offsets);
}

template <typename T, typename TCache>
Expand Down
3 changes: 3 additions & 0 deletions cpp/tensorrt_llm/kernels/mlaKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ struct MlaParams

// For FP8 context qkv quantization
float const* quant_scale_qkv = nullptr;

// for Helix parallelism: the rotary position offsets [b]
int32_t const* helix_position_offsets{nullptr};
};

template <typename T, typename KVCacheBuffer>
Expand Down
5 changes: 3 additions & 2 deletions cpp/tensorrt_llm/nanobind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ void initBindings(nb::module_& m)
nb::arg("kv_lora_rank") = std::nullopt, nb::arg("qk_nope_head_dim") = std::nullopt,
nb::arg("qk_rope_head_dim") = std::nullopt, nb::arg("v_head_dim") = std::nullopt,
nb::arg("mrope_rotary_cos_sin") = std::nullopt, nb::arg("mrope_position_deltas") = std::nullopt,
nb::arg("attention_chunk_size") = std::nullopt, nb::arg("softmax_stats_tensor") = std::nullopt,
nb::arg("spec_decoding_bool_params"), nb::arg("spec_decoding_tensor_params"), "Multi-head attention operation",
nb::arg("mla_tensor_params"), nb::arg("attention_chunk_size") = std::nullopt,
nb::arg("softmax_stats_tensor") = std::nullopt, nb::arg("spec_decoding_bool_params"),
nb::arg("spec_decoding_tensor_params"), "Multi-head attention operation",
nb::call_guard<nb::gil_scoped_release>());
}
} // namespace tensorrt_llm::nanobind::thop
5 changes: 3 additions & 2 deletions cpp/tensorrt_llm/pybind/thop/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ void initBindings(pybind11::module_& m)
py::arg("kv_lora_rank") = std::nullopt, py::arg("qk_nope_head_dim") = std::nullopt,
py::arg("qk_rope_head_dim") = std::nullopt, py::arg("v_head_dim") = std::nullopt,
py::arg("mrope_rotary_cos_sin") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt,
py::arg("attention_chunk_size") = std::nullopt, py::arg("softmax_stats_tensor") = std::nullopt,
py::arg("spec_decoding_bool_params"), py::arg("spec_decoding_tensor_params"), "Multi-head attention operation",
py::arg("mla_tensor_params"), py::arg("attention_chunk_size") = std::nullopt,
py::arg("softmax_stats_tensor") = std::nullopt, py::arg("spec_decoding_bool_params"),
py::arg("spec_decoding_tensor_params"), "Multi-head attention operation",
py::call_guard<py::gil_scoped_release>());
}
} // namespace tensorrt_llm::pybind::thop
20 changes: 16 additions & 4 deletions cpp/tensorrt_llm/thop/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class RunnerBase
torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache,
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
std::vector<std::optional<torch::Tensor>> mla_tensor_params,
torch::optional<torch::Tensor> softmax_stats_tensor,
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
torch::optional<torch::Tensor> attention_sinks) const
Expand Down Expand Up @@ -133,6 +134,7 @@ class Runner : public RunnerBase
torch::optional<torch::Tensor> rotary_cos_sin, torch::optional<torch::Tensor> latent_cache,
torch::optional<torch::Tensor> q_pe, torch::optional<torch::Tensor> block_ids_per_seq,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
std::vector<std::optional<torch::Tensor>> mla_tensor_params,
torch::optional<torch::Tensor> softmax_stats_tensor,
c10::ArrayRef<std::optional<torch::Tensor>> spec_decoding_tensor_params,
torch::optional<torch::Tensor> attention_sinks) const override
Expand Down Expand Up @@ -165,6 +167,8 @@ class Runner : public RunnerBase
[[maybe_unused]] MlaParams<T> mla_params;
if (op.isMLAEnabled())
{
TORCH_CHECK(mla_tensor_params.size() == 1,
"Expecting 1 tensor for custom MLA tensor params: helix_position_offsets.");
if (is_context)
{
if (latent_cache.has_value())
Expand Down Expand Up @@ -210,6 +214,11 @@ class Runner : public RunnerBase
mla_params.meta = op.mMLAParams;

mla_params.workspace = workspace_ptr;
auto& mla_helix_position_offsets = mla_tensor_params[0];
if (mla_helix_position_offsets.has_value())
{
mla_params.helix_position_offsets = mla_helix_position_offsets->data_ptr<int32_t>();
}
}

int const* context_lengths_ptr = context_lengths.slice(0, seq_offset).data_ptr<int>();
Expand Down Expand Up @@ -507,8 +516,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
std::optional<torch::Tensor> mrope_rotary_cos_sin, std::optional<torch::Tensor> mrope_position_deltas,
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor,
std::vector<bool> spec_decoding_bool_params, std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params)
std::vector<std::optional<torch::Tensor>> mla_tensor_params, std::optional<int64_t> attention_chunk_size,
std::optional<torch::Tensor> softmax_stats_tensor, std::vector<bool> spec_decoding_bool_params,
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params)
{
TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx);
// Use these tensors to infer if the attention is using KV cache
Expand Down Expand Up @@ -748,7 +758,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers,
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
mrope_position_deltas, softmax_stats_tensor, spec_decoding_tensor_params, attention_sinks);
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
attention_sinks);
}

if ((num_generations > 0) && (attn_input_type != AttentionInputType::ContextOnly))
Expand All @@ -764,7 +775,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
kv_cache_block_offsets, host_kv_cache_block_offsets, host_kv_cache_pool_pointers,
host_kv_cache_pool_mapping, cache_indirection, kv_scale_orig_quant, kv_scale_quant_orig, out_scale,
rotary_inv_freq, rotary_cos_sin, latent_cache, q_pe, block_ids_per_seq, mrope_rotary_cos_sin,
mrope_position_deltas, softmax_stats_tensor, spec_decoding_tensor_params, attention_sinks);
mrope_position_deltas, mla_tensor_params, softmax_stats_tensor, spec_decoding_tensor_params,
attention_sinks);
}

TLLM_LOG_TRACE("Attention op stops at layer %d", layer_idx);
Expand Down
5 changes: 3 additions & 2 deletions cpp/tensorrt_llm/thop/attentionOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
std::optional<int64_t> kv_lora_rank, std::optional<int64_t> qk_nope_head_dim,
std::optional<int64_t> qk_rope_head_dim, std::optional<int64_t> v_head_dim,
torch::optional<torch::Tensor> mrope_rotary_cos_sin, torch::optional<torch::Tensor> mrope_position_deltas,
std::optional<int64_t> attention_chunk_size, std::optional<torch::Tensor> softmax_stats_tensor,
std::vector<bool> spec_decoding_bool_params, std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params);
std::vector<std::optional<torch::Tensor>> mla_tensor_params, std::optional<int64_t> attention_chunk_size,
std::optional<torch::Tensor> softmax_stats_tensor, std::vector<bool> spec_decoding_bool_params,
std::vector<std::optional<torch::Tensor>> spec_decoding_tensor_params);

} // namespace torch_ext
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
36 changes: 18 additions & 18 deletions docs/source/blogs/tech_blog/blog10_ADP_Balance_Strategy.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,44 +45,44 @@ To address this critical performance limitation, we introduce the **ADP (Attenti

We model and quantify the performance impact of load imbalance in Attention DP. Since workloads across ranks can be heterogeneous, the execution time for the Attention module in any given iteration is bounded by the rank with the highest workload:

```math
$$
time_i = \max_{0 \leq m < N} time_{i,m}
```
$$

where $time_{i,m}$ represents the execution time of rank $m$ in iteration $i$, and $N$ is the data parallel size.

To quantify load balance and theoretical performance bounds, we define two key metrics:

#### 1. Balance Ratio
The $balance\\_ratio$ measures the load distribution across ranks within the Attention module for each iteration:
The balance ratio measures the load distribution across ranks within the Attention module for each iteration:

```math
balance\_ratio = \frac{avg\_tokens}{max\_tokens}
```
$$
balance = \frac{tokens_{avg}}{tokens_{max}}
$$

where:
- $avg\\_tokens$ represents the average number of tokens across all ranks
- $max\\_tokens$ represents the maximum number of tokens across all ranks
- $tokens_{avg}$ represents the average number of tokens across all ranks
- $tokens_{max}$ represents the maximum number of tokens across all ranks
- $tokens_i$ represents the number of tokens processed by rank $i$

Note: MoE module load balancing is handled separately by the Expert Parallel Load Balancer (EPLB) module and is not considered during the early scheduling phase.

#### 2. Speed-of-Light Throughput (SOL TPS)
The $sol\\_tps$ represents the theoretical upper-bound throughput achievable with perfect load balancing:
The Speed-of-Light throughput represents the theoretical upper-bound throughput achievable with perfect load balancing:

```math
sol\_time = \sum_{i=0}^{\infty} time_i * balance\_ratio_i
```
$$
time_{sol} = \sum_{i=0}^{\infty} time_i \times balance
$$

```math
sol\_tps = \frac{elapsed\_time}{sol\_time} \times actual\_tps
```
$$
tps_{sol} = \frac{time_{elapsed}}{time_{sol}} \times tps_{actual}
$$

where:
- $time_i$: Measured execution time of iteration $i$
- $elapsed\\_time$: Total empirically measured end-to-end execution time
- $actual\\_tps$: Observed throughput in tokens per second
- $sol\\_tps$: Theoretical maximum throughput under perfect load balance
- $time_{elapsed}$: Total empirically measured end-to-end execution time
- $tps_{actual}$: Observed throughput in tokens per second
- $tps_{sol}$: Theoretical maximum throughput under perfect load balance

This theoretical framework enables us to quantify the performance gap between current and optimal system utilization, providing clear targets for optimization.

Expand Down
Loading
Loading