From bbb9a6619fc8c2d82d36797bc851c4ecb23eb36a Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Mon, 29 Jul 2024 22:57:38 -0700 Subject: [PATCH 01/23] add automicrobatching for non-powers-of-2 + adaptive sync hooks --- composer/distributed/dist_strategy.py | 24 ++- composer/trainer/_patch_pytorch.py | 74 +++++++ composer/trainer/trainer.py | 208 ++++++++++++-------- composer/utils/__init__.py | 12 ++ composer/utils/automicrobatching.py | 266 ++++++++++++++++++++++++++ tests/trainer/test_fsdp.py | 68 ++++++- 6 files changed, 563 insertions(+), 89 deletions(-) create mode 100644 composer/utils/automicrobatching.py diff --git a/composer/distributed/dist_strategy.py b/composer/distributed/dist_strategy.py index be81652881..d8a5b4b8d1 100644 --- a/composer/distributed/dist_strategy.py +++ b/composer/distributed/dist_strategy.py @@ -7,7 +7,7 @@ import logging import warnings from contextlib import contextmanager, nullcontext -from typing import Any, Callable, ContextManager, Iterator, Optional, Sequence, Union, cast +from typing import Any, Callable, ContextManager, Iterator, Optional, Sequence, Union, Tuple, cast import torch from packaging import version @@ -203,7 +203,7 @@ def prepare_fsdp_module( device: Device, auto_microbatching: bool, te_rng_seed: int = 1234, -) -> None: +) -> Tuple[list, dict]: """Prepare a module (assumed ComposerModel) and optimizer for use with :class:`torch.distributed.fsdp.FullyShardedDataParallel`. Args: @@ -230,6 +230,9 @@ def prepare_fsdp_module( 'some weights may be randomly initialized when loading a checkpoint.', ) + # Handles of FSDP sync hooks if automicrobatching is on + hook_handles = [] + # Check if other ranks OOMed after forward/backward pass when using auto microbatching. This # may happen when close to memory limit or with uneven memory usage across ranks. Since we # need to do this before the model weights are gathered for the next FSDP block, we wrap every @@ -512,9 +515,6 @@ def lambda_fn(module: torch.nn.Module) -> Union[bool, dict]: ret = obj.fsdp_wrap_fn(module) if isinstance(ret, dict): ret = set_custom_fsdp_module_kwargs(ret, process_group_cache) - if ret and auto_microbatching: - module.register_forward_hook(sync_hook) - module.register_full_backward_hook(sync_hook) return ret _auto_wrap_policy = CustomPolicy(lambda_fn) @@ -531,9 +531,6 @@ def __auto_wrap_policy(module: torch.nn.Module, recurse: bool, nonwrapped_numel: elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable): should_be_wrapped = obj.fsdp_wrap_fn(module) - if should_be_wrapped and auto_microbatching: - module.register_forward_hook(sync_hook) - module.register_full_backward_hook(sync_hook) return should_be_wrapped def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool: @@ -567,6 +564,15 @@ def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_num log.info(f'Calling prepare_te_modules_for_fsdp to enable TE weights sharding') prepare_te_modules_for_fsdp(fsdp_obj) + + if auto_microbatching: + for _, module in fsdp_obj.named_modules(): + if isinstance(module, FullyShardedDataParallel): + hook_handles.append(module.register_forward_pre_hook(sync_hook, prepend=True)) + hook_handles.append(module.register_full_backward_pre_hook(sync_hook, prepend=True)) + else: + hook_handles.append(module.register_full_backward_hook(sync_hook)) + if hasattr(fsdp_obj, '_exec_order_data'): if hasattr(fsdp_obj._exec_order_data, '_forward_prefetch_limit'): fsdp_obj._exec_order_data._forward_prefetch_limit = fsdp_config.forward_prefetch_limit @@ -727,3 +733,5 @@ def _check_fn(module: torch.nn.Module) -> bool: assert optimizer_specific_info is not None optimizer_specific_info.update({'params': list(model.parameters())}) optim.add_param_group(optimizer_specific_info) + + return hook_handles, dict(fsdp_obj.named_modules()) \ No newline at end of file diff --git a/composer/trainer/_patch_pytorch.py b/composer/trainer/_patch_pytorch.py index 3f19df7d2a..25cf365721 100644 --- a/composer/trainer/_patch_pytorch.py +++ b/composer/trainer/_patch_pytorch.py @@ -30,9 +30,16 @@ from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy from torch.distributed.fsdp._fsdp_extensions import _ext_pre_load_state_dict_transform from torch.distributed.utils import _replace_by_prefix +from torch.distributed.fsdp._flat_param import FlatParamHandle log = logging.getLogger(__name__) +def patch_unshard_for_automicrobatching(auto_microbatch_size_found=False): + if auto_microbatch_size_found: + FlatParamHandle.unshard = (unshard) + print("dropping monkey") + else: + FlatParamHandle.unshard = (unshard_with_sync) def patch_pytorch(): """Monkey patches pytorch functions based on pytorch version.""" @@ -122,6 +129,73 @@ def patch_pytorch(): _MeshEnv.create_child_mesh = create_child_mesh DeviceMesh.__getitem__ = device_mesh__getitem__ +@no_type_check +def unshard(self): + """ + Run the unshard logic. + This is an unpatched method from pytorch, meant to be reverted to + whenever automicrobatching turns off its hooks for increased throughput. + This includes all-gathering the flat parameter + and switching to using the unsharded flat parameter. If the handle does + not need unsharding, then this only switches to using the unsharded + flat parameter. For ``NO_SHARD``, this is a no-op. + If FSDP is in :meth:`summon_full_params` and the handle uses parameter + mixed precision, then the parameter is forced to full precision. + """ + if not self.needs_unshard(): + # Even when not needing an unshard, we should switch to using + # the unsharded flat parameter + unsharded_flat_param = ( + self._get_padded_unsharded_flat_param() + if self.uses_sharded_strategy + else self.flat_param + ) + self._use_unsharded_flat_param(unsharded_flat_param) + return + unsharded_flat_param = self._alloc_padded_unsharded_flat_param() + padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param) + self._use_unsharded_flat_param(padded_unsharded_flat_param) + +@no_type_check +def unshard_with_sync(self): + """ + Run the unshard logic, but with a sync after a :meth:`_alloc_padded_unsharded_flat_param` + to prevent deadlocks when some ranks OOM after the alloc call and others do not. + This is a patched method from pytorch, meant to be called when automicrobatching + turns on hooks in its search process for the optimal non-OOMing microbatch size. + This includes all-gathering the flat parameter + and switching to using the unsharded flat parameter. If the handle does + not need unsharding, then this only switches to using the unsharded + flat parameter. For ``NO_SHARD``, this is a no-op. + If FSDP is in :meth:`summon_full_params` and the handle uses parameter + mixed precision, then the parameter is forced to full precision. + """ + if not self.needs_unshard(): + # Even when not needing an unshard, we should switch to using + # the unsharded flat parameter + unsharded_flat_param = ( + self._get_padded_unsharded_flat_param() + if self.uses_sharded_strategy + else self.flat_param + ) + self._use_unsharded_flat_param(unsharded_flat_param) + return + unsharded_flat_param = self._alloc_padded_unsharded_flat_param() + + # Check if any other rank hit an OOM + found_cuda_oom_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True) + + dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX') + found_cuda_oom = found_cuda_oom_tensor.item() + # Signal current rank is still in batch + all_ranks_finished_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True) + + dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN') + + if found_cuda_oom == 1: + raise RuntimeError('CUDA out of memory encountered on a different rank') + padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param) + self._use_unsharded_flat_param(padded_unsharded_flat_param) def build_metadata( self, diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 7f28d722ba..3cfa30697a 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -17,7 +17,6 @@ import textwrap import time import warnings -from collections import defaultdict from copy import deepcopy from pathlib import Path from typing import ( @@ -41,18 +40,17 @@ from packaging import version from torch._dynamo import OptimizedModule from torch.cuda.amp.grad_scaler import GradScaler -from torch.distributed.fsdp import FullyShardedDataParallel -from torch.distributed.fsdp._runtime_utils import _post_backward_final_callback from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.fsdp import FullyShardedDataParallel from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader, DistributedSampler from torchmetrics import Metric if version.parse(torch.__version__) >= version.parse('2.3.0'): - from torch.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state # type: ignore + from torch.amp.grad_scaler import GradScaler # type: ignore else: - from torch.cuda.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state # type: ignore + from torch.cuda.amp.grad_scaler import GradScaler # type: ignore from composer.callbacks import CheckpointSaver, MemorySnapshot, OOMObserver, OptimizerMonitor from composer.core import ( @@ -99,7 +97,7 @@ from composer.models import ComposerModel from composer.optim import ComposerScheduler, DecoupledSGDW, compile_composer_scheduler from composer.profiler import Profiler -from composer.trainer._patch_pytorch import patch_pytorch +from composer.trainer._patch_pytorch import patch_pytorch, patch_unshard_for_automicrobatching from composer.trainer._scale_schedule import scale_pytorch_scheduler from composer.trainer._scaler import ClosureGradScaler from composer.utils import ( @@ -132,6 +130,13 @@ parse_uri, partial_format, reproducibility, + _create_sync_hook, + _clear_incomplete_train_states, + _found_ooms_across_ranks, + _update_num_consecutive_thrashes, + _handle_downward_search_in_automicrobatching, + _handle_upward_search_in_automicrobatching, + _handle_thrashing_in_automicrobatching, ) if is_xla_installed(): @@ -322,58 +327,6 @@ def _is_cuda_oom(e: RuntimeError): return True return False - -def _fsdp_reshard_and_cleanup(model: torch.nn.Module): - """Manually reshard and clean up FSDP model. - - When an exception like OOM happens, _post_backward_final_callback, which - is registered as a backward callback, will not run. We manually call it to cleanup - loose memory. - """ - for __, module in model.named_modules(): - if isinstance(module, FullyShardedDataParallel): - if module.check_is_root(): - # Only call _post_backward_final_callback on root module. It will - # traverse and reshard all FSDP sub-modules - _post_backward_final_callback(module, module) - - -def _adjust_device_train_microbatch_size(state: State): - """Adjust device_train_microbatch_size if we encounter OOM. - - Args: - state (State): State of trainer. - """ - # If any rank hit CUDA OOM, update device_train_microbatch_size and retry. Raise runtime error - # if training 1 sample at a time still resulted in CUDA out of memory. - assert state.device_train_microbatch_size is not None - if state.device_train_microbatch_size == 1: - raise RuntimeError(( - 'CUDA out of memory. The train loop failed with an internal microbatch of size 1.' - 'The GPU does not have enough memory to process even 1 sample during train.' - )) - else: - original_microbatch_size = state.device_train_microbatch_size - state.device_train_microbatch_size = max(int(original_microbatch_size / 2), 1) - warnings.warn( - RuntimeWarning( - 'CUDA out of memory detected. Train microbatch size will be decreased from ' - f'{original_microbatch_size} -> {state.device_train_microbatch_size}.', - ), - ) - # Clear gradients in case failure happened during backwards pass - if hasattr(state, 'outputs'): - del state.outputs - if hasattr(state, 'loss'): - del state.loss - for optimizer in state.optimizers: - optimizer.zero_grad(set_to_none=True) - if state.scaler is not None: - state.scaler._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - _fsdp_reshard_and_cleanup(state.model) - torch.cuda.empty_cache() - - def _adjust_device_eval_microbatch_size(evaluator: Evaluator): """Adjust device_eval_microbatch_size if we encounter OOM. @@ -1167,7 +1120,12 @@ def __init__( raise ValueError( '`Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.', ) - + + # Automicrobatching + self.auto_microbatch_size_found = False + self.num_alloc_retries = 0 + self.num_consecutive_thrashes = 0 + self.auto_microbatch_hooks = [] if auto_microbatching and profiler: raise ValueError( "`device_train_microbatch_size='auto'` is not compatible with the profiler. It is " @@ -1251,6 +1209,7 @@ def __init__( if parallelism_config is not None: # Patch PyTorch to fix distributed bugs patch_pytorch() + patch_unshard_for_automicrobatching(self.auto_microbatch_size_found) # Reproducibility rank_zero_seed, seed = _distribute_and_get_random_seed(seed, device) @@ -1668,7 +1627,7 @@ def __init__( if self.state.fsdp_config is not None and self.state.fsdp_config.auto_wrap and not self.state.load_monolith_rank0_only: # Init with globally fixed seed so all HSDP replicas have the same initial weights with reproducibility.seed_context(self.state.rank_zero_seed): - prepare_fsdp_module( + self.automicrobatch_hook_handles, self.fsdp_modules = prepare_fsdp_module( model, optimizers, self.state.fsdp_config, @@ -1838,7 +1797,7 @@ def __init__( ): # Init with globally fixed seed so all HSDP replicas have the same initial weights with reproducibility.seed_context(self.state.rank_zero_seed): - prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) + self.automicrobatch_hook_handles, self.fsdp_modules = prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) self.engine.run_event(Event.AFTER_LOAD) @@ -2733,8 +2692,22 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: # Any in-place changes to a microbatch will be reflected in the device batch. device_batch = self.state.batch + # Automicrobatching + lowest_oom_microbatch_size = None + highest_non_oom_microbatch_size = self.state.device_train_microbatch_size + max_search_steps = 5 + num_search_steps = 0 + original_microbatch_size = self.state.device_train_microbatch_size + lower_bound_microbatch_size = self.state.device_train_microbatch_size + searching_for_non_thrashing_microbatch_size = False + first_success = False + sync_hook = _create_sync_hook(self.state) + # Retry until we successfully complete training and return loss while True: + if not self.auto_microbatch_size_found: + log.info("Searching for optimal microbatch size with automicrobatching.") + log.info("Testing microbatch size = " + str(self.state.device_train_microbatch_size)) # Reset train_metrics on every batch # Placing reset here ensures that if auto grad accum catches an OOM, incomplete metric state is cleared if self.state.train_metrics is not None: # pyright: ignore[reportUnnecessaryComparison] @@ -2771,7 +2744,10 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: else: optimizer.step() except RuntimeError as e: - if self.state.auto_microbatching and _is_cuda_oom(e): + if self.state.auto_microbatching and str(e) == 'CUDA out of memory encountered on a different rank': + log.debug((f"A Different Rank OOM'd.")) + found_cuda_oom = 1 + elif self.state.auto_microbatching and _is_cuda_oom(e): log.debug((f"Rank {dist.get_global_rank()} OOM'd.")) found_cuda_oom = 1 elif self.state.auto_microbatching and ('cuda' in str(e).lower() or 'c10' in str(e).lower()): @@ -2785,27 +2761,91 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: raise if self.state.auto_microbatching: - all_ranks_finished = False - while not all_ranks_finished: - # Propagate across all ranks if any rank hit CUDA OOM - found_cuda_oom_tensor = self.state.device.tensor_to_device( - torch.tensor([found_cuda_oom], dtype=torch.uint8), - ) - dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX') - found_cuda_oom = found_cuda_oom_tensor.item() - # Check if any rank is still not done with the batch. This may happen if only a - # subset of ranks OOM, leaving some batches still in the forward pass - all_ranks_finished_tensor = self.state.device.tensor_to_device(torch.tensor([1], dtype=torch.uint8)) - dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN') - all_ranks_finished = all_ranks_finished_tensor.item() == 1 - if found_cuda_oom == 1: - _adjust_device_train_microbatch_size(self.state) - # Skip return and rerun after handling oom + # Sync for OOMs + found_cuda_oom = _found_ooms_across_ranks(self.state, found_cuda_oom) + + # Sync for alloc retries + if torch.cuda.is_available() and self.auto_microbatch_size_found and not searching_for_non_thrashing_microbatch_size: + self.num_consecutive_thrashes = _update_num_consecutive_thrashes(self.state, self.num_consecutive_thrashes, self.num_alloc_retries) + + if found_cuda_oom == 1: + # Manually clean up state and reshard if an OOM prevents a batch from finishing + _clear_incomplete_train_states(self.state) + self.auto_microbatch_size_found = False + self.num_consecutive_thrashes = 0 + + # Readd sync hooks if they were previously turned off + if len(self.auto_microbatch_hooks) == 0: + patch_unshard_for_automicrobatching(False) + for _, module in self.fsdp_modules.items(): + if isinstance(module, FullyShardedDataParallel): + self.automicrobatch_hook_handles.append(module.register_forward_pre_hook(sync_hook, prepend=True)) + self.automicrobatch_hook_handles.append(module.register_full_backward_pre_hook(sync_hook, prepend=True)) + else: + self.automicrobatch_hook_handles.append(module.register_full_backward_hook(sync_hook)) + + if self.state.device_train_microbatch_size == 1: + raise RuntimeError(( + 'CUDA out of memory. The train loop failed with an internal microbatch of size 1.' + 'The GPU does not have enough memory to process even 1 sample during train.' + )) + + lowest_oom_microbatch_size, highest_non_oom_microbatch_size, lower_bound_microbatch_size, num_search_steps = _handle_downward_search_in_automicrobatching(self.state, lowest_oom_microbatch_size, + highest_non_oom_microbatch_size, lower_bound_microbatch_size, + num_search_steps, max_search_steps) continue + else: + if not self.first_batch_complete and not first_success: + # First successful microbatch size found + first_success = True + _clear_incomplete_train_states(self.state) + continue # Rerun with the same size since this is our first successful batch completion + + if self.num_consecutive_thrashes >= 2: + searching_for_non_thrashing_microbatch_size = True + self.num_consecutive_thrashes = 0 + _clear_incomplete_train_states(self.state) + lowest_oom_microbatch_size, highest_non_oom_microbatch_size, lower_bound_microbatch_size = _handle_thrashing_in_automicrobatching(self.state) + + # Readd sync hooks if they were previously turned off + if len(self.auto_microbatch_hooks) == 0: + patch_unshard_for_automicrobatching(False) + for _, module in self.fsdp_modules.items(): + if isinstance(module, FullyShardedDataParallel): + self.automicrobatch_hook_handles.append(module.register_forward_pre_hook(sync_hook, prepend=True)) + self.automicrobatch_hook_handles.append(module.register_full_backward_pre_hook(sync_hook, prepend=True)) + else: + self.automicrobatch_hook_handles.append(module.register_full_backward_hook(sync_hook)) + continue + + if not self.auto_microbatch_size_found: # microbatch size found in previous search + search_upwards, highest_non_oom_microbatch_size, num_search_steps = _handle_upward_search_in_automicrobatching(self.state, lowest_oom_microbatch_size, + highest_non_oom_microbatch_size, num_search_steps, + max_search_steps) + if search_upwards: + continue + # Log microbatch and return loss if we've completed without OOMing. assert self.state.device_train_microbatch_size is not None + if original_microbatch_size != self.state.device_train_microbatch_size: + warnings.warn( + RuntimeWarning( + 'Automicrobatching changed the microbatch size from ' + f'{original_microbatch_size} -> {self.state.device_train_microbatch_size}.', + ), + ) + if len(self.auto_microbatch_hooks) > 0: + patch_unshard_for_automicrobatching(True) + for handle in self.auto_microbatch_hooks: + handle.remove() + self.auto_microbatch_hooks.clear() + self.auto_microbatch_size_found = True + if torch.cuda.is_available(): + memory_stats = torch.cuda.memory_stats() + self.num_alloc_retries = memory_stats["num_alloc_retries"] self.logger.log_metrics({'trainer/device_train_microbatch_size': self.state.device_train_microbatch_size}) self.first_batch_complete = True + self.engine.run_event(Event.AFTER_TRAIN_BATCH) return total_loss_dict def _train_microbatches( @@ -2901,8 +2941,6 @@ def _train_microbatches( for optimizer in ensure_tuple(self.state.optimizers): self.state.scaler.unscale_(optimizer) - self.engine.run_event(Event.AFTER_TRAIN_BATCH) - return total_loss_dict['loss/train/total'] def _train_microbatch( @@ -3389,7 +3427,17 @@ def _eval_loop( last_wct = datetime.datetime.now() + sync_hook = _create_sync_hook(self.state) + with torch.no_grad(), model_eval_mode(self.state.model): + if self.first_batch_complete: + patch_unshard_for_automicrobatching(False) + for _ , module in self.fsdp_modules.items(): + if isinstance(module, FullyShardedDataParallel): + self.automicrobatch_hook_handles.append(module.register_forward_pre_hook(sync_hook, prepend=True)) + self.automicrobatch_hook_handles.append(module.register_full_backward_pre_hook(sync_hook, prepend=True)) + else: + self.automicrobatch_hook_handles.append(module.register_full_backward_hook(sync_hook)) self.state.set_dataloader(data_spec.dataloader, evaluator.label, subset_num_batches) assert self.state.dataloader is not None, 'dataloader is set' diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index 20fa44e092..5dd76bf67a 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -8,6 +8,18 @@ convert_nested_dict_to_flat_dict, extract_hparams, ) +from composer.utils.automicrobatching import ( + _create_sync_hook, + _fsdp_reshard_and_cleanup, + _double_device_train_microbatch_size, + _closest_lower_power_of_2, + _clear_incomplete_train_states, + _found_ooms_across_ranks, + _update_num_consecutive_thrashes, + _handle_downward_search_in_automicrobatching, + _handle_upward_search_in_automicrobatching, + _handle_thrashing_in_automicrobatching, +) from composer.utils.batch_helpers import batch_get, batch_set from composer.utils.checkpoint import ( PartialFilePath, diff --git a/composer/utils/automicrobatching.py b/composer/utils/automicrobatching.py new file mode 100644 index 0000000000..e04eb20435 --- /dev/null +++ b/composer/utils/automicrobatching.py @@ -0,0 +1,266 @@ +import torch +import logging +from composer.core import State +from composer.utils import dist +from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.fsdp._runtime_utils import _post_backward_final_callback +from collections import defaultdict +from packaging import version + +if version.parse(torch.__version__) >= version.parse('2.3.0'): + from torch.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore +else: + from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore + +log = logging.getLogger(__name__) + +__all__ = [ + '_create_sync_hook', + '_fsdp_reshard_and_cleanup', + '_double_device_train_microbatch_size', + '_closest_lower_power_of_2', + '_clear_incomplete_train_states', + '_found_ooms_across_ranks', + '_update_num_consecutive_thrashes', + '_handle_downward_search_in_automicrobatching', + '_handle_upward_search_in_automicrobatching', + '_handle_thrashing_in_automicrobatching' +] + +def _create_sync_hook(state): + def sync_hook(*args): + # Check if any other rank hit an OOM + found_cuda_oom_tensor = state.device.tensor_to_device(torch.tensor([0], dtype=torch.uint8)) + dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX') + found_cuda_oom = found_cuda_oom_tensor.item() + # Signal current rank is still in batch + all_ranks_finished_tensor = state.device.tensor_to_device(torch.tensor([0], dtype=torch.uint8)) + dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN') + + if found_cuda_oom == 1: + raise RuntimeError('CUDA out of memory encountered on a different rank') + + return sync_hook + +def _fsdp_reshard_and_cleanup(model: torch.nn.Module): + """Manually reshard and clean up FSDP model. + + When an exception like OOM happens, _post_backward_final_callback, which + is registered as a backward callback, will not run. We manually call it to cleanup + loose memory. + """ + for __, module in model.named_modules(): + if isinstance(module, FullyShardedDataParallel): + if module.check_is_root(): + # Only call _post_backward_final_callback on root module. It will + # traverse and reshard all FSDP sub-modules + _post_backward_final_callback(module, module) + + +def _double_device_train_microbatch_size(state: State): + """Double device_train_microbatch_size when automicrobatching searches upward for a higher non-OOM microbatch size. + + Args: + state (State): State of trainer. + """ + # If any rank hit CUDA OOM, update device_train_microbatch_size and retry. Raise runtime error + # if training 1 sample at a time still resulted in CUDA out of memory. + assert state.device_train_microbatch_size is not None + assert state.train_dataloader is not None + + try: + batch_size = getattr(state.train_dataloader, 'batch_size') + except AttributeError as e: + # Error message when `device_train_microbatch_size` is 'auto' + raise AttributeError( + "`device_train_microbatch_size='auto'` requires the `state.train_dataloader` to have a `batch_size` attribute.", + ) from e + + original_microbatch_size = state.device_train_microbatch_size + # Device train microbatch size can't be greater than the device train batch size + state.device_train_microbatch_size = min(int(original_microbatch_size * 2), batch_size) + +def _closest_lower_power_of_2(microbatch_size: int): + """Find the highest lower power of 2 to serve as a lower bound device_train_microbatch_size when automicrobatching + searches downward, due to either thrashing or when a previously non-OOMing microbatch size is now OOMing. + Args: + microbatch_size (int): Current device train microbatch size. + """ + if microbatch_size <= 1: + return 1 + return 1 << ((microbatch_size - 1).bit_length() - 1) + +def _clear_incomplete_train_states(state: State): + """Manually clear gradients when automicrobatching reruns a batch. + Before automicrobatching tries a new higher or lower microbatch size, clear the + training states and memory of the previous run of the batch to reset the memory to + before the batch was run. + """ + if hasattr(state, 'outputs'): + del state.outputs + if hasattr(state, 'loss'): + del state.loss + for optimizer in state.optimizers: + optimizer.zero_grad(set_to_none=True) + if state.scaler is not None: + state.scaler._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + _fsdp_reshard_and_cleanup(state.model) + torch.cuda.empty_cache() + +def _found_ooms_across_ranks(state: State, found_cuda_oom: bool): + """Check if at least one rank, including the local rank, OOM'd in the forward/backward pass + when using automicrobatching. This may happen when close to memory limit or with uneven memory + usage across ranks. + + Ensure that all ranks are out of microbatch training before completing batch training or finding + a new microbatch size. Return whether at least one rank OOM'd. + """ + + all_ranks_finished = False + while not all_ranks_finished: + # Propagate across all ranks if any rank hit CUDA OOM + found_cuda_oom_tensor = state.device.tensor_to_device( + torch.tensor([found_cuda_oom], dtype=torch.uint8), + ) + dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX') + found_cuda_oom = found_cuda_oom_tensor.item() + # Check if any rank is still not done with the batch. This may happen if only a + # subset of ranks OOM, leaving some batches still in the forward pass + all_ranks_finished_tensor = state.device.tensor_to_device(torch.tensor([1], dtype=torch.uint8)) + dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN') + all_ranks_finished = all_ranks_finished_tensor.item() == 1 + return found_cuda_oom + +def _update_num_consecutive_thrashes(state: State, num_consecutive_thrashes: int, num_alloc_retries: int): + """Update the number of consecutive batches where we experienced alloc retries. + Consecutive alloc retries in GPU memory usually indicate thrashing, where GPU memory usage is so close + to the memory limit that it hinders throughput. + """ + # Check for alloc retries between batches + stats = torch.cuda.memory_stats() + cur_num_alloc_retries = stats["num_alloc_retries"] + + if cur_num_alloc_retries - num_alloc_retries > 0: + alloc_retry_this_batch = 1 + log.info("Found alloc retries this batch: " + str(num_alloc_retries) + " to " + str(cur_num_alloc_retries)) + else: + alloc_retry_this_batch = 0 + + # Propagate across all ranks if any rank had alloc retries this batch + alloc_retry_tensor = state.device.tensor_to_device( + torch.tensor([alloc_retry_this_batch], dtype=torch.uint8), + ) + dist.all_reduce(alloc_retry_tensor, reduce_operation='MAX') + alloc_retry_this_batch = alloc_retry_tensor.item() == 1 + if alloc_retry_this_batch: + num_consecutive_thrashes += 1 + else: + num_consecutive_thrashes = 0 + return num_consecutive_thrashes + +def _handle_downward_search_in_automicrobatching(state: State, lowest_oom_microbatch_size: int, highest_non_oom_microbatch_size: int, lower_bound_microbatch_size: int, num_search_steps: int, max_search_steps: int): + """Search downward for the highest non-OOMing microbatch size. + + This method is only called when an OOM was seen this batch with the current state.device_train_microbatch_size. + If this is the first time automicrobatching is searching for a non-OOMing microbatch size, or the previously highest non-OOMing power of 2 + microbatch size is now OOMing, automicrobatching searches for the next highest power of 2 to test as a microbatch size. This resets num_search_steps + to 1. + Otherwise, while automicrobatching has searched for less than max_search_steps, automicrobatching binary searches downwards between the highest recorded + non-OOMing microbatch size and the lowest recorded OOMing microbatch size. + Once automicrobatching has searched for max_search_steps, if the last tested microbatch size OOM'd, choose the highest previously + recorded non-OOMing microbatch size. For the edge case where that microbatch size OOMs upon retry, binary search downward between + that value and lower_bound_microbatch_size, which is the highest power of 2 guaranteed to not OOM. + """ + # Find closest lower power of 2 if previously non-OOM microbatch size is OOMing or this is the first microbatch size search + if state.device_train_microbatch_size == lower_bound_microbatch_size: + lowest_oom_microbatch_size = state.device_train_microbatch_size + lower_bound_microbatch_size = _closest_lower_power_of_2(state.device_train_microbatch_size) + state.device_train_microbatch_size = lower_bound_microbatch_size + highest_non_oom_microbatch_size = state.device_train_microbatch_size + + num_search_steps = 1 + # Skip return and continue searching for the highest non-OOM size in the new lower range + else: + if num_search_steps < max_search_steps: + lowest_oom_microbatch_size = state.device_train_microbatch_size + median_microbatch_size = int((lowest_oom_microbatch_size + highest_non_oom_microbatch_size) // 2) + state.device_train_microbatch_size = median_microbatch_size + + num_search_steps += 1 + + # Optimization so we don't repeat a converged value + if lowest_oom_microbatch_size == highest_non_oom_microbatch_size: + num_search_steps = max_search_steps + 1 # go to else protocol + lowest_oom_microbatch_size = state.device_train_microbatch_size + highest_non_oom_microbatch_size = lower_bound_microbatch_size + state.device_train_microbatch_size = int((lowest_oom_microbatch_size + highest_non_oom_microbatch_size) // 2) + + # Skip return and decrease dtms, continuing the search for the highest non-OOM size + elif num_search_steps == max_search_steps: + state.device_train_microbatch_size = highest_non_oom_microbatch_size + + num_search_steps += 1 + # Skip return and rerun to obtain loss - committing to this dtms unless retrying it OOMs + else: + # Only end up here if a previously non-OOM microbatch size is no longer successful in the same training step, and it's not the original microbatch size + + lowest_oom_microbatch_size = state.device_train_microbatch_size + highest_non_oom_microbatch_size = lower_bound_microbatch_size + state.device_train_microbatch_size = int((lowest_oom_microbatch_size + highest_non_oom_microbatch_size) // 2) + + # Skip return and continue searching for the highest non-OOM size in this narrower range + return lowest_oom_microbatch_size, highest_non_oom_microbatch_size, lower_bound_microbatch_size, num_search_steps + +def _handle_upward_search_in_automicrobatching(state: State, lowest_oom_microbatch_size: int, highest_non_oom_microbatch_size: int, num_search_steps: int, max_search_steps: int): + """Searches upward for the highest non-OOMing microbatch size. + + This method is only called when the current state.device_train_microbatch_size did not OOM and automicrobatching is actively searching for a new + microbatch size, either because this is the first search or a previously working microbatch size OOM'd. + If the microbatch size is already equal to the batch size, automicrobatching commits to this microbatch size. + Otherwise, while automicrobatching has searched for less than max_search_steps, automicrobatching binary searches upwards between the highest recorded + non-OOMing microbatch size and the lowest recorded OOMing microbatch size. + """ + assert state.train_dataloader is not None + try: + batch_size = getattr(state.train_dataloader, 'batch_size') + except AttributeError as e: + # Error message when `device_train_microbatch_size` is 'auto' + raise AttributeError( + "`device_train_microbatch_size='auto'` requires the `state.train_dataloader` to have a `batch_size` attribute.", + ) from e + + search_upwards = False + + if state.device_train_microbatch_size != batch_size: + if num_search_steps == 0: + highest_non_oom_microbatch_size = state.device_train_microbatch_size + _double_device_train_microbatch_size(state) + _clear_incomplete_train_states(state) + search_upwards = True + elif num_search_steps < max_search_steps: # Previous OOMs found in this training step + highest_non_oom_microbatch_size = state.device_train_microbatch_size + median_microbatch_size = int((highest_non_oom_microbatch_size + lowest_oom_microbatch_size) // 2) + state.device_train_microbatch_size = median_microbatch_size + + num_search_steps += 1 + + # Optimization so we don't repeat a converged value + if median_microbatch_size == highest_non_oom_microbatch_size: + num_search_steps = max_search_steps + + _clear_incomplete_train_states(state) + search_upwards = True + # Else: reached max search steps and found a non-OOM microbatch size + return search_upwards, highest_non_oom_microbatch_size, num_search_steps + +def _handle_thrashing_in_automicrobatching(state: State): + """Searches downward for the highest non-OOMing microbatch size that also doesn't thrash. + This method is only called when two consecutive batches have alloc retries, indicating thrashing, + where GPU memory usage is so close to the memory limit that it hinders throughput. + Automicrobatching searches for the next highest power of 2 to use as the microbatch size. + """ + lowest_oom_microbatch_size = state.device_train_microbatch_size + lower_bound_microbatch_size = _closest_lower_power_of_2(state.device_train_microbatch_size) + highest_non_oom_microbatch_size = lower_bound_microbatch_size + state.device_train_microbatch_size = lower_bound_microbatch_size + return lowest_oom_microbatch_size, highest_non_oom_microbatch_size, lower_bound_microbatch_size diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index b077d22131..e42ef17a73 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -7,7 +7,7 @@ import torch from packaging import version from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from composer.models import ComposerClassifier, ComposerModel from composer.trainer.trainer import Trainer, _fsdp_reshard_and_cleanup @@ -205,6 +205,72 @@ def test_fsdp_prefetch_limit(forward_prefetch_limit: int, backward_prefetch_limi trainer.fit() +class SimpleDatasetForAuto(Dataset): + + def __init__(self, size: int = 256, feature_size: int = 1, num_classes: int = 2): + self.size = size + self.feature_size = feature_size + self.num_classes = num_classes + self.x = None + self.y = None + + def __len__(self): + return self.size + + def __getitem__(self, index: int): + # Note: lazily generate data so it runs after Composer seeds everything, giving the same + # dataset across multiple calls when using the same seed. + if self.x is None: + self.x = torch.randn(self.size, self.feature_size) + if self.y is None: + self.y = torch.randint(0, self.num_classes, size=(self.size,), dtype=torch.long) + return self.x[index] + +class SimpleMLPForTestingOOM(ComposerModel): + + def __init__(self, num_features: int = 128, device: str = 'cuda'): + super().__init__() + self.device = device + self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + self.fc2 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + self.fc3 = torch.nn.Linear(num_features, num_features, device=device, bias=False) + self.rank = dist.get_global_rank() + self.iter = 0 + + def forward(self, x): + x = self.fc1(x) + if self.rank == 0 and x.shape[0] >= 64: + raise RuntimeError('CUDA out of memory') + x = self.fc2(x) + x = self.fc3(x) + self.iter += 1 + return x + + def loss(self, outputs, batch): + return torch.sum(outputs) + +@pytest.mark.gpu +@world_size(2) +def test_automicrobatching_fsdp(world_size: int): + model = SimpleMLPForTestingOOM() + model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] + dataset = SimpleDatasetForAuto(size=256, feature_size=128) + train_dataloader = DataLoader(dataset, batch_size=64) + trainer = Trainer( + model=model, + train_dataloader=train_dataloader, + fsdp_config={ + 'forward_prefetch_limit': 1, + 'backward_prefetch_limit': 1, + }, + max_duration='1ba', + device='gpu', + device_train_microbatch_size='auto', + dist_timeout=20, + ) + trainer.fit() + @pytest.mark.gpu @world_size(2) From 42288897cdbea79e428101285e3df13f70b83b73 Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Mon, 29 Jul 2024 23:14:17 -0700 Subject: [PATCH 02/23] include auto helpers in _all_ --- composer/utils/__init__.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index 5dd76bf67a..807696a246 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -176,4 +176,14 @@ 'validate_credentials', 'build_remote_backend', 'RemoteFilesExistingCheckStatus', + '_create_sync_hook', + '_fsdp_reshard_and_cleanup', + '_double_device_train_microbatch_size', + '_closest_lower_power_of_2', + '_clear_incomplete_train_states', + '_found_ooms_across_ranks', + '_update_num_consecutive_thrashes', + '_handle_downward_search_in_automicrobatching', + '_handle_upward_search_in_automicrobatching', + '_handle_thrashing_in_automicrobatching' ] From a537c4c7c4c816b3fc6873b7db1ad6d33a060a6d Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Mon, 29 Jul 2024 23:20:24 -0700 Subject: [PATCH 03/23] fix circular imports --- composer/core/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/core/__init__.py b/composer/core/__init__.py index 1198ffd32b..6c3709f0a9 100644 --- a/composer/core/__init__.py +++ b/composer/core/__init__.py @@ -10,6 +10,7 @@ from composer.core.algorithm import Algorithm from composer.core.callback import Callback +from composer.core.state import State from composer.core.data_spec import DataSpec, ensure_data_spec from composer.core.engine import Engine, Trace from composer.core.evaluator import Evaluator, ensure_evaluator @@ -17,7 +18,6 @@ from composer.core.passes import AlgorithmPass from composer.core.precision import Precision, get_precision_context from composer.core.serializable import Serializable -from composer.core.state import State from composer.core.time import Time, Timestamp, TimeUnit, ensure_time from composer.core.types import JSON, Batch, Dataset, MemoryFormat, TrainerMode From ff806d1615a9e260823c922428fbe88b4caead56 Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Mon, 29 Jul 2024 23:32:11 -0700 Subject: [PATCH 04/23] remove circular import --- composer/core/__init__.py | 2 +- composer/trainer/trainer.py | 223 +++++++++++++++++++++++++++- composer/utils/__init__.py | 34 ++--- composer/utils/automicrobatching.py | 221 --------------------------- 4 files changed, 234 insertions(+), 246 deletions(-) diff --git a/composer/core/__init__.py b/composer/core/__init__.py index 6c3709f0a9..1198ffd32b 100644 --- a/composer/core/__init__.py +++ b/composer/core/__init__.py @@ -10,7 +10,6 @@ from composer.core.algorithm import Algorithm from composer.core.callback import Callback -from composer.core.state import State from composer.core.data_spec import DataSpec, ensure_data_spec from composer.core.engine import Engine, Trace from composer.core.evaluator import Evaluator, ensure_evaluator @@ -18,6 +17,7 @@ from composer.core.passes import AlgorithmPass from composer.core.precision import Precision, get_precision_context from composer.core.serializable import Serializable +from composer.core.state import State from composer.core.time import Time, Timestamp, TimeUnit, ensure_time from composer.core.types import JSON, Batch, Dataset, MemoryFormat, TrainerMode diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 3cfa30697a..66fff43aa6 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -130,13 +130,8 @@ parse_uri, partial_format, reproducibility, - _create_sync_hook, - _clear_incomplete_train_states, - _found_ooms_across_ranks, - _update_num_consecutive_thrashes, - _handle_downward_search_in_automicrobatching, - _handle_upward_search_in_automicrobatching, - _handle_thrashing_in_automicrobatching, + _fsdp_reshard_and_cleanup, + _closest_lower_power_of_2 ) if is_xla_installed(): @@ -327,6 +322,220 @@ def _is_cuda_oom(e: RuntimeError): return True return False +def _double_device_train_microbatch_size(state: State): + """Double device_train_microbatch_size when automicrobatching searches upward for a higher non-OOM microbatch size. + + Args: + state (State): State of trainer. + """ + # If any rank hit CUDA OOM, update device_train_microbatch_size and retry. Raise runtime error + # if training 1 sample at a time still resulted in CUDA out of memory. + assert state.device_train_microbatch_size is not None + assert state.train_dataloader is not None + + try: + batch_size = getattr(state.train_dataloader, 'batch_size') + except AttributeError as e: + # Error message when `device_train_microbatch_size` is 'auto' + raise AttributeError( + "`device_train_microbatch_size='auto'` requires the `state.train_dataloader` to have a `batch_size` attribute.", + ) from e + + original_microbatch_size = state.device_train_microbatch_size + # Device train microbatch size can't be greater than the device train batch size + state.device_train_microbatch_size = min(int(original_microbatch_size * 2), batch_size) + +def _found_ooms_across_ranks(state: State, found_cuda_oom: bool): + """Check if at least one rank, including the local rank, OOM'd in the forward/backward pass + when using automicrobatching. This may happen when close to memory limit or with uneven memory + usage across ranks. + + Ensure that all ranks are out of microbatch training before completing batch training or finding + a new microbatch size. Return whether at least one rank OOM'd. + """ + + all_ranks_finished = False + while not all_ranks_finished: + # Propagate across all ranks if any rank hit CUDA OOM + found_cuda_oom_tensor = state.device.tensor_to_device( + torch.tensor([found_cuda_oom], dtype=torch.uint8), + ) + dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX') + found_cuda_oom = found_cuda_oom_tensor.item() + # Check if any rank is still not done with the batch. This may happen if only a + # subset of ranks OOM, leaving some batches still in the forward pass + all_ranks_finished_tensor = state.device.tensor_to_device(torch.tensor([1], dtype=torch.uint8)) + dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN') + all_ranks_finished = all_ranks_finished_tensor.item() == 1 + return found_cuda_oom + +def _update_num_consecutive_thrashes(state: State, num_consecutive_thrashes: int, num_alloc_retries: int): + """Update the number of consecutive batches where we experienced alloc retries. + Consecutive alloc retries in GPU memory usually indicate thrashing, where GPU memory usage is so close + to the memory limit that it hinders throughput. + """ + # Check for alloc retries between batches + stats = torch.cuda.memory_stats() + cur_num_alloc_retries = stats["num_alloc_retries"] + + if cur_num_alloc_retries - num_alloc_retries > 0: + alloc_retry_this_batch = 1 + log.info("Found alloc retries this batch: " + str(num_alloc_retries) + " to " + str(cur_num_alloc_retries)) + else: + alloc_retry_this_batch = 0 + + # Propagate across all ranks if any rank had alloc retries this batch + alloc_retry_tensor = state.device.tensor_to_device( + torch.tensor([alloc_retry_this_batch], dtype=torch.uint8), + ) + dist.all_reduce(alloc_retry_tensor, reduce_operation='MAX') + alloc_retry_this_batch = alloc_retry_tensor.item() == 1 + if alloc_retry_this_batch: + num_consecutive_thrashes += 1 + else: + num_consecutive_thrashes = 0 + return num_consecutive_thrashes + +def _handle_downward_search_in_automicrobatching(state: State, lowest_oom_microbatch_size: int, highest_non_oom_microbatch_size: int, lower_bound_microbatch_size: int, num_search_steps: int, max_search_steps: int): + """Search downward for the highest non-OOMing microbatch size. + + This method is only called when an OOM was seen this batch with the current state.device_train_microbatch_size. + If this is the first time automicrobatching is searching for a non-OOMing microbatch size, or the previously highest non-OOMing power of 2 + microbatch size is now OOMing, automicrobatching searches for the next highest power of 2 to test as a microbatch size. This resets num_search_steps + to 1. + Otherwise, while automicrobatching has searched for less than max_search_steps, automicrobatching binary searches downwards between the highest recorded + non-OOMing microbatch size and the lowest recorded OOMing microbatch size. + Once automicrobatching has searched for max_search_steps, if the last tested microbatch size OOM'd, choose the highest previously + recorded non-OOMing microbatch size. For the edge case where that microbatch size OOMs upon retry, binary search downward between + that value and lower_bound_microbatch_size, which is the highest power of 2 guaranteed to not OOM. + """ + # Find closest lower power of 2 if previously non-OOM microbatch size is OOMing or this is the first microbatch size search + if state.device_train_microbatch_size == lower_bound_microbatch_size: + lowest_oom_microbatch_size = state.device_train_microbatch_size + lower_bound_microbatch_size = _closest_lower_power_of_2(state.device_train_microbatch_size) + state.device_train_microbatch_size = lower_bound_microbatch_size + highest_non_oom_microbatch_size = state.device_train_microbatch_size + + num_search_steps = 1 + # Skip return and continue searching for the highest non-OOM size in the new lower range + else: + if num_search_steps < max_search_steps: + lowest_oom_microbatch_size = state.device_train_microbatch_size + median_microbatch_size = int((lowest_oom_microbatch_size + highest_non_oom_microbatch_size) // 2) + state.device_train_microbatch_size = median_microbatch_size + + num_search_steps += 1 + + # Optimization so we don't repeat a converged value + if lowest_oom_microbatch_size == highest_non_oom_microbatch_size: + num_search_steps = max_search_steps + 1 # go to else protocol + lowest_oom_microbatch_size = state.device_train_microbatch_size + highest_non_oom_microbatch_size = lower_bound_microbatch_size + state.device_train_microbatch_size = int((lowest_oom_microbatch_size + highest_non_oom_microbatch_size) // 2) + + # Skip return and decrease dtms, continuing the search for the highest non-OOM size + elif num_search_steps == max_search_steps: + state.device_train_microbatch_size = highest_non_oom_microbatch_size + + num_search_steps += 1 + # Skip return and rerun to obtain loss - committing to this dtms unless retrying it OOMs + else: + # Only end up here if a previously non-OOM microbatch size is no longer successful in the same training step, and it's not the original microbatch size + + lowest_oom_microbatch_size = state.device_train_microbatch_size + highest_non_oom_microbatch_size = lower_bound_microbatch_size + state.device_train_microbatch_size = int((lowest_oom_microbatch_size + highest_non_oom_microbatch_size) // 2) + + # Skip return and continue searching for the highest non-OOM size in this narrower range + return lowest_oom_microbatch_size, highest_non_oom_microbatch_size, lower_bound_microbatch_size, num_search_steps + +def _handle_upward_search_in_automicrobatching(state: State, lowest_oom_microbatch_size: int, highest_non_oom_microbatch_size: int, num_search_steps: int, max_search_steps: int): + """Searches upward for the highest non-OOMing microbatch size. + + This method is only called when the current state.device_train_microbatch_size did not OOM and automicrobatching is actively searching for a new + microbatch size, either because this is the first search or a previously working microbatch size OOM'd. + If the microbatch size is already equal to the batch size, automicrobatching commits to this microbatch size. + Otherwise, while automicrobatching has searched for less than max_search_steps, automicrobatching binary searches upwards between the highest recorded + non-OOMing microbatch size and the lowest recorded OOMing microbatch size. + """ + assert state.train_dataloader is not None + try: + batch_size = getattr(state.train_dataloader, 'batch_size') + except AttributeError as e: + # Error message when `device_train_microbatch_size` is 'auto' + raise AttributeError( + "`device_train_microbatch_size='auto'` requires the `state.train_dataloader` to have a `batch_size` attribute.", + ) from e + + search_upwards = False + + if state.device_train_microbatch_size != batch_size: + if num_search_steps == 0: + highest_non_oom_microbatch_size = state.device_train_microbatch_size + _double_device_train_microbatch_size(state) + _clear_incomplete_train_states(state) + search_upwards = True + elif num_search_steps < max_search_steps: # Previous OOMs found in this training step + highest_non_oom_microbatch_size = state.device_train_microbatch_size + median_microbatch_size = int((highest_non_oom_microbatch_size + lowest_oom_microbatch_size) // 2) + state.device_train_microbatch_size = median_microbatch_size + + num_search_steps += 1 + + # Optimization so we don't repeat a converged value + if median_microbatch_size == highest_non_oom_microbatch_size: + num_search_steps = max_search_steps + + _clear_incomplete_train_states(state) + search_upwards = True + # Else: reached max search steps and found a non-OOM microbatch size + return search_upwards, highest_non_oom_microbatch_size, num_search_steps + +def _handle_thrashing_in_automicrobatching(state: State): + """Searches downward for the highest non-OOMing microbatch size that also doesn't thrash. + This method is only called when two consecutive batches have alloc retries, indicating thrashing, + where GPU memory usage is so close to the memory limit that it hinders throughput. + Automicrobatching searches for the next highest power of 2 to use as the microbatch size. + """ + lowest_oom_microbatch_size = state.device_train_microbatch_size + lower_bound_microbatch_size = _closest_lower_power_of_2(state.device_train_microbatch_size) + highest_non_oom_microbatch_size = lower_bound_microbatch_size + state.device_train_microbatch_size = lower_bound_microbatch_size + return lowest_oom_microbatch_size, highest_non_oom_microbatch_size, lower_bound_microbatch_size + + +def _clear_incomplete_train_states(state: State): + """Manually clear gradients when automicrobatching reruns a batch. + Before automicrobatching tries a new higher or lower microbatch size, clear the + training states and memory of the previous run of the batch to reset the memory to + before the batch was run. + """ + if hasattr(state, 'outputs'): + del state.outputs + if hasattr(state, 'loss'): + del state.loss + for optimizer in state.optimizers: + optimizer.zero_grad(set_to_none=True) + if state.scaler is not None: + state.scaler._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + _fsdp_reshard_and_cleanup(state.model) + torch.cuda.empty_cache() + +def _create_sync_hook(state: State): + def sync_hook(*args): + # Check if any other rank hit an OOM + found_cuda_oom_tensor = state.device.tensor_to_device(torch.tensor([0], dtype=torch.uint8)) + dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX') + found_cuda_oom = found_cuda_oom_tensor.item() + # Signal current rank is still in batch + all_ranks_finished_tensor = state.device.tensor_to_device(torch.tensor([0], dtype=torch.uint8)) + dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN') + + if found_cuda_oom == 1: + raise RuntimeError('CUDA out of memory encountered on a different rank') + + return sync_hook + def _adjust_device_eval_microbatch_size(evaluator: Evaluator): """Adjust device_eval_microbatch_size if we encounter OOM. diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index 807696a246..259abb63ae 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -9,16 +9,16 @@ extract_hparams, ) from composer.utils.automicrobatching import ( - _create_sync_hook, + # _create_sync_hook, _fsdp_reshard_and_cleanup, - _double_device_train_microbatch_size, - _closest_lower_power_of_2, - _clear_incomplete_train_states, - _found_ooms_across_ranks, - _update_num_consecutive_thrashes, - _handle_downward_search_in_automicrobatching, - _handle_upward_search_in_automicrobatching, - _handle_thrashing_in_automicrobatching, + # _double_device_train_microbatch_size, + _closest_lower_power_of_2, + # _clear_incomplete_train_states, + # _found_ooms_across_ranks, + # _update_num_consecutive_thrashes, + # _handle_downward_search_in_automicrobatching, + # _handle_upward_search_in_automicrobatching, + # _handle_thrashing_in_automicrobatching, ) from composer.utils.batch_helpers import batch_get, batch_set from composer.utils.checkpoint import ( @@ -176,14 +176,14 @@ 'validate_credentials', 'build_remote_backend', 'RemoteFilesExistingCheckStatus', - '_create_sync_hook', + # '_create_sync_hook', '_fsdp_reshard_and_cleanup', - '_double_device_train_microbatch_size', + # '_double_device_train_microbatch_size', '_closest_lower_power_of_2', - '_clear_incomplete_train_states', - '_found_ooms_across_ranks', - '_update_num_consecutive_thrashes', - '_handle_downward_search_in_automicrobatching', - '_handle_upward_search_in_automicrobatching', - '_handle_thrashing_in_automicrobatching' + # '_clear_incomplete_train_states', + # '_found_ooms_across_ranks', + # '_update_num_consecutive_thrashes', + # '_handle_downward_search_in_automicrobatching', + # '_handle_upward_search_in_automicrobatching', + # '_handle_thrashing_in_automicrobatching' ] diff --git a/composer/utils/automicrobatching.py b/composer/utils/automicrobatching.py index e04eb20435..9f7d60729f 100644 --- a/composer/utils/automicrobatching.py +++ b/composer/utils/automicrobatching.py @@ -15,33 +15,10 @@ log = logging.getLogger(__name__) __all__ = [ - '_create_sync_hook', '_fsdp_reshard_and_cleanup', - '_double_device_train_microbatch_size', '_closest_lower_power_of_2', - '_clear_incomplete_train_states', - '_found_ooms_across_ranks', - '_update_num_consecutive_thrashes', - '_handle_downward_search_in_automicrobatching', - '_handle_upward_search_in_automicrobatching', - '_handle_thrashing_in_automicrobatching' ] -def _create_sync_hook(state): - def sync_hook(*args): - # Check if any other rank hit an OOM - found_cuda_oom_tensor = state.device.tensor_to_device(torch.tensor([0], dtype=torch.uint8)) - dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX') - found_cuda_oom = found_cuda_oom_tensor.item() - # Signal current rank is still in batch - all_ranks_finished_tensor = state.device.tensor_to_device(torch.tensor([0], dtype=torch.uint8)) - dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN') - - if found_cuda_oom == 1: - raise RuntimeError('CUDA out of memory encountered on a different rank') - - return sync_hook - def _fsdp_reshard_and_cleanup(model: torch.nn.Module): """Manually reshard and clean up FSDP model. @@ -56,30 +33,6 @@ def _fsdp_reshard_and_cleanup(model: torch.nn.Module): # traverse and reshard all FSDP sub-modules _post_backward_final_callback(module, module) - -def _double_device_train_microbatch_size(state: State): - """Double device_train_microbatch_size when automicrobatching searches upward for a higher non-OOM microbatch size. - - Args: - state (State): State of trainer. - """ - # If any rank hit CUDA OOM, update device_train_microbatch_size and retry. Raise runtime error - # if training 1 sample at a time still resulted in CUDA out of memory. - assert state.device_train_microbatch_size is not None - assert state.train_dataloader is not None - - try: - batch_size = getattr(state.train_dataloader, 'batch_size') - except AttributeError as e: - # Error message when `device_train_microbatch_size` is 'auto' - raise AttributeError( - "`device_train_microbatch_size='auto'` requires the `state.train_dataloader` to have a `batch_size` attribute.", - ) from e - - original_microbatch_size = state.device_train_microbatch_size - # Device train microbatch size can't be greater than the device train batch size - state.device_train_microbatch_size = min(int(original_microbatch_size * 2), batch_size) - def _closest_lower_power_of_2(microbatch_size: int): """Find the highest lower power of 2 to serve as a lower bound device_train_microbatch_size when automicrobatching searches downward, due to either thrashing or when a previously non-OOMing microbatch size is now OOMing. @@ -90,177 +43,3 @@ def _closest_lower_power_of_2(microbatch_size: int): return 1 return 1 << ((microbatch_size - 1).bit_length() - 1) -def _clear_incomplete_train_states(state: State): - """Manually clear gradients when automicrobatching reruns a batch. - Before automicrobatching tries a new higher or lower microbatch size, clear the - training states and memory of the previous run of the batch to reset the memory to - before the batch was run. - """ - if hasattr(state, 'outputs'): - del state.outputs - if hasattr(state, 'loss'): - del state.loss - for optimizer in state.optimizers: - optimizer.zero_grad(set_to_none=True) - if state.scaler is not None: - state.scaler._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) - _fsdp_reshard_and_cleanup(state.model) - torch.cuda.empty_cache() - -def _found_ooms_across_ranks(state: State, found_cuda_oom: bool): - """Check if at least one rank, including the local rank, OOM'd in the forward/backward pass - when using automicrobatching. This may happen when close to memory limit or with uneven memory - usage across ranks. - - Ensure that all ranks are out of microbatch training before completing batch training or finding - a new microbatch size. Return whether at least one rank OOM'd. - """ - - all_ranks_finished = False - while not all_ranks_finished: - # Propagate across all ranks if any rank hit CUDA OOM - found_cuda_oom_tensor = state.device.tensor_to_device( - torch.tensor([found_cuda_oom], dtype=torch.uint8), - ) - dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX') - found_cuda_oom = found_cuda_oom_tensor.item() - # Check if any rank is still not done with the batch. This may happen if only a - # subset of ranks OOM, leaving some batches still in the forward pass - all_ranks_finished_tensor = state.device.tensor_to_device(torch.tensor([1], dtype=torch.uint8)) - dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN') - all_ranks_finished = all_ranks_finished_tensor.item() == 1 - return found_cuda_oom - -def _update_num_consecutive_thrashes(state: State, num_consecutive_thrashes: int, num_alloc_retries: int): - """Update the number of consecutive batches where we experienced alloc retries. - Consecutive alloc retries in GPU memory usually indicate thrashing, where GPU memory usage is so close - to the memory limit that it hinders throughput. - """ - # Check for alloc retries between batches - stats = torch.cuda.memory_stats() - cur_num_alloc_retries = stats["num_alloc_retries"] - - if cur_num_alloc_retries - num_alloc_retries > 0: - alloc_retry_this_batch = 1 - log.info("Found alloc retries this batch: " + str(num_alloc_retries) + " to " + str(cur_num_alloc_retries)) - else: - alloc_retry_this_batch = 0 - - # Propagate across all ranks if any rank had alloc retries this batch - alloc_retry_tensor = state.device.tensor_to_device( - torch.tensor([alloc_retry_this_batch], dtype=torch.uint8), - ) - dist.all_reduce(alloc_retry_tensor, reduce_operation='MAX') - alloc_retry_this_batch = alloc_retry_tensor.item() == 1 - if alloc_retry_this_batch: - num_consecutive_thrashes += 1 - else: - num_consecutive_thrashes = 0 - return num_consecutive_thrashes - -def _handle_downward_search_in_automicrobatching(state: State, lowest_oom_microbatch_size: int, highest_non_oom_microbatch_size: int, lower_bound_microbatch_size: int, num_search_steps: int, max_search_steps: int): - """Search downward for the highest non-OOMing microbatch size. - - This method is only called when an OOM was seen this batch with the current state.device_train_microbatch_size. - If this is the first time automicrobatching is searching for a non-OOMing microbatch size, or the previously highest non-OOMing power of 2 - microbatch size is now OOMing, automicrobatching searches for the next highest power of 2 to test as a microbatch size. This resets num_search_steps - to 1. - Otherwise, while automicrobatching has searched for less than max_search_steps, automicrobatching binary searches downwards between the highest recorded - non-OOMing microbatch size and the lowest recorded OOMing microbatch size. - Once automicrobatching has searched for max_search_steps, if the last tested microbatch size OOM'd, choose the highest previously - recorded non-OOMing microbatch size. For the edge case where that microbatch size OOMs upon retry, binary search downward between - that value and lower_bound_microbatch_size, which is the highest power of 2 guaranteed to not OOM. - """ - # Find closest lower power of 2 if previously non-OOM microbatch size is OOMing or this is the first microbatch size search - if state.device_train_microbatch_size == lower_bound_microbatch_size: - lowest_oom_microbatch_size = state.device_train_microbatch_size - lower_bound_microbatch_size = _closest_lower_power_of_2(state.device_train_microbatch_size) - state.device_train_microbatch_size = lower_bound_microbatch_size - highest_non_oom_microbatch_size = state.device_train_microbatch_size - - num_search_steps = 1 - # Skip return and continue searching for the highest non-OOM size in the new lower range - else: - if num_search_steps < max_search_steps: - lowest_oom_microbatch_size = state.device_train_microbatch_size - median_microbatch_size = int((lowest_oom_microbatch_size + highest_non_oom_microbatch_size) // 2) - state.device_train_microbatch_size = median_microbatch_size - - num_search_steps += 1 - - # Optimization so we don't repeat a converged value - if lowest_oom_microbatch_size == highest_non_oom_microbatch_size: - num_search_steps = max_search_steps + 1 # go to else protocol - lowest_oom_microbatch_size = state.device_train_microbatch_size - highest_non_oom_microbatch_size = lower_bound_microbatch_size - state.device_train_microbatch_size = int((lowest_oom_microbatch_size + highest_non_oom_microbatch_size) // 2) - - # Skip return and decrease dtms, continuing the search for the highest non-OOM size - elif num_search_steps == max_search_steps: - state.device_train_microbatch_size = highest_non_oom_microbatch_size - - num_search_steps += 1 - # Skip return and rerun to obtain loss - committing to this dtms unless retrying it OOMs - else: - # Only end up here if a previously non-OOM microbatch size is no longer successful in the same training step, and it's not the original microbatch size - - lowest_oom_microbatch_size = state.device_train_microbatch_size - highest_non_oom_microbatch_size = lower_bound_microbatch_size - state.device_train_microbatch_size = int((lowest_oom_microbatch_size + highest_non_oom_microbatch_size) // 2) - - # Skip return and continue searching for the highest non-OOM size in this narrower range - return lowest_oom_microbatch_size, highest_non_oom_microbatch_size, lower_bound_microbatch_size, num_search_steps - -def _handle_upward_search_in_automicrobatching(state: State, lowest_oom_microbatch_size: int, highest_non_oom_microbatch_size: int, num_search_steps: int, max_search_steps: int): - """Searches upward for the highest non-OOMing microbatch size. - - This method is only called when the current state.device_train_microbatch_size did not OOM and automicrobatching is actively searching for a new - microbatch size, either because this is the first search or a previously working microbatch size OOM'd. - If the microbatch size is already equal to the batch size, automicrobatching commits to this microbatch size. - Otherwise, while automicrobatching has searched for less than max_search_steps, automicrobatching binary searches upwards between the highest recorded - non-OOMing microbatch size and the lowest recorded OOMing microbatch size. - """ - assert state.train_dataloader is not None - try: - batch_size = getattr(state.train_dataloader, 'batch_size') - except AttributeError as e: - # Error message when `device_train_microbatch_size` is 'auto' - raise AttributeError( - "`device_train_microbatch_size='auto'` requires the `state.train_dataloader` to have a `batch_size` attribute.", - ) from e - - search_upwards = False - - if state.device_train_microbatch_size != batch_size: - if num_search_steps == 0: - highest_non_oom_microbatch_size = state.device_train_microbatch_size - _double_device_train_microbatch_size(state) - _clear_incomplete_train_states(state) - search_upwards = True - elif num_search_steps < max_search_steps: # Previous OOMs found in this training step - highest_non_oom_microbatch_size = state.device_train_microbatch_size - median_microbatch_size = int((highest_non_oom_microbatch_size + lowest_oom_microbatch_size) // 2) - state.device_train_microbatch_size = median_microbatch_size - - num_search_steps += 1 - - # Optimization so we don't repeat a converged value - if median_microbatch_size == highest_non_oom_microbatch_size: - num_search_steps = max_search_steps - - _clear_incomplete_train_states(state) - search_upwards = True - # Else: reached max search steps and found a non-OOM microbatch size - return search_upwards, highest_non_oom_microbatch_size, num_search_steps - -def _handle_thrashing_in_automicrobatching(state: State): - """Searches downward for the highest non-OOMing microbatch size that also doesn't thrash. - This method is only called when two consecutive batches have alloc retries, indicating thrashing, - where GPU memory usage is so close to the memory limit that it hinders throughput. - Automicrobatching searches for the next highest power of 2 to use as the microbatch size. - """ - lowest_oom_microbatch_size = state.device_train_microbatch_size - lower_bound_microbatch_size = _closest_lower_power_of_2(state.device_train_microbatch_size) - highest_non_oom_microbatch_size = lower_bound_microbatch_size - state.device_train_microbatch_size = lower_bound_microbatch_size - return lowest_oom_microbatch_size, highest_non_oom_microbatch_size, lower_bound_microbatch_size From 40252744517b80b87330bd2f690dcc1942a6f542 Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Mon, 29 Jul 2024 23:38:57 -0700 Subject: [PATCH 05/23] remove import state --- composer/utils/automicrobatching.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/composer/utils/automicrobatching.py b/composer/utils/automicrobatching.py index 9f7d60729f..029c44e062 100644 --- a/composer/utils/automicrobatching.py +++ b/composer/utils/automicrobatching.py @@ -1,10 +1,7 @@ import torch import logging -from composer.core import State -from composer.utils import dist from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.fsdp._runtime_utils import _post_backward_final_callback -from collections import defaultdict from packaging import version if version.parse(torch.__version__) >= version.parse('2.3.0'): From 9ad171901cbb2dcc144c11a08e793995ada5d25d Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Mon, 29 Jul 2024 23:43:20 -0700 Subject: [PATCH 06/23] dist --- composer/trainer/_patch_pytorch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/composer/trainer/_patch_pytorch.py b/composer/trainer/_patch_pytorch.py index 25cf365721..4a9f6633c1 100644 --- a/composer/trainer/_patch_pytorch.py +++ b/composer/trainer/_patch_pytorch.py @@ -32,6 +32,8 @@ from torch.distributed.utils import _replace_by_prefix from torch.distributed.fsdp._flat_param import FlatParamHandle +from composer.utils import dist + log = logging.getLogger(__name__) def patch_unshard_for_automicrobatching(auto_microbatch_size_found=False): From 896b9995ffa63678e7e7ebbd6557cef86b94c9d3 Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 00:24:36 -0700 Subject: [PATCH 07/23] fix imports --- composer/trainer/trainer.py | 4 ++-- composer/utils/automicrobatching.py | 6 ------ 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 66fff43aa6..7b867dff5b 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -48,9 +48,9 @@ from torchmetrics import Metric if version.parse(torch.__version__) >= version.parse('2.3.0'): - from torch.amp.grad_scaler import GradScaler # type: ignore + from torch.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state # type: ignore else: - from torch.cuda.amp.grad_scaler import GradScaler # type: ignore + from torch.cuda.amp.grad_scaler import GradScaler, _refresh_per_optimizer_state # type: ignore from composer.callbacks import CheckpointSaver, MemorySnapshot, OOMObserver, OptimizerMonitor from composer.core import ( diff --git a/composer/utils/automicrobatching.py b/composer/utils/automicrobatching.py index 029c44e062..0774fefa58 100644 --- a/composer/utils/automicrobatching.py +++ b/composer/utils/automicrobatching.py @@ -2,12 +2,6 @@ import logging from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.fsdp._runtime_utils import _post_backward_final_callback -from packaging import version - -if version.parse(torch.__version__) >= version.parse('2.3.0'): - from torch.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore -else: - from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state # type: ignore log = logging.getLogger(__name__) From 7146f23eba6fd660a045b7ddb50b6a7b8445d518 Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 00:29:28 -0700 Subject: [PATCH 08/23] import defaultdict --- composer/trainer/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 7b867dff5b..a51d14a0e7 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -38,6 +38,7 @@ import torch.nn as nn import torch.utils.data from packaging import version +from collections import defaultdict from torch._dynamo import OptimizedModule from torch.cuda.amp.grad_scaler import GradScaler from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler From cd2fe9f1af2730c766a3b8caf5ef4f1500e0b986 Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 09:46:01 -0700 Subject: [PATCH 09/23] log for hook on off --- composer/trainer/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index a51d14a0e7..7a47e33fe1 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2986,6 +2986,7 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: # Readd sync hooks if they were previously turned off if len(self.auto_microbatch_hooks) == 0: + print("readding hooks for OOM") patch_unshard_for_automicrobatching(False) for _, module in self.fsdp_modules.items(): if isinstance(module, FullyShardedDataParallel): @@ -3019,6 +3020,7 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: # Readd sync hooks if they were previously turned off if len(self.auto_microbatch_hooks) == 0: + print("readd hooks from thrashing") patch_unshard_for_automicrobatching(False) for _, module in self.fsdp_modules.items(): if isinstance(module, FullyShardedDataParallel): @@ -3045,6 +3047,7 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: ), ) if len(self.auto_microbatch_hooks) > 0: + print("remove hooks from batch completion") patch_unshard_for_automicrobatching(True) for handle in self.auto_microbatch_hooks: handle.remove() @@ -3641,6 +3644,7 @@ def _eval_loop( with torch.no_grad(), model_eval_mode(self.state.model): if self.first_batch_complete: + print("readd hooks for eval") patch_unshard_for_automicrobatching(False) for _ , module in self.fsdp_modules.items(): if isinstance(module, FullyShardedDataParallel): From b1b16cda67f34170d4968bad1d70f1bc91899d6b Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 10:05:17 -0700 Subject: [PATCH 10/23] fixed hook readd bug --- composer/trainer/trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 7a47e33fe1..63d7430f02 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1335,7 +1335,7 @@ def __init__( self.auto_microbatch_size_found = False self.num_alloc_retries = 0 self.num_consecutive_thrashes = 0 - self.auto_microbatch_hooks = [] + if auto_microbatching and profiler: raise ValueError( "`device_train_microbatch_size='auto'` is not compatible with the profiler. It is " @@ -2985,7 +2985,7 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: self.num_consecutive_thrashes = 0 # Readd sync hooks if they were previously turned off - if len(self.auto_microbatch_hooks) == 0: + if len(self.automicrobatch_hook_handless) == 0: print("readding hooks for OOM") patch_unshard_for_automicrobatching(False) for _, module in self.fsdp_modules.items(): @@ -3019,7 +3019,7 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: lowest_oom_microbatch_size, highest_non_oom_microbatch_size, lower_bound_microbatch_size = _handle_thrashing_in_automicrobatching(self.state) # Readd sync hooks if they were previously turned off - if len(self.auto_microbatch_hooks) == 0: + if len(self.automicrobatch_hook_handles) == 0: print("readd hooks from thrashing") patch_unshard_for_automicrobatching(False) for _, module in self.fsdp_modules.items(): @@ -3046,12 +3046,12 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: f'{original_microbatch_size} -> {self.state.device_train_microbatch_size}.', ), ) - if len(self.auto_microbatch_hooks) > 0: + if len(self.automicrobatch_hook_handles) > 0: print("remove hooks from batch completion") patch_unshard_for_automicrobatching(True) - for handle in self.auto_microbatch_hooks: + for handle in self.automicrobatch_hook_handles: handle.remove() - self.auto_microbatch_hooks.clear() + self.automicrobatch_hook_handles.clear() self.auto_microbatch_size_found = True if torch.cuda.is_available(): memory_stats = torch.cuda.memory_stats() From 476e028d5362e15d3246ac130c6d028858f7f9f2 Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 10:12:01 -0700 Subject: [PATCH 11/23] rename hooks to fsdp hooks, will only trigger if fsdp --- composer/trainer/trainer.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 63d7430f02..4a76ec3ef3 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1335,6 +1335,7 @@ def __init__( self.auto_microbatch_size_found = False self.num_alloc_retries = 0 self.num_consecutive_thrashes = 0 + self.automicrobatch_fsdp_hook_handles = [] if auto_microbatching and profiler: raise ValueError( @@ -1837,7 +1838,7 @@ def __init__( if self.state.fsdp_config is not None and self.state.fsdp_config.auto_wrap and not self.state.load_monolith_rank0_only: # Init with globally fixed seed so all HSDP replicas have the same initial weights with reproducibility.seed_context(self.state.rank_zero_seed): - self.automicrobatch_hook_handles, self.fsdp_modules = prepare_fsdp_module( + self.automicrobatch_fsdp_hook_handles, self.fsdp_modules = prepare_fsdp_module( model, optimizers, self.state.fsdp_config, @@ -2007,7 +2008,7 @@ def __init__( ): # Init with globally fixed seed so all HSDP replicas have the same initial weights with reproducibility.seed_context(self.state.rank_zero_seed): - self.automicrobatch_hook_handles, self.fsdp_modules = prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) + self.automicrobatch_fsdp_hook_handles, self.fsdp_modules = prepare_fsdp_module(model, optimizers, self.state.fsdp_config, precision, device, auto_microbatching) self.engine.run_event(Event.AFTER_LOAD) @@ -2985,15 +2986,15 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: self.num_consecutive_thrashes = 0 # Readd sync hooks if they were previously turned off - if len(self.automicrobatch_hook_handless) == 0: + if len(self.automicrobatch_fsdp_hook_handless) == 0: print("readding hooks for OOM") patch_unshard_for_automicrobatching(False) for _, module in self.fsdp_modules.items(): if isinstance(module, FullyShardedDataParallel): - self.automicrobatch_hook_handles.append(module.register_forward_pre_hook(sync_hook, prepend=True)) - self.automicrobatch_hook_handles.append(module.register_full_backward_pre_hook(sync_hook, prepend=True)) + self.automicrobatch_fsdp_hook_handles.append(module.register_forward_pre_hook(sync_hook, prepend=True)) + self.automicrobatch_fsdp_hook_handles.append(module.register_full_backward_pre_hook(sync_hook, prepend=True)) else: - self.automicrobatch_hook_handles.append(module.register_full_backward_hook(sync_hook)) + self.automicrobatch_fsdp_hook_handles.append(module.register_full_backward_hook(sync_hook)) if self.state.device_train_microbatch_size == 1: raise RuntimeError(( @@ -3019,15 +3020,15 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: lowest_oom_microbatch_size, highest_non_oom_microbatch_size, lower_bound_microbatch_size = _handle_thrashing_in_automicrobatching(self.state) # Readd sync hooks if they were previously turned off - if len(self.automicrobatch_hook_handles) == 0: + if len(self.automicrobatch_fsdp_hook_handles) == 0: print("readd hooks from thrashing") patch_unshard_for_automicrobatching(False) for _, module in self.fsdp_modules.items(): if isinstance(module, FullyShardedDataParallel): - self.automicrobatch_hook_handles.append(module.register_forward_pre_hook(sync_hook, prepend=True)) - self.automicrobatch_hook_handles.append(module.register_full_backward_pre_hook(sync_hook, prepend=True)) + self.automicrobatch_fsdp_hook_handles.append(module.register_forward_pre_hook(sync_hook, prepend=True)) + self.automicrobatch_fsdp_hook_handles.append(module.register_full_backward_pre_hook(sync_hook, prepend=True)) else: - self.automicrobatch_hook_handles.append(module.register_full_backward_hook(sync_hook)) + self.automicrobatch_fsdp_hook_handles.append(module.register_full_backward_hook(sync_hook)) continue if not self.auto_microbatch_size_found: # microbatch size found in previous search @@ -3046,12 +3047,12 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: f'{original_microbatch_size} -> {self.state.device_train_microbatch_size}.', ), ) - if len(self.automicrobatch_hook_handles) > 0: + if len(self.automicrobatch_fsdp_hook_handles) > 0: print("remove hooks from batch completion") patch_unshard_for_automicrobatching(True) - for handle in self.automicrobatch_hook_handles: + for handle in self.automicrobatch_fsdp_hook_handles: handle.remove() - self.automicrobatch_hook_handles.clear() + self.automicrobatch_fsdp_hook_handles.clear() self.auto_microbatch_size_found = True if torch.cuda.is_available(): memory_stats = torch.cuda.memory_stats() @@ -3648,10 +3649,10 @@ def _eval_loop( patch_unshard_for_automicrobatching(False) for _ , module in self.fsdp_modules.items(): if isinstance(module, FullyShardedDataParallel): - self.automicrobatch_hook_handles.append(module.register_forward_pre_hook(sync_hook, prepend=True)) - self.automicrobatch_hook_handles.append(module.register_full_backward_pre_hook(sync_hook, prepend=True)) + self.automicrobatch_fsdp_hook_handles.append(module.register_forward_pre_hook(sync_hook, prepend=True)) + self.automicrobatch_fsdp_hook_handles.append(module.register_full_backward_pre_hook(sync_hook, prepend=True)) else: - self.automicrobatch_hook_handles.append(module.register_full_backward_hook(sync_hook)) + self.automicrobatch_fsdp_hook_handles.append(module.register_full_backward_hook(sync_hook)) self.state.set_dataloader(data_spec.dataloader, evaluator.label, subset_num_batches) assert self.state.dataloader is not None, 'dataloader is set' From 069336423e502220d06b8df3812448c2aa9b50b4 Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 10:16:27 -0700 Subject: [PATCH 12/23] only invoke hook logic if fsdp enabled --- composer/trainer/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 4a76ec3ef3..2240366125 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2986,7 +2986,7 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: self.num_consecutive_thrashes = 0 # Readd sync hooks if they were previously turned off - if len(self.automicrobatch_fsdp_hook_handless) == 0: + if self.state.fsdp_enabled and len(self.automicrobatch_fsdp_hook_handless) == 0: print("readding hooks for OOM") patch_unshard_for_automicrobatching(False) for _, module in self.fsdp_modules.items(): @@ -3020,7 +3020,7 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: lowest_oom_microbatch_size, highest_non_oom_microbatch_size, lower_bound_microbatch_size = _handle_thrashing_in_automicrobatching(self.state) # Readd sync hooks if they were previously turned off - if len(self.automicrobatch_fsdp_hook_handles) == 0: + if self.state.fsdp_enabled and len(self.automicrobatch_fsdp_hook_handles) == 0: print("readd hooks from thrashing") patch_unshard_for_automicrobatching(False) for _, module in self.fsdp_modules.items(): @@ -3047,7 +3047,7 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: f'{original_microbatch_size} -> {self.state.device_train_microbatch_size}.', ), ) - if len(self.automicrobatch_fsdp_hook_handles) > 0: + if self.state.fsdp_enabled and len(self.automicrobatch_fsdp_hook_handles) > 0: print("remove hooks from batch completion") patch_unshard_for_automicrobatching(True) for handle in self.automicrobatch_fsdp_hook_handles: @@ -3644,7 +3644,7 @@ def _eval_loop( sync_hook = _create_sync_hook(self.state) with torch.no_grad(), model_eval_mode(self.state.model): - if self.first_batch_complete: + if self.state.fsdp_enabled and self.first_batch_complete: print("readd hooks for eval") patch_unshard_for_automicrobatching(False) for _ , module in self.fsdp_modules.items(): From df404ac880f3409b5f18ca61618a14dcfb0c9441 Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 10:16:53 -0700 Subject: [PATCH 13/23] typo --- composer/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 2240366125..efbd288ccf 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2986,7 +2986,7 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: self.num_consecutive_thrashes = 0 # Readd sync hooks if they were previously turned off - if self.state.fsdp_enabled and len(self.automicrobatch_fsdp_hook_handless) == 0: + if self.state.fsdp_enabled and len(self.automicrobatch_fsdp_hook_handles) == 0: print("readding hooks for OOM") patch_unshard_for_automicrobatching(False) for _, module in self.fsdp_modules.items(): From 0b6d6cec91f78a925340165ac1dbe86b8fce17fb Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 10:47:49 -0700 Subject: [PATCH 14/23] fix seq length warmup --- .../seq_length_warmup/seq_length_warmup.py | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/composer/algorithms/seq_length_warmup/seq_length_warmup.py b/composer/algorithms/seq_length_warmup/seq_length_warmup.py index 2ab0eecee0..6306f920b7 100644 --- a/composer/algorithms/seq_length_warmup/seq_length_warmup.py +++ b/composer/algorithms/seq_length_warmup/seq_length_warmup.py @@ -5,6 +5,8 @@ import logging import textwrap +import warnings +from collections import defaultdict from typing import Mapping, Optional import torch @@ -284,7 +286,32 @@ def _activate_model(self, state: State, logger: Logger) -> None: batch_clone[k] = v[:, :self.max_seq_length].contiguous() # In-line to avoid circular dependency - from composer.trainer.trainer import _adjust_device_train_microbatch_size, _is_cuda_oom + from composer.trainer.trainer import _clear_incomplete_train_states, _is_cuda_oom + + def _adjust_device_train_microbatch_size(state: State): + """Adjust device_train_microbatch_size if we encounter OOM. + + Args: + state (State): State of trainer. + """ + # If any rank hit CUDA OOM, update device_train_microbatch_size and retry. Raise runtime error + # if training 1 sample at a time still resulted in CUDA out of memory. + assert state.device_train_microbatch_size is not None + if state.device_train_microbatch_size == 1: + raise RuntimeError(( + 'CUDA out of memory. The train loop failed with an internal microbatch of size 1.' + 'The GPU does not have enough memory to process even 1 sample during train.' + )) + else: + original_microbatch_size = state.device_train_microbatch_size + state.device_train_microbatch_size = max(int(original_microbatch_size / 2), 1) + warnings.warn( + RuntimeWarning( + 'CUDA out of memory detected. Train microbatch size will be decreased from ' + f'{original_microbatch_size} -> {state.device_train_microbatch_size}.', + ), + ) + _clear_incomplete_train_states # This loop tries to do a forward/backward pass using the current microbatch size. # If it hits an OOM error, it halves `state.device_train_microbatch_size` and tries again From 153c41301581c30e7b177f235de28fcf3d0439c8 Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 15:53:22 -0700 Subject: [PATCH 15/23] only patch flat param handle unshard if > 2.3 --- composer/trainer/_patch_pytorch.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/composer/trainer/_patch_pytorch.py b/composer/trainer/_patch_pytorch.py index 4a9f6633c1..8f9b1de36a 100644 --- a/composer/trainer/_patch_pytorch.py +++ b/composer/trainer/_patch_pytorch.py @@ -30,18 +30,18 @@ from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy from torch.distributed.fsdp._fsdp_extensions import _ext_pre_load_state_dict_transform from torch.distributed.utils import _replace_by_prefix -from torch.distributed.fsdp._flat_param import FlatParamHandle from composer.utils import dist log = logging.getLogger(__name__) def patch_unshard_for_automicrobatching(auto_microbatch_size_found=False): - if auto_microbatch_size_found: - FlatParamHandle.unshard = (unshard) - print("dropping monkey") - else: - FlatParamHandle.unshard = (unshard_with_sync) + if version.parse(torch.__version__ >= version.parse('2.3.1')): + from torch.distributed.fsdp._flat_param import FlatParamHandle + if auto_microbatch_size_found: + FlatParamHandle.unshard = (unshard) + else: + FlatParamHandle.unshard = (unshard_with_sync) def patch_pytorch(): """Monkey patches pytorch functions based on pytorch version.""" From d98926d03dda367f90b798badd7664a13d892385 Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 16:00:54 -0700 Subject: [PATCH 16/23] fix version comparison --- composer/trainer/_patch_pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/trainer/_patch_pytorch.py b/composer/trainer/_patch_pytorch.py index 8f9b1de36a..3aef3a05ab 100644 --- a/composer/trainer/_patch_pytorch.py +++ b/composer/trainer/_patch_pytorch.py @@ -36,7 +36,7 @@ log = logging.getLogger(__name__) def patch_unshard_for_automicrobatching(auto_microbatch_size_found=False): - if version.parse(torch.__version__ >= version.parse('2.3.1')): + if version.parse(torch.__version__) >= version.parse('2.3.1'): from torch.distributed.fsdp._flat_param import FlatParamHandle if auto_microbatch_size_found: FlatParamHandle.unshard = (unshard) From 3e82ef669a74d24b5228e97d38b0d8acdfa07824 Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 16:41:43 -0700 Subject: [PATCH 17/23] mark unit test --- tests/trainer/test_fsdp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index e42ef17a73..ac112415d8 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -250,6 +250,7 @@ def loss(self, outputs, batch): return torch.sum(outputs) @pytest.mark.gpu +@pytest.mark.parametrize('device', _INIT_DEVICES) @world_size(2) def test_automicrobatching_fsdp(world_size: int): model = SimpleMLPForTestingOOM() From a09b844d4cdce82bb428200b94dca555b55c986a Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 17:20:52 -0700 Subject: [PATCH 18/23] remove device mark --- tests/trainer/test_fsdp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index ac112415d8..e42ef17a73 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -250,7 +250,6 @@ def loss(self, outputs, batch): return torch.sum(outputs) @pytest.mark.gpu -@pytest.mark.parametrize('device', _INIT_DEVICES) @world_size(2) def test_automicrobatching_fsdp(world_size: int): model = SimpleMLPForTestingOOM() From 33840c5c24ca13a5d117da6df434441e4c4bc438 Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 18:56:06 -0700 Subject: [PATCH 19/23] filter user warnigns out --- tests/trainer/test_fsdp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index e42ef17a73..91a27e2d3f 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -250,6 +250,7 @@ def loss(self, outputs, batch): return torch.sum(outputs) @pytest.mark.gpu +@pytest.mark.filterwarnings("ignore:device_train_microbatch_size='auto'") @world_size(2) def test_automicrobatching_fsdp(world_size: int): model = SimpleMLPForTestingOOM() From e87c9f6933e95960275990816a30611dfbe7c5fb Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 18:58:15 -0700 Subject: [PATCH 20/23] fix --- tests/trainer/test_fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index 91a27e2d3f..4c33af3333 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -250,7 +250,7 @@ def loss(self, outputs, batch): return torch.sum(outputs) @pytest.mark.gpu -@pytest.mark.filterwarnings("ignore:device_train_microbatch_size='auto'") +@pytest.mark.filterwarnings("ignore:`device_train_microbatch_size='auto'` may potentially fail with unexpected.*") @world_size(2) def test_automicrobatching_fsdp(world_size: int): model = SimpleMLPForTestingOOM() From 86c32a0c43a544f7f7d1529776d0a889429df93c Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 19:01:14 -0700 Subject: [PATCH 21/23] dist sampler --- tests/trainer/test_fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index 4c33af3333..67d2d5a5bc 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -257,7 +257,7 @@ def test_automicrobatching_fsdp(world_size: int): model.fc1._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] model.fc2._fsdp_wrap = True # pyright: ignore[reportGeneralTypeIssues] dataset = SimpleDatasetForAuto(size=256, feature_size=128) - train_dataloader = DataLoader(dataset, batch_size=64) + train_dataloader = DataLoader(dataset, batch_size=64, sampler=dist.get_sampler(dataset)) trainer = Trainer( model=model, train_dataloader=train_dataloader, From a193b76b2497e15af1a7d7e30c7365d965875cc2 Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Tue, 30 Jul 2024 19:03:32 -0700 Subject: [PATCH 22/23] ignore runtime warning --- tests/trainer/test_fsdp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index 67d2d5a5bc..39a840ef03 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -251,6 +251,7 @@ def loss(self, outputs, batch): @pytest.mark.gpu @pytest.mark.filterwarnings("ignore:`device_train_microbatch_size='auto'` may potentially fail with unexpected.*") +@pytest.mark.filterwarnings("ignore:Automicrobatching changed the microbatch size from*") @world_size(2) def test_automicrobatching_fsdp(world_size: int): model = SimpleMLPForTestingOOM() From 0b6a30d30c5a921f32cd787e603b4f147b7c3764 Mon Sep 17 00:00:00 2001 From: jack-zhang_data Date: Wed, 31 Jul 2024 10:01:38 -0700 Subject: [PATCH 23/23] only drop hooks after 3 consecutive successes with this microbatch size --- composer/trainer/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index efbd288ccf..f27b10cb2d 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1335,6 +1335,7 @@ def __init__( self.auto_microbatch_size_found = False self.num_alloc_retries = 0 self.num_consecutive_thrashes = 0 + self.num_consecutive_non_OOM_batches = 0 self.automicrobatch_fsdp_hook_handles = [] if auto_microbatching and profiler: @@ -2984,6 +2985,7 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: _clear_incomplete_train_states(self.state) self.auto_microbatch_size_found = False self.num_consecutive_thrashes = 0 + self.num_consecutive_non_OOM_batches = 0 # Readd sync hooks if they were previously turned off if self.state.fsdp_enabled and len(self.automicrobatch_fsdp_hook_handles) == 0: @@ -3047,7 +3049,8 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]: f'{original_microbatch_size} -> {self.state.device_train_microbatch_size}.', ), ) - if self.state.fsdp_enabled and len(self.automicrobatch_fsdp_hook_handles) > 0: + self.num_consecutive_non_OOM_batches += 1 + if self.state.fsdp_enabled and len(self.automicrobatch_fsdp_hook_handles) > 0 and self.num_consecutive_non_OOM_batches >= 3: print("remove hooks from batch completion") patch_unshard_for_automicrobatching(True) for handle in self.automicrobatch_fsdp_hook_handles: