-
Notifications
You must be signed in to change notification settings - Fork 545
FSDP2 Allgather Perf improvement and support for FusedAdam with FSDP2 #2370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
FSDP2 Allgather Perf improvement and support for FusedAdam with FSDP2 #2370
Conversation
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <[email protected]>
…ne into fix_ci_error Signed-off-by: Varun Thumbe <[email protected]>
Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR addresses two key issues with FSDP2 FP8 support: Performance Optimization: Reduces CPU overhead in FSDP2 allgather operations by eliminating redundant quantizer copies on every training iteration. The quantizer is now copied only once during the first FusedAdam Segfault Fix: Resolves segmentation fault when using FusedAdam optimizer with FP8 FSDP2 parameters. The issue occurred because FSDP2 wraps FP8 tensors in DTensor, but the optimizer only checked for The changes apply identical optimization patterns to both Confidence Score: 5/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant FSDP2 as FSDP2 Framework
participant FP8Tensor as Float8/MXFP8 Tensor
participant Quantizer as Quantizer (Shared)
participant FusedAdam as FusedAdam Optimizer
Note over FSDP2,FusedAdam: Scenario 1: FSDP2 Allgather (Performance Optimization)
FSDP2->>FP8Tensor: fsdp_pre_all_gather(module, mesh)
FP8Tensor->>FP8Tensor: Determine rowwise/columnwise usage
FP8Tensor->>FSDP2: Return (sharded_data, metadata with usage flags)
Note over FP8Tensor: No quantizer.copy() - uses self._quantizer
FSDP2->>FSDP2: Perform allgather on sharded_data
FSDP2->>FP8Tensor: fsdp_post_all_gather(outputs, metadata, out=None)
Note over FP8Tensor: First iteration: out is None
FP8Tensor->>FP8Tensor: Create new tensor with self._quantizer
Note over Quantizer: Constructor makes one copy here
FP8Tensor->>Quantizer: set_usage(rowwise, columnwise)
FP8Tensor->>FSDP2: Return allgathered tensor
FSDP2->>FP8Tensor: fsdp_post_all_gather(outputs, metadata, out=cached)
Note over FP8Tensor: Subsequent iterations: reuse out
FP8Tensor->>FP8Tensor: Update out._data only
FP8Tensor->>Quantizer: set_usage(rowwise, columnwise)
Note over FP8Tensor: No new quantizer copy
FP8Tensor->>FSDP2: Return updated tensor
Note over FSDP2,FusedAdam: Scenario 2: FusedAdam with FSDP2 (Bug Fix)
FusedAdam->>FusedAdam: step() - iterate parameters
FusedAdam->>FusedAdam: Check isinstance(p, Float8Tensor)
Note over FusedAdam: Old: Fails for DTensor wrapper
FusedAdam->>FusedAdam: Check isinstance(p, DTensor) and isinstance(p._local_tensor, Float8Tensor)
Note over FusedAdam: New: Detects DTensor(Float8Tensor)
FusedAdam->>FusedAdam: p = p._local_tensor (unwrap)
FusedAdam->>FusedAdam: Access p._data, p._fp8_dtype
Note over FusedAdam: Now works correctly, no segfault
FusedAdam->>FusedAdam: Apply optimizer update
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/optimizers/fused_adam.py, line 583 (link)logic: When
pis aDTensorwrapping aFloat8Tensor,p_gradwill also be aDTensor. Need to unwrap it before accessing.data
3 files reviewed, 1 comment
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <[email protected]>
…503/TransformerEngine into optimization_for_fsdp_allgather Signed-off-by: Varun Thumbe <[email protected]>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/pytorch/optimizers/fused_adam.py, line 583 (link)logic: When
pis aDTensorwrappingFloat8Tensor, verify thatp_grad(obtained fromp.gradbefore unwrapping on line 574) doesn't also need unwrapping. Ifp_gradis also aDTensor, accessing.datamay not give the underlying FP8 uint8 data needed for the optimizer kernel.
3 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
Description
FSDP2 Allgather Perf improvement and support for FusedAdam with FSDP2
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
For FSDP2, FP8/MXFP8 tensors implement fsdp_pre_all_gather method that splits our tensors into uint8 tensor and metadata, so that it can be reconstructed into unsharded fp8/mxfp8 tensor after allgatherting Uint8 data.
While constructing the metadata, we create a copy of the quantizer, for pre_allgather call in each iteration of training. This PR fixes that to create copy only once to reduce CPU overhead.
Using FusedAdam with FP8 FSDP2 results in segfault currently. Since we werent taking care of DTensor FP8 tensors(FSDP2 specific) while doing optimizer step and rather only FP8 tensors. This is fixed now and FusedAdam works as expected with fp8 and FSDP2.
Checklist: