Skip to content

Commit 6314139

Browse files
committed
Deduplicate TorchTitan main function
People are creating different train.py and duplicate the `main` function. But in realitly people just want to use different Trainer subclasses. This PR creates a main() in torchtitan/train.py to deduplicate the code. ghstack-source-id: 2b3b9ac Pull-Request: #1995
1 parent 2ee0aef commit 6314139

File tree

4 files changed

+30
-93
lines changed

4 files changed

+30
-93
lines changed

torchtitan/experiments/forge/example_train.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import importlib
88
import time
99
from datetime import timedelta
10-
from typing import Any, Iterable, Optional
10+
from typing import Any, Iterable
1111

1212
import torch
1313
from torch.distributed.elastic.multiprocessing.errors import record
@@ -17,15 +17,16 @@
1717
from torchtitan.components.metrics import build_metrics_processor
1818
from torchtitan.components.tokenizer import build_hf_tokenizer
1919
from torchtitan.components.validate import build_validator
20-
from torchtitan.config import ConfigManager, JobConfig
20+
from torchtitan.config import JobConfig
2121
from torchtitan.distributed import utils as dist_utils
2222
from torchtitan.hf_datasets.text_datasets import build_text_dataloader
2323
from torchtitan.tools import utils
24-
from torchtitan.tools.logging import init_logger, logger
24+
from torchtitan.tools.logging import logger
2525
from torchtitan.tools.profiling import (
2626
maybe_enable_memory_snapshot,
2727
maybe_enable_profiling,
2828
)
29+
from torchtitan.train import main as train_main
2930

3031
from .engine import ForgeEngine
3132

@@ -350,19 +351,4 @@ def close(self) -> None:
350351

351352

352353
if __name__ == "__main__":
353-
init_logger()
354-
config_manager = ConfigManager()
355-
config = config_manager.parse_args()
356-
trainer: Optional[Trainer] = None
357-
358-
try:
359-
trainer = Trainer(config)
360-
trainer.train()
361-
except Exception:
362-
if trainer:
363-
trainer.close()
364-
raise
365-
else:
366-
trainer.close()
367-
torch.distributed.destroy_process_group()
368-
logger.info("Process group destroyed.")
354+
train_main(Trainer)

torchtitan/experiments/torchcomms/train.py

Lines changed: 9 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import os
8-
from typing import Optional
9-
10-
import torch
11-
12-
from torchtitan.config import ConfigManager
137
from torchtitan.distributed import ParallelDims
14-
from torchtitan.tools.logging import init_logger, logger
15-
from torchtitan.train import Trainer
8+
from torchtitan.train import main, Trainer
169

1710
from .parallel_dims import TorchCommsParallelDims
1811

@@ -32,35 +25,13 @@ def _create_parallel_dims(self, parallelism_config, world_size) -> ParallelDims:
3225
world_size=world_size,
3326
)
3427

35-
36-
if __name__ == "__main__":
37-
init_logger()
38-
config_manager = ConfigManager()
39-
config = config_manager.parse_args()
40-
trainer: Optional[TorchCommsTrainer] = None
41-
42-
try:
43-
trainer = TorchCommsTrainer(config)
44-
45-
if config.checkpoint.create_seed_checkpoint:
46-
assert (
47-
int(os.environ["WORLD_SIZE"]) == 1
48-
), "Must create seed checkpoint using a single device, to disable sharding."
49-
assert (
50-
config.checkpoint.enable
51-
), "Must enable checkpointing when creating a seed checkpoint."
52-
trainer.checkpointer.save(curr_step=0, last_step=True)
53-
logger.info("Created seed checkpoint")
54-
else:
55-
trainer.train()
56-
# Call finalize on all comms after training and before destroying process group.
28+
def close(self) -> None:
29+
# Call finalize on all comms after training and before destroying process group.
30+
if hasattr(trainer, "parallel_dims"):
5731
for comm in trainer.parallel_dims.comms:
5832
comm.finalize()
59-
except Exception:
60-
if trainer:
61-
trainer.close()
62-
raise
63-
else:
64-
trainer.close()
65-
torch.distributed.destroy_process_group()
66-
logger.info("Process group destroyed")
33+
super().close()
34+
35+
36+
if __name__ == "__main__":
37+
main(TorchCommsTrainer)

torchtitan/models/flux/train.py

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import os
8-
from typing import Optional
9-
107
import torch
118

12-
from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP
9+
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
1310
from torchtitan.distributed import utils as dist_utils
1411

1512
from torchtitan.models.flux.infra.parallelize import parallelize_encoders
@@ -20,8 +17,7 @@
2017
pack_latents,
2118
preprocess_data,
2219
)
23-
from torchtitan.tools.logging import init_logger, logger
24-
from torchtitan.train import Trainer
20+
from torchtitan.train import main, Trainer
2521

2622

2723
class FluxTrainer(Trainer):
@@ -175,29 +171,4 @@ def forward_backward_step(
175171

176172

177173
if __name__ == "__main__":
178-
init_logger()
179-
config_manager = ConfigManager()
180-
config = config_manager.parse_args()
181-
trainer: Optional[FluxTrainer] = None
182-
183-
try:
184-
trainer = FluxTrainer(config)
185-
if config.checkpoint.create_seed_checkpoint:
186-
assert (
187-
int(os.environ["WORLD_SIZE"]) == 1
188-
), "Must create seed checkpoint using a single device, to disable sharding."
189-
assert (
190-
config.checkpoint.enable
191-
), "Must enable checkpointing when creating a seed checkpoint."
192-
trainer.checkpointer.save(curr_step=0, last_step=True)
193-
logger.info("Created seed checkpoint")
194-
else:
195-
trainer.train()
196-
except Exception:
197-
if trainer:
198-
trainer.close()
199-
raise
200-
else:
201-
trainer.close()
202-
torch.distributed.destroy_process_group()
203-
logger.info("Process group destroyed.")
174+
main(FluxTrainer)

torchtitan/train.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import time
1010
from datetime import timedelta
11-
from typing import Any, Generator, Iterable, Optional
11+
from typing import Any, Generator, Iterable
1212

1313
import torch
1414

@@ -673,14 +673,19 @@ def close(self) -> None:
673673
self.metrics_processor.close()
674674

675675

676-
if __name__ == "__main__":
676+
def main(trainer_class: type[Trainer]) -> None:
677+
"""Main entry point for training with a specified trainer class.
678+
679+
Args:
680+
trainer_class: The trainer class to instantiate (e.g., Trainer, FluxTrainer, TorchCommsTrainer)
681+
"""
677682
init_logger()
678683
config_manager = ConfigManager()
679684
config = config_manager.parse_args()
680-
trainer: Optional[Trainer] = None
685+
trainer: Trainer | None = None
681686

682687
try:
683-
trainer = Trainer(config)
688+
trainer = trainer_class(config)
684689

685690
if config.checkpoint.create_seed_checkpoint:
686691
assert (
@@ -701,3 +706,7 @@ def close(self) -> None:
701706
trainer.close()
702707
torch.distributed.destroy_process_group()
703708
logger.info("Process group destroyed")
709+
710+
711+
if __name__ == "__main__":
712+
main(Trainer)

0 commit comments

Comments
 (0)