×ばつ1024) + long prompts required accessing pos_freqs[1024:2048] from a 1024-element buffer. Solution Added dynamic buffer expansion that automatically resizes pos_freqs when needed: Only expands when required (memory efficient) Uses register_buffer() for proper tensor management Maintains backward compatibility and performance Changes Added _expand_pos_freqs_if_needed() method Modified forward() to check expansion requirements Added test case for long prompt scenarios Before: pos_freqs[1024:2048] → IndexError After: Auto-expands buffer → Success Fixes #12083 Before submitting This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). Did you read the contributor guideline? Did you read our philosophy doc (important for complex PRs)? Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case. Issue: #12083 Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings. Note: This is a bug fix with internal implementation changes only. No public API changes require documentation updates. Did you write any new necessary tests? Added: test_long_prompt_no_error() in tests/pipelines/qwenimage/test_qwenimage.py Implementation Details Architecture Analysis Our solution follows established patterns in the diffusers codebase: Buffer Management: Uses register_buffer() like other components (modeling_ctx_clip.py, embeddings.py) Dynamic Computation: Mirrors pattern in get_1d_rotary_pos_embed() which computes frequencies on-demand Memory Alignment: Rounds to 512-token boundaries following PyTorch optimization practices Performance Impact Memory: Minimal - only expands when needed (one-time cost) Speed: Negligible - expansion happens once then cached Quality: Zero impact - identical mathematical operations, just larger buffer Backward Compatibility API: No changes to public interface Behavior: Existing short prompts work exactly as before Performance: Same performance characteristics for existing use cases Who can review? This PR affects: Model implementations: @sayakpaul The fix is focused on the QwenImage transformer implementation and follows established PyTorch patterns for dynamic buffer management.">
Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Fix Qwen-Image long prompt dimension mismatch error (issue #12083) #12087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
robin-ede wants to merge 9 commits into huggingface:main
base: main
Choose a base branch
Loading
from robin-ede:fix/qwen-image-long-prompt-issue-12083
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
9 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 76 additions & 16 deletions src/diffusers/models/transformers/transformer_qwenimage.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,31 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(1024)
neg_index = torch.arange(1024).flip(0) * -1 - 1
self.pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
# Initialize with default size 1024, but allow dynamic expansion
self._current_max_len = 1024
pos_index = torch.arange(self._current_max_len)
neg_index = torch.arange(self._current_max_len).flip(0) * -1 - 1
self.register_buffer(
"pos_freqs",
torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
),
)
self.neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
self.register_buffer(
"neg_freqs",
torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
),
)
self.rope_cache = {}

Expand All @@ -193,6 +201,53 @@ def rope_params(self, index, dim, theta=10000):
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs

def _expand_pos_freqs_if_needed(self, required_len):
"""Expand pos_freqs and neg_freqs if required length exceeds current size"""
if required_len <= self._current_max_len:
return

# Calculate new size (use next power of 2 or round to nearest 512 for efficiency)
new_max_len = max(required_len, int((required_len + 511) // 512) * 512)

# Log warning about potential quality degradation for long prompts
if required_len > 512:
logger.warning(
f"QwenImage model was trained on prompts up to 512 tokens. "
f"Current prompt requires {required_len} tokens, which may lead to unpredictable behavior. "
f"Consider using shorter prompts for better results."
)

# Generate expanded indices
pos_index = torch.arange(new_max_len, device=self.pos_freqs.device)
neg_index = torch.arange(new_max_len, device=self.neg_freqs.device).flip(0) * -1 - 1

# Generate expanded frequency embeddings
new_pos_freqs = torch.cat(
[
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
],
dim=1,
).to(device=self.pos_freqs.device, dtype=self.pos_freqs.dtype)

new_neg_freqs = torch.cat(
[
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
],
dim=1,
).to(device=self.neg_freqs.device, dtype=self.neg_freqs.dtype)

# Update buffers
self.register_buffer("pos_freqs", new_pos_freqs)
self.register_buffer("neg_freqs", new_neg_freqs)
self._current_max_len = new_max_len

# Clear cache since dimensions changed
self.rope_cache = {}

def forward(self, video_fhw, txt_seq_lens, device):
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
Expand Down Expand Up @@ -232,6 +287,11 @@ def forward(self, video_fhw, txt_seq_lens, device):
max_vid_index = max(height, width)

max_len = max(txt_seq_lens)

# Expand pos_freqs if needed to accommodate max_vid_index + max_len
required_len = max_vid_index + max_len
self._expand_pos_freqs_if_needed(required_len)

txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]

return vid_freqs, txt_freqs
Expand Down
78 changes: 77 additions & 1 deletion tests/pipelines/qwenimage/test_qwenimage.py
View file Open in desktop
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
QwenImagePipeline,
QwenImageTransformer2DModel,
)
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from diffusers.utils.testing_utils import CaptureLogger, enable_full_determinism, torch_device

from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
Expand Down Expand Up @@ -234,3 +234,79 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
expected_diff_max,
"VAE tiling should not affect the inference results",
)

def test_long_prompt_no_error(self):
# Test for issue #12083: long prompts should not cause dimension mismatch errors
device = torch_device
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)

# Create a long prompt that approaches but stays within limits
# This tests the original issue fix without triggering the warning
phrase = "A beautiful, detailed, high-resolution, photorealistic image showing "
long_prompt = phrase * 40 # Generates ~800 tokens, well within limits

# Verify token count for test clarity
tokenizer = components["tokenizer"]
token_count = len(tokenizer.encode(long_prompt))
required_len = 32 + token_count # height/width + tokens
# Should be large enough to test the fix but not trigger expansion warning
self.assertGreater(token_count, 500, f"Test prompt should be substantial (got {token_count} tokens)")
self.assertLess(required_len, 1024, f"Test should stay within limits (got {required_len})")

inputs = {
"prompt": long_prompt,
"generator": torch.Generator(device=device).manual_seed(0),
"num_inference_steps": 2,
"guidance_scale": 3.0,
"true_cfg_scale": 1.0,
"height": 32, # Small size for fast test
"width": 32, # Small size for fast test
"max_sequence_length": 1024, # Allow long sequence (max allowed)
"output_type": "pt",
}

# This should not raise a RuntimeError about tensor dimension mismatch
_ = pipe(**inputs)

def test_long_prompt_warning(self):
"""Test that long prompts trigger appropriate warning about training limitation"""
from diffusers.utils import logging

components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)

# Create a long prompt that will exceed the RoPE expansion threshold
# The warning is triggered when required_len = max(height, width) + text_tokens > _current_max_len
# Since _current_max_len is 1024 and height=width=32, we need > 992 tokens
phrase = "A detailed photorealistic image showing many beautiful elements and complex artistic creative features with intricate designs."
long_prompt = phrase * 58 # Generates ~1045 tokens, ensuring required_len > 1024

# Verify we exceed the threshold (for test robustness)
tokenizer = components["tokenizer"]
token_count = len(tokenizer.encode(long_prompt))
required_len = 32 + token_count # height/width + tokens
self.assertGreater(required_len, 1024, f"Test prompt must exceed threshold (got {required_len})")

# Capture transformer logging
logger = logging.get_logger("diffusers.models.transformers.transformer_qwenimage")
logger.setLevel(logging.WARNING)

with CaptureLogger(logger) as cap_logger:
_ = pipe(
prompt=long_prompt,
generator=torch.Generator(device=torch_device).manual_seed(0),
num_inference_steps=2,
guidance_scale=3.0,
true_cfg_scale=1.0,
height=32, # Small size for fast test
width=32, # Small size for fast test
max_sequence_length=1024, # Allow long sequence
output_type="pt",
)

# Verify warning was logged about the 512-token training limitation
self.assertTrue("512 tokens" in cap_logger.out)
self.assertTrue("unpredictable behavior" in cap_logger.out)

AltStyle によって変換されたページ (->オリジナル) /