Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Fix PyTorch 2.3.1 compatibility: add version guard for torch.library.... #12206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit Hold shift + click to select a range
0034bdc
Fix PyTorch 2.3.1 compatibility: add version guard for torch.library....
Aishwarya0811 Aug 21, 2025
51a4f7f
Use dummy decorators approach for PyTorch version compatibility
Aishwarya0811 Aug 21, 2025
22dc9d1
Update src/diffusers/models/attention_dispatch.py
Aishwarya0811 Aug 22, 2025
7009fd8
Update src/diffusers/models/attention_dispatch.py
Aishwarya0811 Aug 22, 2025
a831012
Update src/diffusers/models/attention_dispatch.py
Aishwarya0811 Aug 22, 2025
6f799b0
Update src/diffusers/models/attention_dispatch.py
Aishwarya0811 Aug 22, 2025
f515990
Merge branch 'main' into fix-pytorch-231-compatibility
Aishwarya0811 Aug 22, 2025
5ef7da4
Move version check to top of file and use private naming as requested
Aishwarya0811 Aug 22, 2025
0a3a228
Merge branch 'main' into fix-pytorch-231-compatibility
a-r-r-o-w Aug 23, 2025
ec47936
Apply style fixes
github-actions[bot] Aug 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions src/diffusers/models/attention_dispatch.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,27 @@
else:
xops = None

# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
if torch.__version__ >= "2.4.0":
_custom_op = torch.library.custom_op
_register_fake = torch.library.register_fake
else:

def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
def wrap(func):
return func

return wrap if fn is None else fn

def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
def wrap(func):
return func

return wrap if fn is None else fn

_custom_op = custom_op_no_op
_register_fake = register_fake_no_op


logger = get_logger(__name__) # pylint: disable=invalid-name

Expand Down Expand Up @@ -473,12 +494,11 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):

# ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility


# TODO: library.custom_op and register_fake probably need version guards?
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")


@_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_original(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -487,7 +507,7 @@ def _wrapped_flash_attn_3_original(
return out, lse


@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
@_register_fake("flash_attn_3::_flash_attn_forward")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape
lse_shape = (batch_size, seq_len, num_heads)
Expand Down

AltStyle によって変換されたページ (->オリジナル) /