From c6fdc2c286b46711170354b64a40cde5614000df Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 22 Aug 2025 11:09:27 +0000 Subject: [PATCH] fix moe example Signed-off-by: yiliu30 --- torchao/prototype/moe_quant/llama4_quant.py | 5 ++++- torchao/prototype/moe_quant/quantizable_moe_modules.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/moe_quant/llama4_quant.py b/torchao/prototype/moe_quant/llama4_quant.py index 36e684d47d..b87e1d2f5a 100644 --- a/torchao/prototype/moe_quant/llama4_quant.py +++ b/torchao/prototype/moe_quant/llama4_quant.py @@ -58,7 +58,10 @@ def convert_fn(module): model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" -model = Llama4ForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) + +dtype = torch.bfloat16 +torch.set_default_dtype(dtype) +model = Llama4ForCausalLM.from_pretrained(model_id, torch_dtype=dtype) tokenizer = AutoTokenizer.from_pretrained(model_id) _replace_with_custom_fn_if_matches_filter( diff --git a/torchao/prototype/moe_quant/quantizable_moe_modules.py b/torchao/prototype/moe_quant/quantizable_moe_modules.py index d806f50b4f..553525ce44 100644 --- a/torchao/prototype/moe_quant/quantizable_moe_modules.py +++ b/torchao/prototype/moe_quant/quantizable_moe_modules.py @@ -30,7 +30,7 @@ def __init__( def forward(self, x: Tensor) -> Tensor: batch_size = x.shape[0] x = x.view(-1, self.hidden_dim) # x: [T, D] - scores = self.router(x) # [T, E] + scores = self.router(x)[0] # [T, E] scores = F.softmax(scores, dim=-1) scores, expert_indices = torch.topk( scores, self.top_k, dim=-1