Skip to content

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Sep 12, 2025

temporary fix for #1136 where loss are worse when resuming training from checkpoint

loss becomes exactly the same after disabling cache_state_dict. I need to understand why this only happens to unven sharding. but provid this workaround to unblock customers

1st run to save checkpoint at step 10 and 20: NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh

[rank0]:[titan] 2025-09-11 17:33:50,431 - root - INFO - step:  1  loss:  8.0403  grad_norm:  1.3470  memory:  1.02GiB(1.08%)  tps: 27,483  tflops: 1.97  mfu: 0.20%
[rank0]:[titan] 2025-09-11 17:33:50,432 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-09-11 17:33:50,479 - root - INFO - step:  2  loss:  7.7561  grad_norm:  1.4095  memory:  1.07GiB(1.13%)  tps: 346,268  tflops: 24.77  mfu: 2.50%
[rank0]:[titan] 2025-09-11 17:33:50,520 - root - INFO - step:  3  loss:  7.0386  grad_norm:  1.8312  memory:  1.07GiB(1.13%)  tps: 403,390  tflops: 28.86  mfu: 2.92%
[rank0]:[titan] 2025-09-11 17:33:50,561 - root - INFO - step:  4  loss:  6.2345  grad_norm:  2.2490  memory:  1.07GiB(1.13%)  tps: 399,989  tflops: 28.61  mfu: 2.89%
[rank0]:[titan] 2025-09-11 17:33:50,601 - root - INFO - step:  5  loss:  5.2014  grad_norm:  2.4889  memory:  1.07GiB(1.13%)  tps: 414,189  tflops: 29.63  mfu: 3.00%
[rank0]:[titan] 2025-09-11 17:33:50,647 - root - INFO - step:  6  loss:  4.6400  grad_norm:  2.2484  memory:  1.07GiB(1.13%)  tps: 360,624  tflops: 25.80  mfu: 2.61%
[rank0]:[titan] 2025-09-11 17:33:50,687 - root - INFO - step:  7  loss:  4.2797  grad_norm:  2.0226  memory:  1.07GiB(1.13%)  tps: 415,778  tflops: 29.74  mfu: 3.01%
[rank0]:[titan] 2025-09-11 17:33:50,726 - root - INFO - step:  8  loss:  4.0060  grad_norm:  1.7718  memory:  1.07GiB(1.13%)  tps: 418,295  tflops: 29.92  mfu: 3.03%
[rank0]:[titan] 2025-09-11 17:33:50,766 - root - INFO - step:  9  loss:  3.9769  grad_norm:  1.4904  memory:  1.07GiB(1.13%)  tps: 418,463  tflops: 29.94  mfu: 3.03%
[rank0]:[titan] 2025-09-11 17:33:50,810 - root - INFO - step: 10  loss:  3.6987  grad_norm:  1.4993  memory:  1.07GiB(1.13%)  tps: 367,048  tflops: 26.26  mfu: 2.65%
[rank0]:[titan] 2025-09-11 17:33:50,811 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank0]:[titan] 2025-09-11 17:33:51,843 - root - INFO - [GC] GC collection invoked by checkpointer. 0.01 seconds
[rank0]:[titan] 2025-09-11 17:33:51,843 - root - INFO - Finished saving the checkpoint (or staging if async is enabled)in 1.03 seconds.
[rank0]:[titan] 2025-09-11 17:33:51,892 - root - INFO - step: 11  loss:  3.5623  grad_norm:  1.3196  memory:  1.07GiB(1.13%)  tps: 15,158  tflops: 1.08  mfu: 0.11%
[rank0]:[titan] 2025-09-11 17:33:51,933 - root - INFO - step: 12  loss:  3.5165  grad_norm:  1.0960  memory:  1.07GiB(1.13%)  tps: 401,532  tflops: 28.72  mfu: 2.90%
[rank0]:[titan] 2025-09-11 17:33:51,976 - root - INFO - step: 13  loss:  3.4645  grad_norm:  0.9459  memory:  1.07GiB(1.13%)  tps: 377,112  tflops: 26.98  mfu: 2.73%
[rank0]:[titan] 2025-09-11 17:33:52,026 - root - INFO - step: 14  loss:  3.3633  grad_norm:  0.9229  memory:  1.07GiB(1.13%)  tps: 330,803  tflops: 23.66  mfu: 2.39%
[rank0]:[titan] 2025-09-11 17:33:52,067 - root - INFO - step: 15  loss:  3.3642  grad_norm:  0.8076  memory:  1.07GiB(1.13%)  tps: 409,659  tflops: 29.31  mfu: 2.96%
[rank0]:[titan] 2025-09-11 17:33:52,114 - root - INFO - step: 16  loss:  3.3604  grad_norm:  0.7525  memory:  1.07GiB(1.13%)  tps: 351,022  tflops: 25.11  mfu: 2.54%
[rank0]:[titan] 2025-09-11 17:33:52,159 - root - INFO - step: 17  loss:  3.2142  grad_norm:  0.7370  memory:  1.07GiB(1.13%)  tps: 365,952  tflops: 26.18  mfu: 2.65%
[rank0]:[titan] 2025-09-11 17:33:52,199 - root - INFO - step: 18  loss:  3.1889  grad_norm:  0.7102  memory:  1.07GiB(1.13%)  tps: 408,904  tflops: 29.25  mfu: 2.96%
[rank0]:[titan] 2025-09-11 17:33:52,242 - root - INFO - step: 19  loss:  3.3118  grad_norm:  0.6511  memory:  1.07GiB(1.13%)  tps: 379,809  tflops: 27.17  mfu: 2.75%
[rank0]:[titan] 2025-09-11 17:33:52,283 - root - INFO - step: 20  loss:  3.2661  grad_norm:  0.6239  memory:  1.07GiB(1.13%)  tps: 408,047  tflops: 29.19  mfu: 2.95%

2nd run to load checkpoint from step 10: rm -rf outputs/checkpoint/step-20 && NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh

[rank0]:[titan] 2025-09-11 17:34:17,543 - root - INFO - Training starts at step 11
[rank0]:[titan] 2025-09-11 17:34:17,884 - root - INFO - step: 11  loss:  3.5623  grad_norm:  1.3196  memory:  1.03GiB(1.09%)  tps: 10,552  tflops: 0.75  mfu: 0.08%
[rank0]:[titan] 2025-09-11 17:34:17,930 - root - INFO - step: 12  loss:  3.5165  grad_norm:  1.0960  memory:  1.07GiB(1.13%)  tps: 360,919  tflops: 25.82  mfu: 2.61%
[rank0]:[titan] 2025-09-11 17:34:17,975 - root - INFO - step: 13  loss:  3.4645  grad_norm:  0.9459  memory:  1.07GiB(1.13%)  tps: 361,003  tflops: 25.83  mfu: 2.61%
[rank0]:[titan] 2025-09-11 17:34:18,017 - root - INFO - step: 14  loss:  3.3633  grad_norm:  0.9229  memory:  1.07GiB(1.13%)  tps: 401,298  tflops: 28.71  mfu: 2.90%
[rank0]:[titan] 2025-09-11 17:34:18,057 - root - INFO - step: 15  loss:  3.3642  grad_norm:  0.8077  memory:  1.07GiB(1.13%)  tps: 404,723  tflops: 28.95  mfu: 2.93%
[rank0]:[titan] 2025-09-11 17:34:18,106 - root - INFO - step: 16  loss:  3.3604  grad_norm:  0.7525  memory:  1.07GiB(1.13%)  tps: 339,546  tflops: 24.29  mfu: 2.46%
[rank0]:[titan] 2025-09-11 17:34:18,154 - root - INFO - step: 17  loss:  3.2142  grad_norm:  0.7370  memory:  1.07GiB(1.13%)  tps: 340,067  tflops: 24.33  mfu: 2.46%
[rank0]:[titan] 2025-09-11 17:34:18,195 - root - INFO - step: 18  loss:  3.1889  grad_norm:  0.7102  memory:  1.07GiB(1.13%)  tps: 403,231  tflops: 28.85  mfu: 2.92%
[rank0]:[titan] 2025-09-11 17:34:18,243 - root - INFO - step: 19  loss:  3.3118  grad_norm:  0.6511  memory:  1.07GiB(1.13%)  tps: 346,384  tflops: 24.78  mfu: 2.51%
[rank0]:[titan] 2025-09-11 17:34:18,284 - root - INFO - step: 20  loss:  3.2661  grad_norm:  0.6239  memory:  1.07GiB(1.13%)  tps: 399,322  tflops: 28.57  mfu: 2.89%

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 12, 2025
@weifengpy weifengpy marked this pull request as draft September 12, 2025 00:47
@weifengpy
Copy link
Contributor Author

keep this as draft. this is to unblock customers if they are urgent. I still need to understand how cache_state_dict is related to uneven sharding

@fegin
Copy link
Contributor

fegin commented Sep 12, 2025

This is an excellent finding. But I don't understand why? We cache during Checkpointer ctor, which should be after model is wrapped and before the training. However the timing of DCP.load should also be the same. So caching or not caching shouldn't make a big difference.

@weifengpy
Copy link
Contributor Author

This is an excellent finding. But I don't understand why? We cache during Checkpointer ctor, which should be after model is wrapped and before the training. However the timing of DCP.load should also be the same. So caching or not caching shouldn't make a big difference.

discussed with @fegin , the core is when we need to make sure model.state_dict() should return padded parameters. there are 2 ways

  • let fsdp2 do in-place padding
  • call fsdp2 lazy init in state_dict()

a minimal repro is

# torchrun --standalone --nproc_per_node=2 run_fsdp2.py

import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import fully_shard

def main():
    dist.init_process_group(backend="nccl")
    gpu_id = int(os.environ["LOCAL_RANK"])
    device = f"cuda:{gpu_id}"
    torch.cuda.set_device(device)
    torch.manual_seed(0)
    input = torch.randn((3, 3), device=device)
    model = nn.Linear(3, 3, bias=False)
    fully_shard(model, reshard_after_forward=True)
    state_dict = model.state_dict()
    
    print(f"rank:{torch.distributed.get_rank()} before lazy_init {state_dict['weight']._local_tensor.untyped_storage().data_ptr() == model.weight._local_tensor.untyped_storage().data_ptr()}")
    loss = model(input).sum()
    # last rank holds staled reference because fsdp padded model.parameters()
    print(f"rank:{torch.distributed.get_rank()} before lazy_init {state_dict['weight']._local_tensor.untyped_storage().data_ptr() == model.weight._local_tensor.untyped_storage().data_ptr()}")


if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants