Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/megatron/bridge/data/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def cyclic_iter(iter: Iterable) -> Iterator:
def get_train_valid_test_num_samples(cfg: ConfigContainer) -> tuple[int, int, int]:
"""Calculate the number of samples for train, validation, and test sets.

Determines sample counts based on training iterations, global batch size,
and evaluation interval/iterations specified in the config.
Determines sample counts based on training mode either specified iterations or samples,
global batch size, and evaluation interval/iterations specified in the config.

Args:
cfg: The main configuration container.
Expand All @@ -119,8 +119,13 @@ def get_train_valid_test_num_samples(cfg: ConfigContainer) -> tuple[int, int, in
A tuple (train_samples, valid_samples, test_samples).
"""

# Number of train/valid/test samples.
train_samples = cfg.train.train_iters * cfg.train.global_batch_size
# If train_samples is directly provided, use it
if cfg.train.train_samples is not None:
train_samples = cfg.train.train_samples
else:
# Otherwise fallback to calculating samples based on iterations and global batch size
train_samples = cfg.train.train_iters * cfg.train.global_batch_size

eval_iters = (cfg.train.train_iters // cfg.train.eval_interval + 1) * cfg.train.eval_iters
test_iters = cfg.train.eval_iters

Expand Down
61 changes: 61 additions & 0 deletions src/megatron/bridge/recipes/utils/optimizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,64 @@ def distributed_fused_adam_with_cosine_annealing(
)

return optimizer, scheduler


def distributed_fused_adam_with_cosine_annealing_samples(
precision: str = "bf16-mixed",
lr_warmup_samples: Optional[int] = None,
lr_decay_samples: Optional[int] = None,
adam_beta1: float = 0.9,
adam_beta2: float = 0.95,
adam_eps: float = 1e-5,
weight_decay: float = 0.1,
max_lr: float = 1e-4,
min_lr: Optional[float] = None,
clip_grad: float = 1.0,
) -> tuple[OptimizerConfig, SchedulerConfig]:
"""
Creates a distributed fused Adam optimizer with cosine annealing scheduler for sample-based training.

This is the sample-based equivalent of distributed_fused_adam_with_cosine_annealing().

Args:
precision: Mixed precision mode ("bf16-mixed", "16-mixed", etc.)
lr_warmup_samples: Number of samples for learning rate warmup (None = auto from train_samples)
lr_decay_samples: Number of samples for learning rate decay (None = auto from train_samples)
adam_beta1: Adam optimizer beta1 parameter
adam_beta2: Adam optimizer beta2 parameter
adam_eps: Adam optimizer epsilon parameter
weight_decay: Weight decay coefficient
max_lr: Maximum learning rate
min_lr: Minimum learning rate (defaults to 0.1 * max_lr)
clip_grad: Gradient clipping value

Returns:
A tuple of (OptimizerConfig, SchedulerConfig) configured for sample-based training
"""
min_lr = min_lr if min_lr is not None else (0.1 * max_lr)
optimizer = OptimizerConfig(
optimizer="adam",
lr=max_lr,
min_lr=min_lr,
weight_decay=weight_decay,
bf16=precision == "bf16-mixed",
fp16=precision == "16-mixed",
adam_beta1=adam_beta1,
adam_beta2=adam_beta2,
adam_eps=adam_eps,
use_distributed_optimizer=True,
clip_grad=clip_grad,
)

scheduler = SchedulerConfig(
start_weight_decay=0.033,
end_weight_decay=0.033,
weight_decay_incr_style="constant",
lr_decay_style="cosine",
lr_warmup_samples=lr_warmup_samples,
lr_warmup_init=0.0,
lr_decay_samples=lr_decay_samples,
override_opt_param_scheduler=True,
)

return optimizer, scheduler
135 changes: 120 additions & 15 deletions src/megatron/bridge/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,17 +362,26 @@ class SchedulerConfig:
"""Decay style for the annealing phase of WSD"""

lr_decay_iters: Optional[int] = None
"""number of iterations to decay learning rate over, If None defaults to `--train-iters`"""
"""number of iterations to decay learning rate over, If None defaults to `train.train_iters`"""

lr_decay_samples: Optional[int] = None
"""number of samples to decay learning rate over, If None defaults to `train.train_samples`"""

lr_wsd_decay_iters: Optional[int] = None
"""number of iterations for the annealing phase in the wsd schedule"""

lr_wsd_decay_samples: Optional[int] = None
"""number of samples for the annealing phase in the wsd schedule"""

lr_warmup_fraction: Optional[float] = None
"""fraction of lr-warmup-(iters/samples) to use for warmup (as a float)"""

lr_warmup_iters: int = 0
"""number of iterations to linearly warmup learning rate over."""

lr_warmup_samples: int = 0
"""number of samples to linearly warmup learning rate over."""

lr_warmup_init: float = 0.0
"""Initial value for learning rate warmup. The scheduler starts warmup from this value."""

Expand Down Expand Up @@ -402,7 +411,7 @@ class SchedulerConfig:
wd_incr_steps: Optional[int] = field(init=False, default=None)
wsd_decay_steps: Optional[int] = field(init=False, default=None)

def finalize(self):
def finalize(self) -> None:
"""Post-initialization checks for scheduler config."""
if self.start_weight_decay is not None:
assert self.start_weight_decay >= 0.0, "start_weight_decay should be positive."
Expand All @@ -411,6 +420,28 @@ def finalize(self):
if self.override_opt_param_scheduler:
assert not self.use_checkpoint_opt_param_scheduler, "both override and use-checkpoint are set."

# Validate mutual exclusivity between iteration-based and sample-based scheduler fields
has_iter_fields = (
self.lr_decay_iters is not None or self.lr_warmup_iters != 0 or self.lr_wsd_decay_iters is not None
)
has_sample_fields = (
self.lr_decay_samples is not None or self.lr_warmup_samples != 0 or self.lr_wsd_decay_samples is not None
)

assert not (has_iter_fields and has_sample_fields), (
f"Cannot mix iteration-based and sample-based scheduler fields. "
f"Found iteration fields: lr_decay_iters={self.lr_decay_iters}, lr_warmup_iters={self.lr_warmup_iters}, lr_wsd_decay_iters={self.lr_wsd_decay_iters}. "
f"Found sample fields: lr_decay_samples={self.lr_decay_samples}, lr_warmup_samples={self.lr_warmup_samples}, lr_wsd_decay_samples={self.lr_wsd_decay_samples}. "
f"Use either iteration fields OR sample fields, not both."
)

# Validate mutual exclusivity between lr_warmup_fraction and specific warmup fields
if self.lr_warmup_fraction is not None:
assert self.lr_warmup_iters == 0 and self.lr_warmup_samples == 0, (
f"Cannot specify lr_warmup_fraction={self.lr_warmup_fraction} with lr_warmup_iters={self.lr_warmup_iters} or lr_warmup_samples={self.lr_warmup_samples}. "
f"Use either lr_warmup_fraction OR lr_warmup_iters OR lr_warmup_samples."
)


@dataclass(kw_only=True)
class TrainingConfig:
Expand Down Expand Up @@ -457,9 +488,13 @@ class TrainingConfig:

train_iters: Optional[int] = None
"""Total number of iterations to train over all training runs.
Note that either train-iters or train-samples should be provided.
Note that either train_iters or train_samples should be provided.
"""

train_samples: Optional[int] = None
"""Total number of samples to train over all training runs.
Note that either train_iters or train_samples should be provided."""

exit_interval: Optional[int] = None
"""Exit the program after the iteration is divisible by this value."""

Expand Down Expand Up @@ -502,6 +537,21 @@ class TrainingConfig:
skip_train: bool = False
"""If set, bypass the training loop, optionally do evaluation for validation/test, and exit."""

def finalize(self) -> None:
"""Validate training mode specification and calculate train_iters from train_samples if needed."""
has_train_iters = self.train_iters is not None
has_train_samples = self.train_samples is not None

assert has_train_iters or has_train_samples, "Either train_iters or train_samples must be provided"
assert not (has_train_iters and has_train_samples), "Cannot specify both train_iters and train_samples"
if has_train_samples:
assert self.train_samples > 0, "train_samples must be positive"
assert self.rampup_batch_size is None, "Batch size rampup not supported with sample-based training yet"

# Calculate train_iters from train_samples (rampup_batch_size already validated as None)
self.train_iters = self.train_samples // self.global_batch_size
print_rank_0(f"Setting training iterations to {self.train_iters} based on {self.train_samples} samples")


@dataclass(kw_only=True)
class CheckpointConfig:
Expand Down Expand Up @@ -1030,6 +1080,8 @@ def validate(self) -> None:
self.optimizer.finalize()
if hasattr(self.model, "finalize"):
self.model.finalize()

self.train.finalize()
self.scheduler.finalize()
self.checkpoint.finalize()
if self.profiling is not None:
Expand Down Expand Up @@ -1102,18 +1154,11 @@ def validate(self) -> None:
# optimizer state in the CPU memory of DP rank 0.
assert self.checkpoint.ckpt_format == "torch_dist"

# Scheduler
if self.scheduler.lr_decay_iters is None:
self.scheduler.lr_decay_iters = self.train.train_iters
self.scheduler.lr_decay_steps = self.scheduler.lr_decay_iters * self.train.global_batch_size
self.scheduler.wd_incr_steps = self.train.train_iters * self.train.global_batch_size
self.scheduler.wsd_decay_steps = None
if self.scheduler.lr_wsd_decay_iters is not None:
self.scheduler.wsd_decay_steps = self.scheduler.lr_wsd_decay_iters * self.train.global_batch_size
if self.scheduler.lr_warmup_fraction is not None:
self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_fraction * self.scheduler.lr_decay_steps
else:
self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_iters * self.train.global_batch_size
# Cross-validation between training and scheduler configs
self._validate_training_scheduler_compatibility()

# Calculate scheduler steps for both iteration-based and sample-based training
self._calculate_scheduler_steps()

if self.model.context_parallel_size > 1:
assert self.model.seq_length % (self.model.context_parallel_size * 2) == 0, (
Expand Down Expand Up @@ -1172,6 +1217,66 @@ def validate(self) -> None:

self._sync_and_validate_external_cuda_graph()

def _validate_training_scheduler_compatibility(self) -> None:
"""Cross-validation between training and scheduler configs."""
has_train_samples = self.train.train_samples is not None

if has_train_samples:
# Sample-based training validation
assert self.scheduler.lr_decay_iters is None, (
"Use lr_decay_samples for sample-based training, not lr_decay_iters"
)
assert self.scheduler.lr_warmup_iters == 0, (
"Use lr_warmup_samples for sample-based training, not lr_warmup_iters"
)
assert not (self.scheduler.lr_warmup_fraction is not None and self.scheduler.lr_warmup_samples != 0), (
"Can only specify one of lr_warmup_fraction or lr_warmup_samples"
)
else:
# Iteration-based training validation
assert self.scheduler.lr_decay_samples is None, (
"Use lr_decay_iters for iteration-based training, not lr_decay_samples"
)
assert self.scheduler.lr_warmup_samples == 0, (
"Use lr_warmup_iters for iteration-based training, not lr_warmup_samples"
)
assert not (self.scheduler.lr_warmup_fraction is not None and self.scheduler.lr_warmup_iters != 0), (
"Can only specify one of lr_warmup_fraction or lr_warmup_iters"
)

def _calculate_scheduler_steps(self) -> None:
"""Calculate scheduler steps for both iteration-based and sample-based training."""
is_sample_based = self.train.train_samples is not None

if is_sample_based:
if self.scheduler.lr_decay_samples is None:
self.scheduler.lr_decay_samples = self.train.train_samples
self.scheduler.lr_decay_steps = self.scheduler.lr_decay_samples
self.scheduler.wd_incr_steps = self.train.train_samples

if self.scheduler.lr_wsd_decay_samples is not None:
self.scheduler.wsd_decay_steps = self.scheduler.lr_wsd_decay_samples

# Warmup calculation for sample-based training
if self.scheduler.lr_warmup_fraction is not None:
self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_fraction * self.scheduler.lr_decay_steps
else:
self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_samples
else:
# Iteration-based training
if self.scheduler.lr_decay_iters is None:
self.scheduler.lr_decay_iters = self.train.train_iters
self.scheduler.lr_decay_steps = self.scheduler.lr_decay_iters * self.train.global_batch_size
self.scheduler.wd_incr_steps = self.train.train_iters * self.train.global_batch_size

if self.scheduler.lr_wsd_decay_iters is not None:
self.scheduler.wsd_decay_steps = self.scheduler.lr_wsd_decay_iters * self.train.global_batch_size

if self.scheduler.lr_warmup_fraction is not None:
self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_fraction * self.scheduler.lr_decay_steps
else:
self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_iters * self.train.global_batch_size


def runtime_config_update(cfg: ConfigContainer) -> None:
"""Apply runtime configuration updates prior to initialization.
Expand Down
Loading
Loading