Skip to content

Conversation

@yiliu30
Copy link
Owner

@yiliu30 yiliu30 commented Jun 30, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results

PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.

Purpose

Test Plan

Test Result

BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)

yiliu30 added 2 commits June 29, 2025 23:28
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yiliu30, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces initial support for MXFP8 (mixed-precision FP8) quantization within Mixture-of-Experts (MoE) layers, aiming to enhance performance and memory efficiency for large language models. The core change involves a new quantization method that manages FP8 weights and their associated scales, along with an emulated forward pass for MXFP8 operations. This lays the groundwork for more efficient MoE inference.

Highlights

  • New MXFP8 MoE Quantization Method: I've introduced a new quantization method, CompressedTensorsW8A8MXFp8MoEMethod, specifically designed for Mixture-of-Experts (MoE) layers using mixed-precision FP8 (MXFP8). This method focuses on per-tensor group quantization for weights and dynamic per-token quantization for input activations.
  • FP8 Weight and Scale Handling: The new method creates torch.float8_e4m3fn tensors for MoE weights (w13_weight, w2_weight) and torch.uint8 tensors for their corresponding MXFP8 scales (w13_weight_scale, w2_weight_scale), adhering to a group size of 32 for per-tensor group quantization.
  • Emulated MXFP8 Forward Pass: An emulated forward pass for MXFP8 MoE is implemented within the apply method. This includes selecting experts, applying masking, and performing matrix multiplications using run_mxfp8_emulations, which handles the dequantization of weights and optional quantization/dequantization of inputs.
  • Quantization Scale Calculation Adjustment: I've updated the quant_mx_fp8 function to use ScaleCalculationMode.RCEIL instead of ScaleCalculationMode.FLOOR for calculating quantization scales. This change likely improves the precision or behavior of the MXFP8 quantization process.
  • Refined FP8 W8A8 Strategy Check: The _is_fp8_w8a8 check in compressed_tensors.py has been updated to no longer include QuantizationStrategy.TENSOR_GROUP. This change ensures that the new MXFP8 MoE method, which specifically uses TENSOR_GROUP, is the primary handler for such configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for CUDA MXFP8 MoE (Mixture of Experts) layers. The main changes include a new CompressedTensorsW8A8MXFp8MoEMethod to handle MXFP8 MoE layers, a modification in the MXFP8 quantization scheme, and an adjustment in the generic FP8 check.

The implementation of the new MoE method is a good start, but there are several areas for improvement regarding code clarity, maintainability, and potential bugs. I've pointed out a redundant import, some dead code due to contradictory logic, and opportunities to refactor the apply method for better readability by moving out local imports, nested functions, and the large HPU-specific block. Addressing these points will improve the quality of the new feature.

Comment on lines +180 to +193
if self.static_input_scales:
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)

w2_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block of code handles self.static_input_scales, but the __init__ method on lines 78-82 raises a ValueError if self.static_input_scales is True. This makes this code block unreachable. The FIXME comment also suggests it should be removed. This dead code should be removed to avoid confusion and prevent potential bugs if the check in __init__ is changed in the future.

Comment on lines +283 to +304
from .schemes.compressed_tensors_w8a8_mxfp8 import (
dequant_mx_fp8,
quant_mx_fp8,
)

def run_mxfp8_emulations(x, weight, weight_scale, bias=None):
dequnat_weight = dequant_mx_fp8(
weight_fp8=weight.data,
scale_e8m0=weight_scale.data,
block_size=self.group_size,
)
dequnat_weight = dequnat_weight.to(x.dtype)
if not envs.VLLM_DISABLE_INPUT_QDQ:
x_scale, x_quant = quant_mx_fp8(x)
dequant_x = dequant_mx_fp8(
weight_fp8=x_quant,
scale_e8m0=x_scale,
block_size=self.group_size,
)
x = dequant_x.to(x.dtype)
out = x @ dequnat_weight.t()
return out.to(x.dtype) + (bias if bias is not None else 0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are a couple of structural issues in this part of the apply method:

  1. Local Import: The import from .schemes.compressed_tensors_w8a8_mxfp8 is performed inside the apply method. Imports should generally be at the top of the file. This improves readability, avoids repeated import overhead, and helps prevent potential circular import issues.
  2. Nested Function Definition: The function run_mxfp8_emulations is defined inside the apply method. Defining functions inside frequently called methods can incur performance overhead due to repeated creation of the function object.

Please consider moving the import to the top of the file and defining run_mxfp8_emulations as a private helper method of the class (e.g., _run_mxfp8_emulations).

from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
import vllm.envs as envs

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The vllm.envs module is imported twice (here and on line 16). Please remove this redundant import to keep the code clean.

"For MXFP8 Fused MoE layers, we require per tensor group"
f"Found weight: {self.weight_quant}, input {self.input_quant}"
)
self.group_size = 32

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The group_size is hardcoded to 32. While this might be a requirement for the MXFP8 format, it's better to define it as a named constant at the module level (e.g., _MXFP8_GROUP_SIZE = 32). This improves readability and makes it clear that this is a fixed property of the scheme, rather than an arbitrary magic number.

Comment on lines +232 to +331
if envs.VLLM_USE_STATIC_MOE_HPU:
num_experts, intermediate_size_per_partition_2x, _ = (
layer.w13_weight.shape
)
intermediate_size_per_partition = (
intermediate_size_per_partition_2x // 2
)
# FIXME: Handle mask
act_fn = F.silu
num_all_tokens, hidden_dim = x.shape
num_experts = layer.local_num_experts
total_num_experts = router_logits.size(-1)
experts_mask = torch.zeros(
(x.size(0), total_num_experts), dtype=x.dtype, device=x.device
)
topk_ids = topk_ids.to(torch.int64)
topk_weights = topk_weights.to(x.dtype)
experts_mask.scatter_(-1, topk_ids, topk_weights)
experts_mask = experts_mask.transpose(0, 1)

mask_weights = torch.zeros(
(num_all_tokens, total_num_experts),
dtype=x.dtype,
device=x.device,
)
mask_weights.scatter_(-1, topk_ids, 1)
mask_weights = mask_weights.transpose(0, 1)
# Note: ep_size equal tp_size
ep_rank = get_tensor_model_parallel_rank()
ep_shift = ep_rank * num_experts

for expert_index in range(num_experts):
mask_weight = mask_weights[expert_index + ep_shift].unsqueeze(1)
current_state_static = x * mask_weight

local_w13 = layer.w13_weight[expert_index]
local_w13_scale = layer.w13_weight_scale[expert_index]

local_w2 = layer.w2_weight[expert_index]
local_w2_scale = layer.w2_weight_scale[expert_index]

local_w1 = local_w13[:intermediate_size_per_partition, ...]
local_w1_scale = local_w13_scale[
:intermediate_size_per_partition, ...
]

local_w3 = local_w13[intermediate_size_per_partition:, ...]
local_w3_scale = local_w13_scale[
intermediate_size_per_partition:, ...
]

from .schemes.compressed_tensors_w8a8_mxfp8 import (
dequant_mx_fp8,
quant_mx_fp8,
)

def run_mxfp8_emulations(x, weight, weight_scale, bias=None):
dequnat_weight = dequant_mx_fp8(
weight_fp8=weight.data,
scale_e8m0=weight_scale.data,
block_size=self.group_size,
)
dequnat_weight = dequnat_weight.to(x.dtype)
if not envs.VLLM_DISABLE_INPUT_QDQ:
x_scale, x_quant = quant_mx_fp8(x)
dequant_x = dequant_mx_fp8(
weight_fp8=x_quant,
scale_e8m0=x_scale,
block_size=self.group_size,
)
x = dequant_x.to(x.dtype)
out = x @ dequnat_weight.t()
return out.to(x.dtype) + (bias if bias is not None else 0)

local_w1_out = run_mxfp8_emulations(
x=current_state_static,
weight=local_w1,
weight_scale=local_w1_scale,
)
local_w3_out = run_mxfp8_emulations(
x=current_state_static,
weight=local_w3,
weight_scale=local_w3_scale,
)
w13_out = act_fn(local_w1_out) * local_w3_out

local_w2_out = run_mxfp8_emulations(
x=w13_out,
weight=local_w2,
weight_scale=local_w2_scale,
)
padded_weight = experts_mask[expert_index + ep_shift].unsqueeze(
1
)
local_w2_out = local_w2_out * padded_weight
if expert_index == 0:
final_hidden_states = local_w2_out
else:
final_hidden_states += local_w2_out
return final_hidden_states

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic inside if envs.VLLM_USE_STATIC_MOE_HPU: is quite long and complex, making the apply method hard to read and maintain. Consider refactoring this block into a separate private helper method, for example _apply_hpu_static_moe(...). This would improve the clarity and modularity of the apply method.

yiliu30 and others added 11 commits June 30, 2025 02:49
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Co-authored-by: yiliu30 <[email protected]>
Co-authored-by: He, Xin3 <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: yiliu30 <[email protected]>
Co-authored-by: Copilot <[email protected]>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: He, Xin3 <[email protected]>
Co-authored-by: Yi Liu <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants