-
Notifications
You must be signed in to change notification settings - Fork 6.3k
Description
Describe the bug
TLDR:
In the constructor for the UNet1dModel (line 104, seen here), the embedding size is manually hardcoded to be 8 for no apparent good reason. Instead it should be block_out_channels[0].
Explanation:
The embedding_size argument is passed into GaussianFourierProjection and determines the output dimension of self.time_proj. If the user is using timestep embeddings, the output of self.time_proj is fed into the timestep embedding MLP. The input dimensionality of this feedforward ANN is defined as 2*block_out_channels[0]
but the input fed into it is going to always be 2*self.time_proj.embedding_dim
which is hardcoded as 8. You can see below that a Positional time embedding is initialized based on block_out_channels[0]; only the Gaussian is hardcoded. I think this is a very simple, very easily fixable bug.
Reproduction
import diffusers
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
unet = diffusers.UNet1DModel(use_timestep_embedding=True, act_fn='silu').to(device)
unet(torch.randn(32, unet.config.in_channels, 64).to(device), 0)
Logs
RuntimeError Traceback (most recent call last) Cell In[272], line 1 ----> 1 unet(torch.randn(32, unet.config.in_channels, 64).to(device), 0) File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File /srv/conda/envs/notebook/lib/python3.12/site-packages/diffusers/models/unets/unet_1d.py:228, in UNet1DModel.forward(self, sample, timestep, return_dict) 226 timestep_embed = self.time_proj(timesteps) 227 if self.config.use_timestep_embedding: --> 228 timestep_embed = self.time_mlp(timestep_embed.to(sample.dtype)) 229 else: 230 timestep_embed = timestep_embed[..., None] File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File /srv/conda/envs/notebook/lib/python3.12/site-packages/diffusers/models/embeddings.py:1308, in TimestepEmbedding.forward(self, sample, condition) 1306 if condition is not None: 1307 sample = sample + self.cond_proj(condition) -> 1308 sample = self.linear_1(sample) 1310 if self.act is not None: 1311 sample = self.act(sample) File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File /srv/conda/envs/notebook/lib/python3.12/site-packages/torch/nn/modules/linear.py:125, in Linear.forward(self, input) 124 def forward(self, input: Tensor) -> Tensor: --> 125 return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x16 and 64x128)
System Info
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
- 🤗 Diffusers version: 0.34.0
- Platform: Linux-6.6.56+-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.12.11
- PyTorch version (GPU?): 2.5.1.post303 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.33.1
- Transformers version: not installed
- Accelerate version: not installed
- PEFT version: not installed
- Bitsandbytes version: not installed
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: Tesla T4, 15360 MiB
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Who can help?
No response