-
Notifications
You must be signed in to change notification settings - Fork 537
[CP][RFC] Enable FlexCP for llama3 with parallelize_module #1707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ | |
from torch.distributed.device_mesh import DeviceMesh | ||
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy | ||
from torch.distributed.tensor import Replicate, Shard | ||
|
||
from torch.distributed.tensor.experimental._attention import _ContextParallel | ||
from torch.distributed.tensor.parallel import ( | ||
ColwiseParallel, | ||
parallelize_module, | ||
|
@@ -67,8 +69,6 @@ def parallelize_llama( | |
""" | ||
|
||
use_flex_attn = getattr(model.model_args, "use_flex_attn", False) | ||
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: | ||
raise NotImplementedError("CP support for FlexAttention is still in progress.") | ||
|
||
if parallel_dims.tp_enabled: | ||
enable_float8_linear = "float8" in job_config.model.converters | ||
|
@@ -90,6 +90,17 @@ def parallelize_llama( | |
) | ||
maybe_enable_async_tp(job_config, world_mesh["tp"]) | ||
|
||
if parallel_dims.cp_enabled: | ||
for block in model.layers.values(): | ||
parallelize_module( | ||
module=block.attention.sdpa.attention_fn_wrapper, | ||
device_mesh=world_mesh["cp"], | ||
parallelize_plan=_ContextParallel( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So after this change, we only need to specify context parallel plan for attention module here, and CP of other modules is still handled by the context manager, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can check the discussion in pytorch/pytorch#162542. It's definitely good to remove the context manager, but that may also have some implications to how users should write the model, like the wrapper in this PR. |
||
seq_dim=2, | ||
attention_type=_ContextParallel.AttentionType.FLEX, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this only work for FlexAttention? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will work for both SDPA and Flex. We just need to pass in a different type based on what attention is used. |
||
), | ||
) | ||
|
||
model_compile_enabled = ( | ||
job_config.compile.enable and "model" in job_config.compile.components | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC for FlexAttn we need this wrapper because of block mask has to be obtained inside
FlexAttention
class before calling the wrapper. For SDPA it seems unnecessary? It is already a very thin wrapper.If the concern is code branching, the code is going to branch couple of lines below anyway, so I think it's fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not just about the unification. That wrapper must have the exact function signatures as scaled_dot_product_attention. Our ScaledDotProductAttention doesn't meet this requirement. More importantly, we don't want this wrapper to be broken when the core library changes the function signature of scaled_dot_product_attention or flex_attention. So the best UX is to always ask users to wrap the APIs with forward being
def forward(*args, **kwargs) -> Any
. So TorchTitan should also follow this rule.