Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 130fd8d

Browse files
[core] use kernels to support _flash_3_hub attention backend (#12236)
* feat: try loading fa3 using kernels when available. * up * change to Hub. * up * up * up * switch env var. * up * up * up * up * up * up
1 parent bcd4d77 commit 130fd8d

File tree

3 files changed

+88
-1
lines changed

3 files changed

+88
-1
lines changed

‎src/diffusers/models/attention_dispatch.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
is_flash_attn_3_available,
2727
is_flash_attn_available,
2828
is_flash_attn_version,
29+
is_kernels_available,
2930
is_sageattention_available,
3031
is_sageattention_version,
3132
is_torch_npu_available,
@@ -35,7 +36,7 @@
3536
is_xformers_available,
3637
is_xformers_version,
3738
)
38-
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
39+
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
3940

4041

4142
_REQUIRED_FLASH_VERSION = "2.6.3"
@@ -67,6 +68,17 @@
6768
flash_attn_3_func = None
6869
flash_attn_3_varlen_func = None
6970

71+
if DIFFUSERS_ENABLE_HUB_KERNELS:
72+
if not is_kernels_available():
73+
raise ImportError(
74+
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
75+
)
76+
from ..utils.kernels_utils import _get_fa3_from_hub
77+
78+
flash_attn_interface_hub = _get_fa3_from_hub()
79+
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
80+
else:
81+
flash_attn_3_func_hub = None
7082

7183
if _CAN_USE_SAGE_ATTN:
7284
from sageattention import (
@@ -153,6 +165,8 @@ class AttentionBackendName(str, Enum):
153165
FLASH_VARLEN = "flash_varlen"
154166
_FLASH_3 = "_flash_3"
155167
_FLASH_VARLEN_3 = "_flash_varlen_3"
168+
_FLASH_3_HUB = "_flash_3_hub"
169+
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
156170

157171
# PyTorch native
158172
FLEX = "flex"
@@ -351,6 +365,17 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
351365
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
352366
)
353367

368+
# TODO: add support Hub variant of FA3 varlen later
369+
elif backend in [AttentionBackendName._FLASH_3_HUB]:
370+
if not DIFFUSERS_ENABLE_HUB_KERNELS:
371+
raise RuntimeError(
372+
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
373+
)
374+
if not is_kernels_available():
375+
raise RuntimeError(
376+
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
377+
)
378+
354379
elif backend in [
355380
AttentionBackendName.SAGE,
356381
AttentionBackendName.SAGE_VARLEN,
@@ -657,6 +682,44 @@ def _flash_attention_3(
657682
return (out, lse) if return_attn_probs else out
658683

659684

685+
@_AttentionBackendRegistry.register(
686+
AttentionBackendName._FLASH_3_HUB,
687+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
688+
)
689+
def _flash_attention_3_hub(
690+
query: torch.Tensor,
691+
key: torch.Tensor,
692+
value: torch.Tensor,
693+
scale: Optional[float] = None,
694+
is_causal: bool = False,
695+
window_size: Tuple[int, int] = (-1, -1),
696+
softcap: float = 0.0,
697+
deterministic: bool = False,
698+
return_attn_probs: bool = False,
699+
) -> torch.Tensor:
700+
out = flash_attn_3_func_hub(
701+
q=query,
702+
k=key,
703+
v=value,
704+
softmax_scale=scale,
705+
causal=is_causal,
706+
qv=None,
707+
q_descale=None,
708+
k_descale=None,
709+
v_descale=None,
710+
window_size=window_size,
711+
softcap=softcap,
712+
num_splits=1,
713+
pack_gqa=None,
714+
deterministic=deterministic,
715+
sm_margin=0,
716+
return_attn_probs=return_attn_probs,
717+
)
718+
# When `return_attn_probs` is True, the above returns a tuple of
719+
# actual outputs and lse.
720+
return (out[0], out[1]) if return_attn_probs else out
721+
722+
660723
@_AttentionBackendRegistry.register(
661724
AttentionBackendName._FLASH_VARLEN_3,
662725
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

‎src/diffusers/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
4747
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
4848
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES
49+
DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES
4950

5051
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
5152
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

‎src/diffusers/utils/kernels_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from ..utils import get_logger
2+
from .import_utils import is_kernels_available
3+
4+
5+
logger = get_logger(__name__)
6+
7+
8+
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
9+
10+
11+
def _get_fa3_from_hub():
12+
if not is_kernels_available():
13+
return None
14+
else:
15+
from kernels import get_kernel
16+
17+
try:
18+
# TODO: temporary revision for now. Remove when merged upstream into `main`.
19+
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
20+
return flash_attn_3_hub
21+
except Exception as e:
22+
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
23+
raise

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /