diff --git a/README.md b/README.md index d141529791f..747fe0278df 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,9 @@ TensorRT-LLM
Figure 1: Top: guided decoding timeline without overlapping. Bottom: guided decoding timeline with overlapping. (This figure is from the XGrammar paper.)
+ +### Speculative Decoding + +Speculative decoding is a crucial feature in low-latency or throughput@latency LLM inference scenarios. For each request, a lightweight drafter proposes several draft tokens, and then the target model verifies the draft tokens in parallel. Hopefully, most draft tokens are accepted, and thus multiple tokens are generated in a single target model forward. Compared with normal LLM inference where each model forward generates a single token, speculative decoding offers the potential to generate more tokens per iteration by leveraging more computation. This improves the arithmetic intensity and reduces the required number of iterations. + +TensorRT LLM has two kinds of speculative decoding implementations, namely the one-model and two-model implementations. The one-model implementation launches a single CUDA graph for a target model forward together with multiple draft model forwards. This is more difficult to implement and is coupled with the modeling code, but it offers the best performance. The two-model implementation decouples the target and draft models into separate CUDA graphs, which is much more flexible and offers better feature coverage. There are ongoing efforts to close the gaps between the two implementations. + +Figure 2: Top: GPU timeline of one-model speculative decoding. Bottom: GPU timeline of two-model speculative decoding.
+ +### Two Challenges + +When combining guided decoding and speculative decoding, two challenges arise. First, at each generation iteration, speculative decoding proposes multiple draft tokens, some of which might be rejected in the verification step. The draft token proposal and rejection are not transparent to guided decoding. Specifically, this can be broken down into two views: + +* For the target model, guided decoding should advance the grammar state and generate the mask for every draft token. If some draft tokens are rejected, guided decoding should rollback the grammar state to the last accepted token. +* For the draft model, without grammar constraints, some draft tokens may violate the grammar and thus be forcefully rejected in the verification step. Clearly, this hurts the acceptance rate. Hence, guided decoding should also intervene on the logits for every draft token generation if possible. + * Some speculative algorithms propose draft tokens recurrently by computing logits and sampling (e.g., the standard draft-target model, EAGLE or MTP), similarly to a standard LLM. In that case, guided decoding can apply grammar constraints in a similar mask gen and applying way. + * Some drafting algorithms work without logits sampling, which require other ways to apply the grammar constraints. + +Second, specific to the one-model speculative decoding where a single CUDA graph contains multiple (draft and target) model forwards, the CPU-GPU synchronization becomes challenging. Note that for every step $i$, there are two event waits: + +* The host waits for the *token event* that indicates the readiness of CPU tokens from step $i-1$. +* The model forward stream waits for the *mask event* that indicates the readiness of GPU masks from step $i$. + +Figure 3: The CPU-GPU synchronization for multiple model forwards.
+ +Note that in the two-model implementation, the sampling is excluded from the CUDA graphs for better flexibility (Figure 2). From the CPU perspective, this offers a timing for the grammar computation. In particular, the mask event wait can be inserted between the CUDA graph replay and sampling, effectively making the GPU wait for the GPU masks asynchronously copied from CPU. + +However, the CUDA graph of the one-model implementation contains multiple forwards, inevitably including the sampling operations. Hence, there is no timing for the grammar computation. The most outstanding problem is that when replaying the CUDA graph, the mask event wait cannot be inserted before sampling. An alternative is capturing the events and waits in the CUDA graph, but it is still ineffective because the grammar computation is on CPU and thus not capturable. Once such a CUDA graph is launched to replay, the GPU does not wait for any newly recorded events, so it is impossible to block the GPU for the readiness of masks. + +## Trace Grammar State for Draft Token Proposal and Rejection + +### Target Model + +For a target model forward, a request should have one new token and multiple draft tokens from the last verification step and drafter, respectively. For each token in the sequence, guided decoding should advance the grammar state and fill the mask tensor. Before sampling, the masks should be applied to the corresponding logits. After verification, the grammar state should be rolled back by the number of rejected tokens. + +Compared to guided decoding with non-speculative decoding, the rollback operation is newly introduced. Thankfully, it has built-in support by grammar backends like [XGrammar](https://github.com/mlc-ai/xgrammar/blob/v0.1.21/python/xgrammar/matcher.py#L341-L350) and [LLGuidance](https://github.com/guidance-ai/llguidance/blob/v1.1.1/python/llguidance/_lib.pyi#L363-L366). + +Before proceeding to the draft model view, note that the LLM can generate correct outputs as long as we apply grammar constraints on the target model, because any draft tokens violating the grammar will be forcefully rejected by the verification step. However, this hurts the acceptance rate. + +### Draft Model + +As aforementioned, we can apply grammar constraints for draft tokens in a similar mask gen and applying way for speculative algorithms based on recurrent logits sampling. Specifically, for the first drafting step, guided decoding advances the grammar state using the last new token. For the following drafting steps, the grammar state is advanced using the last draft token. Each step should fill and apply the mask to the corresponding draft model logits before sampling. + +After the drafting process, the grammar state should be rolled back to the original state, so that the subsequent target model forward can have the correct grammar state. If the draft and target models share the same vocabulary, then the grammar computation is exactly the same so the masks can be reused. + +One special case is EAGLE3, whose draft model has a [pruned vocabulary](https://github.com/SafeAILab/EAGLE/blob/58d1de099fe315645a82fe002e46586d54efe405/eagle/traineagle3/config.json#L22-L23) compared to the target model. For instance, LLaMA 3.1 has a 128k vocabulary size, while the corresponding EAGLE3 drafter has a vocabulary containing the most frequent 32k tokens. This saves some computation of the lm_head GEMM. Note that grammar is built on the target model’s vocabulary, so the produced mask cannot be directly applied to the logits of the draft model. EAGLE3 provides a special [d2t](https://github.com/SafeAILab/EAGLE/blob/d7161f9f94aaa345654d9b4045931145811d4d03/eagle/traineagle3/cnets.py#L673-L681) tensor that maps draft token IDs to target token IDs. [PR 7481](https://github.com/NVIDIA/TensorRT-LLM/pull/7481) fuses this d2t mapping to the mask applying kernel. + +> **Note:** Here we focus on the chain-based speculative algorithms. A tree-based algorithm would further complicate the implementation; in particular, guided decoding should traverse the drafting tree, advance and rollback grammar states accordingly. + +## Make Grammar Computation Capturable by CUDA Graph + +### CUDA Callback + +CUDA graph can help eliminate the CPU overhead, which is an important technique in the LLM inference systems, especially for the generation phase. As aforementioned, the one-model speculative decoding implementation launches a single CUDA graph to compute multiple draft and target model forwards. This makes the CPU-GPU synchronization challenging: the sampling operation depends on masks computed on CPU, but the GPU is not able to wait for the readiness of any CPU computation once the CUDA graph is launched. + +CUDA callback [`cudaLaunchHostFunc`](https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EXECUTION.html#group__CUDART__EXECUTION_1g05841eaa5f90f27124241baafb3e856f) can launch a host function to a CUDA stream. (The host function should not call any CUDA API.) This has two crucial implications: + +* CUDA events and event waits can be inserted before and after the host functions, which can be used to synchronize the CPU and GPU computation. +* The host functions can be captured and replayed by CUDA graph. + +Hence, we can launch grammar computation along with other auxiliary host functions as CUDA callbacks to a CUDA stream. The CUDA graph should capture and replay multiple model forwards and corresponding grammar computation all together. To achieve CPU-GPU overlapping, the grammar computation should be placed on a dedicated CUDA stream. Specifically, for every step $i$: + +* The grammar stream: + * waits for the *token event* that indicates the readiness of CPU tokens from step $i-1$; + * performs grammar advance and mask gen (CUDA callback); + * asynchronously copies the CPU masks to GPU; + * records the *mask event*. +* The model forward stream: + * computes model forward using the last GPU tokens; + * waits for the *mask event* that indicates the readiness of GPU masks; + * applies the mask to logits and then samples new tokens; + * asynchronously copies the GPU tokens to CPU; + * records the *token event*. + +Figure 4: The CPU-GPU synchronization for multiple model forwards by CUDA callback.
+ +### Integration to TensorRT LLM Python Runtime + +We surveyed some off-the-shelf Python bindings implementations of `cudaLaunchHostFunc`, but it turned out that they do not work well with CUDA graph (e.g., CUDA-Python [Issue 790](https://github.com/NVIDIA/cuda-python/issues/790), cupy [Issue 9274](https://github.com/cupy/cupy/issues/9274)). The probable reason is that the intermediate wrapper data structures are released once the callback is executed; hence, even though the callback is captured by CUDA graph, it cannot be replayed multiple times. + +We implement our own bindings to `cudaLaunchHostFunc` — [`launch_hostfunc`](https://github.com/NVIDIA/TensorRT-LLM/blob/v1.1.0rc5/cpp/tensorrt_llm/nanobind/runtime/hostfunc.cpp#L76). Specifically, `launch_hostfunc` packs the Python function and arguments to an [intermediate data structure](https://github.com/NVIDIA/TensorRT-LLM/blob/v1.1.0rc5/cpp/tensorrt_llm/nanobind/runtime/hostfunc.cpp#L33) and calls `cudaLaunchHostFunc` to launch a [trampoline function](https://github.com/NVIDIA/TensorRT-LLM/blob/v1.1.0rc5/cpp/tensorrt_llm/nanobind/runtime/hostfunc.cpp#L49) to a CUDA stream. The trampoline function unpacks the intermediate data structure and invokes the Python function with the arguments. Note that `launch_hostfunc` offers great flexibility — it can launch an arbitrary Python function (without any CUDA API calls) as a CUDA callback. Hence, the grammar computation logics can still be implemented in Python. + +When CUDA graph is capturing, `launch_hostfunc` does not release the intermediate data structure, so it is accessible during CUDA graph replay. The intermediate data structures can be manually released via [`free_hostfunc_user_data`](https://github.com/NVIDIA/TensorRT-LLM/blob/v1.1.0rc5/cpp/tensorrt_llm/nanobind/runtime/hostfunc.cpp#L97); otherwise, they are automatically cleaned up when the Python interpreter exists. If CUDA graph is disabled (e.g., prefill phase), the intermediate data structure should be released promptly to avoid memory leak. Specifically, the trampoline function automatically releases it once the callback finishes execution. + +In Python, we provide a decorator `hostfunc` which casts an arbitrary Python function to a CUDA callback. For example, run the below code snippet: + +```python +import torch +from tensorrt_llm._torch.hostfunc import hostfunc + +@hostfunc +def increase(x: torch.Tensor): + x.add_(1) + +x = torch.zeros(10, dtype=torch.int32) + +stream = torch.cuda.Stream() +g = torch.cuda.CUDAGraph() +with torch.cuda.graph(g, stream=stream): + increase(x) + increase(x) +torch.cuda.synchronize() + +with torch.cuda.stream(stream): + for _ in range(10): + g.replay() + +torch.cuda.synchronize() +print(x) +``` + +The output would look like: + +```txt +tensor([20, 20, 20, 20, 20, 20, 20, 20, 20, 20], dtype=torch.int32) +``` + +Note that the CUDA graph increases the tensor twice, and it is replayed for ten times, so the tensor should be totally increased by 20 times. Clearly, the output validates that the CUDA graph capture and replay are successful. + +As the final step, we implemented a variant of `GuidedDecoder` — [`CapturableGuidedDecoder`](https://github.com/NVIDIA/TensorRT-LLM/blob/v1.1.0rc5/tensorrt_llm/_torch/pyexecutor/guided_decoder.py#L405). It reuses most logics from `GuidedDecoder`, but the grammar computation and some auxiliary methods are decorated by `hostfunc`, making it capturable by CUDA graph. + +### CUDA Graph Compatibility: Grammar Computation + +Once captured, CUDA graph can be launched to run the same GPU kernels as many times as needed. Note that the replayed kernels are always executed using the fixed input and output memory addresses. By filling input buffers with new data, we can run the same work on new data. This pattern also applies to CUDA callback, except that the input and output buffers are on CPU. + +Guided decoder manages the below buffers and resources: + +* [Request states](https://github.com/NVIDIA/TensorRT-LLM/blob/v1.1.0rc5/tensorrt_llm/_torch/pyexecutor/guided_decoder.py#L20): All the necessary request information affecting grammar computation, including the user-specified grammar, the last new token and draft tokens. +* [Grammar states](https://github.com/NVIDIA/TensorRT-LLM/blob/v1.1.0rc5/tensorrt_llm/_torch/pyexecutor/guided_decoder.py#L167-L168): The grammar states managed by grammar backends. By leveraging the grammar backends, guided decoder advances grammar states and fills mask tensors. +* [New tokens tensor](https://github.com/NVIDIA/TensorRT-LLM/blob/v1.1.0rc5/tensorrt_llm/_torch/pyexecutor/guided_decoder.py#L419-L422): The tensor values are copied from the newly computed GPU tokens, and used to update the last new token or draft tokens of the request states. +* [Mask tensor](https://github.com/NVIDIA/TensorRT-LLM/blob/v1.1.0rc5/tensorrt_llm/_torch/pyexecutor/guided_decoder.py#L175-L177): The tensor values are filled according to the grammar states and then copied to GPU masks, which will be used to apply to logits. + +The buffers are stored in fixed memories, and the resources are accessed via fixed pointers. This makes grammar computation compatible with CUDA graph. The buffers and resources are connected via slot IDs. In the runtime, each request is assigned with an exclusive slot ID (0 <= slot ID < `max_batch_size`) upon the first scheduling. The slot ID is occupied until the request is finished and removed from the scheduler. + +When the runtime schedules a new batch of requests, the guided decoder updates the request states on the host. After that, all the other operations (grammar compilation/advance, mask gen, buffer copying, etc.) happen on CUDA streams and should be capturable by CUDA graph. More specifically, buffer copying should be asynchronous, and the other CPU computation should be CUDA callbacks. + +### CUDA Graph Compatibility: Mask Applying Kernel + +The mask applying kernel takes a batch of logits and masks as the input, and modifies the logits in-place. Specifically, the masked-out (disallowed by grammar) token logits are assigned a value of negative infinity, so that they are impossible to be sampled as the next tokens. + +Note that currently CUDA graph is enabled for the generation phase only, and the draft length is fixed for all requests. This greatly simplifies the effort for CUDA graph compatibility. Given a `batch_size` and `max_num_draft_tokens`, the logits tensor is of shape `(batch_size * (1 + max_num_draft_tokens), vocab_size)`. Clearly, we can fill the first `(batch_size * (1 + max_num_draft_tokens))` rows of the mask tensor accordingly, and pass the mask tensor address to the kernel. + +Some requests may have no grammar constraints. For such requests, we can fill the corresponding masks with all ones (allowed by grammar) so the logits will not be modified by the kernel, but this causes unnecessary computation. To resolve this, a token-level mask tensor is introduced. The tensor values are filled with zeros for requests without grammar constraints. The kernel reads these mask values and skips the rows with mask values being zero. + +### Troubleshooting: Data Race between Host and CUDA Callback + +Similar to GPU kernels, CUDA callbacks are asynchronously executed on CUDA streams. Note that both normal host functions and CUDA callbacks can access the same CPU memory addresses, so it can easily cause a data race. + +In the initial implementation, `CapturableGuidedDecoder` directly reads request states from [`ScheduledRequests`](https://github.com/NVIDIA/TensorRT-LLM/blob/v1.1.0rc5/tensorrt_llm/_torch/pyexecutor/scheduler.py#L18). However, the `ScheduledRequests` is shared through an executor iteration and thus probably modified by other executor components. This creates a potential data race scenario: + +* Guided decoder launches a CUDA callback, which will read some request states from `ScheduledRequests`; +* Some other executor components inplace modify `ScheduledRequests`; +* The CUDA callback is executed, reading some modified request states from `ScheduledRequests`. + +Clearly, the CUDA callback may read unexpected data. This data race motivates a dedicated request states class — [`GuidedRequest`](https://github.com/NVIDIA/TensorRT-LLM/blob/v1.1.0rc5/tensorrt_llm/_torch/pyexecutor/guided_decoder.py#L20). It is a request snapshot created for guided decoder only, so it will never be modified by other components. It is also possible that the guided decoder itself may access request states via both normal host functions and CUDA callbacks, so we adopt a protocol that the request snapshots should be created on the host, and then accessed only via CUDA callbacks. This prevents potential data race within an executor iteration. + +When overlap scheduler is enabled, another data race scenario exists between executor iterations: + +* Iteration $i$ launches CUDA callbacks, which will read request states from a fixed address; +* Iteration $i+1$ updates the request states; +* Iteration $i$'s CUDA callbacks are executed, reading request states updated by iteration $i+1$. + +Again, the CUDA callbacks may read unexpected data. A straightforward solution is letting the request state update wait for CUDA callback execution, but this effectively disables overlap scheduling. To resolve this issue and also unblock overlap scheduling, a [queue](https://github.com/NVIDIA/TensorRT-LLM/blob/v1.1.0rc5/tensorrt_llm/_torch/pyexecutor/guided_decoder.py#L417) is introduced. For each iteration, a new batch of request states is put into the queue; then, a CUDA callback is launched to fetch a new batch of request states from the queue, and all the subsequent CUDA callbacks access the newly fetched request states. This allows the co-existence of the request snapshots of two (or even more) iterations, which prevents potential data race between iterations. + +### Troubleshooting: Deadlock by GIL and CUDA Mutex + +After the first version was implemented, the program intermittently encountered a hang issue when `CapturableGuidedDecoder` is enabled. By checking out the callstack, we found that it was hanging on completely irrelevant kernel launches or some other CUDA API calls. With further investigation, we discovered that the hang issue was caused by a deadlock between the Python GIL and a CUDA mutex. + +As documented, a CUDA callback must not make any CUDA API calls. This implies that the CUDA callback execution and CUDA API calls compete for the same mutex. Note that the trampoline function needs to [acquire the GIL](https://github.com/NVIDIA/TensorRT-LLM/blob/v1.1.0rc5/cpp/tensorrt_llm/nanobind/runtime/hostfunc.cpp#L52) before calling the Python code. Hence, when executing Python code by a CUDA callback, it acquires a CUDA mutex and then the GIL. In the meanwhile, the Python main thread may hold the GIL and make CUDA API calls, so it acquires the GIL and then the CUDA mutex. The two threads acquire the two locks in opposite orders, which creates a deadlock pattern. + +This deadlock can be resolved if the Python main thread can release the GIL for CUDA API calls. TensorRT LLM Python runtime is built on PyTorch. Thankfully, PyTorch releases the GIL for most CUDA API calls, even including PyTorch custom operators. However, we find two exceptions in PyTorch 2.8. When creating a device tensor using a shape depending on data from another device tensor, it triggers an implicit and synchronized D2H copy, and this D2H copy is executed with GIL being held ([Issue 163062](https://github.com/pytorch/pytorch/issues/163062)). This can be reproduced by the below code snippet: + +```python +import torch + +x = torch.randint(0, 100, (100,), dtype=torch.int64, device='cuda') +y = torch.zeros(100, x.max(), dtype=torch.int64, device='cuda') +``` + +The other case is that `torch.compile` kernels are called with GIL being held ([Issue 163061](https://github.com/pytorch/pytorch/issues/163061)), although Triton kernels are called with GIL released. Hence, we have to avoid any problematic operations and disable `torch.compile` when using CUDA callback to Python code, until these issues are fixed by PyTorch. + +Another source of risk comes from some runtime components that are implemented in C++ and exposed as Python bindings; they may make CUDA API calls as well. By default, Python bindings do not release GIL. Hence, we swept these Python bindings and released GIL properly in [PR 6948](https://github.com/NVIDIA/TensorRT-LLM/pull/6948). + +After all these efforts, the hang issue disappears. It is generally recommended to release the GIL when calling C++ code from Python; even without the context of CUDA callback, this is beneficial for multi-threading performance. However, we acknowledge the limitation that it is difficult to make sure that every such place has been properly handled, and that future code changes do not introduce any risks. + +> **Note:** Theoretically, the GIL-free Python ([PEP 703](https://peps.python.org/pep-0703)) could be another remedy. + +## Performance and Analysis + +We benchmark the performance of guided decoding on two datasets [JSON Mode Eval](https://huggingface.co/datasets/NousResearch/json-mode-eval) and [JSON Schema Bench](https://huggingface.co/datasets/epfl-dlab/JSONSchemaBench). The models are [LLaMA 3.1 8B](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) and [LLaMA 3.3 70B](https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct), the GPUs are H200 and the grammar backend is XGrammar. + +Figure 5: Pareto curve on LLaMA 3.1 8B TP1 on H200, JSON Mode Eval. The concurrency ranges from 1 to 128.
+ + +Figure 6: Pareto curve on LLaMA 3.3 70B TP4 on H200, JSON Mode Eval. The concurrency ranges from 1 to 128.
+ +Figures 5 and 6 present the Pareto curves on JSON Mode Eval for LLaMA 3.1 8B and LLaMA 3.3 70B, respectively. Speculative decoding achieves significant speedup for low-latency or throughput@latency scenarios. In particular, the speedup can be up to ~2x for batch size 1. The one-model EAGLE3 implementation is more performant than the two-model EAGLE3, and this performance gap is amplified for small models. This is reasonable, because the one-model implementation captures more workloads into a single CUDA graph, which results in less (if any) exposed CPU overhead. + +Note that although NGram is a two-model implementation, it performs surprisingly well. This is because JSON Mode Eval is an information extraction task. Each prompt contains the JSON schema and all the information required by the response, so the NGram has a high acceptance rate on this dataset. + +Figure 7: Pareto curve on LLaMA 3.1 8B TP1 on H200, JSON Schema Bench. The concurrency ranges from 1 to 128.
+ +Figure 8: Pareto curve on LLaMA 3.3 70B TP4 on H200, JSON Schema Bench. The concurrency ranges from 1 to 128.
+ +Figures 7 and 8 show the results on JSON Schema Bench. The one-model EAGLE3 achieves the best performance across almost all scenarios. Note that the NGram becomes less performant since the task is no longer an information extraction task, although the JSON schemas are still present in the prompts. + +| Dataset | Model | EAGLE3 | EAGLE3 w/o draft | NGram | +| :-----: | :---: | :----: | :--------------: | :---: | +| JSON Mode Eval | LLaMA 3.1 8B | 2.86 | 2.65 | 2.59 | +| JSON Mode Eval | LLaMA 3.3 70B | 2.72 | 2.60 | 2.44 | +| JSON Schema Bench | LLaMA 3.1 8B | 2.55 | 2.33 | 1.89 | +| JSON Schema Bench | LLaMA 3.3 70B | 2.50 | 2.30 | 1.87 | + +Table 1: Average acceptance lengths per iteration for EAGLE3 and NGram. The acceptance length includes the golden token. The draft length is 3. "EAGLE3 w/o draft" means the draft model does not apply grammar constraints.
+ +Table 1 lists the average acceptance lengths per iteration. We perform an ablation experiment where the draft model does not apply grammar constraints. As presented, this does decrease acceptance rates, but by a slighter margin than expected. Note that it introduces extra overheads to apply grammar constraints on the draft model: + +* In the drafting loop, the extra mask applying kernels slightly contribute to the GPU time. +* If the drafting forwards are too fast to hide the grammar computation, the exposed CPU time will cause bubbles in the GPU timeline. + +These extra overheads could partially offset the benefits from the improved acceptance. + +## Acknowledgements + +This work demonstrates an outstanding example of cross-team collaboration between the TensorRT LLM and XGrammar teams. We sincerely appreciate the support from everyone who contributed to making this happen. + +We acknowledge that it is built on top of the tremendous existing foundations from the community. In particular, some designs were inspired by vLLM [PR 14702](https://github.com/vllm-project/vllm/pull/14702) and SGLang [PR 6499](https://github.com/sgl-project/sglang/pull/6499). In addition, special thanks go to the authors who proposed the speculative algorithms like EAGLE/MTP, and the grammar backend projects like XGrammar/LLGuidance. diff --git a/docs/source/conf.py b/docs/source/conf.py index def277aba43..c531f1c06e2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -50,6 +50,7 @@ 'sphinx.ext.autosummary', 'sphinx.ext.viewcode', 'sphinx.ext.napoleon', + 'sphinx.ext.mathjax', 'myst_parser', # for markdown support "breathe", 'sphinx.ext.todo', @@ -86,6 +87,8 @@ myst_enable_extensions = [ "deflist", "substitution", + "dollarmath", + "amsmath", ] myst_substitutions = { @@ -167,8 +170,11 @@ def tag_role(name, rawtext, text, lineno, inliner, options=None, content=None): def setup(app): from helper import generate_examples, generate_llmapi - from tensorrt_llm.llmapi.utils import tag_llm_params - tag_llm_params() + try: + from tensorrt_llm.llmapi.utils import tag_llm_params + tag_llm_params() + except ImportError: + print("Warning: tensorrt_llm not available, skipping tag_llm_params") app.add_role('tag', tag_role) diff --git a/examples/models/core/gemma/requirements.txt b/examples/models/core/gemma/requirements.txt index a1bbed25b68..20f8719a379 100644 --- a/examples/models/core/gemma/requirements.txt +++ b/examples/models/core/gemma/requirements.txt @@ -5,6 +5,7 @@ nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64" tensorrt_llm>=0.0.0.dev0 flax~=0.8.0 +numpy<2 # jax[cuda12_pip]~=0.4.19 safetensors~=0.4.1 sentencepiece>=0.1.99 diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index a317149b1b6..b65d9366d0f 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -181,6 +181,7 @@ def plan( q_pe: Optional[torch.Tensor] = None, mrope_config: Optional[dict] = None, softmax_stats_tensor: Optional[torch.Tensor] = None, + helix_position_offsets: Optional[torch.Tensor] = None, is_spec_decoding_enabled: bool = False, use_spec_decoding: bool = False, is_spec_dec_tree: bool = False, @@ -225,6 +226,7 @@ def plan( use_paged_context_fmha (bool): Sets the mPagedContextFMHA attribute in the op runner. mrope_config (dict): The dictionary containing the mRope configuration. softmax_stats_tensor (torch.Tensor): The tensor to store the softmax statistics (max/sum) + helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU. attention_sinks (torch.Tensor): The attention sinks (additional value in the denominator of the softmax) with shape of (num_heads_q) on GPU. chunked_prefill_buffer_batch_size (int): used for malloc buffer for k and v in fp8 context mla. the max input kv length is not max_num_tokens in this case. It is chunked_prefill_buffer_batch_size * max_num_tokens. """ @@ -262,6 +264,7 @@ def plan( 'mrope_position_deltas') if mrope_config is not None else None self.block_ids_per_seq = block_ids_per_seq self.softmax_stats_tensor = softmax_stats_tensor + self.helix_position_offsets = helix_position_offsets self.attention_sinks = attention_sinks if max_sequence_length > self.rope_params.max_positions: @@ -420,6 +423,7 @@ def run( self.spec_decoding_generation_lengths, self.spec_decoding_position_offsets, self.spec_decoding_packed_mask ] + mla_tensor_params = [self.helix_position_offsets] thop.attention( q, @@ -482,6 +486,7 @@ def run( self.v_head_dim, self.mrope_rotary_cos_sin, self.mrope_position_deltas, + mla_tensor_params, self.attention_chunk_size, self.softmax_stats_tensor, spec_decoding_bool_params, @@ -1214,6 +1219,7 @@ def forward( mrope_config: Optional[dict] = None, attention_window_size: Optional[int] = None, softmax_stats_tensor: Optional[torch.Tensor] = None, + helix_position_offsets: Optional[torch.Tensor] = None, enable_attn_nvfp4_output: bool = True, output: Optional[torch.Tensor] = None, output_sf: Optional[torch.Tensor] = None, @@ -1284,6 +1290,7 @@ def forward( q_pe=q_pe, mrope_config=mrope_config, softmax_stats_tensor=softmax_stats_tensor, + helix_position_offsets=helix_position_offsets, is_spec_decoding_enabled=metadata.is_spec_decoding_enabled, use_spec_decoding=metadata.use_spec_decoding, is_spec_dec_tree=metadata.is_spec_dec_tree, diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 49f8b67c8ee..d57b3fd40d8 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -69,6 +69,7 @@ transforms: stage: sharding simple_shard_only: false use_sharding_from_factory: false + support_partial_config: false sharding_dims: ['tp', 'ep', 'bmm'] # TODO: (hg) need to ensure run_shape_prop after sharding. sharding_transform_executor: diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index faf5134d8cc..2ff564b17e3 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -164,7 +164,7 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): ) sharding_dims: List[str] = Field( - default=["tp", "ep", "bmm"], + default=["tp", "ep", "dp"], description="The sharding methods to apply by the heuristic sharding stage.", ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index a85389c94a7..dc4b62caafa 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -127,6 +127,7 @@ class ShardingTransformConfig(TransformConfig): simple_shard_only: bool = Field(default=False) use_sharding_from_factory: bool = Field(default=False) + support_partial_config: bool = Field(default=False) # Which sharding families to run: any subset of {"tp", "ep", "bmm"} sharding_dims: List[str] = Field(default_factory=lambda: ["tp", "ep", "bmm"]) @@ -185,6 +186,9 @@ def _apply( else ShardingConfigSource.UNKNOWN ) shared_config.sharding_config.simple_shard_only = self.config.simple_shard_only + shared_config.sharding_config.support_partial_config = self.config.support_partial_config + shared_config.sharding_config.sharding_dims = self.config.sharding_dims + shared_config.sharding_config.use_sharding_from_factory = ( self.config.use_sharding_from_factory ) @@ -200,8 +204,6 @@ def _apply( factory_info = detect_sharding_from_factory_config(gm, sharding_config) return gm, factory_info - shared_config.sharding_config.sharding_dims = self.config.sharding_dims - ad_logger.info( f"Running autodeploy sharding heuristics: {shared_config.sharding_config.sharding_dims}" ) @@ -338,8 +340,39 @@ def detect_sharding_from_factory_config( # TODO: Sequence parallelism is not supported yet. ad_logger.warning("Sequence parallelism is not supported yet. Skipping.") elif "local" in config: - # TODO: local refers to hybrid EP+TP parallelism. Not supported yet. - ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.") + # Check if this applies to shared experts in EP parallelism. + # If yes, apply the TP col-row shard. + if "shared" in module_name: + col_row_action = config.replace("local_", "") + if col_row_action == "colwise": + sharding_config.tp_transforms.append( + TPShardingInfo( + target_node=lin_node.name, + split_dim=SplitDimension.COLUMN, + rank=rank, + world_size=world_size, + dist_op=None, + min_local_shape=min_local_shape, + ) + ) + elif col_row_action == "rowwise": + sharding_config.tp_transforms.append( + TPShardingInfo( + target_node=lin_node.name, + split_dim=SplitDimension.ROW, + rank=rank, + world_size=world_size, + dist_op="all_reduce", + min_local_shape=min_local_shape, + ) + ) + num_row_col_shards += 1 + else: + ad_logger.warning("Invalid sharding config. Skipping.") + else: + # TODO: local refers to hybrid EP+TP parallelism. Not supported yet. + ad_logger.warning("Local EP+TP sharding is not supported yet. Skipping.") + elif "gather" in config: # Simple shard (row + all_gather) sharding_config.tp_transforms.append( @@ -362,9 +395,35 @@ def detect_sharding_from_factory_config( f"Applied {num_shards} TP shards (simple: {num_simple_shards}, " f"row-col pattern: {num_row_col_shards})" ) + + num_matches = len(sharding_config.tp_transforms) + + if sharding_config.support_partial_config: + ad_logger.info( + f"Partial factory config applied only for TP. " + f"Applying heuristics for {sharding_config.sharding_dims}." + ) + + # run EP sharding across ranks + if "ep" in sharding_config.sharding_dims: + ep_info = detect_ep_shard(gm, sharding_config) + else: + ep_info = TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + # run BMM sharding across ranks + if "bmm" in sharding_config.sharding_dims: + dp_bmm_info = detect_dp_bmm_shard(gm, sharding_config) + else: + dp_bmm_info = TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + num_matches += ep_info.num_matches + dp_bmm_info.num_matches + return TransformInfo( skipped=False, - num_matches=len(sharding_config.tp_transforms), + num_matches=num_matches, is_clean=False, has_valid_shapes=False, ) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index 40680ada291..3862d03e6e5 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -737,6 +737,7 @@ class ShardingConfig(BaseModel): predefined_config: Optional[Dict[str, Any]] = None simple_shard_only: bool = Field(default=False) use_sharding_from_factory: bool = False + support_partial_config: bool = False sharding_dims: List[str] = Field(default_factory=list) tp_transforms: List[TPShardingInfo] = Field(default_factory=list) bmm_transforms: List[BMMShardingInfo] = Field(default_factory=list) @@ -781,7 +782,7 @@ def validate_config(self) -> bool: tp_plan = self.predefined_config["tp_plan"] values = set(tp_plan.values()) - allowed_values = { + supported_modes = { "colwise", # row split and no collective "rowwise", # column split and all-reduce "gather", # simple shard (row + all_gather) @@ -793,7 +794,7 @@ def validate_config(self) -> bool: # "local_packed_rowwise", # "local", } - if not values.issubset(allowed_values): + if not self.support_partial_config and not values.issubset(supported_modes): ad_logger.warning("Sharding config contains invalid values. Skipping.") # invalidate the config self.predefined_config = {} diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py index fa9be6afcf8..221478edf8e 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @@ -78,7 +78,7 @@ class Qwen3MoE(nn.Module): def __init__( self, model_config: ModelConfig[Qwen3MoeConfig], - aux_stream: torch.cuda.Stream, + aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], layer_idx: Optional[int] = None, ): super().__init__() @@ -108,7 +108,7 @@ def __init__( routing_method=self.gate.routing_method, hidden_size=self.hidden_dim, intermediate_size=self.moe_intermediate_size, - aux_stream_dict={AuxStreamType.MoeChunkingOverlap: aux_stream}, + aux_stream_dict=aux_stream_dict, dtype=config.torch_dtype, reduce_results=False, model_config=model_config, @@ -160,7 +160,8 @@ def forward( class Qwen3MoEDecoderLayer(DecoderLayer): def __init__(self, model_config: ModelConfig[Qwen3MoeConfig], - layer_idx: int, aux_stream: torch.cuda.Stream): + layer_idx: int, aux_stream_dict: Dict[AuxStreamType, + torch.cuda.Stream]): super().__init__() self.model_config = model_config config = model_config.pretrained_config @@ -171,7 +172,7 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig], self.mapping = model_config.mapping self.enable_attention_dp = self.mapping.enable_attention_dp - self.mlp = Qwen3MoE(model_config, aux_stream, layer_idx=layer_idx) + self.mlp = Qwen3MoE(model_config, aux_stream_dict, layer_idx=layer_idx) self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, @@ -302,7 +303,10 @@ class Qwen3MoEModel(DecoderModel): def __init__(self, model_config: ModelConfig[Qwen3MoeConfig]): super().__init__(model_config) config = self.model_config - self.aux_stream = torch.cuda.Stream() + self.aux_stream_dict = { + AuxStreamType.MoeChunkingOverlap: torch.cuda.Stream(), + AuxStreamType.MoeBalancer: torch.cuda.Stream(), + } self.preload_weight_modules = [] if config.moe_backend == "TRTLLM": self.preload_weight_modules = [ @@ -332,7 +336,7 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig]): Qwen3MoEDecoderLayer( model_config, layer_idx, - self.aux_stream, + self.aux_stream_dict, ) for layer_idx in range(config.pretrained_config.num_hidden_layers) ]) self.norm = RMSNorm( diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 5281a7de888..ccf10538b00 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1,12 +1,9 @@ import bisect import contextlib -import copy import functools import gc import inspect import math -import os -import traceback import weakref from abc import ABC, abstractmethod from contextlib import contextmanager @@ -17,16 +14,13 @@ import tensorrt_llm.bindings.internal.userbuffers as ub from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc, - str_dtype_to_torch, torch_dtype_to_str, - trace_func) + torch_dtype_to_str, trace_func) from tensorrt_llm.inputs.multimodal import (MultimodalParams, MultimodalRuntimeData) from tensorrt_llm.logger import logger from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.lora_manager import LoraModelConfig from tensorrt_llm.mapping import CpType, Mapping -from tensorrt_llm.models.modeling_utils import QuantAlgo -from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2 from ..attention_backend.interface import (AttentionMetadata, AttentionRuntimeFeatures) @@ -40,14 +34,11 @@ from ..distributed.communicator import init_pp_comm from ..expert_statistic import ExpertStatistic from ..metadata import KVCacheParams -from ..model_config import ModelConfig, MoeLoadBalancerConfig -from ..models import AutoModelForCausalLM from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids -from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode, - timing) -from ..modules.fused_moe.moe_load_balancer import ( - MoeLoadBalancer, MoeLoadBalancerIterContext, maybe_create_moe_load_balancer) +from ..models.modeling_utils import DecoderModelForCausalLM +from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer, + MoeLoadBalancerIterContext) from ..speculative import (SpecMetadata, get_num_extra_kv_tokens, get_spec_metadata, update_spec_config_from_model_config) @@ -56,12 +47,13 @@ from ..utils import (get_model_extra_attrs, set_per_request_piecewise_cuda_graph_flag, set_torch_compiling, with_model_extra_attrs) -from .config import LoadFormat, PyTorchConfig +from .config import PyTorchConfig from .config_utils import is_mla from .cuda_graph_runner import CUDAGraphRunner from .guided_decoder import CapturableGuidedDecoder from .layerwise_nvtx_marker import LayerwiseNvtxMarker from .llm_request import get_draft_token_length +from .model_loader import ModelLoader from .resource_manager import (BaseResourceManager, KVCacheManager, ResourceManager, ResourceManagerType) from .sampler import SampleStateTensors @@ -96,137 +88,6 @@ def warmup(self, resource_manager: ResourceManager) -> None: return -_KV_CACHE_MAP = { - "fp8": QuantAlgo.FP8.value, - "nvfp4": QuantAlgo.NVFP4.value, - "auto": "auto" -} -_VALID_KV_CACHE_DTYPES = ("fp8", "nvfp4", "auto") - - -def validate_and_set_mamba_ssm_cache_dtype(config: ModelConfig, - mamba_ssm_cache_dtype: str) -> None: - if mamba_ssm_cache_dtype == "auto": - mamba_ssm_cache_dtype = config.pretrained_config.torch_dtype - else: - mamba_ssm_cache_dtype = str_dtype_to_torch(mamba_ssm_cache_dtype) - - config.quant_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype - - -def validate_and_set_kv_cache_quant(model_config: ModelConfig, - pyt_kv_cache_dtype: str) -> QuantAlgo: - logger.info( - f'Validating KV Cache config against kv_cache_dtype="{pyt_kv_cache_dtype}"' - ) - # Quantization from hf_quant_config.json - kv_cache_quant = model_config.quant_config.kv_cache_quant_algo - # PyTorch configuration quantization - valid_pyt_quant = bool(pyt_kv_cache_dtype in _VALID_KV_CACHE_DTYPES) - mapped_pyt_quant = _KV_CACHE_MAP.get(pyt_kv_cache_dtype, None) - - # If we're letting the checkpoint dictate the quant with auto, simply - # return and do not modify the checkpoint. - if pyt_kv_cache_dtype == "auto": - logger.info( - f'KV cache quantization set to "{pyt_kv_cache_dtype}". Using ' - "checkpoint KV quantization.") - return - - # If we have an invalid quantization, simply raise an exception. - if not valid_pyt_quant: - raise ValueError( - "Overriding KV cache quantization with an invalid type " - f'"PyTorchConfig.kv_cache_dtype="{pyt_kv_cache_dtype}" ' - f'Accepted types are "{_VALID_KV_CACHE_DTYPES}".') - - # If we get to this point we have a valid quantization setting, but if - # we have an existing setting and it doesn't match we shouldn't proceed. - if kv_cache_quant is not None and mapped_pyt_quant != kv_cache_quant: - raise RuntimeError( - "Attempting to override KV cache quantization " - f'"{kv_cache_quant}" with PyTorchConfig.kv_cache_dtype=' - f'"{pyt_kv_cache_dtype}". You cannot override a checkpoint with a ' - "pre-quantized KV cache that doesn't match.") - - # We have an open ended KV cache in the checkpoint - # and we have a specified override. - model_config.quant_config.kv_cache_quant_algo = mapped_pyt_quant - - -def initialize_dummy_weights( - model: torch.nn.Module, - low: float = -1e-3, - high: float = 1e-3, - seed: int = 0, -) -> None: - """ - This is similar to this function in SGLang with a few changes: - https://github.com/sgl-project/sglang/blob/e074e76b31d4fff13e87a455dbc3acdaa92c537a/python/sglang/srt/model_loader/weight_utils.py#L577 - - This method is used to initialize weights with dummy values for testing - models without checkpoints. Unquantized (FP16/BF16/etc) values are generated - from a uniform distribution over the interval (low, high). - - For some quantized types (FP8/NVFP4), torch has no built-in way to generate random values. - We simply generate values uniformly across an interval that has been empirically verified - to not generate NaNs/inf for these. - """ - - def _get_random_min_max(dtype: torch.dtype) -> Tuple[int, int]: - # These values are not necessarily the largest possible min/max, - # they need to be small enough to avoid NaNs. - if dtype in (torch.float8_e4m3fn, torch.int8): - return (-3.0, 3.0) - - elif dtype == float4_e2m1x2: - # These correspond to bits of 2 packed FP4 values. - # Because we only go up to 64, the high 4 bits will - # always be 0. But this is fine - we just need values - # that won't generate NaNs. - return (0, 64) - - else: - raise NotImplementedError(f"Unknown quantized type: {dtype}.") - - for param in model.state_dict().values(): - generator = torch.Generator(device=param.data.device) - generator.manual_seed(seed) - dtype = param.data.dtype - - if param.data.element_size() < 2: - # We need to do a cast/round since torch doesn't have uniform_ - # support for these dtypes. - tmp_param = torch.empty(param.data.shape, - dtype=torch.float16, - device=param.data.device) - - quant_min, quant_max = _get_random_min_max(dtype) - tmp_param = tmp_param.uniform_(quant_min, - quant_max, - generator=generator) - - param.data.copy_(tmp_param.to(dtype)) - - # Note: no need to to mess with int32 params, these are probably - # constants and not weights. - elif torch.is_floating_point(param): - param.uniform_(low, high, generator=generator) - - -def get_rank_model_storage(model): - total_bytes = 0 - for _, param in model.named_parameters(): - if param.device.type == 'cuda' and param.device.index == torch.cuda.current_device( - ): - total_bytes += param.element_size() * param.nelement() - for _, buf in model.named_buffers(): - if buf.device.type == 'cuda' and buf.device.index == torch.cuda.current_device( - ): - total_bytes += buf.element_size() * buf.nelement() - return total_bytes - - def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int], max_batch_size: int, max_num_tokens: int, max_draft_len: int, @@ -280,6 +141,7 @@ def __init__( is_draft_model: bool = False, drafting_loop_wrapper: Optional[Callable[[torch.nn.Module], torch.nn.Module]] = None, + model: Optional[torch.nn.Module] = None, ): self.ub_buffers = None self.batch_size = batch_size @@ -302,21 +164,24 @@ def __init__( self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures( ) - attn_backend = pytorch_backend_config.attn_backend - self.model = self._load_model( - model_path, - mapping=self.mapping, - checkpoint_loader=checkpoint_loader, - attn_backend=attn_backend, - moe_backend=pytorch_backend_config.moe_backend, - moe_disable_finalize_fusion=pytorch_backend_config. - moe_disable_finalize_fusion, - load_format=pytorch_backend_config.load_format, - max_num_tokens=max_num_tokens, - moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens, - moe_load_balancer=pytorch_backend_config.moe_load_balancer, - lora_config=lora_config, - drafting_loop_wrapper=drafting_loop_wrapper) + if model is None: + loader = ModelLoader( + pytorch_backend_config=pytorch_backend_config, + mapping=self.mapping, + spec_config=self.spec_config, + max_num_tokens=max_num_tokens, + max_seq_len=max_seq_len, + lora_config=lora_config, + ) + self.model = loader.load(checkpoint_dir=model_path, + checkpoint_loader=checkpoint_loader) + else: + self.model = model + if drafting_loop_wrapper is not None: + self.model = drafting_loop_wrapper(self.model) + self.model_is_wrapped = True + else: + self.model_is_wrapped = False # In case that some tests use stub models and override `_load_model`. if not hasattr(self.model, 'extra_attrs'): self.model.extra_attrs = {} @@ -387,7 +252,8 @@ def __init__( self.is_warmup = False - self.attn_backend = get_attention_backend(attn_backend) + self.attn_backend = get_attention_backend( + pytorch_backend_config.attn_backend) if self.is_spec_decode: self.spec_metadata = None @@ -938,141 +804,6 @@ def __del__(self) -> None: # Release model weights. release_gc() - def _load_model(self, - checkpoint_dir: str, - checkpoint_loader: BaseCheckpointLoader, - load_format: LoadFormat, - max_num_tokens: int, - moe_max_num_tokens: Optional[int] = None, - moe_load_balancer: Optional[MoeLoadBalancerConfig] = None, - lora_config: Optional[LoraConfig] = None, - drafting_loop_wrapper: Optional[Callable[ - [torch.nn.Module], torch.nn.Module]] = None, - **kwargs) -> DecoderModelForCausalLM: - config = checkpoint_loader.load_config( - checkpoint_dir, - trust_remote_code=True, - enable_min_latency=self.pytorch_backend_config.enable_min_latency, - use_cuda_graph=self.pytorch_backend_config.use_cuda_graph, - force_dynamic_quantization=self.pytorch_backend_config. - force_dynamic_quantization, - spec_config=self.spec_config, - max_num_tokens=max_num_tokens, - max_seq_len=self.max_seq_len, - moe_max_num_tokens=moe_max_num_tokens, - moe_load_balancer=moe_load_balancer, - lora_config=lora_config, - allreduce_strategy=self.pytorch_backend_config.allreduce_strategy, - mm_encoder_only=self.pytorch_backend_config.mm_encoder_only, - **kwargs) - - validate_and_set_kv_cache_quant( - config, self.pytorch_backend_config.kv_cache_dtype) - validate_and_set_mamba_ssm_cache_dtype( - config, self.pytorch_backend_config.mamba_ssm_cache_dtype) - - num_layers = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", "0")) - if num_layers > 0: - config.pretrained_config.num_hidden_layers = num_layers - for sub_config in ["text_config", "vision_config"]: - if hasattr(config.pretrained_config, sub_config): - getattr(config.pretrained_config, - sub_config).num_hidden_layers = num_layers - - with timing("Model init total"), maybe_create_moe_load_balancer( - config, self.mapping) as moe_load_balancer: - - try: - # config will be modified in-place for some models, like Qwen2 - config_copy = copy.deepcopy(config) - with MetaInitMode(): - model = AutoModelForCausalLM.from_config(config_copy) - - memo = dict() - - def init_meta_tensor(t: torch.Tensor): - if t.device != torch.device('meta'): - return t - if t not in memo: - memo[t] = torch.empty_like(t, device='cuda') - return memo[t] - - model._apply(init_meta_tensor) - config = config_copy - - except Exception: - logger.info( - f"Fallback to regular model init: {traceback.format_exc(limit=10)}\n" - ) - model = AutoModelForCausalLM.from_config(config) - - model.to("cuda") - rank_model_storage = get_rank_model_storage(model) - logger.info( - f"Use {rank_model_storage / (1024**3):.2f} GB for model weights." - ) - if load_format == LoadFormat.AUTO: - if hasattr(model, 'llm_checkpoint_dir'): - weights = checkpoint_loader.load_weights( - model.llm_checkpoint_dir) - else: - weights = checkpoint_loader.load_weights(checkpoint_dir) - - weight_mapper = checkpoint_loader.get_initialized_weight_mapper( - model, config) - self._call_load_weights(model.load_weights, weights, - weight_mapper) - - if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( - ): - weights = checkpoint_loader.load_weights( - self.spec_config.speculative_model_dir) - self._call_load_weights(model.load_draft_weights, weights, - weight_mapper) - - elif load_format == LoadFormat.DUMMY: - initialize_dummy_weights(model) - if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( - ): - model.draft_model.load_weights_from_target_model(model) - - elif load_format == LoadFormat.VISION_ONLY: - # Vision weights are already loaded within the model. - logger.info( - "LoadFormat.VISION_ONLY: skipping weight loading; using preloaded vision weights." - ) - - else: - raise NotImplementedError( - f"No load support for load format: {load_format}") - - if isinstance(moe_load_balancer, MoeLoadBalancer): - setattr(self, "moe_load_balancer", moe_load_balancer) - moe_load_balancer.register_weight_slots_after_to_cuda() - logger.info("moe_load_balancer finalizing model...") - moe_load_balancer.finalize_model() - logger.info("moe_load_balancer finalize model done") - - torch.cuda.current_stream().synchronize() - - if drafting_loop_wrapper is not None: - model = drafting_loop_wrapper(model) - self.model_is_wrapped = True - else: - self.model_is_wrapped = False - - return model - - def _call_load_weights(self, load_method, weights, weight_mapper): - # TODO smor- this is a temporary solution to load weights. - # Once checkpoint format is unified, this method will be removed. - from inspect import getfullargspec - args = getfullargspec(load_method).args - if "weight_mapper" in args: - load_method(weights, weight_mapper=weight_mapper) - else: - load_method(weights) - def _init_max_seq_len(self): # For mm_encoder_only mode, infer_max_seq_len() is for LLM decoder models if hasattr(self.model, 'infer_max_seq_len'): diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py new file mode 100644 index 00000000000..fccce44f5f3 --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -0,0 +1,329 @@ +import copy +import inspect +import os +import traceback +from typing import Callable, Optional, Tuple + +import torch + +from tensorrt_llm._utils import str_dtype_to_torch +from tensorrt_llm.logger import logger +from tensorrt_llm.lora_helper import LoraConfig +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import QuantAlgo +from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2 + +from ..model_config import ModelConfig +from ..models import AutoModelForCausalLM +from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader +from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode, + timing) +from ..modules.fused_moe.moe_load_balancer import ( + MoeLoadBalancer, maybe_create_moe_load_balancer) +from .config import LoadFormat, PyTorchConfig + +_KV_CACHE_MAP = { + "fp8": QuantAlgo.FP8.value, + "nvfp4": QuantAlgo.NVFP4.value, + "auto": "auto" +} +_VALID_KV_CACHE_DTYPES = ("fp8", "nvfp4", "auto") + + +def validate_and_set_mamba_ssm_cache_dtype(config: ModelConfig, + mamba_ssm_cache_dtype: str) -> None: + if mamba_ssm_cache_dtype == "auto": + mamba_ssm_cache_dtype = config.pretrained_config.torch_dtype + else: + mamba_ssm_cache_dtype = str_dtype_to_torch(mamba_ssm_cache_dtype) + + config.quant_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype + + +def validate_and_set_kv_cache_quant(model_config: ModelConfig, + pyt_kv_cache_dtype: str) -> QuantAlgo: + logger.info( + f'Validating KV Cache config against kv_cache_dtype="{pyt_kv_cache_dtype}"' + ) + # Quantization from hf_quant_config.json + kv_cache_quant = model_config.quant_config.kv_cache_quant_algo + # PyTorch configuration quantization + valid_pyt_quant = bool(pyt_kv_cache_dtype in _VALID_KV_CACHE_DTYPES) + mapped_pyt_quant = _KV_CACHE_MAP.get(pyt_kv_cache_dtype, None) + + # If we're letting the checkpoint dictate the quant with auto, simply + # return and do not modify the checkpoint. + if pyt_kv_cache_dtype == "auto": + logger.info( + f'KV cache quantization set to "{pyt_kv_cache_dtype}". Using ' + "checkpoint KV quantization.") + return + + # If we have an invalid quantization, simply raise an exception. + if not valid_pyt_quant: + raise ValueError( + "Overriding KV cache quantization with an invalid type " + f'"PyTorchConfig.kv_cache_dtype="{pyt_kv_cache_dtype}" ' + f'Accepted types are "{_VALID_KV_CACHE_DTYPES}".') + + # If we get to this point we have a valid quantization setting, but if + # we have an existing setting and it doesn't match we shouldn't proceed. + if kv_cache_quant is not None and mapped_pyt_quant != kv_cache_quant: + raise RuntimeError( + "Attempting to override KV cache quantization " + f'"{kv_cache_quant}" with PyTorchConfig.kv_cache_dtype=' + f'"{pyt_kv_cache_dtype}". You cannot override a checkpoint with a ' + "pre-quantized KV cache that doesn't match.") + + # We have an open ended KV cache in the checkpoint + # and we have a specified override. + model_config.quant_config.kv_cache_quant_algo = mapped_pyt_quant + + +def initialize_dummy_weights( + model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, + seed: int = 0, +) -> None: + """ + This is similar to this function in SGLang with a few changes: + https://github.com/sgl-project/sglang/blob/e074e76b31d4fff13e87a455dbc3acdaa92c537a/python/sglang/srt/model_loader/weight_utils.py#L577 + This method is used to initialize weights with dummy values for testing + models without checkpoints. Unquantized (FP16/BF16/etc) values are generated + from a uniform distribution over the interval (low, high). + For some quantized types (FP8/NVFP4), torch has no built-in way to generate random values. + We simply generate values uniformly across an interval that has been empirically verified + to not generate NaNs/inf for these. + """ + + def _get_random_min_max(dtype: torch.dtype) -> Tuple[int, int]: + # These values are not necessarily the largest possible min/max, + # they need to be small enough to avoid NaNs. + if dtype in (torch.float8_e4m3fn, torch.int8): + return (-3.0, 3.0) + + elif dtype == float4_e2m1x2: + # These correspond to bits of 2 packed FP4 values. + # Because we only go up to 64, the high 4 bits will + # always be 0. But this is fine - we just need values + # that won't generate NaNs. + return (0, 64) + + else: + raise NotImplementedError(f"Unknown quantized type: {dtype}.") + + for param in model.state_dict().values(): + generator = torch.Generator(device=param.data.device) + generator.manual_seed(seed) + dtype = param.data.dtype + + if param.data.element_size() < 2: + # We need to do a cast/round since torch doesn't have uniform_ + # support for these dtypes. + tmp_param = torch.empty(param.data.shape, + dtype=torch.float16, + device=param.data.device) + + quant_min, quant_max = _get_random_min_max(dtype) + tmp_param = tmp_param.uniform_(quant_min, + quant_max, + generator=generator) + + param.data.copy_(tmp_param.to(dtype)) + + # Note: no need to to mess with int32 params, these are probably + # constants and not weights. + elif torch.is_floating_point(param): + param.uniform_(low, high, generator=generator) + + +def get_rank_model_storage(model): + total_bytes = 0 + for _, param in model.named_parameters(): + if param.device.type == 'cuda' and param.device.index == torch.cuda.current_device( + ): + total_bytes += param.element_size() * param.nelement() + for _, buf in model.named_buffers(): + if buf.device.type == 'cuda' and buf.device.index == torch.cuda.current_device( + ): + total_bytes += buf.element_size() * buf.nelement() + return total_bytes + + +class ModelLoader: + """ + Handles the loading, configuration, and weight initialization of a PyTorch model. + This class isolates model loading logic from the main execution engine. + """ + + def __init__(self, + pytorch_backend_config: PyTorchConfig, + mapping: Mapping, + spec_config: Optional["DecodingBaseConfig"], + max_num_tokens: int, + max_seq_len: Optional[int], + lora_config: Optional[LoraConfig] = None): + """ + Initializes the ModelLoader. + + Args: + pytorch_backend_config: Configuration for the PyTorch backend. + mapping: The distributed mapping configuration. + spec_config: Configuration for speculative decoding. + max_num_tokens: The maximum number of tokens the engine will handle. + max_seq_len: The maximum sequence length. + lora_config: Configuration for LoRA. + """ + self.pytorch_backend_config = pytorch_backend_config + self.mapping = mapping + self.spec_config = spec_config + self.max_num_tokens = max_num_tokens + self.max_seq_len = max_seq_len + self.lora_config = lora_config + self.moe_load_balancer = None + + def load( + self, + checkpoint_dir: str, + checkpoint_loader: BaseCheckpointLoader, + ) -> DecoderModelForCausalLM: + """ + Loads the model, its weights, and applies necessary configurations. + + Args: + checkpoint_dir: The directory of the model checkpoint. + checkpoint_loader: The loader object for model checkpoints. + + Returns: + The loaded and initialized PyTorch model. + """ + config = self._load_and_validate_config(checkpoint_dir, + checkpoint_loader) + load_format = self.pytorch_backend_config.load_format + + with timing("Model init total"), maybe_create_moe_load_balancer( + config, self.mapping) as moe_load_balancer: + + try: + # config will be modified in-place for some models, like Qwen2 + config_copy = copy.deepcopy(config) + with MetaInitMode(): + model = AutoModelForCausalLM.from_config(config_copy) + + memo = dict() + + def init_meta_tensor(t: torch.Tensor): + if t.device != torch.device('meta'): + return t + if t not in memo: + memo[t] = torch.empty_like(t, device='cuda') + return memo[t] + + model._apply(init_meta_tensor) + config = config_copy + + except Exception: + logger.info( + f"Fallback to regular model init: {traceback.format_exc(limit=10)}\n" + ) + model = AutoModelForCausalLM.from_config(config) + + model.to("cuda") + rank_model_storage = get_rank_model_storage(model) + logger.info( + f"Use {rank_model_storage / (1024**3):.2f} GB for model weights." + ) + if load_format == LoadFormat.AUTO: + if hasattr(model, 'llm_checkpoint_dir'): + weights = checkpoint_loader.load_weights( + model.llm_checkpoint_dir) + else: + weights = checkpoint_loader.load_weights(checkpoint_dir) + + weight_mapper = checkpoint_loader.get_initialized_weight_mapper( + model, config) + self._call_load_weights(model.load_weights, weights, + weight_mapper) + + if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( + ): + weights = checkpoint_loader.load_weights( + self.spec_config.speculative_model_dir) + self._call_load_weights(model.load_draft_weights, weights, + weight_mapper) + + elif load_format == LoadFormat.DUMMY: + initialize_dummy_weights(model) + if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( + ): + model.draft_model.load_weights_from_target_model(model) + + elif load_format == LoadFormat.VISION_ONLY: + # Vision weights are already loaded within the model. + logger.info( + "LoadFormat.VISION_ONLY: skipping weight loading; using preloaded vision weights." + ) + + else: + raise NotImplementedError( + f"No load support for load format: {load_format}") + + if isinstance(moe_load_balancer, MoeLoadBalancer): + setattr(self, "moe_load_balancer", moe_load_balancer) + moe_load_balancer.register_weight_slots_after_to_cuda() + logger.info("moe_load_balancer finalizing model...") + moe_load_balancer.finalize_model() + logger.info("moe_load_balancer finalize model done") + + torch.cuda.current_stream().synchronize() + + return model + + def _load_and_validate_config( + self, checkpoint_dir: str, + checkpoint_loader: BaseCheckpointLoader) -> ModelConfig: + """Loads and validates the model configuration.""" + config = checkpoint_loader.load_config( + checkpoint_dir, + trust_remote_code=True, + enable_min_latency=self.pytorch_backend_config.enable_min_latency, + use_cuda_graph=self.pytorch_backend_config.use_cuda_graph, + force_dynamic_quantization=self.pytorch_backend_config. + force_dynamic_quantization, + spec_config=self.spec_config, + max_num_tokens=self.max_num_tokens, + max_seq_len=self.max_seq_len, + moe_max_num_tokens=self.pytorch_backend_config.moe_max_num_tokens, + moe_load_balancer=self.pytorch_backend_config.moe_load_balancer, + lora_config=self.lora_config, + allreduce_strategy=self.pytorch_backend_config.allreduce_strategy, + mm_encoder_only=self.pytorch_backend_config.mm_encoder_only, + attn_backend=self.pytorch_backend_config.attn_backend, + moe_backend=self.pytorch_backend_config.moe_backend, + moe_disable_finalize_fusion=self.pytorch_backend_config. + moe_disable_finalize_fusion) + + validate_and_set_kv_cache_quant( + config, self.pytorch_backend_config.kv_cache_dtype) + validate_and_set_mamba_ssm_cache_dtype( + config, self.pytorch_backend_config.mamba_ssm_cache_dtype) + + # Allow overriding the number of layers via environment variable + num_layers_override = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", + "0")) + if num_layers_override > 0: + config.pretrained_config.num_hidden_layers = num_layers_override + for sub_config in ["text_config", "vision_config"]: + if hasattr(config.pretrained_config, sub_config): + getattr(config.pretrained_config, + sub_config).num_hidden_layers = num_layers_override + return config + + def _call_load_weights(self, load_method: Callable, weights, weight_mapper): + """Calls the model's weight loading method with the correct arguments.""" + args = inspect.getfullargspec(load_method).args + if "weight_mapper" in args: + load_method(weights, weight_mapper=weight_mapper) + else: + load_method(weights) diff --git a/tensorrt_llm/bench/benchmark/utils/general.py b/tensorrt_llm/bench/benchmark/utils/general.py index ff3cd933ce1..a21511f38cd 100755 --- a/tensorrt_llm/bench/benchmark/utils/general.py +++ b/tensorrt_llm/bench/benchmark/utils/general.py @@ -8,7 +8,7 @@ import yaml -from tensorrt_llm._torch.pyexecutor.model_engine import \ +from tensorrt_llm._torch.pyexecutor.model_loader import \ validate_and_set_kv_cache_quant from tensorrt_llm.bench.build.build import (get_benchmark_engine_settings, get_model_config) diff --git a/tensorrt_llm/bench/dataclasses/reporting.py b/tensorrt_llm/bench/dataclasses/reporting.py index b12873b5637..70e4cae646b 100755 --- a/tensorrt_llm/bench/dataclasses/reporting.py +++ b/tensorrt_llm/bench/dataclasses/reporting.py @@ -4,7 +4,7 @@ from collections import defaultdict from typing import Any, Dict, List, NamedTuple -from tensorrt_llm._torch.pyexecutor.model_engine import \ +from tensorrt_llm._torch.pyexecutor.model_loader import \ validate_and_set_kv_cache_quant from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig from tensorrt_llm.bench.dataclasses.general import DatasetMetadata diff --git a/tensorrt_llm/executor/ipc.py b/tensorrt_llm/executor/ipc.py index 5d45ebe4c12..327dbf4f6f5 100644 --- a/tensorrt_llm/executor/ipc.py +++ b/tensorrt_llm/executor/ipc.py @@ -24,6 +24,8 @@ class ZeroMqQueue: zmq.PAIR: "PAIR", zmq.PULL: "PULL", zmq.PUSH: "PUSH", + zmq.ROUTER: "ROUTER", + zmq.DEALER: "DEALER", } def __init__(self, @@ -56,6 +58,9 @@ def __init__(self, self.name = name self.socket = self.context.socket(socket_type) + # For ROUTER sockets, track the last identity to enable replies. For now we assume there is only one client in our case. + self._last_identity = None + self.hmac_key = address[1] if address is not None else None self.use_hmac_encryption = use_hmac_encryption @@ -68,8 +73,8 @@ def __init__(self, "Server and client should not receive HMAC key when encryption is disabled" ) - if (socket_type == zmq.PAIR - and self.is_server) or socket_type == zmq.PULL: + if (socket_type == zmq.PAIR and self.is_server + ) or socket_type == zmq.PULL or socket_type == zmq.ROUTER: self.socket.bind( self.address_endpoint ) # Binds to the address and occupy a port immediately @@ -116,13 +121,12 @@ def poll(self, timeout: int) -> bool: def put(self, obj: Any): self.setup_lazily() with nvtx_range_debug("send", color="blue", category="IPC"): - if self.use_hmac_encryption: - # Send pickled data with HMAC appended - data = pickle.dumps(obj) # nosec B301 - signed_data = self._sign_data(data) - self.socket.send(signed_data) + if self.use_hmac_encryption or self.socket_type == zmq.ROUTER: + # Need manual serialization for encryption or ROUTER multipart + data = self._prepare_data(obj) + self._send_data(data) else: - # Send data without HMAC + # Standard socket without encryption - use pyobj directly self.socket.send_pyobj(obj) def put_noblock(self, @@ -144,11 +148,10 @@ def put_noblock(self, self.setup_lazily() with nvtx_range_debug("send", color="blue", category="IPC"): - data = pickle.dumps(obj) # nosec B301 - if self.use_hmac_encryption: - data = self._sign_data(data) + + data = self._prepare_data(obj) try: - self.socket.send(data, flags=zmq.NOBLOCK) + self._send_data(data, flags=zmq.NOBLOCK) except zmq.Again: if retry > 0: time.sleep(wait_time) @@ -159,13 +162,12 @@ def put_noblock(self, async def put_async(self, obj: Any): self.setup_lazily() try: - if self.use_hmac_encryption: - # Send pickled data with HMAC appended - data = pickle.dumps(obj) # nosec B301 - signed_data = self._sign_data(data) - await self.socket.send(signed_data) + if self.use_hmac_encryption or self.socket_type == zmq.ROUTER: + # Need manual serialization for encryption or ROUTER multipart + data = self._prepare_data(obj) + await self._send_data_async(data) else: - # Send data without HMAC + # Standard socket without encryption await self.socket.send_pyobj(obj) except TypeError as e: logger.error(f"Cannot pickle {obj}") @@ -179,45 +181,11 @@ async def put_async(self, obj: Any): def get(self) -> Any: self.setup_lazily() - - if self.use_hmac_encryption: - # Receive signed data with HMAC - signed_data = self.socket.recv() - - # Split data and HMAC - data = signed_data[:-32] - actual_hmac = signed_data[-32:] - - # Verify HMAC - if not self._verify_hmac(data, actual_hmac): - raise RuntimeError("HMAC verification failed") - - obj = pickle.loads(data) # nosec B301 - else: - # Receive data without HMAC - obj = self.socket.recv_pyobj() - return obj + return self._recv_data() async def get_async(self) -> Any: self.setup_lazily() - - if self.use_hmac_encryption: - # Receive signed data with HMAC - signed_data = await self.socket.recv() - - # Split data and HMAC - data = signed_data[:-32] - actual_hmac = signed_data[-32:] - - # Verify HMAC - if not self._verify_hmac(data, actual_hmac): - raise RuntimeError("HMAC verification failed") - - obj = pickle.loads(data) # nosec B301 - else: - # Receive data without HMAC - obj = await self.socket.recv_pyobj() - return obj + return await self._recv_data_async() def close(self): if self.socket: @@ -241,6 +209,108 @@ def _sign_data(self, data_before_encoding: bytes) -> bytes: def __del__(self): self.close() + def _prepare_data(self, obj: Any) -> bytes: + """Serialize object and optionally add HMAC signature.""" + data = pickle.dumps(obj) # nosec B301 + if self.use_hmac_encryption: + return self._sign_data(data) + return data + + def _parse_data(self, data: bytes) -> Any: + """Parse data and optionally verify HMAC signature.""" + if self.use_hmac_encryption: + # Split data and HMAC + message_data = data[:-32] + actual_hmac = data[-32:] + + # Verify HMAC + if not self._verify_hmac(message_data, actual_hmac): + raise RuntimeError("HMAC verification failed") + + return pickle.loads(message_data) # nosec B301 + else: + return pickle.loads(data) # nosec B301 + + def _send_data(self, data: bytes, flags: int = 0): + """Send data using appropriate API based on socket type.""" + if self.socket_type == zmq.ROUTER: + if self._last_identity is None: + raise ValueError("ROUTER socket requires identity") + self.socket.send_multipart([self._last_identity, data], flags=flags) + else: + self.socket.send(data, flags=flags) + + async def _send_data_async(self, data: bytes): + """Async version of _send_data.""" + if self.socket_type == zmq.ROUTER: + if self._last_identity is None: + raise ValueError("ROUTER socket requires identity") + await self.socket.send_multipart([self._last_identity, data]) + else: + await self.socket.send(data) + + def _recv_data(self) -> Any: + """Receive data using appropriate API based on socket type.""" + if self.socket_type == zmq.ROUTER: + identity, data = self.socket.recv_multipart() + self._last_identity = identity # Store for replies + obj = self._parse_data(data) + return obj + else: + if self.use_hmac_encryption: + data = self.socket.recv() + obj = self._parse_data(data) + else: + obj = self.socket.recv_pyobj() + return obj + + async def _recv_data_async(self) -> Any: + """Async version of _recv_data.""" + if self.socket_type == zmq.ROUTER: + identity, data = await self.socket.recv_multipart() + self._last_identity = identity # Store for replies + return self._parse_data(data) + else: + if self.use_hmac_encryption: + data = await self.socket.recv() + return self._parse_data(data) + else: + return await self.socket.recv_pyobj() + + def notify_with_retry(self, message, max_retries=5, timeout=1): + """ + Notify with automatic retry on failure (for DEALER socket pattern). + + Args: + message: Message to send + max_retries: Maximum retry attempts (default: 5) + timeout: Timeout in seconds for each attempt (default: 1) + + Returns: + bool: True if acknowledgment received, False if failed after all retries + """ + if self.socket_type != zmq.DEALER: + raise ValueError( + "notify_with_retry is only supported for DEALER socket for now") + + retry_count = 0 + + while retry_count < max_retries: + try: + self.put(message) + # Wait for ACK with timeout + if self.poll(timeout): + self.get() + return True + else: + retry_count += 1 + + except Exception as e: + logger.error(f"Failed to notify with retry: {e}") + retry_count += 1 + + return False + IpcQueue = ZeroMqQueue diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index 00c8562ff5c..b8ca34620ba 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -120,7 +120,9 @@ def _setup_queues(self) -> WorkerCommIpcAddrs: self.request_queue = IpcQueue(is_server=True, name="proxy_request_queue") self.worker_init_status_queue = IpcQueue( - is_server=True, name="worker_init_status_queue") + is_server=True, + socket_type=zmq.ROUTER, + name="worker_init_status_queue") # TODO[chunweiy]: Unify IpcQueue and FusedIpcQueue # Use PULL mode when enable_postprocess_parallel as there are # multiple senders from multiple processes. @@ -324,6 +326,9 @@ def mpi_done_callback(future: concurrent.futures.Future): while True: if self.worker_init_status_queue.poll(1): ready_signal, error_trace = self.worker_init_status_queue.get() + # Send ACK to the worker + self.worker_init_status_queue.put("ACK") + logger.info("get signal from executor worker") break if any(fut.done() for fut in self.mpi_futures): logger.error("Executor worker died during initialization.") diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 86c2bf6e7b9..5dbbf19f24c 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -11,6 +11,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import torch +import zmq from tensorrt_llm.logger import logger @@ -776,6 +777,7 @@ def worker_main( worker_init_status_queue = IpcQueue( worker_queues.worker_init_status_queue_addr, is_server=False, + socket_type=zmq.DEALER, name="worker_init_status_queue") mp_stats_queue = FusedIpcQueue(worker_queues.stats_queue_addr, is_server=False, @@ -869,7 +871,10 @@ def notify_proxy_threads_to_quit(): logger.error(traceback.format_exc()) print_colored_debug(f"error: {traceback.format_exc()}", "red") if is_leader: - worker_init_status_queue.put((e, traceback.format_exc())) + # Send error message with confirmation + error_msg = (e, traceback.format_exc()) + if not worker_init_status_queue.notify_with_retry(error_msg): + logger.error("Failed to deliver error message to proxy") return with worker: @@ -887,7 +892,12 @@ def notify_proxy_threads_to_quit(): mp_stats_queue) worker._set_iteration_result_queue(worker.kv_events_queues, kv_cache_events_queue) - worker_init_status_queue.put((ready_signal, None)) + # Send ready signal with confirmation + ready_msg = (ready_signal, None) + if not worker_init_status_queue.notify_with_retry(ready_msg): + logger.warning( + "Failed to deliver ready signal to proxy, continuing anyway" + ) while (req := request_queue.get()) is not None: if isinstance(req, CancellingRequest): worker.abort_request(req.id) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 6cbf6ca4121..d8639e19149 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -2010,42 +2010,7 @@ def test_nvfp4_multi_gpus_chunked_prefill(self, tp_size, pp_size, ep_size, assert llm.args.moe_config.backend == moe_backend assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 - def test_nvfp4_multi_gpus_corner_case(self): - """ - This test is used to test the corner case of the NVFP4 model. - When using the same value for max_seq_len and max_num_tokens, there will be no - enough kv block for the dummy requests in CUDA graph warmup when creating - the py_executor before estimating kv cache. Then CUDA graph capture will be - triggered when estimating kv cache. This may cause some errors. - More info in https://nvbugs/5485325. - """ - kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.80, - dtype="fp8", - enable_block_reuse=False) - pytorch_config = dict(disable_overlap_scheduler=False, - cuda_graph_config=CudaGraphConfig( - enable_padding=True, max_batch_size=1024), - moe_config=MoeConfig(backend="TRTLLM")) - - mtp_config = MTPDecodingConfig(num_nextn_predict_layers=1) - with LLM(f"{llm_models_root()}/DeepSeek-R1/DeepSeek-R1-FP4", - tensor_parallel_size=8, - pipeline_parallel_size=1, - moe_expert_parallel_size=8, - kv_cache_config=kv_cache_config, - **pytorch_config, - enable_attention_dp=False, - speculative_config=mtp_config, - max_seq_len=5120, - max_num_tokens=5120) as llm: - - assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4 - - task = MMLU(self.MODEL_NAME) - task.evaluate(llm) - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) - + @skip_pre_blackwell def test_nvfp4_multi_gpus_corner_case(self): """ This test is used to test the corner case of the NVFP4 model. diff --git a/tests/integration/defs/conftest.py b/tests/integration/defs/conftest.py index d29f9840c67..77b2e91b819 100644 --- a/tests/integration/defs/conftest.py +++ b/tests/integration/defs/conftest.py @@ -284,35 +284,11 @@ def gemma_example_root(llm_root, llm_venv): "Get gemma example root" example_root = os.path.join(llm_root, "examples", "models", "core", "gemma") - # https://nvbugs/4559583 Jax dependency broke the entire pipeline in TRT container - # due to the dependency incompatibility with torch, which forced reinstall everything - # and caused pipeline to fail. We manually install gemma dependency as a WAR. - llm_venv.run_cmd(["-m", "pip", "install", "safetensors~=0.4.1", "nltk"]) - # Install Jax because it breaks dependency - google_extension = [ - "-f", - "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html", - ] - - # WAR the new posting of "nvidia-cudnn-cu12~=9.0". - # "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9". - if "x86_64" in platform.machine(): - llm_venv.run_cmd(["-m", "pip", "install", "nvidia-cudnn-cu12~=8.9"]) - - if "Windows" in platform.system(): - llm_venv.run_cmd([ - "-m", "pip", "install", "jax~=0.4.19", "jaxlib~=0.4.19", "--no-deps" - ] + google_extension) - else: - llm_venv.run_cmd([ - "-m", - "pip", - "install", - "jax[cuda12_pip]~=0.4.19", - "jaxlib[cuda12_pip]~=0.4.19", - "--no-deps", - ] + google_extension) - llm_venv.run_cmd(["-m", "pip", "install", "flax~=0.8.0"]) + llm_venv.run_cmd([ + "-m", "pip", "install", "-r", + os.path.join(example_root, "requirements.txt") + ]) + return example_root diff --git a/tests/integration/defs/stress_test/stress_test.py b/tests/integration/defs/stress_test/stress_test.py index 03456d8d5c5..b2a02a610e7 100644 --- a/tests/integration/defs/stress_test/stress_test.py +++ b/tests/integration/defs/stress_test/stress_test.py @@ -15,6 +15,18 @@ """ Stress test script for inference of model using TensorRT-LLM with PyTorch/TRT backend. This script is used for stress testing inference performance using trtllm-serve and genai-perf. + +The script supports three test modes: +1. "stress-test": Runs performance test followed by stress test +2. "stress-stage-alone": Runs only stress test with customized parameters +3. "stress-test-with-accuracy": Runs performance test, stress test, and accuracy tests (GSM8K) + +Accuracy testing is performed using lm_eval with GSM8K dataset: +- Baseline accuracy test: Run before stress test to establish baseline +- Post-stress accuracy test: Run after stress test to verify accuracy stability + +Usage example for accuracy testing: + pytest tests/integration/defs/stress_test/stress_test.py::test_run_stress_test[stress-test-with-accuracy] """ import contextlib import json @@ -126,6 +138,14 @@ class StressTestConfig: customized_stress_concurrency: int = 128 customized_stress_request_rate: int = 20 + # Accuracy test parameters + enable_accuracy_test: bool = False # Enable accuracy testing with GSM8K + accuracy_test_timeout: int = 1200 # 20 minutes timeout for accuracy tests + accuracy_test_concurrency: int = 512 # Concurrency for accuracy tests + accuracy_test_max_retries: int = 3 # Max retries for accuracy tests + accuracy_test_max_gen_toks: int = 256 # Max generation tokens for accuracy tests + accuracy_test_max_length: int = 4096 # Max input length for accuracy tests + @property def request_count_stress_test(self) -> int: """Calculate request count for stress test""" @@ -320,8 +340,10 @@ def check_server_health(server_url: str, return False, f"Unexpected error during health check: {str(e)}" -@pytest.mark.parametrize("test_mode", ["stress-test", "stress-stage-alone"], - ids=lambda x: x) +@pytest.mark.parametrize( + "test_mode", + ["stress-test", "stress-stage-alone", "stress-test-with-accuracy"], + ids=lambda x: x) @pytest.mark.parametrize("backend", ["trt", "pytorch"], ids=lambda x: x) @pytest.mark.parametrize("capacity_scheduler_policy", ["GUARANTEED_NO_EVICT", "MAX_UTILIZATION"], @@ -416,9 +438,14 @@ def stress_test(config, elif test_mode == "stress-stage-alone": run_performance = False run_stress = True + elif test_mode == "stress-test-with-accuracy": + run_performance = True + run_stress = True else: - pytest.skip(f"Skipping test for unsupported mode: {test_mode}. " - f"Supported modes: stress-test, stress-stage-alone") + pytest.skip( + f"Skipping test for unsupported mode: {test_mode}. " + f"Supported modes: stress-test, stress-stage-alone, stress-test-with-accuracy" + ) return # Skip if not enough GPU memory @@ -458,9 +485,9 @@ def stress_test(config, pp_size=test_server_config.pp_size, ep_size=8, # DeepSeek-V3 or DeepSeek-R1 specific ep_size max_batch_size= - 161, # DeepSeek-V3 or DeepSeek-R1 specific max_batch_size + 2048, # DeepSeek-V3 or DeepSeek-R1 specific max_batch_size max_num_tokens= - 1160, # DeepSeek-V3 or DeepSeek-R1 specific max_num_tokens + 2048, # DeepSeek-V3 or DeepSeek-R1 specific max_num_tokens kv_cache_free_gpu_memory_fraction= 0.7, # DeepSeek-V3 or DeepSeek-R1 specific kv_cache fraction capacity_scheduler_policy=test_server_config. @@ -472,8 +499,12 @@ def stress_test(config, # Create a StressTestConfig with customized time parameters if provided if run_stress: + # Enable accuracy test for stress-test-with-accuracy mode + enable_accuracy = (test_mode == "stress-test-with-accuracy") + stress_config = StressTestConfig(model_config=config, - server_config=test_server_config) + server_config=test_server_config, + enable_accuracy_test=enable_accuracy) # Override stress_time and stress_timeout if provided if stress_time is not None: @@ -482,7 +513,8 @@ def stress_test(config, server_config=test_server_config, stress_time=stress_time, stress_timeout=stress_timeout - if stress_timeout is not None else stress_time * 2) + if stress_timeout is not None else stress_time * 2, + enable_accuracy_test=enable_accuracy) else: stress_config = None @@ -632,6 +664,12 @@ def stress_test(config, print_info( f"Server is running with model {model_name}. Starting tests...") + # Run baseline accuracy test first if enabled + baseline_accuracy_success = True + if stress_config and stress_config.enable_accuracy_test: + baseline_accuracy_success, baseline_accuracy_value = run_accuracy_test( + model_path, test_server_config, stress_config, "baseline") + # Run performance test first if enabled stage2_output = None # Initialize stage2_output to None if run_performance: @@ -664,6 +702,52 @@ def stress_test(config, stress_config, None, request_counter=request_counter) + + # Run post-stress accuracy test if enabled + post_stress_accuracy_success = True + if stress_config and stress_config.enable_accuracy_test: + post_stress_accuracy_success, post_stress_accuracy_value = run_accuracy_test( + model_path, test_server_config, stress_config, + "post_stress") + + # Report accuracy test results + if baseline_accuracy_success and post_stress_accuracy_success: + print_info("=== ACCURACY TEST SUMMARY ===") + print_info("✓ Baseline accuracy test: PASSED") + print_info("✓ Post-stress accuracy test: PASSED") + + # Compare accuracy values if both are available + if baseline_accuracy_value is not None and post_stress_accuracy_value is not None: + accuracy_drop = baseline_accuracy_value - post_stress_accuracy_value + accuracy_drop_percentage = ( + accuracy_drop / baseline_accuracy_value) * 100 + + print_info( + f"Baseline accuracy: {baseline_accuracy_value:.4f}") + print_info( + f"Post-stress accuracy: {post_stress_accuracy_value:.4f}" + ) + print_info( + f"Accuracy drop: {accuracy_drop:.4f} ({accuracy_drop_percentage:.2f}%)" + ) + + # Define threshold for significant accuracy drop (e.g., 5%) + accuracy_drop_threshold = 0.05 # 5% + # Assert that accuracy drop is within acceptable threshold + assert accuracy_drop_percentage <= ( + accuracy_drop_threshold * 100 + ), f"Accuracy drop {accuracy_drop_percentage:.2f}% exceeds threshold {accuracy_drop_threshold * 100}%" + print_info( + "✓ Model accuracy appears stable under stress conditions" + ) + else: + print_warning("=== ACCURACY TEST SUMMARY ===") + if not baseline_accuracy_success: + print_warning("✗ Baseline accuracy test: FAILED") + if not post_stress_accuracy_success: + print_warning("✗ Post-stress accuracy test: FAILED") + print_warning( + "Model accuracy may be affected by stress conditions") finally: # Clean up temp yaml file if os.path.exists(extra_llm_options_path): @@ -984,6 +1068,112 @@ def format_time(seconds: int) -> str: return f"{seconds}s" +def parse_accuracy_from_lm_eval_output(output_text: str) -> float: + """ + Parse accuracy value from lm_eval output for GSM8K flexible-extract exact_match + + Args: + output_text: The output text from lm_eval command + + Returns: + float: The accuracy value (0.7582 in the example) + """ + import re + + # Look for the specific pattern: |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7559|± |0.0118| + patterns = [ + r'flexible-extract\|\s+\d+\|exact_match\|\↑\s+\|(\d+\.\d+)', + ] + + for pattern in patterns: + match = re.search(pattern, output_text) + if match: + accuracy_value = float(match.group(1)) + print_info(f"Extracted accuracy value: {accuracy_value}") + return accuracy_value + + print_warning("Could not find accuracy value in lm_eval output") + print_warning(f"Output text: {output_text}") + return None + + +def run_accuracy_test(model_path: str, + server_config: ServerConfig, + stress_config: StressTestConfig, + test_phase: str = "baseline") -> tuple[bool, float]: + """ + Run accuracy test using lm_eval with GSM8K dataset + + Args: + model_path: Path of the model being tested + server_config: Server configuration containing URL and port + stress_config: Stress test configuration containing accuracy test parameters + test_phase: Phase of the test ("baseline" or "post_stress") + + Returns: + tuple: (Boolean indicating whether the accuracy test completed successfully, accuracy value) + """ + if not stress_config.enable_accuracy_test: + print_info(f"Skipping accuracy test for {test_phase} phase (disabled)") + return True, None + + print_info(f"=== Running {test_phase.upper()} ACCURACY TEST (GSM8K) ===") + + # Create lm_eval command + lm_eval_cmd = [ + "lm_eval", "--model", "local-completions", "--tasks", "gsm8k", + "--model_args", + f"model={model_path},base_url={server_config.url}/v1/completions," + f"num_concurrent={stress_config.accuracy_test_concurrency}," + f"max_retries={stress_config.accuracy_test_max_retries}," + f"tokenized_requests=False," + f"timeout={stress_config.accuracy_test_timeout}," + f"max_gen_toks={stress_config.accuracy_test_max_gen_toks}," + f"max_length={stress_config.accuracy_test_max_length}", + "--trust_remote_code" + ] + + test_start_time = time.time() + accuracy_value = None + + try: + # Run lm_eval process with timeout monitoring + print_info(f"Running lm_eval command: {' '.join(lm_eval_cmd)}") + + # Use subprocess.run to capture output directly + result = subprocess.run(lm_eval_cmd, + capture_output=True, + text=True, + timeout=stress_config.accuracy_test_timeout) + + # Check if process completed successfully + if result.returncode == 0: + test_end_time = time.time() + duration = int(test_end_time - test_start_time) + print_info( + f"{test_phase.capitalize()} accuracy test completed successfully in {format_time(duration)}" + ) + + # Parse accuracy value from output + output_text = result.stdout + accuracy_value = parse_accuracy_from_lm_eval_output(output_text) + return True, accuracy_value + else: + print_warning( + f"lm_eval exited with non-zero code: {result.returncode}") + print_warning(f"stderr: {result.stderr}") + return False, None + + except subprocess.TimeoutExpired: + print_warning( + f"Accuracy test timed out after {stress_config.accuracy_test_timeout} seconds" + ) + return False, None + except Exception as e: + print_warning(f"Error during {test_phase} accuracy test: {str(e)}") + return False, None + + def extract_stress_test_metrics(artifacts_dir="./artifacts", current_model=None): """ diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 333a99726c4..b46608166c4 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1735,8 +1735,10 @@ def test_openai_multinodes_chat_tp8pp2(llm_root, llm_venv): @pytest.mark.skip_less_device_memory(80000) -@pytest.mark.parametrize( - "model_name", ["llama-3.1-model/Meta-Llama-3.1-8B", "gpt_oss/gpt-oss-20b"]) +@pytest.mark.parametrize("model_name", [ + "llama-3.1-model/Meta-Llama-3.1-8B", + pytest.param("gpt_oss/gpt-oss-20b", marks=skip_pre_hopper) +]) def test_trtllm_benchmark_serving(llm_venv, model_name): test_root = unittest_path() / "llmapi" / "apps" llm_venv.run_cmd([ diff --git a/tests/integration/test_lists/qa/llm_function_rtx6k.txt b/tests/integration/test_lists/qa/llm_function_rtx6k.txt index 69f0787908a..9f0746697a2 100644 --- a/tests/integration/test_lists/qa/llm_function_rtx6k.txt +++ b/tests/integration/test_lists/qa/llm_function_rtx6k.txt @@ -1,6 +1,3 @@ -accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp8_tp2pp2 -accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp8_tp2pp2_manage_weights -accuracy/test_cli_flow.py::TestMixtral8x7B::test_nvfp4_prequantized accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False-enable_chunked_prefill=False] diff --git a/tests/integration/test_lists/qa/llm_function_stress.txt b/tests/integration/test_lists/qa/llm_function_stress.txt new file mode 100644 index 00000000000..a4aac892ad4 --- /dev/null +++ b/tests/integration/test_lists/qa/llm_function_stress.txt @@ -0,0 +1,3 @@ +stress_test/stress_test.py::test_run_stress_test[DeepSeek-V3_tp8-stress_time_3600s_timeout_5400s-GUARANTEED_NO_EVICT-pytorch-stress-test-with-accuracy] +stress_test/stress_test.py::test_run_stress_test[DeepSeek-V3_tp8-stress_time_3600s_timeout_5400s-MAX_UTILIZATION-pytorch-stress-test-with-accuracy] +stress_test/stress_test.py::test_run_stress_test[DeepSeek-R1_tp8-stress_time_3600s_timeout_5400s-MAX_UTILIZATION-pytorch-stress-test-with-accuracy] diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 31fc2ee704c..cb0eb622d03 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -39,7 +39,18 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_auto_dtype[tp8-cuda_graph=False] + - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_auto_dtype[tp8ep4-cuda_graph=True] + - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_auto_dtype[tp8ep8-cuda_graph=True] + - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_auto_dtype[tp4-cuda_graph=False] + - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_auto_dtype[tp4ep2-cuda_graph=True] + - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_auto_dtype[tp4ep4-cuda_graph=True] + - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8[tp8ep8-cuda_graph=True] + - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8[tp4-cuda_graph=True] + - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp4[tp8ep8-cuda_graph=True] - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp4[tp4-cuda_graph=True] + - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8_chunked_prefill[tp4ep4-cuda_graph=True] + - accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp4_chunked_prefill[tp4ep4-cuda_graph=True] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-auto] - accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-cutlass-auto] diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index faf0b1af975..d4d1a94a7a5 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -283,8 +283,6 @@ disaggregated/test_disaggregated.py::test_disaggregated_diff_max_tokens[TinyLlam disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu_mtp[DeepSeek-V3-Lite-fp8] SKIP (https://nvbugs/5465642) examples/test_multimodal.py::test_llm_multimodal_general[Mistral-Small-3.1-24B-Instruct-2503-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (https://nvbugs/5431146) accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_model[ctxpp2gentp2] SKIP (https://nvbugs/5470769) -full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2] SKIP (https://nvbugs/5471108) -full:L20/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2] SKIP (https://nvbugs/5471108) test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-tp8pp2-mmlu] SKIP (https://nvbugs/5473781) disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_llama_context_capacity[False-False-DeepSeek-V3-Lite-fp8/fp8] SKIP (https://nvbugs/5477404) triton_server/test_triton.py::test_python_bls_unit_tests[python-bls-unit-tests] SKIP (https://nvbugs/5477392) @@ -300,7 +298,6 @@ examples/test_phi.py::test_llm_phi_lora_1gpu[Phi-3-mini-4k-instruct-ru-lora-Phi- examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3-mini-128k-instruct-fp8-float16] SKIP (https://nvbugs/5465143) examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3.5-mini-instruct-fp8-float16] SKIP (https://nvbugs/5465143) examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3.5-MoE-instruct-fp8-bfloat16] SKIP (https://nvbugs/5465143) -examples/test_gemma.py::test_llm_gemma_1gpu_summary_vswa[gemma-3-1b-it-other-bfloat16-8] SKIP (https://nvbugs/5522332) accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_auto_dtype SKIP (https://nvbugs/5465143) accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_tp2 SKIP (https://nvbugs/5465143) accuracy/test_cli_flow.py::TestLongAlpaca7B::test_auto_dtype SKIP (https://nvbugs/5481075) @@ -337,7 +334,6 @@ full:H100/accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8 full:H100/accuracy/test_llm_api_pytorch.py::TestLlama4MaverickInstruct::test_fp8_eagle3[tp8-torch_compile=True] SKIP (https://nvbugs/5483534) full:A100/test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-video-False] SKIP (https://nvbugs/5453725) test_e2e.py::test_ptp_scaffolding[DeepSeek-R1-Distill-Qwen-7B-DeepSeek-R1/DeepSeek-R1-Distill-Qwen-7B] SKIP (https://nvbugs/5517260) -accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8_chunked_prefill[tp4ep4-cuda_graph=True] SKIP (https://nvbugs/5522462) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5522746) accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5522746) test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-image-False] SKIP (https://nvbugs/5523925) @@ -346,3 +342,10 @@ test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video-True] SKIP (https://nvbugs/5509024) test_e2e.py::test_trtllm_multimodal_benchmark_serving SKIP (https://nvbugs/5523315) accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320) +examples/test_llama.py::test_llm_llama_1gpu_fp8_kv_cache[llama-v2-7b-hf-bfloat16] SKIP (https://nvbugs/5527940) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5528070) +accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype SKIP (https://nvbugs/5527956) +test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True] SKIP (https://nvbugs/5509024) +accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] SKIP (https://nvbugs/5481198) +accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale_chunked_prefill[latency] SKIP (https://nvbugs/5481198) +accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale_chunked_prefill[throughput] SKIP (https://nvbugs/5481198) diff --git a/tests/unittest/_torch/executor/test_pytorch_model_engine.py b/tests/unittest/_torch/executor/test_pytorch_model_engine.py index 8a06a3a9f0f..ec53a1ae832 100644 --- a/tests/unittest/_torch/executor/test_pytorch_model_engine.py +++ b/tests/unittest/_torch/executor/test_pytorch_model_engine.py @@ -67,16 +67,14 @@ def __init__(self, mapping = Mapping(world_size=tensorrt_llm.mpi_world_size(), tp_size=tensorrt_llm.mpi_world_size(), rank=tensorrt_llm.mpi_rank()) - self.model_is_wrapped = False + model = DummyModel(self.dtype) super().__init__(model_path="", pytorch_backend_config=pytorch_backend_config, checkpoint_loader=None, batch_size=batch_size, max_seq_len=max_seq_len, - mapping=mapping) - - def _load_model(self, mode_path: str, **kwargs) -> torch.nn.Module: - return DummyModel(self.dtype) + mapping=mapping, + model=model) def _create_request(num_tokens, req_id: int): diff --git a/tests/unittest/_torch/executor/test_router_dealer_ipc.py b/tests/unittest/_torch/executor/test_router_dealer_ipc.py new file mode 100644 index 00000000000..3c96f1d55c3 --- /dev/null +++ b/tests/unittest/_torch/executor/test_router_dealer_ipc.py @@ -0,0 +1,373 @@ +import multiprocessing +import pickle +import time +from contextlib import contextmanager + +import pytest +import zmq + +from tensorrt_llm.executor.ipc import ZeroMqQueue + + +@contextmanager +def router_dealer_pair(use_hmac_encryption=False, + server_name="test_router", + client_name="test_dealer"): + """Context manager to create and manage a ROUTER-DEALER queue pair.""" + server_queue = ZeroMqQueue(socket_type=zmq.ROUTER, + is_server=True, + name=server_name, + use_hmac_encryption=use_hmac_encryption) + + client_queue = ZeroMqQueue(address=server_queue.address, + socket_type=zmq.DEALER, + is_server=False, + name=client_name, + use_hmac_encryption=use_hmac_encryption) + + try: + yield server_queue, client_queue + finally: + server_queue.close() + client_queue.close() + + +def basic_communication_helper(server_queue, client_queue, test_message, + expected_reply): + """Helper function to test basic bidirectional communication.""" + # Send message from client to server + client_queue.put(test_message) + + # Server should receive the message + assert server_queue.poll(2), "Server should receive message" + received_message = server_queue.get() + assert received_message == test_message + + # Send reply from server to client + server_queue.put(expected_reply) + + # Client should receive the reply + assert client_queue.poll(2), "Client should receive reply" + received_reply = client_queue.get() + assert received_reply == expected_reply + + return received_message, received_reply + + +@contextmanager +def multiprocess_runner(): + """Context manager to handle multiprocess execution and cleanup.""" + processes = [] + queues = [] + + def add_process(target, args, daemon=True): + proc = multiprocessing.Process(target=target, args=args, daemon=daemon) + processes.append(proc) + return proc + + def add_queue(): + queue = multiprocessing.Queue() + queues.append(queue) + return queue + + try: + yield add_process, add_queue + + # Start all processes + for proc in processes: + proc.start() + + # Wait for all processes to complete + for proc in processes: + proc.join(timeout=10) + + # Check if any process is still alive and terminate if needed + for proc in processes: + if proc.is_alive(): + proc.terminate() + proc.join(timeout=2) + + # Verify all processes finished successfully + for i, proc in enumerate(processes): + assert not proc.is_alive(), f"Process {i} did not finish in time" + if proc.exitcode != 0: + pytest.fail( + f"Process {i} failed with exit code: {proc.exitcode}") + + except Exception as e: + # Emergency cleanup + for proc in processes: + if proc.is_alive(): + proc.terminate() + proc.join(timeout=1) + raise e + + +def collect_process_results(result_queues, timeout=2): + """Collect results from multiple process result queues.""" + results = [] + for queue in result_queues: + try: + result = queue.get(timeout=timeout) + results.append(result) + except Exception as e: + results.append({"error": f"Failed to get result: {e}"}) + return results + + +def verify_worker_proxy_results(worker_result, + proxy_result, + expected_signal=b"READY"): + """ + Universal verification function for worker-proxy communication results. + Works identically for both encrypted and non-encrypted communication. + + Args: + worker_result: Worker process result dictionary + proxy_result: Proxy process result dictionary + expected_signal: Expected signal in the message + """ + # Basic result verification + assert worker_result[ + "error"] is None, f"Worker error: {worker_result['error']}" + assert worker_result["success"], "Worker should have succeeded" + assert proxy_result[ + "error"] is None, f"Proxy error: {proxy_result['error']}" + assert proxy_result[ + "message"] is not None, "Proxy should have received message" + + # Verify message content - same format regardless of encryption + ready_signal, error_trace = proxy_result["message"] + assert ready_signal == expected_signal + assert error_trace is None + + +def run_worker_proxy_multiprocess_test(use_hmac_encryption=False, + ready_signal=b"READY"): + """ + Generic multiprocess test for worker-proxy communication. + + Args: + use_hmac_encryption: Whether to use HMAC encryption + ready_signal: Signal to send from worker to proxy + """ + + def proxy_process(address_q, result_q): + """Generic proxy process.""" + try: + proxy_queue = ZeroMqQueue(socket_type=zmq.ROUTER, + is_server=True, + name="proxy_multiproc", + use_hmac_encryption=use_hmac_encryption) + + address_q.put(proxy_queue.address) + + try: + if proxy_queue.poll(5): + message = proxy_queue.get() + result_q.put({"message": message, "error": None}) + proxy_queue.put("ACK") + else: + result_q.put({ + "message": None, + "error": "Timeout waiting for worker message" + }) + finally: + proxy_queue.close() + except Exception as e: + result_q.put({"message": None, "error": str(e)}) + + def worker_process(address_q, result_q): + """Generic worker process.""" + try: + proxy_addr = address_q.get(timeout=5) + worker_queue = ZeroMqQueue(socket_type=zmq.DEALER, + is_server=False, + address=proxy_addr, + name="worker_multiproc", + use_hmac_encryption=use_hmac_encryption) + + try: + time.sleep(0.1) # Simulate initialization time + ready_message = (ready_signal, None) + worker_queue.put(ready_message) + + if worker_queue.poll(3): + ack = worker_queue.get() + success = ack == "ACK" + result_q.put({ + "success": success, + "error": None if success else f"Invalid ACK: {ack}", + "ack": ack if success else None + }) + else: + result_q.put({ + "success": False, + "error": "Timeout waiting for ACK" + }) + finally: + worker_queue.close() + except Exception as e: + result_q.put({"success": False, "error": str(e)}) + + # Run the multiprocess test + with multiprocess_runner() as (add_process, add_queue): + address_queue = add_queue() + proxy_result_queue = add_queue() + worker_result_queue = add_queue() + + add_process(proxy_process, (address_queue, proxy_result_queue)) + add_process(worker_process, (address_queue, worker_result_queue)) + + # Collect and verify results + worker_result, proxy_result = collect_process_results( + [worker_result_queue, proxy_result_queue]) + verify_worker_proxy_results(worker_result, + proxy_result, + expected_signal=ready_signal) + + return worker_result, proxy_result + + +class TestRouterDealerIPC: + """Test suite for ROUTER/DEALER socket communication patterns.""" + + @pytest.fixture + def zmq_context(self): + """Create a ZMQ context for testing.""" + context = zmq.Context() + yield context + context.term() + + @pytest.fixture + def router_socket(self, zmq_context): + """Create a ROUTER socket for testing.""" + socket = zmq_context.socket(zmq.ROUTER) + socket.bind("tcp://127.0.0.1:*") + endpoint = socket.getsockopt(zmq.LAST_ENDPOINT).decode() + yield socket, endpoint + socket.close() + + @pytest.fixture + def dealer_socket(self, zmq_context): + """Create a DEALER socket for testing.""" + socket = zmq_context.socket(zmq.DEALER) + yield socket + socket.close() + + def test_basic_router_dealer_communication(self, router_socket, + dealer_socket): + """Test basic ROUTER-DEALER communication with multipart messages.""" + router, endpoint = router_socket + dealer = dealer_socket + + # Connect dealer to router + dealer.connect(endpoint) + time.sleep(0.1) # Allow connection to establish + + # Send message from dealer to router + test_message = b"Hello Router" + dealer.send(test_message) + + # Receive multipart message at router (identity + message) + identity, message = router.recv_multipart() + assert message == test_message + assert len(identity) > 0 # Identity should be present + + # Send reply from router to dealer + reply_message = b"Hello Dealer" + router.send_multipart([identity, reply_message]) + + # Receive reply at dealer + received_reply = dealer.recv() + assert received_reply == reply_message + + def test_multipart_message_handling(self, router_socket, dealer_socket): + """Test multipart message handling with multiple frames.""" + router, endpoint = router_socket + dealer = dealer_socket + + dealer.connect(endpoint) + time.sleep(0.1) + + # Send multipart message from dealer + message_parts = [b"frame1", b"frame2", b"frame3"] + dealer.send_multipart(message_parts) + + # Receive at router (identity + all message frames) + frames = router.recv_multipart() + identity = frames[0] + received_parts = frames[1:] + + assert received_parts == message_parts + assert len(identity) > 0 + + # Send multipart reply from router + reply_parts = [b"reply1", b"reply2"] + router.send_multipart([identity] + reply_parts) + + # Receive multipart reply at dealer + received_reply = dealer.recv_multipart() + assert received_reply == reply_parts + + def test_pickle_message_serialization(self, router_socket, dealer_socket): + """Test sending pickled Python objects through ROUTER-DEALER.""" + router, endpoint = router_socket + dealer = dealer_socket + + dealer.connect(endpoint) + time.sleep(0.1) + + # Send pickled object from dealer + test_obj = {"signal": "READY", "data": [1, 2, 3], "error": None} + pickled_data = pickle.dumps(test_obj) + dealer.send(pickled_data) + + # Receive at router and unpickle + identity, message = router.recv_multipart() + received_obj = pickle.loads(message) + assert received_obj == test_obj + + # Send pickled reply from router + reply_obj = {"status": "ACK", "timestamp": 1234567890} + reply_data = pickle.dumps(reply_obj) + router.send_multipart([identity, reply_data]) + + # Receive and unpickle reply at dealer + reply_raw = dealer.recv() + received_reply = pickle.loads(reply_raw) + assert received_reply == reply_obj + + @pytest.mark.parametrize("use_hmac", [False, True]) + def test_router_dealer_communication(self, use_hmac): + """Test ROUTER-DEALER communication with and without HMAC encryption.""" + # Same message content regardless of encryption + test_message = ("READY", "initialization_complete") + reply_message = "ACK" + + with router_dealer_pair(use_hmac_encryption=use_hmac) as (server_queue, + client_queue): + basic_communication_helper(server_queue, client_queue, test_message, + reply_message) + + @pytest.mark.parametrize("use_hmac", [False, True]) + def test_router_dealer_basic(self, use_hmac): + """Test router_dealer with and without HMAC encryption.""" + # Same message content regardless of encryption + ready_message = (b"READY", None) + ack_message = "ACK" + + with router_dealer_pair(use_hmac_encryption=use_hmac) as (proxy_queue, + worker_queue): + received_message, received_ack = basic_communication_helper( + proxy_queue, worker_queue, ready_message, ack_message) + + # Encryption is transparent - same verification for both + assert received_ack == "ACK" + + @pytest.mark.parametrize("use_hmac", [False, True]) + def test_router_dealer_multiprocess(self, use_hmac): + """Test router_dealer using separate processes, with and without HMAC encryption.""" + # Same signal regardless of encryption + run_worker_proxy_multiprocess_test(use_hmac_encryption=use_hmac, + ready_signal=b"READY") diff --git a/tests/unittest/_torch/multimodal/test_find_num_image_tokens.py b/tests/unittest/_torch/multimodal/test_find_num_image_tokens.py index cffa883bd78..8f1e53c42e4 100644 --- a/tests/unittest/_torch/multimodal/test_find_num_image_tokens.py +++ b/tests/unittest/_torch/multimodal/test_find_num_image_tokens.py @@ -1,9 +1,9 @@ -import io +import os +from pathlib import Path import pytest -import requests -from PIL import Image from transformers import AutoConfig, AutoTokenizer +from utils.llm_data import llm_models_root from tensorrt_llm import MultimodalEncoder from tensorrt_llm._torch.models.modeling_llava_next import \ @@ -12,38 +12,34 @@ Qwen2VLInputProcessorBase from tensorrt_llm._torch.shared_tensor import SharedTensorContainer from tensorrt_llm.inputs import default_multimodal_input_loader -from tensorrt_llm.inputs.utils import load_video +from tensorrt_llm.inputs.utils import load_image, load_video +test_data_root = Path( + os.path.join(llm_models_root(), "multimodals", "test_data")) example_images = [ - "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png", - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png", - "https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg", + str(test_data_root / "seashore.png"), + str(test_data_root / "inpaint.png"), + str(test_data_root / "61.jpg"), ] - example_videos = [ - "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4", - "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4", + str(test_data_root / "OAI-sora-tokyo-walk.mp4"), + str(test_data_root / "world.mp4"), ] -def download_image(url: str) -> Image.Image: - """Download image from URL and return as PIL Image.""" - response = requests.get(url, timeout=30) - response.raise_for_status() - img = Image.open(io.BytesIO(response.content)) - return img.convert("RGB") - - @pytest.fixture(scope="function") def multimodal_model_configs(): """Get multimodal model configurations for testing.""" model_configs = { 'llava-v1.6-mistral-7b-hf': { 'hf_model_dir': 'llava-hf/llava-v1.6-mistral-7b-hf', + 'model_dir': + llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf", 'model_type': 'llava_next', }, 'qwen2.5-vl': { 'hf_model_dir': 'Qwen/Qwen2.5-VL-3B-Instruct', + 'model_dir': llm_models_root() / "Qwen2.5-VL-3B-Instruct", 'model_type': 'qwen2_5_vl', }, } @@ -66,7 +62,7 @@ def test_get_num_tokens_per_image(model_key, multimodal_model_configs): pytest.skip(f"Skipping test for {model_key} - model not available") model_config = multimodal_model_configs[model_key] - encoder_model_dir = model_config['hf_model_dir'] + encoder_model_dir = model_config['model_dir'] model_type = model_config['model_type'] # Test configuration @@ -119,10 +115,10 @@ def test_get_num_tokens_per_image(model_key, multimodal_model_configs): example_images ), f"Expected {len(example_images)} encoder outputs, got {len(encoder_outputs)}" - for image_idx, test_image_url in enumerate(example_images): + for image_idx, test_image in enumerate(example_images): # Get test image dimensions - test_image = download_image(test_image_url) + test_image = load_image(test_image, format="pil") image_width, image_height = test_image.size # Get actual embedding tensor for this image @@ -173,7 +169,7 @@ def test_get_num_tokens_per_video(model_key, multimodal_model_configs): pytest.skip(f"Skipping test for {model_key} - model not available") model_config = multimodal_model_configs[model_key] - encoder_model_dir = model_config['hf_model_dir'] + encoder_model_dir = model_config['model_dir'] model_type = model_config['model_type'] # Test configuration @@ -226,10 +222,10 @@ def test_get_num_tokens_per_video(model_key, multimodal_model_configs): example_videos ), f"Expected {len(example_videos)} encoder outputs, got {len(encoder_outputs)}" - for video_idx, test_video_url in enumerate(example_videos): + for video_idx, test_video in enumerate(example_videos): # Get test video dimensions - test_video = load_video(test_video_url, num_frames=8, format="pil") + test_video = load_video(test_video, num_frames=8, format="pil") # load_video returns a list of frames, we only have one video video_width, video_height = test_video[0].size diff --git a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py index e42e07654fb..30c3481cf06 100644 --- a/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py +++ b/tests/unittest/_torch/multimodal/test_mm_encoder_standalone.py @@ -2,6 +2,7 @@ import os import pytest +from utils.llm_data import llm_models_root from tensorrt_llm import MultimodalEncoder from tensorrt_llm._torch.shared_tensor import SharedTensorContainer @@ -22,8 +23,12 @@ def multimodal_model_config(): # You can extend this to support multiple models or get from environment model_configs = { 'llava-v1.6-mistral-7b-hf': { - 'model_name': 'llava-v1.6-mistral-7b-hf', - 'hf_model_dir': 'llava-hf/llava-v1.6-mistral-7b-hf', + 'model_name': + 'llava-v1.6-mistral-7b-hf', + 'hf_model_dir': + 'llava-hf/llava-v1.6-mistral-7b-hf', + 'model_dir': + llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf", } } @@ -47,7 +52,7 @@ def test_single_image_chat(model_key, multimodal_model_config): ) # Extract model information from config - encoder_model_dir = multimodal_model_config['hf_model_dir'] + encoder_model_dir = multimodal_model_config['model_dir'] # Test configuration max_tokens = 64 diff --git a/tests/unittest/llmapi/apps/_test_trtllm_serve_benchmark.py b/tests/unittest/llmapi/apps/_test_trtllm_serve_benchmark.py index b0d32fd76bb..908757b7c90 100644 --- a/tests/unittest/llmapi/apps/_test_trtllm_serve_benchmark.py +++ b/tests/unittest/llmapi/apps/_test_trtllm_serve_benchmark.py @@ -3,7 +3,7 @@ import sys import pytest -from utils.util import skip_gpu_memory_less_than_80gb +from utils.util import skip_gpu_memory_less_than_80gb, skip_pre_hopper from .openai_server import RemoteOpenAIServer @@ -45,9 +45,11 @@ def dataset_path(dataset_name: str): @skip_gpu_memory_less_than_80gb -@pytest.mark.parametrize( - "model_name", ["llama-3.1-model/Meta-Llama-3.1-8B", "gpt_oss/gpt-oss-20b"], - indirect=True) +@pytest.mark.parametrize("model_name", [ + "llama-3.1-model/Meta-Llama-3.1-8B", + pytest.param("gpt_oss/gpt-oss-20b", marks=skip_pre_hopper) +], + indirect=True) def test_trtllm_serve_benchmark(server: RemoteOpenAIServer, benchmark_root: str, model_path: str): model_name = model_path.split("/")[-1]