Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
AuxOutput,
BlockMask,
create_block_mask,
flex_attention,
Expand All @@ -28,6 +29,26 @@
FLEX_ATTN_MASK_T = tuple[str, int | None]


class FlexAttentionWrapper(torch.nn.Module):
_flex_attn: ClassVar[Callable] = torch.compile(
flex_attention, mode="max-autotune-no-cudagraphs"
)

def __init__(self) -> None:
super().__init__()

def forward(self, *args: object, **kwargs: object) -> [
torch.Tensor | tuple[torch.Tensor, torch.Tensor],
tuple[torch.Tensor, AuxOutput],
]:
# 1. _flex_attn has to be a class variable, otherwise there will
# be multiple complied flex_attention, which can be slow.
# 2. `self._flex_attn` is not correct, `self` will be passed in
# as the first argument, which will cause an error.
# `FlexAttentionWrapper._flex_attn` is correct.
return FlexAttentionWrapper._flex_attn(*args, **kwargs)


class FlexAttention(torch.nn.Module):
"""FlexAttention module that uses torch.nn.attention.flex_attention.

Expand All @@ -46,11 +67,6 @@ class FlexAttention(torch.nn.Module):
to the keys within the same block.
"""

# We registered flex_attention related attributes as class variables as we
# need to amortize the cost of compilation.
flex_attn: ClassVar[Callable] = torch.compile(
flex_attention, mode="max-autotune-no-cudagraphs"
)
compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask)
used_attn_mask_types: ClassVar[set[FLEX_ATTN_MASK_T]] = set()
# Attention mask type to the created BlockMask.
Expand All @@ -71,6 +87,7 @@ def __init__(
raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.")
self.attn_mask_type = attn_mask_type
self.fixed_block_size = fixed_block_size
self.attention_fn_wrapper = FlexAttentionWrapper()

FlexAttention.used_attn_mask_types.add(self.mask_key)

Expand All @@ -86,7 +103,7 @@ def forward(
scale: float | None = None,
) -> torch.Tensor:
block_mask = FlexAttention.block_masks[self.mask_key]
return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale)
return self.attention_fn_wrapper(q, k, v, block_mask=block_mask, scale=scale)

@staticmethod
def _get_causal_mask_mod() -> _mask_mod_signature:
Expand Down Expand Up @@ -251,6 +268,11 @@ def init_attention_mask(
# while we continue debugging accuracy issues. However, we want to evaluate
# the user experience with CP enabled.
if cp_mesh is not None:
from torch.distributed.tensor.experimental._attention import _DispatchMode

torch.distributed.tensor.experimental._attention._dispatch_mode = (
_DispatchMode.MODULE_WRAPPER
)
FlexAttention.compiled_create_block_mask = functools.partial(
create_cp_block_mask, device_mesh=cp_mesh
)
Expand Down
15 changes: 13 additions & 2 deletions torchtitan/models/llama3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
from torch.distributed.tensor import Replicate, Shard

from torch.distributed.tensor.experimental._attention import _ContextParallel
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
Expand Down Expand Up @@ -67,8 +69,6 @@ def parallelize_llama(
"""

use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
raise NotImplementedError("CP support for FlexAttention is still in progress.")

if parallel_dims.tp_enabled:
enable_float8_linear = "float8" in job_config.model.converters
Expand All @@ -90,6 +90,17 @@ def parallelize_llama(
)
maybe_enable_async_tp(job_config, world_mesh["tp"])

if parallel_dims.cp_enabled:
for block in model.layers.values():
parallelize_module(
module=block.attention.sdpa.attention_fn_wrapper,
device_mesh=world_mesh["cp"],
parallelize_plan=_ContextParallel(
seq_dim=2,
attention_type=_ContextParallel.AttentionType.FLEX,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this only work for FlexAttention?
Is there a plan to consolidate SDPA and FlexAttention in terms of how CP is applied?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will work for both SDPA and Flex. We just need to pass in a different type based on what attention is used.

),
)

model_compile_enabled = (
job_config.compile.enable and "model" in job_config.compile.components
)
Expand Down
7 changes: 0 additions & 7 deletions torchtitan/models/llama3/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
)
self.max_seq_len = seq_len

if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
raise NotImplementedError(
"CP support for FlexAttention is still in progress."
)

self.max_seq_len = seq_len

def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
nparams = sum(p.numel() for p in model.parameters())
nparams_embedding = sum(
Expand Down
Loading