Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

How to train a diffusion model from sratch but not from from_pretrained ? #12163

Answered by alishanawer
micklexqg asked this question in Q&A
Discussion options

I finetuned the pre-trained model of 'stable-diffusion-inpainting' on image inpainting task and all work well, but when I finetuned with the pre-trained model of 'stable-diffusion-v1-4' on image inpainting task, the loss is always NaN.
As the two models have different input channels for unet, I have changed the unet input channels of 'stable-diffusion-v1-4' to be fit for image inpainting task. So far, the code can run but the loss is NaN.
I do not know where is the problem and I wonder how to train a diffusion model from sratch but not from from_pretrained ?

You must be logged in to vote

Training a diffusion model from scratch, especially for image inpainting, is absolutely possible, but involves careful setup. Since you're seeing NaN losses after manually modifying a pre-trained model, it’s likely due to input/output mismatches, improper initialization, or numerical instability in training.

Let's go over:


1. Why You're Getting NaN Loss

You said:

"I changed the UNet input channels of 'stable-diffusion-v1-4' to fit for inpainting, and now the loss is NaN."

That’s a red flag. A few things to verify:

Potential Issues:

  • Input channel mismatch: Did you update all layers that depend on input channels (e.g., first convolution)?
  • Weight init: If you modified layers without rein...

Replies: 1 comment 1 reply

Comment options

Training a diffusion model from scratch, especially for image inpainting, is absolutely possible, but involves careful setup. Since you're seeing NaN losses after manually modifying a pre-trained model, it’s likely due to input/output mismatches, improper initialization, or numerical instability in training.

Let's go over:


1. Why You're Getting NaN Loss

You said:

"I changed the UNet input channels of 'stable-diffusion-v1-4' to fit for inpainting, and now the loss is NaN."

That’s a red flag. A few things to verify:

Potential Issues:

  • Input channel mismatch: Did you update all layers that depend on input channels (e.g., first convolution)?
  • Weight init: If you modified layers without reinitializing them properly, they might produce unstable values.
  • Wrong masking: Inpainting models (like stable-diffusion-inpainting) use masked images and a conditioning image. If you're not supplying these correctly, the model can learn garbage or explode.
  • Incorrect noise schedule or beta parameters: If you're modifying the pipeline without correct betas, alphas, or timestep embeddings, the model won't train stably.
  • Learning rate too high: A very common cause of NaN.

2. How to Train a Diffusion Model from Scratch (No from_pretrained)

To train from scratch without from_pretrained, you need to initialize every component manually and build a training loop. Here's a step-by-step outline using Hugging Face diffusers.


2.1. Define Model Architecture

You must build:

  • UNet2DConditionModel (or a custom one)
  • AutoencoderKL (if using latent diffusion)
  • DDPMScheduler or DDIMScheduler
  • TextEncoder (e.g., from CLIP, if you're doing text conditioning)

Example:

from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
# Initialize from config instead of from_pretrained
unet = UNet2DConditionModel.from_config("CompVis/stable-diffusion-v1-4")
vae = AutoencoderKL.from_config("CompVis/stable-diffusion-v1-4")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

If you want to go fully custom: write your own UNet, VAE, etc., from scratch.


2.2. Data Pipeline

Your dataset should return:

  • Original image
  • Mask (for inpainting)
  • Masked image
  • Prompt or conditioning

For inpainting, the model input is often:

  • x_noisy — Noised latent image
  • mask — Binary mask
  • masked_image — Image with regions masked out
  • text_embeds — Prompt encoding

You should encode prompts with:

inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True)
embeddings = text_encoder(**inputs).last_hidden_state

2.3. Training Loop Skeleton

Here’s a simplified training loop:

for batch in dataloader:
 # Get image, mask, masked_image, prompt
 latents = vae.encode(batch["image"]).latent_dist.sample() * 0.18215
 noise = torch.randn_like(latents)
 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=device).long()
 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
 # Get text embeddings
 text_input = tokenizer(batch["prompt"], return_tensors="pt", padding="max_length", truncation=True).to(device)
 encoder_hidden_states = text_encoder(**text_input).last_hidden_state
 # Inpainting: add masked image and mask to inputs
 model_pred = unet(
 sample=noisy_latents,
 timestep=timesteps,
 encoder_hidden_states=encoder_hidden_states,
 added_cond_kwargs={
 "mask": batch["mask"],
 "masked_image": batch["masked_image"],
 }
 ).sample
 loss = F.mse_loss(model_pred, noise)
 loss.backward()
 optimizer.step()

3. How to Fix NaN in Your Finetuning Attempt

Here are targeted suggestions for your NaN issue:

Check 1: Match Input Channels Properly

The inpainting model has 9 input channels for UNet: 4 for latent image + 4 for masked image + 1 for mask.

Make sure you updated UNet like this:

unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4")
unet.conv_in = torch.nn.Conv2d(9, unet.conv_in.out_channels, kernel_size=3, padding=1)

Then reinitialize conv_in:

torch.nn.init.kaiming_normal_(unet.conv_in.weight)
torch.nn.init.zeros_(unet.conv_in.bias)

Check 2: Clamp or Normalize Inputs

Ensure your inputs (mask, images) are normalized and correctly scaled (e.g., latent scaling factor 0.18215, image range [-1, 1]).

Check 3: Reduce LR and Use Grad Clipping

Try:

torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)

Lower your learning rate by 10x to see if it prevents the NaN.


4. TL;DR: Training a Diffusion Model from Scratch

You can avoid from_pretrained by:

from diffusers import UNet2DConditionModel
unet = UNet2DConditionModel.from_config("CompVis/stable-diffusion-v1-4")

Or fully define a custom model. Then:

  • Build a working data pipeline
  • Define noise scheduler
  • Encode text inputs
  • Add masks for inpainting
  • Train with MSE loss between predicted and actual noise

Hope that helps! 😊

You must be logged in to vote
1 reply
Comment options

Thanks a lot, I will have a try.😊

Answer selected by micklexqg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet

AltStyle によって変換されたページ (->オリジナル) /