Skip to content

Commit 8f8f227

Browse files
authored
convert : add Llama4ForCausalLM (#16042)
* convert : add Llama4ForCausalLM * handle swa * half working version * fix use_kq_norm * fix use_kq_norm
1 parent c959b67 commit 8f8f227

File tree

4 files changed

+50
-12
lines changed

4 files changed

+50
-12
lines changed

convert_hf_to_gguf.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2393,7 +2393,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
23932393
return [] # skip other tensors
23942394

23952395

2396-
@ModelBase.register("Llama4ForConditionalGeneration")
2396+
@ModelBase.register(
2397+
"Llama4ForConditionalGeneration",
2398+
"Llama4ForCausalLM",
2399+
)
23972400
class Llama4Model(LlamaModel):
23982401
model_arch = gguf.MODEL_ARCH.LLAMA4
23992402
undo_permute = False
@@ -2411,6 +2414,10 @@ def set_gguf_parameters(self):
24112414
super().set_gguf_parameters()
24122415
self.gguf_writer.add_interleave_moe_layer_step(self.hparams["interleave_moe_layer_step"])
24132416
self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"])
2417+
if "layer_types" in self.hparams:
2418+
if all(lt == "full_attention" for lt in self.hparams["layer_types"]):
2419+
# all layers are full attention (for MobileLLM), disable swa
2420+
self.gguf_writer.add_sliding_window(0)
24142421

24152422
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
24162423
if name.startswith("language_model."):

src/llama-hparams.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ struct llama_hparams {
149149
bool causal_attn = true;
150150
bool use_alibi = false;
151151
bool attn_soft_cap = false;
152-
bool use_kq_norm = true;
152+
bool use_kq_norm = false;
153153

154154
// for Classifiers
155155
uint32_t n_cls_out = 1;

src/llama-model.cpp

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ const char * llm_type_name(llm_type type) {
3636
case LLM_TYPE_80M: return "80M";
3737
case LLM_TYPE_109M: return "109M";
3838
case LLM_TYPE_137M: return "137M";
39+
case LLM_TYPE_140M: return "140M";
3940
case LLM_TYPE_160M: return "160M";
4041
case LLM_TYPE_190M: return "190M";
4142
case LLM_TYPE_220M: return "220M";
@@ -44,13 +45,15 @@ const char * llm_type_name(llm_type type) {
4445
case LLM_TYPE_270M: return "270M";
4546
case LLM_TYPE_335M: return "335M";
4647
case LLM_TYPE_350M: return "350M";
48+
case LLM_TYPE_360M: return "360M";
4749
case LLM_TYPE_410M: return "410M";
4850
case LLM_TYPE_450M: return "450M";
4951
case LLM_TYPE_475M: return "475M";
5052
case LLM_TYPE_558M: return "558M";
5153
case LLM_TYPE_700M: return "700M";
5254
case LLM_TYPE_770M: return "770M";
5355
case LLM_TYPE_780M: return "780M";
56+
case LLM_TYPE_950M: return "950M";
5457
case LLM_TYPE_0_3B: return "0.3B";
5558
case LLM_TYPE_0_5B: return "0.5B";
5659
case LLM_TYPE_0_6B: return "0.6B";
@@ -622,19 +625,32 @@ void llama_model::load_hparams(llama_model_loader & ml) {
622625
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
623626
ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step);
624627

625-
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
626-
hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
627-
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
628+
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
629+
if (found_swa && hparams.n_swa == 0) {
630+
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
631+
hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope
632+
} else {
633+
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
634+
hparams.n_swa = 8192;
635+
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
636+
}
628637

629638
switch (hparams.n_expert) {
639+
case 0: {
640+
// MobileLLM (no MoE)
641+
switch (hparams.n_embd) {
642+
case 2048: type = LLM_TYPE_140M; break;
643+
case 4096: type = LLM_TYPE_360M; break;
644+
case 6144: type = LLM_TYPE_950M; break;
645+
default: type = LLM_TYPE_UNKNOWN;
646+
}
647+
} break;
630648
case 16: type = LLM_TYPE_17B_16E; break;
631649
case 128: type = LLM_TYPE_17B_128E; break;
632650
default: type = LLM_TYPE_UNKNOWN;
633651
}
634652

635-
if (type == LLM_TYPE_17B_128E) {
636-
hparams.use_kq_norm = false;
637-
}
653+
hparams.use_kq_norm = type != LLM_TYPE_17B_128E;
638654
} break;
639655
case LLM_ARCH_ARCEE:
640656
{
@@ -2454,9 +2470,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
24542470
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
24552471
}
24562472

2457-
GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Llama 4 requires n_moe_layer_step > 0");
24582473
for (int i = 0; i < n_layer; ++i) {
2459-
bool is_moe_layer = (i + 1) % hparams.n_moe_layer_step == 0;
2474+
bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0;
24602475

24612476
auto & layer = layers[i];
24622477

@@ -6328,6 +6343,14 @@ struct llm_build_llama : public llm_graph_context {
63286343
cb(Kcur, "Kcur", il);
63296344
cb(Vcur, "Vcur", il);
63306345

6346+
if (hparams.use_kq_norm) {
6347+
// Llama4TextL2Norm
6348+
Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps);
6349+
Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps);
6350+
cb(Qcur, "Qcur_normed", il);
6351+
cb(Kcur, "Kcur_normed", il);
6352+
}
6353+
63316354
cur = build_attn(inp_attn,
63326355
model.layers[il].wo, model.layers[il].bo,
63336356
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
@@ -6435,7 +6458,8 @@ struct llm_build_llama_iswa : public llm_graph_context {
64356458
for (int il = 0; il < n_layer; ++il) {
64366459
ggml_tensor * inpSA = inpL;
64376460

6438-
const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0;
6461+
const bool use_rope = hparams.n_no_rope_layer_step > 0 &&
6462+
(il + 1) % hparams.n_no_rope_layer_step != 0;
64396463

64406464
// norm
64416465
cur = build_norm(inpL,
@@ -18981,7 +19005,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1898119005
} break;
1898219006
case LLM_ARCH_LLAMA4:
1898319007
{
18984-
llm = std::make_unique<llm_build_llama_iswa>(*this, params);
19008+
if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) {
19009+
llm = std::make_unique<llm_build_llama>(*this, params);
19010+
} else {
19011+
llm = std::make_unique<llm_build_llama_iswa>(*this, params);
19012+
}
1898519013
} break;
1898619014
case LLM_ARCH_DECI:
1898719015
{

src/llama-model.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ enum llm_type {
2828
LLM_TYPE_80M,
2929
LLM_TYPE_109M,
3030
LLM_TYPE_137M,
31+
LLM_TYPE_140M,
3132
LLM_TYPE_160M,
3233
LLM_TYPE_190M,
3334
LLM_TYPE_220M,
@@ -36,13 +37,15 @@ enum llm_type {
3637
LLM_TYPE_270M,
3738
LLM_TYPE_335M,
3839
LLM_TYPE_350M,
40+
LLM_TYPE_360M,
3941
LLM_TYPE_410M,
4042
LLM_TYPE_450M,
4143
LLM_TYPE_475M,
4244
LLM_TYPE_558M,
4345
LLM_TYPE_700M,
4446
LLM_TYPE_770M,
4547
LLM_TYPE_780M,
48+
LLM_TYPE_950M,
4649
LLM_TYPE_0_3B,
4750
LLM_TYPE_0_5B,
4851
LLM_TYPE_0_6B,

0 commit comments

Comments
 (0)