We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9b721db commit 827fad6Copy full SHA for 827fad6
src/diffusers/models/attention_dispatch.py
@@ -955,12 +955,13 @@ def _native_npu_attention(
955
dropout_p: float = 0.0,
956
scale: Optional[float] = None,
957
) -> torch.Tensor:
958
- return npu_fusion_attention(
+ query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
959
+ out = npu_fusion_attention(
960
query,
961
key,
962
value,
- query.size(2), # num_heads
963
- input_layout="BSND",
+ query.size(1), # num_heads
964
+ input_layout="BNSD",
965
pse=None,
966
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
967
pre_tockens=65536,
@@ -969,6 +970,8 @@ def _native_npu_attention(
969
970
sync=False,
971
inner_precise=0,
972
)[0]
973
+ out = out.transpose(1, 2).contiguous()
974
+ return out
975
976
977
# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
AltStyle によって変換されたページ (->オリジナル) / アドレス: モード: デフォルト 音声ブラウザ ルビ付き 配色反転 文字拡大 モバイル
0 commit comments