diff --git a/scripts/DEEPSEEK_R1_ON_GAUDI.md b/scripts/DEEPSEEK_R1_ON_GAUDI.md index bada2e3bd967..989e0c89a115 100644 --- a/scripts/DEEPSEEK_R1_ON_GAUDI.md +++ b/scripts/DEEPSEEK_R1_ON_GAUDI.md @@ -89,3 +89,22 @@ ray start --address='${head_ip}:6379' --resources='{"HPU": 8, "TPU": 0}' python scripts/run_example_tp_2nodes.py --model ${YOUR_PATH}/DeepSeek-R1-static ``` +# Requantize the Official FP8 Model Using INC +- INC: https://github.com/yiliu30/vllm-fork/tree/r1-woq + +- Calibration +```bash +export OFFICIAL_FP8_MODEL=deepseek-ai/DeepSeek-R1 +# For quick test +VLLM_REQUANT_FP8_INC=1 QUANT_CONFIG=inc_measure_with_fp8kv_config.json VLLM_ENABLE_RUNTIME_DEQUANT=1 python run_example_tp.py --model ${OFFICIAL_FP8_MODEL} --tokenizer ${OFFICIAL_FP8_MODEL} --osl 32 --max_num_seqs 1 +# For calibration with pile dataset +VLLM_REQUANT_FP8_INC=1 QUANT_CONFIG=inc_measure_with_fp8kv_config.json VLLM_ENABLE_RUNTIME_DEQUANT=1 python run_example_tp.py --model ${OFFICIAL_FP8_MODEL} --tokenizer ${OFFICIAL_FP8_MODEL} --osl 32 --max_num_seqs 1 --nprompts 512 --dataset pile +``` +- Quantizatiion +```bash +VLLM_REQUANT_FP8_INC=1 QUANT_CONFIG=inc_quant_with_fp8kv_config.json VLLM_ENABLE_RUNTIME_DEQUANT=1 python run_example_tp.py --model ${OFFICIAL_FP8_MODEL} --tokenizer ${OFFICIAL_FP8_MODEL} --max_num_seqs 1 --fp8_kv_cache + +- Evaluation +```bash +VLLM_REQUANT_FP8_INC=1 QUANT_CONFIG=inc_quant_with_fp8kv_config.json VLLM_ENABLE_RUNTIME_DEQUANT=1 python run_lm_eval.py --model ${OFFICIAL_FP8_MODEL} --tokenizer ${OFFICIAL_FP8_MODEL} --fp8_kv_cache -l 64 --batch_size 1 +``` \ No newline at end of file diff --git a/scripts/inc_measure_with_fp8kv_config.json b/scripts/inc_measure_with_fp8kv_config.json new file mode 100644 index 000000000000..ee5d0b445e8b --- /dev/null +++ b/scripts/inc_measure_with_fp8kv_config.json @@ -0,0 +1,15 @@ +{ + "method": "HOOKS", + "mode": "MEASURE", + "observer": "maxabs", + "whitelist": { + "types": [], + "names": [] + }, + "blocklist": { + "types": [], + "names": ["lm_head", "mlp\\.gate\\b"] + }, + "quantize_weight": false, + "dump_stats_path": "./nc_workspace_measure_kvache/inc_measure_output" +} \ No newline at end of file diff --git a/scripts/inc_quant_with_fp8kv_config.json b/scripts/inc_quant_with_fp8kv_config.json new file mode 100644 index 000000000000..1ffb4c742d5b --- /dev/null +++ b/scripts/inc_quant_with_fp8kv_config.json @@ -0,0 +1,14 @@ +{ + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "allowlist": { + "types": [], + "names": [] + }, + "blocklist": { + "types": [], + "names": ["lm_head", "mlp\\.gate\\b"] + }, + "dump_stats_path": "./nc_workspace_measure_kvache/inc_measure_output" +} \ No newline at end of file diff --git a/scripts/run_example_tp.py b/scripts/run_example_tp.py index 7f2953055334..db0eec91ecae 100644 --- a/scripts/run_example_tp.py +++ b/scripts/run_example_tp.py @@ -25,6 +25,7 @@ parser.add_argument("--isl", type=int, default=1024, help="input sequence length.") parser.add_argument("--osl", type=int, default=1024, help="output sequence length.") parser.add_argument("--nprompts", type=int, default=4, help="The number of prompts.") +parser.add_argument("--max_num_seqs", type=int, default=None, help="The max number of sequences.") parser.add_argument("--random", action="store_true", help="Randomly sample prompts.") parser.add_argument("--fp8_kv_cache", action="store_true", help="Use fp8 for kv cache.") args = parser.parse_args() @@ -36,8 +37,12 @@ os.environ["VLLM_EP_SIZE"] = f"{args.ep_size}" os.environ["VLLM_MLA_DISABLE_REQUANTIZATION"] = "1" os.environ["PT_HPU_WEIGHT_SHARING"] = "0" +os.environ["VLLM_LOGGING_LEVEL"] = "DEBUG" #os.environ['VLLM_DMOE_DYNAMIC_SCALE']='1' # only works for 1.20 + dmoe patch + + + def sample_sonnet_requests( dataset_path: str, num_requests: int, @@ -160,6 +165,16 @@ def sample_gsm8k_requests( tokenizer=tokenizer, do_random=args.random, ) + elif args.dataset == "pile": + from utils import get_prompts, get_prompt_token_ids, get_pile_prompts + least_tokens = args.isl + num_samples = args.nprompts + prompts = get_pile_prompts(args.model, num_samples) + prompt_token_ids = get_prompt_token_ids( + args.model, prompts, least_tokens + ) + print(f"Got {len(prompts)} prompts, length of first prompt: {len(prompt_token_ids[0])}.") + gt = None else: prompts = [ "Hello, my name is", @@ -178,6 +193,8 @@ def sample_gsm8k_requests( param = {} if args.fp8_kv_cache: param["kv_cache_dtype"] = "fp8_inc" + if args.max_num_seqs is not None: + param["max_num_seqs"] = args.max_num_seqs if args.tp_size == 1: llm = LLM( model=model, @@ -201,10 +218,16 @@ def sample_gsm8k_requests( **param ) + # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. start = time.perf_counter() - outputs = llm.generate(prompts, sampling_params) + if args.dataset == "pile": + outputs = llm.generate( + prompts=None, sampling_params=sampling_params, prompt_token_ids=prompt_token_ids + ) + else: + outputs = llm.generate(prompts, sampling_params) end = time.perf_counter() # Print the outputs. print(f"e2e took {end - start} seconds") @@ -218,4 +241,6 @@ def sample_gsm8k_requests( print(f"Generated text: {generated_text!r}") print(f"Ground truth: {gt_i!r}") print("====================================") + if os.getenv("VLLM_FORCE_INC", None) is not None: + llm.llm_engine.model_executor.shutdown() del llm \ No newline at end of file diff --git a/scripts/run_lm_eval.py b/scripts/run_lm_eval.py index 4fc99ec5f401..9f3fc206d975 100644 --- a/scripts/run_lm_eval.py +++ b/scripts/run_lm_eval.py @@ -16,8 +16,9 @@ parser.add_argument("--tokenizer", type=str, default=None, help="The model path.") parser.add_argument("--tp_size", type=int, default=8, help="Tensor Parallelism size.") parser.add_argument("--ep_size", type=int, default=8, help="Expert Parallelism size.") -parser.add_argument("-l", "--limit", type=int, default=64, help="test request counts.") +parser.add_argument("-l", "--limit", type=int, default=None, help="test request counts.") parser.add_argument("--batch_size", type=int, default=1, help="The batch size.") +parser.add_argument("--fp8_kv_cache", action="store_true", help="Use fp8 for kv cache.") args = parser.parse_args() os.environ["VLLM_SKIP_WARMUP"] = "true" @@ -36,6 +37,16 @@ #os.environ['VLLM_DMOE_DYNAMIC_SCALE']='1' #os.environ['VLLM_ENABLE_RUNTIME_DEQUANT']='1' +if args.task == "gsm8k": + #For testing gsm8k quickly + os.environ['VLLM_PROMPT_BS_BUCKET_MIN']='1' + os.environ['VLLM_PROMPT_BS_BUCKET_MAX']='1' + os.environ['VLLM_PROMPT_SEQ_BUCKET_MIN']='2048' + os.environ['VLLM_PROMPT_SEQ_BUCKET_STEP']='512' + os.environ['VLLM_PROMPT_SEQ_BUCKET_MAX']='2048' + os.environ['VLLM_DECODE_BS_BUCKET_MIN']='1' + os.environ['VLLM_DECODE_BS_BUCKET_MAX']='1' + if __name__ == "__main__": from lm_eval.models.vllm_causallms import VLLM @@ -44,6 +55,9 @@ model = args.model if args.tokenizer is None: args.tokenizer = model + param = {} + if args.fp8_kv_cache: + param["kv_cache_dtype"] = "fp8_inc" if args.tp_size == 1: llm = VLLM( pretrained=model, @@ -65,17 +79,27 @@ dtype="bfloat16", gpu_memory_utilization=0.8, batch_size=args.batch_size, + **param, ) # Run the evaluation; you can adjust num_fewshot and batch_size as needed. start = time.perf_counter() if args.task == "gsm8k": - results = simple_evaluate(model=llm, tasks=["gsm8k"], num_fewshot=5, batch_size=8, limit=args.limit) + from lm_eval.utils import make_table + + results = simple_evaluate( + model=llm, + tasks=["gsm8k"], + # num_fewshot=5, + # batch_size=8, + limit=args.limit, + ) end = time.perf_counter() e2e = end - start + print(make_table(results)) # save as json - with open(f"gsm8k_ep{args.ep_size}_result_samples_limit{args.limit}.jsonl", "w") as f: + with open(f"gsm8k_ep{args.ep_size}_result_samples_limit{str(args.limit)}.jsonl", "w") as f: json.dump(results['results'], f) json.dump({"e2e time(secs)": e2e}, f) f.write("\n") @@ -86,7 +110,7 @@ results = simple_evaluate(model=llm, tasks=["hellaswag"], num_fewshot=0, batch_size=8, limit=args.limit) end = time.perf_counter() e2e = end - start - with open(f"hallaswag_ep{args.ep_size}_result_samples_limit{args.limit}.jsonl", "w") as f: + with open(f"hallaswag_ep{args.ep_size}_result_samples_limit{str(args.limit)}.jsonl", "w") as f: json.dump(results['results'], f) json.dump({"e2e time(secs)": e2e}, f) f.write("\n") diff --git a/scripts/utils.py b/scripts/utils.py new file mode 100644 index 000000000000..e1202909d2ea --- /dev/null +++ b/scripts/utils.py @@ -0,0 +1,148 @@ +from transformers import PreTrainedTokenizerBase, AutoTokenizer +from typing import List, Dict, Any + +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast +from typing import List, Dict, Any +import json + +import random + +def reset_seed(seed=42): + import torch + import random + import numpy as np + print("Using seed: ", seed) + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # multi-GPU. + # TODO: for future use + # torch.backends.cudnn.benchmark = False + # torch.backends.cudnn.deterministic = True + +def get_prompts(): + filename = "pile.txt" + with open(filename, "r") as f: + prompts = f.readlines() + print(f"Number of prompts: {len(prompts)}") + return prompts + + +def get_prompt_token_ids(model_path, prompts, max_length=1024): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_path) + prompt_token_ids = [] + for prompt in prompts: + tokens = tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=max_length, + ) + if len(tokens.input_ids[0]) < max_length: + continue + prompt_token_ids.append([x.item() for x in tokens.input_ids[0]]) + return prompt_token_ids + + +def get_pile_prompts(model_name, num_samples=512): + from datasets import load_dataset + from tqdm import tqdm + import transformers + + """ + autoround calibration static model: + NeelNanda/pile-10k,seed=42, iters=1 rtn, nsamples=512 seqlen=1024 + """ + + # ==-------------------------------------------------------------------------== + # Calibration parameters + least_tokens = 1024 + seed = 42 + # ==-------------------------------------------------------------------------== + + reset_seed(seed) + + dataset = load_dataset("NeelNanda/pile-10k", split="train") + dataset = dataset.shuffle(seed=seed) + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_name, trust_remote_code=True + ) + num_sample = 0 + samples_lst = [] + for data in tqdm(dataset): + prompt = data["text"] + tokens = tokenizer(prompt, return_tensors="pt") + if len(tokens.input_ids[0]) < least_tokens: + continue + num_sample += 1 + samples_lst.append(prompt) + if num_sample >= num_samples: + break + return samples_lst + +#==-------------------------------------------------------------------------== +# Load custom dataset +#==-------------------------------------------------------------------------== + +def get_dataset(filepath: str) -> List[List[Dict[str, str]]]: + """ + [ + [ + {"role": "system", "content": "system prompt"}, + {"role": "user", "content": "query prompt"}, + ], + [ + {"role": "system", "content": "1. 角色设定:- 你是...."}, + {"role": "user", "content": "搜索关键词】\n梁斌是谁,做什么"}, + ], + ... + ] + + """ + with open(filepath) as f: + dataset: List[List[Dict[str, str]]] = [json.loads(line) for line in f] + return dataset + + +def sample_tc_requests( + filepath: str, + tokenizer: PreTrainedTokenizerBase, + num_requests: int = None, + do_random: bool = False, +) -> List[str]: + dataset = get_dataset(filepath) + prompts = dataset + few_shots = 0 + sampled_requests: List[str] = [] + if num_requests is None: + num_requests = len(prompts) + for j in range(num_requests): + i = ( + random.choice(range(len(prompts[few_shots:]))) + if do_random + else j + few_shots + ) + # message demo: + # [ + # {"role": "system", "content": "1. 角色设定:- 你是...."}, + # {"role": "user", "content": "搜索关键词】\n梁斌是谁,做什么"}, + # ], + message: List[Dict[str, str]] = prompts[i] + prompt_with_template = tokenizer.apply_chat_template( + message, add_generation_prompt=True, tokenize=False + ) + sampled_requests.append(prompt_with_template) + + return sampled_requests + +def get_tokenizer(model_path) -> PreTrainedTokenizer | PreTrainedTokenizerFast: + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path) + return tokenizer + + diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e15ba5be444e..e7178a48759c 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -30,6 +30,123 @@ fused_moe_pallas = None # type: ignore logger = init_logger(__name__) +import os +VLLM_REQUANT_FP8_INC = os.getenv("VLLM_REQUANT_FP8_INC", "0") in ["1", "true"] + +# ==-------------------------------------------------------------------------== +# VLLM-HPU-EXT PATCH Start +# ==-------------------------------------------------------------------------== + +import torch.nn.functional as F +import habana_frameworks.torch.core as htcore +import habana_frameworks.torch as htorch + + +class MoeFP8Matmul(torch.nn.Module): + def __init__( + self, + block_size: Tuple[int, int] = (128, 128), + high_precision=torch.bfloat16, + ): + super().__init__() + self.block_size = block_size + self.high_precision = high_precision + self.is_dequantized = False + + def set_weight(self, w: torch.Tensor): + self.weight = w + + def set_scale_inv_fp8(self, scale_inv_fp8: torch.Tensor): + self.scale_inv_fp8 = scale_inv_fp8 + + def set_high_precision(self, high_precision=torch.bfloat16): + self.high_precision = high_precision + + def set_weight_block_size(self, block_size: Tuple[int, int] = (128, 128)): + self.block_size = block_size + + def get_dequant_weight(self): + from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + dequant_block_fp8_weight_naive, + ) + + return dequant_block_fp8_weight_naive( + self.weight, + self.scale_inv_fp8, + block_size=self.block_size, + dtype=self.high_precision, + ) + + def forward(self, state, expert_id, w): + raise NotImplementedError() + + def dequant_block_fp8_weight(self, layer: "MoeFP8Matmul") -> torch.Tensor: + # The function will be called by INC either in the measurement or the quantization phase. + # At quantization phase, INC requantizes the BF16 weight to FP8 and updates the weight. + # At measurement phase, INC only measures the BF16 weight and does NOT update the weight. + # We not track the BF16 weight which will cause OoM. + if self.is_dequantized: + return layer.weight + + dequant_weight = layer.get_dequant_weight() + layer.is_dequantized = True + return dequant_weight + + def get_dequant_weights_func( + self, + ) -> Optional[Callable[[torch.nn.Module], torch.Tensor]]: + return self.dequant_block_fp8_weight + + +class VllmMixtureOfExpertsOpFP8(torch.nn.Module): + def __init__(self, num_total_experts: int): + super().__init__() + self.w13_list = torch.nn.ModuleList( + [MoeFP8Matmul() for _ in range(num_total_experts)] + ) + self.w2_list = torch.nn.ModuleList( + [MoeFP8Matmul() for _ in range(num_total_experts)] + ) + self.num_experts = num_total_experts + # FIXME (Yi) add experts_min and experts_max as init parameters + self.experts_min = None + self.experts_max = None + + def forward( + self, + x, + topk_ids, + topk_weights, + moe_n_slice, + n_expert_slice, + ep_shift, + ): + min_expert = self.experts_min + max_expert = self.experts_max + w13_list_slice = [] + w2_list_slice = [] + for j in range(self.num_experts): + w13_list_slice.append(self.w13_list[j].get_dequant_weight()) + w2_list_slice.append(self.w2_list[j].get_dequant_weight()) + + final_hidden_states = torch.ops.hpu.mixture_of_experts( + hidden_states=x, + expert_routing_table=topk_ids.to(torch.int64), + router_weights=topk_weights.to(x.dtype), + w12=w13_list_slice, + w3=w2_list_slice, + permuted_weights=True, + activation="silu", + experts_min=min_expert, + experts_max=max_expert, + ) + htorch.core.mark_step() + return final_hidden_states + + +# ==-------------------------------------------------------------------------== +# VLLM-HPU-EXT PATCH End +# ==-------------------------------------------------------------------------== class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" @@ -117,7 +234,8 @@ def apply( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + e_score_correction_bias: Optional[torch.Tensor] = None, + ep_rank: Optional[int] = None, ) -> torch.Tensor: return self.forward(x=x, layer=layer, @@ -129,7 +247,9 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) + e_score_correction_bias=e_score_correction_bias, + ep_rank= ep_rank, + ) def forward_cuda( self, @@ -176,8 +296,11 @@ def forward_hpu( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None + e_score_correction_bias: Optional[torch.Tensor] = None, + ep_rank = None, ): + bs, seq_len, hidden_size = x.shape + x = x.reshape(bs * seq_len, hidden_size) assert len(x.shape) == 2 import habana_frameworks.torch as htorch htorch.core.mark_step() @@ -203,8 +326,16 @@ def forward_hpu( topk_weights = topk_weights.to(x.dtype) final_hidden_states = torch.zeros_like(x) - num_experts = layer.w13_weight.shape[0] - n_expert_slice = layer.w13_weight.shape[0] // 8 + num_experts = layer.num_experts + if hasattr(layer, "w13_weight") and layer.w13_weight is not None: + assert ( + layer.w13_weight.shape[0] == num_experts + ), f"Expected {layer.w13_weight.shape[0]} experts, got {num_experts}" + # For mixtral, the `num_expert_group` is 8. + if num_expert_group is None: + num_expert_group = 8 + num_expert_group = num_expert_group + n_expert_slice = num_experts // num_expert_group assert n_expert_slice * 8 == num_experts # w13_list = layer.hpu_fused_moe.MoeOp.w13_list @@ -361,8 +492,10 @@ def __init__( self.topk_group = topk_group self.custom_routing_function = custom_routing_function if is_hpu: - from vllm_hpu_extension.ops import DynamicFusedMOE - self.hpu_fused_moe = DynamicFusedMOE(self.num_experts) + # FIXME: (Yi) WA, should use DynamicFusedMOE for INC + if not VLLM_REQUANT_FP8_INC: + from vllm_hpu_extension.ops import DynamicFusedMOE + self.hpu_fused_moe = DynamicFusedMOE(self.num_experts) self.scoring_func = scoring_func self.e_score_correction_bias = e_score_correction_bias @@ -392,6 +525,74 @@ def __init__( self.quant_method.create_weights(layer=self, **moe_quant_params) + # FIXME: (Yi) we need to wrap the `torch.ops.hpu.mixture_of_experts` as a module, + # so that INC can patch it for measurement and quantization. + layer = self + ep_shift = self.ep_rank * self.num_experts + if VLLM_REQUANT_FP8_INC: + num_experts_on_rank = self.num_experts + num_expert_group = 1 + num_expert_per_group = num_experts_on_rank // num_expert_group + n_expert_slice = num_experts_on_rank // num_expert_group + assert n_expert_slice * num_expert_group == num_experts_on_rank + moe_n_slice = int(os.environ.get("VLLM_MOE_N_SLICE", 4)) + assert moe_n_slice == 1, f"moe_n_slice is {moe_n_slice}, expected 1 for INC" + moe_lst = [] + for i in range(moe_n_slice): + sub_expert_group = VllmMixtureOfExpertsOpFP8( + num_expert_per_group + ) + min_expert = i * n_expert_slice + max_expert = (i + 1) * n_expert_slice + + w13_list_slice = [ + layer.w13_weight[j] for j in range(min_expert, max_expert) + ] + w13_weight_scale_inv_fp8_list = [ + layer.w13_weight_scale_inv[j] + for j in range(min_expert, max_expert) + ] + w2_list_slice = [ + layer.w2_weight[j] for j in range(min_expert, max_expert) + ] + w2_weight_scale_inv_fp8_list = [ + layer.w2_weight_scale_inv[j] + for j in range(min_expert, max_expert) + ] + for index in range(len(w2_list_slice)): + sub_expert_group.w13_list[index].set_weight( + w13_list_slice[index] + ) + sub_expert_group.w13_list[index].set_scale_inv_fp8( + w13_weight_scale_inv_fp8_list[index] + ) + sub_expert_group.w13_list[index].set_weight_block_size( + layer.quant_config.weight_block_size + ) + + sub_expert_group.w2_list[index].set_weight( + w2_list_slice[index] + ) + sub_expert_group.w2_list[index].set_scale_inv_fp8( + w2_weight_scale_inv_fp8_list[index] + ) + sub_expert_group.w2_list[index].set_weight_block_size( + layer.quant_config.weight_block_size + ) + + # FIXME: (Yi) pass `experts_min` and `experts_max` to MoeOp. + setattr( + sub_expert_group, "experts_min", min_expert + ep_shift + ) + setattr( + sub_expert_group, "experts_max", max_expert - 1 + ep_shift + ) + # setattr(self, f"sub_expert_group_{i}", sub_expert_group) + moe_lst.append(sub_expert_group) + htorch.core.mark_step() + self.moe_lst = torch.nn.ModuleList(moe_lst) + htorch.core.mark_step() + def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, @@ -481,7 +682,7 @@ def _load_w13(self, expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data.copy_(loaded_weight) - if is_hpu: + if is_hpu and not VLLM_REQUANT_FP8_INC: self.hpu_fused_moe.MoeOp.w13_list[expert_id].set_weight( orig_exp_data) # print(f"loaded w13 for hpu for expert_id: {expert_id}, orig_exp_data.shape: {orig_exp_data.shape}") @@ -504,7 +705,7 @@ def _load_w2(self, shard_size) # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight) - if is_hpu: + if is_hpu and not VLLM_REQUANT_FP8_INC: self.hpu_fused_moe.MoeOp.w2_list[expert_id].set_weight(expert_data) # print(f"loaded w2 for hpu for expert_id: {expert_id}, expert_data.shape: {expert_data.shape}") diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 62bf9094aa40..58d96ec934cb 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -2,7 +2,7 @@ import itertools from abc import abstractmethod -from typing import Optional +from typing import Optional, Callable import torch import torch.nn.functional as F @@ -184,6 +184,14 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]: raise NotImplementedError + def get_dequant_weights_func( + self, + ) -> Optional[Callable[[torch.nn.Module], torch.Tensor]]: + if self.quant_method is not None: + quant_method = self.quant_method + if hasattr(quant_method, "dequant_block_fp8_weight"): + return quant_method.dequant_block_fp8_weight + class ReplicatedLinear(LinearBase): """Replicated linear layer. diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 288e4c4c5f75..0f9a72e3c004 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -49,6 +49,9 @@ logger = init_logger(__name__) +VLLM_REQUANT_FP8_INC = os.getenv("VLLM_REQUANT_FP8_INC", "0") in ["1", "true"] + + class Fp8Config(QuantizationConfig): """Config class for FP8.""" @@ -275,6 +278,17 @@ def create_weights( else: layer.register_parameter("input_scale", None) + def dequant_block_fp8_weight(self, layer) -> torch.Tensor: + dequant_weight = dequant_block_fp8_weight_naive( + layer.weight, + layer.weight_scale_inv.data, + self.quant_config.weight_block_size, + original_M=layer.orig_M, + original_N=layer.orig_N, + do_unpad=True, + ) + return dequant_weight + def process_weights_after_loading(self, layer: Module) -> None: # TODO(rob): refactor block quant into separate class. if self.block_quant: @@ -498,6 +512,7 @@ def __init__(self, quant_config: Fp8Config): def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): + layer.quant_config = self.quant_config if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn if self.block_quant: @@ -635,6 +650,8 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, def process_weights_after_loading(self, layer: Module) -> None: # TODO (rob): refactor block quant into separate class. + # if torch.distributed.get_rank() == 0: + # import pdb; pdb.set_trace() if self.block_quant: if current_platform.is_hpu(): if self.quant_config.enable_runtime_dequant: @@ -857,8 +874,10 @@ def forward_hpu( ep_rank=0, ): batch_size, seq_len, hidden_dim = x.shape - num_experts = layer.w13_weight.shape[0] - n_expert_slice = layer.w13_weight.shape[0] // self.moe_n_slice + num_experts = layer.num_experts + n_expert_slice = num_experts // self.moe_n_slice + # num_experts = layer.w13_weight.shape[0] + # n_expert_slice = layer.w13_weight.shape[0] // self.moe_n_slice assert n_expert_slice * self.moe_n_slice == num_experts x = x.view(-1, hidden_dim) total_num_experts = router_logits.size(-1) @@ -1054,14 +1073,28 @@ def do_dynamic_moe_with_dequant(x, topk_ids, topk_weights, w13_weight_fp8, w2_we moe_n_slice = 4 if actual_total_experts >= 64 else 1 n_expert_slice = actual_total_experts // moe_n_slice else: - w13_weight_fp8 = layer.w13_weight.data - w13_weight_scale_inv_fp8 = layer.w13_weight_scale_inv.data - w2_weight_fp8 = layer.w2_weight.data - w2_weight_scale_inv_fp8 = layer.w2_weight_scale_inv.data actual_total_experts = total_num_experts actual_num_experts = num_experts moe_n_slice = self.moe_n_slice n_expert_slice = actual_num_experts // moe_n_slice + if self.quant_config.enable_runtime_dequant and VLLM_REQUANT_FP8_INC: + assert not use_partial_experts, "Partial experts not supported with VLLM_REQUANT_FP8_INC" + # FIXME: (Yi) handle the case where moe_n_slice > 1 + final_hidden_states: torch.Tensor = torch.zeros_like(x) + for moe in layer.moe_lst: + final_hidden_states += moe( + x, + topk_ids, + topk_weights, + moe_n_slice, + n_expert_slice, + ep_shift, + ) + return final_hidden_states.view(-1, x.shape[1]) + w13_weight_fp8 = layer.w13_weight.data + w13_weight_scale_inv_fp8 = layer.w13_weight_scale_inv.data + w2_weight_fp8 = layer.w2_weight.data + w2_weight_scale_inv_fp8 = layer.w2_weight_scale_inv.data if self.quant_config.activation_scheme == "dynamic": if self.quant_config.enable_runtime_dequant: diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 84f696fb8f42..3aa913b5f3de 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -416,6 +416,18 @@ def download_model(self, model_config: ModelConfig) -> None: fall_back_to_pt=True, allow_patterns_overrides=None) + def _need_patch_inc_fp8_kvcache(self, vllm_config: VllmConfig) -> bool: + user_pass_inc_as_quantization = ( + vllm_config.quant_config is not None + and "inc" in vllm_config.quant_config.get_name().lower() + ) + force_use_inc = os.environ.get("VLLM_REQUANT_FP8_INC", "0") == "1" + user_pass_inc_as_quantization = user_pass_inc_as_quantization or force_use_inc + return ( + vllm_config.cache_config.cache_dtype == "fp8_inc" + and not user_pass_inc_as_quantization + ) + def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config load_config = vllm_config.load_config @@ -443,7 +455,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: logger.warning(warning_msg) _process_weights_after_loading(model, model_config, target_device) - if vllm_config.cache_config.cache_dtype == "fp8_inc": + if self._need_patch_inc_fp8_kvcache(vllm_config): from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import PatchedVLLMKVCache from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import Fp8cfg from neural_compressor.torch.algorithms.fp8_quant.model_configs import ModuleExtraConfig, ModuleConfig diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index d39810d01475..1068851f5614 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -760,6 +760,28 @@ def _set_gc_threshold(self) -> None: self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP', 'false').lower() == 'true' + def _remove_duplicate_submodules_(self, model, inc_config): + # FIXME: (Yi) for deepseek v3 only + self_attn = model.model.layers[0].self_attn + for layer in model.model.layers: + self_attn = layer.self_attn + # delete attrs: q_b_proj, kv_b_proj, o_proj in self_attn + if hasattr(self_attn, "q_b_proj"): + delattr(self_attn, "q_b_proj") + if hasattr(self_attn, "kv_b_proj"): + delattr(self_attn, "kv_b_proj") + if hasattr(self_attn, "o_proj"): + delattr(self_attn, "o_proj") + + def _inc_preprocess_(self, model: torch.nn.Module, inc_config): + self._remove_duplicate_submodules_(model, inc_config) + + def _is_quant_with_inc(self): + return ( + self.model_config.quantization == "inc" + or os.getenv("VLLM_REQUANT_FP8_INC", "0") in ["1", "true"] + ) + def load_model(self) -> None: import habana_frameworks.torch.core as htcore if self.model_config.quantization == 'inc' or \ @@ -811,19 +833,34 @@ def load_model(self) -> None: ) self.model = self.lora_manager.create_lora_manager(self.model) - if self.model_config.quantization == 'inc': + if self._is_quant_with_inc(): logger.info("Preparing model with INC..") + if torch.distributed.get_rank() == 0: + logger.info(f"Original model \n {self.model}") with HabanaMemoryProfiler() as m_inc: from neural_compressor.torch.quantization import ( FP8Config, convert, prepare) config = FP8Config.from_json_file( os.getenv("QUANT_CONFIG", "")) + + self._inc_preprocess_(self.model, config) if config.measure: self.model = prepare(self.model, config) + # if torch.distributed.get_rank() == 0: + # import pdb;pdb.set_trace() + elif config.quantize: self.model = convert(self.model, config) - htcore.hpu_initialize(self.model, - mark_only_scales_as_const=True) + # if torch.distributed.get_rank() == 0: + # import pdb;pdb.set_trace() + + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + logger.info(f"INC model \n {self.model}") + htcore.hpu_initialize( + self.model, mark_only_scales_as_const=True + ) + self.inc_initialized_successfully = True logger.info("Preparing model with INC took %s", m_inc.get_summary_string()) @@ -1793,7 +1830,9 @@ def warmup_scenario(self, if is_pt_profiler_run and self.is_driver_worker: profiler = setup_profiler() profiler.start() - for _ in range(times): + logger.debug(f"Running warmup scenario: {scenario_name}") + for index in range(times): + logger.debug(f"Running warmup iteration: {index}/{times}") inputs = self.prepare_model_input(seqs) is_single_step = \ self.vllm_config.scheduler_config.num_scheduler_steps == 1 @@ -2107,7 +2146,7 @@ def finish_measurements(self): finalize_calibration(self.model.model) def shutdown_inc(self): - can_finalize_inc = (self.model_config.quantization == 'inc') and \ + can_finalize_inc = self._is_quant_with_inc() and \ (self.model.model is not None) and \ self.inc_initialized_successfully and \ not getattr(self, "_is_inc_finalized", False) diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index 06d024424955..34bdbc666340 100755 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -318,11 +318,13 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_fake_hpu_blocks = fake_hpu_cache_alloc // cache_block_size self.model_runner.bucketing_ctx.num_hpu_blocks = num_fake_hpu_blocks return num_fake_hpu_blocks, 0 + start_time = time.monotonic() with HabanaMemoryProfiler() as m: self.model_runner.profile_run() torch.hpu.synchronize() + profiling_time = time.monotonic() - start_time msg = ("Model profiling run " - f"took {m.get_summary_string()}") + f"took {m.get_summary_string()}, took time {profiling_time:.2f}s") logger.info(msg) # At this point we should've allocated the maximum workspace for all # recipes we will use the extra memory for graphs/blocks