diff --git a/mamba_ssm/utils/torch.py b/mamba_ssm/utils/torch.py index afe1dfcf..37df47c8 100644 --- a/mamba_ssm/utils/torch.py +++ b/mamba_ssm/utils/torch.py @@ -1,16 +1,18 @@ import torch from functools import partial +from typing import Callable - -def custom_amp_decorator(dec, cuda_amp_deprecated): - def decorator(func): - return dec(func) if not cuda_amp_deprecated else partial(dec, func, device_type="cuda") +def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): + def decorator(*args, **kwargs): + if cuda_amp_deprecated: + kwargs["device_type"] = "cuda" + return dec(*args, **kwargs) return decorator -if hasattr(torch.amp, "custom_fwd"): +if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] deprecated = True - from torch.amp import custom_fwd, custom_bwd + from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] else: deprecated = False from torch.cuda.amp import custom_fwd, custom_bwd