Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ wheel>=0.43
pyyaml
py-cpuinfo
torch>=2.6.0
transformers>=4.45.2
transformers>=4.55.0

datasets>=2.15.0
numba
Expand Down
47 changes: 25 additions & 22 deletions src/instructlab/training/accelerator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
# Standard
from copy import deepcopy
from typing import Callable, Optional
from functools import partial
from typing import Optional
import logging

# Third Party
from accelerate import Accelerator as TransformersAccel
from accelerate.utils import DeepSpeedPlugin, FullyShardedDataParallelPlugin
from peft.utils.other import fsdp_auto_wrap_policy
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch.utils.data import DataLoader
from transformers import get_scheduler
import torch
Expand All @@ -13,10 +20,13 @@
DeepSpeedOptions,
DistributedBackend,
)
from instructlab.training.utils import get_module_class_from_name

# Local
from .model import Model

logger = logging.getLogger(__name__)


class Accelerator:
def __init__(
Expand All @@ -32,6 +42,7 @@ def __init__(
deepspeed_cpu_offload_optimizer_pin_memory: Optional[bool] = False,
deepspeed_cpu_offload_optimizer_ratio: Optional[float] = None,
fsdp_cpu_offload_params: Optional[bool] = False,
fsdp_use_orig_params: Optional[bool] = False,
):
self.samples_per_gpu = samples_per_gpu
self.save_samples = save_samples
Expand All @@ -48,7 +59,8 @@ def __init__(
deepspeed_cpu_offload_optimizer_ratio
)
self.fsdp_cpu_offload_params = fsdp_cpu_offload_params

self.fsdp_use_orig_params = fsdp_use_orig_params
self.lr_scheduler = None
if self.distributed_framework == DistributedBackend.DEEPSPEED:
# Standard
accel_args = {
Expand Down Expand Up @@ -84,7 +96,6 @@ def prepare_with_optimizer(
num_epochs: int,
num_warmup_steps: int,
):
self.lr_scheduler: Callable
self.setup_lr_scheduler(
optimizer=optimizer,
lr_scheduler=lr_scheduler,
Expand Down Expand Up @@ -120,19 +131,6 @@ def __getattr__(self, name):
return getattr(self.accelerator, name)

def get_fsdp_config(self):
# Standard
from functools import partial

# Third Party
from accelerate.utils import FullyShardedDataParallelPlugin
from peft.utils.other import fsdp_auto_wrap_policy
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

# First Party
from instructlab.training.utils import get_module_class_from_name

is_lora = self.model.lora_config is not None
block_name = self.model._no_split_modules[0]

Expand All @@ -158,20 +156,16 @@ def get_fsdp_config(self):
backward_prefetch=prefetch_policy,
sharding_strategy=ShardingStrategy[self.fsdp_sharding_strategy],
cpu_offload=CPUOffload(self.fsdp_cpu_offload_params),
use_orig_params=self.fsdp_use_orig_params,
# TODO(osilkin): expose switch for fp32 reduction
)

# `use_orig_params` must be disabled when using LoRA and FSDP together
# Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts
if self.model.lora_config is not None:
fsdp_plugin.use_orig_params = False

return fsdp_plugin

def get_ds_plugin(
self, world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOptions
):
# Third Party
from accelerate.utils import DeepSpeedPlugin

ds_config = {
"train_batch_size": samples_per_gpu * world_size * grad_accum,
Expand Down Expand Up @@ -248,3 +242,12 @@ def setup_fsdp(
fsdp_cpu_offload_params=fsdp_cpu_offload_params,
save_samples=save_samples,
)

def take_optimizer_step(self):
"""
Take an optimizer step and update the learning rate scheduler.
"""
self.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
188 changes: 188 additions & 0 deletions src/instructlab/training/batch_loss_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# SPDX-License-Identifier: Apache-2.0
"""
Batch loss management for distributed training.

This module provides utilities for managing loss computation, accumulation,
and reduction across distributed training environments.
"""

# Standard
from dataclasses import dataclass
import logging

# Third Party
import torch
import torch.distributed

# First Party
from instructlab.training.accelerator import Accelerator
from instructlab.training.model import Model
from instructlab.training.type_definitions import CollatedItem, ModelInputs

logger = logging.getLogger("instructlab.training")


@dataclass
class BatchMetrics:
"""Metrics collected during batch processing."""

total_samples: int
total_length: int
num_loss_counted_tokens: int
accumulated_loss: torch.Tensor
accumulated_aux_loss: torch.Tensor | None
grad_accum_steps: int
num_minibatches: int


class BatchLossManager:
"""
Manages loss computation and metrics collection for batches in distributed training.

This class handles:
- Processing minibatches within a batch
- Accumulating losses across minibatches
- Reducing metrics across distributed ranks
- Computing average losses for logging
"""

def __init__(self, model, accelerator, world_size: int, local_rank: int):
"""
Initialize the BatchLossManager.

Args:
model: The model used for training
accelerator: The accelerator instance for distributed training
world_size: Number of distributed processes
local_rank: Local rank of the current process
"""
self.model: Model = model
self.accelerator: Accelerator = accelerator
self.world_size: int = world_size
self.local_rank: int = local_rank
self.torch_device = torch.device("cuda", local_rank)

def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]:
"""
Process a batch of minibatches, computing losses and accumulating gradients.

Args:
batch: List of minibatches to process

Returns:
tuple: (BatchMetrics, average_loss_across_ranks)
"""
# extract batch-level info (same across all minibatches)
batch_num_loss_counted_tokens = batch[0]["batch_num_loss_counted_tokens"]
num_minibatches = len(batch)

# initialize accumulation variables
batch_total_samples = 0
batch_total_length = 0
accumulated_loss = 0.0
accumulated_aux_loss = 0.0
grad_accum_steps = 0

# process each minibatch
for mb in batch:
# extract minibatch-specific info
micro_batch_size = mb["num_samples"]
total_length = mb["total_length"]

# accumulate minibatch metrics
batch_total_samples += micro_batch_size
batch_total_length += total_length

# prepare model inputs
model_inputs = self._prepare_model_inputs(mb)

# compute loss and backward pass
scaled_loss, raw_losses = self.model.compute_loss(
model_inputs, self.world_size, batch_num_loss_counted_tokens
)
self.accelerator.backward(scaled_loss)

# accumulate losses
grad_accum_steps += 1
accumulated_loss += raw_losses.main_loss
if raw_losses.aux_loss is not None:
accumulated_aux_loss += raw_losses.aux_loss

# reduce metrics across ranks
batch_total_samples, batch_total_length = self._reduce_metrics(
batch_total_samples, batch_total_length
)

# calculate average loss across all ranks
avg_loss_across_ranks = self._compute_average_loss(
accumulated_loss, accumulated_aux_loss, batch_num_loss_counted_tokens
)

# create metrics object
metrics = BatchMetrics(
total_samples=int(batch_total_samples),
total_length=int(batch_total_length),
num_loss_counted_tokens=int(batch_num_loss_counted_tokens),
accumulated_loss=accumulated_loss,
accumulated_aux_loss=accumulated_aux_loss,
grad_accum_steps=grad_accum_steps,
num_minibatches=num_minibatches,
)

return metrics, avg_loss_across_ranks

def _prepare_model_inputs(self, mb: CollatedItem) -> ModelInputs:
"""Prepare and move model inputs to GPU."""
model_inputs = ModelInputs(
input_ids=mb["input_ids"].to(device=self.torch_device),
labels=mb["labels"].to(device=self.torch_device),
)

# add optional fields onto `model_inputs` object
if "attention_mask" in mb:
model_inputs["attention_mask"] = mb["attention_mask"].to(
device=self.torch_device
)
if "position_ids" in mb:
model_inputs["position_ids"] = mb["position_ids"].to(
device=self.torch_device
)

return model_inputs

def _reduce_metrics(
self, batch_total_samples: int, batch_total_length: int
) -> tuple[int, int]:
"""Reduce rank-specific metrics across devices."""
inputs_to_reduce = torch.tensor(
[batch_total_samples, batch_total_length],
dtype=torch.int32,
device=self.accelerator.device,
)

reduced_outputs = self.accelerator.reduce(inputs_to_reduce, reduction="sum")
return reduced_outputs[0].item(), reduced_outputs[1].item()

def _compute_average_loss(
self,
accumulated_loss: torch.Tensor,
accumulated_aux_loss: torch.Tensor | None,
batch_num_loss_counted_tokens: int,
) -> float:
"""Compute average loss across all ranks for metrics logging."""
# calculate total batch loss
total_batch_loss = (
accumulated_loss * self.world_size / batch_num_loss_counted_tokens
)
if self.model.is_gpt_oss and accumulated_aux_loss is not None:
total_batch_loss += accumulated_aux_loss

# reduce across ranks
avg_loss_across_ranks = self.accelerator.reduce(
torch.tensor(
total_batch_loss.detach().item(), device=self.accelerator.device
),
reduction="mean",
).item()

return avg_loss_across_ranks
Loading
Loading