diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index 84216dadbd..51f955cd74 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -30,6 +30,7 @@ class MXConverter(ModelConverter): enabled: bool filter_fqns: List[str] mx_config: Any # MXLinearConfig type when imported + token_group_alignment_size = 32 def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): # Ensure minimum torchao versions @@ -39,8 +40,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ) torchao_version = version("torchao") - # Last torchao release was 0.12.0, so nightly build starts with 0.13.0+git... - is_nightly_build = torchao_version.startswith("0.13.0") + # Require latest release or nightly builds for prototype features + is_nightly_build = torchao_version.startswith("0.14.0") if not is_nightly_build: raise ImportError( f"torchao version {torchao_version} is too old, please install torchao nightly build and try again" @@ -52,7 +53,6 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ), "MXFP8 is only supported on SM100 or architectures" # TP not yet supported with torch.compile - model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) @@ -61,10 +61,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ), "TP not yet supported with torch.compile for mxfp8" # For MoE training with mxfp8, token group sizes must be multiples of 32 - if job_config.mx.moe_fqns_prototype: - mxfp8_block_size = 32 - set_token_group_alignment_size_m(mxfp8_block_size) - logger.info(f"Setting token group alignment size to {mxfp8_block_size}") + self.moe_fqns = job_config.mx.moe_fqns_prototype + if self.moe_fqns: + logger.info( + f"Setting token group alignment size to {self.token_group_alignment_size}" + ) + set_token_group_alignment_size_m(self.token_group_alignment_size) # Configure MXFP8 from torchao.prototype.mx_formats.config import ( @@ -94,6 +96,13 @@ def convert(self, model: nn.Module): from torchao.prototype.mx_formats.config import MXLinearConfig from torchao.quantization import quantize_ + # MoE conversion must take place before MXLinear conversion, otherwise the MXLinear will + # be converted back to nn.Linear: + # https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299 + # TODO: add warning in torchao when this happens, or find a better way to avoid this. + if self.moe_fqns: + self._convert_moe_layers(model) + assert isinstance(self.config, MXLinearConfig) quantize_( model, @@ -102,6 +111,30 @@ def convert(self, model: nn.Module): ) logger.info("Swapped to MXLinear layers") + def _convert_moe_layers(self, model: nn.Module): + """ + Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor, + to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs. + """ + from torchao.prototype.moe_training.conversion_utils import ( + MoEScalingType, + MoETrainingConfig, + ) + from torchao.quantization.quant_api import quantize_ + + def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: + for target_fqn in self.moe_fqns: + if target_fqn in cur_fqn: + return True + return False + + config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8) + quantize_(model, config=config, filter_fn=moe_module_filter_fn) + logger.info( + f"Converted MoE layers matching FQNS {self.moe_fqns} " + "to use dynamic MXFP8 quantization with scaled grouped GEMMs" + ) + def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): """ MXFP8 doesn't require any post-optimizer hooks at the moment diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 567df051cf..9dcc3483e8 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -22,6 +22,8 @@ ) from torch.distributed.tensor.parallel import ParallelStyle +from torchtitan.tools.utils import _round_up + TOKEN_GROUP_ALIGN_SIZE_M = 8 ValidTokenGroupAlignmentSize = Literal[8, 16, 32] @@ -253,6 +255,13 @@ def wrapper( experts_per_ep_rank = w1.shape[0] num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank + # Make sure max_len of permuted token indicies is divisible by TOKEN_GROUP_ALIGN_SIZE_M, + # by padding it to the nearest multiple of TOKEN_GROUP_ALIGN_SIZE_M. + x_padded_per_expert = ( + x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M + ) + padded_max_len = _round_up(x_padded_per_expert, TOKEN_GROUP_ALIGN_SIZE_M) + with torch.no_grad(): ( permuted_indices, @@ -262,7 +271,7 @@ def wrapper( num_tokens_per_expert, experts_per_ep_rank, num_ep_ranks, - x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M, + padded_max_len, TOKEN_GROUP_ALIGN_SIZE_M, ) diff --git a/torchtitan/experiments/llama4/__init__.py b/torchtitan/experiments/llama4/__init__.py index d177e93dc6..b0a2047e9e 100644 --- a/torchtitan/experiments/llama4/__init__.py +++ b/torchtitan/experiments/llama4/__init__.py @@ -27,15 +27,20 @@ llama4_configs = { "debugmodel": TransformerModelArgs( - dim=256, - n_layers=6, - n_heads=16, - vocab_size=2000, + dim=5120, + n_layers=2, + n_heads=40, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=2048, rope_theta=500000, + max_seq_len=10485760, + moe_args=MoEArgs(num_experts=4), + interleave_moe_layer_step=1, ), "17bx16e": TransformerModelArgs( dim=5120, - n_layers=48, + n_layers=2, n_heads=40, n_kv_heads=8, ffn_dim_multiplier=1.2, diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index 2d07c03c50..ef029adab8 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -75,3 +75,7 @@ enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false filter_fqns = ["output", "router.gate"] moe_fqns_prototype = ["experts"] + +[mx] +filter_fqns = ["output", "router.gate"] +moe_fqns_prototype = ["experts,shared_expert"] diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml index 6de020fad0..a4a2d35c22 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml @@ -67,3 +67,7 @@ components = ["model", "loss"] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false filter_fqns = ["output", "router.gate"] + +[mx] +filter_fqns = ["output", "router.gate"] +moe_fqns_prototype = ["experts,shared_expert"] diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml index bc9e9bc4f5..e2856b1e8b 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml @@ -65,3 +65,7 @@ components = ["model", "loss"] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false filter_fqns = ["output", "router.gate"] + +[mx] +filter_fqns = ["output", "router.gate"] +moe_fqns_prototype = ["experts,shared_expert"] diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index f66361a6d2..a1191c573f 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -20,8 +20,6 @@ flex_attention, ) -from torchtitan.tools.utils import has_cuda_capability - # FlexAttention mask type. For each mask type, we initialize it at most once per # batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to # track the initialized mask. @@ -203,14 +201,11 @@ def _init_backend(cls) -> None: if cls.backends: return - # Add CuDNN on B200 w/ highest priority cls.backends = [ SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH, ] - if has_cuda_capability(10, 0): - cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) def forward( self, diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index 37273a71de..fea93995ad 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -202,3 +202,9 @@ def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]: yield finally: torch.set_default_dtype(old_dtype) + + +def _round_up(x: int, y: int) -> int: + """Round up x to the nearest multiple of y.""" + x_ceil_div_y = (x + y - 1) // y + return x_ceil_div_y * y