-
Couldn't load subscription status.
- Fork 6.4k
Description
Hi! @a-r-r-o-w ,I would like to ask you about my error on using Context Parallelism for inference.
Issue Description
Environment
- Diffusers: 0.36.0.dev0
Problem Description
I'm trying to run image-to-video generation using WanImageToVideoPipeline with model quantization (qfloat8_e4m3fn via Quanto), frozen weights, and Context Parallelism enabled with ulysses_degree=8. The pipeline initializes successfully, but during the first inference step (at 0/20 steps), it raises an AssertionError in the Context Parallel hook:
AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size
This occurs in diffusers/hooks/context_parallel.py during the sharding of hidden_states in the transformer block's forward pass.
Expected Behavior: The pipeline should generate the video frames without crashing, distributing computation across GPUs via Context Parallelism.
Actual Behavior: Crashes immediately at the start of denoising loop.
Minimal Reproducible Code
Here's the full script that's failing (run with torchrun --nproc_per_node=8 test.py or similar for 8 GPUs):
import torch import os from PIL import Image from diffusers import ( AutoencoderKLWan, WanPipeline, WanTransformer3DModel, ContextParallelConfig, WanImageToVideoPipeline ) from diffusers.utils import export_to_video from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler from optimum.quanto import freeze, qfloat8_e4m3fn, quantize from transformers import AutoTokenizer, UMT5EncoderModel, CLIPVisionModel torch.distributed.init_process_group("nccl") rank = torch.distributed.get_rank() device = torch.device("cuda", rank % torch.cuda.device_count()) torch.cuda.set_device(device) dtype = torch.bfloat16 model_id = '/share/models/checkpoints/Wan-AI/Wan2___1-I2V-14B-720P-Diffusers' transformer = WanTransformer3DModel.from_pretrained( model_id, subfolder="transformer", torch_dtype=torch.bfloat16 ) text_encoder = UMT5EncoderModel.from_pretrained( model_id, subfolder="text_encoder", torch_dtype=dtype ) # Quantize text_encoder quantize(text_encoder, weights=qfloat8_e4m3fn) freeze(text_encoder) # Quantize transformer quantize(transformer, weights=qfloat8_e4m3fn) freeze(transformer) pipe = WanImageToVideoPipeline.from_pretrained( model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=dtype ) flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P pipe.scheduler = UniPCMultistepScheduler.from_config( pipe.scheduler.config, flow_shift=flow_shift ) pipe.to("cuda") transformer.set_attention_backend("_native_cudnn") pipe.transformer.enable_parallelism( config=ContextParallelConfig(ulysses_degree=8) ) image = Image.open("/share/common/AIPhoto/3.jpeg").resize((832, 480)) # .resize((800,1280)) prompt = ( "现代都市风格摄影,一位身穿白色印花T恤和黑色短裤的年轻男子坐在透明玻璃楼梯上," "脚穿黑白帆布鞋,姿态随性自然。他的皮肤白皙,身材匀称,双腿微微分开," "手肘搭在膝盖上,背景是高耸的玻璃幕墙和现代化建筑,透过玻璃可见城市的高楼轮廓。" "在固定镜头下,他缓慢抬起双手【双手比心】,动作轻松流畅,整个画面充满现代感与都市气息。" "慢动作展现细腻的动态细节。" ) negative_prompt = ( "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, " "static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, " "extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, " "fused fingers, still picture, messy background, three legs, many people in the background, " "walking backwards" ) # Must specify generator so all ranks start with same latents (or pass your own) generator = torch.Generator().manual_seed(42) output = pipe( image=image, prompt=prompt, negative_prompt=negative_prompt, height=480, width=832, num_frames=81, guidance_scale=5.0, num_inference_steps=20, generator=generator, ).frames[0] if rank == 0: export_to_video(output, "output.mp4", fps=16) if torch.distributed.is_initialized(): torch.distributed.destroy_process_group()
Error Traceback
0%| | 0/20 [00:00<?, ?it/s]
[rank1]: Traceback (most recent call last):
[rank1]: File "/share/gdli7/common/AIPhoto/test.py", line 46, in <module>
[rank1]: output = pipe(
[rank1]: ^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 756, in __call__
[rank1]: noise_pred = current_model(
[rank1]: ^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/models/transformers/transformer_wan.py", line 680, in forward
[rank1]: hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/hooks.py", line 188, in new_forward
[rank1]: args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/context_parallel.py", line 157, in pre_forward
[rank1]: input_val = self._prepare_cp_input(input_val, cpm)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/context_parallel.py", line 209, in _prepare_cp_input
[rank1]: return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/conda/lib/python3.11/site-packages/diffusers-0.36.0.dev0-py3.11.egg/diffusers/hooks/context_parallel.py", line 259, in shard
[rank1]: assert tensor.size()[dim] % mesh.size() == 0, (
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: AssertionError: Tensor size along dimension to be sharded must be divisible by mesh size