diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index d0b00450f3b..4c5f6ae3dc0 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -52,8 +52,7 @@ from ..modules.attention import MLA from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding -from ..modules.fused_moe import (CutlassFusedMoE, DeepSeekV3MoeRoutingMethod, - create_moe) +from ..modules.fused_moe import DeepSeekV3MoeRoutingMethod, create_moe from ..modules.gated_mlp import GatedMLP from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig from ..modules.multi_stream_utils import maybe_execute_in_parallel @@ -502,8 +501,11 @@ def _compute_shared_expert_tp_size(self, intermediate_size: int, def compute_routed_output(self, hidden_states, hidden_states_fp4, all_rank_num_tokens, do_finalize): # max-throughput - use_dp_padding = False if self.use_dp and self.mapping.tp_size > 1: + max_num_token = max(all_rank_num_tokens) + hidden_states = torch.nn.functional.pad( + hidden_states, + (0, 0, 0, max_num_token - hidden_states.shape[0])) # FP4 all_gather moves this bf16 allgather in to after topk and fp4 quantization # to reduce allreduce BW if disable_fp4_allgather() and not self.experts.enable_alltoall: @@ -511,14 +513,6 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4, self.mapping, dim=0, sizes=all_rank_num_tokens) - elif not isinstance(self.experts, CutlassFusedMoE) or ( - not self.experts.has_fp8_qdq and self.experts.has_nvfp4): - # Use padding when not using the cutlass path or when x_sf in self.experts is not None - use_dp_padding = True - max_num_token = max(all_rank_num_tokens) - hidden_states = torch.nn.functional.pad( - hidden_states, - (0, 0, 0, max_num_token - hidden_states.shape[0])) router_logits = self.gate(hidden_states) @@ -528,7 +522,7 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4, do_finalize=do_finalize, output_dtype=hidden_states.dtype, all_rank_num_tokens=all_rank_num_tokens, - use_dp_padding=use_dp_padding, + use_dp_padding=True, ) return routed_output