diff --git a/repro.sh b/repro.sh new file mode 100644 index 0000000000..46d13d0c8d --- /dev/null +++ b/repro.sh @@ -0,0 +1,2 @@ +NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh +rm -rf outputs/checkpoint/step-30 && NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index fcec601850..9ad9b2a829 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from copy import deepcopy import enum import functools import os @@ -59,6 +60,7 @@ class ModelWrapper(Stateful): def __init__(self, model: nn.Module | list[nn.Module]) -> None: self.model = [model] if isinstance(model, nn.Module) else model self.cache_state_dict = self._get_state_dict() + assert self.cache_state_dict['tok_embeddings.weight']._local_tensor.untyped_storage().data_ptr() == self.model[0].tok_embeddings.weight._local_tensor.untyped_storage().data_ptr() def _get_state_dict(self) -> dict[str, Any]: state_dict = { @@ -231,6 +233,11 @@ def load_state_dict(state_dict): LR_SCHEDULER: lr_schedulers, } ) + try: + assert self.states['model'].state_dict()['tok_embeddings.weight']._local_tensor.untyped_storage().data_ptr(), self.states[MODEL].model[0].tok_embeddings.weight._local_tensor.untyped_storage().data_ptr() + except: + import fbvscode + fbvscode.set_trace() self.ft_states = {DATALOADER: dataloader} self.staging = False @@ -397,6 +404,7 @@ def dcp_load( state_dict: dict[str, Any], checkpoint_id: str, from_hf: bool, + step: int = -1, ) -> None: """Load the checkpoint with dcp. Args: @@ -420,12 +428,56 @@ def dcp_load( state_dict = self.sd_adapter.from_hf(hf_state_dict) self.states[MODEL].load_state_dict(state_dict) else: + before_load = state_dict['tok_embeddings.weight']._local_tensor.clone() + before_load_full = state_dict['tok_embeddings.weight'].full_tensor().clone() + try: + assert torch.equal(state_dict['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + except: + logger.info(f"{torch.distributed.get_rank()=} does not have equal") + import fbvscode + fbvscode.set_trace() dcp.load(state_dict, checkpoint_id=checkpoint_id) + after_load = state_dict['tok_embeddings.weight']._local_tensor.clone() + after_load_full = state_dict['tok_embeddings.weight'].full_tensor().clone() + try: + assert torch.equal(state_dict['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + assert torch.equal(before_load, after_load) + assert torch.equal(before_load_full, after_load_full) + # dcp.load(state_dict, checkpoint_id=checkpoint_id) + except: + logger.info(f"{torch.distributed.get_rank()=} does not have equal") + import fbvscode + fbvscode.set_trace() + + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + local_param = param._local_tensor + ref_local_param = ref_param._local_tensor + try: + # if torch.distributed.get_rank() != 3: + assert torch.equal(local_param, ref_local_param) + # assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() # TODO: Since we flatten the model states in state_dict, we need to # manually call load_state_dict() for the model. Need to fix this. if MODEL in self.states: + # import fbvscode + # fbvscode.set_trace() + try: + assert torch.equal(state_dict['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + except: + import fbvscode + fbvscode.set_trace() self.states[MODEL].load_state_dict(state_dict) + try: + assert torch.equal(state_dict['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + except: + import fbvscode + fbvscode.set_trace() @torch.no_grad() def save(self, curr_step: int, last_step: bool = False) -> None: @@ -481,13 +533,27 @@ def save(self, curr_step: int, last_step: bool = False) -> None: ) GarbageCollection.collect("GC collection invoked by checkpointer.") else: + try: + assert states['tok_embeddings.weight']._local_tensor.untyped_storage().data_ptr() == self.states[MODEL].model[0].tok_embeddings.weight._local_tensor.untyped_storage().data_ptr() + except: + import fbvscode + fbvscode.set_trace() self.dcp_save( states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.DISABLED, enable_garbage_collection=True, ) + try: + assert torch.equal(self.states['model']._get_state_dict()['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + assert torch.equal(self.states['model'].state_dict()['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + except: + import fbvscode + fbvscode.set_trace() self._purge_stale_checkpoints() + # import fbvscode + # fbvscode.set_trace() + self.load(step=-1) logger.info( "Finished saving the checkpoint (or staging if async is enabled)" @@ -522,6 +588,7 @@ def load(self, step: int = -1) -> bool: model_only = False from_hf = False + # torch.distributed.breakpoint() if not os.path.exists(self.folder): model_only = self.initial_load_model_only from_hf = self.initial_load_in_hf @@ -576,11 +643,23 @@ def load(self, step: int = -1) -> bool: logger.info(f"Loading the checkpoint from {checkpoint_id}.") begin = time.monotonic() states = self._states_to_load(model_only) + try: + assert torch.equal(states['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + except: + import fbvscode + fbvscode.set_trace() + before_load = states['tok_embeddings.weight']._local_tensor self.dcp_load( states, checkpoint_id=checkpoint_id, from_hf=from_hf, ) + after_load = states['tok_embeddings.weight']._local_tensor + try: + assert torch.equal(before_load, after_load) + except: + import fbvscode + fbvscode.set_trace() GarbageCollection.collect("GC collection for checkpoint loading.") logger.info( f"Finished loading the checkpoint in {time.monotonic() - begin:.2f} seconds." @@ -698,9 +777,16 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]: states_to_load = { k: v for k, v in self.states.items() if k not in self.exclude_from_loading } - states_to_load = self._flattened_model_states_sd(states_to_load) + try: + assert torch.equal(self.states['model']._get_state_dict()['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + assert torch.equal(self.states['model'].state_dict()['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + assert torch.equal(states_to_load['tok_embeddings.weight']._local_tensor, self.states[MODEL].model[0].tok_embeddings.weight._local_tensor) + except: + import fbvscode + fbvscode.set_trace() + if self.ft_manager: states_to_load.pop(DATALOADER) diff --git a/torchtitan/components/ft/manager.py b/torchtitan/components/ft/manager.py index 5d64d34b09..de6ab3e549 100644 --- a/torchtitan/components/ft/manager.py +++ b/torchtitan/components/ft/manager.py @@ -119,6 +119,7 @@ def maybe_semi_sync_training( """ If TorchFT is enabled and the config is set, use semi_sync_method """ + return nullcontext() semi_sync_method = ft_config.semi_sync_method if ft_config.enable and semi_sync_method is not None: from torchft import local_sgd diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index f66361a6d2..f141b39729 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -205,8 +205,8 @@ def _init_backend(cls) -> None: # Add CuDNN on B200 w/ highest priority cls.backends = [ - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, + # SDPBackend.FLASH_ATTENTION, + # SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH, ] if has_cuda_capability(10, 0): diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index a34b4463f8..edd2fa3918 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -29,7 +29,8 @@ llama3_configs = { "debugmodel": TransformerModelArgs( - dim=256, n_layers=6, n_heads=16, vocab_size=2000, rope_theta=500000 + # dim=256, n_layers=6, n_heads=16, vocab_size=2000, rope_theta=500000 + dim=256, n_layers=6, n_heads=16, vocab_size=2017, rope_theta=500000 ), "debugmodel_flex_attn": TransformerModelArgs( dim=256, diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index ecabf6e5db..c57c739050 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -42,7 +42,7 @@ min_lr_factor = 0.0 local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 30 dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -55,9 +55,9 @@ pipeline_parallel_degree = 1 context_parallel_degree = 1 [checkpoint] -enable = false +enable = true folder = "checkpoint" -interval = 10 +interval = 15 last_save_model_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] diff --git a/torchtitan/train.py b/torchtitan/train.py index 758a5a6995..36e34353a6 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,6 +11,9 @@ from typing import Any, Generator, Iterable, Optional import torch +torch.backends.cuda.enable_flash_sdp(False) +torch.backends.cuda.enable_mem_efficient_sdp(False) +torch.backends.cuda.enable_math_sdp(True) from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module @@ -154,8 +157,11 @@ def __init__(self, job_config: JobConfig): logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) - with torch.device("meta"): + with torch.device("cuda"): + # import fbvscode + # fbvscode.set_trace() model = self.train_spec.model_cls(model_args) + # model = torch.nn.Linear(1024, 1024, device="cuda") # Build the collection of model converters. No-op if `model.converters` empty model_converters = build_model_converters(job_config, parallel_dims) @@ -257,15 +263,19 @@ def __init__(self, job_config: JobConfig): ensure_pp_loss_visible(parallel_dims, job_config, color) else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel - model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) + import copy + model.init_weights() + ref_model = copy.deepcopy(model) + ref_model = self.train_spec.parallelize_fn(ref_model, parallel_dims, job_config) + ref_model.train() + self.ref_model_parts = [ref_model] - model.to_empty(device=init_device) - with torch.no_grad(): - model.init_weights(buffer_device=buffer_device) + model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) model.train() - self.model_parts = [model] + + self.ft_manager.maybe_set_all_reduce_hook(self.model_parts) # initialize device memory monitor and get peak flops for MFU calculation @@ -294,6 +304,19 @@ def __init__(self, job_config: JobConfig): self.model_parts ) ) + + self.ref_optimizers = self.train_spec.build_optimizers_fn( + self.ref_model_parts, job_config.optimizer, parallel_dims, self.ft_manager + ) + self.ref_lr_schedulers = self.train_spec.build_lr_schedulers_fn( + self.ref_optimizers, job_config.lr_scheduler, job_config.training.steps + ) + self.ref_optimizers.register_step_post_hook( + lambda *args, **kwargs: model_converters.post_optimizer_hook( + self.ref_model_parts + ) + ) + self.metrics_processor.optimizers = self.optimizers self.metrics_processor.model_parts = self.model_parts @@ -320,6 +343,24 @@ def __init__(self, job_config: JobConfig): ft_manager=self.ft_manager, ) + self.ref_checkpointer = CheckpointManager( + dataloader=self.dataloader, + model_parts=self.ref_model_parts, + optimizers=self.ref_optimizers, + lr_schedulers=self.ref_lr_schedulers, + states={"train_state": self}, + checkpoint_config=job_config.checkpoint, + sd_adapter=( + self.train_spec.state_dict_adapter( + model_args, job_config.model.hf_assets_path + ) + if self.train_spec.state_dict_adapter + else None + ), + base_folder=job_config.job.dump_folder, + ft_manager=self.ft_manager, + ) + loss_parallel_enabled = ( parallel_dims.tp_enabled and not parallelism_config.disable_loss_parallel ) @@ -460,16 +501,32 @@ def forward_backward_step( with self.maybe_enable_amp: pred = model_parts[0](inputs) loss = self.loss_fn(pred, labels) + + import copy + ref_inputs = copy.deepcopy(inputs) + ref_pred = self.ref_model_parts[0](ref_inputs) + ref_loss = self.loss_fn(ref_pred, labels) + + try: + assert torch.equal(pred, ref_pred) + assert torch.equal(loss, ref_loss) + except: + import fbvscode + fbvscode.set_trace() + # need to free to before bwd to avoid peaking memory del pred + del ref_pred loss.backward() + ref_loss.backward() - return loss + return loss, ref_loss def train_step( self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ): self.optimizers.zero_grad() + self.ref_optimizers.zero_grad() # Save the current step learning rate for logging lr = self.lr_schedulers.schedulers[0].get_last_lr()[0] @@ -478,25 +535,90 @@ def train_step( parallel_dims = self.parallel_dims accumulated_losses = [] + ref_accumulated_losses = [] # If data runs out during gradient accumulation, that # entire step will not be executed. for microbatch in range(self.gradient_accumulation_steps): + # import fbvscode + # fbvscode.set_trace() + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + try: + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + input_dict, labels = next(data_iterator) - loss = self.forward_backward_step(input_dict, labels) + try: + assert self.checkpointer.states['model'].state_dict()['tok_embeddings.weight']._local_tensor.untyped_storage().data_ptr() == self.model_parts[0].tok_embeddings.weight._local_tensor.untyped_storage().data_ptr() + except: + import fbvscode + fbvscode.set_trace() + loss, ref_loss = self.forward_backward_step(input_dict, labels) + try: + assert self.checkpointer.states['model'].state_dict()['tok_embeddings.weight']._local_tensor.untyped_storage().data_ptr() == self.model_parts[0].tok_embeddings.weight._local_tensor.untyped_storage().data_ptr() + except: + import fbvscode + fbvscode.set_trace() accumulated_losses.append(loss.detach()) + ref_accumulated_losses.append(ref_loss.detach()) + + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + full_param_grad = param.grad.full_tensor() + ref_full_param_grad = ref_param.grad.full_tensor() + try: + + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + try: + assert torch.equal(full_param_grad, ref_full_param_grad) + except: + import fbvscode + fbvscode.set_trace() + grad_norm = dist_utils.clip_grad_norm_( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, - foreach=True, + foreach=False, pp_mesh=( parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None ), ep_enabled=parallel_dims.ep_enabled, ) + ref_grad_norm = dist_utils.clip_grad_norm_( + [p for m in self.ref_model_parts for p in m.parameters()], + self.job_config.training.max_norm, + foreach=False, + pp_mesh=( + parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None + ), + ep_enabled=parallel_dims.ep_enabled, + ) + try: + assert torch.equal(grad_norm, ref_grad_norm) + except: + import fbvscode + fbvscode.set_trace() self.checkpointer.maybe_wait_for_staging() + self.ref_checkpointer.maybe_wait_for_staging() self.optimizers.step() self.lr_schedulers.step() + self.ref_optimizers.step() + self.ref_lr_schedulers.step() + + # try: + # assert torch.equal(self.checkpointer.states['model']._get_state_dict()['tok_embeddings.weight']._local_tensor, self.checkpointer.states['model'].model[0].tok_embeddings.weight._local_tensor) + # assert torch.equal(self.checkpointer.states['model'].state_dict()['tok_embeddings.weight']._local_tensor, self.checkpointer.states['model'].model[0].tok_embeddings.weight._local_tensor) + # except: + # import fbvscode + # fbvscode.set_trace() # Reduce the data collected over gradient accumulation steps. loss = torch.sum(torch.stack(accumulated_losses)) @@ -539,7 +661,18 @@ def train_step( def train(self): job_config = self.job_config - self.checkpointer.load(step=job_config.checkpoint.load_step) + try: + assert torch.equal(self.checkpointer.states['model']._get_state_dict()['tok_embeddings.weight']._local_tensor, self.checkpointer.states['model'].model[0].tok_embeddings.weight._local_tensor) + assert torch.equal(self.checkpointer.states['model'].state_dict()['tok_embeddings.weight']._local_tensor, self.checkpointer.states['model'].model[0].tok_embeddings.weight._local_tensor) + except: + import fbvscode + fbvscode.set_trace() + + # self.checkpointer.load(step=job_config.checkpoint.load_step) + # self.ref_checkpointer.load(step=job_config.checkpoint.load_step) + + # import fbvscode + # fbvscode.set_trace() logger.info(f"Training starts at step {self.step + 1}") leaf_folder = ( @@ -582,14 +715,78 @@ def train(self): self.step += 1 self.gc_handler.run(self.step) try: + # try: + # assert torch.equal(self.checkpointer.states['model']._get_state_dict()['tok_embeddings.weight']._local_tensor, self.checkpointer.states['model'].model[0].tok_embeddings.weight._local_tensor) + # assert torch.equal(self.checkpointer.states['model'].state_dict()['tok_embeddings.weight']._local_tensor, self.checkpointer.states['model'].model[0].tok_embeddings.weight._local_tensor) + # except: + # import fbvscode + # fbvscode.set_trace() self.train_step(data_iterator) except DataloaderStopIteration: logger.warning("Ran out of data; last step was canceled.") break - self.checkpointer.save( - self.step, last_step=(self.step == job_config.training.steps) - ) + self.checkpointer.model_parts = self.model_parts + self.checkpointer.ref_model_parts = self.ref_model_parts + + + # def reset_model_parameters(model): + # from torch.distributed.fsdp import FSDPModule + # for fsdp_model in model.modules(): + # if not isinstance(fsdp_model, FSDPModule): + # continue + # # import fbvscode + # # fbvscode.set_trace() + # fsdp_model.reshard() + # state = fsdp_model._get_fsdp_state() + # if not (fsdp_param_group := state._fsdp_param_group): + # continue + # with torch.no_grad(): + # for fsdp_param in fsdp_param_group.fsdp_params: + # fsdp_param.reset_sharded_param() + # reset_model_parameters(self.model_parts[0]) + + if self.checkpointer._should_save(self.step, last_step=(self.step == job_config.training.steps)): + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + local_param = param._local_tensor + ref_local_param = ref_param._local_tensor + try: + # if torch.distributed.get_rank() != 3: + assert torch.equal(local_param, ref_local_param) + # assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + self.checkpointer.save( + self.step, last_step=(self.step == job_config.training.steps) + ) + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + local_param = param._local_tensor + ref_local_param = ref_param._local_tensor + try: + # if torch.distributed.get_rank() != 3: + assert torch.equal(local_param, ref_local_param) + # assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() + self.checkpointer.load(step=self.step) + + for (param_name, param), (_, ref_param) in zip(self.model_parts[0].named_parameters(), self.ref_model_parts[0].named_parameters()): + full_param = param.full_tensor() + ref_full_param = ref_param.full_tensor() + local_param = param._local_tensor + ref_local_param = ref_param._local_tensor + try: + assert torch.equal(local_param, ref_local_param) + assert torch.equal(full_param, ref_full_param) + except: + import fbvscode + fbvscode.set_trace() # Run validation if validator is available if (