Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions tensorrt_llm/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ def create(
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
hf_model_dir: Optional[Path] = None,
llm_args=None,
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
# local imports to avoid cyclic importing
from .proxy import GenerationExecutorProxy
Expand All @@ -378,6 +379,8 @@ def create(
"engine": engine,
"executor_config": executor_config,
"batched_logits_processor": batched_logits_processor,
"hf_model_dir": hf_model_dir,
"llm_args": llm_args,
}

if lora_config:
Expand All @@ -395,9 +398,7 @@ def create(
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
is_llm_executor=is_llm_executor)

# WAR: For the performance of gathering logits, we use single process worker
# for TP1 to avoid the large overhead of IPC.
Expand All @@ -408,9 +409,7 @@ def create(
"Using single process worker for TP1, this may hurt streaming generation performance."
)
return GenerationExecutorWorker(**worker_kwargs,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
is_llm_executor=is_llm_executor)

# For single-gpu case:
# Partition the workload to multiple process for streaming performance.
Expand All @@ -422,9 +421,7 @@ def create(
model_world_size=model_world_size,
mpi_session=None, # use mpi4py
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
is_llm_executor=is_llm_executor)
else:
ctx = multiprocessing.get_context("spawn")
# The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
Expand All @@ -435,9 +432,7 @@ def create(
model_world_size=model_world_size,
mpi_session=mpi_session,
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
garbage_collection_gen0_threshold=
garbage_collection_gen0_threshold)
is_llm_executor=is_llm_executor)

def wait_first_completed(
self, futures: List[GenerationResult]
Expand Down
9 changes: 4 additions & 5 deletions tensorrt_llm/executor/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(
worker_cls: type = GenerationExecutorWorker,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
) -> None:
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
)
Expand Down Expand Up @@ -87,14 +86,14 @@ def __init__(

self.model_world_size = model_world_size

self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold
self.garbage_collection_gen0_threshold = worker_kwargs[
"llm_args"].garbage_collection_gen0_threshold if worker_kwargs.get(
"llm_args", None) is not None else None

worker_kwargs = dict(**worker_kwargs,
worker_queues=self._setup_queues(),
postproc_worker_config=postproc_worker_config,
is_llm_executor=False,
garbage_collection_gen0_threshold=self.
garbage_collection_gen0_threshold)
is_llm_executor=False)

if "log_level" not in worker_kwargs:
worker_kwargs["log_level"] = logger.level
Expand Down
136 changes: 103 additions & 33 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
mpi_comm, mpi_rank, nvtx_range_debug)
from ..bindings import executor as tllm
from ..builder import ConfigEncoder, Engine, EngineConfig
from ..llmapi.llm_args import PybindMirror
from ..llmapi.llm_args import PybindMirror, TorchLlmArgs
from ..llmapi.mpi_session import set_mpi_session_cpp
from ..llmapi.tokenizer import (_llguidance_tokenizer_info,
_xgrammar_tokenizer_info)
from ..llmapi.tracer import VizTracer, global_tracer, set_global_tracer
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
clear_sched_affinity, print_colored_debug,
Expand Down Expand Up @@ -58,7 +60,8 @@ def __init__(
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
hf_model_dir: Optional[Path] = None,
llm_args: Optional[TorchLlmArgs] = None,
) -> None:
postproc_config = postproc_worker_config or PostprocWorkerConfig()
super().__init__(
Expand All @@ -79,22 +82,106 @@ def __init__(
self._await_response_helper = AwaitResponseHelper(
self) # TODO: make it weakref
self._executor_config = executor_config
self._is_pytorch_backend = getattr(self._executor_config, "backend",
None) == "pytorch"
self._is_pytorch_backend = llm_args is not None and llm_args.backend == "pytorch"

if global_mpi_size() > 1:
logger.set_rank(self.global_rank)

if isinstance(engine, list):
engine = engine[self.rank]

if executor_config is None:
executor_config = tllm.ExecutorConfig(1)
def _create_py_executor():
device_id = self.global_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)

max_batch_size = llm_args.max_batch_size
max_num_tokens = llm_args.max_num_tokens
max_seq_len = llm_args.max_seq_len

self._executor_config = tllm.ExecutorConfig(
max_beam_width=llm_args.max_beam_width,
scheduler_config=PybindMirror.maybe_to_pybind(
llm_args.scheduler_config),
batching_type=PybindMirror.maybe_to_pybind(
llm_args.batching_type) or tllm.BatchingType.INFLIGHT,
max_batch_size=max_batch_size,
max_num_tokens=max_num_tokens,
gather_generation_logits=llm_args.gather_generation_logits)

if llm_args.kv_cache_config is not None:
self._executor_config.kv_cache_config = PybindMirror.maybe_to_pybind(
llm_args.kv_cache_config)
if os.getenv("FORCE_DETERMINISTIC", "0") == "1":
# Disable KV cache reuse for deterministic mode
self._executor_config.kv_cache_config.enable_block_reuse = False
self._executor_config.kv_cache_config.enable_partial_reuse = False
if llm_args.peft_cache_config is not None:
self._executor_config.peft_cache_config = PybindMirror.maybe_to_pybind(
llm_args.peft_cache_config)
if llm_args.decoding_config is not None:
self._executor_config.decoding_config = llm_args.decoding_config
if llm_args.guided_decoding_backend == 'xgrammar':
self._executor_config.guided_decoding_config = tllm.GuidedDecodingConfig(
backend=tllm.GuidedDecodingConfig.GuidedDecodingBackend.
XGRAMMAR,
**_xgrammar_tokenizer_info(self.tokenizer))
elif llm_args.guided_decoding_backend == 'llguidance':
self._executor_config.guided_decoding_config = tllm.GuidedDecodingConfig(
backend=tllm.GuidedDecodingConfig.GuidedDecodingBackend.
LLGUIDANCE,
**_llguidance_tokenizer_info(self.tokenizer))
elif llm_args.guided_decoding_backend is not None:
raise ValueError(
f"Unsupported guided decoding backend {llm_args.guided_decoding_backend}"
)

executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
processor_batched=batched_logits_processor, replicate=False)
self._executor_config.normalize_log_probs = llm_args.normalize_log_probs
self._executor_config.enable_chunked_context = llm_args.enable_chunked_prefill
self._executor_config.max_beam_width = llm_args.max_beam_width
if llm_args.cache_transceiver_config is not None:
self._executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind(
llm_args.cache_transceiver_config)
from tensorrt_llm._torch.pyexecutor.config import \
update_executor_config
update_executor_config(
self._executor_config,
backend=llm_args.backend,
pytorch_backend_config=llm_args.get_pytorch_backend_config()
if llm_args.backend in ["pytorch", "_autodeploy"] else None,
mapping=llm_args.parallel_config.to_mapping(),
speculative_config=llm_args.speculative_config,
hf_model_dir=hf_model_dir,
max_input_len=llm_args.max_input_len,
max_seq_len=max_seq_len)

self._executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
processor_batched=batched_logits_processor, replicate=False)
args = {
"executor_config": self._executor_config,
"checkpoint_dir": hf_model_dir,
}
if llm_args.backend == "pytorch":
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
create_py_executor
create_executor = create_py_executor
args["lora_config"] = lora_config
args[
"garbage_collection_gen0_threshold"] = llm_args.garbage_collection_gen0_threshold
elif llm_args.backend == "_autodeploy":
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
create_autodeploy_executor
create_executor = create_autodeploy_executor
else:
raise ValueError(
f"Unsupported backend config: {llm_args.backend}")
return create_executor(**args)

def _create_engine():
if executor_config is None:
executor_config = tllm.ExecutorConfig(1)

executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
processor_batched=batched_logits_processor, replicate=False)
device_id = self.global_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)

Expand All @@ -113,30 +200,11 @@ def _create_engine():
executor_config=executor_config,
managed_weights=engine.managed_weights)

if not hasattr(executor_config, "backend"):
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
executor_config)
args = {
"executor_config": executor_config,
"checkpoint_dir": executor_config.hf_model_dir,
}
if executor_config.backend == "pytorch":
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
create_py_executor
create_executor = create_py_executor
args["lora_config"] = lora_config
args[
"garbage_collection_gen0_threshold"] = garbage_collection_gen0_threshold
elif executor_config.backend == "_autodeploy":
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
create_autodeploy_executor
create_executor = create_autodeploy_executor
else:
raise ValueError(
f"Unsupported backend config: {executor_config.backend}")
return create_executor(**args)
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
executor_config)

self.engine = _create_engine()
self.engine = _create_py_executor(
) if llm_args is not None else _create_engine()

self._lora_manager: Optional[LoraManager] = None
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
Expand Down Expand Up @@ -631,7 +699,8 @@ def worker_main(
is_llm_executor: Optional[
bool] = True, # whether it's the main executor instance
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
hf_model_dir: Optional[Path] = None,
llm_args: Optional[TorchLlmArgs] = None,
) -> None:
mpi_comm().barrier()
print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n",
Expand Down Expand Up @@ -758,7 +827,8 @@ def notify_proxy_threads_to_quit():
postproc_worker_config=postproc_worker_config,
is_llm_executor=is_llm_executor,
lora_config=lora_config,
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
hf_model_dir=hf_model_dir,
llm_args=llm_args)
except Exception as e:
logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}")
logger.error(traceback.format_exc())
Expand Down
8 changes: 4 additions & 4 deletions tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,8 +975,8 @@ def _build_model(self):
return_logits = self.args.gather_generation_logits

self._executor = self._executor_cls.create(
self._engine_dir,
executor_config=self._executor_config,
engine=None,
executor_config=None,
batched_logits_processor=self.args.batched_logits_processor,
model_world_size=self.args.parallel_config.world_size,
mpi_session=self.mpi_session,
Expand All @@ -989,8 +989,8 @@ def _build_model(self):
),
is_llm_executor=True,
lora_config=self.args.lora_config,
garbage_collection_gen0_threshold=self.args.
garbage_collection_gen0_threshold)
hf_model_dir=self._hf_model_dir,
llm_args=self.args)

def _validate_args_for_torch_backend(self, kwargs: dict) -> None:
"""Validate that users don't pass TrtLlmArgs-specific arguments when using PyTorch backend.
Expand Down