-
Couldn't load subscription status.
- Fork 6.5k
fix crash if tiling mode is enabled #12521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
cuda should have similar issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your PR.
But before we go on reviewing it, could you please:
- Include an error trace that you get without the changes from this PR?
- Include an output with the changes from this PR?
- Additionally, the changes introduced in this PR seem non-intrusive to me. So, if you add comments to explain those changes, that'd be super nice.
HuggingFaceDocBuilderDev
commented
Oct 21, 2025
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
wo the change, crash like
Traceback (most recent call last):
File "/workspace/test.py", line 27, in
output = pipe(
^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/utils/contextlib.py", line 120, in decorate context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/workspace/diffusers/src/diffusers/pipelines/wan/pipeline_wan.py", line 645, in call
video = self.vae.decode(latents, return_dict=False)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/diffusers/src/diffusers/utils/accelerate_utils.py", line 46, in wrapper
return method(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1248, in decode
decoded = self._decode(z).sample
^^^^^^^^^^^^^^^
File "/workspace/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1204, in _decode
return self.tiled_decode(z, return_dict=return_dict)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1374, in tiled_decode
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped _call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_im pl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 892, i n forward
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped _call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_im pl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/diffusers/src/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 708, i n forward
x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimens ion 2
Thanks! What about the outputs? Cc: @asomoza if you wanna help test it out a bit?
however, there's another crash after this crash is fixed.
So, it doesn't work yet?
however, there's another crash after this crash is fixed.
So, it doesn't work yet?
it works, the other crash is because patch_size is not considered in tiling mode. in this model, it's 2. and this PR fix it.
crash like
Traceback (most recent call last):
File "/workspace/test.py", line 36, in
export_to_video(output, "5bit2v_output.mp4", fps=24)
File "/workspace/diffusers/src/diffusers/utils/export_utils.py", line 177, in export_to_video
return _legacy_export_to_video(video_frames, output_video_path, fps)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspace/diffusers/src/diffusers/utils/export_utils.py", line 135, in _legacy_export_to_video
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
cv2.error: OpenCV(4.11.0) /io/opencv/modules/imgproc/src/color.simd_helpers.hpp:92: error: (-15:Bad number of channels) in function 'cv::impl::{anonymous}::CvtHelper<VScn, VDcn, VDepth, sizePolicy>::CvtHelper(cv::InputArray, cv::OutputArray, int) [with VScn = cv::impl::{anonymous}::Set<3, 4>; VDcn = cv::impl::{anonymous}::Set<3, 4>; VDepth = cv::impl::{anonymous}::Set<0, 2, 5>; cv::impl::{anonymous}::SizePolicy sizePolicy = cv::impl::::NONE; cv::InputArray = const cv::_InputArray&; cv::OutputArray = const cv::_OutputArray&]'
Invalid number of channels in input image:
'VScn::contains(scn)'
where
'scn' is 12
this PR also fix it
@sywangyi would you be able to post some outputs after applying the fix?
tested it with a simple pipe.vae.enable_tiling() over the example code:
(削除) in fact, it doesn't work with main, but this PR also doesn't fix it, still got: (削除ここまで)
RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 2
edit: I correct myself, I did a silly mistake, this PR does fix the issue for the 5B, I'll do a comparison with main
here they are:
main (without tiling)
5bit2v__main_output.mp4
PR with pipe.vae.enable_tiling()
Uh oh!
There was an error while loading. Please reload this page.
@sayakpaul @dg845 please help review, test script