From 76a73c4bff4d9594bd04a18b1e5b76af9b698bbe Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Fri, 22 Aug 2025 15:19:40 -0700 Subject: [PATCH 1/8] debug fsdp uunven sharding load checkpoint Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/models/llama3/__init__.py | 2 +- torchtitan/models/llama3/train_configs/debug_model.toml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index a34b4463f8..558e13108f 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -29,7 +29,7 @@ llama3_configs = { "debugmodel": TransformerModelArgs( - dim=256, n_layers=6, n_heads=16, vocab_size=2000, rope_theta=500000 + dim=256, n_layers=6, n_heads=16, vocab_size=2017, rope_theta=500000 ), "debugmodel_flex_attn": TransformerModelArgs( dim=256, diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index ecabf6e5db..c57c739050 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -42,7 +42,7 @@ min_lr_factor = 0.0 local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 30 dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -55,9 +55,9 @@ pipeline_parallel_degree = 1 context_parallel_degree = 1 [checkpoint] -enable = false +enable = true folder = "checkpoint" -interval = 10 +interval = 15 last_save_model_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] From 5b665f59e73c9f3f19ae40ff8dc77f130f339412 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Fri, 22 Aug 2025 15:22:44 -0700 Subject: [PATCH 2/8] add repro command Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- repro.sh | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 repro.sh diff --git a/repro.sh b/repro.sh new file mode 100644 index 0000000000..46d13d0c8d --- /dev/null +++ b/repro.sh @@ -0,0 +1,2 @@ +NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh +rm -rf outputs/checkpoint/step-30 && NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh From b71efeb77ecbe9263efd078fa1cc43e6fdad10a5 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Fri, 22 Aug 2025 17:19:26 -0700 Subject: [PATCH 3/8] same model ref_model Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/components/checkpoint.py | 3 + torchtitan/models/attention.py | 4 +- torchtitan/train.py | 126 ++++++++++++++++++++++++++-- 3 files changed, 122 insertions(+), 11 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index fcec601850..98805fca92 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -425,6 +425,8 @@ def dcp_load( # TODO: Since we flatten the model states in state_dict, we need to # manually call load_state_dict() for the model. Need to fix this. if MODEL in self.states: + # import fbvscode + # fbvscode.set_trace() self.states[MODEL].load_state_dict(state_dict) @torch.no_grad() @@ -522,6 +524,7 @@ def load(self, step: int = -1) -> bool: model_only = False from_hf = False + # torch.distributed.breakpoint() if not os.path.exists(self.folder): model_only = self.initial_load_model_only from_hf = self.initial_load_in_hf diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index f66361a6d2..f141b39729 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -205,8 +205,8 @@ def _init_backend(cls) -> None: # Add CuDNN on B200 w/ highest priority cls.backends = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, + # SDPBackend.FLASH_ATTENTION, + # SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH, ] if has_cuda_capability(10, 0): diff --git a/torchtitan/train.py b/torchtitan/train.py index 758a5a6995..a8016f2569 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,6 +11,9 @@ from typing import Any, Generator, Iterable, Optional import torch +torch.backends.cuda.enable_flash_sdp(False) +torch.backends.cuda.enable_mem_efficient_sdp(False) +torch.backends.cuda.enable_math_sdp(True) from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module @@ -154,8 +157,11 @@ def __init__(self, job_config: JobConfig): logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) - with torch.device("meta"): + with torch.device("cuda"): + # import fbvscode + # fbvscode.set_trace() model = self.train_spec.model_cls(model_args) + # model = torch.nn.Linear(1024, 1024, device="cuda") # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) @@ -257,15 +263,19 @@ def __init__(self, job_config: JobConfig): ensure_pp_loss_visible(parallel_dims, job_config, color) else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) + import copy + model.init_weights() + ref_model = copy.deepcopy(model) + ref_model = self.train_spec.parallelize_fn(ref_model, parallel_dims, job_config) + ref_model.train() + self.ref_model_parts = [ref_model] - model.to_empty(device=init_device) - with torch.no_grad(): - model.init_weights(buffer_device=buffer_device) + model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) model.train() - self.model_parts = [model] + + self.ft_manager.maybe_set_all_reduce_hook(self.model_parts) # initialize device memory monitor and get peak flops for MFU calculation @@ -294,6 +304,19 @@ def __init__(self, job_config: JobConfig): self.model_parts ) ) + + self.ref_optimizers = self.train_spec.build_optimizers_fn( + self.ref_model_parts, job_config.optimizer, parallel_dims, self.ft_manager + ) + self.ref_lr_schedulers = self.train_spec.build_lr_schedulers_fn( + self.ref_optimizers, job_config.lr_scheduler, job_config.training.steps + ) + self.ref_optimizers.register_step_post_hook( + lambda *args, **kwargs: model_converters.post_optimizer_hook( + self.ref_model_parts + ) + ) + self.metrics_processor.optimizers = self.optimizers self.metrics_processor.model_parts = self.model_parts @@ -320,6 +343,24 @@ def __init__(self, job_config: JobConfig): ft_manager=self.ft_manager, ) + self.ref_checkpointer = CheckpointManager( + dataloader=self.dataloader, + model_parts=self.ref_model_parts, + optimizers=self.ref_optimizers, + lr_schedulers=self.ref_lr_schedulers, + states={"train_state": self}, + checkpoint_config=job_config.checkpoint, + sd_adapter=( + self.train_spec.state_dict_adapter( + model_args, job_config.model.hf_assets_path + ) + if self.train_spec.state_dict_adapter + else None + ), + base_folder=job_config.job.dump_folder, + ft_manager=self.ft_manager, + ) + loss_parallel_enabled = ( parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel ) @@ -460,16 +501,32 @@ def forward_backward_step( with self.maybe_enable_amp: pred = model_parts[0](inputs) loss = self.loss_fn(pred, labels) + + import copy + ref_inputs = copy.deepcopy(inputs) + ref_pred = self.ref_model_parts[0](ref_inputs) + ref_loss = self.loss_fn(ref_pred, labels) + + try: + assert torch.equal(pred, ref_pred) + assert torch.equal(loss, ref_loss) + except: + import fbvscode + fbvscode.set_trace() + # need to free to before bwd to avoid peaking memory del pred + del ref_pred loss.backward() + ref_loss.backward() - return loss + return loss, ref_loss def train_step( self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ): self.optimizers.zero_grad() + self.ref_optimizers.zero_grad() # Save the current step learning rate for logging lr = self.lr_schedulers.schedulers[0].get_last_lr()[0] @@ -478,25 +535,72 @@ def train_step( parallel_dims = self.parallel_dims accumulated_losses = [] + ref_accumulated_losses = [] # If data runs out during gradient accumulation, that # entire step will not be executed. for microbatch in range(self.gradient_accumulation_steps): + # import fbvscode + # fbvscode.set_trace() + for param, ref_param in zip(self.model_parts[0].parameters(), self.ref_model_parts[0].parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + try: + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + input_dict, labels = next(data_iterator) - loss = self.forward_backward_step(input_dict, labels) + loss, ref_loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) + ref_accumulated_losses.append(ref_loss.detach()) + + for param, ref_param in zip(self.model_parts[0].parameters(), self.ref_model_parts[0].parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + full_param_grad = param.grad.full_tensor() + ref_full_param_grad = ref_param.grad.full_tensor() + try: + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + try: + assert torch.equal(full_param_grad, ref_full_param_grad) + except: + import fbvscode + fbvscode.set_trace() + grad_norm = dist_utils.clip_grad_norm_( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, - foreach=True, + foreach=False, pp_mesh=( parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None ), ep_enabled=parallel_dims.ep_enabled, ) + ref_grad_norm = dist_utils.clip_grad_norm_( + [p for m in self.ref_model_parts for p in m.parameters()], + self.job_config.training.max_norm, + foreach=False, + pp_mesh=( + parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None + ), + ep_enabled=parallel_dims.ep_enabled, + ) + try: + assert torch.equal(grad_norm, ref_grad_norm) + except: + import fbvscode + fbvscode.set_trace() self.checkpointer.maybe_wait_for_staging() + self.ref_checkpointer.maybe_wait_for_staging() self.optimizers.step() self.lr_schedulers.step() + self.ref_optimizers.step() + self.ref_lr_schedulers.step() # Reduce the data collected over gradient accumulation steps. loss = torch.sum(torch.stack(accumulated_losses)) @@ -540,6 +644,10 @@ def train(self): job_config = self.job_config self.checkpointer.load(step=job_config.checkpoint.load_step) + self.ref_checkpointer.load(step=job_config.checkpoint.load_step) + + # import fbvscode + # fbvscode.set_trace() logger.info(f"Training starts at step {self.step + 1}") leaf_folder = ( From 728d3e23d0f2758dd360a45ea10eb95d6dc69cb9 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Mon, 25 Aug 2025 15:02:25 -0700 Subject: [PATCH 4/8] repro numeric diffrences before/after dcp load Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/models/llama3/__init__.py | 3 ++- torchtitan/train.py | 26 ++++++++++++++++++++++---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 558e13108f..3cf92a0224 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -29,7 +29,8 @@ llama3_configs = { "debugmodel": TransformerModelArgs( - dim=256, n_layers=6, n_heads=16, vocab_size=2017, rope_theta=500000 + dim=256, n_layers=6, n_heads=16, vocab_size=2000, rope_theta=500000 + # dim=256, n_layers=6, n_heads=16, vocab_size=2017, rope_theta=500000 ), "debugmodel_flex_attn": TransformerModelArgs( dim=256, diff --git a/torchtitan/train.py b/torchtitan/train.py index a8016f2569..90eb338927 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -541,7 +541,7 @@ def train_step( for microbatch in range(self.gradient_accumulation_steps): # import fbvscode # fbvscode.set_trace() - for param, ref_param in zip(self.model_parts[0].parameters(), self.ref_model_parts[0].parameters()): + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): full_param = param.full_tensor() ref_full_param = ref_param.full_tensor() try: @@ -555,12 +555,13 @@ def train_step( accumulated_losses.append(loss.detach()) ref_accumulated_losses.append(ref_loss.detach()) - for param, ref_param in zip(self.model_parts[0].parameters(), self.ref_model_parts[0].parameters()): + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): full_param = param.full_tensor() ref_full_param = ref_param.full_tensor() full_param_grad = param.grad.full_tensor() ref_full_param_grad = ref_param.grad.full_tensor() try: + assert torch.equal(full_param, ref_full_param) except: import fbvscode @@ -643,8 +644,8 @@ def train_step( def train(self): job_config = self.job_config - self.checkpointer.load(step=job_config.checkpoint.load_step) - self.ref_checkpointer.load(step=job_config.checkpoint.load_step) + # self.checkpointer.load(step=job_config.checkpoint.load_step) + # self.ref_checkpointer.load(step=job_config.checkpoint.load_step) # import fbvscode # fbvscode.set_trace() @@ -698,6 +699,23 @@ def train(self): self.checkpointer.save( self.step, last_step=(self.step == job_config.training.steps) ) + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + try: + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + self.checkpointer.load(step=job_config.checkpoint.load_step) + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + try: + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() # Run validation if validator is available if ( From 5429c430d05b52c23b619e2dd886f5b789bdc53b Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Tue, 26 Aug 2025 15:39:34 -0700 Subject: [PATCH 5/8] repro at checkpoint Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/components/checkpoint.py | 22 +++++++++++++++++++++- torchtitan/train.py | 19 +++++++++++-------- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 98805fca92..f7f8367803 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -397,6 +397,7 @@ def dcp_load( state_dict: dict[str, Any], checkpoint_id: str, from_hf: bool, + step: int = -1, ) -> None: """Load the checkpoint with dcp. Args: @@ -420,7 +421,15 @@ def dcp_load( state_dict = self.sd_adapter.from_hf(hf_state_dict) self.states[MODEL].load_state_dict(state_dict) else: + before_load = state_dict['tok_embeddings.weight']._local_tensor dcp.load(state_dict, checkpoint_id=checkpoint_id) + after_load = state_dict['tok_embeddings.weight']._local_tensor + try: + assert torch.equal(before_load, after_load) + # dcp.load(state_dict, checkpoint_id=checkpoint_id) + except: + import fbvscode + fbvscode.set_trace() # TODO: Since we flatten the model states in state_dict, we need to # manually call load_state_dict() for the model. Need to fix this. @@ -490,6 +499,9 @@ def save(self, curr_step: int, last_step: bool = False) -> None: enable_garbage_collection=True, ) self._purge_stale_checkpoints() + # import fbvscode + # fbvscode.set_trace() + self.load(step=-1) logger.info( "Finished saving the checkpoint (or staging if async is enabled)" @@ -579,11 +591,18 @@ def load(self, step: int = -1) -> bool: logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() states = self._states_to_load(model_only) + before_load = states['tok_embeddings.weight']._local_tensor self.dcp_load( states, checkpoint_id=checkpoint_id, from_hf=from_hf, ) + after_load = states['tok_embeddings.weight']._local_tensor + try: + assert torch.equal(before_load, after_load) + except: + import fbvscode + fbvscode.set_trace() GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." @@ -702,7 +721,8 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: k: v for k, v in self.states.items() if k not in self.exclude_from_loading } - states_to_load = self._flattened_model_states_sd(states_to_load) + # states_to_load = self._flattened_model_states_sd(states_to_load) + states_to_load = self._flattened_model_states_sd() if self.ft_manager: states_to_load.pop(DATALOADER) diff --git a/torchtitan/train.py b/torchtitan/train.py index 90eb338927..841784469a 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -699,19 +699,22 @@ def train(self): self.checkpointer.save( self.step, last_step=(self.step == job_config.training.steps) ) - for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): - full_param = param.full_tensor() - ref_full_param = ref_param.full_tensor() - try: - assert torch.equal(full_param, ref_full_param) - except: - import fbvscode - fbvscode.set_trace() + # for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + # full_param = param.full_tensor() + # ref_full_param = ref_param.full_tensor() + # try: + # assert torch.equal(full_param, ref_full_param) + # except: + # import fbvscode + # fbvscode.set_trace() self.checkpointer.load(step=job_config.checkpoint.load_step) for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): full_param = param.full_tensor() ref_full_param = ref_param.full_tensor() + local_param = param._local_tensor + ref_local_param = ref_param._local_tensor try: + assert torch.equal(local_param, ref_local_param) assert torch.equal(full_param, ref_full_param) except: import fbvscode From f12ca725515e18f857f0df020c5f74bb55fcab59 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Thu, 11 Sep 2025 14:26:25 -0700 Subject: [PATCH 6/8] repro state dict numeric for uneven sharding Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchtitan/components/checkpoint.py | 4 +++ torchtitan/models/llama3/__init__.py | 4 +-- torchtitan/train.py | 38 +++++++++++++++++++++------- 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index f7f8367803..01409cc28b 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -422,12 +422,16 @@ def dcp_load( self.states[MODEL].load_state_dict(state_dict) else: before_load = state_dict['tok_embeddings.weight']._local_tensor + before_load_full = state_dict['tok_embeddings.weight'].full_tensor() dcp.load(state_dict, checkpoint_id=checkpoint_id) after_load = state_dict['tok_embeddings.weight']._local_tensor + after_load_full = state_dict['tok_embeddings.weight'].full_tensor() try: assert torch.equal(before_load, after_load) + assert torch.equal(before_load_full, after_load_full) # dcp.load(state_dict, checkpoint_id=checkpoint_id) except: + logger.info(f"{torch.distributed.get_rank()=} does not have equal") import fbvscode fbvscode.set_trace() diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 3cf92a0224..edd2fa3918 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -29,8 +29,8 @@ llama3_configs = { "debugmodel": TransformerModelArgs( - dim=256, n_layers=6, n_heads=16, vocab_size=2000, rope_theta=500000 - # dim=256, n_layers=6, n_heads=16, vocab_size=2017, rope_theta=500000 + # dim=256, n_layers=6, n_heads=16, vocab_size=2000, rope_theta=500000 + dim=256, n_layers=6, n_heads=16, vocab_size=2017, rope_theta=500000 ), "debugmodel_flex_attn": TransformerModelArgs( dim=256, diff --git a/torchtitan/train.py b/torchtitan/train.py index 841784469a..a3c53210f2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -699,15 +699,35 @@ def train(self): self.checkpointer.save( self.step, last_step=(self.step == job_config.training.steps) ) - # for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): - # full_param = param.full_tensor() - # ref_full_param = ref_param.full_tensor() - # try: - # assert torch.equal(full_param, ref_full_param) - # except: - # import fbvscode - # fbvscode.set_trace() - self.checkpointer.load(step=job_config.checkpoint.load_step) + # def reset_model_parameters(model): + # from torch.distributed.fsdp import FSDPModule + # for fsdp_model in model.modules(): + # if not isinstance(fsdp_model, FSDPModule): + # continue + # # import fbvscode + # # fbvscode.set_trace() + # fsdp_model.reshard() + # state = fsdp_model._get_fsdp_state() + # if not (fsdp_param_group := state._fsdp_param_group): + # continue + # with torch.no_grad(): + # for fsdp_param in fsdp_param_group.fsdp_params: + # fsdp_param.reset_sharded_param() + # reset_model_parameters(self.model_parts[0]) + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + local_param = param._local_tensor + ref_local_param = ref_param._local_tensor + try: + assert torch.equal(local_param, ref_local_param) + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + if self.checkpointer._should_save(self.step, last_step=(self.step == job_config.training.steps)): + self.checkpointer.load(step=self.step) + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): full_param = param.full_tensor() ref_full_param = ref_param.full_tensor() From 9e29bc51937686baef7a39bc639c8728c7c1d1a3 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Thu, 11 Sep 2025 17:27:44 -0700 Subject: [PATCH 7/8] remove cached_state_dict --- torchtitan/components/checkpoint.py | 81 ++++++++++++++++++++++++++--- torchtitan/train.py | 68 +++++++++++++++--------- 2 files changed, 117 insertions(+), 32 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 01409cc28b..c2e3be5999 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -67,7 +67,10 @@ def _get_state_dict(self) -> dict[str, Any]: return state_dict def state_dict(self) -> dict[str, Any]: - return self.cache_state_dict + # import fbvscode + # fbvscode.set_trace() + # return self.cache_state_dict + return self._get_state_dict() def load_state_dict(self, state_dict: dict[str, Any]) -> None: func = functools.partial( @@ -421,12 +424,19 @@ def dcp_load( state_dict = self.sd_adapter.from_hf(hf_state_dict) self.states[MODEL].load_state_dict(state_dict) else: - before_load = state_dict['tok_embeddings.weight']._local_tensor - before_load_full = state_dict['tok_embeddings.weight'].full_tensor() + before_load = state_dict['tok_embeddings.weight']._local_tensor.clone() + before_load_full = state_dict['tok_embeddings.weight'].full_tensor().clone() + try: + assert torch.equal(state_dict['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + except: + logger.info(f"{torch.distributed.get_rank()=} does not have equal") + import fbvscode + fbvscode.set_trace() dcp.load(state_dict, checkpoint_id=checkpoint_id) - after_load = state_dict['tok_embeddings.weight']._local_tensor - after_load_full = state_dict['tok_embeddings.weight'].full_tensor() + after_load = state_dict['tok_embeddings.weight']._local_tensor.clone() + after_load_full = state_dict['tok_embeddings.weight'].full_tensor().clone() try: + assert torch.equal(state_dict['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) assert torch.equal(before_load, after_load) assert torch.equal(before_load_full, after_load_full) # dcp.load(state_dict, checkpoint_id=checkpoint_id) @@ -435,12 +445,59 @@ def dcp_load( import fbvscode fbvscode.set_trace() + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + local_param = param._local_tensor + ref_local_param = ref_param._local_tensor + try: + # if torch.distributed.get_rank() != 3: + assert torch.equal(local_param, ref_local_param) + # assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + # TODO: Since we flatten the model states in state_dict, we need to # manually call load_state_dict() for the model. Need to fix this. if MODEL in self.states: # import fbvscode # fbvscode.set_trace() + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + local_param = param._local_tensor + ref_local_param = ref_param._local_tensor + try: + # if torch.distributed.get_rank() != 3: + assert torch.equal(local_param, ref_local_param) + # assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + try: + assert torch.equal(state_dict['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + except: + import fbvscode + fbvscode.set_trace() self.states[MODEL].load_state_dict(state_dict) + try: + assert torch.equal(state_dict['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + except: + import fbvscode + fbvscode.set_trace() + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + local_param = param._local_tensor + ref_local_param = ref_param._local_tensor + try: + # if torch.distributed.get_rank() != 3: + assert torch.equal(local_param, ref_local_param) + # assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() @torch.no_grad() def save(self, curr_step: int, last_step: bool = False) -> None: @@ -595,6 +652,11 @@ def load(self, step: int = -1) -> bool: logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() states = self._states_to_load(model_only) + try: + assert torch.equal(states['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + except: + import fbvscode + fbvscode.set_trace() before_load = states['tok_embeddings.weight']._local_tensor self.dcp_load( states, @@ -725,8 +787,13 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: k: v for k, v in self.states.items() if k not in self.exclude_from_loading } - # states_to_load = self._flattened_model_states_sd(states_to_load) - states_to_load = self._flattened_model_states_sd() + states_to_load = self._flattened_model_states_sd(states_to_load) + + try: + assert torch.equal(states_to_load['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + except: + import fbvscode + fbvscode.set_trace() if self.ft_manager: states_to_load.pop(DATALOADER) diff --git a/torchtitan/train.py b/torchtitan/train.py index a3c53210f2..67bb16548d 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -696,9 +696,10 @@ def train(self): logger.warning("Ran out of data; last step was canceled.") break - self.checkpointer.save( - self.step, last_step=(self.step == job_config.training.steps) - ) + self.checkpointer.model_parts = self.model_parts + self.checkpointer.ref_model_parts = self.ref_model_parts + + # def reset_model_parameters(model): # from torch.distributed.fsdp import FSDPModule # for fsdp_model in model.modules(): @@ -714,31 +715,48 @@ def train(self): # for fsdp_param in fsdp_param_group.fsdp_params: # fsdp_param.reset_sharded_param() # reset_model_parameters(self.model_parts[0]) - for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): - full_param = param.full_tensor() - ref_full_param = ref_param.full_tensor() - local_param = param._local_tensor - ref_local_param = ref_param._local_tensor - try: - assert torch.equal(local_param, ref_local_param) - assert torch.equal(full_param, ref_full_param) - except: - import fbvscode - fbvscode.set_trace() + if self.checkpointer._should_save(self.step, last_step=(self.step == job_config.training.steps)): + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + local_param = param._local_tensor + ref_local_param = ref_param._local_tensor + try: + # if torch.distributed.get_rank() != 3: + assert torch.equal(local_param, ref_local_param) + # assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + self.checkpointer.save( + self.step, last_step=(self.step == job_config.training.steps) + ) + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + local_param = param._local_tensor + ref_local_param = ref_param._local_tensor + try: + # if torch.distributed.get_rank() != 3: + assert torch.equal(local_param, ref_local_param) + # assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() self.checkpointer.load(step=self.step) - for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): - full_param = param.full_tensor() - ref_full_param = ref_param.full_tensor() - local_param = param._local_tensor - ref_local_param = ref_param._local_tensor - try: - assert torch.equal(local_param, ref_local_param) - assert torch.equal(full_param, ref_full_param) - except: - import fbvscode - fbvscode.set_trace() + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + local_param = param._local_tensor + ref_local_param = ref_param._local_tensor + try: + assert torch.equal(local_param, ref_local_param) + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() # Run validation if validator is available if ( From d65284075f5de39160cc94a5dda8350ce246e1a3 Mon Sep 17 00:00:00 2001 From: Wei Feng Date: Fri, 12 Sep 2025 16:12:40 -0700 Subject: [PATCH 8/8] check _local_tensor storage ptr --- torchtitan/components/checkpoint.py | 50 ++++++++++++----------------- torchtitan/components/ft/manager.py | 1 + torchtitan/train.py | 30 +++++++++++++++++ 3 files changed, 52 insertions(+), 29 deletions(-) diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index c2e3be5999..9ad9b2a829 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from copy import deepcopy import enum import functools import os @@ -59,6 +60,7 @@ class ModelWrapper(Stateful): def __init__(self, model: nn.Module | list[nn.Module]) -> None: self.model = [model] if isinstance(model, nn.Module) else model self.cache_state_dict = self._get_state_dict() + assert self.cache_state_dict['tok_embeddings.weight']._local_tensor.untyped_storage().data_ptr() == self.model[0].tok_embeddings.weight._local_tensor.untyped_storage().data_ptr() def _get_state_dict(self) -> dict[str, Any]: state_dict = { @@ -67,10 +69,7 @@ def _get_state_dict(self) -> dict[str, Any]: return state_dict def state_dict(self) -> dict[str, Any]: - # import fbvscode - # fbvscode.set_trace() - # return self.cache_state_dict - return self._get_state_dict() + return self.cache_state_dict def load_state_dict(self, state_dict: dict[str, Any]) -> None: func = functools.partial( @@ -234,6 +233,11 @@ def load_state_dict(state_dict): LR_SCHEDULER: lr_schedulers, } ) + try: + assert self.states['model'].state_dict()['tok_embeddings.weight']._local_tensor.untyped_storage().data_ptr(), self.states[MODEL].model[0].tok_embeddings.weight._local_tensor.untyped_storage().data_ptr() + except: + import fbvscode + fbvscode.set_trace() self.ft_states = {DATALOADER: dataloader} self.staging = False @@ -463,18 +467,6 @@ def dcp_load( if MODEL in self.states: # import fbvscode # fbvscode.set_trace() - for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): - full_param = param.full_tensor() - ref_full_param = ref_param.full_tensor() - local_param = param._local_tensor - ref_local_param = ref_param._local_tensor - try: - # if torch.distributed.get_rank() != 3: - assert torch.equal(local_param, ref_local_param) - # assert torch.equal(full_param, ref_full_param) - except: - import fbvscode - fbvscode.set_trace() try: assert torch.equal(state_dict['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) except: @@ -486,18 +478,6 @@ def dcp_load( except: import fbvscode fbvscode.set_trace() - for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): - full_param = param.full_tensor() - ref_full_param = ref_param.full_tensor() - local_param = param._local_tensor - ref_local_param = ref_param._local_tensor - try: - # if torch.distributed.get_rank() != 3: - assert torch.equal(local_param, ref_local_param) - # assert torch.equal(full_param, ref_full_param) - except: - import fbvscode - fbvscode.set_trace() @torch.no_grad() def save(self, curr_step: int, last_step: bool = False) -> None: @@ -553,12 +533,23 @@ def save(self, curr_step: int, last_step: bool = False) -> None: ) GarbageCollection.collect("GC collection invoked by checkpointer.") else: + try: + assert states['tok_embeddings.weight']._local_tensor.untyped_storage().data_ptr() == self.states[MODEL].model[0].tok_embeddings.weight._local_tensor.untyped_storage().data_ptr() + except: + import fbvscode + fbvscode.set_trace() self.dcp_save( states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.DISABLED, enable_garbage_collection=True, ) + try: + assert torch.equal(self.states['model']._get_state_dict()['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + assert torch.equal(self.states['model'].state_dict()['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + except: + import fbvscode + fbvscode.set_trace() self._purge_stale_checkpoints() # import fbvscode # fbvscode.set_trace() @@ -786,10 +777,11 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: states_to_load = { k: v for k, v in self.states.items() if k not in self.exclude_from_loading } - states_to_load = self._flattened_model_states_sd(states_to_load) try: + assert torch.equal(self.states['model']._get_state_dict()['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + assert torch.equal(self.states['model'].state_dict()['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) assert torch.equal(states_to_load['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) except: import fbvscode diff --git a/torchtitan/components/ft/manager.py b/torchtitan/components/ft/manager.py index 5d64d34b09..de6ab3e549 100644 --- a/torchtitan/components/ft/manager.py +++ b/torchtitan/components/ft/manager.py @@ -119,6 +119,7 @@ def maybe_semi_sync_training( """ If TorchFT is enabled and the config is set, use semi_sync_method """ + return nullcontext() semi_sync_method = ft_config.semi_sync_method if ft_config.enable and semi_sync_method is not None: from torchft import local_sgd diff --git a/torchtitan/train.py b/torchtitan/train.py index 67bb16548d..36e34353a6 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -551,7 +551,17 @@ def train_step( fbvscode.set_trace() input_dict, labels = next(data_iterator) + try: + assert self.checkpointer.states['model'].state_dict()['tok_embeddings.weight']._local_tensor.untyped_storage().data_ptr() == self.model_parts[0].tok_embeddings.weight._local_tensor.untyped_storage().data_ptr() + except: + import fbvscode + fbvscode.set_trace() loss, ref_loss = self.forward_backward_step(input_dict, labels) + try: + assert self.checkpointer.states['model'].state_dict()['tok_embeddings.weight']._local_tensor.untyped_storage().data_ptr() == self.model_parts[0].tok_embeddings.weight._local_tensor.untyped_storage().data_ptr() + except: + import fbvscode + fbvscode.set_trace() accumulated_losses.append(loss.detach()) ref_accumulated_losses.append(ref_loss.detach()) @@ -603,6 +613,13 @@ def train_step( self.ref_optimizers.step() self.ref_lr_schedulers.step() + # try: + # assert torch.equal(self.checkpointer.states['model']._get_state_dict()['tok_embeddings.weight']._local_tensor, self.checkpointer.states['model'].model[0].tok_embeddings.weight._local_tensor) + # assert torch.equal(self.checkpointer.states['model'].state_dict()['tok_embeddings.weight']._local_tensor, self.checkpointer.states['model'].model[0].tok_embeddings.weight._local_tensor) + # except: + # import fbvscode + # fbvscode.set_trace() + # Reduce the data collected over gradient accumulation steps. loss = torch.sum(torch.stack(accumulated_losses)) @@ -644,6 +661,13 @@ def train_step( def train(self): job_config = self.job_config + try: + assert torch.equal(self.checkpointer.states['model']._get_state_dict()['tok_embeddings.weight']._local_tensor, self.checkpointer.states['model'].model[0].tok_embeddings.weight._local_tensor) + assert torch.equal(self.checkpointer.states['model'].state_dict()['tok_embeddings.weight']._local_tensor, self.checkpointer.states['model'].model[0].tok_embeddings.weight._local_tensor) + except: + import fbvscode + fbvscode.set_trace() + # self.checkpointer.load(step=job_config.checkpoint.load_step) # self.ref_checkpointer.load(step=job_config.checkpoint.load_step) @@ -691,6 +715,12 @@ def train(self): self.step += 1 self.gc_handler.run(self.step) try: + # try: + # assert torch.equal(self.checkpointer.states['model']._get_state_dict()['tok_embeddings.weight']._local_tensor, self.checkpointer.states['model'].model[0].tok_embeddings.weight._local_tensor) + # assert torch.equal(self.checkpointer.states['model'].state_dict()['tok_embeddings.weight']._local_tensor, self.checkpointer.states['model'].model[0].tok_embeddings.weight._local_tensor) + # except: + # import fbvscode + # fbvscode.set_trace() self.train_step(data_iterator) except DataloaderStopIteration: logger.warning("Ran out of data; last step was canceled.")