From 0034bdce10b8ec67fb4e209bf85387f733b6c79a Mon Sep 17 00:00:00 2001 From: AishwaryaBadlani Date: 2025年8月21日 14:07:21 +0500 Subject: [PATCH 1/8] Fix PyTorch 2.3.1 compatibility: add version guard for torch.library.custom_op - Add hasattr() check for torch.library.custom_op and register_fake - These functions were added in PyTorch 2.4, causing import failures in 2.3.1 - Both decorators and functions are now properly guarded with version checks - Maintains backward compatibility while preserving functionality Fixes #12195 --- src/diffusers/models/attention_dispatch.py | 32 ++++++++++++---------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 7cc30e47ab14..ee3873562b82 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -475,23 +475,25 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): # Registrations are required for fullgraph tracing compatibility -# TODO: library.custom_op and register_fake probably need version guards? + + +# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding # this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 -@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") -def _wrapped_flash_attn_3_original( - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - out, lse = flash_attn_3_func(query, key, value) - lse = lse.permute(0, 2, 1) - return out, lse - - -@torch.library.register_fake("flash_attn_3::_flash_attn_forward") -def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size, seq_len, num_heads, head_dim = query.shape - lse_shape = (batch_size, seq_len, num_heads) - return torch.empty_like(query), query.new_empty(lse_shape) +if hasattr(torch.library, 'custom_op') and hasattr(torch.library, 'register_fake'): + @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") + def _wrapped_flash_attn_3_original( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = flash_attn_3_func(query, key, value) + lse = lse.permute(0, 2, 1) + return out, lse + + @torch.library.register_fake("flash_attn_3::_flash_attn_forward") + def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, seq_len, num_heads, head_dim = query.shape + lse_shape = (batch_size, seq_len, num_heads) + return torch.empty_like(query), query.new_empty(lse_shape) # ===== Attention backends ===== From 51a4f7fe176915ab2489020b758b5addcd1fd8e9 Mon Sep 17 00:00:00 2001 From: AishwaryaBadlani Date: 2025年8月21日 17:28:27 +0500 Subject: [PATCH 2/8] Use dummy decorators approach for PyTorch version compatibility - Replace hasattr check with version string comparison - Add no-op decorator functions for PyTorch < 2.4.0 - Follows pattern from #11941 as suggested by reviewer - Maintains cleaner code structure without indentation changes --- src/diffusers/models/attention_dispatch.py | 51 ++++++++++++++-------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ee3873562b82..61b3b3e1ad09 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -473,27 +473,42 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): # ===== torch op registrations ===== # Registrations are required for fullgraph tracing compatibility - - - - # Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding # this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 -if hasattr(torch.library, 'custom_op') and hasattr(torch.library, 'register_fake'): - @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") - def _wrapped_flash_attn_3_original( - query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - out, lse = flash_attn_3_func(query, key, value) - lse = lse.permute(0, 2, 1) - return out, lse - - @torch.library.register_fake("flash_attn_3::_flash_attn_forward") - def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - batch_size, seq_len, num_heads, head_dim = query.shape - lse_shape = (batch_size, seq_len, num_heads) - return torch.empty_like(query), query.new_empty(lse_shape) + +if torch.__version__>= "2.4.0": + custom_op = torch.library.custom_op + register_fake = torch.library.register_fake +else: + def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None): + def wrap(func): + return func + return wrap if fn is None else fn + + def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1): + def wrap(func): + return func + return wrap if fn is None else fn + + custom_op = custom_op_no_op + register_fake = register_fake_no_op + + +@custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") +def _wrapped_flash_attn_3_original( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = flash_attn_3_func(query, key, value) + lse = lse.permute(0, 2, 1) + return out, lse + + +@register_fake("flash_attn_3::_flash_attn_forward") +def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, seq_len, num_heads, head_dim = query.shape + lse_shape = (batch_size, seq_len, num_heads) + return torch.empty_like(query), query.new_empty(lse_shape) # ===== Attention backends ===== From 22dc9d1cbe55f3bcd7aa70ae25df5c0bf0bd384a Mon Sep 17 00:00:00 2001 From: Aishwarya Badlani <41635755+aishwarya0811@users.noreply.github.com> Date: 2025年8月23日 00:24:40 +0500 Subject: [PATCH 3/8] Update src/diffusers/models/attention_dispatch.py Update all the decorator usages Co-authored-by: Aryan --- src/diffusers/models/attention_dispatch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 61b3b3e1ad09..78d61b908644 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -478,8 +478,8 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): # this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 if torch.__version__>= "2.4.0": - custom_op = torch.library.custom_op - register_fake = torch.library.register_fake + _custom_op = torch.library.custom_op + _register_fake = torch.library.register_fake else: def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None): def wrap(func): From 7009fd8aeb8320e3b92b965b8fa5dce136b33fea Mon Sep 17 00:00:00 2001 From: Aishwarya Badlani <41635755+aishwarya0811@users.noreply.github.com> Date: 2025年8月23日 00:24:55 +0500 Subject: [PATCH 4/8] Update src/diffusers/models/attention_dispatch.py Co-authored-by: Aryan --- src/diffusers/models/attention_dispatch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 78d61b908644..6a8e447852b9 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -491,8 +491,8 @@ def wrap(func): return func return wrap if fn is None else fn - custom_op = custom_op_no_op - register_fake = register_fake_no_op + _custom_op = custom_op_no_op + _register_fake = register_fake_no_op @custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") From a8310129933cc35cf39b2c9e5ebf3dee0cc1b41a Mon Sep 17 00:00:00 2001 From: Aishwarya Badlani <41635755+aishwarya0811@users.noreply.github.com> Date: 2025年8月23日 00:25:14 +0500 Subject: [PATCH 5/8] Update src/diffusers/models/attention_dispatch.py Co-authored-by: Aryan --- src/diffusers/models/attention_dispatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 6a8e447852b9..d7f1f5d6ac83 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -495,7 +495,7 @@ def wrap(func): _register_fake = register_fake_no_op -@custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") +@_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") def _wrapped_flash_attn_3_original( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: From 6f799b0ca8ac0d8d2e85e7993b5d3d54f0bbe06d Mon Sep 17 00:00:00 2001 From: Aishwarya Badlani <41635755+aishwarya0811@users.noreply.github.com> Date: 2025年8月23日 00:25:23 +0500 Subject: [PATCH 6/8] Update src/diffusers/models/attention_dispatch.py Co-authored-by: Aryan --- src/diffusers/models/attention_dispatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index d7f1f5d6ac83..4937f96c6ae8 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -504,7 +504,7 @@ def _wrapped_flash_attn_3_original( return out, lse -@register_fake("flash_attn_3::_flash_attn_forward") +@_register_fake("flash_attn_3::_flash_attn_forward") def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, seq_len, num_heads, head_dim = query.shape lse_shape = (batch_size, seq_len, num_heads) From 5ef7da4fb8922758ed218f7aff70abdaaaff3333 Mon Sep 17 00:00:00 2001 From: AishwaryaBadlani Date: 2025年8月23日 01:19:09 +0500 Subject: [PATCH 7/8] Move version check to top of file and use private naming as requested --- src/diffusers/models/attention_dispatch.py | 37 +++++++++++----------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 4937f96c6ae8..b7ff58aceb08 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -109,6 +109,24 @@ import xformers.ops as xops else: xops = None + +# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 +if torch.__version__>= "2.4.0": + _custom_op = torch.library.custom_op + _register_fake = torch.library.register_fake +else: + def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None): + def wrap(func): + return func + return wrap if fn is None else fn + + def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1): + def wrap(func): + return func + return wrap if fn is None else fn + + _custom_op = custom_op_no_op + _register_fake = register_fake_no_op logger = get_logger(__name__) # pylint: disable=invalid-name @@ -473,28 +491,9 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): # ===== torch op registrations ===== # Registrations are required for fullgraph tracing compatibility -# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding # this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 -if torch.__version__>= "2.4.0": - _custom_op = torch.library.custom_op - _register_fake = torch.library.register_fake -else: - def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None): - def wrap(func): - return func - return wrap if fn is None else fn - - def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1): - def wrap(func): - return func - return wrap if fn is None else fn - - _custom_op = custom_op_no_op - _register_fake = register_fake_no_op - - @_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") def _wrapped_flash_attn_3_original( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor From ec4793637a90f0a8c60e206fc0f127fe567e59be Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: 2025年8月23日 07:03:01 +0000 Subject: [PATCH 8/8] Apply style fixes --- src/diffusers/models/attention_dispatch.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index b7ff58aceb08..6a05aac215c6 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -109,24 +109,27 @@ import xformers.ops as xops else: xops = None - + # Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 if torch.__version__>= "2.4.0": _custom_op = torch.library.custom_op _register_fake = torch.library.register_fake else: + def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None): def wrap(func): return func + return wrap if fn is None else fn - + def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1): def wrap(func): return func + return wrap if fn is None else fn - + _custom_op = custom_op_no_op - _register_fake = register_fake_no_op + _register_fake = register_fake_no_op logger = get_logger(__name__) # pylint: disable=invalid-name @@ -494,6 +497,7 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): # TODO: this is only required because the beta release FA3 does not have it. There is a PR adding # this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590 + @_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") def _wrapped_flash_attn_3_original( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor

AltStyle によって変換されたページ (->オリジナル) /