Skip to content

Conversation

@vthumbe1503
Copy link
Collaborator

@vthumbe1503 vthumbe1503 commented Nov 12, 2025

Description

FSDP2 Allgather Perf improvement and support for FusedAdam with FSDP2

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • For FSDP2, FP8/MXFP8 tensors implement fsdp_pre_all_gather method that splits our tensors into uint8 tensor and metadata, so that it can be reconstructed into unsharded fp8/mxfp8 tensor after allgatherting Uint8 data.
    While constructing the metadata, we create a copy of the quantizer, for pre_allgather call in each iteration of training. This PR fixes that to create copy only once to reduce CPU overhead.

    • Before change: FSDP::pre_forward time = 2.1ms, After change: FSDP::pre_forward time = 1.9ms
  • Using FusedAdam with FP8 FSDP2 results in segfault currently. Since we werent taking care of DTensor FP8 tensors(FSDP2 specific) while doing optimizer step and rather only FP8 tensors. This is fixed now and FusedAdam works as expected with fp8 and FSDP2.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@vthumbe1503 vthumbe1503 marked this pull request as ready for review November 12, 2025 19:29
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 12, 2025

Greptile Overview

Greptile Summary

This PR addresses two key issues with FSDP2 FP8 support:

Performance Optimization: Reduces CPU overhead in FSDP2 allgather operations by eliminating redundant quantizer copies on every training iteration. The quantizer is now copied only once during the first fsdp_post_all_gather call, then reused in subsequent iterations. Usage flags (rowwise/columnwise) are now passed through metadata instead of being embedded in copied quantizer objects.

FusedAdam Segfault Fix: Resolves segmentation fault when using FusedAdam optimizer with FP8 FSDP2 parameters. The issue occurred because FSDP2 wraps FP8 tensors in DTensor, but the optimizer only checked for Float8Tensor instances directly. The fix adds DTensor detection and unwraps to the local Float8Tensor before accessing FP8-specific attributes.

The changes apply identical optimization patterns to both Float8Tensor and MXFP8Tensor implementations, maintaining consistency across FP8 data types.

Confidence Score: 5/5

  • This PR is safe to merge with high confidence - it addresses clear performance and correctness issues with well-scoped changes
  • The changes are straightforward optimizations and bug fixes with clear intent: (1) removing unnecessary quantizer copies reduces CPU overhead without affecting correctness since the constructor still makes the necessary copy on first use, and (2) adding DTensor detection fixes a clear type-checking gap that caused segfaults. Both changes follow consistent patterns across Float8Tensor and MXFP8Tensor implementations.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/optimizers/fused_adam.py 5/5 Adds DTensor FP8 support by checking if parameter is DTensor with Float8Tensor local tensor and unwrapping before accessing FP8 metadata
transformer_engine/pytorch/tensor/float8_tensor.py 5/5 Optimizes FSDP2 allgather by passing usage flags in metadata instead of copying quantizer on every iteration, reducing CPU overhead
transformer_engine/pytorch/tensor/mxfp8_tensor.py 5/5 Mirrors float8_tensor.py optimization by passing usage flags in metadata and avoiding repeated quantizer copies during FSDP2 allgather

Sequence Diagram

sequenceDiagram
    participant FSDP2 as FSDP2 Framework
    participant FP8Tensor as Float8/MXFP8 Tensor
    participant Quantizer as Quantizer (Shared)
    participant FusedAdam as FusedAdam Optimizer

    Note over FSDP2,FusedAdam: Scenario 1: FSDP2 Allgather (Performance Optimization)
    
    FSDP2->>FP8Tensor: fsdp_pre_all_gather(module, mesh)
    FP8Tensor->>FP8Tensor: Determine rowwise/columnwise usage
    FP8Tensor->>FSDP2: Return (sharded_data, metadata with usage flags)
    Note over FP8Tensor: No quantizer.copy() - uses self._quantizer
    
    FSDP2->>FSDP2: Perform allgather on sharded_data
    
    FSDP2->>FP8Tensor: fsdp_post_all_gather(outputs, metadata, out=None)
    Note over FP8Tensor: First iteration: out is None
    FP8Tensor->>FP8Tensor: Create new tensor with self._quantizer
    Note over Quantizer: Constructor makes one copy here
    FP8Tensor->>Quantizer: set_usage(rowwise, columnwise)
    FP8Tensor->>FSDP2: Return allgathered tensor
    
    FSDP2->>FP8Tensor: fsdp_post_all_gather(outputs, metadata, out=cached)
    Note over FP8Tensor: Subsequent iterations: reuse out
    FP8Tensor->>FP8Tensor: Update out._data only
    FP8Tensor->>Quantizer: set_usage(rowwise, columnwise)
    Note over FP8Tensor: No new quantizer copy
    FP8Tensor->>FSDP2: Return updated tensor
    
    Note over FSDP2,FusedAdam: Scenario 2: FusedAdam with FSDP2 (Bug Fix)
    
    FusedAdam->>FusedAdam: step() - iterate parameters
    FusedAdam->>FusedAdam: Check isinstance(p, Float8Tensor)
    Note over FusedAdam: Old: Fails for DTensor wrapper
    FusedAdam->>FusedAdam: Check isinstance(p, DTensor) and isinstance(p._local_tensor, Float8Tensor)
    Note over FusedAdam: New: Detects DTensor(Float8Tensor)
    FusedAdam->>FusedAdam: p = p._local_tensor (unwrap)
    FusedAdam->>FusedAdam: Access p._data, p._fp8_dtype
    Note over FusedAdam: Now works correctly, no segfault
    FusedAdam->>FusedAdam: Apply optimizer update
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/optimizers/fused_adam.py, line 583 (link)

    logic: When p is a DTensor wrapping a Float8Tensor, p_grad will also be a DTensor. Need to unwrap it before accessing .data

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@vthumbe1503 vthumbe1503 requested a review from ptrendx November 12, 2025 22:30
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

vthumbe1503 and others added 3 commits November 13, 2025 21:20
Signed-off-by: Varun Thumbe <[email protected]>
…503/TransformerEngine into optimization_for_fsdp_allgather

Signed-off-by: Varun Thumbe <[email protected]>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/optimizers/fused_adam.py, line 583 (link)

    logic: When p is a DTensor wrapping Float8Tensor, verify that p_grad (obtained from p.grad before unwrapping on line 574) doesn't also need unwrapping. If p_grad is also a DTensor, accessing .data may not give the underlying FP8 uint8 data needed for the optimizer kernel.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant