Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

AssertionError in Context Parallelism during WanImageToVideoPipeline inference: Tensor size along sharding dimension not divisible by mesh size #12536

Open
@leeguandong

Description

Hi! @a-r-r-o-w ,I would like to ask you about my error on using Context Parallelism for inference.

Issue Description

Environment

  • Diffusers: 0.36.0.dev0

Problem Description

I'm trying to run image-to-video generation using WanImageToVideoPipeline with model quantization (qfloat8_e4m3fn via Quanto), frozen weights, and Context Parallelism enabled with ulysses_degree=8. The pipeline initializes successfully, but during the first inference step (at 0/20 steps), it raises an AssertionError in the Context Parallel hook:

AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size

This occurs in diffusers/hooks/context_parallel.py during the sharding of hidden_states in the transformer block's forward pass.

Expected Behavior: The pipeline should generate the video frames without crashing, distributing computation across GPUs via Context Parallelism.

Actual Behavior: Crashes immediately at the start of denoising loop.

Minimal Reproducible Code

Here's the full script that's failing (run with torchrun --nproc_per_node=8 test.py or similar for 8 GPUs):

import torch
import os
from PIL import Image
from diffusers import (
 AutoencoderKLWan, WanPipeline, WanTransformer3DModel, ContextParallelConfig,
 WanImageToVideoPipeline
)
from diffusers.utils import export_to_video
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from optimum.quanto import freeze, qfloat8_e4m3fn, quantize
from transformers import AutoTokenizer, UMT5EncoderModel, CLIPVisionModel
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
torch.cuda.set_device(device)
dtype = torch.bfloat16
model_id = '/share/models/checkpoints/Wan-AI/Wan2___1-I2V-14B-720P-Diffusers'
transformer = WanTransformer3DModel.from_pretrained(
 model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
text_encoder = UMT5EncoderModel.from_pretrained(
 model_id, subfolder="text_encoder", torch_dtype=dtype
)
# Quantize text_encoder
quantize(text_encoder, weights=qfloat8_e4m3fn)
freeze(text_encoder)
# Quantize transformer
quantize(transformer, weights=qfloat8_e4m3fn)
freeze(transformer)
pipe = WanImageToVideoPipeline.from_pretrained(
 model_id,
 transformer=transformer,
 text_encoder=text_encoder,
 torch_dtype=dtype
)
flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
pipe.scheduler = UniPCMultistepScheduler.from_config(
 pipe.scheduler.config, flow_shift=flow_shift
)
pipe.to("cuda")
transformer.set_attention_backend("_native_cudnn")
pipe.transformer.enable_parallelism(
 config=ContextParallelConfig(ulysses_degree=8)
)
image = Image.open("/share/common/AIPhoto/3.jpeg").resize((832, 480))
# .resize((800,1280))
prompt = (
 "现代都市风格摄影,一位身穿白色印花T恤和黑色短裤的年轻男子坐在透明玻璃楼梯上,"
 "脚穿黑白帆布鞋,姿态随性自然。他的皮肤白皙,身材匀称,双腿微微分开,"
 "手肘搭在膝盖上,背景是高耸的玻璃幕墙和现代化建筑,透过玻璃可见城市的高楼轮廓。"
 "在固定镜头下,他缓慢抬起双手【双手比心】,动作轻松流畅,整个画面充满现代感与都市气息。"
 "慢动作展现细腻的动态细节。"
)
negative_prompt = (
 "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, "
 "static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, "
 "extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, "
 "fused fingers, still picture, messy background, three legs, many people in the background, "
 "walking backwards"
)
# Must specify generator so all ranks start with same latents (or pass your own)
generator = torch.Generator().manual_seed(42)
output = pipe(
 image=image,
 prompt=prompt,
 negative_prompt=negative_prompt,
 height=480,
 width=832,
 num_frames=81,
 guidance_scale=5.0,
 num_inference_steps=20,
 generator=generator,
).frames[0]
if rank == 0:
 export_to_video(output, "output.mp4", fps=16)
if torch.distributed.is_initialized():
 torch.distributed.destroy_process_group()

Error Traceback

0%| | 0/20 [00:00<?, ?it/s]
[rank1]: Traceback (most recent call last):
[rank1]: File "/share/gdli7/common/AIPhoto/test.py", line 46, in <module>
[rank1]: output = pipe(
[rank1]: ^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 756, in __call__
[rank1]: noise_pred = current_model(
[rank1]: ^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/models/transformers/transformer_wan.py", line 680, in forward
[rank1]: hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/hooks.py", line 188, in new_forward
[rank1]: args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/context_parallel.py", line 157, in pre_forward
[rank1]: input_val = self._prepare_cp_input(input_val, cpm)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/context_parallel.py", line 209, in _prepare_cp_input
[rank1]: return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/context_parallel.py", line 259, in shard
[rank1]: assert tensor.size()[dim] % mesh.size() == 0, (
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

      Relationships

      None yet

      Development

      No branches or pull requests

      Issue actions

        AltStyle によって変換されたページ (->オリジナル) /