diff --git a/requirements.txt b/requirements.txt index 19c81d73..15445c5a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ wheel>=0.43 pyyaml py-cpuinfo torch>=2.6.0 -transformers>=4.45.2 +transformers>=4.55.0 datasets>=2.15.0 numba diff --git a/src/instructlab/training/accelerator.py b/src/instructlab/training/accelerator.py index b03c4a45..4baa7c0e 100644 --- a/src/instructlab/training/accelerator.py +++ b/src/instructlab/training/accelerator.py @@ -1,9 +1,16 @@ # Standard from copy import deepcopy -from typing import Callable, Optional +from functools import partial +from typing import Optional +import logging # Third Party from accelerate import Accelerator as TransformersAccel +from accelerate.utils import DeepSpeedPlugin, FullyShardedDataParallelPlugin +from peft.utils.other import fsdp_auto_wrap_policy +from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data import DataLoader from transformers import get_scheduler import torch @@ -13,10 +20,13 @@ DeepSpeedOptions, DistributedBackend, ) +from instructlab.training.utils import get_module_class_from_name # Local from .model import Model +logger = logging.getLogger(__name__) + class Accelerator: def __init__( @@ -32,6 +42,7 @@ def __init__( deepspeed_cpu_offload_optimizer_pin_memory: Optional[bool] = False, deepspeed_cpu_offload_optimizer_ratio: Optional[float] = None, fsdp_cpu_offload_params: Optional[bool] = False, + fsdp_use_orig_params: Optional[bool] = False, ): self.samples_per_gpu = samples_per_gpu self.save_samples = save_samples @@ -48,7 +59,8 @@ def __init__( deepspeed_cpu_offload_optimizer_ratio ) self.fsdp_cpu_offload_params = fsdp_cpu_offload_params - + self.fsdp_use_orig_params = fsdp_use_orig_params + self.lr_scheduler = None if self.distributed_framework == DistributedBackend.DEEPSPEED: # Standard accel_args = { @@ -84,7 +96,6 @@ def prepare_with_optimizer( num_epochs: int, num_warmup_steps: int, ): - self.lr_scheduler: Callable self.setup_lr_scheduler( optimizer=optimizer, lr_scheduler=lr_scheduler, @@ -120,19 +131,6 @@ def __getattr__(self, name): return getattr(self.accelerator, name) def get_fsdp_config(self): - # Standard - from functools import partial - - # Third Party - from accelerate.utils import FullyShardedDataParallelPlugin - from peft.utils.other import fsdp_auto_wrap_policy - from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy - from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload - from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy - - # First Party - from instructlab.training.utils import get_module_class_from_name - is_lora = self.model.lora_config is not None block_name = self.model._no_split_modules[0] @@ -158,20 +156,16 @@ def get_fsdp_config(self): backward_prefetch=prefetch_policy, sharding_strategy=ShardingStrategy[self.fsdp_sharding_strategy], cpu_offload=CPUOffload(self.fsdp_cpu_offload_params), + use_orig_params=self.fsdp_use_orig_params, + # TODO(osilkin): expose switch for fp32 reduction ) - # `use_orig_params` must be disabled when using LoRA and FSDP together - # Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts - if self.model.lora_config is not None: - fsdp_plugin.use_orig_params = False - return fsdp_plugin def get_ds_plugin( self, world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions ): # Third Party - from accelerate.utils import DeepSpeedPlugin ds_config = { "train_batch_size": samples_per_gpu * world_size * grad_accum, @@ -248,3 +242,12 @@ def setup_fsdp( fsdp_cpu_offload_params=fsdp_cpu_offload_params, save_samples=save_samples, ) + + def take_optimizer_step(self): + """ + Take an optimizer step and update the learning rate scheduler. + """ + self.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py new file mode 100644 index 00000000..f0e10a89 --- /dev/null +++ b/src/instructlab/training/batch_loss_manager.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Batch loss management for distributed training. + +This module provides utilities for managing loss computation, accumulation, +and reduction across distributed training environments. +""" + +# Standard +from dataclasses import dataclass +import logging + +# Third Party +import torch +import torch.distributed + +# First Party +from instructlab.training.accelerator import Accelerator +from instructlab.training.model import Model +from instructlab.training.type_definitions import CollatedItem, ModelInputs + +logger = logging.getLogger("instructlab.training") + + +@dataclass +class BatchMetrics: + """Metrics collected during batch processing.""" + + total_samples: int + total_length: int + num_loss_counted_tokens: int + accumulated_loss: torch.Tensor + accumulated_aux_loss: torch.Tensor | None + grad_accum_steps: int + num_minibatches: int + + +class BatchLossManager: + """ + Manages loss computation and metrics collection for batches in distributed training. + + This class handles: + - Processing minibatches within a batch + - Accumulating losses across minibatches + - Reducing metrics across distributed ranks + - Computing average losses for logging + """ + + def __init__(self, model, accelerator, world_size: int, local_rank: int): + """ + Initialize the BatchLossManager. + + Args: + model: The model used for training + accelerator: The accelerator instance for distributed training + world_size: Number of distributed processes + local_rank: Local rank of the current process + """ + self.model: Model = model + self.accelerator: Accelerator = accelerator + self.world_size: int = world_size + self.local_rank: int = local_rank + self.torch_device = torch.device("cuda", local_rank) + + def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]: + """ + Process a batch of minibatches, computing losses and accumulating gradients. + + Args: + batch: List of minibatches to process + + Returns: + tuple: (BatchMetrics, average_loss_across_ranks) + """ + # extract batch-level info (same across all minibatches) + batch_num_loss_counted_tokens = batch[0]["batch_num_loss_counted_tokens"] + num_minibatches = len(batch) + + # initialize accumulation variables + batch_total_samples = 0 + batch_total_length = 0 + accumulated_loss = 0.0 + accumulated_aux_loss = 0.0 + grad_accum_steps = 0 + + # process each minibatch + for mb in batch: + # extract minibatch-specific info + micro_batch_size = mb["num_samples"] + total_length = mb["total_length"] + + # accumulate minibatch metrics + batch_total_samples += micro_batch_size + batch_total_length += total_length + + # prepare model inputs + model_inputs = self._prepare_model_inputs(mb) + + # compute loss and backward pass + scaled_loss, raw_losses = self.model.compute_loss( + model_inputs, self.world_size, batch_num_loss_counted_tokens + ) + self.accelerator.backward(scaled_loss) + + # accumulate losses + grad_accum_steps += 1 + accumulated_loss += raw_losses.main_loss + if raw_losses.aux_loss is not None: + accumulated_aux_loss += raw_losses.aux_loss + + # reduce metrics across ranks + batch_total_samples, batch_total_length = self._reduce_metrics( + batch_total_samples, batch_total_length + ) + + # calculate average loss across all ranks + avg_loss_across_ranks = self._compute_average_loss( + accumulated_loss, accumulated_aux_loss, batch_num_loss_counted_tokens + ) + + # create metrics object + metrics = BatchMetrics( + total_samples=int(batch_total_samples), + total_length=int(batch_total_length), + num_loss_counted_tokens=int(batch_num_loss_counted_tokens), + accumulated_loss=accumulated_loss, + accumulated_aux_loss=accumulated_aux_loss, + grad_accum_steps=grad_accum_steps, + num_minibatches=num_minibatches, + ) + + return metrics, avg_loss_across_ranks + + def _prepare_model_inputs(self, mb: CollatedItem) -> ModelInputs: + """Prepare and move model inputs to GPU.""" + model_inputs = ModelInputs( + input_ids=mb["input_ids"].to(device=self.torch_device), + labels=mb["labels"].to(device=self.torch_device), + ) + + # add optional fields onto `model_inputs` object + if "attention_mask" in mb: + model_inputs["attention_mask"] = mb["attention_mask"].to( + device=self.torch_device + ) + if "position_ids" in mb: + model_inputs["position_ids"] = mb["position_ids"].to( + device=self.torch_device + ) + + return model_inputs + + def _reduce_metrics( + self, batch_total_samples: int, batch_total_length: int + ) -> tuple[int, int]: + """Reduce rank-specific metrics across devices.""" + inputs_to_reduce = torch.tensor( + [batch_total_samples, batch_total_length], + dtype=torch.int32, + device=self.accelerator.device, + ) + + reduced_outputs = self.accelerator.reduce(inputs_to_reduce, reduction="sum") + return reduced_outputs[0].item(), reduced_outputs[1].item() + + def _compute_average_loss( + self, + accumulated_loss: torch.Tensor, + accumulated_aux_loss: torch.Tensor | None, + batch_num_loss_counted_tokens: int, + ) -> float: + """Compute average loss across all ranks for metrics logging.""" + # calculate total batch loss + total_batch_loss = ( + accumulated_loss * self.world_size / batch_num_loss_counted_tokens + ) + if self.model.is_gpt_oss and accumulated_aux_loss is not None: + total_batch_loss += accumulated_aux_loss + + # reduce across ranks + avg_loss_across_ranks = self.accelerator.reduce( + torch.tensor( + total_batch_loss.detach().item(), device=self.accelerator.device + ), + reduction="mean", + ).item() + + return avg_loss_across_ranks diff --git a/src/instructlab/training/batch_packer.py b/src/instructlab/training/batch_packer.py new file mode 100644 index 00000000..48114cac --- /dev/null +++ b/src/instructlab/training/batch_packer.py @@ -0,0 +1,241 @@ +""" +Numba-optimized batch packing using LPT (Longest Processing Time) algorithm. + +This module provides high-performance batch packing for distributed training, +using JIT compilation for optimal speed while maintaining superior load balancing. +""" + +# Third Party +from numba import int64, njit +import numpy as np + + +@njit +def _lpt_check_heap( + heap: np.ndarray, lengths: np.ndarray, max_tokens: int64, n: int64 +) -> bool: + """Check if sequences can fit using min-heap for O(log n) insertions. + + Uses a binary min-heap where heap[1] is the root (heap[0] unused). + """ + # Sort lengths in descending order (longest first) + sorted_lengths = np.sort(lengths)[::-1] + heap[:] = 0 # Reset heap + + for size in sorted_lengths: + # Add to smallest element (root of min-heap) + heap[1] += size + if heap[1] > max_tokens: + return False + + # Heapify down (sink operation) + u = 1 + while (u << 1) <= n: # While node has at least one child + v = u << 1 # Left child + rch = (u << 1) | 1 # Right child + + # Find smallest child + if rch <= n and heap[rch] < heap[v]: + v = rch + + # If parent is smaller than smallest child, we're done + if heap[u] <= heap[v]: + break + + # Swap with smallest child + heap[u], heap[v] = heap[v], heap[u] + u = v + + return True + + +@njit +def _lpt_distribute_heap( + heap: np.ndarray, + heap_id: np.ndarray, + lengths: np.ndarray, + indices: np.ndarray, + n: int64, + rank: int64, +) -> np.ndarray: + """Distribute sequences using min-heap for efficient LPT scheduling. + + Returns indices assigned to the specified rank. + """ + # Sort by length descending, keeping track of original indices + sort_idx = np.argsort(lengths)[::-1] + sorted_lengths = lengths[sort_idx] + sorted_indices = indices[sort_idx] + + # Initialize heap and rank assignments + heap[:] = 0 + heap_id[:] = np.arange(-1, n, dtype=np.int64) + + # Pre-allocate result array (worst case: all sequences to one rank) + result = np.empty(len(lengths), dtype=np.int64) + result_count = 0 + + for i in range(len(sorted_lengths)): + # Add to smallest load (root of min-heap) + heap[1] += sorted_lengths[i] + + # If this is our rank, add the index + if heap_id[1] == rank: + result[result_count] = sorted_indices[i] + result_count += 1 + + # Heapify down + u = 1 + while (u << 1) <= n: + v = u << 1 # Left child + rch = (u << 1) | 1 # Right child + + # Find smallest child + if rch <= n and heap[rch] < heap[v]: + v = rch + + # If parent is smaller, we're done + if heap[u] <= heap[v]: + break + + # Swap values and IDs + heap[u], heap[v] = heap[v], heap[u] + heap_id[u], heap_id[v] = heap_id[v], heap_id[u] + u = v + + # Return only the filled portion + return result[:result_count] + + +@njit +def _batch_to_minibatches_lpt_core( + lengths: np.ndarray, max_tokens: int64, num_ranks: int64, rank: int64 +) -> tuple: + """Core numba-optimized LPT batching algorithm. + + Returns: + minibatch_indices: 2D array of indices for this rank + minibatch_sizes: number of sequences in each minibatch + """ + n_sequences = len(lengths) + if n_sequences == 0: + return np.empty((0, 0), dtype=np.int64), np.empty(0, dtype=np.int64) + + # Get sorted indices (longest first) + sorted_indices = np.argsort(lengths)[::-1] + sorted_lengths = lengths[sorted_indices] + + # Calculate cumulative sum for efficient range queries + lengths_cumsum = np.cumsum(sorted_lengths) + + # Pre-allocate heap (1-indexed, so size n+1) + heap = np.zeros(num_ranks + 1, dtype=np.int64) + heap_id = np.zeros(num_ranks + 1, dtype=np.int64) + + # Pre-allocate output arrays + max_minibatches = n_sequences # Worst case + minibatch_indices = np.full((max_minibatches, n_sequences), -1, dtype=np.int64) + minibatch_sizes = np.zeros(max_minibatches, dtype=np.int64) + n_minibatches = 0 + + start_idx = 0 + s = 0 # Current cumulative sum + + while start_idx < n_sequences: + # Binary search for maximum sequences that fit + # Search up to the point where total tokens would exceed capacity + remaining = n_sequences - start_idx + + # Find upper bound: sequences whose cumsum doesn't exceed current + capacity + if start_idx > 0: + s = lengths_cumsum[start_idx - 1] + else: + s = 0 + + # Binary search in the remaining sequences + left = 1 + right = min( + remaining + 1, + 1 + + np.searchsorted( + lengths_cumsum[start_idx:], s + max_tokens * num_ranks, "right" + ), + ) + + while right - left > 1: + mid = (left + right) // 2 + if _lpt_check_heap( + heap, sorted_lengths[start_idx : start_idx + mid], max_tokens, num_ranks + ): + left = mid + else: + right = mid + + end_idx = start_idx + left + + # Distribute this batch using LPT + batch_indices = sorted_indices[start_idx:end_idx] + batch_lengths = sorted_lengths[start_idx:end_idx] + + my_indices = _lpt_distribute_heap( + heap, heap_id, batch_lengths, batch_indices, num_ranks, rank + ) + + # Store result + if len(my_indices) > 0: + minibatch_indices[n_minibatches, : len(my_indices)] = my_indices + minibatch_sizes[n_minibatches] = len(my_indices) + else: + # Empty minibatch for this rank - use padding + minibatch_indices[n_minibatches, 0] = -1 + minibatch_sizes[n_minibatches] = 1 + + n_minibatches += 1 + start_idx = end_idx + + # Return trimmed arrays + return (minibatch_indices[:n_minibatches], minibatch_sizes[:n_minibatches]) + + +def batch_lengths_to_minibatches_lpt( + batch_lengths: list[int], max_tokens_per_rank: int, num_ranks: int, rank: int +): + """High-performance LPT batch packing using numba optimization. + + Distributes sequences across ranks using the Longest Processing Time (LPT) + algorithm with min-heap optimization for O(n log n log k) complexity. + + This provides optimal load balancing while being significantly faster than + naive implementations through JIT compilation. + + Args: + batch_lengths: List of sequence lengths (in tokens) + max_tokens_per_rank: Maximum tokens allowed per rank per minibatch + num_ranks: Total number of distributed training ranks (GPUs) + rank: The specific rank to retrieve assigned indices for + + Returns: + List of lists, where each inner list contains indices assigned to this rank + for one minibatch. Index -1 indicates padding. + """ + if not batch_lengths: + return [] + + # Convert to numpy + lengths = np.array(batch_lengths, dtype=np.int64) + + # Call numba-optimized core + minibatch_indices, minibatch_sizes = _batch_to_minibatches_lpt_core( + lengths, max_tokens_per_rank, num_ranks, rank + ) + + # Convert back to list format + result = [] + for i in range(len(minibatch_sizes)): + size = minibatch_sizes[i] + if minibatch_indices[i, 0] == -1: + result.append([-1]) + else: + result.append(minibatch_indices[i, :size].tolist()) + + return result diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 3c4e65f3..b561a8d6 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -184,11 +184,19 @@ class TrainingArgs(BaseModel): max_seq_len: int max_batch_len: int - num_epochs: int + num_epochs: int = Field( + default=1, description="Number of epochs to run through before stopping." + ) effective_batch_size: int - save_samples: int + save_samples: int = Field( + default=0, + description="Number of samples the model should see before saving a checkpoint. Consider this to be the checkpoint save frequency. If --save_samples<=0, this feature is disabled.", + ) learning_rate: float - warmup_steps: int + warmup_steps: int = Field( + default=0, + description="Number of warmup steps to run before starting the main training loop.", + ) random_seed: int = 42 # (jkunstle) left here for compatibility, but Dolomite is removed. diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index c917bbac..53744482 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Standard -from functools import partial +from functools import lru_cache, partial from pathlib import Path import logging import os @@ -11,7 +11,12 @@ # Third Party from datasets import Dataset, load_dataset -from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import ( + AutoConfig, + AutoTokenizer, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) import numpy as np import regex as re @@ -32,6 +37,14 @@ logger = logging.getLogger(__name__) +@lru_cache() +def is_gpt_oss_model(tokenizer: PreTrainedTokenizer) -> bool: + """Check if this is a GPT-OSS model based on tokenizer.""" + model_name_or_path = tokenizer.name_or_path + config = AutoConfig.from_pretrained(model_name_or_path) + return config.model_type == "gpt_oss" + + def check_valid_sample( tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, whole_sentence_tk: list[int], @@ -83,6 +96,7 @@ def unmask_message_content( pretrain_token, pretrain_end_token, tool_resp_tokens=None, + tokenizer=None, ): """ Create labels for tokens in a sequence with special handling for pretraining tokens and role-specific sequences. @@ -152,10 +166,83 @@ def find_longest_match(start_idx, sequences): default=None, ) + def is_gpt_oss_assistant_channel(pos): + """Check if current position is within a GPT-OSS assistant channel pattern.""" + # Look for pattern: <|start|>assistant<|channel|>CHANNEL_NAME<|message|> + if pos >= len(sentence_tk): + return False + + # Try to encode the expected pattern tokens + start_token = tokenizer.encode("<|start|>", add_special_tokens=False) + assistant_tokens = tokenizer.encode("assistant", add_special_tokens=False) + channel_token = tokenizer.encode("<|channel|>", add_special_tokens=False) + message_token = tokenizer.encode("<|message|>", add_special_tokens=False) + + # Look backwards from current position to see if we're in an assistant channel + # We need to find the pattern: start_token + assistant_tokens + channel_token + some_text + message_token + for lookback in range(min(pos + 1, 50)): # Look back up to 50 tokens + start_pos = pos - lookback + if start_pos < 0: + break + + # Check if we have the start of an assistant channel pattern + if ( + start_pos + + len(start_token) + + len(assistant_tokens) + + len(channel_token) + <= len(sentence_tk) + and sentence_tk[start_pos : start_pos + len(start_token)] == start_token + and sentence_tk[ + start_pos + len(start_token) : start_pos + + len(start_token) + + len(assistant_tokens) + ] + == assistant_tokens + and sentence_tk[ + start_pos + len(start_token) + len(assistant_tokens) : start_pos + + len(start_token) + + len(assistant_tokens) + + len(channel_token) + ] + == channel_token + ): + # Found assistant channel start, now look for the message token ahead + message_start = ( + start_pos + + len(start_token) + + len(assistant_tokens) + + len(channel_token) + ) + for forward in range( + min(len(sentence_tk) - message_start, 20) + ): # Look forward up to 20 tokens for message token + if ( + message_start + forward + len(message_token) <= len(sentence_tk) + and sentence_tk[ + message_start + forward : message_start + + forward + + len(message_token) + ] + == message_token + ): + # Found complete assistant channel pattern + message_end = message_start + forward + len(message_token) + # Check if current position is between message token and next special sequence + if pos >= message_end: + return True + break + break + + return False + special_sequences = [user_tokens, assist_tokens, system_tokens] if tool_resp_tokens: special_sequences.append(tool_resp_tokens) + # Check if this is a GPT-OSS model for special handling + is_gpt_oss = is_gpt_oss_model(tokenizer) + in_pretraining = False unmasking = False i = 0 @@ -171,12 +258,23 @@ def find_longest_match(start_idx, sequences): match = find_longest_match(i, special_sequences) if match: - unmasking = (match == assist_tokens) or ( - example["unmask"] and match != system_tokens - ) + # For GPT-OSS models, always unmask assistant tokens regardless of unmask flag + if is_gpt_oss and match == assist_tokens: + unmasking = True + else: + unmasking = (match == assist_tokens) or ( + example["unmask"] and match != system_tokens + ) i += len(match) continue + # Special case: Check for GPT-OSS assistant channel patterns + # For GPT-OSS, always unmask assistant channels regardless of unmask flag + if is_gpt_oss and is_gpt_oss_assistant_channel(i): + unmasking = True + elif example["unmask"] and is_gpt_oss_assistant_channel(i): + unmasking = True + if in_pretraining or unmasking: labels[i] = sentence_tk[i] i += 1 @@ -395,6 +493,7 @@ def process_messages_into_input_ids_with_chat_template(args: DataProcessArgs): pretrain_token=get_sp_token(tokenizer, "<|pretrain|>")[0], pretrain_end_token=get_sp_token(tokenizer, "<|/pretrain|>")[0], tool_resp_tokens=tool_resp_tk, + tokenizer=tokenizer, ) logger.info("Unmasking the appropriate message content...") data_with_labels = data_with_input_ids.map( @@ -473,12 +572,13 @@ def wrap_masked_messages( # here, we need to be on the lookout for both string and non-string # entries (e.g. other content types, or pure reasoning traces) - interesting_fields = ["content", "reasoning_content"] + interesting_fields = ["content", "reasoning_content", "thinking"] new_msg = {k: v for k, v in msg.items() if k not in interesting_fields} - # what's left to add then is content or reasoning_content + # what's left to add then is content, reasoning_content, or thinking content = msg.get("content", None) reasoning_content = msg.get("reasoning_content", None) + thinking = msg.get("thinking", None) # we handle these conditionally since these may become optional fields in the future. if content is not None: @@ -505,6 +605,22 @@ def wrap_masked_messages( # When not enabled, pass through unchanged new_msg["reasoning_content"] = reasoning_content + # Handle GPT-OSS "thinking" field similar to reasoning_content + if thinking is not None: + if enable_reasoning_content: + if not isinstance(thinking, str): + raise ValueError( + "Error: received an entry for `thinking` which was not a string. " + "Non-string datatypes for this field are currently unsupported, if this is intentional please raise an issue." + ) + + new_msg["thinking"] = ( + UNMASK_REASONING_BEGIN_TOKEN + thinking + UNMASK_REASONING_END_TOKEN + ) + else: + # When not enabled, pass through unchanged + new_msg["thinking"] = thinking + # MyPy wants to be very specific about types, but new_msg may contain # valid fields in each message which are hard to account for ahead of time. new_msgs += [new_msg] # type: ignore @@ -544,13 +660,12 @@ def unmask_messages( Returns: Result (ProcessedMessagesData): Dict with the resulting `input_ids`, `labels`, and `len` """ - # Check if any messages have reasoning_content that we need to handle + # Check if any messages have reasoning_content or thinking that we need to handle has_reasoning = any( - msg.get("reasoning_content") is not None + msg.get("reasoning_content") is not None or msg.get("thinking") is not None for msg in msgs if msg["role"] in unmask_roles ) - # TODO(osilkin): Here we assume that we will always unmask reasoning content, # in the future we can make this configurable. msgs_with_unmasking = wrap_masked_messages( @@ -562,7 +677,10 @@ def unmask_messages( for idx, msg in enumerate(msgs_with_unmasking): if msg["role"] in unmask_roles: regions = [] - if has_reasoning and msg.get("reasoning_content") is not None: + if has_reasoning and ( + msg.get("reasoning_content") is not None + or msg.get("thinking") is not None + ): regions.append("reasoning") if msg.get("content") is not None: regions.append("content") @@ -781,6 +899,9 @@ def unmask_sample( # TODO(osilkin): we should define an unmasking policy that # enables the user to more dynamically choose what should be unmasked and not. + # Check if this is a GPT-OSS model for special handling + is_gpt_oss = is_gpt_oss_model(tokenizer) + # if sample has `unmask` set to true, we unmask everything other than the system role, # else we only unmask assistant unmask_roles_set = {"assistant"} @@ -789,6 +910,10 @@ def unmask_sample( # the unique roles ahead of time unmask_roles_set = set(m["role"] for m in sample["messages"]) - {"system"} + # For GPT-OSS models, always unmask assistant regardless of unmask flag + elif is_gpt_oss: + unmask_roles_set = {"assistant"} + unmask_roles = list(unmask_roles_set) return unmask_messages(sample["messages"], tokenizer, unmask_roles) diff --git a/src/instructlab/training/gpt_oss_utils_correct.py b/src/instructlab/training/gpt_oss_utils_correct.py new file mode 100644 index 00000000..430a890f --- /dev/null +++ b/src/instructlab/training/gpt_oss_utils_correct.py @@ -0,0 +1,437 @@ +# SPDX-License-Identifier: Apache-2.0 + +""" +Correct GPT-OSS MXFP4 quantization implementation that matches OpenAI's format exactly. +Based on the official OSS specification. +""" + +# Standard +from typing import Dict +import logging +import re + +# Third Party +from transformers import AutoConfig, PretrainedConfig +import torch + +logger = logging.getLogger("instructlab.training") + +GROUP_SIZE = 32 # MXFP4 block size (last-dim groups) + + +# ---- E2M1 codebook (FP4: 1 sign, 2 exp, 1 mant, bias=1), 16 values ---- +# Exact values from PyTorch AO MXFP4 implementation +def _e2m1_decode_table(device=torch.device("cpu"), dtype=torch.float32): + # Exact FP4 E2M1 values from PyTorch AO - force float32 for consistency + fp4_values = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, # Positive values + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, # Negative values + ], + device=device, + dtype=dtype, + ) # Always use float32 for consistency + return fp4_values # shape [16] + + +@torch.no_grad() +def _find_closest_with_first_tie_breaking(values, table): + """ + Find closest FP4 values with OpenAI's exact tie-breaking rules. + Key insight: OpenAI respects IEEE 754 sign bits for zero values. + When input is negative zero (-0.0), prefer index 8 over index 0. + """ + # Ensure consistent precision for all calculations + values = values.to(torch.float32) + table = table.to(torch.float32) + + # Calculate squared distances with high precision + distances = (values.unsqueeze(-1) - table) ** 2 # [..., 16] + + # Start with argmin (which handles most cases correctly) + result_indices = torch.argmin(distances, dim=-1) + + # Special case: handle negative zero + # When input is negative zero and distance to both +0.0 and -0.0 is equal, + # prefer -0.0 (index 8) over +0.0 (index 0) + + min_distances = distances.min(dim=-1, keepdim=True)[0] + epsilon = 1e-10 + + # Find positions where both index 0 and 8 are tied (zero case) + zero_tie_mask = ( + (distances[..., 0] <= min_distances[..., 0] + epsilon) # Index 0 is tied + & (distances[..., 8] <= min_distances[..., 0] + epsilon) # Index 8 is tied + ) + + # Check which values are actually negative zero (sign bit = 1) + # Use torch.signbit to detect negative zero + is_negative_zero = torch.zeros_like(values, dtype=torch.bool) + + # Only check values that are actually zero + zero_mask = torch.abs(values) < 1e-10 + if zero_mask.any(): + # For zero values, check sign bit + with torch.no_grad(): + is_negative_zero = zero_mask & torch.signbit(values) + + # Apply the rule: if it's a zero tie and input is negative zero, choose index 8 + negative_zero_correction = zero_tie_mask & is_negative_zero + result_indices = torch.where(negative_zero_correction, 8, result_indices) + + return result_indices + + +# Quantize floats (normalized by the block scale) to nearest E2M1 code 0..15 +@torch.no_grad() +def _e2m1_encode(normalized: torch.Tensor) -> torch.Tensor: + # Force float32 for all calculations to ensure consistency + normalized = normalized.to(torch.float32) + table = _e2m1_decode_table(device=normalized.device, dtype=torch.float32) # [16] + + # Clamp to valid range first + normalized_clamped = torch.clamp(normalized, min=-6.0, max=6.0) + + # OPTIMIZED: Increased batch sizes and smarter memory management + # Process larger chunks since we're now batching more efficiently + if ( + normalized_clamped.dim() >= 3 and normalized_clamped.shape[0] > 32 + ): # Very large tensors + # Process in larger batches - modern GPUs can handle much more + batch_size = 32 # Increased from 4 to 32 for better GPU utilization + expert_results = [] + for start_idx in range(0, normalized_clamped.shape[0], batch_size): + end_idx = min(start_idx + batch_size, normalized_clamped.shape[0]) + expert_batch = normalized_clamped[start_idx:end_idx] + # Use proper tie-breaking that matches OpenAI's implementation + batch_indices = _find_closest_with_first_tie_breaking(expert_batch, table) + expert_results.append(batch_indices.to(torch.uint8)) + return torch.cat(expert_results, dim=0) + else: + # Process normally for smaller tensors + # Use proper tie-breaking that matches OpenAI's implementation + idx = _find_closest_with_first_tie_breaking(normalized_clamped, table) + return idx.to(torch.uint8) + + +@torch.no_grad() +def _pack_nibbles(low_nib: torch.Tensor, high_nib: torch.Tensor) -> torch.Tensor: + # both uint8 in [0,15] + return (low_nib | (high_nib << 4)).to(torch.uint8) + + +@torch.no_grad() +def _power2_scales_from_maxabs(blocks: torch.Tensor) -> torch.Tensor: + # blocks: [..., nblocks, G] + # Use exact PyTorch AO scale calculation with bit manipulation + maxabs = ( + blocks.abs().amax(dim=-1, keepdim=True).clamp_min(2 ** (-126)) + ) # [..., nblocks, 1] + + # Extract power-of-2 component from float32 representation (PyTorch AO method) + maxabs_int32 = maxabs.view(torch.int32) + extracted_pow2 = ((maxabs_int32 >> 23) & 0xFF) - 127 # Extract FP32 exponent + + # Calculate scale with target maximum power (4.0 = 2^2, so target_pow2 = 2) + target_max_pow2 = 2 # For FP4 E2M1 max value 4.0 + scale_unbiased = extracted_pow2 - target_max_pow2 + + # Clamp to valid range and remove keepdim + scale_unbiased = scale_unbiased.squeeze(-1).clamp(-127, 128) # [..., nblocks] + + # Return signed int8 exponent (transformers will handle +127 offset) + return scale_unbiased.to(torch.int8) # [..., nblocks] + + +@torch.no_grad() +def _quantize_tensor_to_mxfp4_param(weight: torch.Tensor, group_size: int = GROUP_SIZE): + """ + Returns (blocks_u8, scales_i8, meta) for a single 2D+ tensor quantized along the last dim. + + This function now uses OpenAI's exact algorithm with: + 1. Perfect signed zero handling in tie-breaking + 2. Interleaved nibble packing (even positions as low, odd as high) + 3. Correct dimensional mapping for expert parameters + """ + assert weight.ndim >= 2, "Quantize only 2D+ tensors" + x = weight.to(torch.float32) + + # Pad last dim to multiple of group_size + last = x.shape[-1] + pad = (group_size - (last % group_size)) % group_size + if pad: + x = torch.nn.functional.pad(x, (0, pad)) + + # [..., nblocks, G] + new_shape = (*x.shape[:-1], x.shape[-1] // group_size, group_size) + xb = x.view(new_shape) + + # per-block signed exponent e (int8); scale = 2**e + e_i8 = _power2_scales_from_maxabs( + xb.to(torch.float32) + ) # [..., nblocks] - ensure float32 + scale = torch.pow( + torch.tensor(2.0, device=x.device, dtype=torch.float32), e_i8.to(torch.float32) + ).unsqueeze(-1) # [..., nblocks, 1] + + y = xb * (1.0 / scale) # normalized (use reciprocal like Triton) + + # encode each element to E2M1 code [0..15] using OpenAI's exact tie-breaking + codes = _e2m1_encode(y) # uint8 [..., nblocks, G] + + # Pack using OpenAI's INTERLEAVED method: + # - Even positions (0, 2, 4, ...) become low nibbles + # - Odd positions (1, 3, 5, ...) become high nibbles + G = codes.shape[-1] + assert G % 2 == 0 + + # Split into even and odd positions (interleaved packing) + low_nibbles = codes[..., ::2] # Even positions: [0, 2, 4, 6, ...] + high_nibbles = codes[..., 1::2] # Odd positions: [1, 3, 5, 7, ...] + + # Pack nibbles: each byte = low_nibble | (high_nibble << 4) + packed = _pack_nibbles(low_nibbles, high_nibbles) # [..., nblocks, G/2] + + # Keep the 4D structure: [..., nblocks, 16] for blocks + # packed shape is [..., nblocks, G/2] where G=32, so G/2=16 + blocks_u8 = packed.contiguous() # Keep as [..., nblocks, 16] + + meta = { + "orig_shape": tuple(weight.shape), + "padded_last": int(pad), + "group_size": int(group_size), + "layout": "blocks_scales_lastdim", + "dtype": "mxfp4_e2m1", + } + return blocks_u8.to(torch.uint8), e_i8.contiguous(), meta + + +def convert_dequantized_to_quantized_format_correct( + state_dict: Dict[str, torch.Tensor], +) -> Dict[str, torch.Tensor]: + """ + Convert dequantized GPT-OSS parameters to quantized format using the correct OSS-compatible algorithm. + + This function converts: + - experts.down_proj -> experts.down_proj_blocks + experts.down_proj_scales + - experts.gate_up_proj -> experts.gate_up_proj_blocks + experts.gate_up_proj_scales + + Using the exact MXFP4 algorithm that matches OpenAI's format. + + Args: + state_dict: Model state dict with dequantized parameters + + Returns: + State dict with quantized format parameter names and correct MXFP4 quantization + """ + converted_state_dict = {} + conversion_count = 0 + + logger.info("๐Ÿ”ง Starting CORRECT GPT-OSS parameter conversion...") + logger.info(f"๐Ÿ“ฅ Input state dict has {len(state_dict)} parameters") + + # Pattern to match MoE expert weight parameters (not biases) + moe_param_pattern = re.compile(r"experts\.(gate_up_proj|down_proj)$") + + # First, copy all non-expert parameters to save memory + expert_params_to_convert = [] + + for param_name, param_tensor in state_dict.items(): + if moe_param_pattern.search(param_name): + # Store expert params for later conversion + expert_params_to_convert.append((param_name, param_tensor)) + else: + # Keep non-expert parameters - move to CPU and convert to bf16 for memory efficiency + if param_tensor.dtype == torch.float32: + converted_param = param_tensor.to(torch.bfloat16).cpu() + logger.debug( + f"๐Ÿ’พ {param_name}: converted float32 โ†’ bf16 and moved to CPU" + ) + else: + converted_param = param_tensor.cpu() + logger.debug( + f"๐Ÿ’พ {param_name}: moved to CPU, kept {param_tensor.dtype}" + ) + + converted_state_dict[param_name] = converted_param + + # Now convert expert parameters one at a time to manage GPU memory + for param_name, param_tensor in expert_params_to_convert: + logger.info( + f"๐Ÿ”„ Converting {param_name}: {param_tensor.shape} {param_tensor.dtype}" + ) + + try: + # Use OpenAI's exact dimensional mapping discovered through reverse engineering + # OpenAI's format: dequant[expert, row, col] -> blocks[expert, col, block_idx, byte_idx] + # This means we quantize along the row dimension (dim=1), not the column dimension + + logger.info( + f"๐Ÿ”„ Processing {param_name} with OpenAI's exact dimensional mapping" + ) + logger.info(f" Input shape: {param_tensor.shape}") + + if "gate_up_proj" in param_name: + # gate_up_proj: dequantized is [experts, rows, cols] = [32, 2880, 5760] + # OpenAI quantizes each column separately: [32, 2880] -> [32, 90, 16] per column + # Result: [experts, cols, blocks_per_col, bytes_per_block] = [32, 5760, 90, 16] + experts, rows, cols = param_tensor.shape + blocks_per_col = rows // GROUP_SIZE + + logger.info( + f" Processing {cols} columns, each with {blocks_per_col} blocks" + ) + + # OPTIMIZED: Process ALL columns at once using vectorized operations + # Reshape to process all columns simultaneously: [experts, cols, rows] = [32, 5760, 2880] + # Transpose to put columns first for efficient memory access + tensor_transposed = param_tensor.transpose(1, 2) # [32, 5760, 2880] + + # Reshape for batch quantization: [32*5760, 1, 2880] + # This allows us to quantize all expert-column pairs at once + total_columns = experts * cols + reshaped_for_quant = tensor_transposed.reshape(total_columns, 1, rows) + + logger.info( + f" VECTORIZED: Quantizing {total_columns} columns simultaneously" + ) + + # Single quantization call for all columns - MASSIVE speedup! + all_blocks_flat, all_scales_flat, _ = _quantize_tensor_to_mxfp4_param( + reshaped_for_quant, GROUP_SIZE + ) + + # Reshape back to the correct format + # all_blocks_flat: [32*5760, 1, 90, 16] -> [32, 5760, 90, 16] + blocks_u8 = all_blocks_flat.squeeze(1).reshape( + experts, cols, blocks_per_col, 16 + ) + # all_scales_flat: [32*5760, 1, 90] -> [32, 5760, 90] + scales_i8 = all_scales_flat.squeeze(1).reshape( + experts, cols, blocks_per_col + ) + + else: # down_proj + # down_proj: dequantized is [experts, rows, cols] = [32, 2880, 2880] + # Same vectorized logic as gate_up_proj + experts, rows, cols = param_tensor.shape + blocks_per_col = rows // GROUP_SIZE + + logger.info( + f" Processing {cols} columns, each with {blocks_per_col} blocks" + ) + + # OPTIMIZED: Process ALL columns at once using vectorized operations + # Transpose to put columns first: [32, 2880, 2880] -> [32, 2880, 2880] + tensor_transposed = param_tensor.transpose(1, 2) # [32, 2880, 2880] + + # Reshape for batch quantization: [32*2880, 1, 2880] + total_columns = experts * cols + reshaped_for_quant = tensor_transposed.reshape(total_columns, 1, rows) + + logger.info( + f" VECTORIZED: Quantizing {total_columns} columns simultaneously" + ) + + # Single quantization call for all columns - MASSIVE speedup! + all_blocks_flat, all_scales_flat, _ = _quantize_tensor_to_mxfp4_param( + reshaped_for_quant, GROUP_SIZE + ) + + # Reshape back to the correct format + # all_blocks_flat: [32*2880, 1, 90, 16] -> [32, 2880, 90, 16] + blocks_u8 = all_blocks_flat.squeeze(1).reshape( + experts, cols, blocks_per_col, 16 + ) + # all_scales_flat: [32*2880, 1, 90] -> [32, 2880, 90] + scales_i8 = all_scales_flat.squeeze(1).reshape( + experts, cols, blocks_per_col + ) + + # Create new parameter names with _blocks and _scales + blocks_name = param_name + "_blocks" + scales_name = param_name + "_scales" + + # Add +127 offset to scales for uint8 storage (HF format) + scales_u8 = (scales_i8.to(torch.int32) + 127).clamp(0, 255).to(torch.uint8) + + # Store quantized parameters (move to CPU to save GPU memory) + converted_state_dict[blocks_name] = blocks_u8.cpu() + converted_state_dict[scales_name] = scales_u8.cpu() + + logger.info(f"โœ… {blocks_name}: {blocks_u8.shape} {blocks_u8.dtype}") + logger.info(f"โœ… {scales_name}: {scales_u8.shape} {scales_u8.dtype}") + + conversion_count += 1 + + # Clear GPU memory after each conversion + del blocks_u8, scales_i8, scales_u8 + torch.cuda.empty_cache() + + except Exception as e: + logger.error(f"โŒ Failed to convert {param_name}: {e}") + raise e + + logger.info( + f"๐ŸŽฏ Converted {conversion_count} expert parameters using correct MXFP4 algorithm" + ) + logger.info(f"๐Ÿ“Š Output state dict has {len(converted_state_dict)} parameters") + + return converted_state_dict + + +def is_gpt_oss(model_path_or_config: str | PretrainedConfig) -> bool: + """ + Determine if we should convert GPT-OSS format during saving. + """ + if not isinstance(model_path_or_config, (PretrainedConfig, str)): + raise ValueError( + f"cannot detect model: received invalid argument of type {type(model_path_or_config)}" + ) + + # convert to config + model_config = model_path_or_config + if isinstance(model_path_or_config, str): + model_config = AutoConfig.from_pretrained(model_path_or_config) + + return getattr(model_config, "model_type", None) == "gpt_oss" + + +def add_gpt_oss_quantization_config(config): + """ + Add GPT-OSS quantization configuration to a model config object. + + Args: + config: A transformers PretrainedConfig object + + Returns: + The config object with quantization settings added + """ + # add the quantization config if not present + if not hasattr(config, "quantization_config") or config.quantization_config is None: + config.quantization_config = { + "modules_to_not_convert": [ + "model.layers.*.self_attn", + "model.layers.*.mlp.router", + "model.embed_tokens", + "lm_head", + ], + "quant_method": "mxfp4", + } + logger.info("Added GPT-OSS quantization config to model config") + + return config diff --git a/src/instructlab/training/ilab_to_sdg.py b/src/instructlab/training/ilab_to_sdg.py deleted file mode 100644 index f020e478..00000000 --- a/src/instructlab/training/ilab_to_sdg.py +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Standard -import glob -import json -import logging -import os -import sys - -# First Party -from instructlab.training.logger import setup_root_logger - -logger = logging.getLogger("instructlab.training") - - -def ilab_to_sdb(ilab_train_data_dir, prefix, taxonomy_path): - """ - Convert a ilab train dataset to a SDG compatible format for backend training - - Parameters: - ilab_train_data_dir: path to directory where train dataset generated by ilab lives. - We automatically pick the latest training file from the ones present - there. - taxonomy_path: path to the taxonomy dataset used for generating the train dataset - """ - items_list = [] - files = glob.glob( - os.path.join(ilab_train_data_dir, prefix + "*.jsonl"), recursive=False - ) - try: - files.sort(reverse=True) - latest_train_file = files[0] - except IndexError: - logger.error("IndexError: no matching files found", exc_info=True) - return - logger.info("Converting %s", latest_train_file) - with open(latest_train_file, "r") as file: - # Read each line (which represents a JSON object) - for line in file: - line = json.loads(line) - new_dict = {"messages": []} - for key, value in line.items(): - tmp = {} - tmp["content"] = value - tmp["role"] = key - new_dict["messages"].append(tmp) - - new_dict["group"] = "lab_extension" - new_dict["dataset"] = taxonomy_path - new_dict["metadata"] = '{"num_turns": 1}' - items_list.append(new_dict) - - with open("sdg_out.jsonl", "a", encoding="utf-8") as file: - for item in items_list: - file.write(json.dumps(item)) - file.write("\n") - - -if __name__ == "__main__": - setup_root_logger() - if len(sys.argv) > 1: - ilab_train_data_dir = sys.argv[1] - prefix = sys.argv[2] - taxonomy = sys.argv[3] - else: - logger.critical("provide ilab train data dir as an argument") - sys.exit(1) - ilab_to_sdb(ilab_train_data_dir, prefix, taxonomy) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index df166fd4..dc8a171e 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -37,11 +37,12 @@ from tqdm import tqdm from transformers import AutoConfig import torch -import torch.distributed +import torch.distributed as dist # First Party from instructlab.training import config from instructlab.training.accelerator import Accelerator +from instructlab.training.batch_loss_manager import BatchLossManager from instructlab.training.config import ( DistributedBackend, ModelTypes, @@ -61,14 +62,14 @@ Model, setup_optimizer, ) -from instructlab.training.multipack_sampler import ( - find_packing_max_batch_len_and_grad_accum, -) -from instructlab.training.token_dataset import setup_dataloader, setup_dataset +from instructlab.training.sampler import get_data_loader + +# Removed old multipack_sampler import - using mini_trainer approach from instructlab.training.tokenizer_utils import setup_tokenizer from instructlab.training.utils import ( StreamablePopen, check_valid_train_args, + freeze_router_params, load_latest_full_state, save_checkpoint, save_hf_format_accelerate, @@ -82,7 +83,6 @@ def train( args, model: Model, - optimizer: torch.optim.Optimizer, accelerator: Accelerator, ): model.train() @@ -94,7 +94,8 @@ def train( metric_logger = logging.getLogger("instructlab.training.metrics") base_logger = logging.getLogger("instructlab.training") - batch_size = args.effective_batch_size // accelerator.grad_accum + # Mini_trainer approach: batch_size will be determined dynamically by data loader + # For save logic, use effective_batch_size since that's the target samples_seen = 0 if hasattr(args, "samples_seen"): @@ -102,22 +103,20 @@ def train( samples_seen = args.samples_seen if accelerator.save_samples > 0: - accelerator.save_samples = (accelerator.save_samples // batch_size) * batch_size logger.info("Number of samples per save: %d", args.save_samples) if args.save_samples_ds is not None: - args.save_samples_ds = (args.save_samples_ds // batch_size) * batch_size logger.info("Number of samples per DS save: %d", args.save_samples_ds) global_grad_norm = None - for epoch in range(args.current_epoch, args.num_epochs): - if args.sampler in ("multipack"): - accelerator.train_loader.batch_sampler.set_epoch(epoch) - elif args.sampler in ("distributed"): - accelerator.train_loader.sampler.set_epoch(epoch) - else: - raise NotADirectoryError + # Initialize the batch loss manager + batch_loss_manager = BatchLossManager(model, accelerator, world_size, local_rank) + + # Blast through batches + for epoch in range(args.current_epoch, args.num_epochs): + # set the epoch for correct sampling + accelerator.train_loader.sampler.set_epoch(epoch) num_epoch_steps = len(accelerator.train_loader) if local_rank == 0: inner_pb = tqdm(range(num_epoch_steps), desc=f"Epoch {epoch}") @@ -131,51 +130,25 @@ def train( inner_pb.update(1) continue start = time.time() - num_loss_counted_tokens = float( - torch.tensor([batch.pop("num_loss_counted_tokens")]) - ) - micro_batch_size = float(torch.tensor([batch.pop("num_samples")])) - total_length = float(torch.tensor([batch.pop("total_length")])) - for k in batch: - batch[k] = batch[k].to(local_rank) - output = model( - **batch, - use_cache=False, - ) - loss = output.loss - log_loss = loss.detach().item() - - num_loss_counted_tokens, micro_batch_size, log_loss = map( - float, - accelerator.reduce( - torch.tensor( - [num_loss_counted_tokens, micro_batch_size, log_loss], - dtype=torch.float32, - device=accelerator.device, - ), - reduction="sum", - ), + + # Process the batch using the BatchLossManager + batch_metrics, avg_loss_across_ranks = batch_loss_manager.process_batch( + batch ) - samples_seen += int(micro_batch_size) - # num_loss_counted_tokens = aggregated_values[0] - loss = ( - loss / num_loss_counted_tokens * world_size - ) # dividing by the total number of non-padding tokens and multiplying by the number of GPUs so when accelerate averages by world_size, it will be the correct loss. + # Update samples seen + samples_seen += batch_metrics.total_samples + base_logger.info( - f"Epoch: {epoch}, Step: {global_step}, Rank: {torch.distributed.get_rank()}, loss = {loss}" + f"Epoch: {epoch}, Step: {global_step}, Rank: {dist.get_rank()}, loss = {avg_loss_across_ranks:.6f}, grad_accum_steps = {batch_metrics.grad_accum_steps}" ) - accelerator.backward(loss) - if global_step % accelerator.grad_accum == 0: - global_grad_norm = accelerator.clip_grad_norm_(model.parameters(), 1.0) - optimizer.step() - accelerator.lr_scheduler.step() - optimizer.zero_grad() + # Take optimizer step after all minibatches + accelerator.take_optimizer_step() if local_rank == 0: elapsed_time = time.time() - start - overall_throughput = args.samples_per_gpu * world_size / elapsed_time + overall_throughput = batch_metrics.total_samples / elapsed_time current_lr = accelerator.lr_scheduler.get_last_lr()[0] cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3) cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] @@ -197,15 +170,16 @@ def train( { "epoch": epoch, "step": global_step, - "rank": torch.distributed.get_rank(), + "rank": dist.get_rank(), "overall_throughput": overall_throughput, "lr": current_lr, "cuda_mem_allocated": cuda_mem_allocated, "cuda_malloc_retries": cuda_malloc_retries, - "num_loss_counted_tokens": int(num_loss_counted_tokens), - "num_tokens_rank0": int(total_length), - "batch_size": int(micro_batch_size), - "total_loss": float(log_loss / num_loss_counted_tokens), + "num_loss_counted_tokens": batch_metrics.num_loss_counted_tokens, + "num_tokens_rank0": batch_metrics.total_length, + "batch_size": batch_metrics.total_samples, + "num_minibatches": batch_metrics.num_minibatches, + "avg_loss": float(avg_loss_across_ranks), "samples_seen": samples_seen, "gradnorm": global_grad_norm, "total_samples": len(accelerator.train_loader.dataset), @@ -215,9 +189,7 @@ def train( extra={"step": global_step}, ) - if args.save_samples > 0 and ( - global_step * batch_size % args.save_samples == 0 - ): + if args.save_samples > 0 and (samples_seen % args.save_samples == 0): base_logger.debug(f"Saving checkpoint at step {global_step}") save_checkpoint( args=args, @@ -229,7 +201,7 @@ def train( hf_format=True, ) base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank) - torch.distributed.barrier() + dist.barrier() global_step += 1 if local_rank == 0: @@ -249,7 +221,7 @@ def train( epoch=epoch, ) base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank) - torch.distributed.barrier() + dist.barrier() if args.save_last: save_hf_format_accelerate( @@ -317,23 +289,17 @@ def main(args): timeout = _get_collective_timeout() if timeout is not None: - torch.distributed.init_process_group(timeout=timeout) + dist.init_process_group(timeout=timeout) else: - torch.distributed.init_process_group() + dist.init_process_group() - args.global_rank = torch.distributed.get_rank() + args.global_rank = dist.get_rank() tensor = torch.ByteTensor([False]).cuda() - torch.distributed.all_reduce(tensor) - torch.distributed.barrier() + dist.all_reduce(tensor) + dist.barrier() flash_enabled = Model.check_flash_attn_enabled(args.disable_flash_attn) - dataset = setup_dataset( - args.data_path, - mock=args.mock_data, - mock_len=args.mock_len, - ) - # This model class wraps the various AutoModel classes we support # based on model_type, and model_path -> choose auto_model lora_config = None @@ -360,7 +326,7 @@ def main(args): # Get the model class with default fallback model_class = model_class_map.get(model_type, CausalLMModel) - m = model_class( + m: Model = model_class( model_path=args.model_name_or_path, output_dir=args.output_dir, lora_config=lora_config, @@ -373,81 +339,67 @@ def main(args): args.base_model_args = m.base_model_args - try: - packing_max_batch_len, grad_accum = find_packing_max_batch_len_and_grad_accum( - num_gpus=torch.distributed.get_world_size(), - avg_sample_len=dataset.get_lengths().mean(), - effective_batch_size=args.effective_batch_size, - max_batch_len_per_gpu=args.max_batch_len, - is_padding=not flash_enabled, - dataset=dataset, - seed=args.seed, - ) - args.sampler = "multipack" - except RuntimeError as e: - logger.error(e) - - # fallback to grad accum = 1 - # NOTE: packing max batch len will not be used - packing_max_batch_len = None - grad_accum = 1 - args.sampler = "distributed" - - args.samples_per_gpu = ( - args.effective_batch_size // grad_accum // torch.distributed.get_world_size() - ) - - train_loader = setup_dataloader( - dataset, - tokenizer.pad_token_id, - num_workers=8, - flash_enabled=flash_enabled, - max_batch_len=args.max_batch_len, - packing_max_batch_len=packing_max_batch_len, - samples_per_gpu=args.samples_per_gpu, - sampler=args.sampler, + # IMPORTANT: Freeze GPT-OSS router parameters BEFORE accelerator setup + # This must happen before FSDP preparation to avoid requires_grad uniformity issues + # + # NOTE: in theory this can work for any MoE model with router parameters, but we handle + # GPT-OSS specifically + # We don't want to use use_orig_params for GPT-OSS models + fsdp_should_use_orig_params = False + if m.is_gpt_oss: + logger.info("๐ŸŽฏ Detected GPT-OSS model - freezing router parameters") + freeze_router_params(m) + # For GPT-OSS, we need to use the original parameters so we can properly + # freeze the router parameters. + fsdp_should_use_orig_params = True + + # Mini_trainer approach: simplified setup + # No complex calculations needed - the data loader handles everything + packing_max_batch_len = args.max_batch_len + + # Mini_trainer approach: use effective_batch_size as the data loader batch_size + # Let the collator handle distribution across GPUs and dynamic minibatching + batch_size = args.effective_batch_size + + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 + train_loader = get_data_loader( + data_path=args.data_path, + batch_size=batch_size, + max_tokens_per_gpu=packing_max_batch_len, seed=args.seed, + rank=dist.get_rank(), + world_size=dist.get_world_size(), + num_workers=8, # I don't like this but am setting it for consistency + flash_enabled=flash_enabled, + pad_token_id=pad_token_id, ) - if len(train_loader) == 0: - # this happens sometimes when we have more GPUs than data to process. In this case - # we should either alert the user to switch samplers, or do it automatically and - # warn them about it happening - logger.warning( - "The dataset is too small for multipack to distribute all of the samples across GPUs. Falling back to the distributed sampler!" - ) - args.sampler = "distributed" - train_loader = setup_dataloader( - dataset, - tokenizer.pad_token_id, - num_workers=8, - flash_enabled=flash_enabled, - max_batch_len=args.max_batch_len, - packing_max_batch_len=packing_max_batch_len, - samples_per_gpu=args.samples_per_gpu, - sampler=args.sampler, - seed=args.seed, - ) if args.local_rank == 0: metric_logger.info( { - "num_gpus": torch.distributed.get_world_size(), - "avg_sample_len": dataset.get_lengths().mean(), + "num_gpus": dist.get_world_size(), + "avg_sample_len": train_loader.dataset.get_lengths().mean(), "effective_batch_size": args.effective_batch_size, "max_batch_len_per_gpu": args.max_batch_len, "packing_max_batch_len": packing_max_batch_len, - "grad_accum": grad_accum, + # grad_accum will be determined dynamically per batch "num_batches": len(train_loader), - "avg_samples_per_batch": len(dataset) / len(train_loader), - "samples_per_gpu": args.samples_per_gpu, - "total_samples": len(dataset), # emit the total number of samples + "avg_samples_per_batch": len(train_loader.dataset) / len(train_loader), + "total_samples": len( + train_loader.dataset + ), # emit the total number of samples }, extra={"hparams": True}, ) # accelerator does not need optimizer to init, in fact, the optimizer needs to be initialized AFTER the Accelerator + # Required for DeepSpeed/FSDP configuration even though mini_trainer uses dynamic approach + # These set up the distributed framework but actual batching is handled dynamically + samples_per_gpu = 1 # Base unit - actual samples determined by data loader + grad_accum = 1 # Base unit - actual accumulation handled by minibatch loop + accelerator = Accelerator( model=m, - samples_per_gpu=args.samples_per_gpu, + samples_per_gpu=samples_per_gpu, grad_accum=grad_accum, train_loader=train_loader, distributed_framework=DistributedBackend(args.distributed_training_framework), @@ -457,6 +409,7 @@ def main(args): deepspeed_cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio, fsdp_cpu_offload_params=args.cpu_offload_params_fsdp, save_samples=args.save_samples, + fsdp_use_orig_params=fsdp_should_use_orig_params, ) # optimizer needs model that has been prepared by accelerator # and then accelerator needs to be prepared AGAIN once optimizer is initialized @@ -481,12 +434,11 @@ def main(args): train( args, model=m, - optimizer=optimizer, accelerator=accelerator, ) - torch.distributed.barrier() - torch.distributed.destroy_process_group() + dist.barrier() + dist.destroy_process_group() # public API diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index 24eac063..de863e1d 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -30,7 +30,12 @@ # Third Party from peft import LoraConfig from torch.optim import AdamW -from transformers import PreTrainedTokenizer +from transformers import Mxfp4Config # pylint: disable=no-name-in-module +from transformers import ( + AutoModelForCausalLM, + BitsAndBytesConfig, + PreTrainedTokenizer, +) import torch # First Party @@ -38,6 +43,8 @@ DistributedBackend, Optimizer, ) +from instructlab.training.gpt_oss_utils_correct import is_gpt_oss +from instructlab.training.type_definitions import ModelInputs, ModelLosses class Model: @@ -55,12 +62,19 @@ def __init__( self.noise_alpha = noise_alpha self.tokenizer = tokenizer self.distributed_framework = distributed_framework - bnb_config = None - if lora_config and lora_config.r > 0 and lora_quant_bits == 4: + quant_config = None + + # check model type & set on the mclasss + self.is_gpt_oss = is_gpt_oss(model_path) + if self.is_gpt_oss: # Third Party - from transformers import BitsAndBytesConfig + quant_config = Mxfp4Config(dequantize=True) - bnb_config = BitsAndBytesConfig( + # TODO: Add support for 8bit quantization + elif lora_config and lora_config.r > 0 and lora_quant_bits == 4: + # Third Party + + quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, @@ -69,12 +83,22 @@ def __init__( self.base_model_args = { "pretrained_model_name_or_path": model_path, - "torch_dtype": torch.bfloat16, - "quantization_config": bnb_config, + "quantization_config": quant_config, } + # load GPT-OSS in bfloat16 because it's a massive model, but otherwise + # we do not specify the default dtype so mixed precision training works + # correctly + if self.is_gpt_oss: + self.base_model_args["torch_dtype"] = torch.bfloat16 + + # set flash attention accordingly if flash_enabled: self.base_model_args["attn_implementation"] = "flash_attention_2" + if self.is_gpt_oss: + self.base_model_args["attn_implementation"] = ( + "kernels-community/vllm-flash-attn3" + ) def _post_model_init(self): """Common initialization steps that should happen after model initialization.""" @@ -308,9 +332,8 @@ def reconcile_tokenizer(self): ) # Local - from .utils import add_noisy_embeddings, convert_loss_to_reduce_sum + from .utils import add_noisy_embeddings - self.model = convert_loss_to_reduce_sum(self.model) self.model = add_noisy_embeddings(self.model, noise_alpha=self.noise_alpha) @staticmethod @@ -345,6 +368,73 @@ def check_flash_attn_enabled(disable_flash_attn: bool) -> bool: "ERROR: Trying to use Flash Attention on unsupported hardware. Please set disable_flash_attn to True." ) + def compute_loss( + self, inputs: ModelInputs, world_size: int, samples_in_batch: int + ) -> tuple[torch.Tensor, ModelLosses]: + """ + Computes the cross-entropy los for the given model and returns a final loss which can be backproped on, as well as + a dataclass with the raw losses. + + Args: + inputs (ModelInputs): dict containing the items we would need in order to make the forward computation + world_size (int): number of total nodes + samples_in_batch (int): samples in the entire mini-batch + + Returns: + tuple[torch.Tensor, ModelLosses]: + - The total loss which can be used to backprop + - Dataclass containing the raw pre-scaled losses + """ + # Forward pass to get logits + output = self( + **inputs, + use_cache=False, + ) + + # Manual loss computation with reduction="none" following mini_trainer's exact approach + + # Manually compute cross-entropy loss with reduction="none" + logits = output.logits + labels = inputs["labels"] + + # Shift logits and labels for causal LM (standard approach) + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Flatten tokens + shift_logits = shift_logits.view(-1, shift_logits.size(-1)) + shift_labels = shift_labels.view(-1) + + # Compute loss with reduction="none" to get per-token losses + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + token_losses = loss_fct(shift_logits, shift_labels) + + # Only sum losses for non-padding tokens (labels != -100) + valid_tokens = shift_labels != -100 + + # compute losses + aux_loss = None + primary_loss = token_losses[valid_tokens].sum() + + # add the MoE auxiliary loss (currently we only support this for GPT-OSS) + if ( + self.is_gpt_oss + and hasattr(output, "aux_loss") + and output.aux_loss is not None + ): + # For GPT-OSS: separate main loss and aux loss + aux_loss = output.aux_loss.float() # Auxiliary loss stays as scalar + + # Scale main loss following mini_trainer approach: loss * world_size / batch_num_loss_counted_tokens + scaled_main_loss = primary_loss * world_size / samples_in_batch + + # For GPT-OSS: add unscaled auxiliary loss after scaling main loss + if self.is_gpt_oss and aux_loss is not None: + scaled_main_loss += aux_loss + + raw_losses = ModelLosses(main_loss=primary_loss, aux_loss=aux_loss) + return scaled_main_loss, raw_losses + class LigerModel(Model): # pylint: disable=unused-argument @@ -410,9 +500,6 @@ def __init__( lora_config=lora_config, lora_quant_bits=lora_quant_bits, ) - # Third Party - from transformers import AutoModelForCausalLM - self.model = AutoModelForCausalLM.from_pretrained(**self.base_model_args) self._post_model_init() self.model.gradient_checkpointing_enable() @@ -455,8 +542,21 @@ def setup_optimizer( optimizer_cls = DeepSpeedCPUAdam else: optimizer_cls = FusedAdam + + # Filter parameters to only include those that require gradients + # This handles cases where some parameters (e.g., frozen router params) have requires_grad=False + trainable_params = filter(lambda p: p.requires_grad, model.parameters()) + + # Count trainable parameters for logging + total_params = sum(1 for _ in model.parameters()) + trainable_count = sum(1 for p in model.parameters() if p.requires_grad) + if total_params != trainable_count: + logger.info( + f"๐Ÿ“Š Using {trainable_count}/{total_params} trainable parameters in optimizer" + ) + factory = functools.partial( - optimizer_cls, model.parameters(), lr=learning_rate, betas=betas + optimizer_cls, trainable_params, lr=learning_rate, betas=betas ) if optimizer_cls is AdamW: return factory(weight_decay=0.0) diff --git a/src/instructlab/training/multipack_sampler.py b/src/instructlab/training/multipack_sampler.py deleted file mode 100644 index 6b9a4941..00000000 --- a/src/instructlab/training/multipack_sampler.py +++ /dev/null @@ -1,430 +0,0 @@ -""" -MIT License - -Copyright (c) 2023 One - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -taken from https://github.com/imoneoi/multipack_sampler -""" - -# Standard -from typing import List, Optional -import warnings - -# Third Party -from torch.utils.data import Sampler -import numba -import numpy as np -import torch -import torch.distributed as dist - - -def find_max_pack_len_with_padding( - dataset, - samples_per_minibatch, - num_gpus, - avg_sample_len, - seed, -): - """ - This function calculates the maximum batch length with padding for a given dataset. it uses a binary search to find the optimal addition to the average sample length that will result in the average batch size per minibatch being less than or equal to the number of samples per minibatch. - - Parameters: - - dataset: The dataset for which the maximum batch length is to be calculated. - - samples_per_minibatch: The number of samples per minibatch. - - num_gpus: The number of GPUs available for computation. - - avg_sample_len: The average length of a sample in the dataset. - - seed: The seed for the random number generator. - - Returns: - - The maximum batch length with padding for the given dataset. - """ - - def get_effective_samples_per_minibatch(num_tokens_per_gpu): - """ - This nested function calculates the effective number of samples per minibatch for a given number of tokens per GPU. - - Parameters: - - num_tokens_per_gpu: The number of tokens per GPU. - - Returns: - - The effective number of samples per minibatch. - - The function creates a sampler using the MultipackDistributedBatchSampler class, generates batches using the sampler, and then returns the ratio of the dataset size to the number of batches. - """ - sampler = MultipackDistributedBatchSampler( - batch_max_length=num_tokens_per_gpu, - lengths=dataset.get_lengths(), - num_replicas=torch.distributed.get_world_size(), - rank=torch.distributed.get_rank(), - seed=seed, - padding=True, - ) - batches = sampler.generate_batches() - return len(dataset) / len(batches) - - samples_per_gpu = samples_per_minibatch / num_gpus - - addition = int(avg_sample_len * 0.1 * samples_per_gpu) - packing_max_batch_len = int(avg_sample_len * samples_per_gpu) - - avg_bs_per_minibatch = get_effective_samples_per_minibatch( - packing_max_batch_len + addition - ) - while avg_bs_per_minibatch <= samples_per_minibatch: - addition *= 2 - avg_bs_per_minibatch = get_effective_samples_per_minibatch( - packing_max_batch_len + addition - ) - - l = 0 - r = addition - while r - l > 1: - addition = (l + r) // 2 - avg_bs_per_minibatch = get_effective_samples_per_minibatch( - packing_max_batch_len + addition - ) - - # check if simulation resulted in batch sizes close enough to goal and adjust if needed - if abs(avg_bs_per_minibatch - samples_per_minibatch) <= max( - 10, round(avg_bs_per_minibatch * 0.02) - ): - break - if avg_bs_per_minibatch > samples_per_minibatch: - r = addition - else: - l = addition - - return packing_max_batch_len + addition - - -def find_packing_max_batch_len_and_grad_accum( - num_gpus, - avg_sample_len, - effective_batch_size, - max_batch_len_per_gpu, - is_padding, - dataset, - seed, -): - """ - Calculate the minimum gradient accumulation steps required and the corresponding maximum batch length. - - This function determines the minimum number of gradient accumulation steps needed to process the - effective batch size within the constraints of the maximum batch length per GPU. It starts with - the assumption of a single step (no accumulation) and increases the number of steps until the - calculated batch length does not exceed the maximum allowed per GPU. The goal is to find the - lowest gradient accumulation that allows fitting the batch within GPU limits, ensuring efficient - utilization of computational resources. - - Parameters: - - num_gpus (int): The number of GPUs over which the batch is distributed. - - avg_sample_len (int): The average length of samples in the dataset, used to estimate batch length. - - effective_batch_size (int): The total batch size intended to be processed across all GPUs and - accumulation steps. - - max_batch_len_per_gpu (int): The maximum permissible number of tokens on each GPU to avoid memory overflow. - - Returns: - - Tuple[int, int]: A tuple where the first element is the maximum batch length that can be achieved - without exceeding the per-GPU limit, and the second element is the minimum number of gradient - accumulation steps required to maintain the effective batch size. - """ - - packing_max_batch_len = max_batch_len_per_gpu + 1 - grad_accum = 0 - while packing_max_batch_len > max_batch_len_per_gpu: - grad_accum += 1 - samples_per_minibatch = effective_batch_size / grad_accum - samples_per_gpu = samples_per_minibatch / num_gpus - if int(avg_sample_len * samples_per_gpu) < dataset.get_lengths().max(): - raise RuntimeError( - f"Effective batch size is too low for multipack sampling, max sample length={dataset.get_lengths().max()} and min packing length={int(avg_sample_len * samples_per_gpu)}. " - "Switching to naive distributed sampling." - ) - if is_padding: - packing_max_batch_len = find_max_pack_len_with_padding( - dataset, - samples_per_minibatch, - num_gpus, - avg_sample_len, - seed, - ) - else: - packing_max_batch_len = int((avg_sample_len) * samples_per_gpu) - - return packing_max_batch_len, grad_accum - - -@numba.njit -def ffd_check(a: np.ndarray, c: int, n: int): - # First-fit-decreasing bin packing - # Check if a[] could fit in n bins with capacity c - # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing - - a = np.sort(a)[::-1] - bins = np.full((n,), c, dtype=a.dtype) - for size in a: - not_found = True - for idx in range(n): - if bins[idx] >= size: - bins[idx] -= size - not_found = False - break - - if not_found: - return False - - return True - - -@numba.njit -def ffd_check_padding(a: np.ndarray, c: int, n: int): - # First-fit-decreasing bin packing - # Check if a[] could fit in n bins with capacity c - # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing - - a = np.sort(a)[::-1] - bins_max_lengths = np.zeros( - (n,), dtype=a.dtype - ) # Track the maximum length in each bin - bins_num_samples = np.zeros( - (n,), dtype=np.int_ - ) # Track the number of samples in each bin - - for size in a: - not_found = True - for idx in range(n): - # Calculate the new capacity if size is added to the bin - new_capacity = max(bins_max_lengths[idx], size) * ( - bins_num_samples[idx] + 1 - ) - if new_capacity <= c: - bins_max_lengths[idx] = max(bins_max_lengths[idx], size) - bins_num_samples[idx] += 1 - not_found = False - break - - if not_found: - return False - - return True - - -@numba.njit -def ffd_with_result(a: np.ndarray, c: int, start_index: int): - # First-fit-decreasing bin packing (with result return) - - indices = np.argsort(a)[::-1] - a = a[indices] - - bins = [] - bins_result = [] - for a_id, size in enumerate(a): - add_new = True - for idx in range(len(bins)): - if bins[idx] >= size: - bins[idx] -= size - bins_result[idx].append(indices[a_id] + start_index) - add_new = False - break - - if add_new: - bins.append(c - size) - bins_result.append([indices[a_id] + start_index]) - - return bins_result - - -@numba.njit -def ffd_with_result_padding(a: np.ndarray, c: int, start_index: int): - # First-fit-decreasing bin packing (with result return) - - indices = np.argsort(a)[::-1] - a = a[indices] - - bins_max_lengths = [] # Track the maximum length in each bin - bins_num_samples = [] # Track the number of samples in each bin - bins_result = [] # Track the indices of the samples in each bin - - for a_id, size in enumerate(a): - add_new = True - for idx in range(len(bins_max_lengths)): - # Calculate the new capacity if size is added to the bin - new_capacity = max(bins_max_lengths[idx], size) * ( - bins_num_samples[idx] + 1 - ) - if new_capacity <= c: - bins_max_lengths[idx] = max(bins_max_lengths[idx], size) - bins_num_samples[idx] += 1 - bins_result[idx].append(indices[a_id] + start_index) - add_new = False - break - - if add_new: - bins_max_lengths.append(size) - bins_num_samples.append(1) - bins_result.append([indices[a_id] + start_index]) - - return bins_result - - -@numba.njit -def allocate( - lengths: np.ndarray, - lengths_cumsum: np.ndarray, - rank: int, - c: int, - n: int, - padding: bool = True, -): - # Dynamic batch allocator, similar to Multifit - # https://en.wikipedia.org/wiki/Multifit_algorithm - # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) - - s = 0 - start_index = 0 - result = [] - - while True: - # binary search [l, r) - l = 1 - r = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") - - while r - l > 1: - m = (l + r) // 2 - if padding: - check = ffd_check_padding(lengths[start_index : start_index + m], c, n) - else: - check = ffd_check(lengths[start_index : start_index + m], c, n) - if check: - l = m - else: - r = m - - # use length l - if padding: - batch = ffd_with_result_padding( - lengths[start_index : start_index + l], c, start_index - ) - else: - batch = ffd_with_result( - lengths[start_index : start_index + l], c, start_index - ) - assert len(batch) <= n - if len(batch) < n: - break - - start_index += l - s = lengths_cumsum[start_index - 1] - - # add local rank - result.append(batch[rank]) - - return result, s, len(result) * c * n - - -class MultipackDistributedBatchSampler(Sampler): - """Unpadded length sampling using Multipack. - Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard. - """ - - def __init__( - self, - batch_max_length: int, - lengths: List[int], - num_replicas: Optional[int] = None, - rank: Optional[int] = None, - seed: int = 0, - padding: bool = True, - ): - # Get rank - if num_replicas is None: - if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available") - num_replicas = dist.get_world_size() - if rank is None: - if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available") - rank = dist.get_rank() - - self.num_replicas = num_replicas - self.rank = rank - self.seed = seed - - self.batch_max_length = batch_max_length - self.lengths = lengths - assert isinstance(self.lengths, np.ndarray) - - self.epoch = 0 - - # statistics - self.eff_total_used = 0 - self.eff_total_slots = 0 - self.padding = padding - - def set_epoch(self, epoch: int): - self.epoch = epoch - - def generate_batches(self, set_stats=False): - indices = np.random.default_rng(seed=self.seed + self.epoch).permutation( - len(self.lengths) - ) - - # remove indices where the entries are longer than batch max length - indices = indices[self.lengths[indices] <= self.batch_max_length] - if len(indices) < len(self.lengths): - warnings.warn( - "Dropping %d samples longer than batch_max_length. Ensure that the right max_batch_length is used during data processing.", - len(self.lengths) - len(indices), - ) - - lengths = self.lengths[indices] - lengths_cumsum = np.cumsum(lengths) - - batches, total_used, total_slots = allocate( - lengths=lengths, - lengths_cumsum=lengths_cumsum, - rank=self.rank, - c=self.batch_max_length, - n=self.num_replicas, - padding=self.padding, - ) - - batches = [indices[batch] for batch in batches] - - # statistics - if set_stats: - self.eff_total_used += total_used - self.eff_total_slots += total_slots - - return batches - - def __iter__(self): - batches = self.generate_batches(set_stats=True) - return iter(batches) - - def __len__(self): - return self.num_batches() - - def num_batches(self): - batches = self.generate_batches() - return len(batches) - - def efficiency(self): - return self.eff_total_used / self.eff_total_slots diff --git a/src/instructlab/training/padded_batch_packer.py b/src/instructlab/training/padded_batch_packer.py new file mode 100644 index 00000000..3cad913b --- /dev/null +++ b/src/instructlab/training/padded_batch_packer.py @@ -0,0 +1,288 @@ +""" +Numba-optimized batch packing for padded training (non-flash attention). + +This module provides high-performance batch packing that minimizes padding +while maintaining good load balance across distributed training ranks. +""" + +# Third Party +from numba import int64, njit +import numpy as np + + +@njit +def _compute_padded_tokens(lengths: np.ndarray, start: int64, end: int64) -> int64: + """Compute total tokens including padding for a batch.""" + if start >= end: + return 0 + max_len = lengths[start] # Since sorted descending, first is max + return max_len * (end - start) + + +@njit +def _compute_wasted_tokens(lengths: np.ndarray, start: int64, end: int64) -> int64: + """Compute wasted tokens (padding) for a batch.""" + if start >= end: + return 0 + max_len = lengths[start] + total_actual = 0 + for i in range(start, end): + total_actual += lengths[i] + return max_len * (end - start) - total_actual + + +@njit +def _find_optimal_batch_size( + lengths: np.ndarray, start_idx: int64, max_tokens: int64, max_batch_size: int64 +) -> int64: + """Find optimal batch size that minimizes padding ratio while fitting constraints.""" + n = len(lengths) + if start_idx >= n: + return 0 + + # Maximum possible batch size given token constraint + max_len = lengths[start_idx] + max_sequences = ( + min(max_tokens // max_len, max_batch_size) if max_len > 0 else max_batch_size + ) + max_sequences = min(max_sequences, n - start_idx) + + if max_sequences <= 1: + return 1 + + # Find batch size that minimizes padding ratio + best_size = 1 + best_ratio = 1.0 + + for size in range(2, max_sequences + 1): + # Check if this batch size fits within token limit + padded_tokens = _compute_padded_tokens(lengths, start_idx, start_idx + size) + if padded_tokens > max_tokens: + break + + # Compute padding ratio + actual_tokens = 0 + for i in range(start_idx, start_idx + size): + actual_tokens += lengths[i] + + if padded_tokens > 0: + padding_ratio = float(padded_tokens - actual_tokens) / float(padded_tokens) + + # Prefer larger batches if padding ratio is similar (within 5%) + if padding_ratio < best_ratio - 0.05 or ( + abs(padding_ratio - best_ratio) < 0.05 and size > best_size + ): + best_ratio = padding_ratio + best_size = size + + return best_size + + +@njit +def _distribute_batches_balanced( + batch_tokens: np.ndarray, num_ranks: int64, rank: int64 +) -> np.ndarray: + """Distribute batches across ranks to balance total padded tokens.""" + n_batches = len(batch_tokens) + if n_batches == 0: + return np.empty(0, dtype=np.int64) + + # Compute cumulative load for each rank using round-robin with load balancing + rank_loads = np.zeros(num_ranks, dtype=np.int64) + batch_assignments = np.empty(n_batches, dtype=np.int64) + + # Sort batches by size (largest first) for better load balancing + sorted_indices = np.argsort(batch_tokens)[::-1] + + for i in range(n_batches): + batch_idx = sorted_indices[i] + # Assign to least loaded rank + min_rank = 0 + min_load = rank_loads[0] + for r in range(1, num_ranks): + if rank_loads[r] < min_load: + min_load = rank_loads[r] + min_rank = r + + batch_assignments[batch_idx] = min_rank + rank_loads[min_rank] += batch_tokens[batch_idx] + + # Return indices for this rank + my_batches = [] + for i in range(n_batches): + if batch_assignments[i] == rank: + my_batches.append(i) + + return np.array(my_batches, dtype=np.int64) + + +@njit +def _padded_batch_packing_core( + lengths: np.ndarray, + max_tokens: int64, + num_ranks: int64, + rank: int64, + max_batch_size: int64, +) -> tuple: + """Core numba-optimized batch packing for padded training. + + Returns: + batch_indices: 2D array where each row is a batch of sequence indices + batch_sizes: number of sequences in each batch + """ + n_sequences = len(lengths) + if n_sequences == 0: + return np.empty((0, 0), dtype=np.int64), np.empty(0, dtype=np.int64) + + # Sort by length descending for better packing + sorted_indices = np.argsort(lengths)[::-1] + sorted_lengths = lengths[sorted_indices] + + # First pass: create all batches + max_batches = n_sequences # Worst case: one sequence per batch + temp_batch_starts = np.zeros(max_batches, dtype=np.int64) + temp_batch_sizes = np.zeros(max_batches, dtype=np.int64) + temp_batch_tokens = np.zeros(max_batches, dtype=np.int64) + n_batches = 0 + + start_idx = 0 + while start_idx < n_sequences: + # Find optimal batch size for current position + batch_size = _find_optimal_batch_size( + sorted_lengths, start_idx, max_tokens, max_batch_size + ) + + if batch_size == 0: + break + + temp_batch_starts[n_batches] = start_idx + temp_batch_sizes[n_batches] = batch_size + temp_batch_tokens[n_batches] = _compute_padded_tokens( + sorted_lengths, start_idx, start_idx + batch_size + ) + + n_batches += 1 + start_idx += batch_size + + # Distribute batches across ranks + batch_tokens = temp_batch_tokens[:n_batches] + my_batch_indices = _distribute_batches_balanced(batch_tokens, num_ranks, rank) + + if len(my_batch_indices) == 0: + # Return single batch with padding indicator + result_indices = np.full((1, 1), -1, dtype=np.int64) + result_sizes = np.ones(1, dtype=np.int64) + return result_indices, result_sizes + + # Build result for this rank + n_my_batches = len(my_batch_indices) + max_batch_len = 0 + for i in range(n_my_batches): + batch_idx = my_batch_indices[i] + size = temp_batch_sizes[batch_idx] + max_batch_len = max(max_batch_len, size) + + result_indices = np.full((n_my_batches, max_batch_len), -1, dtype=np.int64) + result_sizes = np.zeros(n_my_batches, dtype=np.int64) + + for i in range(n_my_batches): + batch_idx = my_batch_indices[i] + start = temp_batch_starts[batch_idx] + size = temp_batch_sizes[batch_idx] + + for j in range(size): + result_indices[i, j] = sorted_indices[start + j] + result_sizes[i] = size + + return result_indices, result_sizes + + +def batch_lengths_to_minibatches_padded( + batch_lengths: list[int], + max_tokens_per_rank: int, + num_ranks: int, + rank: int, + max_batch_size: int = 64, +) -> list[list[int]]: + """Batch packing optimized for padded training (non-flash attention). + + Groups sequences to minimize total padding while maintaining good load + balance across ranks. Sequences within each batch are padded to the + length of the longest sequence in that batch. + + Args: + batch_lengths: List of sequence lengths (in tokens) + max_tokens_per_rank: Maximum tokens allowed per rank per batch (including padding) + num_ranks: Total number of distributed training ranks (GPUs) + rank: The specific rank to retrieve assigned indices for + max_batch_size: Maximum sequences per batch (default 64) + + Returns: + List of lists, where each inner list contains indices assigned to this rank + for one batch. Index -1 indicates padding/placeholder. + """ + if not batch_lengths: + return [] + + # Convert to numpy + lengths = np.array(batch_lengths, dtype=np.int64) + + # Call numba-optimized core + batch_indices, batch_sizes = _padded_batch_packing_core( + lengths, max_tokens_per_rank, num_ranks, rank, max_batch_size + ) + + # Convert to list format + result = [] + for i in range(len(batch_sizes)): + size = batch_sizes[i] + if batch_indices[i, 0] == -1: + result.append([-1]) + else: + result.append(batch_indices[i, :size].tolist()) + + return result + + +def compute_padding_stats(batch_lengths: list[int], batches: list[list[int]]) -> dict: + """Compute padding statistics for given batches. + + Args: + batch_lengths: Original sequence lengths + batches: List of batches (each batch is a list of indices) + + Returns: + Dictionary with padding statistics + """ + total_actual_tokens = 0 + total_padded_tokens = 0 + total_sequences = 0 + + for batch in batches: + if not batch or batch[0] == -1: + continue + + # Find max length in this batch + max_len = max(batch_lengths[idx] for idx in batch) + + # Compute tokens + for idx in batch: + actual_len = batch_lengths[idx] + total_actual_tokens += actual_len + total_padded_tokens += max_len + total_sequences += 1 + + padding_ratio = 0.0 + if total_padded_tokens > 0: + padding_ratio = ( + total_padded_tokens - total_actual_tokens + ) / total_padded_tokens + + return { + "total_sequences": total_sequences, + "total_actual_tokens": total_actual_tokens, + "total_padded_tokens": total_padded_tokens, + "total_padding_tokens": total_padded_tokens - total_actual_tokens, + "padding_ratio": padding_ratio, + "num_batches": len([b for b in batches if b and b[0] != -1]), + } diff --git a/src/instructlab/training/sampler.py b/src/instructlab/training/sampler.py new file mode 100644 index 00000000..be2b3f4b --- /dev/null +++ b/src/instructlab/training/sampler.py @@ -0,0 +1,388 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +from typing import Optional + +# Third Party +from datasets import load_dataset +from torch.utils.data import DataLoader, Dataset, Sampler +import numpy as np +import torch +import torch.distributed as dist + +# First Party +from instructlab.training.batch_packer import batch_lengths_to_minibatches_lpt +from instructlab.training.padded_batch_packer import ( + batch_lengths_to_minibatches_padded, +) +from instructlab.training.type_definitions import CollatedItem + + +class EpochSampler(Sampler): + """ + Epoch-based sampler that provides shuffled data indices for each epoch. + Replaces the naive distributed sampler with reproducible epoch-based shuffling. + """ + + def __init__(self, len_data: int, seed: int = 67, epoch: int = 0): + self.len_data = len_data + self.seed = seed + self._epoch = epoch + + @property + def epoch(self) -> int: + return self._epoch + + def set_epoch(self, epoch: int): + self._epoch = epoch + + def generate_samples(self): + g = torch.Generator() + g.manual_seed(self.seed + self._epoch) + samples = torch.randperm(self.len_data, generator=g).tolist() + return samples + + def __iter__(self): + samples = self.generate_samples() + yield from samples + + def __len__(self): + return self.len_data + + +def mb_collate_fn(minibatch, batch_num_loss_counted_tokens) -> CollatedItem: + """Collates a list of samples into a single packed batch for Flash Attention. + + This function takes a 'minibatch' (list of pre-fetched dataset samples) + and concatenates their 'input_ids', 'labels', and generates corresponding + 'position_ids'. It does *not* add padding. + + The resulting batch format is 'packed' or 'unpadded', where multiple sequences + are concatenated into single tensors. Sequence boundaries are implicitly defined + by the 'position_ids', which restart from 0 for each concatenated sequence. + + **IMPORTANT**: This format requires the downstream model's attention mechanism + (e.g., Flash Attention) to correctly handle packed sequences. Standard attention + implementations may not work correctly as they expect padded inputs and explicit + attention masks. Flash Attention typically uses mechanisms like `cu_seqlens` + (cumulative sequence lengths), derived from position IDs or sequence lengths, + to compute the correct block-diagonal attention implicitly. + + Args: + minibatch: A list of dictionaries, where each dictionary represents a + sample and contains at least 'input_ids' and 'labels'. + batch_num_loss_counted_tokens: Total number of loss-counted tokens in the batch. + + Returns: + A dictionary containing the collated batch: + - 'input_ids': Single tensor of concatenated input IDs. + - 'labels': Single tensor of concatenated labels. + - 'position_ids': Single tensor of position IDs, reset for each sequence. + - 'num_loss_counted_tokens': Total number of non-ignored label tokens (-100). + - 'num_samples': The number of sequences packed into this batch. + """ + input_ids = [] + labels = [] + position_ids = [] + total_len = 0 + num_loss_counted_tokens = 0 + num_samples = 0 + + for item in minibatch: + item_len = len(item["input_ids"]) + + input_ids.extend(item["input_ids"]) + labels.extend(item["labels"]) + position_ids.extend(range(item_len)) + + total_len += item_len + num_loss_counted_tokens += item["num_loss_counted_tokens"] + + # Dummy samples don't have labels != -100 and should not count + num_samples += 1 if item["num_loss_counted_tokens"] > 0 else 0 + + return { + "input_ids": torch.tensor([input_ids], dtype=torch.long), + "labels": torch.tensor([labels], dtype=torch.long), + "position_ids": torch.tensor([position_ids], dtype=torch.long), + "num_loss_counted_tokens": num_loss_counted_tokens, + "num_samples": num_samples, + "batch_num_loss_counted_tokens": batch_num_loss_counted_tokens, + "total_length": total_len, # Total tokens in the batch + } + + +def padded_mb_collate_fn( + minibatch, batch_num_loss_counted_tokens, pad_token_id=0 +) -> CollatedItem: + """Collates a list of samples into a padded batch for standard attention. + + This function takes a minibatch (list of dataset samples) and creates padded + tensors suitable for standard attention mechanisms. Unlike the flash attention + version, this pads all sequences to the same length and creates attention masks. + + Args: + minibatch: A list of dictionaries, where each dictionary represents a + sample and contains 'input_ids' and 'labels'. + batch_num_loss_counted_tokens: Total number of loss-counted tokens in the batch. + + Returns: + A dictionary containing the collated batch: + - 'input_ids': 2D tensor of padded input IDs [batch_size, max_len] + - 'labels': 2D tensor of padded labels [batch_size, max_len] + - 'attention_mask': 2D tensor indicating real vs padding tokens + - 'num_loss_counted_tokens': Total number of non-ignored label tokens + - 'num_samples': The number of real sequences in this batch + """ + if not minibatch: + # Return empty batch + return { + "input_ids": torch.tensor([[]], dtype=torch.long), + "labels": torch.tensor([[]], dtype=torch.long), + "attention_mask": torch.tensor([[]], dtype=torch.long), + "num_loss_counted_tokens": 0, + "num_samples": 0, + "batch_num_loss_counted_tokens": batch_num_loss_counted_tokens, + "total_length": 0, + } + + # Find max length in this batch + max_len = max(len(item["input_ids"]) for item in minibatch) + + # Prepare lists for batched tensors + padded_input_ids = [] + padded_labels = [] + attention_masks = [] + num_loss_counted_tokens = 0 + num_samples = 0 + + for item in minibatch: + item_len = len(item["input_ids"]) + + # Pad input_ids with the provided pad_token_id + pad_length = max_len - item_len + input_ids = item["input_ids"] + if isinstance(input_ids, torch.Tensor): + input_ids = input_ids.tolist() + padded_input = input_ids + [pad_token_id] * pad_length + padded_input_ids.append(padded_input) + + # Pad labels with -100 (ignore index) + labels = item["labels"] + if isinstance(labels, torch.Tensor): + labels = labels.tolist() + padded_label = labels + [-100] * pad_length + padded_labels.append(padded_label) + + # Create attention mask (1 for real tokens, 0 for padding) + attention_mask = [1] * item_len + [0] * pad_length + attention_masks.append(attention_mask) + + # Count loss tokens and samples + num_loss_counted_tokens += item["num_loss_counted_tokens"] + # Only count as a sample if it has loss-counted tokens + if item["num_loss_counted_tokens"] > 0: + num_samples += 1 + + return { + "input_ids": torch.tensor(padded_input_ids, dtype=torch.long), + "labels": torch.tensor(padded_labels, dtype=torch.long), + "attention_mask": torch.tensor(attention_masks, dtype=torch.long), + "num_loss_counted_tokens": num_loss_counted_tokens, + "num_samples": num_samples, + "batch_num_loss_counted_tokens": batch_num_loss_counted_tokens, + "total_length": max_len * len(minibatch), # Total padded tokens + } + + +class MaxTokensPerRankCollator: + """A unified collate function for PyTorch DataLoader for distributed training. + + This collator supports both flash attention (unpadded) and standard attention (padded) modes. + It takes a batch of samples and: + 1. Filters out samples longer than `max_tokens_per_rank`. + 2. Uses the appropriate batch packing algorithm to distribute samples across ranks. + 3. Collates samples into the format required by the model. + + Args: + max_tokens_per_rank (int): Maximum number of tokens allowed per rank in a minibatch. + rank (int, optional): The rank of the current process. If None, uses torch.distributed. + world_size (int, optional): Total number of ranks. If None, uses torch.distributed. + dummy_sample (dict, optional): A sample used for padding when a rank has no real samples. + flash_enabled (bool): Whether to use flash attention mode (default: True). + pad_token_id (int): Token ID to use for padding in non-flash mode (default: 0). + """ + + def __init__( + self, + max_tokens_per_rank: int, + rank: Optional[int] = None, + world_size: Optional[int] = None, + dummy_sample=None, + flash_enabled: bool = True, + pad_token_id: int = 0, + ): + self.max_tokens_per_rank = max_tokens_per_rank + self.flash_enabled = flash_enabled + self.pad_token_id = pad_token_id + + self.global_rank = rank if rank is not None else dist.get_rank() + self.world_size = ( + world_size if world_size is not None else dist.get_world_size() + ) + + if dummy_sample is None: + dummy_sample = { + "input_ids": torch.tensor([15, 14, 13, 12, 11], dtype=torch.long), + "labels": torch.tensor( + [-100, -100, -100, -100, -100], dtype=torch.long + ), + "len": 5, + "num_loss_counted_tokens": 0, + } + self.dummy_sample = dummy_sample + + # Select the appropriate batch packer and collate function + if flash_enabled: + self.batch_packer = batch_lengths_to_minibatches_lpt + self.collate_fn = mb_collate_fn + else: + self.batch_packer = batch_lengths_to_minibatches_padded + # Create a wrapper for padded collate that includes pad_token_id + self.collate_fn = ( + lambda minibatch, batch_num_loss_counted_tokens: padded_mb_collate_fn( + minibatch, batch_num_loss_counted_tokens, pad_token_id + ) + ) + + def __call__(self, batch: list[dict]): + """Processes a batch of samples into minibatches for the current rank. + + Args: + batch: A list of sample dictionaries from the Dataset. + + Returns: + A list where each element is a collated minibatch ready for processing. + """ + # Filter out samples longer than max_tokens_per_rank + batch_ = [b for b in batch if b["len"] <= self.max_tokens_per_rank] + if len(batch_) < len(batch): + print( + f"\033[38;5;196mremoved {len(batch) - len(batch_)} samples from batch because they are longer than the max tokens per gpu\033[0m" + ) + + # Extract lengths and count loss tokens + batch_lengths = [sample["len"] for sample in batch_] + batch_num_loss_counted_tokens = sum( + [sample["num_loss_counted_tokens"] for sample in batch_] + ) + + # Use the appropriate batch packer + all_minibatches_indices = self.batch_packer( + batch_lengths, self.max_tokens_per_rank, self.world_size, self.global_rank + ) + + # Collate minibatches + all_minibatches = [] + for mb_indices in all_minibatches_indices: + mb = [batch_[i] if i != -1 else self.dummy_sample for i in mb_indices] + all_minibatches.append(self.collate_fn(mb, batch_num_loss_counted_tokens)) + + return all_minibatches + + +class TokenDataset(Dataset): + """Dataset for loading tokenized data from JSONL files. + + Handles both InstructLab format and mini_trainer format data. + """ + + def __init__(self, data_path: str): + dataset = load_dataset("json", data_files=data_path, split="train") + self.dataset = dataset + + # Compute lengths if not present + if "len" not in self.dataset.column_names: + self.lengths = np.array( + self.dataset.map( + lambda x: {"len": len(x["input_ids"])}, + num_proc=8, + )["len"] + ) + else: + self.lengths = np.array(self.dataset["len"]) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index: int): + sample = self.dataset[int(index)] + + # Calculate num_loss_counted_tokens if not present + if (loss_counted_tokens := sample.get("num_loss_counted_tokens", None)) is None: + loss_counted_tokens = sum( + 1 if label != -100 else 0 for label in sample["labels"] + ) + + return { + "input_ids": torch.tensor(sample["input_ids"], dtype=torch.long), + "labels": torch.tensor(sample["labels"], dtype=torch.long), + "len": sample["len"], + "num_loss_counted_tokens": loss_counted_tokens, + } + + def get_lengths(self): + return self.lengths + + +def get_data_loader( + data_path: str, + batch_size: int, + max_tokens_per_gpu: int, + seed: int, + rank: Optional[int] = None, + world_size: Optional[int] = None, + dummy_sample: Optional[dict] = None, + num_workers: int = 0, + flash_enabled: bool = True, + pad_token_id: int = 0, +): + """Create a data loader with epoch-based sampling and batch packing. + + Args: + data_path: Path to the JSONL data file + batch_size: Number of samples to fetch per batch (before packing) + max_tokens_per_gpu: Maximum tokens allowed per GPU + seed: Random seed for sampling + rank: Current process rank + world_size: Total number of processes + dummy_sample: Sample used for padding + num_workers: Number of data loading workers + flash_enabled: Whether flash attention is enabled (affects collation strategy) + pad_token_id: Token ID to use for padding (only used when flash_enabled=False) + + Returns: + DataLoader configured with appropriate collator based on flash_enabled + """ + dataset = TokenDataset(data_path) + sampler = EpochSampler(len(dataset), seed=seed) + + # Create unified collator with appropriate mode + collate_fn = MaxTokensPerRankCollator( + max_tokens_per_gpu, + rank=rank, + world_size=world_size, + dummy_sample=dummy_sample, + flash_enabled=flash_enabled, + pad_token_id=pad_token_id, + ) + + return DataLoader( + dataset, + batch_size, + sampler=sampler, + collate_fn=collate_fn, + num_workers=num_workers, + persistent_workers=(num_workers > 0), + drop_last=False, + ) diff --git a/src/instructlab/training/token_dataset.py b/src/instructlab/training/token_dataset.py deleted file mode 100644 index 38b3a6f9..00000000 --- a/src/instructlab/training/token_dataset.py +++ /dev/null @@ -1,138 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -# Standard -import os - -# Third Party -from datasets import load_dataset -from torch.utils.data import DataLoader, Dataset -import numpy as np -import torch - -# First Party -from instructlab.training.multipack_sampler import MultipackDistributedBatchSampler -from instructlab.training.utils import log_rank_0, make_collate_fn - - -class TokenDataset(Dataset): - def __init__(self, data_path): - self.data = load_dataset("json", data_files=data_path, split="train") - if "len" not in self.data.column_names: - self.lengths = np.array( - self.data.map( - lambda x: {"len": len(x["input_ids"])}, - num_proc=8, - )["len"] - ) - else: - self.lengths = np.array(self.data["len"]) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - item = self.data[int(idx)] - input_ids = torch.tensor(item["input_ids"], dtype=torch.long) - labels = torch.tensor(item["labels"], dtype=torch.long) - attention_mask = torch.ones_like(input_ids) - - return { - "input_ids": input_ids, - "labels": labels, - "attention_mask": attention_mask, - } - - def get_lengths(self): - return self.lengths - - -class MockDataset(Dataset): - def __init__(self, max_seq_len=4600): - self.input_ids = np.random.randint( - 0, 10000, size=(92000, max_seq_len), dtype=np.int16 - ) - self.labels = np.random.randint( - 0, 10000, size=(92000, max_seq_len), dtype=np.int16 - ) - - def __len__(self): - return len(self.input_ids) - - def __getitem__(self, idx): - input_ids = torch.tensor(self.input_ids[idx], dtype=torch.long) - labels = torch.tensor(self.labels[idx], dtype=torch.long) - attention_mask = torch.ones_like(input_ids) - - return { - "input_ids": input_ids, - "labels": labels, - "attention_mask": attention_mask, - } - - def get_lengths(self): - return np.array([len(self.input_ids[0])] * len(self.input_ids)) - - -def setup_dataset( - data_path: str, - mock: bool = False, - mock_len: int = 2600, -) -> Dataset: - if mock: - log_rank_0("Using a mock dataset.") - dataset = MockDataset(max_seq_len=mock_len) - else: - dataset = TokenDataset(data_path) - return dataset - - -def setup_dataloader( - dataset: Dataset, - pad_token_id: int, - num_workers: int = 8, - flash_enabled=True, - max_batch_len=60000, - packing_max_batch_len=60000, - samples_per_gpu=None, - sampler="multipack", - seed=47, -) -> DataLoader: - collate_fn = make_collate_fn( - pad_token_id, flash_enabled=flash_enabled, max_batch_len=max_batch_len - ) - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - - lengths = dataset.get_lengths() - if sampler == "multipack": - sampler = MultipackDistributedBatchSampler( - batch_max_length=packing_max_batch_len, - lengths=lengths, - num_replicas=world_size, - rank=rank, - seed=seed, - padding=not flash_enabled, - ) - sampler = {"batch_sampler": sampler} - elif sampler == "distributed": - # Third Party - from torch.utils.data import DistributedSampler - - sampler = ( - DistributedSampler(dataset) if torch.distributed.is_initialized() else None - ) - sampler = { - "sampler": sampler, - "batch_size": samples_per_gpu, - } - else: - raise NotImplementedError - - dataloader = DataLoader( - dataset, - **sampler, - num_workers=num_workers, - collate_fn=collate_fn, - ) - - return dataloader diff --git a/src/instructlab/training/type_definitions.py b/src/instructlab/training/type_definitions.py index 0b8866a3..986c14e9 100644 --- a/src/instructlab/training/type_definitions.py +++ b/src/instructlab/training/type_definitions.py @@ -8,8 +8,13 @@ """ # Standard +from dataclasses import dataclass import typing as t +# Third Party +# Third-party +import torch + # For Python 3.8+ compatibility try: # Standard @@ -41,6 +46,32 @@ class Message(t.TypedDict): reasoning_content: NotRequired[str] +class CollatedItem(t.TypedDict): + """ + Items being returned by the collator function. + """ + + input_ids: Required[torch.Tensor] + labels: Required[torch.Tensor] + position_ids: NotRequired[torch.Tensor] # Only required for flash attention + attention_mask: NotRequired[torch.Tensor] # Required for non-flash attention + num_samples: Required[int] + batch_num_loss_counted_tokens: Required[int] + total_length: Required[int] + num_loss_counted_tokens: Required[int] + + +class ModelInputs(t.TypedDict): + """ + These are the inputs that models will be passed + """ + + input_ids: Required[torch.Tensor] + labels: Required[torch.Tensor] + position_ids: NotRequired[torch.Tensor] + attention_mask: NotRequired[torch.Tensor] # used when not training in padding free + + class ProcessedMessagesData(t.TypedDict): """ This class represents the data generated when a single sample is @@ -50,3 +81,9 @@ class ProcessedMessagesData(t.TypedDict): input_ids: t.List[int] labels: t.List[int] len: int + + +@dataclass +class ModelLosses: + main_loss: torch.Tensor + aux_loss: torch.Tensor | None diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 15dd2897..275a4b7e 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -14,11 +14,13 @@ import subprocess import sys import time +import traceback import warnings # Third Party # pylint: disable=no-name-in-module from accelerate import Accelerator, DistributedType +from peft import LoraConfig, LoraModel from torch import distributed as dist from torch import nn from torch.distributed import get_rank, is_initialized @@ -36,6 +38,11 @@ QuantizeDataType, TrainingArgs, ) +from instructlab.training.gpt_oss_utils_correct import ( + add_gpt_oss_quantization_config, + convert_dequantized_to_quantized_format_correct, + is_gpt_oss, +) from instructlab.training.model import Model logger = logging.getLogger("instructlab.training") @@ -391,8 +398,6 @@ def save_fsdp_lora_model( model (FSDP): FSDP model as prepared by `accelerate.Accelerator` accelerator (Accelerator): The given accelerator object. """ - # Third Party - from peft import LoraConfig, LoraModel if accelerator.distributed_type != DistributedType.FSDP: raise RuntimeError( @@ -437,6 +442,86 @@ def save_fsdp_lora_model( dist.barrier() +def save_fsdp_gpt_oss_model( + model: FSDP, + tokenizer: PreTrainedTokenizer, + accelerator: Accelerator, + output_dir: Path, +): + """Save GPT-OSS model with parameter conversion, following FSDP LoRA pattern.""" + # Local + + if accelerator.distributed_type != DistributedType.FSDP: + raise RuntimeError( + "`save_fsdp_gpt_oss_model` was called when FSDP was not being used." + ) + if not wraps(model, FSDP): + raise RuntimeError( + "`save_fsdp_gpt_oss_model` was called but provided model is not an FSDP model." + ) + + logger.info("Converting GPT-OSS parameters to quantized format for compatibility") + + # Extract state dict with FSDP configuration (same as LoRA) + sd_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, sd_config): + state = model.state_dict() + + # Convert parameters on main process only (same pattern as LoRA) + if accelerator.is_main_process: + clean_state = OrderedDict() + expert_params_to_process = [] + + for name, param in state.items(): + if ( + "experts." in name + and ("down_proj" in name or "gate_up_proj" in name) + and not name.endswith("_bias") + ): + expert_params_to_process.append((name, param)) + else: + clean_state[name] = deepcopy(param).cpu() + + # Process expert parameters one by one on GPU to avoid OOM + + for clean_name, param in expert_params_to_process: + # Create mini state dict with just this parameter on GPU + mini_state = { + clean_name: param.cuda() if param.device.type == "cpu" else param + } + + # Convert this parameter + mini_converted = convert_dequantized_to_quantized_format_correct(mini_state) + + # Move all results back to CPU and add to final state + for conv_name, conv_param in mini_converted.items(): + tensor_cpu = ( + conv_param.cpu() if conv_param.device.type != "cpu" else conv_param + ) + clean_state[conv_name] = deepcopy(tensor_cpu) + + # Clean up GPU memory + del mini_state, mini_converted + torch.cuda.empty_cache() + + # Save state dict using accelerator.save + output_dir.mkdir(parents=True, exist_ok=True) + + # Use accelerator.save directly on our state dict + accelerator.save( + clean_state, output_dir / "model.safetensors", safe_serialization=True + ) + + # Save config and tokenizer + # Add quantization config before saving to avoid double-write + add_gpt_oss_quantization_config(model.config) + model.config.to_json_file(f"{output_dir}/config.json") + + tokenizer.save_pretrained(output_dir) + + dist.barrier() + + def get_module_class_from_name( model: torch.nn.Module, name: str ) -> torch.nn.Module | None: @@ -484,6 +569,25 @@ def _copy_no_lora_dict(state_dict): return cleaned_state_dict +def _copy_gpt_oss_converted_dict(state_dict): + """Copy and convert GPT-OSS state dict to quantized format.""" + # Local + + # First apply standard cleaning like LoRA does + cleaned_state_dict = OrderedDict() + for param_tensor in state_dict: + cleaned_state_dict[ + param_tensor.replace(".base_layer", "").replace("base_model.model.", "") + ] = deepcopy(state_dict[param_tensor]).cpu() + + # Then apply GPT-OSS parameter name conversion + converted_state_dict = convert_dequantized_to_quantized_format_correct( + cleaned_state_dict + ) + + return converted_state_dict + + def save_dict_accelerate( accelerator: Accelerator, state_to_save, @@ -511,6 +615,34 @@ def skip_precheck_loops(): accelerator.get_state_dict = old_get_state +def save_dict_accelerate_gpt_oss( + accelerator: Accelerator, + state_to_save, + save_directory, + max_shard_size="5GB", + safe_serialization=True, +): + """Save state dict with GPT-OSS parameter conversion (same pattern as LoRA).""" + old_get_state = accelerator.get_state_dict + accelerator.get_state_dict = _copy_gpt_oss_converted_dict + + def skip_precheck_loops(): + return [] + + # The save model does a loop over modules and params in order to determine how to get state dict. Since we already have the state dict directly, we want to bypass those checks. + state_to_save.modules = skip_precheck_loops + state_to_save.parameters = skip_precheck_loops + + accelerator.save_model( + state_to_save, + save_directory=save_directory, + max_shard_size=max_shard_size, + safe_serialization=safe_serialization, + ) + + accelerator.get_state_dict = old_get_state + + def save_hf_format_accelerate( args, model, @@ -575,7 +707,14 @@ def _get_state_dict_patched(model, unwrap=False): warnings.warn( f"Adding architectures to ckpt: {model.module.config.architectures}", ) + + # For GPT-OSS models, ensure the config has proper quantization settings + # before writing to file (avoids double-write) + if is_gpt_oss(model.module.config): + add_gpt_oss_quantization_config(model.module.config) + model.module.config.to_json_file(output_config_file) + tokenizer.save_pretrained(output_dir) if is_lora: @@ -589,12 +728,45 @@ def _get_state_dict_patched(model, unwrap=False): model.module.unmerge_adapter() if not is_lora: - accelerator.save_model( - model, - save_directory=output_dir, - max_shard_size="5GB", - safe_serialization=True, - ) + # Check if this is a GPT-OSS model that needs format conversion + + if is_gpt_oss(model.module.config): + # For GPT-OSS models, check if we need FSDP handling like LoRA does + is_fsdp_gpt_oss = accelerator.distributed_type == DistributedType.FSDP + + if is_fsdp_gpt_oss: + # Use FSDP GPT-OSS saving (same pattern as LoRA FSDP) + log_rank_0( + "Converting GPT-OSS parameters to quantized format for compatibility (FSDP)" + ) + save_fsdp_gpt_oss_model( + model=model, + tokenizer=tokenizer, + accelerator=accelerator, + output_dir=output_dir, + ) + elif accelerator.is_main_process: + # Non-FSDP path + log_rank_0( + "Converting GPT-OSS parameters to quantized format for compatibility" + ) + model_state = model.module.state_dict() + + save_dict_accelerate_gpt_oss( + accelerator, + model_state, + save_directory=output_dir, + max_shard_size="5GB", + safe_serialization=True, + ) + elif not is_gpt_oss(model.module.config): + # Standard model saving + accelerator.save_model( + model, + save_directory=output_dir, + max_shard_size="5GB", + safe_serialization=True, + ) log_rank_0(f"\033[93mModel saved in {final_output_dir}\033[0m", to_print=True) log_rank_0(f"saving took {time.time() - start} seconds") @@ -611,6 +783,7 @@ def set_random_seed(seed): torch.cuda.manual_seed_all(seed) +# TODO: move this to also live in the `Model` object def save_checkpoint( args, accelerator: Accelerator, @@ -725,3 +898,65 @@ def load_latest_full_state(args, accelerator) -> None: # previous epoch is basis for current epoch. args.__dict__["current_epoch"] = training_metadata["current_epoch"] + 1 args.__dict__["samples_seen"] = training_metadata["samples_seen"] + + +def freeze_router_params(model: Model): + """ + Freeze router parameters for GPT-OSS models before FSDP setup. + + Args: + model: The model to check and potentially freeze parameters + + Returns: + bool: True if this is a GPT-OSS model and parameters were frozen + """ + + # Freeze router parameters BEFORE accelerator setup + frozen_count = 0 + for name, param in model.named_parameters(): + if param.requires_grad and "router" in name: + param.requires_grad = False + frozen_count += 1 + logger.info(f"โ„๏ธ Frozen router parameter: {name}") + + logger.info(f"โœ… Frozen {frozen_count} router parameters for GPT-OSS model") + return True + + +def test_model_inference_quick(model, tokenizer, stage_name): + """Quick inference test to check if model outputs are coherent.""" + try: + logger.info(f"๐Ÿงช Running quick inference test at stage: {stage_name}") + + # Simple test prompt + test_prompt = "The quick brown fox" + inputs = tokenizer(test_prompt, return_tensors="pt") + + # Move inputs to model device + device = next(model.parameters()).device + inputs = {k: v.to(device) for k, v in inputs.items()} + + # Generate a few tokens + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=10, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + + # Decode and log result + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + logger.info(f"๐Ÿ”ค {stage_name} OUTPUT: '{generated_text}'") + + # Check if output looks reasonable (not just repeated tokens or gibberish) + output_tokens = generated_text.split() + if len(set(output_tokens)) < 3: + logger.warning(f"โš ๏ธ {stage_name}: Output looks repetitive/corrupted!") + else: + logger.info(f"โœ… {stage_name}: Output looks reasonable") + + except Exception as e: + logger.error(f"โŒ {stage_name} inference test failed: {e}") + + traceback.print_exc() diff --git a/tests/unit/test_collators.py b/tests/unit/test_collators.py new file mode 100644 index 00000000..a11ce385 --- /dev/null +++ b/tests/unit/test_collators.py @@ -0,0 +1,449 @@ +# Standard +from unittest.mock import MagicMock, patch +import unittest + +# Third Party +import numpy as np +import torch + +# First Party +from instructlab.training.batch_packer import batch_lengths_to_minibatches_lpt +from instructlab.training.padded_batch_packer import batch_lengths_to_minibatches_padded +from instructlab.training.sampler import MaxTokensPerRankCollator + + +class TestCollators(unittest.TestCase): + """Comprehensive tests for the unified MaxTokensPerRankCollator.""" + + def setUp(self): + """Set up test fixtures.""" + # Create a mock for distributed operations when not in distributed mode + self.dist_patcher = patch("instructlab.training.sampler.dist") + self.mock_dist = self.dist_patcher.start() + self.mock_dist.get_rank.return_value = 0 + self.mock_dist.get_world_size.return_value = 1 + + def tearDown(self): + """Clean up patches.""" + self.dist_patcher.stop() + + def create_test_batch(self, num_samples=10, min_len=50, max_len=200): + """Create a test batch with varying sequence lengths.""" + batch = [] + np.random.seed(42) # For reproducibility + for i in range(num_samples): + seq_len = np.random.randint(min_len, max_len) + input_ids = torch.randint(1, 1000, (seq_len,)) + labels = torch.where( + torch.rand(seq_len) > 0.3, input_ids, torch.tensor(-100) + ) + num_loss_counted = (labels != -100).sum().item() + + batch.append( + { + "input_ids": input_ids, + "labels": labels, + "len": seq_len, + "num_loss_counted_tokens": num_loss_counted, + } + ) + + return batch + + def test_flash_mode_output_format(self): + """Test that flash mode produces the expected output format.""" + batch = self.create_test_batch(20) + + collator = MaxTokensPerRankCollator( + max_tokens_per_rank=1024, + rank=0, + world_size=1, + flash_enabled=True, + ) + + minibatches = collator(batch) + + self.assertGreater(len(minibatches), 0, "Should produce at least one minibatch") + + for mb in minibatches: + # Check required keys for flash mode + self.assertIn("input_ids", mb) + self.assertIn("labels", mb) + self.assertIn("position_ids", mb) + self.assertIn("num_loss_counted_tokens", mb) + self.assertIn("num_samples", mb) + self.assertIn("total_length", mb) + + # Should NOT have attention_mask in flash mode + self.assertNotIn("attention_mask", mb) + + # Check tensor shapes (flash mode uses concatenated format) + self.assertEqual(mb["input_ids"].dim(), 2) # [1, total_length] + self.assertEqual(mb["input_ids"].shape[0], 1) + self.assertEqual(mb["labels"].shape, mb["input_ids"].shape) + self.assertEqual(mb["position_ids"].shape, mb["input_ids"].shape) + + def test_padded_mode_output_format(self): + """Test that padded mode produces the expected output format.""" + batch = self.create_test_batch(20) + + collator = MaxTokensPerRankCollator( + max_tokens_per_rank=1024, + rank=0, + world_size=1, + flash_enabled=False, + pad_token_id=0, + ) + + minibatches = collator(batch) + + self.assertGreater(len(minibatches), 0, "Should produce at least one minibatch") + + for mb in minibatches: + # Check required keys for padded mode + self.assertIn("input_ids", mb) + self.assertIn("labels", mb) + self.assertIn("attention_mask", mb) + self.assertIn("num_loss_counted_tokens", mb) + self.assertIn("num_samples", mb) + self.assertIn("total_length", mb) + + # Should NOT have position_ids in padded mode + self.assertNotIn("position_ids", mb) + + # Check tensor shapes (padded mode uses [batch_size, max_len]) + self.assertEqual(mb["input_ids"].dim(), 2) + self.assertEqual(mb["labels"].shape, mb["input_ids"].shape) + self.assertEqual(mb["attention_mask"].shape, mb["input_ids"].shape) + + # Verify attention mask correctness + # Attention mask should be 1 for non-padding, 0 for padding + batch_size = mb["input_ids"].shape[0] + for i in range(batch_size): + # Find non-padding positions (assuming pad_token_id=0) + non_pad_mask = mb["input_ids"][i] != 0 + # Attention mask should match non-padding positions + self.assertTrue( + torch.equal(mb["attention_mask"][i], non_pad_mask.long()), + "Attention mask should be 1 for non-padding tokens", + ) + + def test_multi_rank_distribution(self): + """Test that samples are correctly distributed across multiple ranks.""" + # Create batch with known content for tracking + batch_size = 20 + batch = [] + for i in range(batch_size): + seq_len = 100 + i * 5 # Varying lengths + # Use a unique pattern for each sample to track it + input_ids = torch.full((seq_len,), i + 1000, dtype=torch.long) + labels = torch.full((seq_len,), i + 2000, dtype=torch.long) + + batch.append( + { + "input_ids": input_ids, + "labels": labels, + "len": seq_len, + "num_loss_counted_tokens": seq_len, + } + ) + + world_size = 4 + self.mock_dist.get_world_size.return_value = world_size + + # Track which samples each rank receives + rank_sample_ids = {rank: set() for rank in range(world_size)} + total_samples_per_rank = {rank: 0 for rank in range(world_size)} + + for rank in range(world_size): + self.mock_dist.get_rank.return_value = rank + + # We need to use the actual batch packer to track indices + batch_lengths = [b["len"] for b in batch] + indices = batch_lengths_to_minibatches_lpt( + batch_lengths, 1024, world_size, rank + ) + + # Track which sample indices this rank got + for minibatch_indices in indices: + for idx in minibatch_indices: + if idx != -1: # Not a dummy sample + rank_sample_ids[rank].add(idx) + total_samples_per_rank[rank] += 1 + + # Verify no sample appears on multiple ranks + all_assigned_samples = set() + for rank, samples in rank_sample_ids.items(): + for sample_id in samples: + self.assertNotIn( + sample_id, + all_assigned_samples, + f"Sample {sample_id} assigned to multiple ranks!", + ) + all_assigned_samples.add(sample_id) + + # Verify all samples are assigned exactly once + self.assertEqual( + len(all_assigned_samples), + len(batch), + "All samples should be assigned exactly once across all ranks", + ) + + # Verify reasonable load balance + samples_counts = list(total_samples_per_rank.values()) + max_samples = max(samples_counts) + min_samples = min(samples_counts) + + # Allow some imbalance but not too much + self.assertLessEqual( + max_samples - min_samples, + len(batch) // 2, + f"Load imbalance too high: max={max_samples}, min={min_samples}", + ) + + print(f"\nSample distribution across ranks: {total_samples_per_rank}") + + def test_padded_multi_rank_distribution(self): + """Test padded mode also distributes samples correctly across ranks.""" + # Similar test but for padded mode + batch_size = 20 + batch = [] + for i in range(batch_size): + seq_len = 100 + i * 5 + input_ids = torch.full((seq_len,), i + 1000, dtype=torch.long) + labels = torch.full((seq_len,), i + 2000, dtype=torch.long) + + batch.append( + { + "input_ids": input_ids, + "labels": labels, + "len": seq_len, + "num_loss_counted_tokens": seq_len, + } + ) + + world_size = 4 + self.mock_dist.get_world_size.return_value = world_size + + rank_sample_ids = {rank: set() for rank in range(world_size)} + + for rank in range(world_size): + self.mock_dist.get_rank.return_value = rank + + batch_lengths = [b["len"] for b in batch] + indices = batch_lengths_to_minibatches_padded( + batch_lengths, 1024, world_size, rank + ) + + for minibatch_indices in indices: + for idx in minibatch_indices: + if idx != -1: + rank_sample_ids[rank].add(idx) + + # Verify no duplicates + all_assigned = set() + for rank, samples in rank_sample_ids.items(): + for sample_id in samples: + self.assertNotIn( + sample_id, + all_assigned, + f"Sample {sample_id} assigned to multiple ranks in padded mode!", + ) + all_assigned.add(sample_id) + + # All samples should be assigned + self.assertEqual( + len(all_assigned), + len(batch), + "All samples should be assigned in padded mode", + ) + + def test_accumulation_behavior(self): + """Test that batches accumulate properly across multiple calls.""" + # Create multiple batches + batch1 = self.create_test_batch(10, min_len=50, max_len=100) + batch2 = self.create_test_batch(10, min_len=100, max_len=150) + batch3 = self.create_test_batch(10, min_len=150, max_len=200) + + collator = MaxTokensPerRankCollator( + max_tokens_per_rank=512, # Smaller to force multiple minibatches + rank=0, + world_size=1, + flash_enabled=True, + ) + + # Process batches + minibatches1 = collator(batch1) + minibatches2 = collator(batch2) + minibatches3 = collator(batch3) + + # Each batch should be processed independently + self.assertGreater(len(minibatches1), 0) + self.assertGreater(len(minibatches2), 0) + self.assertGreater(len(minibatches3), 0) + + # Verify that longer sequences produce more minibatches + # (due to token limit constraints) + total_tokens1 = sum(mb["total_length"] for mb in minibatches1) + total_tokens3 = sum(mb["total_length"] for mb in minibatches3) + + # Batch 3 has longer sequences, so should have more total tokens + self.assertGreater( + total_tokens3, + total_tokens1, + "Batches with longer sequences should have more total tokens", + ) + + def test_max_tokens_constraint(self): + """Test that max_tokens_per_rank constraint is respected.""" + batch = self.create_test_batch(20, min_len=100, max_len=200) + max_tokens = 256 # Small limit to test constraint + + for flash_enabled in [True, False]: + with self.subTest(flash_enabled=flash_enabled): + collator = MaxTokensPerRankCollator( + max_tokens_per_rank=max_tokens, + rank=0, + world_size=1, + flash_enabled=flash_enabled, + ) + + minibatches = collator(batch) + + for mb in minibatches: + if flash_enabled: + # In flash mode, check total concatenated length + self.assertLessEqual( + mb["total_length"], + max_tokens, + f"Minibatch exceeds max tokens: {mb['total_length']} > {max_tokens}", + ) + else: + # In padded mode, check max_len * batch_size + batch_size, max_len = mb["input_ids"].shape + padded_tokens = batch_size * max_len + # Note: padded mode might slightly exceed due to padding + # but should be close + self.assertLessEqual( + padded_tokens, + max_tokens * 1.5, # Allow some padding overhead + f"Padded batch significantly exceeds max tokens", + ) + + def test_filtering_long_sequences(self): + """Test that sequences longer than max_tokens are filtered.""" + # Create batch with some very long sequences + batch = [] + for i in range(10): + if i < 5: + seq_len = 100 # Normal length + else: + seq_len = 2000 # Very long + + batch.append( + { + "input_ids": torch.randint(1, 1000, (seq_len,)), + "labels": torch.randint(0, 1000, (seq_len,)), + "len": seq_len, + "num_loss_counted_tokens": seq_len // 2, + } + ) + + collator = MaxTokensPerRankCollator( + max_tokens_per_rank=1024, + rank=0, + world_size=1, + flash_enabled=True, + ) + + # Capture print output to verify filtering message + # Standard + import io + import sys + + captured_output = io.StringIO() + sys.stdout = captured_output + + minibatches = collator(batch) + + sys.stdout = sys.__stdout__ + output = captured_output.getvalue() + + # Should have printed a warning about filtered samples + self.assertIn("removed", output) + self.assertIn("5 samples", output) # Should filter 5 long sequences + + def test_dummy_sample_handling(self): + """Test that dummy samples are used correctly for padding.""" + batch = self.create_test_batch(5) + world_size = 4 # More ranks than samples to force dummy usage + + self.mock_dist.get_world_size.return_value = world_size + + # Check a rank that might not get real samples + self.mock_dist.get_rank.return_value = 3 + + custom_dummy = { + "input_ids": torch.tensor([999, 998, 997], dtype=torch.long), + "labels": torch.tensor([-100, -100, -100], dtype=torch.long), + "len": 3, + "num_loss_counted_tokens": 0, + } + + collator = MaxTokensPerRankCollator( + max_tokens_per_rank=1024, + rank=3, + world_size=world_size, + dummy_sample=custom_dummy, + flash_enabled=True, + ) + + minibatches = collator(batch) + + # Should have at least one minibatch (possibly with dummy) + self.assertGreater(len(minibatches), 0) + + # Check if dummy values appear in output + for mb in minibatches: + # If this rank got a dummy sample, it should have 0 loss tokens + if mb["num_samples"] == 0: + self.assertEqual( + mb["num_loss_counted_tokens"], + 0, + "Dummy samples should not contribute to loss", + ) + + def test_mode_switching(self): + """Test that the same collator can switch between modes correctly.""" + batch = self.create_test_batch(10) + + # Test flash mode + collator_flash = MaxTokensPerRankCollator( + max_tokens_per_rank=1024, + rank=0, + world_size=1, + flash_enabled=True, + ) + + # Test padded mode + collator_padded = MaxTokensPerRankCollator( + max_tokens_per_rank=1024, + rank=0, + world_size=1, + flash_enabled=False, + pad_token_id=0, + ) + + mb_flash = collator_flash(batch) + mb_padded = collator_padded(batch) + + # Verify different output formats + self.assertIn("position_ids", mb_flash[0]) + self.assertNotIn("attention_mask", mb_flash[0]) + + self.assertNotIn("position_ids", mb_padded[0]) + self.assertIn("attention_mask", mb_padded[0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_data_process.py b/tests/unit/test_data_process.py index ba0bc4cc..d0d36f6c 100644 --- a/tests/unit/test_data_process.py +++ b/tests/unit/test_data_process.py @@ -1,27 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # Standard -from unittest.mock import Mock, patch -import tempfile +from unittest.mock import MagicMock, patch import typing as t import unittest -try: - # Third Party - import pytest - - PYTEST_AVAILABLE = True -except ImportError: - PYTEST_AVAILABLE = False - -try: - # Third Party - from transformers import AutoTokenizer, PreTrainedTokenizer - - TRANSFORMERS_AVAILABLE = True -except ImportError: - TRANSFORMERS_AVAILABLE = False - PreTrainedTokenizer = None +# Third Party +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase # First Party from instructlab.training.data_process import ( @@ -43,10 +28,8 @@ class TestComprehensiveUnmasking(unittest.TestCase): def setUp(self): """Set up test fixtures.""" # Mock tokenizer for basic tests - if TRANSFORMERS_AVAILABLE: - self.mock_tokenizer = Mock(spec=PreTrainedTokenizer) - else: - self.mock_tokenizer = Mock() + self.mock_tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + self.mock_tokenizer.name_or_path = "test-model" # Set up token IDs for unmask tokens self.unmask_begin_id = 1001 @@ -366,7 +349,8 @@ def test_long_multi_turn_conversation(self): self.assertIsInstance(result, dict) self.assertEqual(len(result["input_ids"]), len(result["labels"])) - def test_unmask_sample_function(self): + @patch("instructlab.training.data_process.is_gpt_oss_model", return_value=False) + def test_unmask_sample_function(self, mock_is_gpt_oss): """Test the unmask_sample function with various scenarios.""" sample_scenarios = [ # Basic conversation @@ -528,10 +512,8 @@ class TestReasoningContentSupport(unittest.TestCase): def setUp(self): """Set up test fixtures.""" # Mock tokenizer for basic tests - if TRANSFORMERS_AVAILABLE: - self.mock_tokenizer = Mock(spec=PreTrainedTokenizer) - else: - self.mock_tokenizer = Mock() + self.mock_tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + self.mock_tokenizer.name_or_path = "test-model" self.mock_tokenizer.encode.side_effect = ( lambda text, add_special_tokens=False: [ hash(text) % 1000 for _ in text.split() @@ -796,7 +778,8 @@ def test_unmask_messages_with_reasoning_content(self): self.assertEqual(wrapped[0]["content"], "What is 5*7?") self.assertNotIn("reasoning_content", wrapped[0]) - def test_unmask_sample_with_reasoning_content(self): + @patch("instructlab.training.data_process.is_gpt_oss_model", return_value=False) + def test_unmask_sample_with_reasoning_content(self, mock_is_gpt_oss): """Test that unmask_sample correctly processes samples with reasoning_content.""" sample = { "messages": [ @@ -820,7 +803,8 @@ def test_unmask_sample_with_reasoning_content(self): self.assertIn("labels", result) self.assertIn("len", result) - def test_unmask_sample_with_unmask_flag(self): + @patch("instructlab.training.data_process.is_gpt_oss_model", return_value=False) + def test_unmask_sample_with_unmask_flag(self, mock_is_gpt_oss): """Test that unmask_sample correctly handles the unmask flag.""" sample = { "messages": [ @@ -846,18 +830,13 @@ def test_unmask_sample_with_unmask_flag(self): self.assertIn("len", result) -@unittest.skipUnless(TRANSFORMERS_AVAILABLE, "transformers library not available") class TestReasoningContentWithRealTokenizers(unittest.TestCase): """Test reasoning_content functionality with real tokenizers.""" - @unittest.skipUnless(PYTEST_AVAILABLE, "pytest not available") def test_with_qwen_tokenizer(self): """Test reasoning_content functionality with Qwen3-32B tokenizer.""" - try: - # Use a smaller Qwen model that's more readily available - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-32B") - except Exception as e: - self.skipTest(f"Qwen tokenizer not available: {e}") + # Use a smaller Qwen model that's more readily available + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B") # Add the unmask tokens to the tokenizer tokenizer.add_special_tokens( @@ -911,15 +890,9 @@ def test_with_qwen_tokenizer(self): self.assertNotIn(unmask_begin_id, result["input_ids"]) self.assertNotIn(unmask_end_id, result["input_ids"]) - @unittest.skipUnless(PYTEST_AVAILABLE, "pytest not available") - def test_with_mistral_tokenizer(self): - """Test reasoning_content functionality with Mistral tokenizer.""" - try: - tokenizer = AutoTokenizer.from_pretrained( - "mistralai/Mistral-7B-Instruct-v0.1" - ) - except Exception as e: - self.skipTest(f"Mistral tokenizer not available: {e}") + def test_with_phi_tokenizer(self): + """Test reasoning_content functionality with Phi-4 tokenizer.""" + tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-4-mini-instruct") # Add the unmask tokens to the tokenizer tokenizer.add_special_tokens( diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index c195e7dd..4341be85 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -23,6 +23,13 @@ def mock_tokenizer(): return tokenizer +@pytest.fixture +def mock_config(): + config = MagicMock() + config.model_type = "llama" + return config + + @pytest.fixture def mock_model(): model = MagicMock() @@ -58,9 +65,12 @@ def lora_config(): ) -def test_model_initialization(mock_tokenizer, mock_model): - with patch( - "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model +def test_model_initialization(mock_tokenizer, mock_model, mock_config): + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch( + "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model + ), ): model = CausalLMModel( model_path="test_model", @@ -76,8 +86,9 @@ def test_model_initialization(mock_tokenizer, mock_model): assert model.noise_alpha is None -def test_model_with_lora(mock_tokenizer, mock_model, lora_config): +def test_model_with_lora(mock_tokenizer, mock_model, mock_config, lora_config): with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), patch( "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model ), @@ -95,9 +106,12 @@ def test_model_with_lora(mock_tokenizer, mock_model, lora_config): assert model.lora_config == lora_config -def test_reconcile_tokenizer(mock_tokenizer, mock_model): - with patch( - "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model +def test_reconcile_tokenizer(mock_tokenizer, mock_model, mock_config): + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch( + "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model + ), ): # Test case where tokenizer has more tokens than model mock_tokenizer.__len__.return_value = 33000 @@ -114,9 +128,12 @@ def test_reconcile_tokenizer(mock_tokenizer, mock_model): mock_model.resize_token_embeddings.assert_called_once() -def test_model_train_mode(mock_tokenizer, mock_model): - with patch( - "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model +def test_model_train_mode(mock_tokenizer, mock_model, mock_config): + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch( + "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model + ), ): model = CausalLMModel( model_path="test_model", @@ -135,9 +152,12 @@ def test_model_train_mode(mock_tokenizer, mock_model): mock_model.train.assert_called_with(False) -def test_model_parameters(mock_tokenizer, mock_model): - with patch( - "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model +def test_model_parameters(mock_tokenizer, mock_model, mock_config): + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch( + "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model + ), ): model = CausalLMModel( model_path="test_model", @@ -152,9 +172,12 @@ def test_model_parameters(mock_tokenizer, mock_model): mock_model.parameters.assert_called_once() -def test_model_get_projection_layers(mock_tokenizer, mock_model): - with patch( - "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model +def test_model_get_projection_layers(mock_tokenizer, mock_model, mock_config): + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch( + "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model + ), ): model = CausalLMModel( model_path="test_model", @@ -214,9 +237,12 @@ def test_model_flash_attention_check( # New tests for model initializations -def test_causal_lm_model_with_flash_attention(mock_tokenizer, mock_model): - with patch( - "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model +def test_causal_lm_model_with_flash_attention(mock_tokenizer, mock_model, mock_config): + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch( + "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model + ), ): model = CausalLMModel( model_path="test_model", @@ -230,10 +256,13 @@ def test_causal_lm_model_with_flash_attention(mock_tokenizer, mock_model): assert model.base_model_args["attn_implementation"] == "flash_attention_2" -def test_model_with_noise_alpha(mock_tokenizer, mock_model): +def test_model_with_noise_alpha(mock_tokenizer, mock_model, mock_config): mock_model.__class__.__name__ = "LlamaForCausalLM" - with patch( - "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model + with ( + patch("transformers.AutoConfig.from_pretrained", return_value=mock_config), + patch( + "transformers.AutoModelForCausalLM.from_pretrained", return_value=mock_model + ), ): model = CausalLMModel( model_path="test_model", diff --git a/tests/unit/test_padded_batch_packer.py b/tests/unit/test_padded_batch_packer.py new file mode 100644 index 00000000..110edbb5 --- /dev/null +++ b/tests/unit/test_padded_batch_packer.py @@ -0,0 +1,235 @@ +# Standard +from unittest.mock import patch +import unittest + +# Third Party +import numpy as np + +# First Party +from instructlab.training.padded_batch_packer import ( + batch_lengths_to_minibatches_padded, + compute_padding_stats, +) + + +class TestPaddedBatchPacker(unittest.TestCase): + """Unit tests for the padded batch packer algorithm.""" + + def test_empty_input(self): + """Test that empty input returns empty list.""" + result = batch_lengths_to_minibatches_padded([], 100, 2, 0) + self.assertEqual(result, []) + + def test_single_sequence(self): + """Test single sequence with single rank.""" + lengths = [50] + result = batch_lengths_to_minibatches_padded(lengths, 100, 1, 0) + self.assertEqual(len(result), 1) + self.assertEqual(result[0], [0]) + + def test_all_sequences_assigned(self): + """Test that all sequences are assigned exactly once across ranks.""" + lengths = [30, 40, 50, 20, 60, 70, 80, 90] + num_ranks = 4 + max_tokens = 200 + + # Collect all assigned indices across ranks + all_indices = set() + for rank in range(num_ranks): + batches = batch_lengths_to_minibatches_padded( + lengths, max_tokens, num_ranks, rank + ) + for batch in batches: + for idx in batch: + if idx != -1: + all_indices.add(idx) + + # Check all sequences are assigned + self.assertEqual(sorted(all_indices), list(range(len(lengths)))) + + def test_max_tokens_constraint(self): + """Test that batches respect max tokens constraint.""" + lengths = [80, 60, 40, 90, 50, 70] + max_tokens = 100 + num_ranks = 2 + + for rank in range(num_ranks): + batches = batch_lengths_to_minibatches_padded( + lengths, max_tokens, num_ranks, rank + ) + + for batch in batches: + if not batch or batch[0] == -1: + continue + + # Calculate padded tokens for this batch + max_len = max(lengths[idx] for idx in batch) + padded_tokens = max_len * len(batch) + + self.assertLessEqual( + padded_tokens, + max_tokens, + f"Batch exceeds max tokens: {padded_tokens} > {max_tokens}", + ) + + def test_padding_efficiency_similar_lengths(self): + """Test that similar-length sequences have minimal padding.""" + # Sequences with similar lengths + lengths = [98, 95, 97, 96, 94, 99, 93, 96] + batches = batch_lengths_to_minibatches_padded(lengths, 400, 1, 0) + stats = compute_padding_stats(lengths, batches) + + # Should have very low padding ratio + self.assertLess( + stats["padding_ratio"], + 0.1, + f"Similar lengths should have low padding, got {stats['padding_ratio']:.2%}", + ) + + def test_same_length_zero_padding(self): + """Test that same-length sequences have zero padding.""" + lengths = [50] * 10 + batches = batch_lengths_to_minibatches_padded(lengths, 200, 2, 0) + stats = compute_padding_stats(lengths, batches) + + self.assertEqual(stats["padding_ratio"], 0.0) + + def test_load_balancing(self): + """Test that load is reasonably balanced across ranks.""" + lengths = list(np.random.randint(20, 200, size=50)) + num_ranks = 4 + max_tokens = 500 + + rank_tokens = {} + for rank in range(num_ranks): + batches = batch_lengths_to_minibatches_padded( + lengths, max_tokens, num_ranks, rank + ) + stats = compute_padding_stats(lengths, batches) + rank_tokens[rank] = stats["total_padded_tokens"] + + # Check load balance + if max(rank_tokens.values()) > 0: + max_load = max(rank_tokens.values()) + min_load = min(rank_tokens.values()) + imbalance = (max_load - min_load) / max_load + + # Allow up to 50% imbalance for small test cases + self.assertLess( + imbalance, + 0.5, + f"Load imbalance too high: {imbalance:.2%}", + ) + + def test_accumulation_steps(self): + """Test batch packing with gradient accumulation steps.""" + # Simulate dataset split across accumulation steps + total_lengths = list(np.random.randint(20, 100, size=100)) + num_ranks = 4 + accumulation_steps = 4 + max_tokens = 200 + + samples_per_step = len(total_lengths) // accumulation_steps + all_processed = set() + + for step in range(accumulation_steps): + start = step * samples_per_step + end = min((step + 1) * samples_per_step, len(total_lengths)) + step_lengths = total_lengths[start:end] + + for rank in range(num_ranks): + batches = batch_lengths_to_minibatches_padded( + step_lengths, max_tokens, num_ranks, rank + ) + + for batch in batches: + for idx in batch: + if idx != -1: + global_idx = start + idx + all_processed.add(global_idx) + + # Verify all samples in the steps were processed + expected = set(range(accumulation_steps * samples_per_step)) + self.assertEqual(all_processed, expected) + + def test_deterministic_output(self): + """Test that the algorithm produces deterministic results.""" + lengths = [100, 90, 80, 70, 60, 50, 40, 30, 20, 10] + max_tokens = 200 + num_ranks = 2 + + # Run multiple times and verify same output + results = [] + for _ in range(3): + rank_results = {} + for rank in range(num_ranks): + batches = batch_lengths_to_minibatches_padded( + lengths, max_tokens, num_ranks, rank + ) + rank_results[rank] = batches + results.append(rank_results) + + # Check all runs produce same results + for i in range(1, len(results)): + for rank in range(num_ranks): + self.assertEqual( + results[0][rank], + results[i][rank], + f"Non-deterministic output for rank {rank}", + ) + + def test_edge_case_single_long_sequence(self): + """Test handling of sequences longer than max_tokens.""" + lengths = [1000] # Much longer than typical max_tokens + max_tokens = 100 + + batches = batch_lengths_to_minibatches_padded(lengths, max_tokens, 1, 0) + + # Should still process the sequence even if it exceeds max_tokens + all_indices = [] + for batch in batches: + all_indices.extend([idx for idx in batch if idx != -1]) + + # The long sequence should be included + self.assertIn(0, all_indices) + + def test_max_batch_size_constraint(self): + """Test that max_batch_size parameter is respected.""" + lengths = [10] * 20 # Many small sequences + max_tokens = 1000 # Large enough to fit many sequences + max_batch_size = 5 + + batches = batch_lengths_to_minibatches_padded( + lengths, max_tokens, 1, 0, max_batch_size=max_batch_size + ) + + for batch in batches: + if batch and batch[0] != -1: + self.assertLessEqual( + len(batch), + max_batch_size, + f"Batch size {len(batch)} exceeds max {max_batch_size}", + ) + + def test_padding_stats_computation(self): + """Test padding statistics computation.""" + lengths = [100, 90, 80, 70] + # Manually create batches for testing + batches = [[0, 1], [2, 3]] # Two batches of 2 sequences each + + stats = compute_padding_stats(lengths, batches) + + # Batch 1: max=100, sequences=100+90=190, padded=200 + # Batch 2: max=80, sequences=80+70=150, padded=160 + # Total: actual=340, padded=360, padding=20 + + self.assertEqual(stats["total_sequences"], 4) + self.assertEqual(stats["total_actual_tokens"], 340) + self.assertEqual(stats["total_padded_tokens"], 360) + self.assertEqual(stats["total_padding_tokens"], 20) + self.assertAlmostEqual(stats["padding_ratio"], 20 / 360, places=4) + self.assertEqual(stats["num_batches"], 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_unmask_messages.py b/tests/unit/test_unmask_messages.py index 700fe0c0..7b2de057 100644 --- a/tests/unit/test_unmask_messages.py +++ b/tests/unit/test_unmask_messages.py @@ -896,7 +896,10 @@ def mock_apply_chat_template(messages, **kwargs): tokenizer.apply_chat_template = mock_apply_chat_template return tokenizer - def test_unmask_sample_unmask_false_logic(self, mock_tokenizer_for_unmask_sample): + @patch("instructlab.training.data_process.is_gpt_oss_model", return_value=False) + def test_unmask_sample_unmask_false_logic( + self, mock_is_gpt_oss, mock_tokenizer_for_unmask_sample + ): """Test unmask: False logic - should only unmask assistant role.""" sample = { "messages": [ @@ -929,7 +932,10 @@ def test_unmask_sample_unmask_false_logic(self, mock_tokenizer_for_unmask_sample f"Assistant token {token} should be unmasked" ) - def test_unmask_sample_unmask_true_logic(self, mock_tokenizer_for_unmask_sample): + @patch("instructlab.training.data_process.is_gpt_oss_model", return_value=False) + def test_unmask_sample_unmask_true_logic( + self, mock_is_gpt_oss, mock_tokenizer_for_unmask_sample + ): """Test unmask: True logic - should unmask user and assistant, but not system.""" sample = { "messages": [ @@ -963,7 +969,10 @@ def test_unmask_sample_unmask_true_logic(self, mock_tokenizer_for_unmask_sample) f"User/Assistant token {token} should be unmasked" ) - def test_unmask_sample_comparison(self, mock_tokenizer_for_unmask_sample): + @patch("instructlab.training.data_process.is_gpt_oss_model", return_value=False) + def test_unmask_sample_comparison( + self, mock_is_gpt_oss, mock_tokenizer_for_unmask_sample + ): """Test that unmask: True unmasks more tokens than unmask: False.""" sample_base = { "messages": [