Skip to content

Commit 453a4b0

Browse files
committed
[RFC][WIP][CP] Enable FlexAttention CP for llama3
This PR uses the latest CP APIs to enable FlexAttention + CP for llama3. This PR removes the usage of context_paralle() context manager and use `_context_parallel_shard()` to shard the input data. ghstack-source-id: d30bc9f Pull-Request: #1857
1 parent d28e253 commit 453a4b0

File tree

5 files changed

+184
-70
lines changed

5 files changed

+184
-70
lines changed

torchtitan/distributed/utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP
2121
from torchtitan.distributed.parallel_dims import ParallelDims
22+
from torchtitan.protocols.model import AttentionMasksType
2223
from torchtitan.tools.logging import logger
2324
from torchtitan.tools.utils import device_module, device_type
2425

@@ -449,3 +450,52 @@ def _clip_grad_norm_with_ep(
449450
torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach)
450451

451452
return total_norm
453+
454+
455+
def cp_shard(
456+
cp_mesh: DeviceMesh,
457+
inputs: torch.Tensor,
458+
labels: torch.Tensor,
459+
attention_masks: AttentionMasksType | None,
460+
order_sensitive_buffers: dict[str, torch.Tensor],
461+
order_sensitive_buffers_seq_dims: dict[str, int],
462+
):
463+
from torch.distributed.tensor.experimental._attention import _context_parallel_shard
464+
from torch.nn.attention.flex_attention import BlockMask
465+
466+
load_balancer = None
467+
inputs, labels = _context_parallel_shard(
468+
mesh=cp_mesh,
469+
buffers=(inputs, labels),
470+
seq_dims=(1, 1),
471+
load_balancer=load_balancer,
472+
)
473+
474+
order_sensitive_buffers = _context_parallel_shard(
475+
mesh=cp_mesh,
476+
buffers=order_sensitive_buffers,
477+
seq_dims=order_sensitive_buffers_seq_dims,
478+
load_balancer=load_balancer,
479+
)
480+
481+
if attention_masks is None:
482+
return inputs, labels, None, order_sensitive_buffers
483+
484+
masks = (
485+
[attention_masks]
486+
if isinstance(attention_masks, BlockMask)
487+
else list(attention_masks.values())
488+
)
489+
masks = _context_parallel_shard(
490+
mesh=cp_mesh,
491+
buffers=masks,
492+
seq_dims=(2,) * len(masks),
493+
load_balancer=load_balancer,
494+
)
495+
attention_masks = (
496+
masks[0]
497+
if isinstance(attention_masks, BlockMask)
498+
else {k: v for k, v in zip(attention_masks.keys(), masks)}
499+
)
500+
501+
return inputs, labels, attention_masks, order_sensitive_buffers

torchtitan/models/attention.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from torch.nn.attention.flex_attention import (
1717
_mask_mod_signature,
1818
AuxOutput,
19-
BlockMask,
2019
create_block_mask,
2120
flex_attention,
2221
)
@@ -49,23 +48,13 @@ class FlexAttentionWrapper(torch.nn.Module):
4948
flex_attention, mode="max-autotune-no-cudagraphs"
5049
)
5150

52-
def forward(
53-
self,
54-
q: torch.Tensor,
55-
k: torch.Tensor,
56-
v: torch.Tensor,
57-
*,
58-
block_mask: BlockMask,
59-
scale: float | None = None,
60-
) -> torch.Tensor | tuple[torch.Tensor, AuxOutput]:
51+
def forward(self, *args, **kwargs) -> torch.Tensor | tuple[torch.Tensor, AuxOutput]:
6152
# 1. _compiled_flex_attn has to be a class variable, otherwise there will
6253
# be multiple compiled flex_attention instances, which can be slow.
6354
# 2. `self._compiled_flex_attn` is not correct, `self` will be passed in
6455
# as the first argument, which will cause an error.
6556
# `FlexAttentionWrapper._compiled_flex_attn` is correct.
66-
return FlexAttentionWrapper._compiled_flex_attn(
67-
q, k, v, block_mask=block_mask, scale=scale
68-
)
57+
return FlexAttentionWrapper._compiled_flex_attn(*args, **kwargs)
6958

7059

7160
class ScaledDotProductAttentionWrapper(torch.nn.Module):

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torchtitan.distributed import ParallelDims
2828
from torchtitan.distributed.activation_checkpoint import apply_ac
2929
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
30+
from torchtitan.protocols.model import AttentionMasksType
3031
from torchtitan.tools.logging import logger
3132

3233

@@ -67,10 +68,6 @@ def parallelize_llama(
6768
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
6869
"""
6970

70-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
71-
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
72-
raise NotImplementedError("CP support for FlexAttention is still in progress.")
73-
7471
if parallel_dims.tp_enabled:
7572
enable_float8_linear = "float8" in job_config.model.converters
7673
float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in (
@@ -91,6 +88,11 @@ def parallelize_llama(
9188
)
9289
maybe_enable_async_tp(job_config, world_mesh["tp"])
9390

91+
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
92+
if parallel_dims.cp_enabled:
93+
logger.info("Applied Context Parallel to the model")
94+
apply_cp(model, world_mesh["cp"], use_flex_attn)
95+
9496
model_compile_enabled = (
9597
job_config.compile.enable and "model" in job_config.compile.components
9698
)
@@ -131,9 +133,6 @@ def parallelize_llama(
131133
else:
132134
logger.info("Applied FSDP to the model")
133135

134-
if parallel_dims.cp_enabled:
135-
logger.info("Applied Context Parallel to the model")
136-
137136
if job_config.training.enable_cpu_offload:
138137
logger.info("Applied CPU Offloading to the model")
139138
elif parallel_dims.dp_replicate_enabled:
@@ -335,3 +334,87 @@ def apply_ddp(
335334
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
336335

337336
logger.info("Applied DDP to the model")
337+
338+
339+
def apply_cp(
340+
model: nn.Module,
341+
cp_mesh: DeviceMesh,
342+
use_flex_attn: bool,
343+
) -> None:
344+
"""
345+
Apply context parallelism to the model.
346+
"""
347+
from torch.distributed.tensor.experimental._attention import (
348+
_ContextParallel,
349+
_enable_context_parallel_dispatcher,
350+
)
351+
352+
# Apply context parallelism to every transformer block
353+
# TODO: make seq_sim configurable once the implementation doesn't assume 2
354+
# internally.
355+
if use_flex_attn:
356+
cp_plan = _ContextParallel(
357+
seq_dim=2, attention_type=_ContextParallel.AttentionType.FLEX
358+
)
359+
else:
360+
# This is currently required as DTensor dispatcher is not enabled to
361+
# dispatch SDPA to CP implementation. We don't disable the CP
362+
# dispatching in TorchTitan as it is not needed. But there is a
363+
# corresponding API, _disable_context_parallel_dispatcher to do
364+
# that if users have this use case.
365+
_enable_context_parallel_dispatcher()
366+
cp_plan = _ContextParallel(
367+
seq_dim=2, attention_type=_ContextParallel.AttentionType.SDPA
368+
)
369+
370+
for transformer_block in model.layers.values():
371+
parallelize_module(
372+
module=transformer_block.attention.inner_attention,
373+
device_mesh=cp_mesh,
374+
parallelize_plan=cp_plan,
375+
)
376+
377+
378+
def cp_shard(
379+
cp_mesh: DeviceMesh,
380+
inputs: torch.Tensor,
381+
labels: torch.Tensor,
382+
attention_masks: AttentionMasksType,
383+
order_sensitive_buffers: dict[str, torch.Tensor],
384+
order_sensitive_buffers_seq_dims: dict[str, int],
385+
):
386+
from torch.distributed.tensor.experimental._attention import _context_parallel_shard
387+
from torch.nn.attention.flex_attention import BlockMask
388+
389+
load_balancer = None
390+
inputs, labels = _context_parallel_shard(
391+
mesh=cp_mesh,
392+
buffers=(inputs, labels),
393+
seq_dims=(1, 1),
394+
load_balancer=load_balancer,
395+
)
396+
397+
masks = (
398+
[attention_masks]
399+
if isinstance(attention_masks, BlockMask)
400+
else list(attention_masks.values())
401+
)
402+
masks = _context_parallel_shard(
403+
mesh=cp_mesh,
404+
buffers=masks,
405+
seq_dims=(2,) * len(masks),
406+
load_balancer=load_balancer,
407+
)
408+
attention_masks = (
409+
masks[0]
410+
if isinstance(attention_masks, BlockMask)
411+
else {k: v for k, v in zip(attention_masks.keys(), masks)}
412+
)
413+
414+
order_sensitive_buffers = _context_parallel_shard(
415+
mesh=cp_mesh,
416+
buffers=order_sensitive_buffers,
417+
seq_dims=order_sensitive_buffers_seq_dims,
418+
load_balancer=load_balancer,
419+
)
420+
return inputs, labels, attention_masks, order_sensitive_buffers

torchtitan/models/llama3/model/args.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
5555
)
5656
self.max_seq_len = seq_len
5757

58-
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
59-
raise NotImplementedError(
60-
"CP support for FlexAttention is still in progress."
61-
)
62-
6358
def get_nparams_and_flops(
6459
self, model: nn.Module, seq_len: int
6560
) -> tuple[int, float]:

torchtitan/train.py

Lines changed: 42 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -426,50 +426,48 @@ def forward_backward_step(
426426
tokenizer=self.tokenizer,
427427
extra_inputs=extra_inputs,
428428
)
429+
else:
430+
extra_args["attention_masks"] = None
429431

430432
# Get the order sensitive buffers
431433
order_sensitive_buffers = model_parts[0].get_order_sensitive_buffers(
432434
inputs.size(0), inputs.size(1)
433435
)
434-
extra_args.update(order_sensitive_buffers[0])
435-
436-
# apply context parallelism if cp is enabled
437-
# ensure CP handles the separate freqs_cis buffer for each pp stage
438436
cp_mesh = parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
439-
optional_context_parallel_ctx = (
440-
dist_utils.create_context_parallel_ctx(
441-
cp_mesh=parallel_dims.world_mesh["cp"],
442-
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
443-
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
444-
cp_no_restore_buffers={inputs, labels},
445-
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
437+
if cp_mesh:
438+
(
439+
inputs,
440+
labels,
441+
extra_args["attention_masks"],
442+
*order_sensitive_buffers,
443+
) = dist_utils.cp_shard(
444+
cp_mesh,
445+
inputs,
446+
labels,
447+
extra_args["attention_masks"],
448+
*order_sensitive_buffers,
446449
)
447-
if parallel_dims.cp_enabled
448-
else None
449-
)
450+
extra_args.update(order_sensitive_buffers[0])
450451

451452
if parallel_dims.pp_enabled:
452453
# Pipeline Parallel forward / backward inside step() call
453-
with self.train_context(optional_context_parallel_ctx):
454-
targets, losses = (
455-
(labels, []) if self.pp_has_last_stage else (None, None)
454+
targets, losses = (labels, []) if self.pp_has_last_stage else (None, None)
455+
if self.pp_has_first_stage:
456+
self.pp_schedule.step(
457+
inputs,
458+
**extra_inputs,
459+
**extra_args,
460+
target=targets,
461+
losses=losses,
462+
input_batch=inputs,
463+
)
464+
else:
465+
self.pp_schedule.step(
466+
**extra_args,
467+
target=targets,
468+
losses=losses,
469+
input_batch=inputs,
456470
)
457-
if self.pp_has_first_stage:
458-
self.pp_schedule.step(
459-
inputs,
460-
**extra_inputs,
461-
**extra_args,
462-
target=targets,
463-
losses=losses,
464-
input_batch=inputs,
465-
)
466-
else:
467-
self.pp_schedule.step(
468-
**extra_args,
469-
target=targets,
470-
losses=losses,
471-
input_batch=inputs,
472-
)
473471

474472
# accumulate losses across pipeline microbatches
475473
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
@@ -483,18 +481,17 @@ def forward_backward_step(
483481
)
484482
else:
485483
# Non-PP forward / backward
486-
with self.train_context(optional_context_parallel_ctx):
487-
assert len(model_parts) == 1
488-
with self.maybe_enable_amp:
489-
pred = model_parts[0](
490-
inputs,
491-
**extra_inputs,
492-
**extra_args,
493-
)
494-
loss = self.loss_fn(pred, labels)
495-
# need to free pred before bwd to avoid peaking memory
496-
del pred
497-
loss.backward()
484+
assert len(model_parts) == 1
485+
with self.maybe_enable_amp:
486+
pred = model_parts[0](
487+
inputs,
488+
**extra_inputs,
489+
**extra_args,
490+
)
491+
loss = self.loss_fn(pred, labels)
492+
# need to free pred before bwd to avoid peaking memory
493+
del pred
494+
loss.backward()
498495

499496
return loss
500497

0 commit comments

Comments
 (0)