Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 827fad6

Browse files
leisuzzJ石页a-r-r-o-w
authored
Improve performance of NPU FA (#12260)
Co-authored-by: J石页 <jiangshuo9@h-partners.com> Co-authored-by: Aryan <aryan@huggingface.co>
1 parent 9b721db commit 827fad6

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

‎src/diffusers/models/attention_dispatch.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -955,12 +955,13 @@ def _native_npu_attention(
955955
dropout_p: float = 0.0,
956956
scale: Optional[float] = None,
957957
) -> torch.Tensor:
958-
return npu_fusion_attention(
958+
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
959+
out = npu_fusion_attention(
959960
query,
960961
key,
961962
value,
962-
query.size(2), # num_heads
963-
input_layout="BSND",
963+
query.size(1), # num_heads
964+
input_layout="BNSD",
964965
pse=None,
965966
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
966967
pre_tockens=65536,
@@ -969,6 +970,8 @@ def _native_npu_attention(
969970
sync=False,
970971
inner_precise=0,
971972
)[0]
973+
out = out.transpose(1, 2).contiguous()
974+
return out
972975

973976

974977
# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /