Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions repro.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh
rm -rf outputs/checkpoint/step-30 && NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh
88 changes: 87 additions & 1 deletion torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from copy import deepcopy
import enum
import functools
import os
Expand Down Expand Up @@ -59,6 +60,7 @@ class ModelWrapper(Stateful):
def __init__(self, model: nn.Module | list[nn.Module]) -> None:
self.model = [model] if isinstance(model, nn.Module) else model
self.cache_state_dict = self._get_state_dict()
assert self.cache_state_dict['tok_embeddings.weight']._local_tensor.untyped_storage().data_ptr() == self.model[0].tok_embeddings.weight._local_tensor.untyped_storage().data_ptr()

def _get_state_dict(self) -> dict[str, Any]:
state_dict = {
Expand Down Expand Up @@ -231,6 +233,11 @@ def load_state_dict(state_dict):
LR_SCHEDULER: lr_schedulers,
}
)
try:
assert self.states['model'].state_dict()['tok_embeddings.weight']._local_tensor.untyped_storage().data_ptr(), self.states[MODEL].model[0].tok_embeddings.weight._local_tensor.untyped_storage().data_ptr()
except:
import fbvscode
fbvscode.set_trace()
self.ft_states = {DATALOADER: dataloader}

self.staging = False
Expand Down Expand Up @@ -397,6 +404,7 @@ def dcp_load(
state_dict: dict[str, Any],
checkpoint_id: str,
from_hf: bool,
step: int = -1,
) -> None:
"""Load the checkpoint with dcp.
Args:
Expand All @@ -420,12 +428,56 @@ def dcp_load(
state_dict = self.sd_adapter.from_hf(hf_state_dict)
self.states[MODEL].load_state_dict(state_dict)
else:
before_load = state_dict['tok_embeddings.weight']._local_tensor.clone()
before_load_full = state_dict['tok_embeddings.weight'].full_tensor().clone()
try:
assert torch.equal(state_dict['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor)
except:
logger.info(f"{torch.distributed.get_rank()=} does not have equal")
import fbvscode
fbvscode.set_trace()
dcp.load(state_dict, checkpoint_id=checkpoint_id)
after_load = state_dict['tok_embeddings.weight']._local_tensor.clone()
after_load_full = state_dict['tok_embeddings.weight'].full_tensor().clone()
try:
assert torch.equal(state_dict['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor)
assert torch.equal(before_load, after_load)
assert torch.equal(before_load_full, after_load_full)
# dcp.load(state_dict, checkpoint_id=checkpoint_id)
except:
logger.info(f"{torch.distributed.get_rank()=} does not have equal")
import fbvscode
fbvscode.set_trace()

for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()):
full_param = param.full_tensor()
ref_full_param = ref_param.full_tensor()
local_param = param._local_tensor
ref_local_param = ref_param._local_tensor
try:
# if torch.distributed.get_rank() != 3:
assert torch.equal(local_param, ref_local_param)
# assert torch.equal(full_param, ref_full_param)
except:
import fbvscode
fbvscode.set_trace()

# TODO: Since we flatten the model states in state_dict, we need to
# manually call load_state_dict() for the model. Need to fix this.
if MODEL in self.states:
# import fbvscode
# fbvscode.set_trace()
try:
assert torch.equal(state_dict['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor)
except:
import fbvscode
fbvscode.set_trace()
self.states[MODEL].load_state_dict(state_dict)
try:
assert torch.equal(state_dict['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor)
except:
import fbvscode
fbvscode.set_trace()

@torch.no_grad()
def save(self, curr_step: int, last_step: bool = False) -> None:
Expand Down Expand Up @@ -481,13 +533,27 @@ def save(self, curr_step: int, last_step: bool = False) -> None:
)
GarbageCollection.collect("GC collection invoked by checkpointer.")
else:
try:
assert states['tok_embeddings.weight']._local_tensor.untyped_storage().data_ptr() == self.states[MODEL].model[0].tok_embeddings.weight._local_tensor.untyped_storage().data_ptr()
except:
import fbvscode
fbvscode.set_trace()
self.dcp_save(
states,
checkpoint_id=checkpoint_id,
async_mode=AsyncMode.DISABLED,
enable_garbage_collection=True,
)
try:
assert torch.equal(self.states['model']._get_state_dict()['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor)
assert torch.equal(self.states['model'].state_dict()['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor)
except:
import fbvscode
fbvscode.set_trace()
self._purge_stale_checkpoints()
# import fbvscode
# fbvscode.set_trace()
self.load(step=-1)

logger.info(
"Finished saving the checkpoint (or staging if async is enabled)"
Expand Down Expand Up @@ -522,6 +588,7 @@ def load(self, step: int = -1) -> bool:

model_only = False
from_hf = False
# torch.distributed.breakpoint()
if not os.path.exists(self.folder):
model_only = self.initial_load_model_only
from_hf = self.initial_load_in_hf
Expand Down Expand Up @@ -576,11 +643,23 @@ def load(self, step: int = -1) -> bool:
logger.info(f"Loading the checkpoint from {checkpoint_id}.")
begin = time.monotonic()
states = self._states_to_load(model_only)
try:
assert torch.equal(states['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor)
except:
import fbvscode
fbvscode.set_trace()
before_load = states['tok_embeddings.weight']._local_tensor
self.dcp_load(
states,
checkpoint_id=checkpoint_id,
from_hf=from_hf,
)
after_load = states['tok_embeddings.weight']._local_tensor
try:
assert torch.equal(before_load, after_load)
except:
import fbvscode
fbvscode.set_trace()
GarbageCollection.collect("GC collection for checkpoint loading.")
logger.info(
f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds."
Expand Down Expand Up @@ -698,9 +777,16 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
states_to_load = {
k: v for k, v in self.states.items() if k not in self.exclude_from_loading
}

states_to_load = self._flattened_model_states_sd(states_to_load)

try:
assert torch.equal(self.states['model']._get_state_dict()['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor)
assert torch.equal(self.states['model'].state_dict()['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor)
assert torch.equal(states_to_load['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor)
except:
import fbvscode
fbvscode.set_trace()

if self.ft_manager:
states_to_load.pop(DATALOADER)

Expand Down
1 change: 1 addition & 0 deletions torchtitan/components/ft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def maybe_semi_sync_training(
"""
If TorchFT is enabled and the config is set, use semi_sync_method
"""
return nullcontext()
semi_sync_method = ft_config.semi_sync_method
if ft_config.enable and semi_sync_method is not None:
from torchft import local_sgd
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ def _init_backend(cls) -> None:

# Add CuDNN on B200 w/ highest priority
cls.backends = [
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
# SDPBackend.FLASH_ATTENTION,
# SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.MATH,
]
if has_cuda_capability(10, 0):
Expand Down
3 changes: 2 additions & 1 deletion torchtitan/models/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

llama3_configs = {
"debugmodel": TransformerModelArgs(
dim=256, n_layers=6, n_heads=16, vocab_size=2000, rope_theta=500000
# dim=256, n_layers=6, n_heads=16, vocab_size=2000, rope_theta=500000
dim=256, n_layers=6, n_heads=16, vocab_size=2017, rope_theta=500000
),
"debugmodel_flex_attn": TransformerModelArgs(
dim=256,
Expand Down
6 changes: 3 additions & 3 deletions torchtitan/models/llama3/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ min_lr_factor = 0.0
local_batch_size = 8
seq_len = 2048
max_norm = 1.0 # grad norm clipping
steps = 10
steps = 30
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

[parallelism]
Expand All @@ -55,9 +55,9 @@ pipeline_parallel_degree = 1
context_parallel_degree = 1

[checkpoint]
enable = false
enable = true
folder = "checkpoint"
interval = 10
interval = 15
last_save_model_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
Expand Down
Loading