Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[Qwen-image] encoder_hidden_states_mask is not used #12294

Open
Labels
bugSomething isn't working
@JingyeChen-Canva

Description

Describe the bug

During text encoding stage, the text encoder will produce encoder_hidden_states_mask to mask some regions. It usually happens during training with batch size more than 1, to make sure that the produce text embedding has the same sequence length.

However, when I check the following code, it seems weird that encoder_hidden_states_mask is not used even though it has been passed to the function.

It indeed works when batch=1. However, I feel it is indeed a bug if we train Qwen-image with batch size more than 1.

Reproduction

class QwenDoubleStreamAttnProcessor2_0:
 """
 Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
 implements joint attention computation where text and image streams are processed together.
 """
 _attention_backend = None
 def __init__(self):
 if not hasattr(F, "scaled_dot_product_attention"):
 raise ImportError(
 "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
 )
 def __call__(
 self,
 attn: Attention,
 hidden_states: torch.FloatTensor, # Image stream
 encoder_hidden_states: torch.FloatTensor = None, # Text stream
 encoder_hidden_states_mask: torch.FloatTensor = None,
 attention_mask: Optional[torch.FloatTensor] = None,
 image_rotary_emb: Optional[torch.Tensor] = None,
 ) -> torch.FloatTensor:
 if encoder_hidden_states is None:
 raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
 seq_txt = encoder_hidden_states.shape[1]
 # Compute QKV for image stream (sample projections)
 img_query = attn.to_q(hidden_states)
 img_key = attn.to_k(hidden_states)
 img_value = attn.to_v(hidden_states)
 # Compute QKV for text stream (context projections)
 txt_query = attn.add_q_proj(encoder_hidden_states)
 txt_key = attn.add_k_proj(encoder_hidden_states)
 txt_value = attn.add_v_proj(encoder_hidden_states)
 # Reshape for multi-head attention
 img_query = img_query.unflatten(-1, (attn.heads, -1))
 img_key = img_key.unflatten(-1, (attn.heads, -1))
 img_value = img_value.unflatten(-1, (attn.heads, -1))
 txt_query = txt_query.unflatten(-1, (attn.heads, -1))
 txt_key = txt_key.unflatten(-1, (attn.heads, -1))
 txt_value = txt_value.unflatten(-1, (attn.heads, -1))
 # Apply QK normalization
 if attn.norm_q is not None:
 img_query = attn.norm_q(img_query)
 if attn.norm_k is not None:
 img_key = attn.norm_k(img_key)
 if attn.norm_added_q is not None:
 txt_query = attn.norm_added_q(txt_query)
 if attn.norm_added_k is not None:
 txt_key = attn.norm_added_k(txt_key)
 # Apply RoPE
 if image_rotary_emb is not None:
 img_freqs, txt_freqs = image_rotary_emb
 img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
 img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
 txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
 txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
 # Concatenate for joint attention
 # Order: [text, image]
 joint_query = torch.cat([txt_query, img_query], dim=1)
 joint_key = torch.cat([txt_key, img_key], dim=1)
 joint_value = torch.cat([txt_value, img_value], dim=1)
 # Compute joint attention
 joint_hidden_states = dispatch_attention_fn(
 joint_query,
 joint_key,
 joint_value,
 attn_mask=attention_mask,
 dropout_p=0.0,
 is_causal=False,
 backend=self._attention_backend,
 )
 # Reshape back
 joint_hidden_states = joint_hidden_states.flatten(2, 3)
 joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
 # Split attention outputs back
 txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
 img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
 # Apply output projections
 img_attn_output = attn.to_out[0](img_attn_output)
 if len(attn.to_out) > 1:
 img_attn_output = attn.to_out[1](img_attn_output) # dropout
 txt_attn_output = attn.to_add_out(txt_attn_output)
 return img_attn_output, txt_attn_output

Logs

System Info

A100

Who can help?

@sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

      Relationships

      None yet

      Development

      No branches or pull requests

      Issue actions

        AltStyle によって変換されたページ (->オリジナル) /