Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 4a9dbd5

Browse files
sayakpaula-r-r-o-w
andauthored
enable compilation in qwen image. (#12061)
* update * update * update * enable compilation in qwen image. * add tests --------- Co-authored-by: Aryan <aryan@huggingface.co>
1 parent 630d27f commit 4a9dbd5

File tree

3 files changed

+137
-24
lines changed

3 files changed

+137
-24
lines changed

‎src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515

16+
import functools
1617
import math
1718
from typing import Any, Dict, List, Optional, Tuple, Union
1819

@@ -162,15 +163,15 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
162163
self.axes_dim = axes_dim
163164
pos_index = torch.arange(1024)
164165
neg_index = torch.arange(1024).flip(0) * -1 - 1
165-
self.pos_freqs = torch.cat(
166+
pos_freqs = torch.cat(
166167
[
167168
self.rope_params(pos_index, self.axes_dim[0], self.theta),
168169
self.rope_params(pos_index, self.axes_dim[1], self.theta),
169170
self.rope_params(pos_index, self.axes_dim[2], self.theta),
170171
],
171172
dim=1,
172173
)
173-
self.neg_freqs = torch.cat(
174+
neg_freqs = torch.cat(
174175
[
175176
self.rope_params(neg_index, self.axes_dim[0], self.theta),
176177
self.rope_params(neg_index, self.axes_dim[1], self.theta),
@@ -179,6 +180,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
179180
dim=1,
180181
)
181182
self.rope_cache = {}
183+
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
184+
self.register_buffer("neg_freqs", neg_freqs, persistent=False)
182185

183186
# 是否使用 scale rope
184187
self.scale_rope = scale_rope
@@ -198,33 +201,17 @@ def forward(self, video_fhw, txt_seq_lens, device):
198201
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
199202
txt_length: [bs] a list of 1 integers representing the length of the text
200203
"""
201-
if self.pos_freqs.device != device:
202-
self.pos_freqs = self.pos_freqs.to(device)
203-
self.neg_freqs = self.neg_freqs.to(device)
204-
205204
if isinstance(video_fhw, list):
206205
video_fhw = video_fhw[0]
207206
frame, height, width = video_fhw
208207
rope_key = f"{frame}_{height}_{width}"
209208

210-
if rope_key not in self.rope_cache:
211-
seq_lens = frame * height * width
212-
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
213-
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
214-
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
215-
if self.scale_rope:
216-
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
217-
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
218-
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
219-
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
220-
221-
else:
222-
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
223-
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
224-
225-
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
226-
self.rope_cache[rope_key] = freqs.clone().contiguous()
227-
vid_freqs = self.rope_cache[rope_key]
209+
if not torch.compiler.is_compiling():
210+
if rope_key not in self.rope_cache:
211+
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width)
212+
vid_freqs = self.rope_cache[rope_key]
213+
else:
214+
vid_freqs = self._compute_video_freqs(frame, height, width)
228215

229216
if self.scale_rope:
230217
max_vid_index = max(height // 2, width // 2)
@@ -236,6 +223,25 @@ def forward(self, video_fhw, txt_seq_lens, device):
236223

237224
return vid_freqs, txt_freqs
238225

226+
@functools.lru_cache(maxsize=None)
227+
def _compute_video_freqs(self, frame, height, width):
228+
seq_lens = frame * height * width
229+
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
230+
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
231+
232+
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
233+
if self.scale_rope:
234+
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
235+
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
236+
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
237+
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
238+
else:
239+
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
240+
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
241+
242+
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
243+
return freqs.clone().contiguous()
244+
239245

240246
class QwenDoubleStreamAttnProcessor2_0:
241247
"""
@@ -482,6 +488,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
482488
_supports_gradient_checkpointing = True
483489
_no_split_modules = ["QwenImageTransformerBlock"]
484490
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
491+
_repeated_blocks = ["QwenImageTransformerBlock"]
485492

486493
@register_to_config
487494
def __init__(

‎tests/models/test_modeling_common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,6 +1711,11 @@ def test_group_offloading_with_disk(self, offload_type, record_stream, atol=1e-5
17111711
if not self.model_class._supports_group_offloading:
17121712
pytest.skip("Model does not support group offloading.")
17131713

1714+
if self.model_class.__name__ == "QwenImageTransformer2DModel":
1715+
pytest.skip(
1716+
"QwenImageTransformer2DModel doesn't support group offloading with disk. Needs to be investigated."
1717+
)
1718+
17141719
def _has_generator_arg(model):
17151720
sig = inspect.signature(model.forward)
17161721
params = sig.parameters
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import unittest
17+
18+
import torch
19+
20+
from diffusers import QwenImageTransformer2DModel
21+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
22+
23+
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
24+
25+
26+
enable_full_determinism()
27+
28+
29+
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
30+
model_class = QwenImageTransformer2DModel
31+
main_input_name = "hidden_states"
32+
# We override the items here because the transformer under consideration is small.
33+
model_split_percents = [0.7, 0.6, 0.6]
34+
35+
# Skip setting testing with default: AttnProcessor
36+
uses_custom_attn_processor = True
37+
38+
@property
39+
def dummy_input(self):
40+
return self.prepare_dummy_input()
41+
42+
@property
43+
def input_shape(self):
44+
return (16, 16)
45+
46+
@property
47+
def output_shape(self):
48+
return (16, 16)
49+
50+
def prepare_dummy_input(self, height=4, width=4):
51+
batch_size = 1
52+
num_latent_channels = embedding_dim = 16
53+
sequence_length = 7
54+
vae_scale_factor = 4
55+
56+
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
57+
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
58+
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
59+
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
60+
orig_height = height * 2 * vae_scale_factor
61+
orig_width = width * 2 * vae_scale_factor
62+
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
63+
64+
return {
65+
"hidden_states": hidden_states,
66+
"encoder_hidden_states": encoder_hidden_states,
67+
"encoder_hidden_states_mask": encoder_hidden_states_mask,
68+
"timestep": timestep,
69+
"img_shapes": img_shapes,
70+
"txt_seq_lens": encoder_hidden_states_mask.sum(dim=1).tolist(),
71+
}
72+
73+
def prepare_init_args_and_inputs_for_common(self):
74+
init_dict = {
75+
"patch_size": 2,
76+
"in_channels": 16,
77+
"out_channels": 4,
78+
"num_layers": 2,
79+
"attention_head_dim": 16,
80+
"num_attention_heads": 3,
81+
"joint_attention_dim": 16,
82+
"guidance_embeds": False,
83+
"axes_dims_rope": (8, 4, 4),
84+
}
85+
86+
inputs_dict = self.dummy_input
87+
return init_dict, inputs_dict
88+
89+
def test_gradient_checkpointing_is_applied(self):
90+
expected_set = {"QwenImageTransformer2DModel"}
91+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
92+
93+
94+
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
95+
model_class = QwenImageTransformer2DModel
96+
97+
def prepare_init_args_and_inputs_for_common(self):
98+
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
99+
100+
def prepare_dummy_input(self, height, width):
101+
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /