diff --git a/docs/hpu.md b/docs/hpu.md new file mode 100644 index 00000000..357015c3 --- /dev/null +++ b/docs/hpu.md @@ -0,0 +1,61 @@ +# InstructLab Training on HPU + +## HPU specific changes +Next changes are required to enable training on HPU: + +|GPU|HPU| +|---|---| +|`from accelerate import Accelerator` | `from optimum.habana.accelerate import GaudiAccelerator`| +|`from accelerate.utils import FullyShardedDataParallelPlugin` | `from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin` | + +It is also recommended to use HPU optimized versions of transformers: + +```Python +from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +adapt_transformers_to_gaudi() +``` + +## Bucketing +Multipack sampler implementation produces wide range of batches with different sample lengths and number of samples. Each of these combinations leads to graph recompilation and this recompilation takes time and slows down training. To reduce number of recompilations HPU implementation uses bucketing approach, when maximum sample length in batch is aligned to some predefined value. It is similar to padding but all samples in the batch are padded not to the longest sample but to the some slightly bigger value. + +![bucketing vs. padding](./hpu_pic/bucketing_vs_padding.png) + + +To compute bucked size, we use next algorithm: +- Firstly, we find MSB of the longest sample in the batch, let's call it S. +- Then we slice the range [2 ** S, 2 ** (S+1)] into 16 buckets of the same size. +- Then we use top boundary of the smallest suitable bucked as padding value. + +This approach limits overhead of the bucketing to 1/16 th of the longest sample and allows us to significantly reduce number of recompilations. + +## How to run +To run training build docker using next dockerfile: +```Dockerfile +FROM vault.habana.ai/gaudi-docker/1.21.0/rhel9.4/habanalabs/pytorch-installer-2.6.0:1.21.0-555 + +ARG CMAKE_ARGS="-DGGML_NATIVE=off" + +WORKDIR /app +RUN pip install git+https://github.com/instructlab/instructlab.git@v0.26.1 + +WORKDIR /app +RUN pip install git+https://github.com/huggingface/optimum-habana.git@v1.18.0 +``` + +Then make next changes to config file: +```YAML +train: + device: hpu + distributed_backend: fsdp + fsdp_cpu_offload_optimizer: false + is_padding_free: true + pipeline: accelerated + disable_flash_attn: true +``` + +And finally run this command line: +```BASH +ilab --config=./config.yaml model train --pipeline accelerated --data-path ./data.jsonl +``` + + diff --git a/docs/hpu_pic/bucketing_vs_padding.png b/docs/hpu_pic/bucketing_vs_padding.png new file mode 100644 index 00000000..cfe1b365 Binary files /dev/null and b/docs/hpu_pic/bucketing_vs_padding.png differ diff --git a/src/instructlab/training/accelerator.py b/src/instructlab/training/accelerator.py index b03c4a45..a6b6950f 100644 --- a/src/instructlab/training/accelerator.py +++ b/src/instructlab/training/accelerator.py @@ -3,7 +3,6 @@ from typing import Callable, Optional # Third Party -from accelerate import Accelerator as TransformersAccel from torch.utils.data import DataLoader from transformers import get_scheduler import torch @@ -32,6 +31,7 @@ def __init__( deepspeed_cpu_offload_optimizer_pin_memory: Optional[bool] = False, deepspeed_cpu_offload_optimizer_ratio: Optional[float] = None, fsdp_cpu_offload_params: Optional[bool] = False, + device: Optional[str] = None, ): self.samples_per_gpu = samples_per_gpu self.save_samples = save_samples @@ -48,6 +48,7 @@ def __init__( deepspeed_cpu_offload_optimizer_ratio ) self.fsdp_cpu_offload_params = fsdp_cpu_offload_params + self.device_str = device if self.distributed_framework == DistributedBackend.DEEPSPEED: # Standard @@ -69,6 +70,12 @@ def __init__( "fsdp_plugin": self.get_fsdp_config(), "mixed_precision": "bf16", } + + if device == "hpu": + from optimum.habana.accelerate import GaudiAccelerator as TransformersAccel + else: + from accelerate import Accelerator as TransformersAccel + self.accelerator = TransformersAccel( **accel_args, ) @@ -160,6 +167,10 @@ def get_fsdp_config(self): cpu_offload=CPUOffload(self.fsdp_cpu_offload_params), ) + if self.device_str == "hpu": + fsdp_plugin.use_orig_params=True + fsdp_plugin.sync_module_states=True + # `use_orig_params` must be disabled when using LoRA and FSDP together # Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts if self.model.lora_config is not None: diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 3c4e65f3..3743cdc4 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -246,3 +246,5 @@ class TrainingArgs(BaseModel): log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field( default="INFO" ) + + device: Optional[str] = None diff --git a/src/instructlab/training/hpu_utils.py b/src/instructlab/training/hpu_utils.py new file mode 100644 index 00000000..15a951b7 --- /dev/null +++ b/src/instructlab/training/hpu_utils.py @@ -0,0 +1,49 @@ +import torch +from functools import lru_cache + + +@lru_cache(maxsize=None) +def is_torch_hpu_available() -> bool: + try: + import habana_frameworks.torch.core # noqa: F401 + except ImportError: + return False + return True + + +def simple_bucket(length): + """ + This bucket algorithm merely relies on the given number instead of based on + slicing the known (min, max) range for several reasons: + 1) Due to the use of the first-fit-decreasing (FFD) algorithm, the + (min, max) sequence length of each rank will be much smaller than the + (min, max) sequence length of the dataset. Bucketing on the + (min, max) sequence length of the dataset is not practical + 2) The (min, max) sequence length of a given rank is unknown until + finishing 1 epoch since the packing is done on the fly + 3) Due to the shuffling, the (min, max) sequence length of a given rank + may vary between ranks. Once the (min, max) sequence length of a + given rank changes, the bucketing also needs adjustment + + This bucket algorithm is based on the most significant set bit of the input number. + It first check what’s the most significant set bit, assuming it's bit "S", + and then slice the range [2 ** S, 2 ** (S+1)] into buckets with the same size. + By default the range is divided into 16 buckets, so the bucket size will be + 2 ** (S - 4) + For example, 0b10001 will be padded to 0b10010. + This approach can limit the overhead of bucketing (at most 1/16 of the input + number) and also prevent recompilation due to a too small bucket size. + """ + l = length + msb = 0 + while l > 0: + msb += 1 + l = l // 2 + + align = (1 << (msb - 4)) if msb >= 4 else 1 + + return (length + align - 1) // align * align + + +def bucket(length): + return simple_bucket(length) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index df166fd4..d03565d3 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -33,6 +33,14 @@ UserWarning, ) +from instructlab.training.hpu_utils import is_torch_hpu_available + +if is_torch_hpu_available(): + import habana_frameworks.torch.core as htcore + import habana_frameworks.torch.distributed.hccl + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + adapt_transformers_to_gaudi() + # Third Party from tqdm import tqdm from transformers import AutoConfig @@ -122,7 +130,7 @@ def train( if local_rank == 0: inner_pb = tqdm(range(num_epoch_steps), desc=f"Epoch {epoch}") - # blast through the batches in the train loader up to the last step within the epoch. + # blast through the batches in the train loader up to the last step within the epoch. for batch in accelerator.train_loader: if global_step <= args.last_step: # in the case of resuming, last_step > 0 @@ -137,10 +145,19 @@ def train( micro_batch_size = float(torch.tensor([batch.pop("num_samples")])) total_length = float(torch.tensor([batch.pop("total_length")])) for k in batch: - batch[k] = batch[k].to(local_rank) + batch[k] = batch[k].to('hpu' if args.device == "hpu" else local_rank) + + hpu_args = {} + if args.device == "hpu": + hpu_args = { + "use_flash_attention":True, + "lazy_mode":False, + } + output = model( **batch, use_cache=False, + **hpu_args, ) loss = output.loss log_loss = loss.detach().item() @@ -177,8 +194,14 @@ def train( elapsed_time = time.time() - start overall_throughput = args.samples_per_gpu * world_size / elapsed_time current_lr = accelerator.lr_scheduler.get_last_lr()[0] - cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3) - cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] + + if args.device == "hpu": + mem_allocated = torch.hpu.memory_allocated() / (1024**3) + malloc_retries = 0 + else: + mem_allocated = torch.cuda.memory_allocated() / (1024**3) + malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] + global_grad_norm = ( model.get_global_grad_norm() if hasattr(model, "get_global_grad_norm") @@ -200,8 +223,8 @@ def train( "rank": torch.distributed.get_rank(), "overall_throughput": overall_throughput, "lr": current_lr, - "cuda_mem_allocated": cuda_mem_allocated, - "cuda_malloc_retries": cuda_malloc_retries, + ("hpu" if args.device == "hpu" else "cuda") + "_mem_allocated": mem_allocated, + ("hpu" if args.device == "hpu" else "cuda") + "_malloc_retries": malloc_retries, "num_loss_counted_tokens": int(num_loss_counted_tokens), "num_tokens_rank0": int(total_length), "batch_size": int(micro_batch_size), @@ -234,7 +257,10 @@ def train( global_step += 1 if local_rank == 0: inner_pb.update(1) - torch.cuda.empty_cache() + + if args.device != "hpu": + torch.cuda.empty_cache() + if args.checkpoint_at_epoch: base_logger.debug(f"Saving checkpoint at epoch {epoch}") save_checkpoint( @@ -312,17 +338,24 @@ def main(args): args.model_type = model_conf.model_type #### distributed init ##### - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + if args.device == "hpu": + torch.hpu.set_device(int(os.environ["LOCAL_RANK"])) + else: + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + args.local_rank = int(os.environ["LOCAL_RANK"]) timeout = _get_collective_timeout() - if timeout is not None: - torch.distributed.init_process_group(timeout=timeout) - else: - torch.distributed.init_process_group() + backend = "hccl" if args.device == "hpu" else None + torch.distributed.init_process_group(backend=backend, timeout=timeout) args.global_rank = torch.distributed.get_rank() - tensor = torch.ByteTensor([False]).cuda() + + if args.device == "hpu": + tensor = torch.ByteTensor([False]).to('hpu') + else: + tensor = torch.ByteTensor([False]).cuda() + torch.distributed.all_reduce(tensor) torch.distributed.barrier() @@ -369,6 +402,7 @@ def main(args): flash_enabled=flash_enabled, noise_alpha=args.NEFTune_alpha, lora_quant_bits=args.lora_quant_bits, + device=args.device, ) args.base_model_args = m.base_model_args @@ -407,6 +441,7 @@ def main(args): samples_per_gpu=args.samples_per_gpu, sampler=args.sampler, seed=args.seed, + device=args.device, ) if len(train_loader) == 0: # this happens sometimes when we have more GPUs than data to process. In this case @@ -426,6 +461,7 @@ def main(args): samples_per_gpu=args.samples_per_gpu, sampler=args.sampler, seed=args.seed, + device=args.device, ) if args.local_rank == 0: @@ -457,6 +493,7 @@ def main(args): deepspeed_cpu_offload_optimizer_ratio=args.cpu_offload_optimizer_ratio, fsdp_cpu_offload_params=args.cpu_offload_params_fsdp, save_samples=args.save_samples, + device=args.device, ) # optimizer needs model that has been prepared by accelerator # and then accelerator needs to be prepared AGAIN once optimizer is initialized @@ -636,6 +673,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.keep_last_checkpoint_only: command.append("--keep_last_checkpoint_only") + command.append( + f"--device={train_args.device}" + ) + logger.info("Running training command as subprocess: %s", " ".join(command)) process = None interrupt: KeyboardInterrupt | Exception | None = None @@ -837,6 +878,14 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: action="store_true", help="Use Liger kernels for training.", ) + + parser.add_argument( + "--device", + type=str, + default=None, + help="PyTorch device to use.", + ) + args = parser.parse_args() set_random_seed(args.seed) main(args) diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index 24eac063..fd72dfc3 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -50,11 +50,13 @@ def __init__( flash_enabled: bool = False, lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, + device: Optional[str] = None, ): self.lora_config = lora_config self.noise_alpha = noise_alpha self.tokenizer = tokenizer self.distributed_framework = distributed_framework + self.device = device bnb_config = None if lora_config and lora_config.r > 0 and lora_quant_bits == 4: # Third Party @@ -78,6 +80,15 @@ def __init__( def _post_model_init(self): """Common initialization steps that should happen after model initialization.""" + + if self.device == "hpu" and os.getenv("HPU_ENABLE_TORCH_COMPILE", False): + cache_size_limit = 10*1000 + torch._dynamo.config.cache_size_limit = cache_size_limit + torch._dynamo.config.accumulated_cache_size_limit = 2*cache_size_limit + self.model = torch.compile(self.model, backend="hpu_backend", dynamic=False) + for layer in self.model.model.layers: + layer.compile(backend="hpu_backend", dynamic=False) + self.reconcile_tokenizer() if self.lora_config: self.model = self.prepare_peft_model() @@ -246,7 +257,11 @@ def _is_causal_lm_model(self) -> bool: bool: True if the model is a causal language model, False otherwise. """ # Third Party - return "ForCausalLM" in self.model.__class__.__name__ + if self.device != "hpu": + class_name = self.model.__class__.__name__ + else: + class_name = self.model._orig_mod.__class__.__name__ if self.model.__class__.__name__ == 'OptimizedModule' else self.model.__class__.__name__ + return "ForCausalLM" in class_name def reconcile_tokenizer(self): if len(self.tokenizer) > self.model.config.vocab_size: @@ -302,6 +317,17 @@ def reconcile_tokenizer(self): ): self.model.config.eos_token_id = self.tokenizer.eos_token_id + if self.device == "hpu": + model = self.model._orig_mod if self.model.__class__.__name__ == 'OptimizedModule' else self.model + class_name = model.__class__.__name__ + + replace_no_split_modules = { + 'GaudiLlamaForCausalLM': ['GaudiLlamaDecoderLayer',] + } + + if class_name in replace_no_split_modules: + model._no_split_modules = replace_no_split_modules[class_name] + if not self._is_causal_lm_model(): raise ValueError( f"Model must be a causal language model, got {type(self.model)}" @@ -358,6 +384,7 @@ def __init__( flash_enabled: bool = False, lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, + device: Optional[str] = None, ): super().__init__( model_path=model_path, @@ -367,6 +394,7 @@ def __init__( flash_enabled=flash_enabled, lora_config=lora_config, lora_quant_bits=lora_quant_bits, + device=device, ) try: # Third Party @@ -400,6 +428,7 @@ def __init__( flash_enabled: bool = False, lora_config: Optional[LoraConfig] = None, lora_quant_bits: int = 0, + device: Optional[str] = None, ): super().__init__( model_path=model_path, @@ -409,6 +438,7 @@ def __init__( flash_enabled=flash_enabled, lora_config=lora_config, lora_quant_bits=lora_quant_bits, + device=device, ) # Third Party from transformers import AutoModelForCausalLM diff --git a/src/instructlab/training/multipack_sampler.py b/src/instructlab/training/multipack_sampler.py index 6b9a4941..a48f8e12 100644 --- a/src/instructlab/training/multipack_sampler.py +++ b/src/instructlab/training/multipack_sampler.py @@ -34,6 +34,8 @@ import torch import torch.distributed as dist +from instructlab.training.hpu_utils import is_torch_hpu_available, bucket + def find_max_pack_len_with_padding( dataset, @@ -68,9 +70,14 @@ def get_effective_samples_per_minibatch(num_tokens_per_gpu): The function creates a sampler using the MultipackDistributedBatchSampler class, generates batches using the sampler, and then returns the ratio of the dataset size to the number of batches. """ + lengths=dataset.get_lengths() + if is_torch_hpu_available(): + bucket_v = np.vectorize(bucket) + lengths = bucket_v(lengths) + sampler = MultipackDistributedBatchSampler( batch_max_length=num_tokens_per_gpu, - lengths=dataset.get_lengths(), + lengths=lengths, num_replicas=torch.distributed.get_world_size(), rank=torch.distributed.get_rank(), seed=seed, diff --git a/src/instructlab/training/token_dataset.py b/src/instructlab/training/token_dataset.py index 38b3a6f9..3ca726ce 100644 --- a/src/instructlab/training/token_dataset.py +++ b/src/instructlab/training/token_dataset.py @@ -13,6 +13,7 @@ from instructlab.training.multipack_sampler import MultipackDistributedBatchSampler from instructlab.training.utils import log_rank_0, make_collate_fn +from instructlab.training.hpu_utils import bucket class TokenDataset(Dataset): def __init__(self, data_path): @@ -96,15 +97,21 @@ def setup_dataloader( samples_per_gpu=None, sampler="multipack", seed=47, + device=None, ) -> DataLoader: collate_fn = make_collate_fn( - pad_token_id, flash_enabled=flash_enabled, max_batch_len=max_batch_len + pad_token_id, flash_enabled=flash_enabled, max_batch_len=max_batch_len, + device=device, ) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) lengths = dataset.get_lengths() if sampler == "multipack": + if device == "hpu": + bucket_v = np.vectorize(bucket) + lengths = bucket_v(lengths) + sampler = MultipackDistributedBatchSampler( batch_max_length=packing_max_batch_len, lengths=lengths, diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 15dd2897..5988813d 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -37,6 +37,7 @@ TrainingArgs, ) from instructlab.training.model import Model +from instructlab.training.hpu_utils import is_torch_hpu_available, bucket logger = logging.getLogger("instructlab.training") @@ -197,7 +198,7 @@ def listen(self): break -def make_collate_fn(pad_token_id, flash_enabled=True, max_batch_len=60000): +def make_collate_fn(pad_token_id, flash_enabled=True, max_batch_len=60000, device=None): if flash_enabled: def pad_collate_fn(batch): @@ -234,6 +235,9 @@ def pad_collate_fn(batch): lens = np.array([len(item["input_ids"]) for item in batch]) max_len = max(lens) + if device=="hpu": + max_len = bucket(max_len) + input_ids = torch.stack( [ F.pad( @@ -311,6 +315,7 @@ def reduce_sum_forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + **(_deprecated_arguments if is_torch_hpu_available() else {}), ) return_dict = isinstance(output, dict) @@ -608,7 +613,10 @@ def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) + if is_torch_hpu_available(): + torch.hpu.manual_seed_all(seed) + else: + torch.cuda.manual_seed_all(seed) def save_checkpoint(