Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions examples/offline_inference/basic/run_basic.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ MODEL_PATH=/data6/yiliu4/Qwen3-15B-A2B-Base-MXFP4-fp8attention/
MODEL_PATH=/data5/yliu7/HF_HOME/Yi30/gpt-oss-20b-BF16-MXFP8
MODEL_PATH="/data5/yliu7/HF_HOME/unsloth-gpt-oss-20b-BF16-ar-MXFP4/"

# PYTHONPATH=/home/yliu7/workspace/inc/3rd-party/vllm/vllm/model_executor/layers/quantization/auto_round_vllm_extension/:$PYTHONPATH \
# PYTHONPATH=/home/yiliu7/workspace/vllm/vllm/model_executor/layers/quantization/auto_round_vllm_extension/:$PYTHONPATH \
# VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \
# VLLM_ENABLE_AR_EXT=1 \
# VLLM_ENABLE_V1_MULTIPROCESSING=0 \
# VLLM_ENABLE_STATIC_MOE=1 \
# VLLM_AR_MXFP4_MODULAR_MOE=0 \
# python basic_local_2.py --tp 1 -e --model_path $MODEL_PATH
PYTHONPATH=/home/yliu7/workspace/inc/3rd-party/vllm/vllm/model_executor/layers/quantization/auto_round_vllm_extension/:$PYTHONPATH \
PYTHONPATH=/home/yiliu7/workspace/vllm/vllm/model_executor/layers/quantization/auto_round_vllm_extension/:$PYTHONPATH \
VLLM_MXFP4_PRE_UNPACK_WEIGHTS=0 \
VLLM_ENABLE_AR_EXT=1 \
VLLM_ENABLE_V1_MULTIPROCESSING=0 \
VLLM_ENABLE_STATIC_MOE=1 \
VLLM_AR_MXFP4_MODULAR_MOE=0 \
python basic_local_2.py --tp 1 -e --model_path $MODEL_PATH


PYTHONPATH=/home/yliu7/workspace/inc/3rd-party/vllm/vllm/model_executor/layers/quantization/auto_round_vllm_extension/:$PYTHONPATH \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str):
quant_method = AutoRoundMoEMethod.get_moe_method(self, layer, prefix)
else:
quant_method = super().get_quant_method(layer, prefix)
logger.debug(f"Apply {quant_method.__class__.__name__} to {prefix}")
return quant_method


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@

from typing import Union, Optional
import torch



FLOAT_TO_E2M1 = [
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
]

# Module-level device tensor cache
_DEVICE_E2M1_TENSORS = {}

# Constants for FP4 values (E2M1 format)
_E2M1_VALUES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]


def get_e2m1_tensor(device):
"""Get device-specific E2M1 lookup tensor, creating it if needed."""
device_str = str(device)
if device_str not in _DEVICE_E2M1_TENSORS:
_DEVICE_E2M1_TENSORS[device_str] = torch.tensor(
_E2M1_VALUES, dtype=torch.float32, device=device
)
return _DEVICE_E2M1_TENSORS[device_str]




def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
m, n = x.shape
device = x.device

# Create lookup table for FP4 values to indices
# Map the absolute values to 0-7 indices
kE2M1 = get_e2m1_tensor(x.device)

# Find closest valid FP4 value index for each element
abs_x = torch.abs(x)
abs_diff_x = torch.abs(abs_x.unsqueeze(-1) - kE2M1) # [m, n, 8]
abs_indices = torch.argmin(abs_diff_x, dim=-1) # [m, n]

# Apply sign bit (bit 3) to get final 4-bit representation
indices = abs_indices + (torch.signbit(x).to(torch.long) << 3)

# Reshape to prepare for packing pairs of values
indices = indices.reshape(-1)

# Handle odd length by padding if necessary
if indices.numel() % 2 != 0:
indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)])

# Reshape to pair consecutive elements
indices = indices.reshape(-1, 2)

# Pack pairs of 4-bit values into 8-bit values
packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8)

return packed.reshape(m, n // 2)


def unpack_fp4_from_uint8(
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
) -> torch.Tensor:
"""
Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values
(i.e. first four bits correspond to one fp4 value, last four correspond to a
consecutive fp4 value). The bits represent an index, which are mapped to an fp4
value.

:param a: tensor to unpack
:param m: original dim 0 size of the unpacked tensor
:param n: original dim 1 size of the unpacked tensor
:param dtype: dense dtype to cast the unpacked tensor to
"""
assert a.dtype == torch.uint8, f"expected uint8, got {a.dtype}"

# Vectorized nibble processing
a_flat = a.flatten()
high = (a_flat & 0xF0) >> 4 # Upper nibbles
low = a_flat & 0x0F # Lower nibbles

# Combine nibbles for batch processing
combined = torch.stack((low, high), dim=1).flatten()

# Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices

# Device-aware lookup and sign application
kE2M1 = get_e2m1_tensor(a.device)

values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)

# Reshape to final form
return values.reshape(m, n).to(dtype=dtype)



def cast_to_fp4(x):
sign = torch.sign(x)
x = torch.abs(x)
x[(x >= 0.0) & (x <= 0.25)] = 0.0
x[(x > 0.25) & (x < 0.75)] = 0.5
x[(x >= 0.75) & (x <= 1.25)] = 1.0
x[(x > 1.25) & (x < 1.75)] = 1.5
x[(x >= 1.75) & (x <= 2.5)] = 2.0
x[(x > 2.5) & (x < 3.5)] = 3.0
x[(x >= 3.5) & (x <= 5.0)] = 4.0
x[x > 5.0] = 6.0
return x * sign
Comment on lines +107 to +118

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Following the removal of its only call site in mxfp4_qdq_utils.py, this function cast_to_fp4 is now unused and can be removed. Additionally, its implementation using multiple masked assignments is inefficient on GPUs.

Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def create_weights(
H = hidden_size
IN = intermediate_size_per_partition
if self.has_bias:
# TODO: @yiliu30: use the dtype in CK
# TODO: yiliu30 use the dtype in CK
bias_dtype = torch.bfloat16
w13_bias = torch.nn.Parameter(
torch.zeros(E, 2 * IN, dtype=bias_dtype), requires_grad=False
Expand All @@ -155,16 +155,12 @@ def create_weights(
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> Optional[FusedMoEQuantConfig]:
# TODO: @yiliu30: implement it
if envs.VLLM_AR_MXFP4_MODULAR_MOE:
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
ocp_mx_moe_quant_config,
)
self.input_dtype = "mxfp4"
self.weight_dtype = "mxfp4"
# breakpoint()
return ocp_mx_moe_quant_config(
quant_dtype=self.input_dtype,
weight_dtype=self.weight_dtype,
Expand All @@ -181,9 +177,6 @@ def get_fused_moe_quant_config(
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if envs.VLLM_ENABLE_STATIC_MOE:
if envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS:
logger.debug(
f"start processing weights for {getattr(layer, 'prefix', 'unknown')}"
)
weight_name_lst = ["w13_weight", "w2_weight"]
from .mxfp4_qdq_utils import dequant_mxfp4_to_fp8

Expand Down Expand Up @@ -383,8 +376,11 @@ def apply(
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,)
e_score_correction_bias=e_score_correction_bias,
)
assert self.fused_experts is None

# There are three implementations:

if envs.VLLM_AR_MXFP4_MODULAR_MOE:
from vllm.model_executor.layers.fused_moe import fused_experts
Expand All @@ -409,8 +405,7 @@ def apply(
quant_config=self.moe_quant_config,
)
return out



num_all_tokens, hidden_dim = x.shape
num_experts = layer.local_num_experts
total_num_experts = router_logits.size(-1)
Expand Down Expand Up @@ -454,7 +449,6 @@ def apply(
self.experts_mask_buffer.zero_()
experts_mask = self.experts_mask_buffer


topk_ids = topk_ids.to(torch.int64)
topk_weights = topk_weights.to(x.dtype)
experts_mask.scatter_(-1, topk_ids, topk_weights)
Expand All @@ -464,10 +458,7 @@ def apply(
experts_mask = experts_mask[:num_all_tokens, :total_num_experts]
experts_mask = experts_mask.transpose(0, 1)
# Note: ep_size equal tp_size
if expert_map is not None:
ep_rank = get_tensor_model_parallel_rank()
else:
ep_rank = 0
ep_rank = get_tensor_model_parallel_rank() if expert_map is not None else 0
ep_shift = ep_rank * num_experts

if envs.VLLM_ENABLE_STATIC_MOE and not envs.VLLM_MXFP4_PRE_UNPACK_WEIGHTS:
Expand All @@ -481,28 +472,18 @@ def apply(

local_w13_packed = layer.w13_weight_packed[expert_index]
local_w13_scale = layer.w13_weight_scale[expert_index]
# local_w13_global_scale = layer.w13_weight_global_scale[expert_index]
# local_w13_input_global_scale = layer.w13_input_global_scale[
# expert_index
# ]
local_w2_packed = layer.w2_weight_packed[expert_index]
local_w2_scale = layer.w2_weight_scale[expert_index]
# local_w2_global_scale = layer.w2_weight_global_scale[expert_index]
# local_w2_input_global_scale = layer.w2_input_global_scale[expert_index]

local_w1_packed = local_w13_packed[
:intermediate_size_per_partition, ...
]
local_w1_scale = local_w13_scale[:intermediate_size_per_partition, ...]
# local_w1_global_scale = local_w13_global_scale[0]
# local_w1_input_global_scale = local_w13_input_global_scale[0]

local_w3_packed = local_w13_packed[
intermediate_size_per_partition:, ...
]
local_w3_scale = local_w13_scale[intermediate_size_per_partition:, ...]
# local_w3_global_scale = local_w13_global_scale[1]
# local_w3_input_global_scale = local_w13_input_global_scale[1]

from .mxfp4_qdq_utils import run_mxfp4_emulations

Expand All @@ -515,8 +496,6 @@ def apply(
local_w3_bias = local_w13_bias[intermediate_size_per_partition:]
local_w2_bias = layer.w2_bias[expert_index]

# local_w13_input_global_scale_max = local_w13_input_global_scale.max()

local_w1_out = run_mxfp4_emulations(
x=current_state_static,
weight=local_w1_packed,
Expand Down Expand Up @@ -560,15 +539,17 @@ def apply(

local_unpacked_w2 = layer.w2_weight_unpacked[expert_index]
local_w2_scale = layer.w2_weight_scale[expert_index]

local_unpacked_w1 = local_unpacked_w13[:intermediate_size_per_partition, ...]

local_unpacked_w1 = local_unpacked_w13[
:intermediate_size_per_partition, ...
]
half_scale = local_w13_scale.shape[0] // 2
local_w1_scale = local_w13_scale[:half_scale, ...]
local_unpacked_w3 = local_unpacked_w13[intermediate_size_per_partition:, ...]
local_unpacked_w3 = local_unpacked_w13[
intermediate_size_per_partition:, ...
]
local_w3_scale = local_w13_scale[half_scale:, ...]



local_w1_bias = None
local_w2_bias = None
local_w3_bias = None
Expand All @@ -584,32 +565,29 @@ def apply(
x=current_state_static,
weigth_fp8=local_unpacked_w1,
weight_scale_bf16=local_w1_scale,
bias=local_w1_bias
bias=local_w1_bias,
)
local_w3_out = mxfp4_gemm_with_unpacked_weight(
x=current_state_static,
weigth_fp8=local_unpacked_w3,
weight_scale_bf16=local_w3_scale,
bias=local_w3_bias
bias=local_w3_bias,
)

w13_out = apply_act(local_w1_out, local_w3_out, activation)


local_w2_out = mxfp4_gemm_with_unpacked_weight(
x=w13_out,
weigth_fp8=local_unpacked_w2,
weight_scale_bf16=local_w2_scale,
bias=local_w2_bias
)

padded_weight = experts_mask[expert_index + ep_shift].unsqueeze(
1
bias=local_w2_bias,
)

padded_weight = experts_mask[expert_index + ep_shift].unsqueeze(1)
local_w2_out = local_w2_out * padded_weight
if expert_index == 0:
final_hidden_states = local_w2_out
else:
final_hidden_states += local_w2_out
return final_hidden_states
raise NotImplementedError(f"Not implemented for now.")
raise NotImplementedError("Not implemented for now.")
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# ==------------------------------------------------------------------------==


class AutoRoundMoEMethodMXFP8(AutoRoundMoEMethod):
class AutoRoundMoEMethodMXFp8Impl(AutoRoundMoEMethod):
def __init__(
self,
quant_config: "AutoRoundConfig", # type: ignore # noqa E501
Expand Down
Loading