Skip to content

Conversation

wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Sep 8, 2025

As titled.

recommended training config with might tps and mfu:

  1. local batch size=2, FSDP=8, full AC

When torch.compile is enabled, FSDP=8 gives following results:

[rank7]:[titan] 2025-09-11 15:23:55,616 - root - INFO - step: 50  loss:  7.1291  grad_norm:  8.3760  memory: 81.85GiB(86.17%)  tps: 1,874  tflops: 389.87  mfu: 39.42%

Some other recommended training configs (no compile applied for following benchmarking)

Local bs (gbs) seq_len Training parallelism Performance
16 (64) 4096 fsdp 4, tp 8 Full AC[trainer1|1]:[titan] 2025-09-05 16:03:11,935 - root - INFO - step: 20  loss:  8.0074  grad_norm:  5.1354  memory: 43.45GiB(45.73%)  tps: 876  tflops: 182.14  mfu: 18.42%
32 (128) 4096 fsdp 4, tp 8 Full AC, [trainer5|5]:[titan] 2025-09-05 16:08:09,172 - root - INFO - step: 10  loss: 10.7079  grad_norm: 12.3490  memory: 70.79GiB(74.51%)  tps: 883  tflops: 183.59  mfu: 18.56%
32 (256) 4096 Fsdp8, tp 8 Full AC, [trainer55|7]:[titan] 2025-09-05 16:28:17,152 - root - INFO - step: 70  loss:  6.7005  grad_norm:  5.5571  memory: 63.27GiB(66.60%)  tps: 873  tflops: 181.63  mfu: 18.36%
64 (256) 4096 Fsdp 4, tp 8 Full AC,OOM
64 (512) 4096 Fsdp 8, tp 8 Full AC,OOM
1 (8) 4096 FSDP 8 [rank0]:[titan] 2025-09-11 15:10:21,428 - root - INFO - step: 50 loss: 7.1103 grad_norm: 2.5687 memory: 81.47GiB(85.76%) tps: 1,444 tflops: 300.45 mfu: 30.38%
2 (16) 4096 FSDP 8 [rank7]:[titan] 2025-09-11 15:17:40,962 - root - INFO - step: 50 loss: 7.0498 grad_norm: 4.9964 memory: 89.33GiB(94.04%) tps: 1,614 tflops: 335.70 mfu: 33.94%
4 (64) 4096 FSDP 16 [trainer11|3]:[titan] 2025-09-08 11:41:10,821 - root - INFO - step: 80  loss:  6.6554  grad_norm:  4.3011  memory: 78.48GiB(82.54%)  tps: 1,518  tflops: 315.72  mfu: 31.92%
8 4096 FSDP 16 OOM
16 4096 Fsdp=16 OOM
16 4096 fsdp=32 OOM
16 4096 fsdp=16, tp 2 OOM
16 4096 fsdp=8, tp=4 [trainer0|0]:[titan] 2025-09-08 11:50:11,217 - root - INFO - step: 90  loss:  7.4707  grad_norm:  8.5404  memory: 70.20GiB(73.89%)  tps: 1,139  tflops: 236.94  mfu: 23.96%
32 4096 fsdp=8, tp=4 OOM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 8, 2025
@wwwjn wwwjn requested a review from tianyu-l September 8, 2025 20:18
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why you pick FSDP=8, TP=4 when FSDP 16 seems giving better MFU?

data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
fsdp_reshard_after_forward = "default" # default / never / always
tensor_parallel_degree = 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


[optimizer]
name = "AdamW"
lr = 3e-4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how far we'd like to go, but ideally a recommended config should have lr and batch size verified by some loss converging runs. So it's about both perf + reasonable converging.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For learning rate, I checked the official tech report didn't mention recommended learning rate. I will kick off a training job to verify converging as well

@wwwjn
Copy link
Contributor Author

wwwjn commented Sep 10, 2025

Curious why you pick FSDP=8, TP=4 when FSDP 16 seems giving better MFU?

I check the overall all for each step, the total time for each profiler step is:

  1. FSDP=16, local bs = 4: 10s for each step, mfu: 31.92%
  2. FSDP=8, TP=4, local bs=16: 15s for each step, mfu: 23.96%
  3. FSDP=4, TP=8, local bs=32: 19.6s for each step, mfu: 18.36%

Comparing these 3 setting, I choose 2 because it overall runs faster with same number of data samples, with ok mfu performance. But I would also want to let user choose if they want better MFU/tflops or shortest total training time.

@tianyu-l
Copy link
Contributor

@wwwjn
I see you are trying to optimizer throughput than per-device MFU. But even for throughput, the comparison doesn't seem fair.
The FSDP 16 has half GPUs compared with the other two. I feel if you switch to FSDP 32, local batch size 4, it will give the best throughput + better MFU than the other FSDP+TP runs.

I think naively, the better MFU, the faster you train?

@wwwjn
Copy link
Contributor Author

wwwjn commented Sep 10, 2025

@wwwjn I see you are trying to optimizer throughput than per-device MFU. But even for throughput, the comparison doesn't seem fair. The FSDP 16 has half GPUs compared with the other two. I feel if you switch to FSDP 32, local batch size 4, it will give the best throughput + better MFU than the other FSDP+TP runs.

I think naively, the better MFU, the faster you train?

Agree, better tflops (suppose total flops for a model are the same) , the faster we train. That would be simpler - so we will target optimizing tflops/mfu (gbs 64 seems ok).

Here's FSDP=16 profiler: good comm/comp overlap, but CPU bounded
Screenshot 2025-09-09 at 9 30 40 PM

@@ -0,0 +1,64 @@
# NOTE: this toml config is a preset for 8 H100 GPUs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FSDP 8 sounds good enough, although not all H100 have 96GB memory (we are using something https://github.com/pytorch/torchtitan/blob/main/benchmarks/llama3_h100_202412_torchtitan.md?plain=1#L13) -- some of them only have 80GB. We can add it to the note here.

@wwwjn
Copy link
Contributor Author

wwwjn commented Sep 18, 2025

Add a little bit more context for learning rate here:

  1. The largest lr I've tested is 1e-2, and the model coverages smoothly at first 5k steps:
Screenshot 2025-09-17 at 7 28 24 PM
  1. We tested 1.5e-4, 3e-4, 8e-4, 1e-3, 5e-3, 1e-2 such kind of learning rate, and all of them coverages, with 3e-4 and 8e-4 coverages slightly faster than other lr (if we solely judged by the training loss)
Screenshot 2025-09-17 at 7 31 28 PM Screenshot 2025-09-17 at 7 35 07 PM

@wwwjn wwwjn merged commit d66b72a into main Sep 18, 2025
4 checks passed
@tianyu-l tianyu-l deleted the qwen-32b branch September 18, 2025 02:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants