Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Bug in initialization of UNet1DModel GaussianFourier time projection #12110

Open
Labels
bugSomething isn't working
@SammyAgrawal

Description

Describe the bug

TLDR:
In the constructor for the UNet1dModel (line 104, seen here), the embedding size is manually hardcoded to be 8 for no apparent good reason. Instead it should be block_out_channels[0].

Explanation:
The embedding_size argument is passed into GaussianFourierProjection and determines the output dimension of self.time_proj. If the user is using timestep embeddings, the output of self.time_proj is fed into the timestep embedding MLP. The input dimensionality of this feedforward ANN is defined as 2*block_out_channels[0] but the input fed into it is going to always be 2*self.time_proj.embedding_dim which is hardcoded as 8. You can see below that a Positional time embedding is initialized based on block_out_channels[0]; only the Gaussian is hardcoded. I think this is a very simple, very easily fixable bug.

Reproduction

import diffusers
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
unet = diffusers.UNet1DModel(use_timestep_embedding=True, act_fn='silu').to(device)
unet(torch.randn(32, unet.config.in_channels, 64).to(device), 0)

Logs

RuntimeError Traceback (most recent call last)
Cell In[272], line 1
----> 1 unet(torch.randn(32, unet.config.in_channels, 64).to(device), 0)
File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
 1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
 1742 # If we don't have any hooks, we want to skip the rest of the logic in
 1743 # this function, and just call forward.
 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
 1745 or _global_backward_pre_hooks or _global_backward_hooks
 1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
 1749 result = None
 1750 called_always_called_hooks = set()
File /srv/conda/envs/notebook/lib/python3.12/site-packages/diffusers/models/unets/unet_1d.py:228, in UNet1DModel.forward(self, sample, timestep, return_dict)
 226 timestep_embed = self.time_proj(timesteps)
 227 if self.config.use_timestep_embedding:
--> 228 timestep_embed = self.time_mlp(timestep_embed.to(sample.dtype))
 229 else:
 230 timestep_embed = timestep_embed[..., None]
File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
 1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
 1742 # If we don't have any hooks, we want to skip the rest of the logic in
 1743 # this function, and just call forward.
 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
 1745 or _global_backward_pre_hooks or _global_backward_hooks
 1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
 1749 result = None
 1750 called_always_called_hooks = set()
File /srv/conda/envs/notebook/lib/python3.12/site-packages/diffusers/models/embeddings.py:1308, in TimestepEmbedding.forward(self, sample, condition)
 1306 if condition is not None:
 1307 sample = sample + self.cond_proj(condition)
-> 1308 sample = self.linear_1(sample)
 1310 if self.act is not None:
 1311 sample = self.act(sample)
File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
 1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
 1742 # If we don't have any hooks, we want to skip the rest of the logic in
 1743 # this function, and just call forward.
 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
 1745 or _global_backward_pre_hooks or _global_backward_hooks
 1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
 1749 result = None
 1750 called_always_called_hooks = set()
File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/linear.py:125, in Linear.forward(self, input)
 124 def forward(self, input: Tensor) -> Tensor:
--> 125 return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x16 and 64x128)

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • 🤗 Diffusers version: 0.34.0
  • Platform: Linux-6.6.56+-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.12.11
  • PyTorch version (GPU?): 2.5.1.post303 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.33.1
  • Transformers version: not installed
  • Accelerate version: not installed
  • PEFT version: not installed
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: Tesla T4, 15360 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

      Relationships

      None yet

      Development

      No branches or pull requests

      Issue actions

        AltStyle によって変換されたページ (->オリジナル) /