Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions torchtitan/components/quantization/mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class MXConverter(ModelConverter):
enabled: bool
filter_fqns: List[str]
mx_config: Any # MXLinearConfig type when imported
token_group_alignment_size = 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make this a MACRO if it's not really a "variable". E.g. you can keep a central map in quantization/__init__.py for both fp8 and mx.


def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
# Ensure minimum torchao versions
Expand All @@ -39,8 +40,8 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
)
torchao_version = version("torchao")

# Last torchao release was 0.12.0, so nightly build starts with 0.13.0+git...
is_nightly_build = torchao_version.startswith("0.13.0")
# Require latest release or nightly builds for prototype features
is_nightly_build = torchao_version.startswith("0.14.0")
if not is_nightly_build:
raise ImportError(
f"torchao version {torchao_version} is too old, please install torchao nightly build and try again"
Expand All @@ -52,7 +53,6 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
), "MXFP8 is only supported on SM100 or architectures"

# TP not yet supported with torch.compile

model_compile_enabled = (
job_config.compile.enable and "model" in job_config.compile.components
)
Expand All @@ -61,10 +61,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
), "TP not yet supported with torch.compile for mxfp8"

# For MoE training with mxfp8, token group sizes must be multiples of 32
if job_config.mx.moe_fqns_prototype:
mxfp8_block_size = 32
set_token_group_alignment_size_m(mxfp8_block_size)
logger.info(f"Setting token group alignment size to {mxfp8_block_size}")
self.moe_fqns = job_config.mx.moe_fqns_prototype
if self.moe_fqns:
logger.info(
f"Setting token group alignment size to {self.token_group_alignment_size}"
)
set_token_group_alignment_size_m(self.token_group_alignment_size)

# Configure MXFP8
from torchao.prototype.mx_formats.config import (
Expand Down Expand Up @@ -94,6 +96,13 @@ def convert(self, model: nn.Module):
from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.quantization import quantize_

# MoE conversion must take place before MXLinear conversion, otherwise the MXLinear will
# be converted back to nn.Linear:
# https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299
# TODO: add warning in torchao when this happens, or find a better way to avoid this.
if self.moe_fqns:
self._convert_moe_layers(model)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it feels like the converter registered for self.config should handle this case specifically

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as in having separate converters registered for mxfp8 moe, fp8 moe etc? If so, I kind of agree actually, the converters are becoming a bit of a mess, and having separate converters would also allow users to convert just dense or just MoE (or both), rather than the current state of having to convert dense in order to convert MoE.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that is one way, another thing I was thinking is that https://github.com/pytorch/ao/blob/93030e750186ace1c1c2ee7a849e2818a9f0ffde/torchao/prototype/moe_training/conversion_utils.py#L50 should be able to gracefully handle the case where the module has already be converted

assert isinstance(self.config, MXLinearConfig)
quantize_(
model,
Expand All @@ -102,6 +111,30 @@ def convert(self, model: nn.Module):
)
logger.info("Swapped to MXLinear layers")

def _convert_moe_layers(self, model: nn.Module):
"""
Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
"""
from torchao.prototype.moe_training.conversion_utils import (
MoEScalingType,
MoETrainingConfig,
)
from torchao.quantization.quant_api import quantize_

def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
for target_fqn in self.moe_fqns:
if target_fqn in cur_fqn:
return True
return False
Comment on lines +125 to +129
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think we can put this in quantization/utils.py, together with the similar thing in fp8 file?


config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
logger.info(
f"Converted MoE layers matching FQNS {self.moe_fqns} "
"to use dynamic MXFP8 quantization with scaled grouped GEMMs"
)

def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
"""
MXFP8 doesn't require any post-optimizer hooks at the moment
Expand Down
11 changes: 10 additions & 1 deletion torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
)
from torch.distributed.tensor.parallel import ParallelStyle

from torchtitan.tools.utils import _round_up


TOKEN_GROUP_ALIGN_SIZE_M = 8
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
Expand Down Expand Up @@ -253,6 +255,13 @@ def wrapper(
experts_per_ep_rank = w1.shape[0]
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank

# Make sure max_len of permuted token indicies is divisible by TOKEN_GROUP_ALIGN_SIZE_M,
# by padding it to the nearest multiple of TOKEN_GROUP_ALIGN_SIZE_M.
x_padded_per_expert = (
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M
)
padded_max_len = _round_up(x_padded_per_expert, TOKEN_GROUP_ALIGN_SIZE_M)

with torch.no_grad():
(
permuted_indices,
Expand All @@ -262,7 +271,7 @@ def wrapper(
num_tokens_per_expert,
experts_per_ep_rank,
num_ep_ranks,
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
padded_max_len,
TOKEN_GROUP_ALIGN_SIZE_M,
)

Expand Down
15 changes: 10 additions & 5 deletions torchtitan/experiments/llama4/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,20 @@

llama4_configs = {
"debugmodel": TransformerModelArgs(
dim=256,
n_layers=6,
n_heads=16,
vocab_size=2000,
dim=5120,
n_layers=2,
n_heads=40,
n_kv_heads=8,
ffn_dim_multiplier=1.2,
multiple_of=2048,
rope_theta=500000,
max_seq_len=10485760,
moe_args=MoEArgs(num_experts=4),
interleave_moe_layer_step=1,
),
"17bx16e": TransformerModelArgs(
dim=5120,
n_layers=48,
n_layers=2,
n_heads=40,
n_kv_heads=8,
ffn_dim_multiplier=1.2,
Expand Down
4 changes: 4 additions & 0 deletions torchtitan/experiments/llama4/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,7 @@ enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output", "router.gate"]
moe_fqns_prototype = ["experts"]

[mx]
filter_fqns = ["output", "router.gate"]
moe_fqns_prototype = ["experts,shared_expert"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shared_expert is now made up of 3 nn.Linear instead of GroupedExperts, so maybe you don't want to include them? In your paste they are converted to MXLinear

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, i didn't know that this had changed. will update accordingly

Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,7 @@ components = ["model", "loss"]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output", "router.gate"]

[mx]
filter_fqns = ["output", "router.gate"]
moe_fqns_prototype = ["experts,shared_expert"]
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ components = ["model", "loss"]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output", "router.gate"]

[mx]
filter_fqns = ["output", "router.gate"]
moe_fqns_prototype = ["experts,shared_expert"]
5 changes: 0 additions & 5 deletions torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
flex_attention,
)

from torchtitan.tools.utils import has_cuda_capability

# FlexAttention mask type. For each mask type, we initialize it at most once per
# batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to
# track the initialized mask.
Expand Down Expand Up @@ -203,14 +201,11 @@ def _init_backend(cls) -> None:
if cls.backends:
return

# Add CuDNN on B200 w/ highest priority
cls.backends = [
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
]
if has_cuda_capability(10, 0):
cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION)

def forward(
self,
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,9 @@ def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
yield
finally:
torch.set_default_dtype(old_dtype)


def _round_up(x: int, y: int) -> int:
"""Round up x to the nearest multiple of y."""
x_ceil_div_y = (x + y - 1) // y
return x_ceil_div_y * y