diff --git a/composer/algorithms/gradient_clipping/gradient_clipping.py b/composer/algorithms/gradient_clipping/gradient_clipping.py index abd712b152..73464d9c8e 100644 --- a/composer/algorithms/gradient_clipping/gradient_clipping.py +++ b/composer/algorithms/gradient_clipping/gradient_clipping.py @@ -25,7 +25,7 @@ def apply_gradient_clipping( clipping_type: str, clipping_threshold: float, fsdp_enabled: bool, -): +) -> Union[torch.Tensor, None]: """Clips all gradients in model based on specified clipping_type. Args: @@ -41,12 +41,16 @@ def apply_gradient_clipping( threshold by which if grad_norm / weight_norm is greater than this threshold then scale gradients by this threshold * (weight_norm / grad_norm) (for 'adaptive'). fsdp_enabled (bool): Bool of if the model is a FSDP model or not. + + Returns: + Union[torch.Tensor, None]: The total gradient norm before clipping for 'norm' clipping type, + None otherwise. """ if fsdp_enabled: for module in model.modules(): if isinstance(module, FullyShardedDataParallel) and module.check_is_root(): if clipping_type == 'norm': - module.clip_grad_norm_(max_norm=clipping_threshold) + return module.clip_grad_norm_(max_norm=clipping_threshold) elif clipping_type == 'value': module.clip_grad_norm_(max_norm=clipping_threshold, norm_type=float('inf')) else: @@ -56,12 +60,14 @@ def apply_gradient_clipping( if clipping_type == 'adaptive': _apply_agc(parameters, clipping_threshold=clipping_threshold) elif clipping_type == 'norm': - torch.nn.utils.clip_grad_norm_(parameters, max_norm=clipping_threshold) + return torch.nn.utils.clip_grad_norm_(parameters, max_norm=clipping_threshold) elif clipping_type == 'value': torch.nn.utils.clip_grad_value_(parameters, clip_value=clipping_threshold) else: raise ValueError(f"clipping_type must be 'adaptive', 'norm', or 'value' not {clipping_type} ") + return None + def _apply_agc( parameters: Union[torch.Tensor, Iterable[torch.Tensor]], @@ -122,24 +128,51 @@ class GradientClipping(Algorithm): to (for 'value'), what values to clip the gradient norms to (for 'norm'), and threshold by which if grad_norm / weight_norm is greater than this threshold then scale gradients by this threshold * (weight_norm / grad_norm) (for 'adaptive'). + clipping_frequency_window (int, optional): Number of steps to use for calculating + the rolling average of clipping frequency. Only used for 'norm' clipping type. + Defaults to 100. """ - def __init__(self, clipping_type: str, clipping_threshold: float): + def __init__(self, clipping_type: str, clipping_threshold: float, clipping_frequency_window: int = 100): self.clipping_type = clipping_type self.clipping_threshold = clipping_threshold + self.clipping_frequency_window = clipping_frequency_window + self._clipping_history = [] def match(self, event: Event, state: State) -> bool: return event in [Event.INIT, Event.AFTER_TRAIN_BATCH] def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: if event == Event.AFTER_TRAIN_BATCH: - apply_gradient_clipping( + maybe_grad_norm = apply_gradient_clipping( model=state.model, clipping_type=self.clipping_type, clipping_threshold=self.clipping_threshold, fsdp_enabled=state.fsdp_config_version == 1, ) + if self.clipping_type == 'norm': + if maybe_grad_norm is None: + raise RuntimeError("Expected gradient norm to be returned for 'norm' clipping type, but got None") + + grad_norm = maybe_grad_norm.item() + + # Log the gradient norm before clipping + logger.log_metrics({'gradient_norm/unclipped_magnitude': grad_norm}) + + # Log whether clipping was applied + clipping_applied = grad_norm > self.clipping_threshold + logger.log_metrics({'gradient_norm/clipped': float(clipping_applied)}) + + # Track clipping frequency + self._clipping_history.append(float(clipping_applied)) + # Keep only last N steps for frequency calculation + if len(self._clipping_history) > self.clipping_frequency_window: + self._clipping_history.pop(0) + + clipping_frequency = sum(self._clipping_history) / len(self._clipping_history) + logger.log_metrics({'gradient_norm/clipping_frequency': clipping_frequency}) + def _get_clipped_gradient_coeff(weights: torch.Tensor, grad: torch.Tensor, clipping_threshold: float = 0.01): """Clips all gradients in model based on ratio of gradient norms to parameter norms.