Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 35cb2c8

Browse files
Apply ruff formatting to QwenImage warning implementation
- Fix whitespace and string quote consistency - Add trailing commas where appropriate - Clean up formatting per diffusers code standards
1 parent 39462a4 commit 35cb2c8

File tree

2 files changed

+45
-39
lines changed

2 files changed

+45
-39
lines changed

‎src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -164,22 +164,28 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
164164
self._current_max_len = 1024
165165
pos_index = torch.arange(self._current_max_len)
166166
neg_index = torch.arange(self._current_max_len).flip(0) * -1 - 1
167-
self.register_buffer('pos_freqs', torch.cat(
168-
[
169-
self.rope_params(pos_index, self.axes_dim[0], self.theta),
170-
self.rope_params(pos_index, self.axes_dim[1], self.theta),
171-
self.rope_params(pos_index, self.axes_dim[2], self.theta),
172-
],
173-
dim=1,
174-
))
175-
self.register_buffer('neg_freqs', torch.cat(
176-
[
177-
self.rope_params(neg_index, self.axes_dim[0], self.theta),
178-
self.rope_params(neg_index, self.axes_dim[1], self.theta),
179-
self.rope_params(neg_index, self.axes_dim[2], self.theta),
180-
],
181-
dim=1,
182-
))
167+
self.register_buffer(
168+
"pos_freqs",
169+
torch.cat(
170+
[
171+
self.rope_params(pos_index, self.axes_dim[0], self.theta),
172+
self.rope_params(pos_index, self.axes_dim[1], self.theta),
173+
self.rope_params(pos_index, self.axes_dim[2], self.theta),
174+
],
175+
dim=1,
176+
),
177+
)
178+
self.register_buffer(
179+
"neg_freqs",
180+
torch.cat(
181+
[
182+
self.rope_params(neg_index, self.axes_dim[0], self.theta),
183+
self.rope_params(neg_index, self.axes_dim[1], self.theta),
184+
self.rope_params(neg_index, self.axes_dim[2], self.theta),
185+
],
186+
dim=1,
187+
),
188+
)
183189
self.rope_cache = {}
184190

185191
# 是否使用 scale rope
@@ -199,22 +205,22 @@ def _expand_pos_freqs_if_needed(self, required_len):
199205
"""Expand pos_freqs and neg_freqs if required length exceeds current size"""
200206
if required_len <= self._current_max_len:
201207
return
202-
208+
203209
# Calculate new size (use next power of 2 or round to nearest 512 for efficiency)
204210
new_max_len = max(required_len, int((required_len + 511) // 512) * 512)
205-
211+
206212
# Log warning about potential quality degradation for long prompts
207213
if required_len > 512:
208214
logger.warning(
209215
f"QwenImage model was trained on prompts up to 512 tokens. "
210216
f"Current prompt requires {required_len} tokens, which may lead to unpredictable behavior. "
211217
f"Consider using shorter prompts for better results."
212218
)
213-
219+
214220
# Generate expanded indices
215221
pos_index = torch.arange(new_max_len, device=self.pos_freqs.device)
216222
neg_index = torch.arange(new_max_len, device=self.neg_freqs.device).flip(0) * -1 - 1
217-
223+
218224
# Generate expanded frequency embeddings
219225
new_pos_freqs = torch.cat(
220226
[
@@ -224,7 +230,7 @@ def _expand_pos_freqs_if_needed(self, required_len):
224230
],
225231
dim=1,
226232
).to(device=self.pos_freqs.device, dtype=self.pos_freqs.dtype)
227-
233+
228234
new_neg_freqs = torch.cat(
229235
[
230236
self.rope_params(neg_index, self.axes_dim[0], self.theta),
@@ -233,12 +239,12 @@ def _expand_pos_freqs_if_needed(self, required_len):
233239
],
234240
dim=1,
235241
).to(device=self.neg_freqs.device, dtype=self.neg_freqs.dtype)
236-
242+
237243
# Update buffers
238-
self.register_buffer('pos_freqs', new_pos_freqs)
239-
self.register_buffer('neg_freqs', new_neg_freqs)
244+
self.register_buffer("pos_freqs", new_pos_freqs)
245+
self.register_buffer("neg_freqs", new_neg_freqs)
240246
self._current_max_len = new_max_len
241-
247+
242248
# Clear cache since dimensions changed
243249
self.rope_cache = {}
244250

@@ -281,11 +287,11 @@ def forward(self, video_fhw, txt_seq_lens, device):
281287
max_vid_index = max(height, width)
282288

283289
max_len = max(txt_seq_lens)
284-
290+
285291
# Expand pos_freqs if needed to accommodate max_vid_index + max_len
286292
required_len = max_vid_index + max_len
287293
self._expand_pos_freqs_if_needed(required_len)
288-
294+
289295
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
290296

291297
return vid_freqs, txt_freqs

‎tests/pipelines/qwenimage/test_qwenimage.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -241,43 +241,43 @@ def test_long_prompt_no_error(self):
241241
components = self.get_dummy_components()
242242
pipe = self.pipeline_class(**components)
243243
pipe.to(device)
244-
244+
245245
# Create a very long prompt that exceeds 1024 tokens when combined with image positioning
246246
# Repeat a long phrase to simulate a real long prompt scenario
247247
long_phrase = "A beautiful, detailed, high-resolution, photorealistic image showing "
248248
long_prompt = (long_phrase * 50)[:1200] # Ensure we exceed 1024 characters
249-
249+
250250
inputs = {
251251
"prompt": long_prompt,
252252
"generator": torch.Generator(device=device).manual_seed(0),
253253
"num_inference_steps": 2,
254254
"guidance_scale": 3.0,
255255
"true_cfg_scale": 1.0,
256256
"height": 32, # Small size for fast test
257-
"width": 32, # Small size for fast test
257+
"width": 32, # Small size for fast test
258258
"max_sequence_length": 1200, # Allow long sequence
259259
"output_type": "pt",
260260
}
261-
261+
262262
# This should not raise a RuntimeError about tensor dimension mismatch
263263
_ = pipe(**inputs)
264264

265265
def test_long_prompt_warning(self):
266266
"""Test that long prompts trigger appropriate warning about training limitation"""
267267
from diffusers.utils import logging
268-
268+
269269
components = self.get_dummy_components()
270270
pipe = self.pipeline_class(**components)
271271
pipe.to(torch_device)
272-
272+
273273
# Create prompt that will exceed 512 tokens to trigger warning
274274
long_phrase = "A detailed photorealistic description of a complex scene with many elements "
275275
long_prompt = (long_phrase * 20)[:800] # Create a prompt that will exceed 512 tokens
276-
277-
# Capture transformer logging
276+
277+
# Capture transformer logging
278278
logger = logging.get_logger("diffusers.models.transformers.transformer_qwenimage")
279279
logger.setLevel(logging.WARNING)
280-
280+
281281
with CaptureLogger(logger) as cap_logger:
282282
_ = pipe(
283283
prompt=long_prompt,
@@ -286,11 +286,11 @@ def test_long_prompt_warning(self):
286286
guidance_scale=3.0,
287287
true_cfg_scale=1.0,
288288
height=32, # Small size for fast test
289-
width=32, # Small size for fast test
289+
width=32, # Small size for fast test
290290
max_sequence_length=900, # Allow long sequence
291-
output_type="pt"
291+
output_type="pt",
292292
)
293-
293+
294294
# Verify warning was logged about the 512-token training limitation
295295
self.assertTrue("512 tokens" in cap_logger.out)
296296
self.assertTrue("unpredictable behavior" in cap_logger.out)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /