diff --git a/torchtitan/experiments/qwen3/README.md b/torchtitan/experiments/qwen3/README.md index b2e052241..6fdf92340 100644 --- a/torchtitan/experiments/qwen3/README.md +++ b/torchtitan/experiments/qwen3/README.md @@ -1,26 +1,29 @@ **The Qwen3 model is still under development.** -#### Available features -QWEN3 0.6B Dense model is available for: +## Available features +#### Dense Model +- Qwen3 dense model: + - supports FSDP/HSDP, TP, DDP. + - Supports AC, torch.compile. +- Qwen3 MoE model: + - Supports FSDP/HSDP, TP, DDP, EP. + - Supports AC, torch.compile. + - MoE models use Token Choice routing, which is using auxiluary-loss-free load balancing algorithm. -- FSDP/HSDP, TP, DDP, AC, compile support Other model sizes are added to the args, but toml file configs need to be added and tested. -#### Download Qwen3 tokenizer +## Download Qwen3 tokenizer ```python scripts/download_hf_assets.py --repo_id --assets tokenizer``` eg, for Qwen3 0.6B model, the HF repo name is `Qwen/Qwen3-0.6B`. For 1.7B model, the HF repo name is `Qwen/Qwen3-1.7B`. -#### Parity with HF -Model parity test has been done and results suggest parity with HF implementation. - -#### To be added +## To be added - Modeling - - Variants of Dense models up to 32B - - MoE alternatives + - CP is not supported currently because of RoPE embedding implementation details. + - `StateDictAdapter` support for MoE model - Testing - Learning rate verifying: verify learning rate and schedule with real training jobs (eg, 3k stps), or find official references. diff --git a/torchtitan/experiments/qwen3/__init__.py b/torchtitan/experiments/qwen3/__init__.py index b5aa870d4..68af67151 100644 --- a/torchtitan/experiments/qwen3/__init__.py +++ b/torchtitan/experiments/qwen3/__init__.py @@ -12,6 +12,7 @@ from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.components.validate import build_validator from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.models.moe import MoEArgs from torchtitan.protocols.train_spec import register_train_spec, TrainSpec from .infra.parallelize import parallelize_qwen3 @@ -104,6 +105,76 @@ hidden_dim=25600, rope_theta=1000000, ), + # Qwen3-MoE models + "debugmodel_moe": Qwen3ModelArgs( + vocab_size=151936, + max_seq_len=4096, + head_dim=128, + dim=1024, + n_layers=28, + n_heads=16, + n_kv_heads=8, + qk_norm=True, + hidden_dim=3072, + rope_theta=1000000, + moe_enabled=True, + moe_inter_dim=768, + moe_args=MoEArgs( + num_experts=64, + num_shared_experts=0, + top_k=8, + score_func="softmax", + route_norm=True, + route_scale=1.0, + score_before_experts=False, + ), + ), + "30B-A3B": Qwen3ModelArgs( + vocab_size=151936, + max_seq_len=4096, + head_dim=128, + dim=5120, + n_layers=64, + n_heads=64, + n_kv_heads=8, + qk_norm=True, + hidden_dim=25600, + rope_theta=1000000, + moe_enabled=True, + moe_inter_dim=768, + moe_args=MoEArgs( + num_experts=128, + num_shared_experts=0, + top_k=8, + score_func="softmax", + route_norm=True, + route_scale=1.0, + score_before_experts=False, + ), + ), + "235B-A22B": Qwen3ModelArgs( + vocab_size=151936, + max_seq_len=4096, + head_dim=128, + dim=4096, + n_layers=94, + n_heads=64, + n_kv_heads=4, + qk_norm=True, + hidden_dim=12288, + rope_theta=5000000, + moe_enabled=True, + moe_inter_dim=1536, + moe_args=MoEArgs( + num_experts=128, + num_shared_experts=0, # no shared experts, double check + top_k=8, # num_experts_per_tok + score_func="softmax", # need double check + route_norm=True, + route_scale=1.0, # not needed, need double check + score_before_experts=False, + ), + ), } diff --git a/torchtitan/experiments/qwen3/infra/parallelize.py b/torchtitan/experiments/qwen3/infra/parallelize.py index 494962264..8367eb445 100644 --- a/torchtitan/experiments/qwen3/infra/parallelize.py +++ b/torchtitan/experiments/qwen3/infra/parallelize.py @@ -23,11 +23,12 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac -from torchtitan.models.llama3.infra.parallelize import ( +from torchtitan.experiments.llama4.infra.parallelize import ( apply_compile, - apply_ddp, apply_fsdp, + apply_moe_ep_tp, ) +from torchtitan.models.llama3.infra.parallelize import apply_ddp from torchtitan.tools.logging import logger @@ -84,7 +85,7 @@ def parallelize_qwen3( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise - apply_tp( + apply_non_moe_tp( model, world_mesh["tp"], loss_parallel=not job_config.parallelism.disable_loss_parallel, @@ -92,6 +93,21 @@ def parallelize_qwen3( enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + apply_moe_ep_tp( + model, + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, + ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + ep_tp_mesh=( + world_mesh["ep", "tp"] + if parallel_dims.tp_enabled + and parallel_dims.ep_enabled + and parallel_dims.etp_enabled + else None + ), + etp_enabled=parallel_dims.etp_enabled, + ) + if job_config.activation_checkpoint.mode != "none": apply_ac( model, @@ -111,15 +127,30 @@ def parallelize_qwen3( dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") else: dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + dp_mod_ep_mesh_dim_names = [] + if parallel_dims.ep_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mod_ep_mesh_dim_names.append("dp_replicate") + dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") apply_fsdp( model, - world_mesh[tuple(dp_mesh_dim_names)], + dp_mesh, param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], pp_enabled=parallel_dims.pp_enabled, cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ep_degree=parallel_dims.ep, + dp_mod_ep_mesh=( + world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + if parallel_dims.ep_enabled + else None + ), + gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) if parallel_dims.dp_replicate_enabled: @@ -149,7 +180,7 @@ def parallelize_qwen3( return model -def apply_tp( +def apply_non_moe_tp( model: nn.Module, tp_mesh: DeviceMesh, loss_parallel: bool, @@ -218,15 +249,21 @@ def apply_tp( "attention.k_norm": NoParallel(use_local_output=False), "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), - "feed_forward": prepare_module_input( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - ), - "feed_forward.w1": colwise_parallel(), - "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), - "feed_forward.w3": colwise_parallel(), } + if not transformer_block.moe_enabled: + layer_plan.update( + { + "feed_forward": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": colwise_parallel(), + "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), + "feed_forward.w3": colwise_parallel(), + } + ) + parallelize_module( module=transformer_block, device_mesh=tp_mesh, diff --git a/torchtitan/experiments/qwen3/model/args.py b/torchtitan/experiments/qwen3/model/args.py index ca0b0fce8..e197cf359 100644 --- a/torchtitan/experiments/qwen3/model/args.py +++ b/torchtitan/experiments/qwen3/model/args.py @@ -7,11 +7,12 @@ # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -from dataclasses import dataclass +from dataclasses import dataclass, field from torch import nn from torchtitan.config import JobConfig +from torchtitan.models.moe import MoEArgs from torchtitan.protocols.train_spec import BaseModelArgs from torchtitan.tools.logging import logger @@ -40,6 +41,11 @@ class Qwen3ModelArgs(BaseModelArgs): enable_weight_tying: bool = False + # MoE params + moe_enabled: bool = False + moe_inter_dim: int = 768 + moe_args: MoEArgs = field(default_factory=MoEArgs) + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: seq_len = job_config.training.seq_len if seq_len > self.max_seq_len: diff --git a/torchtitan/experiments/qwen3/model/model.py b/torchtitan/experiments/qwen3/model/model.py index e5792cdbb..794a5d7ba 100644 --- a/torchtitan/experiments/qwen3/model/model.py +++ b/torchtitan/experiments/qwen3/model/model.py @@ -12,6 +12,7 @@ from torch import nn from torchtitan.models.attention import build_attention +from torchtitan.models.moe import MoE from torchtitan.protocols.train_spec import ModelProtocol from .args import Qwen3ModelArgs @@ -282,9 +283,18 @@ def __init__(self, layer_id: int, model_args: Qwen3ModelArgs): self.dim = model_args.dim self.attention = Attention(model_args) - self.feed_forward = FeedForward( - dim=model_args.dim, hidden_dim=model_args.hidden_dim - ) + + self.moe_enabled = model_args.moe_enabled + if self.moe_enabled: + self.moe = MoE( + model_args.moe_args, + dim=model_args.dim, + hidden_dim=model_args.moe_inter_dim, + ) + else: + self.feed_forward = FeedForward( + dim=model_args.dim, hidden_dim=model_args.hidden_dim + ) self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) @@ -309,15 +319,22 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), rope_cache) - out = h + self.feed_forward(self.ffn_norm(h)) - return out + x = x + self.attention(self.attention_norm(x), rope_cache) - def init_weights(self): + if self.moe_enabled: + x = x + self.moe(self.ffn_norm(x)) + else: + x = x + self.feed_forward(self.ffn_norm(x)) + return x + + def init_weights(self, buffer_device: torch.device): for norm in (self.attention_norm, self.ffn_norm): norm.reset_parameters() self.attention.init_weights(self.weight_init_std) - self.feed_forward.init_weights(self.weight_init_std) + if self.moe_enabled: + self.moe.init_weights(self.weight_init_std, buffer_device) + else: + self.feed_forward.init_weights(self.weight_init_std) class Qwen3Model(nn.Module, ModelProtocol): @@ -384,7 +401,7 @@ def init_weights( nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: - layer.init_weights() + layer.init_weights(buffer_device) if self.norm is not None: self.norm.reset_parameters() final_out_std = self.model_args.dim**-0.5 diff --git a/torchtitan/experiments/qwen3/train_configs/qwen3_moe_debug.toml b/torchtitan/experiments/qwen3/train_configs/qwen3_moe_debug.toml new file mode 100644 index 000000000..cb3932e33 --- /dev/null +++ b/torchtitan/experiments/qwen3/train_configs/qwen3_moe_debug.toml @@ -0,0 +1,64 @@ +[job] +dump_folder = "./outputs" +description = "Qwen 3 MoE debug model training" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +enable_tensorboard = false +save_tb_folder = "tb" + +[model] +name = "qwen3" +flavor = "debugmodel_moe" +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 3e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, 20% total steps + +[training] +local_batch_size = 4 +seq_len = 4096 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4_test" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float16" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index d0478cc96..f55cb301b 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -43,7 +43,7 @@ num_shared_experts=2, top_k=3, score_func="softmax", - route_norm=True, + route_norm=False, score_before_experts=False, ), q_lora_rank=0, @@ -66,7 +66,7 @@ num_shared_experts=2, top_k=3, score_func="softmax", - route_norm=True, + route_norm=False, score_before_experts=False, ), q_lora_rank=0, @@ -91,7 +91,7 @@ num_shared_experts=2, top_k=6, score_func="softmax", - route_norm=True, + route_norm=False, score_before_experts=False, ), q_lora_rank=0, @@ -116,7 +116,7 @@ num_shared_experts=2, top_k=6, score_func="softmax", - route_norm=True, + route_norm=False, route_scale=16.0, score_before_experts=False, ), diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 8be14ecbf..28f59aabb 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -231,7 +231,7 @@ def forward( scores, k=self.top_k, dim=1 ) - if self.score_func == "sigmoid" and self.route_norm: + if self.route_norm: denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 top_scores = top_scores / denominator top_scores = top_scores * self.route_scale