Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/instructlab/training/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from typing import Callable, Optional

# Third Party
from accelerate import Accelerator as TransformersAccel
from instructlab.training.hpu_utils import is_torch_hpu_available
if is_torch_hpu_available():
from optimum.habana.accelerate import GaudiAccelerator as TransformersAccel
else:
from accelerate import Accelerator as TransformersAccel

from torch.utils.data import DataLoader
from transformers import get_scheduler
import torch
Expand Down Expand Up @@ -124,7 +129,11 @@ def get_fsdp_config(self):
from functools import partial

# Third Party
from accelerate.utils import FullyShardedDataParallelPlugin
if is_torch_hpu_available():
from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin
else:
from accelerate.utils import FullyShardedDataParallelPlugin

from peft.utils.other import fsdp_auto_wrap_policy
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
Expand Down Expand Up @@ -152,14 +161,18 @@ def get_fsdp_config(self):
prefetch_policy = (
BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE
)
fsdp_plugin = FullyShardedDataParallelPlugin(
fsdp_plugin = (GaudiFullyShardedDataParallelPlugin if is_torch_hpu_available() else FullyShardedDataParallelPlugin)(
auto_wrap_policy=wrap_policy,
limit_all_gathers=True,
backward_prefetch=prefetch_policy,
sharding_strategy=ShardingStrategy[self.fsdp_sharding_strategy],
cpu_offload=CPUOffload(self.fsdp_cpu_offload_params),
)

if is_torch_hpu_available():
fsdp_plugin.use_orig_params=True
fsdp_plugin.sync_module_states=True

# `use_orig_params` must be disabled when using LoRA and FSDP together
# Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts
if self.model.lora_config is not None:
Expand Down
62 changes: 62 additions & 0 deletions src/instructlab/training/hpu_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import torch
from functools import lru_cache


@lru_cache(maxsize=None)
def is_torch_hpu_available() -> bool:
try:
import habana_frameworks.torch.core # noqa: F401
except ImportError:
return False
return True


def simple_bucket(length):
"""
This bucket algorithm merely relies on the given number instead of based on
slicing the known (min, max) range for several reasons:
1) Due to the use of the first-fit-decreasing (FFD) algorithm, the
(min, max) sequence length of each rank will be much smaller than the
(min, max) sequence length of the dataset. Bucketing on the
(min, max) sequence length of the dataset is not practical
2) The (min, max) sequence length of a given rank is unknown until
finishing 1 epoch since the packing is done on the fly
3) Due to the shuffling, the (min, max) sequence length of a given rank
may vary between ranks. Once the (min, max) sequence length of a
given rank changes, the bucketing also needs adjustment

This bucket algorithm is based on the most significant set bit of the input number.
It first check what’s the most significant set bit, assuming it's bit "S",
and then slice the range [2 ** S, 2 ** (S+1)] into buckets with the same size.
By default the range is divided into 16 buckets, so the bucket size will be
2 ** (S - 4)
For example, 0b10001 will be padded to 0b10010.
This approach can limit the overhead of bucketing (at most 1/16 of the input
number) and also prevent recompilation due to a too small bucket size.
"""
l = length
msb = 0
while l > 0:
msb += 1
l = l // 2

align = (1 << (msb - 4)) if msb >= 4 else 1

return (length + align - 1) // align * align


def bucket(length):
return simple_bucket(length)


def save_hpu_model(model, output_dir):
from safetensors.torch import save_file

state_dict = model.state_dict()
remove_prefix = "_orig_mod."
clean_state_dict = {
k[len(remove_prefix) :] if k.startswith(remove_prefix) else k: v
for k, v in state_dict.items()
}
save_file(clean_state_dict, os.path.join(output_dir, "model.safetensors"))
Loading
Loading