Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit e682af2

Browse files
authored
Qwen Image Edit Support (#12164)
* feat(qwen-image): add qwen-image-edit support * fix(qwen image): - compatible with torch.compile in new rope setting - fix init import - add prompt truncation in img2img and inpaint pipe - remove unused logic and comment - add copy statement - guard logic for rope video shape tuple * fix(qwen image): - make fix-copies - update doc
1 parent a58a4f6 commit e682af2

File tree

9 files changed

+949
-89
lines changed

9 files changed

+949
-89
lines changed

‎src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@
492492
"QwenImageImg2ImgPipeline",
493493
"QwenImageInpaintPipeline",
494494
"QwenImagePipeline",
495+
"QwenImageEditPipeline",
495496
"ReduxImageEncoder",
496497
"SanaControlNetPipeline",
497498
"SanaPAGPipeline",
@@ -1123,6 +1124,7 @@
11231124
PixArtAlphaPipeline,
11241125
PixArtSigmaPAGPipeline,
11251126
PixArtSigmaPipeline,
1127+
QwenImageEditPipeline,
11261128
QwenImageImg2ImgPipeline,
11271129
QwenImageInpaintPipeline,
11281130
QwenImagePipeline,

‎src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
import functools
1716
import math
1817
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -161,17 +160,17 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
161160
super().__init__()
162161
self.theta = theta
163162
self.axes_dim = axes_dim
164-
pos_index = torch.arange(1024)
165-
neg_index = torch.arange(1024).flip(0) * -1 - 1
166-
pos_freqs = torch.cat(
163+
pos_index = torch.arange(4096)
164+
neg_index = torch.arange(4096).flip(0) * -1 - 1
165+
self.pos_freqs = torch.cat(
167166
[
168167
self.rope_params(pos_index, self.axes_dim[0], self.theta),
169168
self.rope_params(pos_index, self.axes_dim[1], self.theta),
170169
self.rope_params(pos_index, self.axes_dim[2], self.theta),
171170
],
172171
dim=1,
173172
)
174-
neg_freqs = torch.cat(
173+
self.neg_freqs = torch.cat(
175174
[
176175
self.rope_params(neg_index, self.axes_dim[0], self.theta),
177176
self.rope_params(neg_index, self.axes_dim[1], self.theta),
@@ -180,10 +179,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
180179
dim=1,
181180
)
182181
self.rope_cache = {}
183-
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
184-
self.register_buffer("neg_freqs", neg_freqs, persistent=False)
185182

186-
# 是否使用 scale rope
183+
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
187184
self.scale_rope = scale_rope
188185

189186
def rope_params(self, index, dim, theta=10000):
@@ -201,35 +198,47 @@ def forward(self, video_fhw, txt_seq_lens, device):
201198
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
202199
txt_length: [bs] a list of 1 integers representing the length of the text
203200
"""
201+
if self.pos_freqs.device != device:
202+
self.pos_freqs = self.pos_freqs.to(device)
203+
self.neg_freqs = self.neg_freqs.to(device)
204+
204205
if isinstance(video_fhw, list):
205206
video_fhw = video_fhw[0]
206-
frame, height, width = video_fhw
207-
rope_key = f"{frame}_{height}_{width}"
208-
209-
if not torch.compiler.is_compiling():
210-
if rope_key not in self.rope_cache:
211-
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width)
212-
vid_freqs = self.rope_cache[rope_key]
213-
else:
214-
vid_freqs = self._compute_video_freqs(frame, height, width)
207+
if not isinstance(video_fhw, list):
208+
video_fhw = [video_fhw]
209+
210+
vid_freqs = []
211+
max_vid_index = 0
212+
for idx, fhw in enumerate(video_fhw):
213+
frame, height, width = fhw
214+
rope_key = f"{idx}_{height}_{width}"
215+
216+
if not torch.compiler.is_compiling():
217+
if rope_key not in self.rope_cache:
218+
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
219+
video_freq = self.rope_cache[rope_key]
220+
else:
221+
video_freq = self._compute_video_freqs(frame, height, width, idx)
222+
vid_freqs.append(video_freq)
215223

216-
if self.scale_rope:
217-
max_vid_index = max(height // 2, width // 2)
218-
else:
219-
max_vid_index = max(height, width)
224+
if self.scale_rope:
225+
max_vid_index = max(height // 2, width // 2, max_vid_index)
226+
else:
227+
max_vid_index = max(height, width, max_vid_index)
220228

221229
max_len = max(txt_seq_lens)
222230
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
231+
vid_freqs = torch.cat(vid_freqs, dim=0)
223232

224233
return vid_freqs, txt_freqs
225234

226235
@functools.lru_cache(maxsize=None)
227-
def _compute_video_freqs(self, frame, height, width):
236+
def _compute_video_freqs(self, frame, height, width, idx=0):
228237
seq_lens = frame * height * width
229238
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
230239
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
231240

232-
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
241+
freqs_frame = freqs_pos[0][idx : idx+frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
233242
if self.scale_rope:
234243
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
235244
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)

‎src/diffusers/pipelines/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,7 @@
391391
"QwenImagePipeline",
392392
"QwenImageImg2ImgPipeline",
393393
"QwenImageInpaintPipeline",
394+
"QwenImageEditPipeline",
394395
]
395396
try:
396397
if not is_onnx_available():
@@ -708,7 +709,12 @@
708709
from .paint_by_example import PaintByExamplePipeline
709710
from .pia import PIAPipeline
710711
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
711-
from .qwenimage import QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, QwenImagePipeline
712+
from .qwenimage import (
713+
QwenImageEditPipeline,
714+
QwenImageImg2ImgPipeline,
715+
QwenImageInpaintPipeline,
716+
QwenImagePipeline,
717+
)
712718
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
713719
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
714720
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline

‎src/diffusers/pipelines/qwenimage/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
2727
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
2828
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
29+
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
2930

3031
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
3132
try:
@@ -35,6 +36,7 @@
3536
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
3637
else:
3738
from .pipeline_qwenimage import QwenImagePipeline
39+
from .pipeline_qwenimage_edit import QwenImageEditPipeline
3840
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
3941
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
4042
else:

‎src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,9 @@ def encode_prompt(
253253
if prompt_embeds is None:
254254
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
255255

256+
prompt_embeds = prompt_embeds[:, :max_sequence_length]
257+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
258+
256259
_, seq_len, _ = prompt_embeds.shape
257260
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
258261
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -316,20 +319,6 @@ def check_inputs(
316319
if max_sequence_length is not None and max_sequence_length > 1024:
317320
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
318321

319-
@staticmethod
320-
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
321-
latent_image_ids = torch.zeros(height, width, 3)
322-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
323-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
324-
325-
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
326-
327-
latent_image_ids = latent_image_ids.reshape(
328-
latent_image_id_height * latent_image_id_width, latent_image_id_channels
329-
)
330-
331-
return latent_image_ids.to(device=device, dtype=dtype)
332-
333322
@staticmethod
334323
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
335324
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
@@ -402,8 +391,7 @@ def prepare_latents(
402391
shape = (batch_size, 1, num_channels_latents, height, width)
403392

404393
if latents is not None:
405-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
406-
return latents.to(device=device, dtype=dtype), latent_image_ids
394+
return latents.to(device=device, dtype=dtype)
407395

408396
if isinstance(generator, list) and len(generator) != batch_size:
409397
raise ValueError(
@@ -414,9 +402,7 @@ def prepare_latents(
414402
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
415403
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
416404

417-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
418-
419-
return latents, latent_image_ids
405+
return latents
420406

421407
@property
422408
def guidance_scale(self):
@@ -594,7 +580,7 @@ def __call__(
594580

595581
# 4. Prepare latent variables
596582
num_channels_latents = self.transformer.config.in_channels // 4
597-
latents, latent_image_ids = self.prepare_latents(
583+
latents = self.prepare_latents(
598584
batch_size * num_images_per_prompt,
599585
num_channels_latents,
600586
height,
@@ -604,7 +590,7 @@ def __call__(
604590
generator,
605591
latents,
606592
)
607-
img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
593+
img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
608594

609595
# 5. Prepare timesteps
610596
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /