|
26 | 26 | is_flash_attn_3_available,
|
27 | 27 | is_flash_attn_available,
|
28 | 28 | is_flash_attn_version,
|
| 29 | + is_kernels_available, |
29 | 30 | is_sageattention_available,
|
30 | 31 | is_sageattention_version,
|
31 | 32 | is_torch_npu_available,
|
|
35 | 36 | is_xformers_available,
|
36 | 37 | is_xformers_version,
|
37 | 38 | )
|
38 | | -from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS |
| 39 | +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS |
39 | 40 |
|
40 | 41 |
|
41 | 42 | _REQUIRED_FLASH_VERSION = "2.6.3"
|
|
67 | 68 | flash_attn_3_func = None
|
68 | 69 | flash_attn_3_varlen_func = None
|
69 | 70 |
|
| 71 | +if DIFFUSERS_ENABLE_HUB_KERNELS: |
| 72 | + if not is_kernels_available(): |
| 73 | + raise ImportError( |
| 74 | + "To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`." |
| 75 | + ) |
| 76 | + from ..utils.kernels_utils import _get_fa3_from_hub |
| 77 | + |
| 78 | + flash_attn_interface_hub = _get_fa3_from_hub() |
| 79 | + flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func |
| 80 | +else: |
| 81 | + flash_attn_3_func_hub = None |
70 | 82 |
|
71 | 83 | if _CAN_USE_SAGE_ATTN:
|
72 | 84 | from sageattention import (
|
@@ -153,6 +165,8 @@ class AttentionBackendName(str, Enum):
|
153 | 165 | FLASH_VARLEN = "flash_varlen"
|
154 | 166 | _FLASH_3 = "_flash_3"
|
155 | 167 | _FLASH_VARLEN_3 = "_flash_varlen_3"
|
| 168 | + _FLASH_3_HUB = "_flash_3_hub" |
| 169 | + # _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet. |
156 | 170 |
|
157 | 171 | # PyTorch native
|
158 | 172 | FLEX = "flex"
|
@@ -351,6 +365,17 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
351 | 365 | f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
352 | 366 | )
|
353 | 367 |
|
| 368 | + # TODO: add support Hub variant of FA3 varlen later |
| 369 | + elif backend in [AttentionBackendName._FLASH_3_HUB]: |
| 370 | + if not DIFFUSERS_ENABLE_HUB_KERNELS: |
| 371 | + raise RuntimeError( |
| 372 | + f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." |
| 373 | + ) |
| 374 | + if not is_kernels_available(): |
| 375 | + raise RuntimeError( |
| 376 | + f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." |
| 377 | + ) |
| 378 | + |
354 | 379 | elif backend in [
|
355 | 380 | AttentionBackendName.SAGE,
|
356 | 381 | AttentionBackendName.SAGE_VARLEN,
|
@@ -657,6 +682,44 @@ def _flash_attention_3(
|
657 | 682 | return (out, lse) if return_attn_probs else out
|
658 | 683 |
|
659 | 684 |
|
| 685 | +@_AttentionBackendRegistry.register( |
| 686 | + AttentionBackendName._FLASH_3_HUB, |
| 687 | + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| 688 | +) |
| 689 | +def _flash_attention_3_hub( |
| 690 | + query: torch.Tensor, |
| 691 | + key: torch.Tensor, |
| 692 | + value: torch.Tensor, |
| 693 | + scale: Optional[float] = None, |
| 694 | + is_causal: bool = False, |
| 695 | + window_size: Tuple[int, int] = (-1, -1), |
| 696 | + softcap: float = 0.0, |
| 697 | + deterministic: bool = False, |
| 698 | + return_attn_probs: bool = False, |
| 699 | +) -> torch.Tensor: |
| 700 | + out = flash_attn_3_func_hub( |
| 701 | + q=query, |
| 702 | + k=key, |
| 703 | + v=value, |
| 704 | + softmax_scale=scale, |
| 705 | + causal=is_causal, |
| 706 | + qv=None, |
| 707 | + q_descale=None, |
| 708 | + k_descale=None, |
| 709 | + v_descale=None, |
| 710 | + window_size=window_size, |
| 711 | + softcap=softcap, |
| 712 | + num_splits=1, |
| 713 | + pack_gqa=None, |
| 714 | + deterministic=deterministic, |
| 715 | + sm_margin=0, |
| 716 | + return_attn_probs=return_attn_probs, |
| 717 | + ) |
| 718 | + # When `return_attn_probs` is True, the above returns a tuple of |
| 719 | + # actual outputs and lse. |
| 720 | + return (out[0], out[1]) if return_attn_probs else out |
| 721 | + |
| 722 | + |
660 | 723 | @_AttentionBackendRegistry.register(
|
661 | 724 | AttentionBackendName._FLASH_VARLEN_3,
|
662 | 725 | constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
|
0 commit comments