Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Support float8, int8, int4 in diffusers? #7023

sayakpaul started this conversation in General
Discussion options

Comfy and A1111 has been supporting Float8 for some time now:

A1111 reports quite nice improvements for VRAM consumption:

image

Timing takes a hit because of the casting overhead but that's okay in the interest of the reduced VRAM, IMO.

So, I tried using qaunto to potentially benefit from FP8 (benchmark run on 4090):

import argparse
from quanto.quantize import quantize, freeze
import torch
import torch.utils.benchmark as benchmark
from diffusers import DiffusionPipeline
CKPT = "runwayml/stable-diffusion-v1-5"
NUM_INFERENCE_STEPS = 50
WARM_UP_ITERS = 5
PROMPT = "ghibli style, a fantasy landscape with castles"
TORCH_DTYPES = {"fp32": torch.float32, "fp16": torch.float16}
UNET_FP8_DTYPES = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e5m2": torch.float8_e5m2}
def load_pipeline(torch_dtype, unet_in_float8=None):
 pipe = DiffusionPipeline.from_pretrained(
 CKPT, torch_dtype=torch_dtype, use_safetensors=True
 ).to("cuda")
 if unet_in_float8:
 quantize(pipe.unet, weights=unet_in_float8)
 freeze(pipe.unet)
 pipe.set_progress_bar_config(disable=True)
 return pipe
def run_inference(pipe, batch_size=1):
 _ = pipe(
 prompt=PROMPT,
 num_inference_steps=NUM_INFERENCE_STEPS,
 num_images_per_prompt=batch_size,
 generator=torch.manual_seed(0),
 )
def benchmark_fn(f, *args, **kwargs):
 t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f})
 return f"{(t0.blocked_autorange().mean):.3f}"
def bytes_to_giga_bytes(bytes):
 return f"{(bytes / 1024 / 1024 / 1024):.3f}"
if __name__ == "__main__":
 parser = argparse.ArgumentParser()
 parser.add_argument("--batch_size", type=int, default=1)
 parser.add_argument("--torch_dtype", type=str, default="fp32", choices=list(TORCH_DTYPES.keys()))
 parser.add_argument("--unet_in_float8", type=str, default=None, choices=list(UNET_FP8_DTYPES.keys()))
 args = parser.parse_args()
 pipeline = load_pipeline(
 TORCH_DTYPES[args.torch_dtype], UNET_FP8_DTYPES[args.unet_in_float8] if args.unet_in_float8 else None
 )
 for _ in range(WARM_UP_ITERS):
 run_inference(pipeline, args.batch_size)
 time = benchmark_fn(run_inference, pipeline, args.batch_size)
 memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
 print(
 f"batch_size: {args.batch_size}, torch_dtype: {args.torch_dtype}, unet_in_float8: {args.unet_in_float8} in {time} seconds."
 )
 print(f"Memory: {memory}GB.")
 img_name = f"bs@{args.batch_size}-dtype@{args.torch_dtype}-unet_fp8@{args.unet_in_float8}.png"
 image = pipeline(
 prompt=PROMPT,
 num_inference_steps=NUM_INFERENCE_STEPS,
 num_images_per_prompt=args.batch_size,
 ).images[0]
 image.save(img_name)

Here are the stats and resultant images (batch size of 1):

Settings Timing (seconds) Memory (GBs) Resultant image
FP32 3.338 6.071 Resultant Image 1
FP16 1.314 3.193 Resultant Image 2
FP16 + UNet FP8 2.131 2.652 Resultant Image 3

As we can see we're able to obtain a good amount of VRAM reduction here in comparison to FP16. Do we want to achieve that in diffusers natively, or supporting this via quanto is preferable? I am okay with the latter.

Edit: int8 is even better: #7023 (comment).

See also: huggingface/optimum-quanto#74. Cc: @dacorvo.

Curious to know your thoughts here: @yiyixuxu @DN6.

You must be logged in to vote

Replies: 10 comments 22 replies

Comment options

@sayakpaul did you try to run inference just through the unet (i.e., skip the VAE in case it's using that much memory?)

You must be logged in to vote
1 reply
Comment options

sayakpaul Feb 19, 2024
Maintainer Author

As we can see in the code I provided, it's just the UNet, that's what the other repos also do.

But I guess I am misunderstanding something here.

Comment options

@sayakpaul a couple of comments:

  • float8 weights are actually less efficient than int8 weights in terms of accuracy, so you should try with int8 weights also,
  • float8 activations are on the other hand quite efficient,
  • regardless of the target weight quantization, quanto does a fake quantize by default to allow weights to be tuned. This explains the fact that you don't decrease the VRAM. To actually store float8/int8 weights you need to call freeze(model).
You must be logged in to vote
7 replies
Comment options

sayakpaul Feb 19, 2024
Maintainer Author

I see. Thanks for explaining. However, I think we had kind of established here that using float8 might be a better choice to their non-linear representation and wider range.

I can check the distribution of the UNet params and get back to you here.

Comment options

This was for activations, not weights. Activations indeed have a non-linear distribution that is better captured by float8.

Comment options

sayakpaul Feb 19, 2024
Maintainer Author

Oh you crushed all my numbers.

Changing to quantize(pipe.unet, weights=torch.int8), yielded:

batch_size: 1, torch_dtype: fp16, unet: torch.int8 in 1.959 seconds.
Memory: 2.655GB.

No loss in the quality as well.

Comment options

That said, if you don't see any loss for float8, and if the weights follow a linear distribution, this actually means that your weights might be compatible with a 4-bit encoding, i.e. int4.
With the latest version of quanto, you can choose quanto.qint4 as a quantization target.

Comment options

sayakpaul Feb 19, 2024
Maintainer Author

I am currently using the qconv2d branch. I will merge with main and try that too.

Comment options

sayakpaul
Feb 19, 2024
Maintainer Author

Cc: @younesbelkada for feedback as well (as he is our in-house ninja for working with reduced precision).

You must be logged in to vote
0 replies
Comment options

sayakpaul
Feb 19, 2024
Maintainer Author

SD with with batch size of 4

Settings Timing (seconds) Memory (GBs)
FP32 11.057 8.902
FP16 3.801 4.587
FP16 + UNet FP8 4.332 3.803
You must be logged in to vote
0 replies
Comment options

sayakpaul
Feb 19, 2024
Maintainer Author

SDXL with batch size of 1 (steps: 30)

Settings Timing (seconds) Memory (GBs)
FP32 12.180 16.827
FP16 4.112 10.468
FP16 + UNet int8 4.710 8.135
You must be logged in to vote
2 replies
Comment options

Would someone explain why UNet int8/ UNet FP8 are slower than FP16?

Comment options

@dingkwang Because GPUs usually have more hardware for dealing with FP32 and FP16 than FP8 which cancels out possible benefits of FP8 in terms of speed (like being able to fit more numbers in the cache).

Comment options

sayakpaul
Feb 19, 2024
Maintainer Author

@dacorvo plotted the distribution of the weights of the UNet as well:

from diffusers import UNet2DConditionModel 
import matplotlib.pyplot as plt
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet").eval()
weights = []
for name, param in unet.named_parameters():
 if "weight" in name: 
 weights.append(param.view(-1).cpu().detach().numpy())
plt.figure(figsize=(10, 6))
for i, weight in enumerate(weights):
 plt.hist(weight, bins=50, alpha=0.5, label=f'Layer {i+1}')
plt.xlabel('Weight values')
plt.ylabel('Frequency')
plt.title('Distribution of Weights in the Neural Network')
plt.savefig("sdxl_unet_weight_dist.png", bbox_inches="tight", dpi=300)

SDXL

sdxl_unet_weight_dist

SD v1.5

sd_unet_weight_dist

Weights seem to be concentrated around 0s. Does this quite fit the bill for quanto.int4?

You must be logged in to vote
2 replies
Comment options

Not so sure: I was not expecting such a high peak around zero. int4 will result in extreme pruning here I suspect.
So if it works it would mean that these close-to-zero weights are actually not required. If so this is a hint that the weights are actually sparse -> this opens the way to even smarter optimization techniques.

Comment options

sayakpaul Feb 19, 2024
Maintainer Author

I am trying int4. However it errors out on from quanto import quantize, freeze:

ImportError: cannot import name 'qbitsdtype' from 'quanto.tensor' (/home/sayak/quanto/quanto/tensor/__init__.py)
Comment options

I am trying int4. However it errors out on from quanto import quantize, freeze:

ImportError: cannot import name 'qbitsdtype' from 'quanto.tensor' (/home/sayak/quanto/quanto/tensor/__init__.py)

I rebased the branch. I did a refactoring and qbitsdtype is now qtype. Note that you also need to pass a quanto.dtype instead of a torch.dtype now.

You must be logged in to vote
4 replies
Comment options

sayakpaul Feb 19, 2024
Maintainer Author

Could you elaborate that a bit? Specifically, how would the code changes look like for the code snippet and also for using int4?

Comment options

Just import qint4, qint8, ... from quanto and use them as the weights parameter in quantize instead of torch.int8 / torch.float8.

Comment options

sayakpaul Feb 19, 2024
Maintainer Author

Okay. So, this is my updated testing snippet now:

import argparse
from quanto import quantize, freeze, qint4, qint8, qfloat8_e4m3fn
import torch
import torch.utils.benchmark as benchmark
from diffusers import DiffusionPipeline
CKPT = "stabilityai/stable-diffusion-xl-base-1.0"
NUM_INFERENCE_STEPS = 30
WARM_UP_ITERS = 5
PROMPT = "ghibli style, a fantasy landscape with castles"
TORCH_DTYPES = {"fp32": torch.float32, "fp16": torch.float16}
UNET_DTYPES = {"fp8": qfloat8_e4m3fn, "int8": qint8, "int4": qint4}
def load_pipeline(torch_dtype, unet_dtype=None):
 pipe = DiffusionPipeline.from_pretrained(CKPT, torch_dtype=torch_dtype, use_safetensors=True).to("cuda")
 if unet_dtype:
 quantize(pipe.unet, weights=unet_dtype)
 freeze(pipe.unet)
 pipe.set_progress_bar_config(disable=True)
 return pipe
def run_inference(pipe, batch_size=1):
 _ = pipe(
 prompt=PROMPT,
 num_inference_steps=NUM_INFERENCE_STEPS,
 num_images_per_prompt=batch_size,
 generator=torch.manual_seed(0),
 )
def benchmark_fn(f, *args, **kwargs):
 t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f})
 return f"{(t0.blocked_autorange().mean):.3f}"
def bytes_to_giga_bytes(bytes):
 return f"{(bytes / 1024 / 1024 / 1024):.3f}"
if __name__ == "__main__":
 parser = argparse.ArgumentParser()
 parser.add_argument("--batch_size", type=int, default=1)
 parser.add_argument("--torch_dtype", type=str, default="fp32", choices=list(TORCH_DTYPES.keys()))
 parser.add_argument("--unet_dtype", type=str, default=None, choices=list(UNET_DTYPES.keys()))
 args = parser.parse_args()
 pipeline = load_pipeline(TORCH_DTYPES[args.torch_dtype], UNET_DTYPES[args.unet_dtype] if args.unet_dtype else None)
 for _ in range(WARM_UP_ITERS):
 run_inference(pipeline, args.batch_size)
 time = benchmark_fn(run_inference, pipeline, args.batch_size)
 memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
 print(
 f"batch_size: {args.batch_size}, torch_dtype: {args.torch_dtype}, unet_dtype: {args.unet_dtype} in {time} seconds."
 )
 print(f"Memory: {memory}GB.")
 img_name = f"bs@{args.batch_size}-dtype@{args.torch_dtype}-unet_dtype@{args.unet_dtype}.png"
 image = pipeline(
 prompt=PROMPT,
 num_inference_steps=NUM_INFERENCE_STEPS,
 num_images_per_prompt=args.batch_size,
 ).images[0]
 image.save(img_name)

It is leading to:

Traceback (most recent call last):
 File "/home/sayak/diffusers/benchmark_reduced_precision_unet.py", line 53, in <module>
 pipeline = load_pipeline(TORCH_DTYPES[args.torch_dtype], UNET_DTYPES[args.unet_dtype] if args.unet_dtype else None)
 File "/home/sayak/diffusers/benchmark_reduced_precision_unet.py", line 22, in load_pipeline
 freeze(pipe.unet)
 File "/home/sayak/quanto/quanto/quantize.py", line 40, in freeze
 m.freeze()
 File "/home/sayak/quanto/quanto/nn/qmodule.py", line 112, in freeze
 qweight = self.qweight()
 File "/home/sayak/quanto/quanto/nn/qlinear.py", line 44, in qweight
 raise ValueError(f"Invalid quantized weights type {self.weights}")
ValueError: Invalid quantized weights type quanto.qfloat8_e4m3fn
Comment options

Hum yeah for QLinear I limited it to integer types. Let me push a small modification.

Comment options

Hum yeah for QLinear I limited it to integer types. Let me push a small modification.

It should be OK now.

You must be logged in to vote
1 reply
Comment options

sayakpaul Feb 23, 2024
Maintainer Author

And we're at:

batch_size: 1, torch_dtype: fp16, unet_dtype: int4 in 5.688 seconds.
Memory: 6.819GB.

Reference is here: #7023 (comment). Pretty nice memory savings :)

Comment options

sayakpaul
Feb 23, 2024
Maintainer Author

@dacorvo I am getting:

RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Half
You must be logged in to vote
4 replies
Comment options

Yes, I know, I got that too. I think this comes from one of my latest changes to fuse dequantization and matmul.

Comment options

sayakpaul Feb 23, 2024
Maintainer Author

Should I wait for a fix?

I am working on a final script for the community to experiment with and also to publish the results as I gather them.

Comment options

Can you share your config ? I cannot reproduce atm.

Comment options

sayakpaul Feb 23, 2024
Maintainer Author

Full code:

import argparse
from quanto import quantize, freeze, qint4, qint8, qfloat8_e4m3fn
import torch
import torch.utils.benchmark as benchmark
from diffusers import DiffusionPipeline
WARM_UP_ITERS = 5
PROMPT = "ghibli style, a fantasy landscape with castles"
TORCH_DTYPES = {"fp32": torch.float32, "fp16": torch.float16}
UNET_DTYPES = {"fp8": qfloat8_e4m3fn, "int8": qint8, "int4": qint4}
def load_pipeline(ckpt_id, torch_dtype, unet_dtype=None):
 pipe = DiffusionPipeline.from_pretrained(ckpt_id, torch_dtype=torch_dtype).to("cuda")
 if unet_dtype:
 quantize(pipe.unet, weights=unet_dtype)
 freeze(pipe.unet)
 pipe.set_progress_bar_config(disable=True)
 return pipe
def run_inference(pipe, num_inference_steps, batch_size=1):
 _ = pipe(
 prompt=PROMPT,
 num_inference_steps=num_inference_steps,
 num_images_per_prompt=batch_size,
 generator=torch.manual_seed(0),
 )
def benchmark_fn(f, *args, **kwargs):
 t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f})
 return f"{(t0.blocked_autorange().mean):.3f}"
def bytes_to_giga_bytes(bytes):
 return f"{(bytes / 1024 / 1024 / 1024):.3f}"
if __name__ == "__main__":
 parser = argparse.ArgumentParser()
 parser.add_argument(
 "--ckpt_id",
 type=str,
 default="runwayml/stable-diffusion-v1-5",
 choices=["runwayml/stable-diffusion-v1-5", "stabilityai/stable-diffusion-xl-base-1.0"],
 )
 parser.add_argument("--num_inference_steps", type=int, default=50)
 parser.add_argument("--batch_size", type=int, default=1)
 parser.add_argument("--torch_dtype", type=str, default="fp32", choices=list(TORCH_DTYPES.keys()))
 parser.add_argument("--unet_dtype", type=str, default=None, choices=list(UNET_DTYPES.keys()))
 args = parser.parse_args()
 pipeline = load_pipeline(
 args.ckpt_id, TORCH_DTYPES[args.torch_dtype], UNET_DTYPES[args.unet_dtype] if args.unet_dtype else None
 )
 for _ in range(WARM_UP_ITERS):
 run_inference(pipeline, args.num_inference_steps, args.batch_size)
 time = benchmark_fn(run_inference, pipeline, args.num_inference_steps, args.batch_size)
 memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
 print(
 f"ckpt: {args.ckpt_id} batch_size: {args.batch_size}, "
 f"torch_dtype: {args.torch_dtype}, unet_dtype: {args.unet_dtype} in {time} seconds and {memory} GBs."
 )
 ckpt_id = args.ckpt_id.replace("/", "_")
 img_name = f"ckpt@{args.ckpt_id}-bs@{args.batch_size}-dtype@{args.torch_dtype}-unet_dtype@{args.unet_dtype}.png"
 image = pipeline(
 prompt=PROMPT,
 num_inference_steps=args.num_inference_steps,
 num_images_per_prompt=args.batch_size,
 ).images[0]
 image.save(img_name)

I am on quanto main.

Run it with

python benchmark_reduced_precision_unet.py --torch_dtype=fp16 --unet_dtype=fp8
Comment options

Has any progress been made on supporting fp8_e4m3fn in diffusers?

You must be logged in to vote
1 reply
Comment options

sayakpaul Sep 28, 2025
Maintainer Author

Well, you can use layerwise casting as well as FP8 from torchao.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

AltStyle によって変換されたページ (->オリジナル) /