Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1704,7 +1704,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.system_prompt = value;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_DIFFUSION}));
add_opt(common_arg(
{"--no-perf"},
string_format("disable internal libllama performance timings (default: %s)", params.no_perf ? "true" : "false"),
Expand Down
78 changes: 78 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
res = "mellum"
if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
# ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
res = "llada-moe"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -8239,6 +8242,81 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("LLaDAMoEModel", "LLaDAMoEModelLM")
class LLaDAMoEModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLADA_MOE

def set_gguf_parameters(self):
super().set_gguf_parameters()
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)

if (expert_intermediate_size := self.hparams.get("expert_intermediate_size")) is not None:
self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size)

# number of experts used per token (top-k)
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)

mask_token_id = 156895
if self.hparams.get("mask_token_id") is not None:
mask_token_id = self.hparams["mask_token_id"]

self.gguf_writer.add_mask_token_id(mask_token_id)
self.gguf_writer.add_causal_attention(False)
self.gguf_writer.add_diffusion_shift_logits(False)

_experts: list[dict[str, Tensor]] | None = None

# Copied from: Qwen2MoeModel
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []

# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []

for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)

merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"

new_name = self.map_tensor_name(merged_name)

tensors.append((new_name, data_torch))
return tensors
else:
return []

return [(self.map_tensor_name(name), data_torch)]

# Copied from: Qwen2MoeModel
def prepare_tensors(self):
super().prepare_tensors()

if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")



@ModelBase.register("HunYuanDenseV1ForCausalLM")
class HunYuanModel(TextModel):
model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
{"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
{"name": "mellum", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
{"name": "llada-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base", },
]

# some models are known to be broken upstream, so we will skip them as exceptions
Expand Down
24 changes: 17 additions & 7 deletions examples/diffusion/diffusion-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,19 +510,27 @@ static void diffusion_generate(llama_context * ctx,
n_generated = params.max_length;
}

static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) {
if (!use_chat_template) {
return prompt;
}

auto chat_templates = common_chat_templates_init(model, "");

common_chat_templates_inputs inputs;
common_chat_msg user_msg;
user_msg.role = "user";
user_msg.content = prompt;
inputs.add_generation_prompt = true;
common_chat_msg system_msg;

if (!system_prompt.empty()) {
system_msg.role = "system";
system_msg.content = system_prompt;
inputs.messages.push_back(system_msg);
}

common_chat_msg user_msg;
user_msg.role = "user";
user_msg.content = prompt;

inputs.messages.push_back(user_msg);
inputs.add_generation_prompt = true;

auto result = common_chat_templates_apply(chat_templates.get(), inputs);

Expand Down Expand Up @@ -579,7 +587,8 @@ int main(int argc, char ** argv) {
llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);

const llama_vocab * vocab = llama_model_get_vocab(model);
std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);

std::string formatted_prompt = format_input_text(params.prompt, params.system_prompt, params.enable_chat_template, model);

std::vector<llama_token> input_tokens = common_tokenize(vocab,
formatted_prompt,
Expand All @@ -596,6 +605,7 @@ int main(int argc, char ** argv) {
}

llama_token mask_token_id = llama_vocab_mask(vocab);

GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);

bool visual_mode = params.diffusion.visual_mode;
Expand Down
20 changes: 19 additions & 1 deletion gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ class MODEL_ARCH(IntEnum):
DREAM = auto()
SMALLTHINKER = auto()
LLADA = auto()
LLADA_MOE = auto()
SEED_OSS = auto()


Expand Down Expand Up @@ -735,6 +736,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.DREAM: "dream",
MODEL_ARCH.SMALLTHINKER: "smallthinker",
MODEL_ARCH.LLADA: "llada",
MODEL_ARCH.LLADA_MOE: "llada-moe",
MODEL_ARCH.SEED_OSS: "seed_oss",
}

Expand Down Expand Up @@ -2693,7 +2695,23 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
# TODO
MODEL_ARCH.LLADA_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
]
}

# tensors that will not be serialized
Expand Down
22 changes: 22 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DREAM, "dream" },
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
{ LLM_ARCH_LLADA, "llada" },
{ LLM_ARCH_LLADA_MOE, "llada-moe" },
{ LLM_ARCH_SEED_OSS, "seed_oss" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
Expand Down Expand Up @@ -2147,6 +2148,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_LLADA_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_SEED_OSS,
{
Expand Down Expand Up @@ -2427,6 +2448,7 @@ bool llm_arch_is_diffusion(const llm_arch & arch) {
switch (arch) {
case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA:
case LLM_ARCH_LLADA_MOE:
return true;
default:
return false;
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ enum llm_arch {
LLM_ARCH_DREAM,
LLM_ARCH_SMALLTHINKER,
LLM_ARCH_LLADA,
LLM_ARCH_LLADA_MOE,
LLM_ARCH_SEED_OSS,
LLM_ARCH_UNKNOWN,
};
Expand Down
Loading
Loading