Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
179f28f
Add Llama4 Text Model
Jul 6, 2025
5518c12
bugfix
Jul 14, 2025
7b733c5
bugfixes - temperature, moe, rms
Jul 21, 2025
4f05848
bugfix gen_cnfig
Jul 21, 2025
0f7d3a1
convert_weight works; tensor parallel bug in model compilation stage
Jul 21, 2025
dc1c4c0
bugfix gen config
Jul 21, 2025
c4f780b
bugfix convert weight
Jul 21, 2025
e4aa23d
avoid low level TIR error when topk = 1
Jul 22, 2025
8943447
explicitly define vocab_size
Jul 22, 2025
e056361
bugfix weight loading
Jul 24, 2025
c70b280
bugfix moe_sum; inference runs but gibberish
Jul 28, 2025
f824c34
text config
MasterJH5574 Jul 28, 2025
b137fbf
bugfix rope
giterator Aug 11, 2025
b2fb219
bugfix qk norm
giterator Aug 11, 2025
c0b0a56
reimplement moe, tp; still need to test
giterator Aug 17, 2025
f87ebc5
fixed TP
giterator Aug 17, 2025
d95ee78
moe activations differ
giterator Aug 22, 2025
ef397cb
custom rope for llama4
giterator Sep 9, 2025
4bf8134
updated conv template - need to bedug
giterator Sep 10, 2025
b29ad59
bugfix moe
giterator Sep 12, 2025
e783505
Remove dead code, prints
giterator Sep 12, 2025
b39ff28
cleanup conv template
giterator Sep 12, 2025
46bd618
remove comments
giterator Sep 13, 2025
9531dcc
added TODO
giterator Sep 13, 2025
6c06b25
format
giterator Sep 13, 2025
bb21b40
fixed var naming
giterator Sep 17, 2025
d4b34a7
fixed lint
giterator Sep 17, 2025
02a2745
corrected 3rd party
giterator Sep 17, 2025
4da684a
removed unused vars
giterator Sep 17, 2025
6469b7d
lint changes
giterator Sep 17, 2025
272e65a
black CI
giterator Sep 17, 2025
0e76219
disable too many locals
giterator Sep 17, 2025
dc9f4ca
black CI
giterator Sep 17, 2025
580ae60
remove awq
giterator Sep 17, 2025
623f83f
black
giterator Sep 17, 2025
1737772
remove awq import
giterator Sep 17, 2025
4f1818a
remove awq import
giterator Sep 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions python/mlc_llm/conversation_template/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,27 @@

from .registry import ConvTemplateRegistry

# Llama4 - same as Llama3.1 except naming has changed slightly
ConvTemplateRegistry.register_conv_template(
Conversation(
name="llama-4",
system_template="",
system_message="",
roles={
"user": "<|header_start|>user",
"assistant": "<|header_start|>assistant",
"tool": "<|header_start|>ipython",
},
seps=["<|eot|>"],
role_content_sep="<|header_end|>\n\n",
role_empty_sep="<|header_end|>\n\n",
stop_str=[],
stop_token_ids=[200001, 200007, 200008], # "<|end_of_text|>", "<|eom|>", "<|eot|>"
system_prefix_token_ids=[200000], # "<|begin_of_text|>"
add_role_after_system_message=False,
)
)

# Llama3.1 -- same as Llama3 except stop token ids and stop str
ConvTemplateRegistry.register_conv_template(
Conversation(
Expand Down
1 change: 1 addition & 0 deletions python/mlc_llm/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
# FIXME: Copy RWKV tokenizer file # pylint: disable=fixme

CONV_TEMPLATES = {
"llama-4",
"llama-3",
"llama-3_1",
"chatml",
Expand Down
Empty file.
119 changes: 119 additions & 0 deletions python/mlc_llm/model/llama4/llama4_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
This file specifies how MLC's Llama parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""

import functools

import numpy as np

from mlc_llm.loader import ExternMapping
from mlc_llm.quantization import Quantization

from .llama4_model import Llama4Config, Llama4ForCausalLM


def huggingface(model_config: Llama4Config, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of HuggingFace PyTorch parameters.

Parameters
----------
model_config : Llama4Config
The configuration of the Llama model.

quantization : Quantization
The quantization configuration.

Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to HuggingFace PyTorch.
"""
model = Llama4ForCausalLM(model_config)
if quantization is not None:
model.to(quantization.model_dtype)
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(),
allow_extern=True,
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for i in range(model_config.text_config.num_hidden_layers):
# Add shared expert weights
mlp = f"model.layers.{i}.feed_forward.shared_expert"
mlc_name = f"{mlp}.gate_up_proj.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
f"language_model.{mlp}.gate_proj.weight",
f"language_model.{mlp}.up_proj.weight",
],
functools.partial(
lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),
dtype=mlc_param.dtype,
),
)

# Add router weights
mlp = f"model.layers.{i}.feed_forward"
mlc_name = f"{mlp}.router.router.weight"
hf_name = f"language_model.{mlp}.router.weight"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
hf_name,
],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)

# Add experts weights
mlp = f"model.layers.{i}.feed_forward"
hf_name = f"language_model.{mlp}.experts.gate_up_proj"
mlc_name = f"{mlp}.experts.gate_up_proj"
mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
hf_name,
],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)

mlp = f"model.layers.{i}.feed_forward"
mlc_name = f"{mlp}.experts.down_proj"
hf_name = f"language_model.{mlp}.experts.down_proj"

mlc_param = named_parameters[mlc_name]
mapping.add_mapping(
mlc_name,
[
hf_name,
],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)

for mlc_name, mlc_param in named_parameters.items():
if mlc_name not in mapping.param_map:
mapping.add_mapping(
mlc_name,
[f"language_model.{mlc_name}"],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)
return mapping
Loading
Loading