-
Notifications
You must be signed in to change notification settings - Fork 6.4k
How to train a diffusion model from sratch but not from from_pretrained ? #12163
-
I finetuned the pre-trained model of 'stable-diffusion-inpainting' on image inpainting task and all work well, but when I finetuned with the pre-trained model of 'stable-diffusion-v1-4' on image inpainting task, the loss is always NaN.
As the two models have different input channels for unet, I have changed the unet input channels of 'stable-diffusion-v1-4' to be fit for image inpainting task. So far, the code can run but the loss is NaN.
I do not know where is the problem and I wonder how to train a diffusion model from sratch but not from from_pretrained ?
Beta Was this translation helpful? Give feedback.
All reactions
Training a diffusion model from scratch, especially for image inpainting, is absolutely possible, but involves careful setup. Since you're seeing NaN
losses after manually modifying a pre-trained model, it’s likely due to input/output mismatches, improper initialization, or numerical instability in training.
Let's go over:
1. Why You're Getting NaN
Loss
You said:
"I changed the UNet input channels of 'stable-diffusion-v1-4' to fit for inpainting, and now the loss is NaN."
That’s a red flag. A few things to verify:
Potential Issues:
- Input channel mismatch: Did you update all layers that depend on input channels (e.g., first convolution)?
- Weight init: If you modified layers without rein...
Replies: 1 comment 1 reply
-
Training a diffusion model from scratch, especially for image inpainting, is absolutely possible, but involves careful setup. Since you're seeing NaN
losses after manually modifying a pre-trained model, it’s likely due to input/output mismatches, improper initialization, or numerical instability in training.
Let's go over:
1. Why You're Getting NaN
Loss
You said:
"I changed the UNet input channels of 'stable-diffusion-v1-4' to fit for inpainting, and now the loss is NaN."
That’s a red flag. A few things to verify:
Potential Issues:
- Input channel mismatch: Did you update all layers that depend on input channels (e.g., first convolution)?
- Weight init: If you modified layers without reinitializing them properly, they might produce unstable values.
- Wrong masking: Inpainting models (like
stable-diffusion-inpainting
) use masked images and a conditioning image. If you're not supplying these correctly, the model can learn garbage or explode. - Incorrect noise schedule or beta parameters: If you're modifying the pipeline without correct
betas
,alphas
, ortimestep embeddings
, the model won't train stably. - Learning rate too high: A very common cause of
NaN
.
2. How to Train a Diffusion Model from Scratch (No from_pretrained
)
To train from scratch without from_pretrained
, you need to initialize every component manually and build a training loop. Here's a step-by-step outline using Hugging Face diffusers
.
2.1. Define Model Architecture
You must build:
UNet2DConditionModel
(or a custom one)AutoencoderKL
(if using latent diffusion)DDPMScheduler
orDDIMScheduler
TextEncoder
(e.g., from CLIP, if you're doing text conditioning)
Example:
from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler from transformers import CLIPTextModel, CLIPTokenizer # Initialize from config instead of from_pretrained unet = UNet2DConditionModel.from_config("CompVis/stable-diffusion-v1-4") vae = AutoencoderKL.from_config("CompVis/stable-diffusion-v1-4") text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
If you want to go fully custom: write your own UNet
, VAE
, etc., from scratch.
2.2. Data Pipeline
Your dataset should return:
- Original image
- Mask (for inpainting)
- Masked image
- Prompt or conditioning
For inpainting, the model input is often:
x_noisy
— Noised latent imagemask
— Binary maskmasked_image
— Image with regions masked outtext_embeds
— Prompt encoding
You should encode prompts with:
inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True) embeddings = text_encoder(**inputs).last_hidden_state
2.3. Training Loop Skeleton
Here’s a simplified training loop:
for batch in dataloader: # Get image, mask, masked_image, prompt latents = vae.encode(batch["image"]).latent_dist.sample() * 0.18215 noise = torch.randn_like(latents) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get text embeddings text_input = tokenizer(batch["prompt"], return_tensors="pt", padding="max_length", truncation=True).to(device) encoder_hidden_states = text_encoder(**text_input).last_hidden_state # Inpainting: add masked image and mask to inputs model_pred = unet( sample=noisy_latents, timestep=timesteps, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs={ "mask": batch["mask"], "masked_image": batch["masked_image"], } ).sample loss = F.mse_loss(model_pred, noise) loss.backward() optimizer.step()
3. How to Fix NaN in Your Finetuning Attempt
Here are targeted suggestions for your NaN
issue:
Check 1: Match Input Channels Properly
The inpainting model has 9 input channels for UNet
: 4 for latent image + 4 for masked image + 1 for mask.
Make sure you updated UNet
like this:
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4") unet.conv_in = torch.nn.Conv2d(9, unet.conv_in.out_channels, kernel_size=3, padding=1)
Then reinitialize conv_in
:
torch.nn.init.kaiming_normal_(unet.conv_in.weight) torch.nn.init.zeros_(unet.conv_in.bias)
Check 2: Clamp or Normalize Inputs
Ensure your inputs (mask, images) are normalized and correctly scaled (e.g., latent scaling factor 0.18215, image range [-1, 1]).
Check 3: Reduce LR and Use Grad Clipping
Try:
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
Lower your learning rate by 10x to see if it prevents the NaN
.
4. TL;DR: Training a Diffusion Model from Scratch
You can avoid from_pretrained
by:
from diffusers import UNet2DConditionModel unet = UNet2DConditionModel.from_config("CompVis/stable-diffusion-v1-4")
Or fully define a custom model. Then:
- Build a working data pipeline
- Define noise scheduler
- Encode text inputs
- Add masks for inpainting
- Train with MSE loss between predicted and actual noise
Hope that helps! 😊
Beta Was this translation helpful? Give feedback.
All reactions
-
👍 1
-
Thanks a lot, I will have a try.😊
Beta Was this translation helpful? Give feedback.