2727from torchtitan .distributed import ParallelDims
2828from torchtitan .distributed .activation_checkpoint import apply_ac
2929from torchtitan .distributed .tensor_parallel import maybe_enable_async_tp
30+ from torchtitan .protocols .model import AttentionMasksType
3031from torchtitan .tools .logging import logger
3132
3233
@@ -67,10 +68,6 @@ def parallelize_llama(
6768 ({ parallel_dims .tp } ) and 2 * CP degree ({ parallel_dims .cp } ).
6869 """
6970
70- use_flex_attn = getattr (model .model_args , "use_flex_attn" , False )
71- if job_config .parallelism .context_parallel_degree > 1 and use_flex_attn :
72- raise NotImplementedError ("CP support for FlexAttention is still in progress." )
73-
7471 if parallel_dims .tp_enabled :
7572 enable_float8_linear = "float8" in job_config .model .converters
7673 float8_is_rowwise = job_config .quantize .linear .float8 .recipe_name in (
@@ -91,6 +88,11 @@ def parallelize_llama(
9188 )
9289 maybe_enable_async_tp (job_config , world_mesh ["tp" ])
9390
91+ use_flex_attn = getattr (model .model_args , "use_flex_attn" , False )
92+ if parallel_dims .cp_enabled :
93+ logger .info ("Applied Context Parallel to the model" )
94+ apply_cp (model , world_mesh ["cp" ], use_flex_attn )
95+
9496 model_compile_enabled = (
9597 job_config .compile .enable and "model" in job_config .compile .components
9698 )
@@ -131,9 +133,6 @@ def parallelize_llama(
131133 else :
132134 logger .info ("Applied FSDP to the model" )
133135
134- if parallel_dims .cp_enabled :
135- logger .info ("Applied Context Parallel to the model" )
136-
137136 if job_config .training .enable_cpu_offload :
138137 logger .info ("Applied CPU Offloading to the model" )
139138 elif parallel_dims .dp_replicate_enabled :
@@ -335,3 +334,87 @@ def apply_ddp(
335334 replicate (model , device_mesh = dp_mesh , bucket_cap_mb = 100 )
336335
337336 logger .info ("Applied DDP to the model" )
337+
338+
339+ def apply_cp (
340+ model : nn .Module ,
341+ cp_mesh : DeviceMesh ,
342+ use_flex_attn : bool ,
343+ ) -> None :
344+ """
345+ Apply context parallelism to the model.
346+ """
347+ from torch .distributed .tensor .experimental ._attention import (
348+ _ContextParallel ,
349+ _enable_context_parallel_dispatcher ,
350+ )
351+
352+ # Apply context parallelism to every transformer block
353+ # TODO: make seq_sim configurable once the implementation doesn't assume 2
354+ # internally.
355+ if use_flex_attn :
356+ cp_plan = _ContextParallel (
357+ seq_dim = 2 , attention_type = _ContextParallel .AttentionType .FLEX
358+ )
359+ else :
360+ # This is currently required as DTensor dispatcher is not enabled to
361+ # dispatch SDPA to CP implementation. We don't disable the CP
362+ # dispatching in TorchTitan as it is not needed. But there is a
363+ # corresponding API, _disable_context_parallel_dispatcher to do
364+ # that if users have this use case.
365+ _enable_context_parallel_dispatcher ()
366+ cp_plan = _ContextParallel (
367+ seq_dim = 2 , attention_type = _ContextParallel .AttentionType .SDPA
368+ )
369+
370+ for transformer_block in model .layers .values ():
371+ parallelize_module (
372+ module = transformer_block .attention .inner_attention ,
373+ device_mesh = cp_mesh ,
374+ parallelize_plan = cp_plan ,
375+ )
376+
377+
378+ def cp_shard (
379+ cp_mesh : DeviceMesh ,
380+ inputs : torch .Tensor ,
381+ labels : torch .Tensor ,
382+ attention_masks : AttentionMasksType ,
383+ order_sensitive_buffers : dict [str , torch .Tensor ],
384+ order_sensitive_buffers_seq_dims : dict [str , int ],
385+ ):
386+ from torch .distributed .tensor .experimental ._attention import _context_parallel_shard
387+ from torch .nn .attention .flex_attention import BlockMask
388+
389+ load_balancer = None
390+ inputs , labels = _context_parallel_shard (
391+ mesh = cp_mesh ,
392+ buffers = (inputs , labels ),
393+ seq_dims = (1 , 1 ),
394+ load_balancer = load_balancer ,
395+ )
396+
397+ masks = (
398+ [attention_masks ]
399+ if isinstance (attention_masks , BlockMask )
400+ else list (attention_masks .values ())
401+ )
402+ masks = _context_parallel_shard (
403+ mesh = cp_mesh ,
404+ buffers = masks ,
405+ seq_dims = (2 ,) * len (masks ),
406+ load_balancer = load_balancer ,
407+ )
408+ attention_masks = (
409+ masks [0 ]
410+ if isinstance (attention_masks , BlockMask )
411+ else {k : v for k , v in zip (attention_masks .keys (), masks )}
412+ )
413+
414+ order_sensitive_buffers = _context_parallel_shard (
415+ mesh = cp_mesh ,
416+ buffers = order_sensitive_buffers ,
417+ seq_dims = order_sensitive_buffers_seq_dims ,
418+ load_balancer = load_balancer ,
419+ )
420+ return inputs , labels , attention_masks , order_sensitive_buffers
0 commit comments