-
Notifications
You must be signed in to change notification settings - Fork 6.4k
-
Comfy and A1111 has been supporting Float8 for some time now:
A1111 reports quite nice improvements for VRAM consumption: imageTiming takes a hit because of the casting overhead but that's okay in the interest of the reduced VRAM, IMO. So, I tried using import argparse from quanto.quantize import quantize, freeze import torch import torch.utils.benchmark as benchmark from diffusers import DiffusionPipeline CKPT = "runwayml/stable-diffusion-v1-5" NUM_INFERENCE_STEPS = 50 WARM_UP_ITERS = 5 PROMPT = "ghibli style, a fantasy landscape with castles" TORCH_DTYPES = {"fp32": torch.float32, "fp16": torch.float16} UNET_FP8_DTYPES = {"fp8_e4m3fn": torch.float8_e4m3fn, "fp8_e5m2": torch.float8_e5m2} def load_pipeline(torch_dtype, unet_in_float8=None): pipe = DiffusionPipeline.from_pretrained( CKPT, torch_dtype=torch_dtype, use_safetensors=True ).to("cuda") if unet_in_float8: quantize(pipe.unet, weights=unet_in_float8) freeze(pipe.unet) pipe.set_progress_bar_config(disable=True) return pipe def run_inference(pipe, batch_size=1): _ = pipe( prompt=PROMPT, num_inference_steps=NUM_INFERENCE_STEPS, num_images_per_prompt=batch_size, generator=torch.manual_seed(0), ) def benchmark_fn(f, *args, **kwargs): t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}) return f"{(t0.blocked_autorange().mean):.3f}" def bytes_to_giga_bytes(bytes): return f"{(bytes / 1024 / 1024 / 1024):.3f}" if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--torch_dtype", type=str, default="fp32", choices=list(TORCH_DTYPES.keys())) parser.add_argument("--unet_in_float8", type=str, default=None, choices=list(UNET_FP8_DTYPES.keys())) args = parser.parse_args() pipeline = load_pipeline( TORCH_DTYPES[args.torch_dtype], UNET_FP8_DTYPES[args.unet_in_float8] if args.unet_in_float8 else None ) for _ in range(WARM_UP_ITERS): run_inference(pipeline, args.batch_size) time = benchmark_fn(run_inference, pipeline, args.batch_size) memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs. print( f"batch_size: {args.batch_size}, torch_dtype: {args.torch_dtype}, unet_in_float8: {args.unet_in_float8} in {time} seconds." ) print(f"Memory: {memory}GB.") img_name = f"bs@{args.batch_size}-dtype@{args.torch_dtype}-unet_fp8@{args.unet_in_float8}.png" image = pipeline( prompt=PROMPT, num_inference_steps=NUM_INFERENCE_STEPS, num_images_per_prompt=args.batch_size, ).images[0] image.save(img_name) Here are the stats and resultant images (batch size of 1):
As we can see we're able to obtain a good amount of VRAM reduction here in comparison to FP16. Do we want to achieve that in Edit: int8 is even better: #7023 (comment). See also: huggingface/optimum-quanto#74. Cc: @dacorvo. |
Beta Was this translation helpful? Give feedback.
All reactions
-
👍 1 -
❤️ 1
Replies: 10 comments 22 replies
-
@sayakpaul did you try to run inference just through the unet (i.e., skip the VAE in case it's using that much memory?)
Beta Was this translation helpful? Give feedback.
All reactions
-
As we can see in the code I provided, it's just the UNet, that's what the other repos also do.
But I guess I am misunderstanding something here.
Beta Was this translation helpful? Give feedback.
All reactions
-
@sayakpaul a couple of comments:
- float8 weights are actually less efficient than int8 weights in terms of accuracy, so you should try with int8 weights also,
- float8 activations are on the other hand quite efficient,
- regardless of the target weight quantization, quanto does a fake quantize by default to allow weights to be tuned. This explains the fact that you don't decrease the VRAM. To actually store float8/int8 weights you need to call freeze(model).
Beta Was this translation helpful? Give feedback.
All reactions
-
👀 2
-
I see. Thanks for explaining. However, I think we had kind of established here that using float8 might be a better choice to their non-linear representation and wider range.
I can check the distribution of the UNet params and get back to you here.
Beta Was this translation helpful? Give feedback.
All reactions
-
This was for activations, not weights. Activations indeed have a non-linear distribution that is better captured by float8
.
Beta Was this translation helpful? Give feedback.
All reactions
-
Oh you crushed all my numbers.
Changing to quantize(pipe.unet, weights=torch.int8)
, yielded:
batch_size: 1, torch_dtype: fp16, unet: torch.int8 in 1.959 seconds.
Memory: 2.655GB.
No loss in the quality as well.
Beta Was this translation helpful? Give feedback.
All reactions
-
That said, if you don't see any loss for float8
, and if the weights follow a linear distribution, this actually means that your weights might be compatible with a 4-bit encoding, i.e. int4
.
With the latest version of quanto
, you can choose quanto.qint4
as a quantization target.
Beta Was this translation helpful? Give feedback.
All reactions
-
I am currently using the qconv2d
branch. I will merge with main
and try that too.
Beta Was this translation helpful? Give feedback.
All reactions
-
Cc: @younesbelkada for feedback as well (as he is our in-house ninja for working with reduced precision).
Beta Was this translation helpful? Give feedback.
All reactions
-
SD with with batch size of 4
|
Beta Was this translation helpful? Give feedback.
All reactions
-
SDXL with batch size of 1 (steps: 30)
|
Beta Was this translation helpful? Give feedback.
All reactions
-
Would someone explain why UNet int8/ UNet FP8 are slower than FP16?
Beta Was this translation helpful? Give feedback.
All reactions
-
@dingkwang Because GPUs usually have more hardware for dealing with FP32 and FP16 than FP8 which cancels out possible benefits of FP8 in terms of speed (like being able to fit more numbers in the cache).
Beta Was this translation helpful? Give feedback.
All reactions
-
@dacorvo plotted the distribution of the weights of the UNet as well:
from diffusers import UNet2DConditionModel import matplotlib.pyplot as plt unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet").eval() weights = [] for name, param in unet.named_parameters(): if "weight" in name: weights.append(param.view(-1).cpu().detach().numpy()) plt.figure(figsize=(10, 6)) for i, weight in enumerate(weights): plt.hist(weight, bins=50, alpha=0.5, label=f'Layer {i+1}') plt.xlabel('Weight values') plt.ylabel('Frequency') plt.title('Distribution of Weights in the Neural Network') plt.savefig("sdxl_unet_weight_dist.png", bbox_inches="tight", dpi=300)
SDXL
SD v1.5
Weights seem to be concentrated around 0s. Does this quite fit the bill for quanto.int4
?
Beta Was this translation helpful? Give feedback.
All reactions
-
Not so sure: I was not expecting such a high peak around zero. int4
will result in extreme pruning here I suspect.
So if it works it would mean that these close-to-zero weights are actually not required. If so this is a hint that the weights are actually sparse -> this opens the way to even smarter optimization techniques.
Beta Was this translation helpful? Give feedback.
All reactions
-
I am trying int4
. However it errors out on from quanto import quantize, freeze
:
ImportError: cannot import name 'qbitsdtype' from 'quanto.tensor' (/home/sayak/quanto/quanto/tensor/__init__.py)
Beta Was this translation helpful? Give feedback.
All reactions
-
👀 1
-
I am trying
int4
. However it errors out onfrom quanto import quantize, freeze
:ImportError: cannot import name 'qbitsdtype' from 'quanto.tensor' (/home/sayak/quanto/quanto/tensor/__init__.py)
I rebased the branch. I did a refactoring and qbitsdtype
is now qtype
. Note that you also need to pass a quanto.dtype
instead of a torch.dtype
now.
Beta Was this translation helpful? Give feedback.
All reactions
-
Could you elaborate that a bit? Specifically, how would the code changes look like for the code snippet and also for using int4
?
Beta Was this translation helpful? Give feedback.
All reactions
-
Just import qint4
, qint8
, ... from quanto and use them as the weights parameter in quantize instead of torch.int8
/ torch.float8
.
Beta Was this translation helpful? Give feedback.
All reactions
-
Okay. So, this is my updated testing snippet now:
import argparse from quanto import quantize, freeze, qint4, qint8, qfloat8_e4m3fn import torch import torch.utils.benchmark as benchmark from diffusers import DiffusionPipeline CKPT = "stabilityai/stable-diffusion-xl-base-1.0" NUM_INFERENCE_STEPS = 30 WARM_UP_ITERS = 5 PROMPT = "ghibli style, a fantasy landscape with castles" TORCH_DTYPES = {"fp32": torch.float32, "fp16": torch.float16} UNET_DTYPES = {"fp8": qfloat8_e4m3fn, "int8": qint8, "int4": qint4} def load_pipeline(torch_dtype, unet_dtype=None): pipe = DiffusionPipeline.from_pretrained(CKPT, torch_dtype=torch_dtype, use_safetensors=True).to("cuda") if unet_dtype: quantize(pipe.unet, weights=unet_dtype) freeze(pipe.unet) pipe.set_progress_bar_config(disable=True) return pipe def run_inference(pipe, batch_size=1): _ = pipe( prompt=PROMPT, num_inference_steps=NUM_INFERENCE_STEPS, num_images_per_prompt=batch_size, generator=torch.manual_seed(0), ) def benchmark_fn(f, *args, **kwargs): t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}) return f"{(t0.blocked_autorange().mean):.3f}" def bytes_to_giga_bytes(bytes): return f"{(bytes / 1024 / 1024 / 1024):.3f}" if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--torch_dtype", type=str, default="fp32", choices=list(TORCH_DTYPES.keys())) parser.add_argument("--unet_dtype", type=str, default=None, choices=list(UNET_DTYPES.keys())) args = parser.parse_args() pipeline = load_pipeline(TORCH_DTYPES[args.torch_dtype], UNET_DTYPES[args.unet_dtype] if args.unet_dtype else None) for _ in range(WARM_UP_ITERS): run_inference(pipeline, args.batch_size) time = benchmark_fn(run_inference, pipeline, args.batch_size) memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs. print( f"batch_size: {args.batch_size}, torch_dtype: {args.torch_dtype}, unet_dtype: {args.unet_dtype} in {time} seconds." ) print(f"Memory: {memory}GB.") img_name = f"bs@{args.batch_size}-dtype@{args.torch_dtype}-unet_dtype@{args.unet_dtype}.png" image = pipeline( prompt=PROMPT, num_inference_steps=NUM_INFERENCE_STEPS, num_images_per_prompt=args.batch_size, ).images[0] image.save(img_name)
It is leading to:
Traceback (most recent call last): File "/home/sayak/diffusers/benchmark_reduced_precision_unet.py", line 53, in <module> pipeline = load_pipeline(TORCH_DTYPES[args.torch_dtype], UNET_DTYPES[args.unet_dtype] if args.unet_dtype else None) File "/home/sayak/diffusers/benchmark_reduced_precision_unet.py", line 22, in load_pipeline freeze(pipe.unet) File "/home/sayak/quanto/quanto/quantize.py", line 40, in freeze m.freeze() File "/home/sayak/quanto/quanto/nn/qmodule.py", line 112, in freeze qweight = self.qweight() File "/home/sayak/quanto/quanto/nn/qlinear.py", line 44, in qweight raise ValueError(f"Invalid quantized weights type {self.weights}") ValueError: Invalid quantized weights type quanto.qfloat8_e4m3fn
Beta Was this translation helpful? Give feedback.
All reactions
-
Hum yeah for QLinear
I limited it to integer types. Let me push a small modification.
Beta Was this translation helpful? Give feedback.
All reactions
-
👍 1
-
Hum yeah for
QLinear
I limited it to integer types. Let me push a small modification.
It should be OK now.
Beta Was this translation helpful? Give feedback.
All reactions
-
And we're at:
batch_size: 1, torch_dtype: fp16, unet_dtype: int4 in 5.688 seconds.
Memory: 6.819GB.
Reference is here: #7023 (comment). Pretty nice memory savings :)
Beta Was this translation helpful? Give feedback.
All reactions
-
🚀 1
-
@dacorvo I am getting:
RuntimeError: Promotion for Float8 Types is not supported, attempted to promote Float8_e4m3fn and Half
Beta Was this translation helpful? Give feedback.
All reactions
-
Yes, I know, I got that too. I think this comes from one of my latest changes to fuse dequantization and matmul.
Beta Was this translation helpful? Give feedback.
All reactions
-
Should I wait for a fix?
I am working on a final script for the community to experiment with and also to publish the results as I gather them.
Beta Was this translation helpful? Give feedback.
All reactions
-
Can you share your config ? I cannot reproduce atm.
Beta Was this translation helpful? Give feedback.
All reactions
-
Full code:
import argparse from quanto import quantize, freeze, qint4, qint8, qfloat8_e4m3fn import torch import torch.utils.benchmark as benchmark from diffusers import DiffusionPipeline WARM_UP_ITERS = 5 PROMPT = "ghibli style, a fantasy landscape with castles" TORCH_DTYPES = {"fp32": torch.float32, "fp16": torch.float16} UNET_DTYPES = {"fp8": qfloat8_e4m3fn, "int8": qint8, "int4": qint4} def load_pipeline(ckpt_id, torch_dtype, unet_dtype=None): pipe = DiffusionPipeline.from_pretrained(ckpt_id, torch_dtype=torch_dtype).to("cuda") if unet_dtype: quantize(pipe.unet, weights=unet_dtype) freeze(pipe.unet) pipe.set_progress_bar_config(disable=True) return pipe def run_inference(pipe, num_inference_steps, batch_size=1): _ = pipe( prompt=PROMPT, num_inference_steps=num_inference_steps, num_images_per_prompt=batch_size, generator=torch.manual_seed(0), ) def benchmark_fn(f, *args, **kwargs): t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}) return f"{(t0.blocked_autorange().mean):.3f}" def bytes_to_giga_bytes(bytes): return f"{(bytes / 1024 / 1024 / 1024):.3f}" if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--ckpt_id", type=str, default="runwayml/stable-diffusion-v1-5", choices=["runwayml/stable-diffusion-v1-5", "stabilityai/stable-diffusion-xl-base-1.0"], ) parser.add_argument("--num_inference_steps", type=int, default=50) parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--torch_dtype", type=str, default="fp32", choices=list(TORCH_DTYPES.keys())) parser.add_argument("--unet_dtype", type=str, default=None, choices=list(UNET_DTYPES.keys())) args = parser.parse_args() pipeline = load_pipeline( args.ckpt_id, TORCH_DTYPES[args.torch_dtype], UNET_DTYPES[args.unet_dtype] if args.unet_dtype else None ) for _ in range(WARM_UP_ITERS): run_inference(pipeline, args.num_inference_steps, args.batch_size) time = benchmark_fn(run_inference, pipeline, args.num_inference_steps, args.batch_size) memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs. print( f"ckpt: {args.ckpt_id} batch_size: {args.batch_size}, " f"torch_dtype: {args.torch_dtype}, unet_dtype: {args.unet_dtype} in {time} seconds and {memory} GBs." ) ckpt_id = args.ckpt_id.replace("/", "_") img_name = f"ckpt@{args.ckpt_id}-bs@{args.batch_size}-dtype@{args.torch_dtype}-unet_dtype@{args.unet_dtype}.png" image = pipeline( prompt=PROMPT, num_inference_steps=args.num_inference_steps, num_images_per_prompt=args.batch_size, ).images[0] image.save(img_name)
I am on quanto
main
.
Run it with
python benchmark_reduced_precision_unet.py --torch_dtype=fp16 --unet_dtype=fp8
Beta Was this translation helpful? Give feedback.
All reactions
-
Has any progress been made on supporting fp8_e4m3fn in diffusers?
Beta Was this translation helpful? Give feedback.
All reactions
-
Well, you can use layerwise casting as well as FP8 from torchao.
Beta Was this translation helpful? Give feedback.