Skip to content

Commit ecae980

Browse files
committed
[RFC][WIP][CP] Enable FlexAttention CP for llama3
This PR uses the latest CP APIs to enable FlexAttention + CP for llama3. This PR removes the usage of context_paralle() context manager and use `_context_parallel_shard()` to shard the input data. ghstack-source-id: 1bff8da Pull-Request: #1857
1 parent e0f4d77 commit ecae980

File tree

7 files changed

+197
-54
lines changed

7 files changed

+197
-54
lines changed

torchtitan/distributed/utils.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP
2121
from torchtitan.distributed.parallel_dims import ParallelDims
22+
from torchtitan.protocols.model import AttentionMasksType
2223
from torchtitan.tools.logging import logger
2324
from torchtitan.tools.utils import device_module, device_type
2425

@@ -200,9 +201,6 @@ def context(cp_context: Generator[None, None, None] | None = None):
200201
if enable_loss_parallel:
201202
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())
202203

203-
if cp_context:
204-
stack.enter_context(cp_context)
205-
206204
yield
207205

208206
return context
@@ -443,3 +441,65 @@ def _clip_grad_norm_with_ep(
443441
torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach)
444442

445443
return total_norm
444+
445+
446+
def cp_shard(
447+
cp_mesh: DeviceMesh,
448+
inputs: torch.Tensor,
449+
labels: torch.Tensor,
450+
attention_masks: AttentionMasksType | None,
451+
order_sensitive_buffers: dict[str, torch.Tensor],
452+
order_sensitive_buffers_seq_dims: dict[str, int],
453+
):
454+
from torch.distributed.tensor.experimental._attention import _context_parallel_shard
455+
from torch.distributed.tensor.experimental._load_balancer import (
456+
_HeadTailLoadBalancer,
457+
_PTRRLoadBalancer,
458+
)
459+
from torch.nn.attention.flex_attention import BlockMask
460+
461+
seq_len = inputs.size(1)
462+
cp_world_size = cp_mesh.size(0)
463+
if isinstance(attention_masks, BlockMask):
464+
load_balancer = _PTRRLoadBalancer(attention_masks, cp_world_size)
465+
else:
466+
# For multiple BlockMasks or SDPA, we use the _HeadTailLoadBalancer.
467+
load_balancer = _HeadTailLoadBalancer(
468+
seq_len, cp_world_size, cp_mesh.device_type
469+
)
470+
471+
inputs, labels = _context_parallel_shard(
472+
mesh=cp_mesh,
473+
buffers=(inputs, labels),
474+
seq_dims=(1, 1),
475+
load_balancer=load_balancer,
476+
)
477+
478+
order_sensitive_buffers = _context_parallel_shard(
479+
mesh=cp_mesh,
480+
buffers=order_sensitive_buffers,
481+
seq_dims=order_sensitive_buffers_seq_dims,
482+
load_balancer=load_balancer,
483+
)
484+
485+
if attention_masks is None:
486+
return inputs, labels, None, order_sensitive_buffers
487+
488+
masks = (
489+
[attention_masks]
490+
if isinstance(attention_masks, BlockMask)
491+
else list(attention_masks.values())
492+
)
493+
masks = _context_parallel_shard(
494+
mesh=cp_mesh,
495+
buffers=masks,
496+
seq_dims=(2,) * len(masks),
497+
load_balancer=load_balancer,
498+
)
499+
attention_masks = (
500+
masks[0]
501+
if isinstance(attention_masks, BlockMask)
502+
else {k: v for k, v in zip(attention_masks.keys(), masks)}
503+
)
504+
505+
return inputs, labels, attention_masks, order_sensitive_buffers

torchtitan/models/attention.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@
3232
]
3333

3434

35+
class FlexAttentionKernel(torch.nn.Module):
36+
"""Wrapper to enable FlexCP"""
37+
38+
_compiled_flex_attn: ClassVar[Callable] = torch.compile(
39+
flex_attention, mode="max-autotune-no-cudagraphs"
40+
)
41+
42+
def forward(self, *args, **kwargs):
43+
# 1. _compiled_flex_attn has to be a class variable, otherwise there will
44+
# be multiple compiled flex_attention instances, which can be slow.
45+
# 2. `self._compiled_flex_attn` is not correct, `self` will be passed in
46+
# as the first argument, which will cause an error.
47+
# `FlexAttentionKernel._compiled_flex_attn` is correct.
48+
return FlexAttentionKernel._compiled_flex_attn(*args, **kwargs)
49+
50+
3551
class FlexAttentionWrapper(torch.nn.Module):
3652
"""Wrapper around `flex_attention` to make it torch.compile and CP compatible.
3753
@@ -45,9 +61,11 @@ class FlexAttentionWrapper(torch.nn.Module):
4561
block_mask as a keyword argument to be compatible with _ContextParallel.
4662
"""
4763

48-
_compiled_flex_attn: ClassVar[Callable] = torch.compile(
49-
flex_attention, mode="max-autotune-no-cudagraphs"
50-
)
64+
def __init__(self) -> None:
65+
super().__init__()
66+
# TODO: remove this wrapper once FlexAttentionWrapper.forward() has the
67+
# same signature as flex_attention() and is compatible with _ContextParallel.
68+
self._flex_attention_kernel = FlexAttentionKernel()
5169

5270
def forward(
5371
self,
@@ -59,15 +77,10 @@ def forward(
5977
scale: float | None = None,
6078
return_lse: bool = False,
6179
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
62-
# 1. _compiled_flex_attn has to be a class variable, otherwise there will
63-
# be multiple compiled flex_attention instances, which can be slow.
64-
# 2. `self._compiled_flex_attn` is not correct, `self` will be passed in
65-
# as the first argument, which will cause an error.
66-
# `FlexAttentionWrapper._compiled_flex_attn` is correct.
67-
# 3. Used `return_lse` instead of `return_aux` because of easier TP module notation
68-
# to convert `lse` to be DTensor.
80+
# Used `return_lse` instead of `return_aux` because of easier TP module notation
81+
# to convert `lse` to be DTensor.
6982

70-
return FlexAttentionWrapper._compiled_flex_attn(
83+
return self._flex_attention_kernel(
7184
q,
7285
k,
7386
v,

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,6 @@ def parallelize_llama(
6767
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
6868
"""
6969

70-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
71-
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
72-
raise NotImplementedError("CP support for FlexAttention is still in progress.")
73-
7470
if parallel_dims.tp_enabled:
7571
enable_float8_linear = "float8" in job_config.model.converters
7672
float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in (
@@ -91,6 +87,11 @@ def parallelize_llama(
9187
)
9288
maybe_enable_async_tp(job_config, world_mesh["tp"])
9389

90+
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
91+
if parallel_dims.cp_enabled:
92+
logger.info("Applied Context Parallel to the model")
93+
apply_cp(model, world_mesh["cp"], use_flex_attn)
94+
9495
model_compile_enabled = (
9596
job_config.compile.enable and "model" in job_config.compile.components
9697
)
@@ -131,9 +132,6 @@ def parallelize_llama(
131132
else:
132133
logger.info("Applied FSDP to the model")
133134

134-
if parallel_dims.cp_enabled:
135-
logger.info("Applied Context Parallel to the model")
136-
137135
if job_config.training.enable_cpu_offload:
138136
logger.info("Applied CPU Offloading to the model")
139137
elif parallel_dims.dp_replicate_enabled:
@@ -328,3 +326,46 @@ def apply_ddp(
328326
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
329327

330328
logger.info("Applied DDP to the model")
329+
330+
331+
def apply_cp(
332+
model: nn.Module,
333+
cp_mesh: DeviceMesh,
334+
use_flex_attn: bool,
335+
) -> None:
336+
"""
337+
Apply context parallelism to the model.
338+
"""
339+
from torch.distributed.tensor.experimental._attention import (
340+
_ContextParallel,
341+
_enable_context_parallel_dispatcher,
342+
)
343+
344+
# Apply context parallelism to every transformer block
345+
# TODO: make seq_sim configurable once the implementation doesn't assume 2
346+
# internally.
347+
if use_flex_attn:
348+
cp_plan = _ContextParallel(
349+
seq_dim=2, attention_type=_ContextParallel.AttentionType.FLEX
350+
)
351+
else:
352+
# This is currently required as DTensor dispatcher is not enabled to
353+
# dispatch SDPA to CP implementation. We don't disable the CP
354+
# dispatching in TorchTitan as it is not needed. But there is a
355+
# corresponding API, _disable_context_parallel_dispatcher to do
356+
# that if users have this use case.
357+
_enable_context_parallel_dispatcher()
358+
cp_plan = _ContextParallel(
359+
seq_dim=2, attention_type=_ContextParallel.AttentionType.SDPA
360+
)
361+
362+
for transformer_block in model.layers.values():
363+
module = transformer_block.attention.inner_attention
364+
if use_flex_attn:
365+
module = module._flex_attention_kernel
366+
367+
parallelize_module(
368+
module=module,
369+
device_mesh=cp_mesh,
370+
parallelize_plan=cp_plan,
371+
)

torchtitan/models/llama3/model/args.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
5555
)
5656
self.max_seq_len = seq_len
5757

58-
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
59-
raise NotImplementedError(
60-
"CP support for FlexAttention is still in progress."
61-
)
62-
6358
def get_nparams_and_flops(
6459
self, model: nn.Module, seq_len: int
6560
) -> tuple[int, float]:

torchtitan/models/llama3/model/model.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
9292
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
9393
for the purpose of broadcasting the frequency tensor during element-wise operations.
9494
95-
The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
96-
and the first seqlen elements will be sliced, but dim must match x.
95+
The input freqs_cis tensor is assumed to be of shape (batch_size, seqlen, dim).
9796
9897
Args:
9998
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
@@ -104,10 +103,10 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
104103
"""
105104
ndim = x.ndim
106105
assert ndim > 1
106+
batch_size = x.shape[0]
107107
seqlen = x.shape[1]
108-
freqs_cis = freqs_cis[0:seqlen]
109-
assert freqs_cis.shape == (seqlen, x.shape[-1])
110-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
108+
assert freqs_cis.shape == (batch_size, seqlen, x.shape[-1])
109+
shape = [d if i in (0, 1, ndim - 1) else 1 for i, d in enumerate(x.shape)]
111110
return freqs_cis.view(*shape)
112111

113112

@@ -474,9 +473,18 @@ def get_attention_masks(
474473
and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]
475474
)
476475

476+
def get_order_sensitive_buffers(
477+
self,
478+
batch_size: int,
479+
seq_len: int,
480+
) -> tuple[dict[str, torch.Tensor], dict[str, int]]:
481+
freqs_cis = self.freqs_cis[:seq_len].repeat(batch_size, 1, 1)
482+
return ({"freqs_cis": freqs_cis}, {"freqs_cis": 1})
483+
477484
def forward(
478485
self,
479486
tokens: torch.Tensor,
487+
freqs_cis: torch.Tensor,
480488
attention_masks: AttentionMasksType | None = None,
481489
):
482490
"""
@@ -496,7 +504,7 @@ def forward(
496504
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
497505

498506
for layer in self.layers.values():
499-
h = layer(h, self.freqs_cis, attention_masks=attention_masks)
507+
h = layer(h, freqs_cis, attention_masks=attention_masks)
500508

501509
h = self.norm(h) if self.norm else h
502510
output = self.output(h) if self.output else h

torchtitan/protocols/model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,10 @@ def get_attention_masks(
7070
raise NotImplementedError(
7171
"This model does not support attention masking/Flex Attention."
7272
)
73+
74+
def get_order_sensitive_buffers(
75+
self,
76+
batch_size: int,
77+
seq_len: int,
78+
) -> tuple[dict[str, torch.Tensor], dict[str, int]]:
79+
return ({}, {})

torchtitan/train.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -409,12 +409,10 @@ def batch_generator(
409409

410410
yield input_dict, labels
411411

412-
def forward_backward_step(
412+
def post_dataloader_step(
413413
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
414-
) -> torch.Tensor:
415-
model_parts = self.model_parts
416-
parallel_dims = self.parallel_dims
417-
414+
) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any], dict[str, Any],]:
415+
"""Post processing of the batch and label after being loaded from the dataloader."""
418416
inputs = input_dict["input"]
419417
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
420418
# For arguments, like attention_masks, we have to put them in a separate
@@ -423,32 +421,53 @@ def forward_backward_step(
423421
extra_kwargs = {}
424422

425423
if getattr(self.model_args, "use_flex_attn", False):
426-
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
424+
extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks(
427425
input_batch=inputs,
428426
tokenizer=self.tokenizer,
429427
extra_inputs=extra_inputs,
430428
)
429+
else:
430+
extra_kwargs["attention_masks"] = None
431431

432-
# apply context parallelism if cp is enabled
433-
# ensure CP handles the separate freqs_cis buffer for each pp stage
434-
optional_context_parallel_ctx = (
435-
dist_utils.create_context_parallel_ctx(
436-
cp_mesh=parallel_dims.world_mesh["cp"],
437-
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
438-
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
439-
cp_no_restore_buffers={inputs, labels},
440-
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
441-
)
442-
if parallel_dims.cp_enabled
432+
# Get the order sensitive buffers
433+
order_sensitive_buffers = self.model_parts[0].get_order_sensitive_buffers(
434+
inputs.size(0), inputs.size(1)
435+
)
436+
cp_mesh = (
437+
self.parallel_dims.world_mesh["cp"]
438+
if self.parallel_dims.cp_enabled
443439
else None
444440
)
441+
if cp_mesh:
442+
(
443+
inputs,
444+
labels,
445+
extra_kwargs["attention_masks"],
446+
*order_sensitive_buffers,
447+
) = dist_utils.cp_shard(
448+
cp_mesh,
449+
inputs,
450+
labels,
451+
extra_kwargs["attention_masks"],
452+
*order_sensitive_buffers,
453+
)
454+
extra_kwargs.update(order_sensitive_buffers[0])
455+
return inputs, labels, extra_inputs, extra_kwargs
456+
457+
def forward_backward_step(
458+
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
459+
) -> torch.Tensor:
460+
model_parts = self.model_parts
461+
parallel_dims = self.parallel_dims
462+
463+
inputs, labels, extra_inputs, extra_kwargs = self.post_dataloader_step(
464+
input_dict, labels
465+
)
445466

446467
if parallel_dims.pp_enabled:
447468
# Pipeline Parallel forward / backward inside step() call
448-
with self.train_context(optional_context_parallel_ctx):
449-
targets, losses = (
450-
(labels, []) if self.pp_has_last_stage else (None, None)
451-
)
469+
targets, losses = (labels, []) if self.pp_has_last_stage else (None, None)
470+
with self.train_context():
452471
if self.pp_has_first_stage:
453472
self.pp_schedule.step(
454473
inputs,
@@ -478,7 +497,7 @@ def forward_backward_step(
478497
)
479498
else:
480499
# Non-PP forward / backward
481-
with self.train_context(optional_context_parallel_ctx):
500+
with self.train_context():
482501
assert len(model_parts) == 1
483502
with self.maybe_enable_amp:
484503
pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)

0 commit comments

Comments
 (0)