From 99fc34728ccb708a91c70f4af627f4eda74b8e18 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Fri, 5 Sep 2025 17:30:46 -0700 Subject: [PATCH 01/11] move weights loading related logic to ModelLoader Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- .../_torch/pyexecutor/model_engine.py | 305 +---------------- .../_torch/pyexecutor/model_loader.py | 311 ++++++++++++++++++ tensorrt_llm/bench/benchmark/utils/general.py | 2 +- tensorrt_llm/bench/dataclasses/reporting.py | 2 +- 4 files changed, 327 insertions(+), 293 deletions(-) create mode 100644 tensorrt_llm/_torch/pyexecutor/model_loader.py diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 1c313db240e..e35495a0535 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1,12 +1,9 @@ import bisect import contextlib -import copy import functools import gc import inspect import math -import os -import traceback import weakref from abc import ABC, abstractmethod from contextlib import contextmanager @@ -17,16 +14,13 @@ import tensorrt_llm.bindings.internal.userbuffers as ub from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc, - str_dtype_to_torch, torch_dtype_to_str, - trace_func) + torch_dtype_to_str, trace_func) from tensorrt_llm.inputs.multimodal import (MultimodalParams, MultimodalRuntimeData) from tensorrt_llm.logger import logger from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.lora_manager import LoraModelConfig from tensorrt_llm.mapping import CpType, Mapping -from tensorrt_llm.models.modeling_utils import QuantAlgo -from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2 from ..attention_backend.interface import (AttentionMetadata, AttentionRuntimeFeatures) @@ -40,13 +34,10 @@ from ..distributed.communicator import init_pp_comm from ..expert_statistic import ExpertStatistic from ..metadata import KVCacheParams -from ..model_config import ModelConfig, MoeLoadBalancerConfig -from ..models import AutoModelForCausalLM from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader -from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode, - timing) -from ..modules.fused_moe.moe_load_balancer import ( - MoeLoadBalancer, MoeLoadBalancerIterContext, maybe_create_moe_load_balancer) +from ..models.modeling_utils import DecoderModelForCausalLM +from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer, + MoeLoadBalancerIterContext) from ..speculative import (SpecMetadata, get_num_extra_kv_tokens, get_spec_metadata, update_spec_config_from_model_config) @@ -55,12 +46,13 @@ from ..utils import (get_model_extra_attrs, set_per_request_piecewise_cuda_graph_flag, set_torch_compiling, with_model_extra_attrs) -from .config import LoadFormat, PyTorchConfig +from .config import PyTorchConfig from .config_utils import is_mla from .cuda_graph_runner import CUDAGraphRunner from .guided_decoder import CapturableGuidedDecoder from .layerwise_nvtx_marker import LayerwiseNvtxMarker from .llm_request import get_draft_token_length +from .model_loader import ModelLoader from .resource_manager import (BaseResourceManager, KVCacheManager, ResourceManager, ResourceManagerType) from .sampler import SampleStateTensors @@ -95,137 +87,6 @@ def warmup(self, resource_manager: ResourceManager) -> None: return -_KV_CACHE_MAP = { - "fp8": QuantAlgo.FP8.value, - "nvfp4": QuantAlgo.NVFP4.value, - "auto": "auto" -} -_VALID_KV_CACHE_DTYPES = ("fp8", "nvfp4", "auto") - - -def validate_and_set_mamba_ssm_cache_dtype(config: ModelConfig, - mamba_ssm_cache_dtype: str) -> None: - if mamba_ssm_cache_dtype == "auto": - mamba_ssm_cache_dtype = config.pretrained_config.torch_dtype - else: - mamba_ssm_cache_dtype = str_dtype_to_torch(mamba_ssm_cache_dtype) - - config.quant_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype - - -def validate_and_set_kv_cache_quant(model_config: ModelConfig, - pyt_kv_cache_dtype: str) -> QuantAlgo: - logger.info( - f'Validating KV Cache config against kv_cache_dtype="{pyt_kv_cache_dtype}"' - ) - # Quantization from hf_quant_config.json - kv_cache_quant = model_config.quant_config.kv_cache_quant_algo - # PyTorch configuration quantization - valid_pyt_quant = bool(pyt_kv_cache_dtype in _VALID_KV_CACHE_DTYPES) - mapped_pyt_quant = _KV_CACHE_MAP.get(pyt_kv_cache_dtype, None) - - # If we're letting the checkpoint dictate the quant with auto, simply - # return and do not modify the checkpoint. - if pyt_kv_cache_dtype == "auto": - logger.info( - f'KV cache quantization set to "{pyt_kv_cache_dtype}". Using ' - "checkpoint KV quantization.") - return - - # If we have an invalid quantization, simply raise an exception. - if not valid_pyt_quant: - raise ValueError( - "Overriding KV cache quantization with an invalid type " - f'"PyTorchConfig.kv_cache_dtype="{pyt_kv_cache_dtype}" ' - f'Accepted types are "{_VALID_KV_CACHE_DTYPES}".') - - # If we get to this point we have a valid quantization setting, but if - # we have an existing setting and it doesn't match we shouldn't proceed. - if kv_cache_quant is not None and mapped_pyt_quant != kv_cache_quant: - raise RuntimeError( - "Attempting to override KV cache quantization " - f'"{kv_cache_quant}" with PyTorchConfig.kv_cache_dtype=' - f'"{pyt_kv_cache_dtype}". You cannot override a checkpoint with a ' - "pre-quantized KV cache that doesn't match.") - - # We have an open ended KV cache in the checkpoint - # and we have a specified override. - model_config.quant_config.kv_cache_quant_algo = mapped_pyt_quant - - -def initialize_dummy_weights( - model: torch.nn.Module, - low: float = -1e-3, - high: float = 1e-3, - seed: int = 0, -) -> None: - """ - This is similar to this function in SGLang with a few changes: - https://github.com/sgl-project/sglang/blob/e074e76b31d4fff13e87a455dbc3acdaa92c537a/python/sglang/srt/model_loader/weight_utils.py#L577 - - This method is used to initialize weights with dummy values for testing - models without checkpoints. Unquantized (FP16/BF16/etc) values are generated - from a uniform distribution over the interval (low, high). - - For some quantized types (FP8/NVFP4), torch has no built-in way to generate random values. - We simply generate values uniformly across an interval that has been empirically verified - to not generate NaNs/inf for these. - """ - - def _get_random_min_max(dtype: torch.dtype) -> Tuple[int, int]: - # These values are not necessarily the largest possible min/max, - # they need to be small enough to avoid NaNs. - if dtype in (torch.float8_e4m3fn, torch.int8): - return (-3.0, 3.0) - - elif dtype == float4_e2m1x2: - # These correspond to bits of 2 packed FP4 values. - # Because we only go up to 64, the high 4 bits will - # always be 0. But this is fine - we just need values - # that won't generate NaNs. - return (0, 64) - - else: - raise NotImplementedError(f"Unknown quantized type: {dtype}.") - - for param in model.state_dict().values(): - generator = torch.Generator(device=param.data.device) - generator.manual_seed(seed) - dtype = param.data.dtype - - if param.data.element_size() < 2: - # We need to do a cast/round since torch doesn't have uniform_ - # support for these dtypes. - tmp_param = torch.empty(param.data.shape, - dtype=torch.float16, - device=param.data.device) - - quant_min, quant_max = _get_random_min_max(dtype) - tmp_param = tmp_param.uniform_(quant_min, - quant_max, - generator=generator) - - param.data.copy_(tmp_param.to(dtype)) - - # Note: no need to to mess with int32 params, these are probably - # constants and not weights. - elif torch.is_floating_point(param): - param.uniform_(low, high, generator=generator) - - -def get_rank_model_storage(model): - total_bytes = 0 - for _, param in model.named_parameters(): - if param.device.type == 'cuda' and param.device.index == torch.cuda.current_device( - ): - total_bytes += param.element_size() * param.nelement() - for _, buf in model.named_buffers(): - if buf.device.type == 'cuda' and buf.device.index == torch.cuda.current_device( - ): - total_bytes += buf.element_size() * buf.nelement() - return total_bytes - - def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int], max_batch_size: int, max_num_tokens: int, max_draft_len: int, @@ -302,20 +163,17 @@ def __init__( ) attn_backend = pytorch_backend_config.attn_backend - self.model = self._load_model( - model_path, + loader = ModelLoader( + pytorch_backend_config=pytorch_backend_config, mapping=self.mapping, - checkpoint_loader=checkpoint_loader, - attn_backend=attn_backend, - moe_backend=pytorch_backend_config.moe_backend, - moe_disable_finalize_fusion=pytorch_backend_config. - moe_disable_finalize_fusion, - load_format=pytorch_backend_config.load_format, + spec_config=self.spec_config, max_num_tokens=max_num_tokens, - moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens, - moe_load_balancer=pytorch_backend_config.moe_load_balancer, + max_seq_len=max_seq_len, lora_config=lora_config, - drafting_loop_wrapper=drafting_loop_wrapper) + ) + self.model = loader.load(checkpoint_dir=model_path, + checkpoint_loader=checkpoint_loader, + drafting_loop_wrapper=drafting_loop_wrapper) # In case that some tests use stub models and override `_load_model`. if not hasattr(self.model, 'extra_attrs'): self.model.extra_attrs = {} @@ -944,141 +802,6 @@ def __del__(self) -> None: # Release model weights. release_gc() - def _load_model(self, - checkpoint_dir: str, - checkpoint_loader: BaseCheckpointLoader, - load_format: LoadFormat, - max_num_tokens: int, - moe_max_num_tokens: Optional[int] = None, - moe_load_balancer: Optional[MoeLoadBalancerConfig] = None, - lora_config: Optional[LoraConfig] = None, - drafting_loop_wrapper: Optional[Callable[ - [torch.nn.Module], torch.nn.Module]] = None, - **kwargs) -> DecoderModelForCausalLM: - config = checkpoint_loader.load_config( - checkpoint_dir, - trust_remote_code=True, - enable_min_latency=self.pytorch_backend_config.enable_min_latency, - use_cuda_graph=self.pytorch_backend_config.use_cuda_graph, - force_dynamic_quantization=self.pytorch_backend_config. - force_dynamic_quantization, - spec_config=self.spec_config, - max_num_tokens=max_num_tokens, - max_seq_len=self.max_seq_len, - moe_max_num_tokens=moe_max_num_tokens, - moe_load_balancer=moe_load_balancer, - lora_config=lora_config, - allreduce_strategy=self.pytorch_backend_config.allreduce_strategy, - mm_encoder_only=self.pytorch_backend_config.mm_encoder_only, - **kwargs) - - validate_and_set_kv_cache_quant( - config, self.pytorch_backend_config.kv_cache_dtype) - validate_and_set_mamba_ssm_cache_dtype( - config, self.pytorch_backend_config.mamba_ssm_cache_dtype) - - num_layers = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", "0")) - if num_layers > 0: - config.pretrained_config.num_hidden_layers = num_layers - for sub_config in ["text_config", "vision_config"]: - if hasattr(config.pretrained_config, sub_config): - getattr(config.pretrained_config, - sub_config).num_hidden_layers = num_layers - - with timing("Model init total"), maybe_create_moe_load_balancer( - config, self.mapping) as moe_load_balancer: - - try: - # config will be modified in-place for some models, like Qwen2 - config_copy = copy.deepcopy(config) - with MetaInitMode(): - model = AutoModelForCausalLM.from_config(config_copy) - - memo = dict() - - def init_meta_tensor(t: torch.Tensor): - if t.device != torch.device('meta'): - return t - if t not in memo: - memo[t] = torch.empty_like(t, device='cuda') - return memo[t] - - model._apply(init_meta_tensor) - config = config_copy - - except Exception: - logger.info( - f"Fallback to regular model init: {traceback.format_exc(limit=1)}\n" - ) - model = AutoModelForCausalLM.from_config(config) - - model.to("cuda") - rank_model_storage = get_rank_model_storage(model) - logger.info( - f"Use {rank_model_storage / (1024**3):.2f} GB for model weights." - ) - if load_format == LoadFormat.AUTO: - if hasattr(model, 'llm_checkpoint_dir'): - weights = checkpoint_loader.load_weights( - model.llm_checkpoint_dir) - else: - weights = checkpoint_loader.load_weights(checkpoint_dir) - - weight_mapper = checkpoint_loader.get_initialized_weight_mapper( - model, config) - self._call_load_weights(model.load_weights, weights, - weight_mapper) - - if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( - ): - weights = checkpoint_loader.load_weights( - self.spec_config.speculative_model_dir) - self._call_load_weights(model.load_draft_weights, weights, - weight_mapper) - - elif load_format == LoadFormat.DUMMY: - initialize_dummy_weights(model) - if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( - ): - model.draft_model.load_weights_from_target_model(model) - - elif load_format == LoadFormat.VISION_ONLY: - # Vision weights are already loaded within the model. - logger.info( - "LoadFormat.VISION_ONLY: skipping weight loading; using preloaded vision weights." - ) - - else: - raise NotImplementedError( - f"No load support for load format: {load_format}") - - if isinstance(moe_load_balancer, MoeLoadBalancer): - setattr(self, "moe_load_balancer", moe_load_balancer) - moe_load_balancer.register_weight_slots_after_to_cuda() - logger.info("moe_load_balancer finalizing model...") - moe_load_balancer.finalize_model() - logger.info("moe_load_balancer finalize model done") - - torch.cuda.current_stream().synchronize() - - if drafting_loop_wrapper is not None: - model = drafting_loop_wrapper(model) - self.model_is_wrapped = True - else: - self.model_is_wrapped = False - - return model - - def _call_load_weights(self, load_method, weights, weight_mapper): - # TODO smor- this is a temporary solution to load weights. - # Once checkpoint format is unified, this method will be removed. - from inspect import getfullargspec - args = getfullargspec(load_method).args - if "weight_mapper" in args: - load_method(weights, weight_mapper=weight_mapper) - else: - load_method(weights) - def _init_max_seq_len(self): # For mm_encoder_only mode, infer_max_seq_len() is for LLM decoder models if hasattr(self.model, 'infer_max_seq_len'): diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py new file mode 100644 index 00000000000..389e24096f5 --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -0,0 +1,311 @@ +import copy +import inspect +import os +import traceback +from typing import Callable, Optional, Tuple + +import torch + +from tensorrt_llm._utils import str_dtype_to_torch +from tensorrt_llm.logger import logger +from tensorrt_llm.lora_helper import LoraConfig +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.modeling_utils import QuantAlgo +from tensorrt_llm.quantization.utils.fp4_utils import float4_e2m1x2 + +from ..model_config import ModelConfig +from ..models import AutoModelForCausalLM +from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader +from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode, + timing) +from ..modules.fused_moe.moe_load_balancer import ( + MoeLoadBalancer, maybe_create_moe_load_balancer) +from .config import LoadFormat, PyTorchConfig + +# Constants from the original file for KV cache validation +_KV_CACHE_MAP = { + "fp8": QuantAlgo.FP8.value, + "nvfp4": QuantAlgo.NVFP4.value, + "auto": "auto" +} +_VALID_KV_CACHE_DTYPES = ("fp8", "nvfp4", "auto") + + +def validate_and_set_mamba_ssm_cache_dtype(config: ModelConfig, + mamba_ssm_cache_dtype: str) -> None: + if mamba_ssm_cache_dtype == "auto": + mamba_ssm_cache_dtype = config.pretrained_config.torch_dtype + else: + mamba_ssm_cache_dtype = str_dtype_to_torch(mamba_ssm_cache_dtype) + + config.quant_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype + + +def validate_and_set_kv_cache_quant(model_config: ModelConfig, + pyt_kv_cache_dtype: str) -> QuantAlgo: + logger.info( + f'Validating KV Cache config against kv_cache_dtype="{pyt_kv_cache_dtype}"' + ) + # Quantization from hf_quant_config.json + kv_cache_quant = model_config.quant_config.kv_cache_quant_algo + # PyTorch configuration quantization + valid_pyt_quant = bool(pyt_kv_cache_dtype in _VALID_KV_CACHE_DTYPES) + mapped_pyt_quant = _KV_CACHE_MAP.get(pyt_kv_cache_dtype, None) + + # If we're letting the checkpoint dictate the quant with auto, simply + # return and do not modify the checkpoint. + if pyt_kv_cache_dtype == "auto": + logger.info( + f'KV cache quantization set to "{pyt_kv_cache_dtype}". Using ' + "checkpoint KV quantization.") + return + + # If we have an invalid quantization, simply raise an exception. + if not valid_pyt_quant: + raise ValueError( + "Overriding KV cache quantization with an invalid type " + f'"PyTorchConfig.kv_cache_dtype="{pyt_kv_cache_dtype}" ' + f'Accepted types are "{_VALID_KV_CACHE_DTYPES}".') + + # If we get to this point we have a valid quantization setting, but if + # we have an existing setting and it doesn't match we shouldn't proceed. + if kv_cache_quant is not None and mapped_pyt_quant != kv_cache_quant: + raise RuntimeError( + "Attempting to override KV cache quantization " + f'"{kv_cache_quant}" with PyTorchConfig.kv_cache_dtype=' + f'"{pyt_kv_cache_dtype}". You cannot override a checkpoint with a ' + "pre-quantized KV cache that doesn't match.") + + # We have an open ended KV cache in the checkpoint + # and we have a specified override. + model_config.quant_config.kv_cache_quant_algo = mapped_pyt_quant + + +class ModelLoader: + """ + Handles the loading, configuration, and weight initialization of a PyTorch model. + This class isolates model loading logic from the main execution engine. + """ + + def __init__(self, + pytorch_backend_config: PyTorchConfig, + mapping: Mapping, + spec_config: Optional["DecodingBaseConfig"], + max_num_tokens: int, + max_seq_len: Optional[int], + lora_config: Optional[LoraConfig] = None): + """ + Initializes the ModelLoader. + + Args: + pytorch_backend_config: Configuration for the PyTorch backend. + mapping: The distributed mapping configuration. + spec_config: Configuration for speculative decoding. + max_num_tokens: The maximum number of tokens the engine will handle. + max_seq_len: The maximum sequence length. + lora_config: Configuration for LoRA. + """ + self.pytorch_backend_config = pytorch_backend_config + self.mapping = mapping + self.spec_config = spec_config + self.max_num_tokens = max_num_tokens + self.max_seq_len = max_seq_len + self.lora_config = lora_config + self.moe_load_balancer = None + + def load( + self, + checkpoint_dir: str, + checkpoint_loader: BaseCheckpointLoader, + drafting_loop_wrapper: Optional[Callable[[torch.nn.Module], + torch.nn.Module]] = None + ) -> DecoderModelForCausalLM: + """ + Loads the model, its weights, and applies necessary configurations. + + Args: + checkpoint_dir: The directory of the model checkpoint. + checkpoint_loader: The loader object for model checkpoints. + drafting_loop_wrapper: An optional wrapper for speculative decoding models. + + Returns: + The loaded and initialized PyTorch model. + """ + config = self._load_and_validate_config(checkpoint_dir, + checkpoint_loader) + + with timing("Model init total"), maybe_create_moe_load_balancer( + config, self.mapping) as moe_load_balancer: + + # Attempt to initialize the model on the meta device for speed + try: + config_copy = copy.deepcopy(config) + with MetaInitMode(): + model = AutoModelForCausalLM.from_config(config_copy) + self._materialize_meta_model(model) + config = config_copy + except Exception: + logger.info("Fallback to regular model init: " + f"{traceback.format_exc(limit=1)}\n") + model = AutoModelForCausalLM.from_config(config) + + model.to("cuda") + + logger.info("Use %.2f GB for model weights.", + self._get_rank_model_storage(model) / (1024**3)) + + self._load_weights(model, config, checkpoint_dir, checkpoint_loader) + + if isinstance(moe_load_balancer, MoeLoadBalancer): + self.moe_load_balancer = moe_load_balancer + moe_load_balancer.register_weight_slots_after_to_cuda() + logger.info("moe_load_balancer finalizing model...") + moe_load_balancer.finalize_model() + logger.info("moe_load_balancer finalize model done") + + torch.cuda.current_stream().synchronize() + + if drafting_loop_wrapper is not None: + model = drafting_loop_wrapper(model) + + return model + + def _load_weights(self, model: DecoderModelForCausalLM, config: ModelConfig, + checkpoint_dir: str, + checkpoint_loader: BaseCheckpointLoader): + """Handles the logic for loading weights based on the specified format.""" + load_format = self.pytorch_backend_config.load_format + + if load_format == LoadFormat.AUTO: + checkpoint_path = (getattr(model, 'llm_checkpoint_dir', None) + or checkpoint_dir) + weights = checkpoint_loader.load_weights(checkpoint_path) + weight_mapper = checkpoint_loader.get_initialized_weight_mapper( + model, config) + self._call_load_weights(model.load_weights, weights, weight_mapper) + + # Load draft model weights if needed for speculative decoding + if self.spec_config and self.spec_config.spec_dec_mode.need_load_draft_weights( + ): + draft_weights = checkpoint_loader.load_weights( + self.spec_config.speculative_model_dir) + self._call_load_weights(model.load_draft_weights, draft_weights, + weight_mapper) + + elif load_format == LoadFormat.DUMMY: + self._initialize_dummy_weights(model) + if self.spec_config and self.spec_config.spec_dec_mode.need_load_draft_weights( + ): + model.draft_model.load_weights_from_target_model(model) + + elif load_format == LoadFormat.VISION_ONLY: + logger.info( + "LoadFormat.VISION_ONLY: skipping weight loading; using preloaded vision weights." + ) + + else: + raise NotImplementedError( + f"No load support for load format: {load_format}") + + def _load_and_validate_config( + self, checkpoint_dir: str, + checkpoint_loader: BaseCheckpointLoader) -> ModelConfig: + """Loads and validates the model configuration.""" + config = checkpoint_loader.load_config( + checkpoint_dir, + trust_remote_code=True, + enable_min_latency=self.pytorch_backend_config.enable_min_latency, + use_cuda_graph=self.pytorch_backend_config.use_cuda_graph, + force_dynamic_quantization=self.pytorch_backend_config. + force_dynamic_quantization, + spec_config=self.spec_config, + max_num_tokens=self.max_num_tokens, + max_seq_len=self.max_seq_len, + moe_max_num_tokens=self.pytorch_backend_config.moe_max_num_tokens, + moe_load_balancer=self.pytorch_backend_config.moe_load_balancer, + lora_config=self.lora_config, + allreduce_strategy=self.pytorch_backend_config.allreduce_strategy, + mm_encoder_only=self.pytorch_backend_config.mm_encoder_only) + + validate_and_set_kv_cache_quant( + config, self.pytorch_backend_config.kv_cache_dtype) + validate_and_set_mamba_ssm_cache_dtype( + config, self.pytorch_backend_config.mamba_ssm_cache_dtype) + + # Allow overriding the number of layers via environment variable + num_layers_override = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", + "0")) + if num_layers_override > 0: + config.pretrained_config.num_hidden_layers = num_layers_override + for sub_config in ["text_config", "vision_config"]: + if hasattr(config.pretrained_config, sub_config): + getattr(config.pretrained_config, + sub_config).num_hidden_layers = num_layers_override + return config + + @staticmethod + def _materialize_meta_model(model: torch.nn.Module): + """Converts a model on the 'meta' device to a materialized model on CUDA.""" + memo = {} + + def init_meta_tensor(t: torch.Tensor): + if t.device != torch.device('meta'): + return t + if t not in memo: + memo[t] = torch.empty_like(t, device='cuda') + return memo[t] + + model._apply(init_meta_tensor) + + @staticmethod + def _call_load_weights(load_method: Callable, weights, weight_mapper): + """Calls the model's weight loading method with the correct arguments.""" + args = inspect.getfullargspec(load_method).args + if "weight_mapper" in args: + load_method(weights, weight_mapper=weight_mapper) + else: + load_method(weights) + + @staticmethod + def _get_rank_model_storage(model: torch.nn.Module) -> int: + """Calculates the total memory in bytes used by the model's weights and buffers on the current device.""" + total_bytes = 0 + current_device_idx = torch.cuda.current_device() + for param in model.parameters(): + if param.device.type == 'cuda' and param.device.index == current_device_idx: + total_bytes += param.element_size() * param.nelement() + for buf in model.buffers(): + if buf.device.type == 'cuda' and buf.device.index == current_device_idx: + total_bytes += buf.element_size() * buf.nelement() + return total_bytes + + @staticmethod + def _initialize_dummy_weights(model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, + seed: int = 0) -> None: + """Initializes model weights with random dummy values for testing purposes.""" + + # This function's logic is copied directly from the original file + def _get_random_min_max(dtype: torch.dtype) -> Tuple[int, int]: + if dtype in (torch.float8_e4m3fn, torch.int8): + return (-3.0, 3.0) + elif dtype == float4_e2m1x2: + return (0, 64) + else: + raise NotImplementedError(f"Unknown quantized type: {dtype}.") + + for param in model.state_dict().values(): + generator = torch.Generator(device=param.data.device) + generator.manual_seed(seed) + dtype = param.data.dtype + + if param.data.element_size() < 2: + tmp_param = torch.empty_like(param.data, + dtype=torch.float16, + device=param.data.device) + quant_min, quant_max = _get_random_min_max(dtype) + tmp_param.uniform_(quant_min, quant_max, generator=generator) + param.data.copy_(tmp_param.to(dtype)) + elif torch.is_floating_point(param): + param.uniform_(low, high, generator=generator) diff --git a/tensorrt_llm/bench/benchmark/utils/general.py b/tensorrt_llm/bench/benchmark/utils/general.py index ff3cd933ce1..a21511f38cd 100755 --- a/tensorrt_llm/bench/benchmark/utils/general.py +++ b/tensorrt_llm/bench/benchmark/utils/general.py @@ -8,7 +8,7 @@ import yaml -from tensorrt_llm._torch.pyexecutor.model_engine import \ +from tensorrt_llm._torch.pyexecutor.model_loader import \ validate_and_set_kv_cache_quant from tensorrt_llm.bench.build.build import (get_benchmark_engine_settings, get_model_config) diff --git a/tensorrt_llm/bench/dataclasses/reporting.py b/tensorrt_llm/bench/dataclasses/reporting.py index b12873b5637..70e4cae646b 100755 --- a/tensorrt_llm/bench/dataclasses/reporting.py +++ b/tensorrt_llm/bench/dataclasses/reporting.py @@ -4,7 +4,7 @@ from collections import defaultdict from typing import Any, Dict, List, NamedTuple -from tensorrt_llm._torch.pyexecutor.model_engine import \ +from tensorrt_llm._torch.pyexecutor.model_loader import \ validate_and_set_kv_cache_quant from tensorrt_llm.bench.dataclasses.configuration import RuntimeConfig from tensorrt_llm.bench.dataclasses.general import DatasetMetadata From 018b022b9dadde651bbe2666ed04b7ba148f53f1 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Fri, 5 Sep 2025 17:39:19 -0700 Subject: [PATCH 02/11] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index e35495a0535..92445327e25 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -174,6 +174,7 @@ def __init__( self.model = loader.load(checkpoint_dir=model_path, checkpoint_loader=checkpoint_loader, drafting_loop_wrapper=drafting_loop_wrapper) + self.model_is_wrapped = drafting_loop_wrapper is not None # In case that some tests use stub models and override `_load_model`. if not hasattr(self.model, 'extra_attrs'): self.model.extra_attrs = {} From cf84a58e1b115ecd82ce12c4d718e16583677a63 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Mon, 8 Sep 2025 14:46:42 -0700 Subject: [PATCH 03/11] clean Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- .../_torch/pyexecutor/model_engine.py | 2 +- .../_torch/pyexecutor/model_loader.py | 234 ++++++++++-------- 2 files changed, 130 insertions(+), 106 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index bfc8e02b69d..da2a133cab0 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -175,7 +175,7 @@ def __init__( self.model = loader.load(checkpoint_dir=model_path, checkpoint_loader=checkpoint_loader, drafting_loop_wrapper=drafting_loop_wrapper) - self.model_is_wrapped = drafting_loop_wrapper is not None + self.model_is_wrapped = loader.model_is_wrapped # In case that some tests use stub models and override `_load_model`. if not hasattr(self.model, 'extra_attrs'): self.model.extra_attrs = {} diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 389e24096f5..35bbd2d6bc7 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -81,6 +81,77 @@ def validate_and_set_kv_cache_quant(model_config: ModelConfig, model_config.quant_config.kv_cache_quant_algo = mapped_pyt_quant +def initialize_dummy_weights( + model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, + seed: int = 0, +) -> None: + """ + This is similar to this function in SGLang with a few changes: + https://github.com/sgl-project/sglang/blob/e074e76b31d4fff13e87a455dbc3acdaa92c537a/python/sglang/srt/model_loader/weight_utils.py#L577 + This method is used to initialize weights with dummy values for testing + models without checkpoints. Unquantized (FP16/BF16/etc) values are generated + from a uniform distribution over the interval (low, high). + For some quantized types (FP8/NVFP4), torch has no built-in way to generate random values. + We simply generate values uniformly across an interval that has been empirically verified + to not generate NaNs/inf for these. + """ + + def _get_random_min_max(dtype: torch.dtype) -> Tuple[int, int]: + # These values are not necessarily the largest possible min/max, + # they need to be small enough to avoid NaNs. + if dtype in (torch.float8_e4m3fn, torch.int8): + return (-3.0, 3.0) + + elif dtype == float4_e2m1x2: + # These correspond to bits of 2 packed FP4 values. + # Because we only go up to 64, the high 4 bits will + # always be 0. But this is fine - we just need values + # that won't generate NaNs. + return (0, 64) + + else: + raise NotImplementedError(f"Unknown quantized type: {dtype}.") + + for param in model.state_dict().values(): + generator = torch.Generator(device=param.data.device) + generator.manual_seed(seed) + dtype = param.data.dtype + + if param.data.element_size() < 2: + # We need to do a cast/round since torch doesn't have uniform_ + # support for these dtypes. + tmp_param = torch.empty(param.data.shape, + dtype=torch.float16, + device=param.data.device) + + quant_min, quant_max = _get_random_min_max(dtype) + tmp_param = tmp_param.uniform_(quant_min, + quant_max, + generator=generator) + + param.data.copy_(tmp_param.to(dtype)) + + # Note: no need to to mess with int32 params, these are probably + # constants and not weights. + elif torch.is_floating_point(param): + param.uniform_(low, high, generator=generator) + + +def get_rank_model_storage(model): + total_bytes = 0 + for _, param in model.named_parameters(): + if param.device.type == 'cuda' and param.device.index == torch.cuda.current_device( + ): + total_bytes += param.element_size() * param.nelement() + for _, buf in model.named_buffers(): + if buf.device.type == 'cuda' and buf.device.index == torch.cuda.current_device( + ): + total_bytes += buf.element_size() * buf.nelement() + return total_bytes + + class ModelLoader: """ Handles the loading, configuration, and weight initialization of a PyTorch model. @@ -133,31 +204,77 @@ def load( """ config = self._load_and_validate_config(checkpoint_dir, checkpoint_loader) + load_format = self.pytorch_backend_config.load_format with timing("Model init total"), maybe_create_moe_load_balancer( config, self.mapping) as moe_load_balancer: - # Attempt to initialize the model on the meta device for speed try: + # config will be modified in-place for some models, like Qwen2 config_copy = copy.deepcopy(config) with MetaInitMode(): model = AutoModelForCausalLM.from_config(config_copy) - self._materialize_meta_model(model) + + memo = dict() + + def init_meta_tensor(t: torch.Tensor): + if t.device != torch.device('meta'): + return t + if t not in memo: + memo[t] = torch.empty_like(t, device='cuda') + return memo[t] + + model._apply(init_meta_tensor) config = config_copy + except Exception: - logger.info("Fallback to regular model init: " - f"{traceback.format_exc(limit=1)}\n") + logger.info( + f"Fallback to regular model init: {traceback.format_exc(limit=1)}\n" + ) model = AutoModelForCausalLM.from_config(config) model.to("cuda") + rank_model_storage = get_rank_model_storage(model) + logger.info( + f"Use {rank_model_storage / (1024**3):.2f} GB for model weights." + ) + if load_format == LoadFormat.AUTO: + if hasattr(model, 'llm_checkpoint_dir'): + weights = checkpoint_loader.load_weights( + model.llm_checkpoint_dir) + else: + weights = checkpoint_loader.load_weights(checkpoint_dir) + + weight_mapper = checkpoint_loader.get_initialized_weight_mapper( + model, config) + self._call_load_weights(model.load_weights, weights, + weight_mapper) - logger.info("Use %.2f GB for model weights.", - self._get_rank_model_storage(model) / (1024**3)) + if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( + ): + weights = checkpoint_loader.load_weights( + self.spec_config.speculative_model_dir) + self._call_load_weights(model.load_draft_weights, weights, + weight_mapper) + + elif load_format == LoadFormat.DUMMY: + initialize_dummy_weights(model) + if self.spec_config is not None and self.spec_config.spec_dec_mode.need_load_draft_weights( + ): + model.draft_model.load_weights_from_target_model(model) + + elif load_format == LoadFormat.VISION_ONLY: + # Vision weights are already loaded within the model. + logger.info( + "LoadFormat.VISION_ONLY: skipping weight loading; using preloaded vision weights." + ) - self._load_weights(model, config, checkpoint_dir, checkpoint_loader) + else: + raise NotImplementedError( + f"No load support for load format: {load_format}") if isinstance(moe_load_balancer, MoeLoadBalancer): - self.moe_load_balancer = moe_load_balancer + setattr(self, "moe_load_balancer", moe_load_balancer) moe_load_balancer.register_weight_slots_after_to_cuda() logger.info("moe_load_balancer finalizing model...") moe_load_balancer.finalize_model() @@ -167,46 +284,12 @@ def load( if drafting_loop_wrapper is not None: model = drafting_loop_wrapper(model) + self.model_is_wrapped = True + else: + self.model_is_wrapped = False return model - def _load_weights(self, model: DecoderModelForCausalLM, config: ModelConfig, - checkpoint_dir: str, - checkpoint_loader: BaseCheckpointLoader): - """Handles the logic for loading weights based on the specified format.""" - load_format = self.pytorch_backend_config.load_format - - if load_format == LoadFormat.AUTO: - checkpoint_path = (getattr(model, 'llm_checkpoint_dir', None) - or checkpoint_dir) - weights = checkpoint_loader.load_weights(checkpoint_path) - weight_mapper = checkpoint_loader.get_initialized_weight_mapper( - model, config) - self._call_load_weights(model.load_weights, weights, weight_mapper) - - # Load draft model weights if needed for speculative decoding - if self.spec_config and self.spec_config.spec_dec_mode.need_load_draft_weights( - ): - draft_weights = checkpoint_loader.load_weights( - self.spec_config.speculative_model_dir) - self._call_load_weights(model.load_draft_weights, draft_weights, - weight_mapper) - - elif load_format == LoadFormat.DUMMY: - self._initialize_dummy_weights(model) - if self.spec_config and self.spec_config.spec_dec_mode.need_load_draft_weights( - ): - model.draft_model.load_weights_from_target_model(model) - - elif load_format == LoadFormat.VISION_ONLY: - logger.info( - "LoadFormat.VISION_ONLY: skipping weight loading; using preloaded vision weights." - ) - - else: - raise NotImplementedError( - f"No load support for load format: {load_format}") - def _load_and_validate_config( self, checkpoint_dir: str, checkpoint_loader: BaseCheckpointLoader) -> ModelConfig: @@ -243,69 +326,10 @@ def _load_and_validate_config( sub_config).num_hidden_layers = num_layers_override return config - @staticmethod - def _materialize_meta_model(model: torch.nn.Module): - """Converts a model on the 'meta' device to a materialized model on CUDA.""" - memo = {} - - def init_meta_tensor(t: torch.Tensor): - if t.device != torch.device('meta'): - return t - if t not in memo: - memo[t] = torch.empty_like(t, device='cuda') - return memo[t] - - model._apply(init_meta_tensor) - - @staticmethod - def _call_load_weights(load_method: Callable, weights, weight_mapper): + def _call_load_weights(self, load_method: Callable, weights, weight_mapper): """Calls the model's weight loading method with the correct arguments.""" args = inspect.getfullargspec(load_method).args if "weight_mapper" in args: load_method(weights, weight_mapper=weight_mapper) else: load_method(weights) - - @staticmethod - def _get_rank_model_storage(model: torch.nn.Module) -> int: - """Calculates the total memory in bytes used by the model's weights and buffers on the current device.""" - total_bytes = 0 - current_device_idx = torch.cuda.current_device() - for param in model.parameters(): - if param.device.type == 'cuda' and param.device.index == current_device_idx: - total_bytes += param.element_size() * param.nelement() - for buf in model.buffers(): - if buf.device.type == 'cuda' and buf.device.index == current_device_idx: - total_bytes += buf.element_size() * buf.nelement() - return total_bytes - - @staticmethod - def _initialize_dummy_weights(model: torch.nn.Module, - low: float = -1e-3, - high: float = 1e-3, - seed: int = 0) -> None: - """Initializes model weights with random dummy values for testing purposes.""" - - # This function's logic is copied directly from the original file - def _get_random_min_max(dtype: torch.dtype) -> Tuple[int, int]: - if dtype in (torch.float8_e4m3fn, torch.int8): - return (-3.0, 3.0) - elif dtype == float4_e2m1x2: - return (0, 64) - else: - raise NotImplementedError(f"Unknown quantized type: {dtype}.") - - for param in model.state_dict().values(): - generator = torch.Generator(device=param.data.device) - generator.manual_seed(seed) - dtype = param.data.dtype - - if param.data.element_size() < 2: - tmp_param = torch.empty_like(param.data, - dtype=torch.float16, - device=param.data.device) - quant_min, quant_max = _get_random_min_max(dtype) - tmp_param.uniform_(quant_min, quant_max, generator=generator) - param.data.copy_(tmp_param.to(dtype)) - elif torch.is_floating_point(param): - param.uniform_(low, high, generator=generator) From 746e486f5e1fec42e205f3208c73c4dc8a3e3731 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Mon, 8 Sep 2025 15:13:15 -0700 Subject: [PATCH 04/11] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/model_loader.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 35bbd2d6bc7..f4f86ad277a 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -22,7 +22,6 @@ MoeLoadBalancer, maybe_create_moe_load_balancer) from .config import LoadFormat, PyTorchConfig -# Constants from the original file for KV cache validation _KV_CACHE_MAP = { "fp8": QuantAlgo.FP8.value, "nvfp4": QuantAlgo.NVFP4.value, @@ -189,7 +188,7 @@ def load( checkpoint_dir: str, checkpoint_loader: BaseCheckpointLoader, drafting_loop_wrapper: Optional[Callable[[torch.nn.Module], - torch.nn.Module]] = None + torch.nn.Module]] = None, ) -> DecoderModelForCausalLM: """ Loads the model, its weights, and applies necessary configurations. @@ -308,7 +307,11 @@ def _load_and_validate_config( moe_load_balancer=self.pytorch_backend_config.moe_load_balancer, lora_config=self.lora_config, allreduce_strategy=self.pytorch_backend_config.allreduce_strategy, - mm_encoder_only=self.pytorch_backend_config.mm_encoder_only) + mm_encoder_only=self.pytorch_backend_config.mm_encoder_only, + attn_backend=self.pytorch_backend_config.attn_backend, + moe_backend=self.pytorch_backend_config.moe_backend, + moe_disable_finalize_fusion=self.pytorch_backend_config. + moe_disable_finalize_fusion) validate_and_set_kv_cache_quant( config, self.pytorch_backend_config.kv_cache_dtype) From c64d32076619125128b13290a1343bc95a960fe6 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Tue, 9 Sep 2025 09:52:42 -0700 Subject: [PATCH 05/11] fix ci Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- .../_torch/pyexecutor/model_engine.py | 35 +++++++++++-------- .../_torch/pyexecutor/model_loader.py | 9 ----- .../executor/test_pytorch_model_engine.py | 8 ++--- 3 files changed, 24 insertions(+), 28 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index da2a133cab0..e3ca240c891 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -141,6 +141,7 @@ def __init__( is_draft_model: bool = False, drafting_loop_wrapper: Optional[Callable[[torch.nn.Module], torch.nn.Module]] = None, + model: Optional[torch.nn.Module] = None, ): self.ub_buffers = None self.batch_size = batch_size @@ -163,19 +164,24 @@ def __init__( self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures( ) - attn_backend = pytorch_backend_config.attn_backend - loader = ModelLoader( - pytorch_backend_config=pytorch_backend_config, - mapping=self.mapping, - spec_config=self.spec_config, - max_num_tokens=max_num_tokens, - max_seq_len=max_seq_len, - lora_config=lora_config, - ) - self.model = loader.load(checkpoint_dir=model_path, - checkpoint_loader=checkpoint_loader, - drafting_loop_wrapper=drafting_loop_wrapper) - self.model_is_wrapped = loader.model_is_wrapped + if model is None: + loader = ModelLoader( + pytorch_backend_config=pytorch_backend_config, + mapping=self.mapping, + spec_config=self.spec_config, + max_num_tokens=max_num_tokens, + max_seq_len=max_seq_len, + lora_config=lora_config, + ) + self.model = loader.load(checkpoint_dir=model_path, + checkpoint_loader=checkpoint_loader) + else: + self.model = model + if drafting_loop_wrapper is not None: + self.model = drafting_loop_wrapper(self.model) + self.model_is_wrapped = True + else: + self.model_is_wrapped = False # In case that some tests use stub models and override `_load_model`. if not hasattr(self.model, 'extra_attrs'): self.model.extra_attrs = {} @@ -246,7 +252,8 @@ def __init__( self.is_warmup = False - self.attn_backend = get_attention_backend(attn_backend) + self.attn_backend = get_attention_backend( + pytorch_backend_config.attn_backend) if self.is_spec_decode: self.spec_metadata = None diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index f4f86ad277a..e061cc6ca1c 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -187,8 +187,6 @@ def load( self, checkpoint_dir: str, checkpoint_loader: BaseCheckpointLoader, - drafting_loop_wrapper: Optional[Callable[[torch.nn.Module], - torch.nn.Module]] = None, ) -> DecoderModelForCausalLM: """ Loads the model, its weights, and applies necessary configurations. @@ -196,7 +194,6 @@ def load( Args: checkpoint_dir: The directory of the model checkpoint. checkpoint_loader: The loader object for model checkpoints. - drafting_loop_wrapper: An optional wrapper for speculative decoding models. Returns: The loaded and initialized PyTorch model. @@ -281,12 +278,6 @@ def init_meta_tensor(t: torch.Tensor): torch.cuda.current_stream().synchronize() - if drafting_loop_wrapper is not None: - model = drafting_loop_wrapper(model) - self.model_is_wrapped = True - else: - self.model_is_wrapped = False - return model def _load_and_validate_config( diff --git a/tests/unittest/_torch/executor/test_pytorch_model_engine.py b/tests/unittest/_torch/executor/test_pytorch_model_engine.py index 8a06a3a9f0f..ec53a1ae832 100644 --- a/tests/unittest/_torch/executor/test_pytorch_model_engine.py +++ b/tests/unittest/_torch/executor/test_pytorch_model_engine.py @@ -67,16 +67,14 @@ def __init__(self, mapping = Mapping(world_size=tensorrt_llm.mpi_world_size(), tp_size=tensorrt_llm.mpi_world_size(), rank=tensorrt_llm.mpi_rank()) - self.model_is_wrapped = False + model = DummyModel(self.dtype) super().__init__(model_path="", pytorch_backend_config=pytorch_backend_config, checkpoint_loader=None, batch_size=batch_size, max_seq_len=max_seq_len, - mapping=mapping) - - def _load_model(self, mode_path: str, **kwargs) -> torch.nn.Module: - return DummyModel(self.dtype) + mapping=mapping, + model=model) def _create_request(num_tokens, req_id: int): From 199facecd92f797b15498c6772a7ce03584a4f48 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Mon, 22 Sep 2025 12:01:26 +0800 Subject: [PATCH 06/11] clean Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 12 ++++++++++-- tensorrt_llm/_torch/pyexecutor/model_loader.py | 13 +------------ 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 433c870a2a3..8623c7a000c 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -37,8 +37,8 @@ from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids from ..models.modeling_utils import DecoderModelForCausalLM -from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer, - MoeLoadBalancerIterContext) +from ..modules.fused_moe.moe_load_balancer import ( + MoeLoadBalancer, MoeLoadBalancerIterContext, maybe_create_moe_load_balancer) from ..speculative import (SpecMetadata, get_num_extra_kv_tokens, get_spec_metadata, update_spec_config_from_model_config) @@ -175,6 +175,14 @@ def __init__( ) self.model = loader.load(checkpoint_dir=model_path, checkpoint_loader=checkpoint_loader) + moe_load_balancer = maybe_create_moe_load_balancer( + self.model.config, self.mapping) + if isinstance(moe_load_balancer, MoeLoadBalancer): + setattr(self, "moe_load_balancer", moe_load_balancer) + moe_load_balancer.register_weight_slots_after_to_cuda() + logger.info("moe_load_balancer finalizing model...") + moe_load_balancer.finalize_model() + logger.info("moe_load_balancer finalize model done") else: self.model = model if drafting_loop_wrapper is not None: diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index fccce44f5f3..3d68b64acd4 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -18,8 +18,6 @@ from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode, timing) -from ..modules.fused_moe.moe_load_balancer import ( - MoeLoadBalancer, maybe_create_moe_load_balancer) from .config import LoadFormat, PyTorchConfig _KV_CACHE_MAP = { @@ -202,9 +200,7 @@ def load( checkpoint_loader) load_format = self.pytorch_backend_config.load_format - with timing("Model init total"), maybe_create_moe_load_balancer( - config, self.mapping) as moe_load_balancer: - + with timing("Model init total"): try: # config will be modified in-place for some models, like Qwen2 config_copy = copy.deepcopy(config) @@ -269,13 +265,6 @@ def init_meta_tensor(t: torch.Tensor): raise NotImplementedError( f"No load support for load format: {load_format}") - if isinstance(moe_load_balancer, MoeLoadBalancer): - setattr(self, "moe_load_balancer", moe_load_balancer) - moe_load_balancer.register_weight_slots_after_to_cuda() - logger.info("moe_load_balancer finalizing model...") - moe_load_balancer.finalize_model() - logger.info("moe_load_balancer finalize model done") - torch.cuda.current_stream().synchronize() return model From c3a01d604bc29da181b8f48cfc9347412a430698 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Mon, 22 Sep 2025 12:16:33 +0800 Subject: [PATCH 07/11] fix ci Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/model_loader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 3d68b64acd4..180b4c99319 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -276,6 +276,7 @@ def _load_and_validate_config( config = checkpoint_loader.load_config( checkpoint_dir, trust_remote_code=True, + mapping=self.mapping, enable_min_latency=self.pytorch_backend_config.enable_min_latency, use_cuda_graph=self.pytorch_backend_config.use_cuda_graph, force_dynamic_quantization=self.pytorch_backend_config. From 7169c51e0a7834f993e9725b60717465da7abc86 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Mon, 22 Sep 2025 12:24:05 +0800 Subject: [PATCH 08/11] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 8623c7a000c..8576644af8e 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -176,7 +176,7 @@ def __init__( self.model = loader.load(checkpoint_dir=model_path, checkpoint_loader=checkpoint_loader) moe_load_balancer = maybe_create_moe_load_balancer( - self.model.config, self.mapping) + self.model.model_config, self.mapping) if isinstance(moe_load_balancer, MoeLoadBalancer): setattr(self, "moe_load_balancer", moe_load_balancer) moe_load_balancer.register_weight_slots_after_to_cuda() From 21ff6577069f6eea1ba56e67958f4299c42277b9 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Tue, 23 Sep 2025 15:06:51 +0800 Subject: [PATCH 09/11] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 6 +++--- tensorrt_llm/_torch/pyexecutor/model_loader.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index b490d1e77b8..eab94ceb614 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -173,10 +173,10 @@ def __init__( max_seq_len=max_seq_len, lora_config=lora_config, ) - self.model = loader.load(checkpoint_dir=model_path, - checkpoint_loader=checkpoint_loader) + self.model, config = loader.load( + checkpoint_dir=model_path, checkpoint_loader=checkpoint_loader) moe_load_balancer = maybe_create_moe_load_balancer( - self.model.model_config, self.mapping) + config, self.mapping) if isinstance(moe_load_balancer, MoeLoadBalancer): setattr(self, "moe_load_balancer", moe_load_balancer) moe_load_balancer.register_weight_slots_after_to_cuda() diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 180b4c99319..c7a73a9ec53 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -185,7 +185,7 @@ def load( self, checkpoint_dir: str, checkpoint_loader: BaseCheckpointLoader, - ) -> DecoderModelForCausalLM: + ) -> Tuple[DecoderModelForCausalLM, ModelConfig]: """ Loads the model, its weights, and applies necessary configurations. @@ -267,7 +267,7 @@ def init_meta_tensor(t: torch.Tensor): torch.cuda.current_stream().synchronize() - return model + return model, config def _load_and_validate_config( self, checkpoint_dir: str, From 8904e9c4c12e7038d48b5c4fd7c63766fa4c271c Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 24 Sep 2025 08:47:57 +0800 Subject: [PATCH 10/11] clean Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/model_loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index c7a73a9ec53..4186099005f 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -179,7 +179,6 @@ def __init__(self, self.max_num_tokens = max_num_tokens self.max_seq_len = max_seq_len self.lora_config = lora_config - self.moe_load_balancer = None def load( self, From 797f5e1cebea79e997541cbda9eee36e36516a67 Mon Sep 17 00:00:00 2001 From: junq <22017000+QiJune@users.noreply.github.com> Date: Wed, 24 Sep 2025 09:38:10 +0800 Subject: [PATCH 11/11] fix Signed-off-by: junq <22017000+QiJune@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/model_engine.py | 12 +++--------- tensorrt_llm/_torch/pyexecutor/model_loader.py | 18 +++++++++++++----- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index eab94ceb614..115bd2ce393 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -37,8 +37,8 @@ from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids from ..models.modeling_utils import DecoderModelForCausalLM -from ..modules.fused_moe.moe_load_balancer import ( - MoeLoadBalancer, MoeLoadBalancerIterContext, maybe_create_moe_load_balancer) +from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer, + MoeLoadBalancerIterContext) from ..speculative import (SpecMetadata, get_num_extra_kv_tokens, get_spec_metadata, update_spec_config_from_model_config) @@ -173,16 +173,10 @@ def __init__( max_seq_len=max_seq_len, lora_config=lora_config, ) - self.model, config = loader.load( + self.model, moe_load_balancer = loader.load( checkpoint_dir=model_path, checkpoint_loader=checkpoint_loader) - moe_load_balancer = maybe_create_moe_load_balancer( - config, self.mapping) if isinstance(moe_load_balancer, MoeLoadBalancer): setattr(self, "moe_load_balancer", moe_load_balancer) - moe_load_balancer.register_weight_slots_after_to_cuda() - logger.info("moe_load_balancer finalizing model...") - moe_load_balancer.finalize_model() - logger.info("moe_load_balancer finalize model done") else: self.model = model if drafting_loop_wrapper is not None: diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 4186099005f..68e877e20b4 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -16,8 +16,9 @@ from ..model_config import ModelConfig from ..models import AutoModelForCausalLM from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader -from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode, - timing) +from ..models.modeling_utils import MetaInitMode, timing +from ..modules.fused_moe.moe_load_balancer import ( + MoeLoadBalancer, maybe_create_moe_load_balancer) from .config import LoadFormat, PyTorchConfig _KV_CACHE_MAP = { @@ -184,7 +185,7 @@ def load( self, checkpoint_dir: str, checkpoint_loader: BaseCheckpointLoader, - ) -> Tuple[DecoderModelForCausalLM, ModelConfig]: + ): """ Loads the model, its weights, and applies necessary configurations. @@ -199,7 +200,8 @@ def load( checkpoint_loader) load_format = self.pytorch_backend_config.load_format - with timing("Model init total"): + with timing("Model init total"), maybe_create_moe_load_balancer( + config, self.mapping) as moe_load_balancer: try: # config will be modified in-place for some models, like Qwen2 config_copy = copy.deepcopy(config) @@ -264,9 +266,15 @@ def init_meta_tensor(t: torch.Tensor): raise NotImplementedError( f"No load support for load format: {load_format}") + if isinstance(moe_load_balancer, MoeLoadBalancer): + moe_load_balancer.register_weight_slots_after_to_cuda() + logger.info("moe_load_balancer finalizing model...") + moe_load_balancer.finalize_model() + logger.info("moe_load_balancer finalize model done") + torch.cuda.current_stream().synchronize() - return model, config + return model, moe_load_balancer def _load_and_validate_config( self, checkpoint_dir: str,