|  | 
| 22 | 22 | 
 | 
| 23 | 23 | from ...configuration_utils import ConfigMixin, register_to_config | 
| 24 | 24 | from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin | 
| 25 |  | -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers | 
|  | 25 | +from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers | 
| 26 | 26 | from ...utils.torch_utils import maybe_allow_in_graph | 
| 27 | 27 | from .._modeling_parallel import ContextParallelInput, ContextParallelOutput | 
| 28 | 28 | from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward | 
| @@ -717,7 +717,11 @@ def forward( | 
| 717 | 717 |  img_ids = img_ids[0] | 
| 718 | 718 | 
 | 
| 719 | 719 |  ids = torch.cat((txt_ids, img_ids), dim=0) | 
| 720 |  | - image_rotary_emb = self.pos_embed(ids) | 
|  | 720 | + if is_torch_npu_available(): | 
|  | 721 | + freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) | 
|  | 722 | + image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) | 
|  | 723 | + else: | 
|  | 724 | + image_rotary_emb = self.pos_embed(ids) | 
| 721 | 725 | 
 | 
| 722 | 726 |  if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: | 
| 723 | 727 |  ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") | 
|  | 
0 commit comments