- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2
Cuda mxfp8 moe #49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: mxfp8
Are you sure you want to change the base?
Cuda mxfp8 moe #49
Conversation
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @yiliu30, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces initial support for MXFP8 (mixed-precision FP8) quantization within Mixture-of-Experts (MoE) layers, aiming to enhance performance and memory efficiency for large language models. The core change involves a new quantization method that manages FP8 weights and their associated scales, along with an emulated forward pass for MXFP8 operations. This lays the groundwork for more efficient MoE inference.
Highlights
- New MXFP8 MoE Quantization Method: I've introduced a new quantization method, CompressedTensorsW8A8MXFp8MoEMethod, specifically designed for Mixture-of-Experts (MoE) layers using mixed-precision FP8 (MXFP8). This method focuses on per-tensor group quantization for weights and dynamic per-token quantization for input activations.
- FP8 Weight and Scale Handling: The new method creates torch.float8_e4m3fntensors for MoE weights (w13_weight,w2_weight) andtorch.uint8tensors for their corresponding MXFP8 scales (w13_weight_scale,w2_weight_scale), adhering to a group size of 32 for per-tensor group quantization.
- Emulated MXFP8 Forward Pass: An emulated forward pass for MXFP8 MoE is implemented within the applymethod. This includes selecting experts, applying masking, and performing matrix multiplications usingrun_mxfp8_emulations, which handles the dequantization of weights and optional quantization/dequantization of inputs.
- Quantization Scale Calculation Adjustment: I've updated the quant_mx_fp8function to useScaleCalculationMode.RCEILinstead ofScaleCalculationMode.FLOORfor calculating quantization scales. This change likely improves the precision or behavior of the MXFP8 quantization process.
- Refined FP8 W8A8 Strategy Check: The _is_fp8_w8a8check incompressed_tensors.pyhas been updated to no longer includeQuantizationStrategy.TENSOR_GROUP. This change ensures that the new MXFP8 MoE method, which specifically usesTENSOR_GROUP, is the primary handler for such configurations.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description | 
|---|---|---|
| Code Review | /gemini review | Performs a code review for the current pull request in its current state. | 
| Pull Request Summary | /gemini summary | Provides a summary of the current pull request in its current state. | 
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. | 
| Help | /gemini help | Displays a list of available commands. | 
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
- 
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩ 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for CUDA MXFP8 MoE (Mixture of Experts) layers. The main changes include a new CompressedTensorsW8A8MXFp8MoEMethod to handle MXFP8 MoE layers, a modification in the MXFP8 quantization scheme, and an adjustment in the generic FP8 check.
The implementation of the new MoE method is a good start, but there are several areas for improvement regarding code clarity, maintainability, and potential bugs. I've pointed out a redundant import, some dead code due to contradictory logic, and opportunities to refactor the apply method for better readability by moving out local imports, nested functions, and the large HPU-specific block. Addressing these points will improve the quality of the new feature.
| if self.static_input_scales: | ||
| w13_input_scale = torch.nn.Parameter( | ||
| torch.ones(num_experts, dtype=torch.float32), | ||
| requires_grad=False, | ||
| ) | ||
| layer.register_parameter("w13_input_scale", w13_input_scale) | ||
| set_weight_attrs(w13_input_scale, extra_weight_attrs) | ||
|  | ||
| w2_input_scale = torch.nn.Parameter( | ||
| torch.ones(num_experts, dtype=torch.float32), | ||
| requires_grad=False, | ||
| ) | ||
| layer.register_parameter("w2_input_scale", w2_input_scale) | ||
| set_weight_attrs(w2_input_scale, extra_weight_attrs) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code handles self.static_input_scales, but the __init__ method on lines 78-82 raises a ValueError if self.static_input_scales is True. This makes this code block unreachable. The FIXME comment also suggests it should be removed. This dead code should be removed to avoid confusion and prevent potential bugs if the check in __init__ is changed in the future.
| from .schemes.compressed_tensors_w8a8_mxfp8 import ( | ||
| dequant_mx_fp8, | ||
| quant_mx_fp8, | ||
| ) | ||
|  | ||
| def run_mxfp8_emulations(x, weight, weight_scale, bias=None): | ||
| dequnat_weight = dequant_mx_fp8( | ||
| weight_fp8=weight.data, | ||
| scale_e8m0=weight_scale.data, | ||
| block_size=self.group_size, | ||
| ) | ||
| dequnat_weight = dequnat_weight.to(x.dtype) | ||
| if not envs.VLLM_DISABLE_INPUT_QDQ: | ||
| x_scale, x_quant = quant_mx_fp8(x) | ||
| dequant_x = dequant_mx_fp8( | ||
| weight_fp8=x_quant, | ||
| scale_e8m0=x_scale, | ||
| block_size=self.group_size, | ||
| ) | ||
| x = dequant_x.to(x.dtype) | ||
| out = x @ dequnat_weight.t() | ||
| return out.to(x.dtype) + (bias if bias is not None else 0) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a couple of structural issues in this part of the apply method:
- Local Import: The import from .schemes.compressed_tensors_w8a8_mxfp8is performed inside theapplymethod. Imports should generally be at the top of the file. This improves readability, avoids repeated import overhead, and helps prevent potential circular import issues.
- Nested Function Definition: The function run_mxfp8_emulationsis defined inside theapplymethod. Defining functions inside frequently called methods can incur performance overhead due to repeated creation of the function object.
Please consider moving the import to the top of the file and defining run_mxfp8_emulations as a private helper method of the class (e.g., _run_mxfp8_emulations).
| from vllm.model_executor.utils import set_weight_attrs | ||
| from vllm.platforms import current_platform | ||
| from vllm.scalar_type import scalar_types | ||
| import vllm.envs as envs | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| "For MXFP8 Fused MoE layers, we require per tensor group" | ||
| f"Found weight: {self.weight_quant}, input {self.input_quant}" | ||
| ) | ||
| self.group_size = 32 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The group_size is hardcoded to 32. While this might be a requirement for the MXFP8 format, it's better to define it as a named constant at the module level (e.g., _MXFP8_GROUP_SIZE = 32). This improves readability and makes it clear that this is a fixed property of the scheme, rather than an arbitrary magic number.
| if envs.VLLM_USE_STATIC_MOE_HPU: | ||
| num_experts, intermediate_size_per_partition_2x, _ = ( | ||
| layer.w13_weight.shape | ||
| ) | ||
| intermediate_size_per_partition = ( | ||
| intermediate_size_per_partition_2x // 2 | ||
| ) | ||
| # FIXME: Handle mask | ||
| act_fn = F.silu | ||
| num_all_tokens, hidden_dim = x.shape | ||
| num_experts = layer.local_num_experts | ||
| total_num_experts = router_logits.size(-1) | ||
| experts_mask = torch.zeros( | ||
| (x.size(0), total_num_experts), dtype=x.dtype, device=x.device | ||
| ) | ||
| topk_ids = topk_ids.to(torch.int64) | ||
| topk_weights = topk_weights.to(x.dtype) | ||
| experts_mask.scatter_(-1, topk_ids, topk_weights) | ||
| experts_mask = experts_mask.transpose(0, 1) | ||
|  | ||
| mask_weights = torch.zeros( | ||
| (num_all_tokens, total_num_experts), | ||
| dtype=x.dtype, | ||
| device=x.device, | ||
| ) | ||
| mask_weights.scatter_(-1, topk_ids, 1) | ||
| mask_weights = mask_weights.transpose(0, 1) | ||
| # Note: ep_size equal tp_size | ||
| ep_rank = get_tensor_model_parallel_rank() | ||
| ep_shift = ep_rank * num_experts | ||
|  | ||
| for expert_index in range(num_experts): | ||
| mask_weight = mask_weights[expert_index + ep_shift].unsqueeze(1) | ||
| current_state_static = x * mask_weight | ||
|  | ||
| local_w13 = layer.w13_weight[expert_index] | ||
| local_w13_scale = layer.w13_weight_scale[expert_index] | ||
|  | ||
| local_w2 = layer.w2_weight[expert_index] | ||
| local_w2_scale = layer.w2_weight_scale[expert_index] | ||
|  | ||
| local_w1 = local_w13[:intermediate_size_per_partition, ...] | ||
| local_w1_scale = local_w13_scale[ | ||
| :intermediate_size_per_partition, ... | ||
| ] | ||
|  | ||
| local_w3 = local_w13[intermediate_size_per_partition:, ...] | ||
| local_w3_scale = local_w13_scale[ | ||
| intermediate_size_per_partition:, ... | ||
| ] | ||
|  | ||
| from .schemes.compressed_tensors_w8a8_mxfp8 import ( | ||
| dequant_mx_fp8, | ||
| quant_mx_fp8, | ||
| ) | ||
|  | ||
| def run_mxfp8_emulations(x, weight, weight_scale, bias=None): | ||
| dequnat_weight = dequant_mx_fp8( | ||
| weight_fp8=weight.data, | ||
| scale_e8m0=weight_scale.data, | ||
| block_size=self.group_size, | ||
| ) | ||
| dequnat_weight = dequnat_weight.to(x.dtype) | ||
| if not envs.VLLM_DISABLE_INPUT_QDQ: | ||
| x_scale, x_quant = quant_mx_fp8(x) | ||
| dequant_x = dequant_mx_fp8( | ||
| weight_fp8=x_quant, | ||
| scale_e8m0=x_scale, | ||
| block_size=self.group_size, | ||
| ) | ||
| x = dequant_x.to(x.dtype) | ||
| out = x @ dequnat_weight.t() | ||
| return out.to(x.dtype) + (bias if bias is not None else 0) | ||
|  | ||
| local_w1_out = run_mxfp8_emulations( | ||
| x=current_state_static, | ||
| weight=local_w1, | ||
| weight_scale=local_w1_scale, | ||
| ) | ||
| local_w3_out = run_mxfp8_emulations( | ||
| x=current_state_static, | ||
| weight=local_w3, | ||
| weight_scale=local_w3_scale, | ||
| ) | ||
| w13_out = act_fn(local_w1_out) * local_w3_out | ||
|  | ||
| local_w2_out = run_mxfp8_emulations( | ||
| x=w13_out, | ||
| weight=local_w2, | ||
| weight_scale=local_w2_scale, | ||
| ) | ||
| padded_weight = experts_mask[expert_index + ep_shift].unsqueeze( | ||
| 1 | ||
| ) | ||
| local_w2_out = local_w2_out * padded_weight | ||
| if expert_index == 0: | ||
| final_hidden_states = local_w2_out | ||
| else: | ||
| final_hidden_states += local_w2_out | ||
| return final_hidden_states | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic inside if envs.VLLM_USE_STATIC_MOE_HPU: is quite long and complex, making the apply method hard to read and maintain. Consider refactoring this block into a separate private helper method, for example _apply_hpu_static_moe(...). This would improve the clarity and modularity of the apply method.
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]> Co-authored-by: yiliu30 <[email protected]> Co-authored-by: He, Xin3 <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: yiliu30 <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: He, Xin3 <[email protected]> Co-authored-by: Yi Liu <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
…t#20921) Signed-off-by: Isotr0py <[email protected]>
Essential Elements of an Effective PR Description Checklist
PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.
Purpose
Test Plan
Test Result
BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)