diff --git a/templates/configs/compute/bon_echo/a40_4x.yaml b/templates/configs/compute/bon_echo/a40_4x.yaml new file mode 100644 index 0000000..bfa2a96 --- /dev/null +++ b/templates/configs/compute/bon_echo/a40_4x.yaml @@ -0,0 +1,7 @@ +nodes: 1 +gpus_per_node: 4 +cpus_per_task: 32 +mem_gb: 64 +timeout_min: 60 +slurm: + partition: a40 diff --git a/templates/src/llm/README.md b/templates/src/llm/README.md index 9b621f8..3cd76b3 100644 --- a/templates/src/llm/README.md +++ b/templates/src/llm/README.md @@ -1,5 +1,24 @@ -### LLM training templates +# LLM Training Templates -This directory includes templates for LLM training tasks: +This directory includes templates for language-model workloads: -- [text_classification](text_classification/): Fine-tunes a small Transformer on AG News using Hugging Face Trainer. +- [text_classification](text_classification/): fine-tunes a small LLM on AG News via Hugging Face Trainer. +- [finetune_distributed](finetune_distributed/): distributed finetuning example adapted from VectorLM (https://github.com/VectorInstitute/vectorlm). + +## Finetune Distributed (DDP/FSDP) + +Run the distributed template: +```bash +uv run python -m llm.finetune_distributed.launch \ + compute=bon_echo/a40_4x \ + +trainer.dist.mode=fsdp --multirun +``` +You can choose **DDP** or **FSDP** mode by setting the `+trainer.dist.mode` argument (`ddp` or `fsdp`). + +A few points to clarify for this template: +- **`launch.py`** is the Hydra entrypoint; it merges config layers and hands the resolved config to Submitit. +- **`distributed_launcher.py`** is a Submitit helper; it shells out to `torch.distributed.run` so that torchrun controls per-rank workers without re-entering Hydra (the same pattern used in VectorLM). +- **`train.py`** is the torchrun worker; it loads the saved config, builds tokenizer, dataloaders, model, and optimizer, and then delegates to the Trainer. +- **`trainer_core.py`** is a minimal trainer (adapted from VectorLM’s `trainer.py`); it handles gradient accumulation, checkpointing, optional evaluation, and works with either DDP or FSDP. + +Hydra and Submitit resolve and submit jobs once. Torchrun (DDP/FSDP) needs to own process creation per GPU. Launching `torch.distributed.run` in a subprocess is the standard Hydra + Submitit approach: it avoids nested Hydra invocations, keeps the Hydra run directory stable for requeues and checkpoints, and makes local debugging under `torchrun` straightforward. diff --git a/templates/src/llm/finetune_distributed/__init__.py b/templates/src/llm/finetune_distributed/__init__.py new file mode 100644 index 0000000..74faeea --- /dev/null +++ b/templates/src/llm/finetune_distributed/__init__.py @@ -0,0 +1 @@ +"""LLM training template: Fine-tuning using distributed training.""" diff --git a/templates/src/llm/finetune_distributed/config.yaml b/templates/src/llm/finetune_distributed/config.yaml new file mode 100644 index 0000000..691e0d8 --- /dev/null +++ b/templates/src/llm/finetune_distributed/config.yaml @@ -0,0 +1,35 @@ +trainer: + model: + name: "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + gradient_checkpointing: false + + data: + name: "karpathy/tiny_shakespeare" + text_key: "text" + max_length: 512 + trust_remote_code: true + + train: + num_train_epochs: 1 + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + gradient_accumulation_steps: 8 + learning_rate: 2.0e-5 + weight_decay: 0.0 + logging_steps: 20 + eval_steps: 200 + save_steps: 200 + eval_strategy: "steps" + + dist: + mode: "fsdp" # none | ddp | fsdp + backend: "nccl" + bf16: false + fp16: true + fsdp: "full_shard auto_wrap" + fsdp_config: + use_orig_params: true + limit_all_gathers: true + forward_prefetch: true + sync_module_states: true + activation_checkpointing: false diff --git a/templates/src/llm/finetune_distributed/distributed_launcher.py b/templates/src/llm/finetune_distributed/distributed_launcher.py new file mode 100644 index 0000000..cdde247 --- /dev/null +++ b/templates/src/llm/finetune_distributed/distributed_launcher.py @@ -0,0 +1,115 @@ +"""Launch script for checkpointable distributed finetuning with Hydra + Submitit.""" + +from __future__ import annotations + +import os +import subprocess +import sys + +import submitit +from omegaconf import OmegaConf + + +def _under_torchrun() -> bool: + return "LOCAL_RANK" in os.environ or "TORCHELASTIC_RUN_ID" in os.environ + + +def _running_inside_slurm() -> bool: + return "SLURM_JOB_ID" in os.environ + + +def _slurm_world(): + nnodes = int(os.environ.get("SLURM_NNODES", "1")) + node_rank = int(os.environ.get("SLURM_NODEID", "0")) + nodelist = os.environ.get("SLURM_NODELIST") or os.environ.get("SLURM_JOB_NODELIST") + if not nodelist: + master_addr = "127.0.0.1" + else: + out = subprocess.check_output(["scontrol", "show", "hostnames", nodelist]) + master_addr = out.decode().splitlines()[0].strip() + master_port = os.environ.get("MASTER_PORT", "29500") + return nnodes, node_rank, master_addr, master_port + + +def _resolve_work_dir(cfg) -> str: + env_dir = os.environ.get("HYDRA_LAUNCHER_RUN_DIR") or os.environ.get( + "HYDRA_RUN_DIR" + ) + if env_dir: + return env_dir + work_dir = getattr(cfg, "work_dir", None) + if isinstance(work_dir, str) and "${" not in work_dir: + return work_dir + return os.getcwd() + + +def _save_resolved_config(cfg, work_dir: str) -> str: + OmegaConf.set_struct(cfg, False) + cfg.work_dir = work_dir + if "paths" in cfg: + cfg.paths["work_dir"] = work_dir + cfg.paths["work_root"] = os.path.dirname(work_dir) + base = os.path.basename(work_dir) + if base.isdigit(): + cfg.experiment_name = os.path.basename(os.path.dirname(work_dir)) + else: + cfg.experiment_name = base + cfg_path = os.path.join(work_dir, "_fsdp_cfg.yaml") + OmegaConf.save(cfg, cfg_path) + return cfg_path + + +def _launch_torchrun(cfg, world_size: int, nproc_per_node: int) -> int: + if world_size <= 1 or _under_torchrun() or not _running_inside_slurm(): + return 0 + nnodes, node_rank, master_addr, master_port = _slurm_world() + work_dir = _resolve_work_dir(cfg) + os.makedirs(work_dir, exist_ok=True) + cfg_path = _save_resolved_config(cfg, work_dir) + cmd = [ + sys.executable, + "-m", + "torch.distributed.run", + f"--nproc_per_node={nproc_per_node}", + f"--nnodes={nnodes}", + f"--node_rank={node_rank}", + f"--master_addr={master_addr}", + f"--master_port={master_port}", + "--module", + "llm.finetune_distributed.train", + "--config", + cfg_path, + ] + return subprocess.run(cmd, check=False).returncode + + +class DistributedLauncher(submitit.helpers.Checkpointable): + """Submitit helper that spins up torchrun or falls back to a local run.""" + + def __call__(self, cfg): + """Dispatch the training job based on the selected distributed mode.""" + nnodes = int(getattr(cfg.compute, "nodes", 1)) + gpn = int(getattr(cfg.compute, "gpus_per_node", 1)) + world_size = nnodes * gpn + + if getattr(cfg.dist, "mode", "none") in {"ddp", "fsdp"}: + return _launch_torchrun(cfg, world_size, gpn) + + work_dir = _resolve_work_dir(cfg) + os.makedirs(work_dir, exist_ok=True) + cfg_path = _save_resolved_config(cfg, work_dir) + cmd = [ + sys.executable, + "-m", + "llm.finetune_distributed.train", + "--config", + cfg_path, + ] + return subprocess.run(cmd, check=False).returncode + + def checkpoint(self, *args, **kwargs): + """Checkpoint the launcher so Submitit can requeue the job.""" + return submitit.helpers.DelayedSubmission(self, *args, **kwargs) + + +__all__ = ["DistributedLauncher"] diff --git a/templates/src/llm/finetune_distributed/launch.py b/templates/src/llm/finetune_distributed/launch.py new file mode 100644 index 0000000..3789bf1 --- /dev/null +++ b/templates/src/llm/finetune_distributed/launch.py @@ -0,0 +1,41 @@ +"""Launch script for checkpointable distributed finetuning with Hydra + Submitit.""" + +import os + +import hydra +from hydra.core.hydra_config import HydraConfig +from llm.finetune_distributed.distributed_launcher import DistributedLauncher +from omegaconf import DictConfig, OmegaConf + + +_CONFIG_PATH = os.path.normpath( + os.path.join(os.path.dirname(__file__), "../../../configs") +) + + +@hydra.main(config_path=_CONFIG_PATH, config_name="_global", version_base=None) +def main(cfg: DictConfig): + """Hydra entrypoint that merges configs before launching training.""" + local_cfg_path = os.path.join(os.path.dirname(__file__), "config.yaml") + local_cfg = OmegaConf.load(local_cfg_path) + OmegaConf.set_struct(cfg, False) + cfg = OmegaConf.merge(local_cfg, cfg) + + hydra_run_dir = HydraConfig.get().runtime.output_dir + if hydra_run_dir is not None: + cfg.work_dir = hydra_run_dir + if "paths" in cfg: + cfg.paths.work_dir = hydra_run_dir + cfg.paths.work_root = os.path.dirname(hydra_run_dir) + + if "trainer" in cfg: + trainer_cfg = cfg.trainer + cfg = OmegaConf.merge(cfg, trainer_cfg) + del cfg.trainer + + runner = DistributedLauncher() + return runner(cfg) + + +if __name__ == "__main__": + main() diff --git a/templates/src/llm/finetune_distributed/train.py b/templates/src/llm/finetune_distributed/train.py new file mode 100644 index 0000000..20049bb --- /dev/null +++ b/templates/src/llm/finetune_distributed/train.py @@ -0,0 +1,346 @@ +"""Distributed finetuning worker script.""" + +from __future__ import annotations + +import argparse +import math +import os +import warnings +from dataclasses import dataclass +from functools import partial +from typing import Optional + +import torch +import torch.distributed as dist +from datasets import load_dataset +from omegaconf import DictConfig, OmegaConf +from torch.distributed.fsdp import ( + BackwardPrefetch, + FullyShardedDataParallel, + ShardingStrategy, +) +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy +from torch.nn.parallel import DistributedDataParallel +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, DistributedSampler +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollatorForLanguageModeling, +) +from transformers.trainer_utils import set_seed + +from .trainer_core import Trainer, is_main_process + + +@dataclass +class RuntimeConfig: + """Resolved runtime configuration used by the trainer.""" + + work_dir: str + grad_accum: int + logging_steps: int + eval_steps: int + save_steps: int + num_train_epochs: int + learning_rate: float + weight_decay: float + warmup_ratio: float + per_device_train_batch_size: int + per_device_eval_batch_size: int + max_length: int + dataset_name: str + text_column: str + model_name: str + trust_remote_code: bool + bf16: bool + fp16: bool + gradient_checkpointing: bool + dist_mode: str + dist_backend: str + seed: int + + +def _prepare_datasets( + cfg: RuntimeConfig, tokenizer +) -> tuple[DataLoader, Optional[DataLoader]]: + raw = load_dataset(cfg.dataset_name, trust_remote_code=cfg.trust_remote_code) + train_split = raw["train"] + eval_split = raw.get("validation") or raw.get("test") + + def tokenize_fn(batch): + return tokenizer( + batch[cfg.text_column], + truncation=True, + max_length=cfg.max_length, + return_special_tokens_mask=True, + ) + + tokenized_train = train_split.map( + tokenize_fn, batched=True, remove_columns=train_split.column_names + ) + + def group_texts(examples): + concatenated = [] + for ids in examples["input_ids"]: + concatenated.extend(ids) + total_length = len(concatenated) + if total_length >= cfg.max_length: + total_length = (total_length // cfg.max_length) * cfg.max_length + if total_length == 0: + return {"input_ids": [], "attention_mask": []} + return { + "input_ids": [ + concatenated[i : i + cfg.max_length] + for i in range(0, total_length, cfg.max_length) + ], + "attention_mask": [ + [1] * cfg.max_length for _ in range(0, total_length, cfg.max_length) + ], + } + + tokenized_train = tokenized_train.map(group_texts, batched=True) + tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask"]) + + collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) + + train_sampler = ( + DistributedSampler(tokenized_train, shuffle=True) + if dist.is_initialized() + else None + ) + train_loader = DataLoader( + tokenized_train, + batch_size=cfg.per_device_train_batch_size, + sampler=train_sampler, + shuffle=(train_sampler is None), + collate_fn=collator, + ) + + eval_loader = None + if eval_split is not None: + tokenized_eval = eval_split.map( + tokenize_fn, batched=True, remove_columns=eval_split.column_names + ) + tokenized_eval = tokenized_eval.map(group_texts, batched=True) + tokenized_eval.set_format(type="torch", columns=["input_ids", "attention_mask"]) + eval_sampler = ( + DistributedSampler(tokenized_eval, shuffle=False) + if dist.is_initialized() + else None + ) + eval_loader = DataLoader( + tokenized_eval, + batch_size=cfg.per_device_eval_batch_size, + sampler=eval_sampler, + shuffle=False, + collate_fn=collator, + ) + + return train_loader, eval_loader + + +def _choose_dtype(cfg: RuntimeConfig) -> torch.dtype: + if cfg.bf16 and torch.cuda.is_available(): + if hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported(): + return torch.bfloat16 + if is_main_process(): + msg = "BF16 not supported on this GPU; using {} instead".format( + "FP16" if cfg.fp16 else "FP32" + ) + warnings.warn(msg, stacklevel=2) + if cfg.fp16: + return torch.float16 + return torch.float32 + + +def _build_model(cfg: RuntimeConfig, device: torch.device) -> torch.nn.Module: + torch_dtype = _choose_dtype(cfg) + model = AutoModelForCausalLM.from_pretrained( + cfg.model_name, torch_dtype=torch_dtype, trust_remote_code=cfg.trust_remote_code + ) + model.to(device) + if cfg.gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): + model.gradient_checkpointing_enable() + if hasattr(model.config, "use_cache"): + model.config.use_cache = False + return model + + +def _wrap_model( + cfg: RuntimeConfig, model: torch.nn.Module, device: torch.device +) -> torch.nn.Module: + if cfg.dist_mode == "fsdp": + auto_wrap_policy = partial( + size_based_auto_wrap_policy, min_num_params=1_000_000 + ) + return FullyShardedDataParallel( + model, + auto_wrap_policy=auto_wrap_policy, + device_id=device, + sharding_strategy=ShardingStrategy.FULL_SHARD, + backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + forward_prefetch=True, + sync_module_states=True, + ) + if cfg.dist_mode == "ddp": + return DistributedDataParallel( + model, device_ids=[device.index] if device.type == "cuda" else None + ) + return model + + +def _build_optimizer(cfg: RuntimeConfig, model: torch.nn.Module) -> AdamW: + return AdamW( + model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay + ) + + +def _build_scheduler( + cfg: RuntimeConfig, optimizer: AdamW, train_loader: DataLoader +) -> LambdaLR: + total_update_steps = ( + math.ceil(len(train_loader) / cfg.grad_accum) * cfg.num_train_epochs + ) + warmup_steps = int(cfg.warmup_ratio * total_update_steps) + + def lr_lambda(current_step: int) -> float: + if warmup_steps > 0 and current_step < warmup_steps: + return float(current_step) / float(max(1, warmup_steps)) + progress = float(current_step - warmup_steps) / float( + max(1, total_update_steps - warmup_steps) + ) + return max(0.0, 1.0 - progress) + + return LambdaLR(optimizer, lr_lambda) + + +def _maybe_resume(trainer: Trainer, resume_path: Optional[str]) -> None: + ckpt = resume_path or trainer.latest_checkpoint() + if ckpt and os.path.exists(ckpt): + if dist.is_initialized(): + dist.barrier() + trainer.load_checkpoint(ckpt) + if is_main_process(): + print(f"Resumed from checkpoint: {ckpt}") + if dist.is_initialized(): + dist.barrier() + + +def run(cfg: RuntimeConfig, raw_cfg) -> None: + """Entry point used by torchrun workers.""" + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + device = ( + torch.device("cuda", local_rank) + if torch.cuda.is_available() + else torch.device("cpu") + ) + + if ( + dist.is_available() + and not dist.is_initialized() + and cfg.dist_mode in {"ddp", "fsdp"} + ): + dist.init_process_group(backend=cfg.dist_backend) + if torch.cuda.is_available(): + torch.cuda.set_device(device) + + set_seed(cfg.seed + dist.get_rank() if dist.is_initialized() else cfg.seed) + + os.makedirs(cfg.work_dir, exist_ok=True) + if is_main_process(): + OmegaConf.save(raw_cfg, os.path.join(cfg.work_dir, "resolved_config.yaml")) + + tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + train_loader, eval_loader = _prepare_datasets(cfg, tokenizer) + + model = _build_model(cfg, device) + model = _wrap_model(cfg, model, device) + + optimizer = _build_optimizer(cfg, model) + scheduler = _build_scheduler(cfg, optimizer, train_loader) + + trainer = Trainer( + model=model, + optimizer=optimizer, + scheduler=scheduler, + train_loader=train_loader, + eval_loader=eval_loader, + work_dir=cfg.work_dir, + gradient_accumulation_steps=cfg.grad_accum, + logging_steps=cfg.logging_steps, + eval_steps=cfg.eval_steps, + save_steps=cfg.save_steps, + max_epochs=cfg.num_train_epochs, + ) + + resume_path = None + if isinstance(raw_cfg, DictConfig) and "resume_from_checkpoint" in raw_cfg: + resume_path = raw_cfg.resume_from_checkpoint + elif isinstance(raw_cfg, dict) and "resume_from_checkpoint" in raw_cfg: + resume_path = raw_cfg.get("resume_from_checkpoint") + elif hasattr(raw_cfg, "resume_from_checkpoint"): + resume_path = raw_cfg.resume_from_checkpoint + _maybe_resume(trainer, resume_path) + + trainer.train() + + if is_main_process(): + trainer.save_checkpoint() + + if dist.is_initialized() and cfg.dist_mode in {"ddp", "fsdp"}: + dist.barrier() + dist.destroy_process_group() + + +def parse_args() -> argparse.Namespace: + """Parse CLI arguments when invoked directly.""" + parser = argparse.ArgumentParser(description="Distributed finetuning trainer") + parser.add_argument( + "--config", type=str, required=True, help="Path to resolved Hydra config" + ) + return parser.parse_args() + + +def build_runtime_config(cfg) -> RuntimeConfig: + """Build the dataclass used by the trainer from a resolved config.""" + return RuntimeConfig( + work_dir=cfg.work_dir, + grad_accum=cfg.train.gradient_accumulation_steps, + logging_steps=cfg.train.logging_steps, + eval_steps=cfg.train.eval_steps, + save_steps=cfg.train.save_steps, + num_train_epochs=cfg.train.num_train_epochs, + learning_rate=cfg.train.learning_rate, + weight_decay=cfg.train.weight_decay, + warmup_ratio=getattr(cfg.train, "warmup_ratio", 0.03), + per_device_train_batch_size=cfg.train.per_device_train_batch_size, + per_device_eval_batch_size=cfg.train.per_device_eval_batch_size, + max_length=cfg.data.max_length, + dataset_name=cfg.data.name, + text_column=cfg.data.text_key, + model_name=cfg.model.name, + trust_remote_code=getattr(cfg.data, "trust_remote_code", False), + bf16=getattr(cfg.dist, "bf16", False), + fp16=getattr(cfg.dist, "fp16", False), + gradient_checkpointing=getattr(cfg.model, "gradient_checkpointing", False), + dist_mode=getattr(cfg.dist, "mode", "none"), + dist_backend=getattr(cfg.dist, "backend", "nccl"), + seed=getattr(cfg.train, "seed", 42), + ) + + +def main() -> None: + """CLI entrypoint executed by torchrun workers.""" + args = parse_args() + raw_cfg = OmegaConf.load(args.config) + runtime_cfg = build_runtime_config(raw_cfg) + run(runtime_cfg, raw_cfg) + + +if __name__ == "__main__": + main() diff --git a/templates/src/llm/finetune_distributed/trainer_core.py b/templates/src/llm/finetune_distributed/trainer_core.py new file mode 100644 index 0000000..31f0d3a --- /dev/null +++ b/templates/src/llm/finetune_distributed/trainer_core.py @@ -0,0 +1,264 @@ +"""Core training utilities for the distributed fine-tuning template.""" + +from __future__ import annotations + +import os +from typing import Optional + +import torch +import torch.distributed as dist +from torch.distributed.fsdp import ( + FullStateDictConfig, + FullyShardedDataParallel, + StateDictType, +) +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from torch.utils.data import DataLoader + + +def is_main_process() -> bool: + """Return ``True`` when running on the primary rank.""" + if not dist.is_available() or not dist.is_initialized(): + return True + return dist.get_rank() == 0 + + +def is_fsdp(model: torch.nn.Module) -> bool: + """Return ``True`` when the model is wrapped in ``FullyShardedDataParallel``.""" + return isinstance(model, FullyShardedDataParallel) + + +class TrainerState: + """Track progress across epochs and global steps.""" + + def __init__(self) -> None: + self.epoch: int = 0 + self.global_step: int = 0 + + def to_dict(self) -> dict[str, int]: + """Serialize the state so it can be checkpointed.""" + return {"epoch": self.epoch, "global_step": self.global_step} + + @classmethod + def from_dict(cls, data: dict[str, int]) -> "TrainerState": + """Create a state object from serialized metadata.""" + state = cls() + state.epoch = data.get("epoch", 0) + state.global_step = data.get("global_step", 0) + return state + + +class Trainer: + """Minimal trainer inspired by VectorLM's implementation.""" + + def __init__( + self, + *, + model: torch.nn.Module, + optimizer: Optimizer, + scheduler: Optional[LRScheduler], + train_loader: DataLoader, + eval_loader: Optional[DataLoader], + work_dir: str, + gradient_accumulation_steps: int, + logging_steps: int, + eval_steps: int, + save_steps: int, + max_epochs: int, + ) -> None: + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.train_loader = train_loader + self.eval_loader = eval_loader + self.work_dir = work_dir + self.grad_accum = max(gradient_accumulation_steps, 1) + self.logging_steps = logging_steps + self.eval_steps = eval_steps + self.save_steps = save_steps + self.max_epochs = max_epochs + + self.state = TrainerState() + self._log_cache: list[float] = [] + + self.checkpoint_dir = os.path.join(self.work_dir, "checkpoints") + os.makedirs(self.checkpoint_dir, exist_ok=True) + + def _model_state_dict(self) -> dict: + """Return a model state dict that respects FSDP wrapping.""" + if is_fsdp(self.model): + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=False) + with FullyShardedDataParallel.state_dict_type( + self.model, StateDictType.FULL_STATE_DICT, cfg + ): + return self.model.state_dict() + return self.model.state_dict() + + def _load_model_state_dict(self, state: dict) -> None: + """Restore parameters into an optionally sharded model.""" + if is_fsdp(self.model): + cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=False) + with FullyShardedDataParallel.state_dict_type( + self.model, StateDictType.FULL_STATE_DICT, cfg + ): + self.model.load_state_dict(state) + return + self.model.load_state_dict(state) + + def latest_checkpoint(self) -> Optional[str]: + """Return the newest checkpoint directory, if it exists.""" + if not os.path.exists(self.checkpoint_dir): + return None + ckpts = [ + name for name in os.listdir(self.checkpoint_dir) if name.startswith("step=") + ] + if not ckpts: + return None + ckpts.sort(key=lambda name: int(name.split("=")[-1])) + return os.path.join(self.checkpoint_dir, ckpts[-1]) + + def save_checkpoint(self) -> None: + """Persist model/optimizer/scheduler state to disk.""" + if not is_main_process(): + return + path = os.path.join(self.checkpoint_dir, f"step={self.state.global_step}") + os.makedirs(path, exist_ok=True) + torch.save(self._model_state_dict(), os.path.join(path, "model.pt")) + torch.save(self.optimizer.state_dict(), os.path.join(path, "optimizer.pt")) + if self.scheduler is not None: + torch.save(self.scheduler.state_dict(), os.path.join(path, "scheduler.pt")) + torch.save(self.state.to_dict(), os.path.join(path, "trainer_state.pt")) + + def load_checkpoint(self, ckpt_path: str) -> None: + """Reload a previously saved checkpoint.""" + map_location = ( + torch.cuda.current_device() if torch.cuda.is_available() else "cpu" + ) + model_state = torch.load( + os.path.join(ckpt_path, "model.pt"), map_location=map_location + ) + self._load_model_state_dict(model_state) + + optim_state = torch.load( + os.path.join(ckpt_path, "optimizer.pt"), map_location=map_location + ) + self.optimizer.load_state_dict(optim_state) + + sched_path = os.path.join(ckpt_path, "scheduler.pt") + if self.scheduler is not None and os.path.exists(sched_path): + sched_state = torch.load(sched_path, map_location=map_location) + self.scheduler.load_state_dict(sched_state) + + state_path = os.path.join(ckpt_path, "trainer_state.pt") + if os.path.exists(state_path): + self.state = TrainerState.from_dict(torch.load(state_path)) + + def train(self) -> None: + """Run the main training loop.""" + if is_main_process(): + print(f"Starting training for {self.max_epochs} epochs") + + for epoch in range(self.state.epoch, self.max_epochs): + self.state.epoch = epoch + if hasattr(self.train_loader.sampler, "set_epoch"): + self.train_loader.sampler.set_epoch(epoch) + + self._train_one_epoch() + + should_eval = ( + self.eval_loader is not None + and self.eval_steps > 0 + and self.state.global_step % self.eval_steps == 0 + ) + if should_eval: + self.evaluate(epoch) + + if self.state.global_step >= len(self.train_loader) * self.max_epochs: + break + + if is_main_process(): + print("Training finished") + + def _train_one_epoch(self) -> None: + """Train the model for a single epoch.""" + self.model.train() + device = ( + torch.device("cuda", torch.cuda.current_device()) + if torch.cuda.is_available() + else torch.device("cpu") + ) + + for step, batch in enumerate(self.train_loader, start=1): + batch_on_device = {key: value.to(device) for key, value in batch.items()} + outputs = self.model(**batch_on_device) + loss = outputs.loss / self.grad_accum + loss.backward() + + gathered_loss = loss.detach().clone() + if dist.is_initialized(): + dist.all_reduce(gathered_loss, op=dist.ReduceOp.SUM) + gathered_loss /= dist.get_world_size() + self._log_cache.append(gathered_loss.item()) + + if step % self.grad_accum == 0: + if hasattr(self.model, "clip_grad_norm_"): + self.model.clip_grad_norm_(max_norm=1.0) + else: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), max_norm=1.0 + ) + self.optimizer.step() + if self.scheduler is not None: + self.scheduler.step() + self.optimizer.zero_grad(set_to_none=True) + self.state.global_step += 1 + + should_log = ( + self.logging_steps + and self.state.global_step % self.logging_steps == 0 + and is_main_process() + ) + if should_log: + mean_loss = sum(self._log_cache) / max(len(self._log_cache), 1) + print(f"step={self.state.global_step} loss={mean_loss:.4f}") + self._log_cache.clear() + + if self.save_steps and self.state.global_step % self.save_steps == 0: + if dist.is_initialized(): + dist.barrier() + self.save_checkpoint() + if dist.is_initialized(): + dist.barrier() + + def evaluate(self, epoch: int) -> None: + """Run an evaluation epoch and report mean loss.""" + if self.eval_loader is None: + return + self.model.eval() + device = ( + torch.device("cuda", torch.cuda.current_device()) + if torch.cuda.is_available() + else torch.device("cpu") + ) + losses = [] + with torch.no_grad(): + for batch in self.eval_loader: + batch_on_device = { + key: value.to(device) for key, value in batch.items() + } + outputs = self.model(**batch_on_device) + losses.append(outputs.loss.detach()) + if not losses: + return + loss_tensor = torch.stack(losses) + if dist.is_initialized(): + dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM) + loss_tensor /= dist.get_world_size() + mean_loss = loss_tensor.mean().item() + if is_main_process(): + print(f"epoch={epoch} eval_loss={mean_loss:.4f}") + self.model.train() + + +__all__ = ["Trainer", "TrainerState", "is_main_process"]