Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import inspect
import math
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -83,12 +83,17 @@
raise ImportError(
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa3_from_hub
from ..utils.kernels_utils import _DEFAULT_HUB_ID_FA3, _DEFAULT_HUB_ID_SAGE, _get_kernel_from_hub

flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_FA3)
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func

sage_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_SAGE)
sage_attn_func_hub = sage_interface_hub.sageattn
Comment on lines -86 to +92
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a huge fan of downloading all kernels if the env variable is set, since it's downloading stuff without explicit user consent. I think we need to rethink this part a bit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly. This is why #12475. Let's get that reviewed and merged first as it will unblock this PR and also #12387


else:
flash_attn_3_func_hub = None
sage_attn_func_hub = None

if _CAN_USE_SAGE_ATTN:
from sageattention import (
Expand Down Expand Up @@ -162,10 +167,6 @@ def wrap(func):
# - CP with sage attention, flex, xformers, other missing backends
# - Add support for normal and CP training with backends that don't support it yet

_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
Comment on lines -165 to -167
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see their usage, hence removed.



class AttentionBackendName(str, Enum):
# EAGER = "eager"
Expand All @@ -190,6 +191,7 @@ class AttentionBackendName(str, Enum):

# `sageattention`
SAGE = "sage"
SAGE_HUB = "sage_hub"
SAGE_VARLEN = "sage_varlen"
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
Expand Down Expand Up @@ -404,14 +406,14 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
)

# TODO: add support Hub variant of FA3 varlen later
elif backend in [AttentionBackendName._FLASH_3_HUB]:
elif backend in [AttentionBackendName._FLASH_3_HUB, AttentionBackendName.SAGE_HUB]:
if not DIFFUSERS_ENABLE_HUB_KERNELS:
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
f"Attention backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
)
if not is_kernels_available():
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
f"Attention backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)

elif backend in [
Expand Down Expand Up @@ -1756,6 +1758,31 @@ def _sage_attention(
return (out, lse) if return_lse else out


@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_HUB,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
)
def _sage_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
lse = None
if _parallel_config is None:
out = sage_attn_func_hub(q=query, k=key, v=value)
if return_lse:
out, lse, *_ = out
else:
raise NotImplementedError("SAGE attention doesn't yet support parallelism.")

return (out, lse) if return_lse else out


@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_VARLEN,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
Expand Down
17 changes: 12 additions & 5 deletions src/diffusers/utils/kernels_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,25 @@


_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
_DEFAULT_HUB_ID_SAGE = "kernels-community/sage_attention"
_KERNEL_REVISION = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
_DEFAULT_HUB_ID_FA3: "fake-ops-return-probs",
_DEFAULT_HUB_ID_SAGE: None,
}


def _get_fa3_from_hub():
def _get_kernel_from_hub(kernel_id):
if not is_kernels_available():
return None
else:
from kernels import get_kernel

try:
# TODO: temporary revision for now. Remove when merged upstream into `main`.
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
return flash_attn_3_hub
if kernel_id not in _KERNEL_REVISION:
raise NotImplementedError(f"{kernel_id} is not implemented in Diffusers.")
kernel_hub = get_kernel(kernel_id, revision=_KERNEL_REVISION.get(kernel_id))
return kernel_hub
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
logger.error(f"An error occurred while fetching kernel '{kernel_id}' from the Hub: {e}")
raise
Loading