Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions torchtitan/distributed/activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def _apply_ac_to_transformer_block(

def _get_custom_policy(meta):
def _custom_policy(ctx, func, *args, **kwargs):
if (
func == torch.ops.aten._to_copy.default
and "cuda" in str(args[0].device)
and "device" in kwargs
and str(kwargs["device"]) == "cpu"
):
return CheckpointPolicy.MUST_SAVE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm did this take effect? I would guess we don't need to do any d2h sync in backward anymore, but in the traces I'm still seeing them in backward.

(save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a_wait_tensor.json.gz&bucket=pytorch
(don't save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a.json.gz&bucket=pytorch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm there is still cudaStreamSync in FlexAttentionBackward but it is expected since SAC only takes effect for the replay of the forward. Is there another place where you see it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to @drisspg this H2D of FlexAttentionBackward is from the eager implementation. #1683 will fix the FlexAttention compilation issue.

mode = "recompute" if ctx.is_recompute else "forward"
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
Expand Down
24 changes: 14 additions & 10 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from typing import Callable, Literal

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._functional_collectives import all_to_all_single_autograd
from torch.distributed._functional_collectives import (
all_to_all_single,
all_to_all_single_autograd,
)
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
Expand Down Expand Up @@ -90,26 +92,28 @@ def _token_dispatch(self, mod, inputs, device_mesh):

# generate the input splits and output splits for all-to-all
with torch.no_grad():
num_tokens_per_expert_group = num_tokens_per_expert.new_empty(
num_tokens_per_expert.shape[0]
)
dist.all_to_all_single(
num_tokens_per_expert_group,
num_tokens_per_expert_group = all_to_all_single(
num_tokens_per_expert,
None,
None,
group=device_mesh.get_group(),
)
# Need to wait explicitly because it is used by a triton kernel later
# which doesn't realize that AsyncCollectiveTensor needs unwrapping
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
num_tokens_per_expert_group
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need an explicit wait because num_tokens_per_expert_group gets used by a triton kernel, which doesn't realize that AsyncCollectiveTensor needs to be unwrapped.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you make this a comment in the code? I think it's very helpful.

input_splits = (
num_tokens_per_expert.view(ep_size, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
)
# NOTE: this would incur a device-to-host sync
output_splits = (
num_tokens_per_expert_group.view(ep_size, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
.to(torch.device("cpu"), non_blocking=False)
)
# NOTE: this would incur a device-to-host sync
torch.cuda.current_stream().synchronize()
self.input_splits = input_splits.tolist()
self.output_splits = output_splits.tolist()

Expand Down
1 change: 1 addition & 0 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
torch.ops._c10d_functional.all_to_all_single.default,
# for low precision training, it's useful to always save
# the result of max, since the absolute maximum is
# used to compute the scaling factor for quantization.
Expand Down
1 change: 1 addition & 0 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
torch.ops._c10d_functional.all_to_all_single.default,
# for low precision training, it's useful to always save
# the result of max, since the absolute maximum is
# used to compute the scaling factor for quantization.
Expand Down