diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 2e84d9abc44..88562d27495 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -351,7 +351,8 @@ def create( postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None, + hf_model_dir: Optional[Path] = None, + llm_args=None, ) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]: # local imports to avoid cyclic importing from .proxy import GenerationExecutorProxy @@ -378,6 +379,8 @@ def create( "engine": engine, "executor_config": executor_config, "batched_logits_processor": batched_logits_processor, + "hf_model_dir": hf_model_dir, + "llm_args": llm_args, } if lora_config: @@ -395,9 +398,7 @@ def create( model_world_size=model_world_size, mpi_session=mpi_session, postproc_worker_config=postproc_worker_config, - is_llm_executor=is_llm_executor, - garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + is_llm_executor=is_llm_executor) # WAR: For the performance of gathering logits, we use single process worker # for TP1 to avoid the large overhead of IPC. @@ -408,9 +409,7 @@ def create( "Using single process worker for TP1, this may hurt streaming generation performance." ) return GenerationExecutorWorker(**worker_kwargs, - is_llm_executor=is_llm_executor, - garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + is_llm_executor=is_llm_executor) # For single-gpu case: # Partition the workload to multiple process for streaming performance. @@ -422,9 +421,7 @@ def create( model_world_size=model_world_size, mpi_session=None, # use mpi4py postproc_worker_config=postproc_worker_config, - is_llm_executor=is_llm_executor, - garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + is_llm_executor=is_llm_executor) else: ctx = multiprocessing.get_context("spawn") # The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot. @@ -435,9 +432,7 @@ def create( model_world_size=model_world_size, mpi_session=mpi_session, postproc_worker_config=postproc_worker_config, - is_llm_executor=is_llm_executor, - garbage_collection_gen0_threshold= - garbage_collection_gen0_threshold) + is_llm_executor=is_llm_executor) def wait_first_completed( self, futures: List[GenerationResult] diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index 05ff864e1dd..a09d092df2c 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -45,7 +45,6 @@ def __init__( worker_cls: type = GenerationExecutorWorker, postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, - garbage_collection_gen0_threshold: Optional[int] = None, ) -> None: postproc_worker_config = postproc_worker_config or PostprocWorkerConfig( ) @@ -87,14 +86,14 @@ def __init__( self.model_world_size = model_world_size - self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold + self.garbage_collection_gen0_threshold = worker_kwargs[ + "llm_args"].garbage_collection_gen0_threshold if worker_kwargs.get( + "llm_args", None) is not None else None worker_kwargs = dict(**worker_kwargs, worker_queues=self._setup_queues(), postproc_worker_config=postproc_worker_config, - is_llm_executor=False, - garbage_collection_gen0_threshold=self. - garbage_collection_gen0_threshold) + is_llm_executor=False) if "log_level" not in worker_kwargs: worker_kwargs["log_level"] = logger.level diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 6ebd7adc03d..0881238b286 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -18,8 +18,10 @@ mpi_comm, mpi_rank, nvtx_range_debug) from ..bindings import executor as tllm from ..builder import ConfigEncoder, Engine, EngineConfig -from ..llmapi.llm_args import PybindMirror +from ..llmapi.llm_args import PybindMirror, TorchLlmArgs from ..llmapi.mpi_session import set_mpi_session_cpp +from ..llmapi.tokenizer import (_llguidance_tokenizer_info, + _xgrammar_tokenizer_info) from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, clear_sched_affinity, print_colored_debug, @@ -58,7 +60,8 @@ def __init__( postproc_worker_config: Optional[PostprocWorkerConfig] = None, is_llm_executor: Optional[bool] = None, lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None, + hf_model_dir: Optional[Path] = None, + llm_args: Optional[TorchLlmArgs] = None, ) -> None: postproc_config = postproc_worker_config or PostprocWorkerConfig() super().__init__( @@ -79,8 +82,7 @@ def __init__( self._await_response_helper = AwaitResponseHelper( self) # TODO: make it weakref self._executor_config = executor_config - self._is_pytorch_backend = getattr(self._executor_config, "backend", - None) == "pytorch" + self._is_pytorch_backend = llm_args is not None and llm_args.backend == "pytorch" if global_mpi_size() > 1: logger.set_rank(self.global_rank) @@ -88,13 +90,98 @@ def __init__( if isinstance(engine, list): engine = engine[self.rank] - if executor_config is None: - executor_config = tllm.ExecutorConfig(1) + def _create_py_executor(): + device_id = self.global_rank % torch.cuda.device_count() + torch.cuda.set_device(device_id) + + max_batch_size = llm_args.max_batch_size + max_num_tokens = llm_args.max_num_tokens + max_seq_len = llm_args.max_seq_len + + self._executor_config = tllm.ExecutorConfig( + max_beam_width=llm_args.max_beam_width, + scheduler_config=PybindMirror.maybe_to_pybind( + llm_args.scheduler_config), + batching_type=PybindMirror.maybe_to_pybind( + llm_args.batching_type) or tllm.BatchingType.INFLIGHT, + max_batch_size=max_batch_size, + max_num_tokens=max_num_tokens, + gather_generation_logits=llm_args.gather_generation_logits) + + if llm_args.kv_cache_config is not None: + self._executor_config.kv_cache_config = PybindMirror.maybe_to_pybind( + llm_args.kv_cache_config) + if os.getenv("FORCE_DETERMINISTIC", "0") == "1": + # Disable KV cache reuse for deterministic mode + self._executor_config.kv_cache_config.enable_block_reuse = False + self._executor_config.kv_cache_config.enable_partial_reuse = False + if llm_args.peft_cache_config is not None: + self._executor_config.peft_cache_config = PybindMirror.maybe_to_pybind( + llm_args.peft_cache_config) + if llm_args.decoding_config is not None: + self._executor_config.decoding_config = llm_args.decoding_config + if llm_args.guided_decoding_backend == 'xgrammar': + self._executor_config.guided_decoding_config = tllm.GuidedDecodingConfig( + backend=tllm.GuidedDecodingConfig.GuidedDecodingBackend. + XGRAMMAR, + **_xgrammar_tokenizer_info(self.tokenizer)) + elif llm_args.guided_decoding_backend == 'llguidance': + self._executor_config.guided_decoding_config = tllm.GuidedDecodingConfig( + backend=tllm.GuidedDecodingConfig.GuidedDecodingBackend. + LLGUIDANCE, + **_llguidance_tokenizer_info(self.tokenizer)) + elif llm_args.guided_decoding_backend is not None: + raise ValueError( + f"Unsupported guided decoding backend {llm_args.guided_decoding_backend}" + ) - executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( - processor_batched=batched_logits_processor, replicate=False) + self._executor_config.normalize_log_probs = llm_args.normalize_log_probs + self._executor_config.enable_chunked_context = llm_args.enable_chunked_prefill + self._executor_config.max_beam_width = llm_args.max_beam_width + if llm_args.cache_transceiver_config is not None: + self._executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind( + llm_args.cache_transceiver_config) + from tensorrt_llm._torch.pyexecutor.config import \ + update_executor_config + update_executor_config( + self._executor_config, + backend=llm_args.backend, + pytorch_backend_config=llm_args.get_pytorch_backend_config() + if llm_args.backend in ["pytorch", "_autodeploy"] else None, + mapping=llm_args.parallel_config.to_mapping(), + speculative_config=llm_args.speculative_config, + hf_model_dir=hf_model_dir, + max_input_len=llm_args.max_input_len, + max_seq_len=max_seq_len) + + self._executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( + processor_batched=batched_logits_processor, replicate=False) + args = { + "executor_config": self._executor_config, + "checkpoint_dir": hf_model_dir, + } + if llm_args.backend == "pytorch": + from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ + create_py_executor + create_executor = create_py_executor + args["lora_config"] = lora_config + args[ + "garbage_collection_gen0_threshold"] = llm_args.garbage_collection_gen0_threshold + elif llm_args.backend == "_autodeploy": + from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ + create_autodeploy_executor + create_executor = create_autodeploy_executor + else: + raise ValueError( + f"Unsupported backend config: {llm_args.backend}") + return create_executor(**args) def _create_engine(): + if executor_config is None: + executor_config = tllm.ExecutorConfig(1) + + executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig( + processor_batched=batched_logits_processor, replicate=False) device_id = self.global_rank % torch.cuda.device_count() torch.cuda.set_device(device_id) @@ -113,30 +200,11 @@ def _create_engine(): executor_config=executor_config, managed_weights=engine.managed_weights) - if not hasattr(executor_config, "backend"): - return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY, - executor_config) - args = { - "executor_config": executor_config, - "checkpoint_dir": executor_config.hf_model_dir, - } - if executor_config.backend == "pytorch": - from tensorrt_llm._torch.pyexecutor.py_executor_creator import \ - create_py_executor - create_executor = create_py_executor - args["lora_config"] = lora_config - args[ - "garbage_collection_gen0_threshold"] = garbage_collection_gen0_threshold - elif executor_config.backend == "_autodeploy": - from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \ - create_autodeploy_executor - create_executor = create_autodeploy_executor - else: - raise ValueError( - f"Unsupported backend config: {executor_config.backend}") - return create_executor(**args) + return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY, + executor_config) - self.engine = _create_engine() + self.engine = _create_py_executor( + ) if llm_args is not None else _create_engine() self._lora_manager: Optional[LoraManager] = None self._prompt_adapter_manager: Optional[PromptAdapterManager] = None @@ -631,7 +699,8 @@ def worker_main( is_llm_executor: Optional[ bool] = True, # whether it's the main executor instance lora_config: Optional[LoraConfig] = None, - garbage_collection_gen0_threshold: Optional[int] = None, + hf_model_dir: Optional[Path] = None, + llm_args: Optional[TorchLlmArgs] = None, ) -> None: mpi_comm().barrier() print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n", @@ -758,7 +827,8 @@ def notify_proxy_threads_to_quit(): postproc_worker_config=postproc_worker_config, is_llm_executor=is_llm_executor, lora_config=lora_config, - garbage_collection_gen0_threshold=garbage_collection_gen0_threshold) + hf_model_dir=hf_model_dir, + llm_args=llm_args) except Exception as e: logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}") logger.error(traceback.format_exc()) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 934813aa4c4..10e29a3a3c1 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -975,8 +975,8 @@ def _build_model(self): return_logits = self.args.gather_generation_logits self._executor = self._executor_cls.create( - self._engine_dir, - executor_config=self._executor_config, + engine=None, + executor_config=None, batched_logits_processor=self.args.batched_logits_processor, model_world_size=self.args.parallel_config.world_size, mpi_session=self.mpi_session, @@ -989,8 +989,8 @@ def _build_model(self): ), is_llm_executor=True, lora_config=self.args.lora_config, - garbage_collection_gen0_threshold=self.args. - garbage_collection_gen0_threshold) + hf_model_dir=self._hf_model_dir, + llm_args=self.args) def _validate_args_for_torch_backend(self, kwargs: dict) -> None: """Validate that users don't pass TrtLlmArgs-specific arguments when using PyTorch backend.