diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index d638a6bd26..7b0b0c81e9 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -7,7 +7,7 @@ import importlib import time from datetime import timedelta -from typing import Any, Iterable, Optional +from typing import Any, Iterable import torch from torch.distributed.elastic.multiprocessing.errors import record @@ -17,15 +17,16 @@ from torchtitan.components.metrics import build_metrics_processor from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.components.validate import build_validator -from torchtitan.config import ConfigManager, JobConfig +from torchtitan.config import JobConfig from torchtitan.distributed import utils as dist_utils from torchtitan.hf_datasets.text_datasets import build_text_dataloader from torchtitan.tools import utils -from torchtitan.tools.logging import init_logger, logger +from torchtitan.tools.logging import logger from torchtitan.tools.profiling import ( maybe_enable_memory_snapshot, maybe_enable_profiling, ) +from torchtitan.train import main from .engine import ForgeEngine @@ -350,19 +351,4 @@ def close(self) -> None: if __name__ == "__main__": - init_logger() - config_manager = ConfigManager() - config = config_manager.parse_args() - trainer: Optional[Trainer] = None - - try: - trainer = Trainer(config) - trainer.train() - except Exception: - if trainer: - trainer.close() - raise - else: - trainer.close() - torch.distributed.destroy_process_group() - logger.info("Process group destroyed.") + main(Trainer) diff --git a/torchtitan/experiments/torchcomms/train.py b/torchtitan/experiments/torchcomms/train.py index 29aba39f41..07716a4b36 100644 --- a/torchtitan/experiments/torchcomms/train.py +++ b/torchtitan/experiments/torchcomms/train.py @@ -4,15 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os -from typing import Optional - -import torch - -from torchtitan.config import ConfigManager from torchtitan.distributed import ParallelDims -from torchtitan.tools.logging import init_logger, logger -from torchtitan.train import Trainer +from torchtitan.train import main, Trainer from .parallel_dims import TorchCommsParallelDims @@ -32,35 +25,13 @@ def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims: world_size=world_size, ) + def close(self) -> None: + # Call finalize on all comms after training and before destroying process group. + if hasattr(self, "parallel_dims"): + for comm in self.parallel_dims.comms: + comm.finalize() + super().close() -if __name__ == "__main__": - init_logger() - config_manager = ConfigManager() - config = config_manager.parse_args() - trainer: Optional[TorchCommsTrainer] = None - - try: - trainer = TorchCommsTrainer(config) - if config.checkpoint.create_seed_checkpoint: - assert ( - int(os.environ["WORLD_SIZE"]) == 1 - ), "Must create seed checkpoint using a single device, to disable sharding." - assert ( - config.checkpoint.enable - ), "Must enable checkpointing when creating a seed checkpoint." - trainer.checkpointer.save(curr_step=0, last_step=True) - logger.info("Created seed checkpoint") - else: - trainer.train() - # Call finalize on all comms after training and before destroying process group. - for comm in trainer.parallel_dims.comms: - comm.finalize() - except Exception: - if trainer: - trainer.close() - raise - else: - trainer.close() - torch.distributed.destroy_process_group() - logger.info("Process group destroyed") +if __name__ == "__main__": + main(TorchCommsTrainer) diff --git a/torchtitan/models/flux/train.py b/torchtitan/models/flux/train.py index 9bb3cd48bf..5af9959050 100644 --- a/torchtitan/models/flux/train.py +++ b/torchtitan/models/flux/train.py @@ -4,12 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os -from typing import Optional - import torch -from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import utils as dist_utils from torchtitan.models.flux.infra.parallelize import parallelize_encoders @@ -20,8 +17,7 @@ pack_latents, preprocess_data, ) -from torchtitan.tools.logging import init_logger, logger -from torchtitan.train import Trainer +from torchtitan.train import main, Trainer class FluxTrainer(Trainer): @@ -175,29 +171,4 @@ def forward_backward_step( if __name__ == "__main__": - init_logger() - config_manager = ConfigManager() - config = config_manager.parse_args() - trainer: Optional[FluxTrainer] = None - - try: - trainer = FluxTrainer(config) - if config.checkpoint.create_seed_checkpoint: - assert ( - int(os.environ["WORLD_SIZE"]) == 1 - ), "Must create seed checkpoint using a single device, to disable sharding." - assert ( - config.checkpoint.enable - ), "Must enable checkpointing when creating a seed checkpoint." - trainer.checkpointer.save(curr_step=0, last_step=True) - logger.info("Created seed checkpoint") - else: - trainer.train() - except Exception: - if trainer: - trainer.close() - raise - else: - trainer.close() - torch.distributed.destroy_process_group() - logger.info("Process group destroyed.") + main(FluxTrainer) diff --git a/torchtitan/train.py b/torchtitan/train.py index 4d3ed12e8e..b05e15551c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -8,7 +8,7 @@ import os import time from datetime import timedelta -from typing import Any, Generator, Iterable, Optional +from typing import Any, Generator, Iterable import torch @@ -410,26 +410,67 @@ def batch_generator( yield input_dict, labels - def forward_backward_step( + def post_dataloading_process( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor - ) -> torch.Tensor: - model_parts = self.model_parts - parallel_dims = self.parallel_dims - + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]: + """ + Post-processing hook after data loading and before model forward pass. + + This method processes the raw data from the dataloader and prepares it for + the model's forward pass. It separates the main input tensor from auxiliary + inputs and constructs additional keyword arguments (e.g., attention masks). + + This method can be overridden in subclasses to customize data processing + for different training strategies (e.g., converting tensors to DTensors, + applying custom transformations, etc.). + + Args: + input_dict: Dictionary containing tensors from the dataloader. Must + contain an "input" key with the main input tensor. May contain + additional keys for auxiliary inputs (e.g., position ids). + labels: Target labels for the batch. + + Returns: + A tuple of (inputs, labels, extra_inputs, extra_kwargs) where: + - inputs: Main input tensor extracted from input_dict["input"]. + - labels: Target labels (unchanged from input parameter). + - extra_inputs: Dict of auxiliary input tensors (all keys except + "input" from input_dict). These are passed to the model forward + but are NOT forwarded across pipeline parallel stages. + - extra_kwargs: Dict of additional keyword arguments for model forward. + These ARE forwarded across pipeline parallel stages. Contains + attention_masks if flex attention is enabled. + + Note: + The distinction between extra_inputs and extra_kwargs is important for + pipeline parallelism: extra_kwargs are forwarded to all pipeline stages, + while extra_inputs are only available to the first stage. + """ inputs = input_dict["input"] extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} # For arguments, like attention_masks, we have to put them in a separate # dict as extra_inputs are not forwarded to other stages in PP, but # extra_kwargs are. - extra_kwargs = {} + extra_kwargs: dict[str, Any] = {} if getattr(self.model_args, "use_flex_attn", False): - extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks( + extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, extra_inputs=extra_inputs, ) + return inputs, labels, extra_inputs, extra_kwargs + + def forward_backward_step( + self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + ) -> torch.Tensor: + model_parts = self.model_parts + parallel_dims = self.parallel_dims + + inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( + input_dict, labels + ) # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage optional_context_parallel_ctx = ( @@ -662,14 +703,19 @@ def close(self) -> None: self.metrics_processor.close() -if __name__ == "__main__": +def main(trainer_class: type[Trainer]) -> None: + """Main entry point for training with a specified trainer class. + + Args: + trainer_class: The trainer class to instantiate (e.g., Trainer, FluxTrainer, TorchCommsTrainer) + """ init_logger() config_manager = ConfigManager() config = config_manager.parse_args() - trainer: Optional[Trainer] = None + trainer: Trainer | None = None try: - trainer = Trainer(config) + trainer = trainer_class(config) if config.checkpoint.create_seed_checkpoint: assert ( @@ -690,3 +736,7 @@ def close(self) -> None: trainer.close() torch.distributed.destroy_process_group() logger.info("Process group destroyed") + + +if __name__ == "__main__": + main(Trainer)