diff --git a/src/megatron/bridge/data/loaders.py b/src/megatron/bridge/data/loaders.py index 6c3aeda95..62232baea 100644 --- a/src/megatron/bridge/data/loaders.py +++ b/src/megatron/bridge/data/loaders.py @@ -109,8 +109,8 @@ def cyclic_iter(iter: Iterable) -> Iterator: def get_train_valid_test_num_samples(cfg: ConfigContainer) -> tuple[int, int, int]: """Calculate the number of samples for train, validation, and test sets. - Determines sample counts based on training iterations, global batch size, - and evaluation interval/iterations specified in the config. + Determines sample counts based on training mode either specified iterations or samples, + global batch size, and evaluation interval/iterations specified in the config. Args: cfg: The main configuration container. @@ -119,8 +119,13 @@ def get_train_valid_test_num_samples(cfg: ConfigContainer) -> tuple[int, int, in A tuple (train_samples, valid_samples, test_samples). """ - # Number of train/valid/test samples. - train_samples = cfg.train.train_iters * cfg.train.global_batch_size + # If train_samples is directly provided, use it + if cfg.train.train_samples is not None: + train_samples = cfg.train.train_samples + else: + # Otherwise fallback to calculating samples based on iterations and global batch size + train_samples = cfg.train.train_iters * cfg.train.global_batch_size + eval_iters = (cfg.train.train_iters // cfg.train.eval_interval + 1) * cfg.train.eval_iters test_iters = cfg.train.eval_iters diff --git a/src/megatron/bridge/recipes/utils/optimizer_utils.py b/src/megatron/bridge/recipes/utils/optimizer_utils.py index 986e7726b..d031427c2 100644 --- a/src/megatron/bridge/recipes/utils/optimizer_utils.py +++ b/src/megatron/bridge/recipes/utils/optimizer_utils.py @@ -75,3 +75,64 @@ def distributed_fused_adam_with_cosine_annealing( ) return optimizer, scheduler + + +def distributed_fused_adam_with_cosine_annealing_samples( + precision: str = "bf16-mixed", + lr_warmup_samples: Optional[int] = None, + lr_decay_samples: Optional[int] = None, + adam_beta1: float = 0.9, + adam_beta2: float = 0.95, + adam_eps: float = 1e-5, + weight_decay: float = 0.1, + max_lr: float = 1e-4, + min_lr: Optional[float] = None, + clip_grad: float = 1.0, +) -> tuple[OptimizerConfig, SchedulerConfig]: + """ + Creates a distributed fused Adam optimizer with cosine annealing scheduler for sample-based training. + + This is the sample-based equivalent of distributed_fused_adam_with_cosine_annealing(). + + Args: + precision: Mixed precision mode ("bf16-mixed", "16-mixed", etc.) + lr_warmup_samples: Number of samples for learning rate warmup (None = auto from train_samples) + lr_decay_samples: Number of samples for learning rate decay (None = auto from train_samples) + adam_beta1: Adam optimizer beta1 parameter + adam_beta2: Adam optimizer beta2 parameter + adam_eps: Adam optimizer epsilon parameter + weight_decay: Weight decay coefficient + max_lr: Maximum learning rate + min_lr: Minimum learning rate (defaults to 0.1 * max_lr) + clip_grad: Gradient clipping value + + Returns: + A tuple of (OptimizerConfig, SchedulerConfig) configured for sample-based training + """ + min_lr = min_lr if min_lr is not None else (0.1 * max_lr) + optimizer = OptimizerConfig( + optimizer="adam", + lr=max_lr, + min_lr=min_lr, + weight_decay=weight_decay, + bf16=precision == "bf16-mixed", + fp16=precision == "16-mixed", + adam_beta1=adam_beta1, + adam_beta2=adam_beta2, + adam_eps=adam_eps, + use_distributed_optimizer=True, + clip_grad=clip_grad, + ) + + scheduler = SchedulerConfig( + start_weight_decay=0.033, + end_weight_decay=0.033, + weight_decay_incr_style="constant", + lr_decay_style="cosine", + lr_warmup_samples=lr_warmup_samples, + lr_warmup_init=0.0, + lr_decay_samples=lr_decay_samples, + override_opt_param_scheduler=True, + ) + + return optimizer, scheduler diff --git a/src/megatron/bridge/training/config.py b/src/megatron/bridge/training/config.py index 519189ae8..90728573b 100644 --- a/src/megatron/bridge/training/config.py +++ b/src/megatron/bridge/training/config.py @@ -362,17 +362,26 @@ class SchedulerConfig: """Decay style for the annealing phase of WSD""" lr_decay_iters: Optional[int] = None - """number of iterations to decay learning rate over, If None defaults to `--train-iters`""" + """number of iterations to decay learning rate over, If None defaults to `train.train_iters`""" + + lr_decay_samples: Optional[int] = None + """number of samples to decay learning rate over, If None defaults to `train.train_samples`""" lr_wsd_decay_iters: Optional[int] = None """number of iterations for the annealing phase in the wsd schedule""" + lr_wsd_decay_samples: Optional[int] = None + """number of samples for the annealing phase in the wsd schedule""" + lr_warmup_fraction: Optional[float] = None """fraction of lr-warmup-(iters/samples) to use for warmup (as a float)""" lr_warmup_iters: int = 0 """number of iterations to linearly warmup learning rate over.""" + lr_warmup_samples: int = 0 + """number of samples to linearly warmup learning rate over.""" + lr_warmup_init: float = 0.0 """Initial value for learning rate warmup. The scheduler starts warmup from this value.""" @@ -402,7 +411,7 @@ class SchedulerConfig: wd_incr_steps: Optional[int] = field(init=False, default=None) wsd_decay_steps: Optional[int] = field(init=False, default=None) - def finalize(self): + def finalize(self) -> None: """Post-initialization checks for scheduler config.""" if self.start_weight_decay is not None: assert self.start_weight_decay >= 0.0, "start_weight_decay should be positive." @@ -411,6 +420,28 @@ def finalize(self): if self.override_opt_param_scheduler: assert not self.use_checkpoint_opt_param_scheduler, "both override and use-checkpoint are set." + # Validate mutual exclusivity between iteration-based and sample-based scheduler fields + has_iter_fields = ( + self.lr_decay_iters is not None or self.lr_warmup_iters != 0 or self.lr_wsd_decay_iters is not None + ) + has_sample_fields = ( + self.lr_decay_samples is not None or self.lr_warmup_samples != 0 or self.lr_wsd_decay_samples is not None + ) + + assert not (has_iter_fields and has_sample_fields), ( + f"Cannot mix iteration-based and sample-based scheduler fields. " + f"Found iteration fields: lr_decay_iters={self.lr_decay_iters}, lr_warmup_iters={self.lr_warmup_iters}, lr_wsd_decay_iters={self.lr_wsd_decay_iters}. " + f"Found sample fields: lr_decay_samples={self.lr_decay_samples}, lr_warmup_samples={self.lr_warmup_samples}, lr_wsd_decay_samples={self.lr_wsd_decay_samples}. " + f"Use either iteration fields OR sample fields, not both." + ) + + # Validate mutual exclusivity between lr_warmup_fraction and specific warmup fields + if self.lr_warmup_fraction is not None: + assert self.lr_warmup_iters == 0 and self.lr_warmup_samples == 0, ( + f"Cannot specify lr_warmup_fraction={self.lr_warmup_fraction} with lr_warmup_iters={self.lr_warmup_iters} or lr_warmup_samples={self.lr_warmup_samples}. " + f"Use either lr_warmup_fraction OR lr_warmup_iters OR lr_warmup_samples." + ) + @dataclass(kw_only=True) class TrainingConfig: @@ -457,9 +488,13 @@ class TrainingConfig: train_iters: Optional[int] = None """Total number of iterations to train over all training runs. - Note that either train-iters or train-samples should be provided. + Note that either train_iters or train_samples should be provided. """ + train_samples: Optional[int] = None + """Total number of samples to train over all training runs. + Note that either train_iters or train_samples should be provided.""" + exit_interval: Optional[int] = None """Exit the program after the iteration is divisible by this value.""" @@ -502,6 +537,21 @@ class TrainingConfig: skip_train: bool = False """If set, bypass the training loop, optionally do evaluation for validation/test, and exit.""" + def finalize(self) -> None: + """Validate training mode specification and calculate train_iters from train_samples if needed.""" + has_train_iters = self.train_iters is not None + has_train_samples = self.train_samples is not None + + assert has_train_iters or has_train_samples, "Either train_iters or train_samples must be provided" + assert not (has_train_iters and has_train_samples), "Cannot specify both train_iters and train_samples" + if has_train_samples: + assert self.train_samples > 0, "train_samples must be positive" + assert self.rampup_batch_size is None, "Batch size rampup not supported with sample-based training yet" + + # Calculate train_iters from train_samples (rampup_batch_size already validated as None) + self.train_iters = self.train_samples // self.global_batch_size + print_rank_0(f"Setting training iterations to {self.train_iters} based on {self.train_samples} samples") + @dataclass(kw_only=True) class CheckpointConfig: @@ -1030,6 +1080,8 @@ def validate(self) -> None: self.optimizer.finalize() if hasattr(self.model, "finalize"): self.model.finalize() + + self.train.finalize() self.scheduler.finalize() self.checkpoint.finalize() if self.profiling is not None: @@ -1102,18 +1154,11 @@ def validate(self) -> None: # optimizer state in the CPU memory of DP rank 0. assert self.checkpoint.ckpt_format == "torch_dist" - # Scheduler - if self.scheduler.lr_decay_iters is None: - self.scheduler.lr_decay_iters = self.train.train_iters - self.scheduler.lr_decay_steps = self.scheduler.lr_decay_iters * self.train.global_batch_size - self.scheduler.wd_incr_steps = self.train.train_iters * self.train.global_batch_size - self.scheduler.wsd_decay_steps = None - if self.scheduler.lr_wsd_decay_iters is not None: - self.scheduler.wsd_decay_steps = self.scheduler.lr_wsd_decay_iters * self.train.global_batch_size - if self.scheduler.lr_warmup_fraction is not None: - self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_fraction * self.scheduler.lr_decay_steps - else: - self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_iters * self.train.global_batch_size + # Cross-validation between training and scheduler configs + self._validate_training_scheduler_compatibility() + + # Calculate scheduler steps for both iteration-based and sample-based training + self._calculate_scheduler_steps() if self.model.context_parallel_size > 1: assert self.model.seq_length % (self.model.context_parallel_size * 2) == 0, ( @@ -1172,6 +1217,66 @@ def validate(self) -> None: self._sync_and_validate_external_cuda_graph() + def _validate_training_scheduler_compatibility(self) -> None: + """Cross-validation between training and scheduler configs.""" + has_train_samples = self.train.train_samples is not None + + if has_train_samples: + # Sample-based training validation + assert self.scheduler.lr_decay_iters is None, ( + "Use lr_decay_samples for sample-based training, not lr_decay_iters" + ) + assert self.scheduler.lr_warmup_iters == 0, ( + "Use lr_warmup_samples for sample-based training, not lr_warmup_iters" + ) + assert not (self.scheduler.lr_warmup_fraction is not None and self.scheduler.lr_warmup_samples != 0), ( + "Can only specify one of lr_warmup_fraction or lr_warmup_samples" + ) + else: + # Iteration-based training validation + assert self.scheduler.lr_decay_samples is None, ( + "Use lr_decay_iters for iteration-based training, not lr_decay_samples" + ) + assert self.scheduler.lr_warmup_samples == 0, ( + "Use lr_warmup_iters for iteration-based training, not lr_warmup_samples" + ) + assert not (self.scheduler.lr_warmup_fraction is not None and self.scheduler.lr_warmup_iters != 0), ( + "Can only specify one of lr_warmup_fraction or lr_warmup_iters" + ) + + def _calculate_scheduler_steps(self) -> None: + """Calculate scheduler steps for both iteration-based and sample-based training.""" + is_sample_based = self.train.train_samples is not None + + if is_sample_based: + if self.scheduler.lr_decay_samples is None: + self.scheduler.lr_decay_samples = self.train.train_samples + self.scheduler.lr_decay_steps = self.scheduler.lr_decay_samples + self.scheduler.wd_incr_steps = self.train.train_samples + + if self.scheduler.lr_wsd_decay_samples is not None: + self.scheduler.wsd_decay_steps = self.scheduler.lr_wsd_decay_samples + + # Warmup calculation for sample-based training + if self.scheduler.lr_warmup_fraction is not None: + self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_fraction * self.scheduler.lr_decay_steps + else: + self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_samples + else: + # Iteration-based training + if self.scheduler.lr_decay_iters is None: + self.scheduler.lr_decay_iters = self.train.train_iters + self.scheduler.lr_decay_steps = self.scheduler.lr_decay_iters * self.train.global_batch_size + self.scheduler.wd_incr_steps = self.train.train_iters * self.train.global_batch_size + + if self.scheduler.lr_wsd_decay_iters is not None: + self.scheduler.wsd_decay_steps = self.scheduler.lr_wsd_decay_iters * self.train.global_batch_size + + if self.scheduler.lr_warmup_fraction is not None: + self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_fraction * self.scheduler.lr_decay_steps + else: + self.scheduler.lr_warmup_steps = self.scheduler.lr_warmup_iters * self.train.global_batch_size + def runtime_config_update(cfg: ConfigContainer) -> None: """Apply runtime configuration updates prior to initialization. diff --git a/tests/functional_tests/training/test_sample_based_training.py b/tests/functional_tests/training/test_sample_based_training.py new file mode 100644 index 000000000..a202b69f0 --- /dev/null +++ b/tests/functional_tests/training/test_sample_based_training.py @@ -0,0 +1,114 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Functional tests for sample-based training that run on 2 GPUs with torchrun.""" + +import logging + +import torch + +from megatron.bridge.models.llama import Llama32ModelProvider1B +from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing_samples +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + DistributedDataParallelConfig, + DistributedInitConfig, + LoggerConfig, + MockGPTDatasetConfig, + RerunStateMachineConfig, + RNGConfig, + TrainingConfig, +) +from megatron.bridge.training.gpt_step import forward_step +from megatron.bridge.training.pretrain import pretrain +from megatron.bridge.training.tokenizers.config import TokenizerConfig +from megatron.bridge.utils.common_utils import get_rank_safe + + +_logger: logging.Logger = logging.getLogger(__name__) + + +class TestSampleBasedTrainingFunctional: + """Functional tests for sample-based training on 2 GPUs.""" + + def test_sample_based_training_mini_run(self): + """Mini end-to-end test that runs a few training steps with sample-based training.""" + # Use the new sample-based optimizer utility + optimizer_cfg, scheduler_cfg = distributed_fused_adam_with_cosine_annealing_samples( + precision="bf16-mixed", + lr_warmup_samples=8, # Very small for quick test + lr_decay_samples=24, + max_lr=1e-3, + min_lr=1e-4, + ) + + cfg = ConfigContainer( + train=TrainingConfig( + micro_batch_size=1, + global_batch_size=4, # 2 GPUs * 2 data_parallel_size + train_samples=32, # Sample-based training (8 iterations) + eval_iters=2, + eval_interval=4, + skip_train=False, + ), + model=Llama32ModelProvider1B( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + sequence_parallel=False, + attention_softmax_in_fp32=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + seq_length=256, + make_vocab_size_divisible_by=128, + vocab_size=None, + ), + optimizer=optimizer_cfg, + scheduler=scheduler_cfg, + dataset=MockGPTDatasetConfig( + random_seed=1234, + sequence_length=256, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + num_dataset_builder_threads=1, + data_sharding=True, + dataloader_type="single", + num_workers=1, + ), + logger=LoggerConfig( + log_interval=2, + tensorboard_dir=None, # Disable tensorboard for testing + ), + tokenizer=TokenizerConfig( + tokenizer_type="NullTokenizer", + vocab_size=10000, + ), + checkpoint=CheckpointConfig(), + dist=DistributedInitConfig(), + ddp=DistributedDataParallelConfig(use_distributed_optimizer=True), + rng=RNGConfig(seed=42), + rerun_state_machine=RerunStateMachineConfig(), + ) + + assert cfg.train.train_samples == 32 + assert cfg.scheduler.lr_decay_samples == 24 + assert cfg.scheduler.lr_warmup_samples == 8 + + pretrain(config=cfg, forward_step_func=forward_step) + + if get_rank_safe() == 0: + _logger.debug(f"Trained for {cfg.train.train_samples} samples over {cfg.train.train_iters} iterations") + _logger.debug(f"Used sample-based scheduler with {cfg.scheduler.lr_warmup_samples} warmup samples") diff --git a/tests/unit_tests/data/test_loaders.py b/tests/unit_tests/data/test_loaders.py index aae8c3741..852eaca8e 100644 --- a/tests/unit_tests/data/test_loaders.py +++ b/tests/unit_tests/data/test_loaders.py @@ -20,10 +20,71 @@ from megatron.bridge.data.loaders import ( build_train_valid_test_data_loaders, get_blend_and_blend_per_split, + get_train_valid_test_num_samples, ) from megatron.bridge.data.utils import get_dataset_provider -from megatron.bridge.recipes.llama.llama3 import llama3_8b_pretrain_config as pretrain_config +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + DistributedDataParallelConfig, + DistributedInitConfig, + LoggerConfig, + MockGPTDatasetConfig, + OptimizerConfig, + RerunStateMachineConfig, + RNGConfig, + SchedulerConfig, + TrainingConfig, +) from megatron.bridge.training.state import TrainState +from megatron.bridge.training.tokenizers.config import TokenizerConfig + + +def create_simple_test_config(): + """Create a simple test configuration without HuggingFace dependencies.""" + return ConfigContainer( + train=TrainingConfig( + micro_batch_size=1, + global_batch_size=32, + train_iters=1000, + eval_iters=10, + eval_interval=100, + ), + model=GPTModelProvider( + num_layers=1, + hidden_size=128, + num_attention_heads=4, + seq_length=512, + apply_rope_fusion=False, + vocab_size=1000, + make_vocab_size_divisible_by=1, + ), + optimizer=OptimizerConfig( + lr=0.001, + use_distributed_optimizer=False, + ), + scheduler=SchedulerConfig(), + dataset=MockGPTDatasetConfig( + random_seed=1234, + sequence_length=512, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + dataloader_type="single", + num_workers=1, + ), + logger=LoggerConfig(), + tokenizer=TokenizerConfig( + tokenizer_type="NullTokenizer", + vocab_size=1000, + ), + checkpoint=CheckpointConfig(), + dist=DistributedInitConfig(), + ddp=DistributedDataParallelConfig(), + rng=RNGConfig(), + rerun_state_machine=RerunStateMachineConfig(), + ) class TestDataLoaders: @@ -87,18 +148,8 @@ def test_build_train_valid_test_data_loaders( ): mock_get_data_parallel_rank.return_value = 0 mock_get_data_parallel_world_size.return_value = 1 - # Avoid HF download by mocking AutoBridge - with mock.patch("megatron.bridge.recipes.llama.llama3.AutoBridge.from_hf_pretrained") as mock_from: - - class _DummyBridge: - def to_megatron_provider(self, load_weights=False): - from megatron.bridge.models.llama.llama_provider import Llama3ModelProvider - - return Llama3ModelProvider() - mock_from.return_value = _DummyBridge() - cfg = pretrain_config() - cfg.train.train_iters = 1000 + cfg = create_simple_test_config() cfg.dataset.finalize() dataset_provider = get_dataset_provider(cfg.dataset) train_dataloader, valid_dataloader, test_dataloader = build_train_valid_test_data_loaders( @@ -121,18 +172,8 @@ def test_build_train_valid_test_data_loaders_eval_iters_0( ): mock_get_data_parallel_rank.return_value = 0 mock_get_data_parallel_world_size.return_value = 1 - # Avoid HF download by mocking AutoBridge - with mock.patch("megatron.bridge.recipes.llama.llama3.AutoBridge.from_hf_pretrained") as mock_from: - class _DummyBridge: - def to_megatron_provider(self, load_weights=False): - from megatron.bridge.models.llama.llama_provider import Llama3ModelProvider - - return Llama3ModelProvider() - - mock_from.return_value = _DummyBridge() - cfg = pretrain_config() - cfg.train.train_iters = 1000 + cfg = create_simple_test_config() cfg.train.eval_iters = 0 cfg.dataset.finalize() dataset_provider = get_dataset_provider(cfg.dataset) @@ -147,3 +188,84 @@ def to_megatron_provider(self, load_weights=False): assert train_dataloader is not None assert valid_dataloader is None assert test_dataloader is None + + +class TestSampleBasedDataLoaders: + """Tests for sample-based training data loader functionality.""" + + def test_get_train_valid_test_num_samples_iteration_based(self): + """Test sample calculation for iteration-based training.""" + cfg = create_simple_test_config() + + train_samples, valid_samples, test_samples = get_train_valid_test_num_samples(cfg) + + expected_train_samples = cfg.train.train_iters * cfg.train.global_batch_size + expected_eval_iters = (cfg.train.train_iters // cfg.train.eval_interval + 1) * cfg.train.eval_iters + expected_valid_samples = expected_eval_iters * cfg.train.global_batch_size + expected_test_samples = cfg.train.eval_iters * cfg.train.global_batch_size + + assert train_samples == expected_train_samples + assert valid_samples == expected_valid_samples + assert test_samples == expected_test_samples + + def test_get_train_valid_test_num_samples_sample_based(self): + """Test sample calculation for sample-based training.""" + cfg = create_simple_test_config() + cfg.train.train_samples = 50000 # Use sample-based training + cfg.train.train_iters = None + + # Need to calculate train_iters first for eval sample calculation + cfg.train.train_iters = cfg.train.train_samples // cfg.train.global_batch_size + + train_samples, valid_samples, test_samples = get_train_valid_test_num_samples(cfg) + + expected_train_samples = cfg.train.train_samples # Direct sample count + expected_eval_iters = (cfg.train.train_iters // cfg.train.eval_interval + 1) * cfg.train.eval_iters + expected_valid_samples = expected_eval_iters * cfg.train.global_batch_size + expected_test_samples = cfg.train.eval_iters * cfg.train.global_batch_size + + assert train_samples == expected_train_samples + assert valid_samples == expected_valid_samples + assert test_samples == expected_test_samples + + @mock.patch("torch.distributed.broadcast") + @mock.patch("megatron.core.mpu.get_data_parallel_rank") + @mock.patch("megatron.core.mpu.get_data_parallel_world_size") + def test_build_data_loaders_sample_based( + self, mock_get_data_parallel_world_size, mock_get_data_parallel_rank, mock_broadcast + ): + """Test data loader building with sample-based training.""" + mock_get_data_parallel_rank.return_value = 0 + mock_get_data_parallel_world_size.return_value = 1 + + cfg = create_simple_test_config() + cfg.train.train_samples = 10000 # Sample-based training + cfg.train.train_iters = None + + # Set sample-based scheduler config + cfg.scheduler.lr_decay_samples = 8000 + cfg.scheduler.lr_decay_iters = None + cfg.scheduler.lr_warmup_samples = 1000 + cfg.scheduler.lr_warmup_iters = 0 + + # Need to validate config to calculate train_iters from train_samples + with mock.patch("megatron.bridge.utils.common_utils.get_world_size_safe", return_value=1): + cfg.validate() + + # Normal training state (no backward compatibility needed) + train_state = TrainState() + train_state.step = 0 + train_state.consumed_train_samples = 0 + train_state.consumed_valid_samples = 0 + + dataset_provider = get_dataset_provider(cfg.dataset) + + # Should build data loaders successfully + train_dataloader, valid_dataloader, test_dataloader = build_train_valid_test_data_loaders( + cfg=cfg, train_state=train_state, build_train_valid_test_datasets_provider=dataset_provider + ) + + # Verify data loaders were created + assert train_dataloader is not None + assert valid_dataloader is not None + assert test_dataloader is not None diff --git a/tests/unit_tests/recipes/utils/test_optimizer_utils.py b/tests/unit_tests/recipes/utils/test_optimizer_utils.py index fa12e266c..05c3e1117 100644 --- a/tests/unit_tests/recipes/utils/test_optimizer_utils.py +++ b/tests/unit_tests/recipes/utils/test_optimizer_utils.py @@ -16,7 +16,10 @@ from megatron.core.optimizer import OptimizerConfig -from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing +from megatron.bridge.recipes.utils.optimizer_utils import ( + distributed_fused_adam_with_cosine_annealing, + distributed_fused_adam_with_cosine_annealing_samples, +) from megatron.bridge.training.config import SchedulerConfig @@ -51,3 +54,52 @@ def test_scheduler_config(self): assert isinstance(scheduler_cfg, SchedulerConfig) assert scheduler_cfg.lr_warmup_iters == 1999 assert scheduler_cfg.lr_decay_iters == 12345 + + def test_sample_based_optimizer_config(self): + """Test sample-based optimizer config.""" + + optim_cfg, _ = distributed_fused_adam_with_cosine_annealing_samples( + precision="bf16-mixed", + adam_beta2=0.95, + adam_eps=1e-5, + weight_decay=0.1, + max_lr=1e-4, + min_lr=1e-5, + ) + + assert isinstance(optim_cfg, OptimizerConfig) + assert optim_cfg.lr == 1e-4 + assert optim_cfg.min_lr == 1e-5 + assert optim_cfg.weight_decay == 0.1 + assert optim_cfg.adam_beta2 == 0.95 + assert optim_cfg.bf16 is True + assert optim_cfg.use_distributed_optimizer is True + + def test_sample_based_scheduler_config(self): + """Test sample-based scheduler config.""" + + _, scheduler_cfg = distributed_fused_adam_with_cosine_annealing_samples( + lr_warmup_samples=1000, + lr_decay_samples=8000, + ) + + assert isinstance(scheduler_cfg, SchedulerConfig) + assert scheduler_cfg.lr_warmup_samples == 1000 + assert scheduler_cfg.lr_decay_samples == 8000 + assert scheduler_cfg.lr_warmup_iters == 0 # Should be 0 for sample-based + assert scheduler_cfg.lr_decay_iters is None # Should be None for sample-based + assert scheduler_cfg.lr_decay_style == "cosine" + + def test_sample_based_scheduler_config_with_none_defaults(self): + """Test sample-based scheduler config with None defaults (auto from train_samples).""" + + _, scheduler_cfg = distributed_fused_adam_with_cosine_annealing_samples( + lr_warmup_samples=None, # Should default to None for auto calculation + lr_decay_samples=None, # Should default to None for auto calculation + ) + + assert isinstance(scheduler_cfg, SchedulerConfig) + assert scheduler_cfg.lr_warmup_samples is None # Will auto-calculate from train_samples + assert scheduler_cfg.lr_decay_samples is None # Will auto-calculate from train_samples + assert scheduler_cfg.lr_warmup_iters == 0 + assert scheduler_cfg.lr_decay_iters is None diff --git a/tests/unit_tests/training/test_config.py b/tests/unit_tests/training/test_config.py index 3af6e44cd..32bffcebf 100644 --- a/tests/unit_tests/training/test_config.py +++ b/tests/unit_tests/training/test_config.py @@ -575,7 +575,7 @@ def test_scheduler_wsd_decay_steps_none(self, monkeypatch): def test_scheduler_lr_warmup_steps_from_fraction(self, monkeypatch): """Test `lr_warmup_steps` calculation from `lr_warmup_fraction`.""" gpt_model_cfg = create_test_gpt_config() - train_cfg = create_test_training_config(train_iters=1000) + train_cfg = create_test_training_config(train_iters=1000, global_batch_size=32) lr_warmup_fraction = 0.1 sched_cfg = create_test_scheduler_config( lr_warmup_fraction=lr_warmup_fraction, lr_warmup_iters=0 @@ -586,8 +586,7 @@ def test_scheduler_lr_warmup_steps_from_fraction(self, monkeypatch): ) try: container.validate() - # lr_decay_iters in scheduler_config defaults to train_config.train_iters - expected_lr_warmup_steps = lr_warmup_fraction * train_cfg.train_iters * train_cfg.global_batch_size + expected_lr_warmup_steps = lr_warmup_fraction * (train_cfg.train_iters * train_cfg.global_batch_size) assert container.scheduler.lr_warmup_steps == expected_lr_warmup_steps finally: restore_get_world_size_safe(og_ws, cfg_mod) @@ -609,12 +608,12 @@ def test_scheduler_lr_warmup_steps_from_iters(self, monkeypatch): finally: restore_get_world_size_safe(og_ws, cfg_mod) - def test_scheduler_lr_warmup_steps_fraction_precedence(self, monkeypatch): - """Test `lr_warmup_fraction` takes precedence over `lr_warmup_iters`.""" + def test_scheduler_lr_warmup_fraction_and_iters_mutual_exclusivity(self, monkeypatch): + """Test that lr_warmup_fraction and lr_warmup_iters cannot both be specified.""" gpt_model_cfg = create_test_gpt_config() train_cfg = create_test_training_config(train_iters=1000, global_batch_size=10) lr_warmup_fraction = 0.05 - lr_warmup_iters = 50 # This should be ignored + lr_warmup_iters = 50 # This should not be allowed with lr_warmup_fraction sched_cfg = create_test_scheduler_config( lr_warmup_fraction=lr_warmup_fraction, lr_warmup_iters=lr_warmup_iters ) @@ -622,9 +621,9 @@ def test_scheduler_lr_warmup_steps_fraction_precedence(self, monkeypatch): world_size_override=1, model_config=gpt_model_cfg, train_config=train_cfg, scheduler_config=sched_cfg ) try: - container.validate() - expected_lr_warmup_steps = lr_warmup_fraction * train_cfg.train_iters * train_cfg.global_batch_size - assert container.scheduler.lr_warmup_steps == expected_lr_warmup_steps + # This should fail validation due to mutual exclusivity at scheduler finalize level + with pytest.raises(AssertionError, match="Cannot specify lr_warmup_fraction=0.05 with lr_warmup_iters=50"): + container.validate() finally: restore_get_world_size_safe(og_ws, cfg_mod) @@ -1777,3 +1776,277 @@ def test_integration_with_config_container_validation(self, mock_warn_rank_0): finally: restore_get_world_size_safe(og_ws, cfg_mod) + + +class TestSampleBasedTraining: + """Tests for sample-based training configuration and validation.""" + + def test_sample_based_training_config_creation(self): + """Test creating a valid sample-based training configuration.""" + train_cfg = create_test_training_config(train_samples=10000, train_iters=None, global_batch_size=32) + sched_cfg = create_test_scheduler_config( + lr_decay_samples=8000, lr_warmup_samples=1000, lr_decay_iters=None, lr_warmup_iters=0 + ) + + gpt_model_cfg = create_test_gpt_config() + container, og_ws, cfg_mod = create_test_config_container( + world_size_override=1, model_config=gpt_model_cfg, train_config=train_cfg, scheduler_config=sched_cfg + ) + + try: + container.validate() + # Verify train_iters was calculated from train_samples + expected_train_iters = train_cfg.train_samples // train_cfg.global_batch_size + assert container.train.train_iters == expected_train_iters + + # Verify scheduler steps for sample-based training + assert container.scheduler.lr_decay_steps == sched_cfg.lr_decay_samples + assert container.scheduler.wd_incr_steps == train_cfg.train_samples + assert container.scheduler.lr_warmup_steps == sched_cfg.lr_warmup_samples + finally: + restore_get_world_size_safe(og_ws, cfg_mod) + + def test_sample_based_training_with_warmup_fraction(self): + """Test sample-based training with lr_warmup_fraction.""" + train_cfg = create_test_training_config(train_samples=10000, train_iters=None, global_batch_size=32) + sched_cfg = create_test_scheduler_config( + lr_decay_samples=8000, lr_warmup_fraction=0.1, lr_warmup_samples=0, lr_decay_iters=None, lr_warmup_iters=0 + ) + + gpt_model_cfg = create_test_gpt_config() + container, og_ws, cfg_mod = create_test_config_container( + world_size_override=1, model_config=gpt_model_cfg, train_config=train_cfg, scheduler_config=sched_cfg + ) + + try: + container.validate() + # Verify warmup steps calculated from fraction of decay steps (sample count) + expected_lr_warmup_steps = sched_cfg.lr_warmup_fraction * sched_cfg.lr_decay_samples + assert container.scheduler.lr_warmup_steps == expected_lr_warmup_steps + finally: + restore_get_world_size_safe(og_ws, cfg_mod) + + def test_training_mode_mutual_exclusivity(self): + """Test that train_iters and train_samples cannot both be specified.""" + train_cfg = create_test_training_config(train_iters=1000, train_samples=10000) + + gpt_model_cfg = create_test_gpt_config() + container, og_ws, cfg_mod = create_test_config_container( + world_size_override=1, model_config=gpt_model_cfg, train_config=train_cfg + ) + + try: + with pytest.raises(AssertionError, match="Cannot specify both train_iters and train_samples"): + container.validate() + finally: + restore_get_world_size_safe(og_ws, cfg_mod) + + def test_training_mode_required(self): + """Test that either train_iters or train_samples must be specified.""" + train_cfg = create_test_training_config(train_iters=None) + # train_samples defaults to None + + gpt_model_cfg = create_test_gpt_config() + container, og_ws, cfg_mod = create_test_config_container( + world_size_override=1, model_config=gpt_model_cfg, train_config=train_cfg + ) + + try: + with pytest.raises(AssertionError, match="Either train_iters or train_samples must be provided"): + container.validate() + finally: + restore_get_world_size_safe(og_ws, cfg_mod) + + def test_sample_based_scheduler_field_validation(self): + """Test that sample-based training rejects iteration-based scheduler fields.""" + train_cfg = create_test_training_config(train_samples=10000, train_iters=None) + sched_cfg = create_test_scheduler_config(lr_decay_iters=500) # Should not be used with sample-based + + gpt_model_cfg = create_test_gpt_config() + container, og_ws, cfg_mod = create_test_config_container( + world_size_override=1, model_config=gpt_model_cfg, train_config=train_cfg, scheduler_config=sched_cfg + ) + + try: + with pytest.raises( + AssertionError, match="Use lr_decay_samples for sample-based training, not lr_decay_iters" + ): + container.validate() + finally: + restore_get_world_size_safe(og_ws, cfg_mod) + + def test_iteration_based_scheduler_field_validation(self): + """Test that iteration-based training rejects sample-based scheduler fields.""" + train_cfg = create_test_training_config(train_iters=1000) + sched_cfg = create_test_scheduler_config(lr_decay_samples=8000) # Should not be used with iteration-based + + gpt_model_cfg = create_test_gpt_config() + container, og_ws, cfg_mod = create_test_config_container( + world_size_override=1, model_config=gpt_model_cfg, train_config=train_cfg, scheduler_config=sched_cfg + ) + + try: + with pytest.raises( + AssertionError, match="Use lr_decay_iters for iteration-based training, not lr_decay_samples" + ): + container.validate() + finally: + restore_get_world_size_safe(og_ws, cfg_mod) + + def test_sample_based_warmup_mutual_exclusivity(self): + """Test mutual exclusivity between lr_warmup_fraction and lr_warmup_samples.""" + train_cfg = create_test_training_config(train_samples=10000, train_iters=None) + sched_cfg = create_test_scheduler_config( + lr_warmup_fraction=0.1, + lr_warmup_samples=1000, # Both specified - should fail + ) + + gpt_model_cfg = create_test_gpt_config() + container, og_ws, cfg_mod = create_test_config_container( + world_size_override=1, model_config=gpt_model_cfg, train_config=train_cfg, scheduler_config=sched_cfg + ) + + try: + # This should now fail at scheduler finalize level with detailed field values + with pytest.raises( + AssertionError, match="Cannot specify lr_warmup_fraction=0.1 with.*lr_warmup_samples=1000" + ): + container.validate() + finally: + restore_get_world_size_safe(og_ws, cfg_mod) + + def test_sample_based_with_rampup_batch_size_fails(self): + """Test that sample-based training with rampup_batch_size raises ValueError.""" + train_cfg = create_test_training_config(train_samples=10000, train_iters=None, rampup_batch_size=[16, 8, 5000]) + + gpt_model_cfg = create_test_gpt_config() + container, og_ws, cfg_mod = create_test_config_container( + world_size_override=1, model_config=gpt_model_cfg, train_config=train_cfg + ) + + try: + with pytest.raises(AssertionError, match="Batch size rampup not supported with sample-based training yet"): + container.validate() + finally: + restore_get_world_size_safe(og_ws, cfg_mod) + + def test_sample_based_lr_decay_samples_defaults(self): + """Test that lr_decay_samples defaults to train_samples.""" + train_cfg = create_test_training_config(train_samples=10000, train_iters=None) + sched_cfg = create_test_scheduler_config(lr_decay_samples=None) # Should default to train_samples + + gpt_model_cfg = create_test_gpt_config() + container, og_ws, cfg_mod = create_test_config_container( + world_size_override=1, model_config=gpt_model_cfg, train_config=train_cfg, scheduler_config=sched_cfg + ) + + try: + container.validate() + assert container.scheduler.lr_decay_samples == train_cfg.train_samples + assert container.scheduler.lr_decay_steps == train_cfg.train_samples + finally: + restore_get_world_size_safe(og_ws, cfg_mod) + + def test_sample_based_wsd_decay_steps(self): + """Test WSD decay steps calculation for sample-based training.""" + train_cfg = create_test_training_config(train_samples=10000, train_iters=None) + sched_cfg = create_test_scheduler_config(lr_wsd_decay_samples=5000) + + gpt_model_cfg = create_test_gpt_config() + container, og_ws, cfg_mod = create_test_config_container( + world_size_override=1, model_config=gpt_model_cfg, train_config=train_cfg, scheduler_config=sched_cfg + ) + + try: + container.validate() + assert container.scheduler.wsd_decay_steps == sched_cfg.lr_wsd_decay_samples + finally: + restore_get_world_size_safe(og_ws, cfg_mod) + + def test_sample_based_vs_iteration_based_config_equivalence(self): + """Test that equivalent sample-based and iteration-based configs produce same scheduler steps.""" + from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing_samples + + # Sample-based config + sample_train_cfg = create_test_training_config(train_samples=32, train_iters=None, global_batch_size=4) + sample_optimizer_cfg, sample_scheduler_cfg = distributed_fused_adam_with_cosine_annealing_samples( + lr_warmup_samples=8, + lr_decay_samples=24, + max_lr=1e-3, + ) + + sample_model_cfg = create_test_gpt_config() + sample_container, og_ws1, cfg_mod1 = create_test_config_container( + world_size_override=1, + model_config=sample_model_cfg, + train_config=sample_train_cfg, + scheduler_config=sample_scheduler_cfg, + ) + + # Equivalent iteration-based config + iter_train_cfg = create_test_training_config(train_iters=8, global_batch_size=4) # 32 samples / 4 batch_size + iter_scheduler_cfg = create_test_scheduler_config( + lr_warmup_iters=2, # 8 samples / 4 batch_size + lr_decay_iters=6, # 24 samples / 4 batch_size + ) + + iter_model_cfg = create_test_gpt_config() + iter_container, og_ws2, cfg_mod2 = create_test_config_container( + world_size_override=1, + model_config=iter_model_cfg, + train_config=iter_train_cfg, + scheduler_config=iter_scheduler_cfg, + ) + + try: + # Validate both configurations + sample_container.validate() + iter_container.validate() + + # Both should have the same final train_iters + assert sample_container.train.train_iters == iter_container.train.train_iters == 8 + + # Both should have equivalent scheduler steps (different calculation, same result) + assert sample_container.scheduler.lr_decay_steps == 24 # Direct sample count + assert iter_container.scheduler.lr_decay_steps == 6 * 4 # lr_decay_iters * global_batch_size = 24 + assert sample_container.scheduler.lr_decay_steps == iter_container.scheduler.lr_decay_steps + + # Both should have equivalent warmup steps + assert sample_container.scheduler.lr_warmup_steps == 8 # Direct sample count + assert iter_container.scheduler.lr_warmup_steps == 2 * 4 # lr_warmup_iters * global_batch_size = 8 + assert sample_container.scheduler.lr_warmup_steps == iter_container.scheduler.lr_warmup_steps + + finally: + restore_get_world_size_safe(og_ws1, cfg_mod1) + restore_get_world_size_safe(og_ws2, cfg_mod2) + + def test_scheduler_field_mixing_validation(self): + """Test that mixing iteration-based and sample-based scheduler fields fails in scheduler finalize.""" + # This should fail at the SchedulerConfig.finalize() level, before cross-validation + sched_cfg = create_test_scheduler_config( + lr_decay_iters=100, # iteration-based + lr_decay_samples=1000, # sample-based - mixing not allowed + ) + + with pytest.raises(AssertionError, match="Cannot mix iteration-based and sample-based scheduler fields"): + sched_cfg.finalize() + + def test_scheduler_warmup_fraction_with_iters_validation(self): + """Test that lr_warmup_fraction with lr_warmup_iters fails in scheduler finalize.""" + sched_cfg = create_test_scheduler_config( + lr_warmup_fraction=0.1, + lr_warmup_iters=100, # Should not be mixed with lr_warmup_fraction + ) + + with pytest.raises(AssertionError, match="Cannot specify lr_warmup_fraction=0.1 with lr_warmup_iters=100"): + sched_cfg.finalize() + + def test_scheduler_warmup_fraction_with_samples_validation(self): + """Test that lr_warmup_fraction with lr_warmup_samples fails in scheduler finalize.""" + sched_cfg = create_test_scheduler_config( + lr_warmup_fraction=0.1, + lr_warmup_samples=1000, # Should not be mixed with lr_warmup_fraction + ) + + with pytest.raises(AssertionError, match="Cannot specify lr_warmup_fraction=0.1 with.*lr_warmup_samples=1000"): + sched_cfg.finalize()