Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions torchtitan/experiments/qwen3/README.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
**The Qwen3 model is still under development.**


#### Available features
QWEN3 0.6B Dense model is available for:
## Available features
#### Dense Model
- Qwen3 dense model:
- supports FSDP/HSDP, TP, DDP.
- Supports AC, torch.compile.
- Qwen3 MoE model:
- Supports FSDP/HSDP, TP, DDP, EP.
- Supports AC, torch.compile.
- MoE models use Token Choice routing, which is using auxiluary-loss-free load balancing algorithm.

- FSDP/HSDP, TP, DDP, AC, compile support

Other model sizes are added to the args, but toml file configs need to be added and tested.

#### Download Qwen3 tokenizer
## Download Qwen3 tokenizer
```python scripts/download_hf_assets.py --repo_id <hf_repo_name> --assets tokenizer```

eg, for Qwen3 0.6B model, the HF repo name is `Qwen/Qwen3-0.6B`. For 1.7B model, the HF repo name is `Qwen/Qwen3-1.7B`.

#### Parity with HF

Model parity test has been done and results suggest parity with HF implementation.

#### To be added
## To be added
- Modeling
- Variants of Dense models up to 32B
- MoE alternatives
- CP is not supported currently because of RoPE embedding implementation details.
- `StateDictAdapter` support for MoE model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is next step right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is on my plate


- Testing
- Learning rate verifying: verify learning rate and schedule with real training jobs (eg, 3k stps), or find official references.
Expand Down
71 changes: 71 additions & 0 deletions torchtitan/experiments/qwen3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.components.validate import build_validator
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.models.moe import MoEArgs
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

from .infra.parallelize import parallelize_qwen3
Expand Down Expand Up @@ -104,6 +105,76 @@
hidden_dim=25600,
rope_theta=1000000,
),
# Qwen3-MoE models
"debugmodel_moe": Qwen3ModelArgs(
vocab_size=151936,
max_seq_len=4096,
head_dim=128,
dim=1024,
n_layers=28,
n_heads=16,
n_kv_heads=8,
qk_norm=True,
hidden_dim=3072,
rope_theta=1000000,
moe_enabled=True,
moe_inter_dim=768,
moe_args=MoEArgs(
num_experts=64,
num_shared_experts=0,
top_k=8,
score_func="softmax",
route_norm=True,
route_scale=1.0,
score_before_experts=False,
),
),
"30B-A3B": Qwen3ModelArgs(
vocab_size=151936,
max_seq_len=4096,
head_dim=128,
dim=5120,
n_layers=64,
n_heads=64,
n_kv_heads=8,
qk_norm=True,
hidden_dim=25600,
rope_theta=1000000,
moe_enabled=True,
moe_inter_dim=768,
moe_args=MoEArgs(
num_experts=128,
num_shared_experts=0,
top_k=8,
score_func="softmax",
route_norm=True,
route_scale=1.0,
score_before_experts=False,
),
),
"235B-A22B": Qwen3ModelArgs(
vocab_size=151936,
max_seq_len=4096,
head_dim=128,
dim=4096,
n_layers=94,
n_heads=64,
n_kv_heads=4,
qk_norm=True,
hidden_dim=12288,
rope_theta=5000000,
moe_enabled=True,
moe_inter_dim=1536,
moe_args=MoEArgs(
num_experts=128,
num_shared_experts=0, # no shared experts, double check
top_k=8, # num_experts_per_tok
score_func="softmax", # need double check
route_norm=True,
route_scale=1.0, # not needed, need double check
score_before_experts=False,
),
),
}


Expand Down
61 changes: 49 additions & 12 deletions torchtitan/experiments/qwen3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import NoParallel, ParallelDims
from torchtitan.distributed.activation_checkpoint import apply_ac
from torchtitan.models.llama3.infra.parallelize import (
from torchtitan.experiments.llama4.infra.parallelize import (
apply_compile,
apply_ddp,
apply_fsdp,
apply_moe_ep_tp,
)
from torchtitan.models.llama3.infra.parallelize import apply_ddp
from torchtitan.tools.logging import logger


Expand Down Expand Up @@ -84,14 +85,29 @@ def parallelize_qwen3(
# all-gather happens in high precision.
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise

apply_tp(
apply_non_moe_tp(
model,
world_mesh["tp"],
loss_parallel=not job_config.parallelism.disable_loss_parallel,
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
)

if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
apply_moe_ep_tp(
model,
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None,
ep_tp_mesh=(
world_mesh["ep", "tp"]
if parallel_dims.tp_enabled
and parallel_dims.ep_enabled
and parallel_dims.etp_enabled
else None
),
etp_enabled=parallel_dims.etp_enabled,
)

if job_config.activation_checkpoint.mode != "none":
apply_ac(
model,
Expand All @@ -111,15 +127,30 @@ def parallelize_qwen3(
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
else:
dp_mesh_dim_names = ("dp_shard_cp",)
dp_mesh = world_mesh[tuple(dp_mesh_dim_names)]

# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
dp_mod_ep_mesh_dim_names = []
if parallel_dims.ep_enabled:
if parallel_dims.dp_replicate_enabled:
dp_mod_ep_mesh_dim_names.append("dp_replicate")
dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep")

apply_fsdp(
model,
world_mesh[tuple(dp_mesh_dim_names)],
dp_mesh,
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
pp_enabled=parallel_dims.pp_enabled,
cpu_offload=job_config.training.enable_cpu_offload,
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
ep_degree=parallel_dims.ep,
dp_mod_ep_mesh=(
world_mesh[tuple(dp_mod_ep_mesh_dim_names)]
if parallel_dims.ep_enabled
else None
),
gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
)

if parallel_dims.dp_replicate_enabled:
Expand Down Expand Up @@ -149,7 +180,7 @@ def parallelize_qwen3(
return model


def apply_tp(
def apply_non_moe_tp(
model: nn.Module,
tp_mesh: DeviceMesh,
loss_parallel: bool,
Expand Down Expand Up @@ -218,15 +249,21 @@ def apply_tp(
"attention.k_norm": NoParallel(use_local_output=False),
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": colwise_parallel(),
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
"feed_forward.w3": colwise_parallel(),
}

if not transformer_block.moe_enabled:
layer_plan.update(
{
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": colwise_parallel(),
"feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)),
"feed_forward.w3": colwise_parallel(),
}
)

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
Expand Down
8 changes: 7 additions & 1 deletion torchtitan/experiments/qwen3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


from dataclasses import dataclass
from dataclasses import dataclass, field

from torch import nn

from torchtitan.config import JobConfig
from torchtitan.models.moe import MoEArgs
from torchtitan.protocols.train_spec import BaseModelArgs

from torchtitan.tools.logging import logger
Expand Down Expand Up @@ -40,6 +41,11 @@ class Qwen3ModelArgs(BaseModelArgs):

enable_weight_tying: bool = False

# MoE params
moe_enabled: bool = False
moe_inter_dim: int = 768
moe_args: MoEArgs = field(default_factory=MoEArgs)

def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
seq_len = job_config.training.seq_len
if seq_len > self.max_seq_len:
Expand Down
35 changes: 26 additions & 9 deletions torchtitan/experiments/qwen3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch import nn

from torchtitan.models.attention import build_attention
from torchtitan.models.moe import MoE
from torchtitan.protocols.train_spec import ModelProtocol

from .args import Qwen3ModelArgs
Expand Down Expand Up @@ -282,9 +283,18 @@ def __init__(self, layer_id: int, model_args: Qwen3ModelArgs):
self.dim = model_args.dim

self.attention = Attention(model_args)
self.feed_forward = FeedForward(
dim=model_args.dim, hidden_dim=model_args.hidden_dim
)

self.moe_enabled = model_args.moe_enabled
if self.moe_enabled:
self.moe = MoE(
model_args.moe_args,
dim=model_args.dim,
hidden_dim=model_args.moe_inter_dim,
)
else:
self.feed_forward = FeedForward(
dim=model_args.dim, hidden_dim=model_args.hidden_dim
)
self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)

Expand All @@ -309,15 +319,22 @@ def forward(
torch.Tensor: Output tensor after applying attention and feedforward layers.

"""
h = x + self.attention(self.attention_norm(x), rope_cache)
out = h + self.feed_forward(self.ffn_norm(h))
return out
x = x + self.attention(self.attention_norm(x), rope_cache)

def init_weights(self):
if self.moe_enabled:
x = x + self.moe(self.ffn_norm(x))
else:
x = x + self.feed_forward(self.ffn_norm(x))
return x

def init_weights(self, buffer_device: torch.device):
for norm in (self.attention_norm, self.ffn_norm):
norm.reset_parameters()
self.attention.init_weights(self.weight_init_std)
self.feed_forward.init_weights(self.weight_init_std)
if self.moe_enabled:
self.moe.init_weights(self.weight_init_std, buffer_device)
else:
self.feed_forward.init_weights(self.weight_init_std)


class Qwen3Model(nn.Module, ModelProtocol):
Expand Down Expand Up @@ -384,7 +401,7 @@ def init_weights(
nn.init.normal_(self.tok_embeddings.weight)
for layer in self.layers.values():
if layer is not None:
layer.init_weights()
layer.init_weights(buffer_device)
if self.norm is not None:
self.norm.reset_parameters()
final_out_std = self.model_args.dim**-0.5
Expand Down
64 changes: 64 additions & 0 deletions torchtitan/experiments/qwen3/train_configs/qwen3_moe_debug.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
[job]
dump_folder = "./outputs"
description = "Qwen 3 MoE debug model training"

[profiling]
enable_profiling = false
save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 1
enable_tensorboard = false
save_tb_folder = "tb"

[model]
name = "qwen3"
flavor = "debugmodel_moe"
hf_assets_path = "./tests/assets/tokenizer"
# converters = ["float8"]

[optimizer]
name = "AdamW"
lr = 3e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 2 # lr scheduler warm up, 20% total steps

[training]
local_batch_size = 4
seq_len = 4096
max_norm = 1.0 # grad norm clipping
steps = 10
dataset = "c4_test"

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
fsdp_reshard_after_forward = "default" # default / never / always
tensor_parallel_degree = 1
enable_async_tensor_parallel = false
expert_parallel_degree = 1
expert_tensor_parallel_degree = 1

[checkpoint]
enable = false
folder = "checkpoint"
interval = 10
last_save_model_only = false
export_dtype = "float16"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy

[compile]
enable=false
components = ["model", "loss"]

[float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]
Loading