Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ __global__ void __launch_bounds__(1024) allreduce_fusion_kernel_oneshot_lamport(
}

template <AllReduceFusionPattern Pattern, typename DType, int NRanks, bool Fp32Acc>
__global__ void allreduce_fusion_kernel_twoshot_sync(
__global__ void __launch_bounds__(1024) allreduce_fusion_kernel_twoshot_sync(
AllReduceFusionParams params, std::array<int, NRanks> begin_tokens, std::array<int, NRanks> token_num_per_ranks)
{
IndexHelper<DType> index_helper(params);
Expand Down
2 changes: 1 addition & 1 deletion examples/models/core/llama4/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ python -m tensorrt_llm.serve.scripts.benchmark_serving \
- `max_batch_size` and `max_num_tokens` can easily affect the performance. The default values for them are already carefully designed and should deliver good performance on overall cases, however, you may still need to tune it for peak performance.
- `max_batch_size` should not be too low to bottleneck the throughput. Note with Attention DP, the the whole system's max_batch_size will be `max_batch_size*dp_size`.
- CUDA grah `max_batch_size` should be same value as TensorRT-LLM server's `max_batch_size`.
- For more details on `max_batch_size` and `max_num_tokens`, refer to [Tuning Max Batch Size and Max Num Tokens](../performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.md).
- For more details on `max_batch_size` and `max_num_tokens`, refer to [Tuning Max Batch Size and Max Num Tokens](../../../../docs/source/performance/performance-tuning-guide/tuning-max-batch-size-and-max-num-tokens.md).

### Troubleshooting

Expand Down
233 changes: 202 additions & 31 deletions tensorrt_llm/_torch/models/modeling_llama.py

Large diffs are not rendered by default.

28 changes: 20 additions & 8 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,8 +810,11 @@ def _set_up_spec_metadata(
is_draft_model=self.is_draft_model)
return self.spec_metadata

def _get_padded_batch(self, scheduled_requests: ScheduledRequests,
kv_cache_manager) -> int:
def _get_padded_batch(
self,
scheduled_requests: ScheduledRequests,
kv_cache_manager,
spec_resource_manager: Optional[BaseResourceManager] = None) -> int:
can_run_cuda_graph = scheduled_requests.can_run_cuda_graph
batch_size = scheduled_requests.batch_size
# The number of sequences in the batch is the number of prompts times the beam width.
Expand Down Expand Up @@ -847,22 +850,29 @@ def _get_padded_batch(self, scheduled_requests: ScheduledRequests,
if available_blocks < 1:
return 0

cuda_graph_dummy_request_ids = [MAX_UINT64 - 1]
self.cuda_graph_dummy_request = kv_cache_manager.add_dummy_requests(
[MAX_UINT64 - 1],
cuda_graph_dummy_request_ids,
is_gen=True,
max_num_draft_tokens=self.max_draft_len,
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width)[0]
self.cuda_graph_dummy_request.is_cuda_graph_dummy = True
if spec_resource_manager is not None:
spec_resource_manager.add_dummy_requests(
request_ids=cuda_graph_dummy_request_ids)

scheduled_requests.generation_requests.extend(
[self.cuda_graph_dummy_request] * padding_size)

return padding_size

@contextlib.contextmanager
def _maybe_pad_batch(self, scheduled_requests: ScheduledRequests,
kv_cache_manager):
def _maybe_pad_batch(
self,
scheduled_requests: ScheduledRequests,
kv_cache_manager,
spec_resource_manager: Optional[BaseResourceManager] = None):
"""
CUDA graphs can only be used for specific batch sizes.

Expand All @@ -871,7 +881,8 @@ def _maybe_pad_batch(self, scheduled_requests: ScheduledRequests,
because the padded requests will be removed from scheduled requests.
"""
padding_size = self._get_padded_batch(scheduled_requests,
kv_cache_manager)
kv_cache_manager,
spec_resource_manager)
try:
yield scheduled_requests
finally:
Expand Down Expand Up @@ -2072,6 +2083,7 @@ def forward(
spec_metadata.is_spec_dec_dynamic_tree,
spec_metadata.max_draft_len)
else:
spec_resource_manager = None
spec_metadata = None

moe_load_balancer = None
Expand All @@ -2090,8 +2102,8 @@ def forward(
with MoeLoadBalancerIterContext(moe_load_balancer):
return self._forward_step(inputs, gather_ids,
gather_context_logits)
with self._maybe_pad_batch(scheduled_requests,
kv_cache_manager) as scheduled_requests:
with self._maybe_pad_batch(scheduled_requests, kv_cache_manager,
spec_resource_manager) as scheduled_requests:
maybe_graph = self._maybe_get_cuda_graph(
scheduled_requests, spec_config=self.spec_config)
if maybe_graph is not None:
Expand Down
131 changes: 50 additions & 81 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,50 @@ def _executor_loop_pp(self):
self.active_requests,
previous_batch)

def _prepare_and_schedule_batch(self):
new_requests = self._fetch_new_requests()
if self.should_stop_processing:
return None, None

if self.kv_cache_transceiver:
self._check_disagg_gen_transfer_status()

iter_stats = None
if self.enable_iter_perf_stats:
iter_stats = self._get_init_iter_stats(
len(new_requests),
self.executor_request_queue.
get_new_active_requests_queue_latency())

self._pad_attention_dp_dummy_request()

if self.drafter is not None:
self._prepare_draft_requests(self.active_requests)

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)

if self.kv_cache_transceiver:
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
self._prepare_disagg_gen_init(fitting_disagg_gen_init_requests)

if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
logger.warning(
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
)
self.kv_cache_transceiver.check_context_transfer_status(1)
else:
assert scheduled_batch.batch_size > 0, (
"fail to schedule any pending request, "
"probably run out of resource.")

self.num_scheduled_requests = scheduled_batch.batch_size
logger.debug(
f'has {len(self.active_requests)} active_request, '
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
f'{len(scheduled_batch.generation_requests)} generation requests')
return scheduled_batch, iter_stats

def _executor_loop(self):
torch.cuda.set_device(self.device_id)
with self._profiler() as profile_step:
Expand All @@ -810,48 +854,10 @@ def _executor_loop(self):
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()
new_requests = self._fetch_new_requests()
if self.should_stop_processing:
break

if self.kv_cache_transceiver:
self._check_disagg_gen_transfer_status()

if self.enable_iter_perf_stats:
iter_stats = self._get_init_iter_stats(
len(new_requests),
self.executor_request_queue.
get_new_active_requests_queue_latency())

self._pad_attention_dp_dummy_request()

if self.drafter is not None:
self._prepare_draft_requests(self.active_requests)

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)

if self.kv_cache_transceiver:
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
self._prepare_disagg_gen_init(
fitting_disagg_gen_init_requests)
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
logger.warning(
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
)
self.kv_cache_transceiver.check_context_transfer_status(
1)
else:
assert scheduled_batch.batch_size > 0, (
"fail to schedule any pending request, "
"probably run out of resource.")

self.num_scheduled_requests = scheduled_batch.batch_size
logger.debug(
f'has {len(self.active_requests)} active_request, '
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
f'{len(scheduled_batch.generation_requests)} generation requests'
)
scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
if scheduled_batch is None:
break

self._pause_requests(scheduled_batch.paused_requests)

Expand Down Expand Up @@ -954,47 +960,10 @@ def _executor_loop_overlap(self):
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()
new_requests = self._fetch_new_requests()
if self.should_stop_processing:
break

if self.kv_cache_transceiver:
self._check_disagg_gen_transfer_status()

if self.enable_iter_perf_stats:
iter_stats = self._get_init_iter_stats(
len(new_requests),
self.executor_request_queue.
get_new_active_requests_queue_latency())

self._pad_attention_dp_dummy_request()

scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)

if self.kv_cache_transceiver:

# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
self._prepare_disagg_gen_init(
fitting_disagg_gen_init_requests)

if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
logger.warning(
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
)
self.kv_cache_transceiver.check_context_transfer_status(
1)
else:
assert scheduled_batch.batch_size > 0, (
"fail to schedule any pending request, "
"probably run out of resource.")

self.num_scheduled_requests = scheduled_batch.batch_size
logger.debug(
f'has {len(self.active_requests)} active_request, '
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
f'{len(scheduled_batch.generation_requests)} generation requests'
)
scheduled_batch, iter_stats = self._prepare_and_schedule_batch()
if scheduled_batch is None:
break

self._pause_requests(scheduled_batch.paused_requests)

Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/speculative/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,17 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
if req.is_first_context_chunk:
slot_id = self.slot_manager.add_slot(req.request_id)
if self.use_relaxed_acceptance_for_thinking:
self.mtp_relaxed_delta_pool[slot_id] = 0.
self.mtp_relaxed_delta_pool[slot_id].copy_(
0, non_blocking=True)

def update_resources(self, scheduled_batch: ScheduledRequests):
pass

def free_resources(self, request: LlmRequest):
free_slot_id = self.slot_manager.get_slot(request.request_id)
if self.use_relaxed_acceptance_for_thinking:
self.mtp_relaxed_delta_pool[free_slot_id] = 0.
self.mtp_relaxed_delta_pool[free_slot_id].copy_(0,
non_blocking=True)
self.slot_manager.remove_slot(request.request_id)

def add_dummy_requests(self, request_ids: List[int]):
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def run_bench(self):
if self.use_pytorch_backend:
benchmark_cmd += " --backend pytorch"
else:
benchmark_cmd += " --backend trt"
benchmark_cmd += " --backend tensorrt"

if self.extra_llm_api_options:
benchmark_cmd += f" --extra_llm_api_options {self.extra_llm_api_options}"
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,6 @@ examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoder] SKIP (http
examples/test_gpt.py::test_starcoder_fp8_quantization_2gpu[starcoderplus] SKIP (https://nvbugs/5355128)
examples/test_multimodal.py::test_llm_multimodal_general[fuyu-8b-pp:1-tp:1-float16-bs:8-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5360086)
examples/test_multimodal.py::test_llm_multimodal_general[llava-1.5-7b-hf-pp:1-tp:1-float16-bs:8-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5360086)
test_e2e.py::test_trtllm_bench_llmapi_launch[trt_backend-llama-v3-llama3-8b] SKIP (https://nvbugs/5320234)
examples/test_granite.py::test_granite_bf16_lora[granite-3.0-1b-a400m-instruct] SKIP (https://nvbugs/5374145)
examples/test_multimodal.py::test_llm_multimodal_general[VILA1.5-3b-pp:1-tp:1-float16-bs:8-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5373451)
examples/test_multimodal.py::test_llm_multimodal_general[llava-1.5-7b-hf-pp:1-tp:1-float16-bs:1-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5360086)
Expand Down