Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 250f5cb

Browse files
lauri9sayakpaulgithub-actions[bot]
authored
Add AITER attention backend (#12549)
* add aiter attention backend * Apply style fixes --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent dc6bd15 commit 250f5cb

File tree

5 files changed

+98
-0
lines changed

5 files changed

+98
-0
lines changed

‎docs/source/en/optimization/attention_backends.md‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Refer to the table below for an overview of the available attention families and
2121
| attention family | main feature |
2222
|---|---|
2323
| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
24+
| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators |
2425
| SageAttention | quantizes attention to int8 |
2526
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
2627
| xFormers | memory-efficient attention with support for various attention kernels |
@@ -139,6 +140,7 @@ Refer to the table below for a complete list of available attention backends and
139140
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
140141
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
141142
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
143+
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
142144
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
143145
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
144146
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |

‎src/diffusers/models/attention_dispatch.py‎

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
from ..utils import (
2929
get_logger,
30+
is_aiter_available,
31+
is_aiter_version,
3032
is_flash_attn_3_available,
3133
is_flash_attn_available,
3234
is_flash_attn_version,
@@ -47,13 +49,15 @@
4749
from ._modeling_parallel import ParallelConfig
4850

4951
_REQUIRED_FLASH_VERSION = "2.6.3"
52+
_REQUIRED_AITER_VERSION = "0.1.5"
5053
_REQUIRED_SAGE_VERSION = "2.1.1"
5154
_REQUIRED_FLEX_VERSION = "2.5.0"
5255
_REQUIRED_XLA_VERSION = "2.2"
5356
_REQUIRED_XFORMERS_VERSION = "0.0.29"
5457

5558
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
5659
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
60+
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
5761
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
5862
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
5963
_CAN_USE_NPU_ATTN = is_torch_npu_available()
@@ -78,6 +82,12 @@
7882
flash_attn_3_func = None
7983
flash_attn_3_varlen_func = None
8084

85+
86+
if _CAN_USE_AITER_ATTN:
87+
from aiter import flash_attn_func as aiter_flash_attn_func
88+
else:
89+
aiter_flash_attn_func = None
90+
8191
if DIFFUSERS_ENABLE_HUB_KERNELS:
8292
if not is_kernels_available():
8393
raise ImportError(
@@ -178,6 +188,9 @@ class AttentionBackendName(str, Enum):
178188
_FLASH_3_HUB = "_flash_3_hub"
179189
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
180190

191+
# `aiter`
192+
AITER = "aiter"
193+
181194
# PyTorch native
182195
FLEX = "flex"
183196
NATIVE = "native"
@@ -414,6 +427,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
414427
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
415428
)
416429

430+
elif backend == AttentionBackendName.AITER:
431+
if not _CAN_USE_AITER_ATTN:
432+
raise RuntimeError(
433+
f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
434+
)
435+
417436
elif backend in [
418437
AttentionBackendName.SAGE,
419438
AttentionBackendName.SAGE_VARLEN,
@@ -1397,6 +1416,47 @@ def _flash_varlen_attention_3(
13971416
return (out, lse) if return_lse else out
13981417

13991418

1419+
@_AttentionBackendRegistry.register(
1420+
AttentionBackendName.AITER,
1421+
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1422+
)
1423+
def _aiter_flash_attention(
1424+
query: torch.Tensor,
1425+
key: torch.Tensor,
1426+
value: torch.Tensor,
1427+
dropout_p: float = 0.0,
1428+
is_causal: bool = False,
1429+
scale: Optional[float] = None,
1430+
return_lse: bool = False,
1431+
_parallel_config: Optional["ParallelConfig"] = None,
1432+
) -> torch.Tensor:
1433+
if not return_lse and torch.is_grad_enabled():
1434+
# aiter requires return_lse=True by assertion when gradients are enabled.
1435+
out, lse, *_ = aiter_flash_attn_func(
1436+
q=query,
1437+
k=key,
1438+
v=value,
1439+
dropout_p=dropout_p,
1440+
softmax_scale=scale,
1441+
causal=is_causal,
1442+
return_lse=True,
1443+
)
1444+
else:
1445+
out = aiter_flash_attn_func(
1446+
q=query,
1447+
k=key,
1448+
v=value,
1449+
dropout_p=dropout_p,
1450+
softmax_scale=scale,
1451+
causal=is_causal,
1452+
return_lse=return_lse,
1453+
)
1454+
if return_lse:
1455+
out, lse, *_ = out
1456+
1457+
return (out, lse) if return_lse else out
1458+
1459+
14001460
@_AttentionBackendRegistry.register(
14011461
AttentionBackendName.FLEX,
14021462
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],

‎src/diffusers/utils/__init__.py‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@
6464
get_objects_from_module,
6565
is_accelerate_available,
6666
is_accelerate_version,
67+
is_aiter_available,
68+
is_aiter_version,
6769
is_better_profanity_available,
6870
is_bitsandbytes_available,
6971
is_bitsandbytes_version,

‎src/diffusers/utils/import_utils.py‎

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
226226
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
227227
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
228228
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
229+
_aiter_available, _aiter_version = _is_package_available("aiter")
229230
_kornia_available, _kornia_version = _is_package_available("kornia")
230231
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
231232

@@ -406,6 +407,10 @@ def is_flash_attn_3_available():
406407
return _flash_attn_3_available
407408

408409

410+
def is_aiter_available():
411+
return _aiter_available
412+
413+
409414
def is_kornia_available():
410415
return _kornia_available
411416

@@ -911,6 +916,22 @@ def is_flash_attn_version(operation: str, version: str):
911916
return compare_versions(parse(_flash_attn_version), operation, version)
912917

913918

919+
@cache
920+
def is_aiter_version(operation: str, version: str):
921+
"""
922+
Compares the current aiter version to a given reference with an operation.
923+
924+
Args:
925+
operation (`str`):
926+
A string representation of an operator, such as `">"` or `"<="`
927+
version (`str`):
928+
A version string
929+
"""
930+
if not _aiter_available:
931+
return False
932+
return compare_versions(parse(_aiter_version), operation, version)
933+
934+
914935
def get_objects_from_module(module):
915936
"""
916937
Returns a dict of object names and values in a module, while skipping private/internal objects

‎tests/others/test_attention_backends.py‎

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
1515
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
1616
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
17+
18+
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
19+
with torch 2025年09月25日 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
20+
aiter 0.1.5.post4.dev20+ga25e55e79.
1721
"""
1822

1923
import os
@@ -44,6 +48,10 @@
4448
"_native_cudnn",
4549
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
4650
),
51+
(
52+
"aiter",
53+
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
54+
)
4755
]
4856

4957
COMPILE_CASES = [
@@ -63,6 +71,11 @@
6371
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
6472
True,
6573
),
74+
(
75+
"aiter",
76+
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
77+
True,
78+
)
6679
]
6780
# fmt: on
6881

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /