Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,7 @@ def forward(
return {"logits": logits_flat}


def create_autodeploy_executor(
executor_config: ExecutorConfig, checkpoint_dir: str = None, engine_dir: str = None
):
def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: str = None):
"""Create an AutoDeploy executor from the given configuration and checkpoint directory.

This is the entrypoint API to the _autodeploy backend.
Expand Down
3 changes: 0 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ class PyTorchConfig:
'tokens_per_block',
'mapping',
'hf_model_dir',
'trt_engine_dir',
]


Expand All @@ -111,7 +110,6 @@ def update_executor_config(
build_config: Optional[BuildConfig] = None,
speculative_config: Optional[SpecConfig] = None,
hf_model_dir: Optional[str] = None,
trt_engine_dir: Optional[str] = None,
max_input_len: Optional[int] = None,
max_seq_len: Optional[int] = None):
if backend is None:
Expand All @@ -135,7 +133,6 @@ def update_executor_config(
executor_config.tokens_per_block = executor_config.tokens_per_block or build_config.plugin_config.tokens_per_block

executor_config.hf_model_dir = hf_model_dir
executor_config.trt_engine_dir = trt_engine_dir

if max_input_len is not None:
executor_config.max_input_len = max_input_len
Expand Down
2 changes: 0 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,6 @@ def __init__(
self.is_cuda_graph_dummy = False
self.py_lora_task_layer_module_configs = None

self.py_tokens = super().get_tokens()

self.py_return_log_probs = return_log_probs
self.py_return_context_logits = return_context_logits
self.py_return_generation_logits = return_generation_logits
Expand Down
1 change: 0 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def _get_mapping(executor_config: ExecutorConfig) -> Mapping:
def create_py_executor(
executor_config: ExecutorConfig,
checkpoint_dir: str = None,
engine_dir: str = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
_mangle_executor_config(executor_config)
Expand Down
2 changes: 0 additions & 2 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def _create_engine():
args = {
"executor_config": executor_config,
"checkpoint_dir": executor_config.hf_model_dir,
"engine_dir": executor_config.trt_engine_dir,
}
if executor_config.backend == "pytorch":
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
Expand All @@ -135,7 +134,6 @@ def _create_engine():
else:
raise ValueError(
f"Unsupported backend config: {executor_config.backend}")

return create_executor(**args)

self.engine = _create_engine()
Expand Down
1 change: 0 additions & 1 deletion tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,6 @@ def _build_model(self):
if self._on_trt_backend else None,
speculative_config=self.args.speculative_config,
hf_model_dir=self._hf_model_dir,
trt_engine_dir=self._engine_dir,
max_input_len=self.args.max_input_len,
max_seq_len=max_seq_len)
self._executor_config.llm_parallel_config = self.args.parallel_config
Expand Down
112 changes: 105 additions & 7 deletions tensorrt_llm/llmapi/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
# yapf: disable
from ..bindings.executor import (BatchingType, CapacitySchedulerPolicy,
ContextChunkingPolicy, ExecutorConfig,
KvCacheRetentionConfig, SchedulerConfig)
GuidedDecodingConfig, KvCacheRetentionConfig,
PeftCacheConfig, SchedulerConfig)
# yapf: enable
from ..builder import BuildConfig, Engine, build
from ..builder import BuildConfig, Engine, EngineConfig, build
from ..llmapi.llm_args import TrtLlmArgs
from ..logger import logger
from ..mapping import Mapping
Expand All @@ -30,15 +31,16 @@
from ..module import Module
from .build_cache import (BuildCache, BuildCacheConfig, CachedStage,
get_build_cache_config_from_env)
from .llm_args import (CalibConfig, DraftTargetDecodingConfig,
from .llm_args import (BaseLlmArgs, CalibConfig, DraftTargetDecodingConfig,
EagleDecodingConfig, KvCacheConfig, LlmArgs,
LookaheadDecodingConfig, MedusaDecodingConfig,
MTPDecodingConfig, NGramDecodingConfig, _ModelFormatKind,
_ModelWrapper, _ParallelConfig, get_model_format,
update_llm_args_with_extra_dict,
MTPDecodingConfig, NGramDecodingConfig, PybindMirror,
_ModelFormatKind, _ModelWrapper, _ParallelConfig,
get_model_format, update_llm_args_with_extra_dict,
update_llm_args_with_extra_options)
from .mpi_session import MPINodeState, MpiSession
from .tokenizer import TransformersTokenizer, load_hf_tokenizer
from .tokenizer import (TransformersTokenizer, _xgrammar_tokenizer_info,
load_hf_tokenizer)
# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
from .utils import (download_hf_model, download_hf_pretrained_config,
enable_llm_debug, get_directory_size_in_gb, print_colored,
Expand Down Expand Up @@ -855,6 +857,102 @@ class LlmBuildStats:
build_steps_info: List[Tuple[str, float]] = field(default_factory=list)


def llm_args_to_executor_config(args: BaseLlmArgs, tokenizer) -> ExecutorConfig:
max_batch_size = args.max_batch_size
max_num_tokens = args.max_num_tokens
max_seq_len = args.max_seq_len

build_config = args.build_config if isinstance(
args, TrtLlmArgs) else BuildConfig()

max_batch_size = max_batch_size or build_config.max_batch_size
max_num_tokens = max_num_tokens or build_config.max_num_tokens
max_seq_len = max_seq_len or build_config.max_seq_len

executor_config = ExecutorConfig(
max_beam_width=args.max_beam_width,
scheduler_config=PybindMirror.maybe_to_pybind(args.scheduler_config),
batching_type=PybindMirror.maybe_to_pybind(args.batching_type)
or BatchingType.INFLIGHT,
max_batch_size=max_batch_size,
max_num_tokens=max_num_tokens,
gather_generation_logits=args.gather_generation_logits)
if args.backend is None:
# also set executor_config.max_seq_len in TRT workflow, to deduce default max_tokens
if max_seq_len is not None:
executor_config.max_seq_len = max_seq_len
else:
engine_config = EngineConfig.from_json_file(args.model /
"config.json")
executor_config.max_seq_len = engine_config.build_config.max_seq_len

if args.kv_cache_config is not None:
executor_config.kv_cache_config = PybindMirror.maybe_to_pybind(
args.kv_cache_config)
if os.getenv("FORCE_DETERMINISTIC", "0") == "1":
# Disable KV cache reuse for deterministic mode
executor_config.kv_cache_config.enable_block_reuse = False
executor_config.kv_cache_config.enable_partial_reuse = False

if args.peft_cache_config is not None:
executor_config.peft_cache_config = PybindMirror.maybe_to_pybind(
args.peft_cache_config)
elif isinstance(args,
TrtLlmArgs) and args.build_config.plugin_config.lora_plugin:
engine_config = EngineConfig.from_json_file(args.model / "config.json")
lora_config = engine_config.build_config.lora_config
max_lora_rank = lora_config.max_lora_rank
num_lora_modules = engine_config.pretrained_config.num_hidden_layers * \
len(lora_config.lora_target_modules + lora_config.missing_qkv_modules)
executor_config.peft_cache_config = PeftCacheConfig(
num_device_module_layer=max_lora_rank * num_lora_modules *
args.max_loras,
num_host_module_layer=max_lora_rank * num_lora_modules *
args.max_cpu_loras,
)
if args.decoding_config is not None:
executor_config.decoding_config = args.decoding_config

if args.guided_decoding_backend == 'xgrammar':
executor_config.guided_decoding_config = GuidedDecodingConfig(
backend=GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR,
**_xgrammar_tokenizer_info(tokenizer))
elif args.guided_decoding_backend is not None:
raise ValueError(
f"Unrecognized guided decoding backend {args.guided_decoding_backend}"
)

executor_config.normalize_log_probs = args.normalize_log_probs
executor_config.enable_chunked_context = args.enable_chunked_prefill
executor_config.max_beam_width = args.max_beam_width or args.build_config.max_beam_width
if isinstance(
args,
TrtLlmArgs) and args.extended_runtime_perf_knob_config is not None:
executor_config.extended_runtime_perf_knob_config = PybindMirror.maybe_to_pybind(
args.extended_runtime_perf_knob_config)

if args.cache_transceiver_config is not None:
executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind(
args.cache_transceiver_config)

from tensorrt_llm._torch.pyexecutor.config import update_executor_config
update_executor_config(
executor_config,
backend=args.backend,
pytorch_backend_config=args.get_pytorch_backend_config()
if args.backend in ["pytorch", "_autodeploy"] else None,
mapping=args.parallel_config.to_mapping(),
build_config=args.build_config
if isinstance(args, TrtLlmArgs) else None,
speculative_config=args.speculative_config,
hf_model_dir=self._hf_model_dir,
max_input_len=args.max_input_len,
max_seq_len=max_seq_len)

executor_config.llm_parallel_config = args.parallel_config
return executor_config


__all__ = [
'LlmArgs',
'LlmBuildStats',
Expand Down