Skip to content

Conversation

soulitzer
Copy link
Contributor

@soulitzer soulitzer commented Sep 1, 2025

CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml ./run_train.sh --parallelism.expert_parallel_degree 4 --model.hf_assets_path "./assets/hf/deepseek-moe-16b-base"

Before (not saving a2a and to_copy)

[rank0]:[titan] 2025-09-01 16:03:30,779 - root - INFO - step:  1  loss: 12.0469  grad_norm:  1.8420  memory: 61.72GiB(77.99%)  tps: 297  tflops: 4.48  mfu: 1.43%
[rank0]:[titan] 2025-09-01 16:03:30,780 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-09-01 16:04:22,882 - root - INFO - step: 10  loss: 11.2800  grad_norm:  2.4321  memory: 70.02GiB(88.48%)  tps: 708  tflops: 10.65  mfu: 3.41%

https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_recompute.json.gz&bucket=pytorch

After (saving a2a and to_copy)

[rank0]:[titan] 2025-09-01 16:01:39,691 - root - INFO - step:  1  loss: 12.0470  grad_norm:  1.8420  memory: 64.49GiB(81.50%)  tps: 321  tflops: 4.82  mfu: 1.55%
[rank0]:[titan] 2025-09-01 16:01:39,691 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-09-01 16:02:25,603 - root - INFO - step: 10  loss: 11.2801  grad_norm:  2.4322  memory: 74.53GiB(94.17%)  tps: 803  tflops: 12.08  mfu: 3.87%

(save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a_wait_tensor.json.gz&bucket=pytorch

(don't save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a.json.gz&bucket=pytorch

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 1, 2025
num_tokens_per_expert.view(ep_size, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
.to(torch.device("cpu"), non_blocking=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh could you remind me of the reason we have to use non_blocking=False?
I think it may not matter too much as this two d2h syncs are adjacent to each other.
If we have to do this, we can remove the non_blocking arg as False is the default.

Copy link
Contributor Author

@soulitzer soulitzer Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is a good point, we can avoid blocking for the first .to(), although yeah I don't think it changed tps very much

torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
torch.ops._c10d_functional.all_to_all_single.default,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you confirm that this saves both:

  1. the dist.all_to_all_single to obtain routing info
  2. the actual all_to_all_single_autograd to route tokens

I think ideally we'd like both to be saved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, dist.all_to_all_single is actually a different op.

Do you know why we use two different all-to-alls here? I don't think dist.all_to_all_single works with SAC, it
mutates a "output" tensor that the user provides and returns a work object.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, there may not be strong reason to.
Could you try the fun col version? https://github.com/pytorch/pytorch/blob/main/torch/distributed/_functional_collectives.py#L445
If it works we can switch to this one, and hopefully the AC policy would capture both, because underlying the same torch.ops._c10d_functional.all_to_all_single.default gets called.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use the fun col version

@fegin
Copy link
Contributor

fegin commented Sep 3, 2025

Just a FYI, @soulitzer , #1675 conflicts with this PR.

)
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
num_tokens_per_expert_group
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need an explicit wait because num_tokens_per_expert_group gets used by a triton kernel, which doesn't realize that AsyncCollectiveTensor needs to be unwrapped.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you make this a comment in the code? I think it's very helpful.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had one more question regarding saving the results of wait_tensor.
Also it would be great if you could share some profiler traces in PR summary.

torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
torch.ops._c10d_functional.all_to_all_single.default,
torch.ops._c10d_functional.wait_tensor.default,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this has any side effect, as in the mapping from all collectives and wait is many-to-one. In particular,

  1. Would this line save all the communication results, not only from a2a but also e.g. TP all-gather?
  2. Would not having this line save none of the communication results? I.e. did the torch.ops._c10d_functional.reduce_scatter_tensor.default, line take effect?

Copy link
Contributor Author

@soulitzer soulitzer Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there should be side effect, unless there are other wait_tensors being explicitly called.

  1. Ordinarily, the AsyncCollectiveTensor triggers the wait before executing the op in its torch dispatch, so it would actually be hidden from SAC (user modes execute before user subclasses unwrap). SAC should only be able to see / save the wait if we're calling it explicitly here.

  2. The lines for reduce scatter, etc will save AsyncCollectiveTensor, and in the original forward, when wait happens via the subclasses's torch dispatch, the wait result should be cached onto the AsyncCollectiveTensor, so that a second wait should not be triggered during recompute.

That being said. I'm not actually entirely sure what happens when you executing wait explicitly on an AsyncCollectiveTensor again even though the collective has already been waited on. Checking again, removing it doesn't seem to affect tps, so I think I will remove it.

@soulitzer
Copy link
Contributor Author

Added links to some profiler traces in the summary. From staring at the traces, saving the wait_tensor reduces the cpu overhead from 200us to 40us, but doesn't really seem to affect tps, so removing it to minimize risk of side effects.

)
num_tokens_per_expert_group = torch.ops._c10d_functional.wait_tensor(
num_tokens_per_expert_group
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you make this a comment in the code? I think it's very helpful.

.to(torch.device("cpu"), non_blocking=True)
.to(torch.device("cpu"), non_blocking=False)
)
# NOTE: this would incur a device-to-host sync
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please move this note to the actual blocking call above

and "device" in kwargs
and str(kwargs["device"]) == "cpu"
):
return CheckpointPolicy.MUST_SAVE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm did this take effect? I would guess we don't need to do any d2h sync in backward anymore, but in the traces I'm still seeing them in backward.

(save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a_wait_tensor.json.gz&bucket=pytorch
(don't save wait_tensor) https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/jw3468/20250904/sac_save_to_copy_a2a.json.gz&bucket=pytorch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm there is still cudaStreamSync in FlexAttentionBackward but it is expected since SAC only takes effect for the replay of the forward. Is there another place where you see it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to @drisspg this H2D of FlexAttentionBackward is from the eager implementation. #1683 will fix the FlexAttention compilation issue.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great!

@tianyu-l tianyu-l merged commit c4e2291 into main Sep 6, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot. high priority module: activation checkpointing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants