Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions torchtitan/experiments/forge/example_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import importlib
import time
from datetime import timedelta
from typing import Any, Iterable, Optional
from typing import Any, Iterable

import torch
from torch.distributed.elastic.multiprocessing.errors import record
Expand All @@ -17,15 +17,16 @@
from torchtitan.components.metrics import build_metrics_processor
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.components.validate import build_validator
from torchtitan.config import ConfigManager, JobConfig
from torchtitan.config import JobConfig
from torchtitan.distributed import utils as dist_utils
from torchtitan.hf_datasets.text_datasets import build_text_dataloader
from torchtitan.tools import utils
from torchtitan.tools.logging import init_logger, logger
from torchtitan.tools.logging import logger
from torchtitan.tools.profiling import (
maybe_enable_memory_snapshot,
maybe_enable_profiling,
)
from torchtitan.train import main as train_main

from .engine import ForgeEngine

Expand Down Expand Up @@ -350,19 +351,4 @@ def close(self) -> None:


if __name__ == "__main__":
init_logger()
config_manager = ConfigManager()
config = config_manager.parse_args()
trainer: Optional[Trainer] = None

try:
trainer = Trainer(config)
trainer.train()
except Exception:
if trainer:
trainer.close()
raise
else:
trainer.close()
torch.distributed.destroy_process_group()
logger.info("Process group destroyed.")
train_main(Trainer)
47 changes: 9 additions & 38 deletions torchtitan/experiments/torchcomms/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from typing import Optional

import torch

from torchtitan.config import ConfigManager
from torchtitan.distributed import ParallelDims
from torchtitan.tools.logging import init_logger, logger
from torchtitan.train import Trainer
from torchtitan.train import main, Trainer

from .parallel_dims import TorchCommsParallelDims

Expand All @@ -32,35 +25,13 @@ def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:
world_size=world_size,
)


if __name__ == "__main__":
init_logger()
config_manager = ConfigManager()
config = config_manager.parse_args()
trainer: Optional[TorchCommsTrainer] = None

try:
trainer = TorchCommsTrainer(config)

if config.checkpoint.create_seed_checkpoint:
assert (
int(os.environ["WORLD_SIZE"]) == 1
), "Must create seed checkpoint using a single device, to disable sharding."
assert (
config.checkpoint.enable
), "Must enable checkpointing when creating a seed checkpoint."
trainer.checkpointer.save(curr_step=0, last_step=True)
logger.info("Created seed checkpoint")
else:
trainer.train()
# Call finalize on all comms after training and before destroying process group.
def close(self) -> None:
# Call finalize on all comms after training and before destroying process group.
if hasattr(trainer, "parallel_dims"):
for comm in trainer.parallel_dims.comms:
comm.finalize()
except Exception:
if trainer:
trainer.close()
raise
else:
trainer.close()
torch.distributed.destroy_process_group()
logger.info("Process group destroyed")
super().close()


if __name__ == "__main__":
main(TorchCommsTrainer)
35 changes: 3 additions & 32 deletions torchtitan/models/flux/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
from typing import Optional

import torch

from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import utils as dist_utils

from torchtitan.models.flux.infra.parallelize import parallelize_encoders
Expand All @@ -20,8 +17,7 @@
pack_latents,
preprocess_data,
)
from torchtitan.tools.logging import init_logger, logger
from torchtitan.train import Trainer
from torchtitan.train import main, Trainer


class FluxTrainer(Trainer):
Expand Down Expand Up @@ -175,29 +171,4 @@ def forward_backward_step(


if __name__ == "__main__":
init_logger()
config_manager = ConfigManager()
config = config_manager.parse_args()
trainer: Optional[FluxTrainer] = None

try:
trainer = FluxTrainer(config)
if config.checkpoint.create_seed_checkpoint:
assert (
int(os.environ["WORLD_SIZE"]) == 1
), "Must create seed checkpoint using a single device, to disable sharding."
assert (
config.checkpoint.enable
), "Must enable checkpointing when creating a seed checkpoint."
trainer.checkpointer.save(curr_step=0, last_step=True)
logger.info("Created seed checkpoint")
else:
trainer.train()
except Exception:
if trainer:
trainer.close()
raise
else:
trainer.close()
torch.distributed.destroy_process_group()
logger.info("Process group destroyed.")
main(FluxTrainer)
44 changes: 32 additions & 12 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import time
from datetime import timedelta
from typing import Any, Generator, Iterable, Optional
from typing import Any, Generator, Iterable

import torch

Expand Down Expand Up @@ -410,26 +410,37 @@ def batch_generator(

yield input_dict, labels

def forward_backward_step(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims

def post_dataloading_processing(
self, input_dict: dict[str, torch.Tensor], label: torch.Tensor
) -> tuple[
dict[str, torch.Tensor], torch.Tensor, dict[str, torch.Tensor], dict[str, Any]
]:
"""Post processing after data loading."""
inputs = input_dict["input"]
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
# For arguments, like attention_masks, we have to put them in a separate
# dict as extra_inputs are not forwarded to other stages in PP, but
# extra_kwargs are.
extra_kwargs = {}
extra_kwargs: dict[str, Any] = {}

if getattr(self.model_args, "use_flex_attn", False):
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks(
input_batch=inputs,
tokenizer=self.tokenizer,
extra_inputs=extra_inputs,
)

return inputs, label, extra_inputs, extra_kwargs

def forward_backward_step(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims

inputs, label, extra_inputs, extra_kwargs = self.post_dataloading_processing(
input_dict, labels
)
# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
optional_context_parallel_ctx = (
Expand Down Expand Up @@ -662,14 +673,19 @@ def close(self) -> None:
self.metrics_processor.close()


if __name__ == "__main__":
def main(trainer_class: type[Trainer]) -> None:
"""Main entry point for training with a specified trainer class.

Args:
trainer_class: The trainer class to instantiate (e.g., Trainer, FluxTrainer, TorchCommsTrainer)
"""
init_logger()
config_manager = ConfigManager()
config = config_manager.parse_args()
trainer: Optional[Trainer] = None
trainer: Trainer | None = None

try:
trainer = Trainer(config)
trainer = trainer_class(config)

if config.checkpoint.create_seed_checkpoint:
assert (
Expand All @@ -690,3 +706,7 @@ def close(self) -> None:
trainer.close()
torch.distributed.destroy_process_group()
logger.info("Process group destroyed")


if __name__ == "__main__":
main(Trainer)
Loading