Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9033,6 +9033,141 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))


@ModelBase.register("MegrezMoeForCausalLM", "MegrezMoEForCausalLM")
class MegrezMoEModel(TextModel):
model_arch = gguf.MODEL_ARCH.MEGREZ_MOE

def set_vocab(self):
# Megrez-MoE uses Qwen-style BPE tokenizer
# Use standard GPT2 vocab loading which handles BPE correctly
try:
self._set_vocab_gpt2()
except Exception:
# Fallback to Qwen-specific handling if needed
self._set_vocab_qwen()
# Note: special_vocab.add_to_gguf() is already called within
# _set_vocab_gpt2() and _set_vocab_qwen(), so no need to call it again

def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams

# MoE expert configuration
num_experts = hparams.get("num_experts") or hparams.get("n_routed_experts")
if num_experts is None:
raise ValueError("Missing 'num_experts' or 'n_routed_experts' in model config")
self.gguf_writer.add_expert_count(num_experts)

# Shared expert FFN size
hidden_size = hparams.get("hidden_size", 2048)
shared_expert_ffn_size = int(hidden_size * 2.75)
self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_ffn_size)

# Per-expert FFN size (should be consistent across all experts)
moe_intermediate_size = hparams.get("moe_intermediate_size")
if moe_intermediate_size is None:
raise ValueError("Missing 'moe_intermediate_size' in model config")
if not isinstance(moe_intermediate_size, list):
moe_intermediate_size = [moe_intermediate_size]

# Validate all experts have same size
if not all(n == moe_intermediate_size[0] for n in moe_intermediate_size):
raise ValueError(f"All experts must have same FFN size, got: {moe_intermediate_size}")
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size[0])

# Shared expert count
num_shared_expert = hparams.get("num_shared_expert") or hparams.get("n_shared_experts")
if num_shared_expert is None:
raise ValueError("Missing 'num_shared_expert' or 'n_shared_experts' in model config")
if not isinstance(num_shared_expert, list):
num_shared_expert = [num_shared_expert]

if not all(n == num_shared_expert[0] for n in num_shared_expert):
raise ValueError(f"All layers must have same shared expert count, got: {num_shared_expert}")
self.gguf_writer.add_expert_shared_count(num_shared_expert[0])

# RoPE scaling (Megrez may use dynamic scaling)
rope_scaling = hparams.get("rope_scaling")
if rope_scaling and rope_scaling.get("type") == "dynamic":
alpha = rope_scaling.get("alpha", 1000)
base = hparams.get("rope_theta", 10000.0)
hidden_size = hparams.get("hidden_size")
num_attention_heads = hparams.get("num_attention_heads")

if hidden_size is None or num_attention_heads is None:
raise ValueError("Missing 'hidden_size' or 'num_attention_heads' for RoPE scaling")

dim = hidden_size // num_attention_heads
scaled_base = base * (alpha ** (dim / (dim - 2)))

self.gguf_writer.add_rope_freq_base(scaled_base)
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_rope_scaling_factor(1)
self.gguf_writer.add_rope_scaling_orig_ctx_len(256 * 1024)
self.gguf_writer.add_context_length(256 * 1024)

_experts: list[dict[str, Tensor]] | None = None

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name == "lm_head.weight":
if self.hparams.get("tie_word_embeddings", False):
logger.info("Skipping tied output layer 'lm_head.weight'")
return []

# Handle MoE gate bias (e_score_correction_bias) - map to exp_probs_b
if "e_score_correction_bias" in name:
# This is the expert selection bias - map to blk.N.exp_probs_b
# Format: model.layers.N.mlp.gate.e_score_correction_bias -> blk.N.exp_probs_b
layer_num = int(name.split(".")[2]) # Extract layer number
new_name = f"blk.{layer_num}.exp_probs_b"
return [(new_name, data_torch)]

# Handle shared FFN (non-expert layers) - pass through directly
if name.find("mlp.down_proj") != -1 or name.find("mlp.gate_proj") != -1 or name.find("mlp.up_proj") != -1:
if name.find("mlp.experts") == -1:
# This is a shared FFN layer, not an expert - pass through
return [(self.map_tensor_name(name), data_torch)]

if name.find("mlp.experts") != -1:
n_experts = self.hparams.get("num_experts") or self.hparams.get("n_routed_experts")
if n_experts is None:
raise ValueError("Missing 'num_experts' or 'n_routed_experts' in config")
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))

return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("HunYuanMoEV1ForCausalLM")
class HunYuanMoEModel(TextModel):
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE
Expand Down
25 changes: 25 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ class MODEL_ARCH(IntEnum):
COGVLM = auto()
MINIMAXM2 = auto()
PANGU_EMBED = auto()
MEGREZ_MOE = auto()


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -795,6 +796,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.MINIMAXM2: "minimax-m2",
MODEL_ARCH.COGVLM: "cogvlm",
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.MEGREZ_MOE: "megrez-moe",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down Expand Up @@ -1549,6 +1551,29 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.MEGREZ_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.QWEN3VL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ add_library(llama
models/mamba.cpp
models/minicpm3.cpp
models/minimax-m2.cpp
models/megrez-moe.cpp
models/mpt.cpp
models/nemotron-h.cpp
models/nemotron.cpp
Expand Down
26 changes: 26 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINIMAX_M2, "minimax-m2" },
{ LLM_ARCH_COGVLM, "cogvlm" },
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MEGREZ_MOE, "megrez-moe" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -2378,6 +2379,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_MEGREZ_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_PANGU_EMBED,
{
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ enum llm_arch {
LLM_ARCH_MINIMAX_M2,
LLM_ARCH_COGVLM,
LLM_ARCH_PANGU_EMBED,
LLM_ARCH_MEGREZ_MOE,
LLM_ARCH_UNKNOWN,
};

Expand Down
5 changes: 4 additions & 1 deletion src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1386,7 +1386,10 @@ void llama_context::output_reorder() {
//

uint32_t llama_context::graph_max_nodes() const {
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
// Megrez-MoE creates many intermediate tensors (~35 per MoE layer)
// Use higher factor (9u) instead of base 8u to account for this overhead
uint32_t factor = (model.arch == LLM_ARCH_MEGREZ_MOE) ? 9u : 8u;
return std::max<uint32_t>(1024u, factor * model.n_tensors());
}

llm_graph_result * llama_context::get_gf_res_reserve() const {
Expand Down
93 changes: 87 additions & 6 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2174,12 +2174,26 @@ void llama_model::load_hparams(llama_model_loader & ml) {
case LLM_ARCH_PANGU_EMBED:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

switch (hparams.n_layer) {
case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1
case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_MEGREZ_MOE:
{
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);

switch (hparams.n_layer) {
case 31: type = LLM_TYPE_7B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
default: throw std::runtime_error("unsupported model architecture");
}

Expand Down Expand Up @@ -2631,9 +2645,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {

const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;

layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_gate_exps =
create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
layer.ffn_down_exps =
create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
layer.ffn_up_exps =
create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
}
} break;
case LLM_ARCH_LLAMA4:
Expand Down Expand Up @@ -3319,9 +3336,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// MoE branch
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;

layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_gate_exps =
create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
layer.ffn_down_exps =
create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
layer.ffn_up_exps =
create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);

// Shared expert branch
const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
Expand All @@ -3332,6 +3352,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0);
}
} break;
case LLM_ARCH_MEGREZ_MOE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);

// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);

for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];

layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);

layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);

if (i == 0) {
// dense FFN
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
} else {
// MoE FFN
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, i), {n_expert}, 0);

if (n_expert == 0) {
throw std::runtime_error("n_expert must be > 0 for MEGREZ_MOE");
}
if (n_expert_used == 0) {
throw std::runtime_error("n_expert_used must be > 0 for MEGREZ_MOE");
}

// Shared expert branch
const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;

layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0);

// Routed expert branch
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;

layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i),
{ n_embd, n_ff_exp, n_expert }, TENSOR_NOT_REQUIRED);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i),
{ n_ff_exp, n_embd, n_expert }, TENSOR_NOT_REQUIRED);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i),
{ n_embd, n_ff_exp, n_expert }, TENSOR_NOT_REQUIRED);
}
}
} break;
case LLM_ARCH_QWEN3:
case LLM_ARCH_QWEN3VL:
{
Expand Down Expand Up @@ -7165,6 +7241,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_jais>(*this, params);
} break;
case LLM_ARCH_MEGREZ_MOE:
{
llm = std::make_unique<llm_build_megrez_moe>(*this, params);
} break;
case LLM_ARCH_NEMOTRON:
{
llm = std::make_unique<llm_build_nemotron>(*this, params);
Expand Down Expand Up @@ -7509,6 +7589,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GPTNEOX:
case LLM_ARCH_CODESHELL:
case LLM_ARCH_ORION:
case LLM_ARCH_MEGREZ_MOE:
case LLM_ARCH_NEMOTRON:
case LLM_ARCH_EXAONE:
case LLM_ARCH_EXAONE4:
Expand Down
Loading
Loading