Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 303efd2

Browse files
gameofdimensionfelix01.yu
and
felix01.yu
authored
Improve pos embed for Flux.1 inference on Ascend NPU (#12534)
improve pos embed for ascend npu Co-authored-by: felix01.yu <felix01.yu@vipshop.com>
1 parent 5afbcce commit 303efd2

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

‎src/diffusers/models/transformers/transformer_flux.py‎

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
2424
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
25-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25+
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
2626
from ...utils.torch_utils import maybe_allow_in_graph
2727
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
2828
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -717,7 +717,11 @@ def forward(
717717
img_ids = img_ids[0]
718718

719719
ids = torch.cat((txt_ids, img_ids), dim=0)
720-
image_rotary_emb = self.pos_embed(ids)
720+
if is_torch_npu_available():
721+
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
722+
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
723+
else:
724+
image_rotary_emb = self.pos_embed(ids)
721725

722726
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
723727
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /