Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[Quantization] Add TRT-ModelOpt as a Backend #11173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
sayakpaul merged 57 commits into huggingface:main from ishan-modi:add-trtquant-backend
Sep 3, 2025

Conversation

Copy link
Contributor

@ishan-modi ishan-modi commented Mar 30, 2025
edited
Loading

What does this PR do?

WIP, aimed at adding new backend for quantization #11032. For now, this PR just works for on-the-fly quantization. Loading pre-quantized models errors out and it is to be fixed by NVIDIA team in next release early may

Depends on

  • (削除) this to support latest diffusers (削除ここまで)
  • (削除) this to enable INT8 quantization (削除ここまで)
  • (削除) this to enable NF4 quantization (削除ここまで)
Code
# !pip install "git+https://github.com/ishan-modi/diffusers.git@add-trtquant-backend#egg=diffusers[nvidia_modelopt]"
import torch
from tqdm import tqdm
from diffusers import SanaPipeline, SanaTransformer2DModel, StableDiffusion3Pipeline, , SD3Transformer2DModel
from diffusers.quantizers.quantization_config import NVIDIAModelOptConfig
checkpoint = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
model_cls = SanaTransformer2DModel
pipe_cls = SanaPipeline
# checkpoint = "stabilityai/stable-diffusion-3-medium-diffusers"
# model_cls = SD3Transformer2DModel
# pipe_cls = StableDiffusion3Pipeline
input = {"prompt":"A capybara holding a sign that reads Hello World", "num_inference_steps":28, "guidance_scale":3.5}
quant_config_fp8 = {"quant_type": "FP8", "quant_method": "modelopt"}
quant_config_int8 = {"quant_type": "INT8", "quant_method": "modelopt"}
quant_config_int4 = {"quant_type": "INT4", "quant_method": "modelopt", "block_quantize": 128, "channel_quantize": -1, "modules_to_not_convert": ["conv", "patch_embed"]}
quant_config_nf4 = {"quant_type": "NF4", "quant_method": "modelopt", "block_quantize": 128, "channel_quantize": -1, "scale_block_quantize": 8, "scale_channel_quantize": -1, "modules_to_not_convert": ["conv"]}
quant_config_nvfp4 = {"quant_type": "NVFP4", "quant_method": "modelopt", "block_quantize": 128, "channel_quantize": -1, "modules_to_not_convert": ["conv"]}
def test_quantization(config, checkpoint, model_cls):
 quant_config = NVIDIAModelOptConfig(**config)
 print(quant_config.get_config_from_quant_type())
 quant_model = model_cls.from_pretrained(checkpoint, subfolder="transformer", quantization_config=quant_config, torch_dtype=torch.bfloat16).to('cuda')
 return quant_model
def test_quant_inference(model, input, pipe_cls, iter=1):
 inference_memory = 0
 for _ in tqdm(range(iter)):
 with torch.no_grad():
 output = pipe_cls.from_pretrained(checkpoint, transformer=model, torch_dtype=torch.bfloat16).to('cuda')(**input).images[0]
 inference_memory += torch.cuda.max_memory_allocated()
 inference_memory /= iter
 output.save("test.png")
 print("Inference Memory: ", inference_memory / 1e6)
 torch.cuda.empty_cache()
 torch.cuda.reset_max_memory_allocated()
test_quant_inference(test_quantization(quant_config_fp8, checkpoint, model_cls), input, pipe_cls)
# test_quant_inference(test_quantization(quant_config_int8, checkpoint, model_cls), input, pipe_cls)
# test_quant_inference(test_quantization(quant_config_int4, checkpoint, model_cls), input, pipe_cls)
# test_quant_inference(test_quantization(quant_config_nf4, checkpoint, model_cls), input, pipe_cls)
# test_quant_inference(test_quantization(quant_config_nvfp4, checkpoint, model_cls), input, pipe_cls)
# test_quant_inference(model_cls.from_pretrained(checkpoint, subfolder="transformer", torch_dtype=torch.bfloat16).to('cuda'), input, pipe_cls)

Following is a discussion on speedups while using real_quant with NVIDIA team here

borhanMorphy reacted with hooray emoji
@ishan-modi ishan-modi marked this pull request as draft March 30, 2025 07:45
Copy link
Contributor Author

@sayakpaul, would you mind giving a quick look and suggestions

Copy link
Member

Thanks for getting started on this. I guess there is a problem here: NVIDIA/TensorRT-Model-Optimizer#165? Additionally, the API should have a TRTConfig in place of just a dict being the quantization config.

Copy link
Contributor Author

I think the problem has been fixed the newest release, I just need to bump it up in diffusers requirements, also we can do the following for passing Config class

from diffusers.quantizers.quantization_config import ModelOptConfig
quant_config = ModelOptConfig(quant_type="FP8_WO", modules_to_not_convert=["conv"])
model = SanaTransformer2DModel.from_pretrained(checkpoint, subfolder="transformer", quantization_config=quant_config...

by TRTConfig did you mean including the config classes from ModelOptimizer here ?

Copy link
Member

We use namings like BitsAndBytesConfig depending on the backend. See here:
https://github.com/huggingface/diffusers/blob/fb54499614f9603bfaa4c026202c5783841b3a80/src/diffusers/quantizers/quantization_config.py#L177C7-L177C25

So, in this case, we should be using TRTConfig or something similar.

Copy link
Member

I think the problem has been fixed the newest release, I just need to bump it up in diffusers requirements

Alright, let's try with the latest fixes then.

ishan-modi reacted with thumbs up emoji

Copy link
Contributor Author

The newer version wasn't backward compatible hence the issues, I have fixed it.

Related to naming, package name is nvidia_modelopt, hence ModelOpt, but I can make it TRTModelOpt if you'd like ?

Copy link
Member

Doesn't it have any reliance on tensorrt?

Copy link
Contributor Author

No it doesn't, we can use TRT to compile the quantized model

Copy link
Member

No it doesn't, we can use TRT to compile the quantized model

Could you elaborate what you mean by this?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks nice. Could you demonstrate some memory savings and any speedups when using modelopt, please? We can then add tests, docs, etc.

@sayakpaul sayakpaul requested a review from DN6 April 9, 2025 05:31
Copy link
Contributor Author

Could you elaborate what you mean by this?

Yeah, so for quantizing the model we dont use tensorRT, but once the model is quantized we can compile the model using tensorrt.

Copy link
Contributor Author

ishan-modi commented Apr 23, 2025
edited
Loading

💾 Model & Inference Memory (in MB)

Quantization Type SanaTransformer2DModel SD3Transformer2DModel FluxTransformer2DModel
Model Size (MB) Inference (MB) Model Size (MB) Inference (MB) Model Size (MB) Inference (MB)
FP8 592.48 2060.80 2142.78 4405.61 13697.76 14438.87
INT4 306.20 2072.30 1160.89 3443.98 8803.12 9559.62
NVFP4 721.99 Err 1148.20 Err 8724.27 Err
Original (BF16) 1183.50 2642.37 4169.90 6457.41 23802.82 -

Following is the code

import torch
from tqdm import tqdm
from diffusers import SanaTransformer2DModel, SD3Transformer2DModel, FluxTransformer2DModel
from diffusers.quantizers.quantization_config import NVIDIAModelOptConfig
checkpoint = "Efficient-Large-Model/Sana_600M_1024px_diffusers"
model_cls = SanaTransformer2DModel
# checkpoint = "stabilityai/stable-diffusion-3-medium-diffusers"
# model_cls = SD3Transformer2DModel
# checkpoint = "black-forest-labs/FLUX.1-dev"
# model_cls = FluxTransformer2DModel
input = lambda: (torch.randn((2, 32, 32, 32), dtype=torch.bfloat16).to('cuda'), torch.randn((2,10,300,2304), dtype=torch.bfloat16).to('cuda'), torch.Tensor([0,0]).to('cuda'))
# input = lambda: (torch.randn((1,16,96,96), dtype=torch.bfloat16).to('cuda'), torch.randn((1,300,4096), dtype=torch.bfloat16).to('cuda'), torch.randn((1, 2048), dtype=torch.bfloat16).to('cuda'), torch.Tensor([0]).to('cuda'))
# input = lambda: (torch.randn((1,1024, 64), dtype=torch.bfloat16).to('cuda'), torch.randn((1,300,4096), dtype=torch.bfloat16).to('cuda'), torch.randn((1, 768), dtype=torch.bfloat16).to('cuda'), torch.Tensor([0]).to('cuda'), torch.randn((300, 3)).to('cuda'), torch.randn((1024, 3)).to('cuda'), torch.Tensor([0]).to('cuda'))
quant_config_fp8 = {"quant_type": "FP8", "quant_method": "modelopt"}
quant_config_int4 = {"quant_type": "INT4", "quant_method": "modelopt", "block_quantize": 128, "channel_quantize": -1}
quant_config_nvfp4 = {"quant_type": "NVFP4", "quant_method": "modelopt", "block_quantize": 128, "channel_quantize": -1, 'modules_to_not_convert' : ['conv']}
def test_quantization(config, checkpoint, model_cls):
 quant_config = NVIDIAModelOptConfig(**config)
 print(quant_config.get_config_from_quant_type())
 quant_model = model_cls.from_pretrained(checkpoint, subfolder="transformer", quantization_config=quant_config, torch_dtype=torch.bfloat16, device_map="balanced").to('cuda')
 print(f"Quant {config['quant_type']} Model Memory Footprint: ", quant_model.get_memory_footprint() / 1e6)
 return quant_model
def test_quant_inference(model, input, iter=10):
 torch.cuda.empty_cache()
 torch.cuda.reset_max_memory_allocated()
 inference_memory = 0
 for _ in tqdm(range(iter)):
 with torch.no_grad():
 output = model(*input())
 inference_memory += torch.cuda.max_memory_allocated()
 inference_memory /= iter
 print("Inference Memory: ", inference_memory / 1e6)
test_quant_inference(test_quantization(quant_config_fp8, checkpoint, model_cls), input)
# test_quant_inference(test_quantization(quant_config_int4, checkpoint, model_cls), input)
# test_quant_inference(test_quantization(quant_config_nvfp4, checkpoint, model_cls), input)
# test_quant_inference(model_cls.from_pretrained(checkpoint, subfolder="transformer", torch_dtype=torch.bfloat16).to('cuda'), input)

Speed Ups

There is no significant speedup between the different quantizations because internally modelopt still uses high precision arithmetic (float32).

Sorry for being a bit late on this, @sayakpaul let me know next steps !

sayakpaul reacted with thumbs up emoji

@ishan-modi ishan-modi marked this pull request as ready for review April 23, 2025 23:28
@ishan-modi ishan-modi changed the title (削除) [WIP] Add TRT as a Backend (削除ここまで) (追記) [Quantization] Add TRT as a Backend (追記ここまで) Apr 24, 2025
Copy link
Member

@ishan-modi let us know if this is ready to be reviewed.

Copy link
Contributor Author

ishan-modi commented Apr 25, 2025
edited
Loading

@sayakpaul, I think it is ready for preliminary review, on-the-fly quantization works fine. But loading pre-quantized models errors out and will be fixed in next release here (early may) by NVIDIA team.

@jingyu-ml, just so that you are in the loop

Copy link
Member

@sayakpaul sayakpaul left a comment
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good so far!

(削除) Could you also demonstrate some memory and timing numbers with the modelopt toolkit and some visual results? (削除ここまで)

No need, just saw #11173 (comment). But it doesn't measure the inference memory which is usually done via torch.cuda.max_memory_allocated(). Could we also see those numbers? Would it be possible to make it clear in the PR description that

on-the-fly quantization works fine. But loading pre-quantized models errors out and will be fixed in next release NVIDIA/TensorRT-Model-Optimizer#185 (early may) by NVIDIA team.

@jingyu-ml is it expected to not see any speedups in latency?

jingyu-ml reacted with eyes emoji
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this Just some nits, it could be nice to add this quantization scheme to transformers after this gets merged !

ishan-modi reacted with eyes emoji
Copy link
Member

@ishan-modi just a quick question. Do we know if the nunchaku SVDQuant method is supported through modelopt? From https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#quantization-examples-docs, it seems like it is supported. But could you confirm?

Copy link
Contributor Author

@sayakpaul, yes modelopt does support SVDQuant, but in this integration we support only min-max based calibration see here. I think we should iteratively add advanced quantizations like svd_quant and awq once we have the base going, let me know if you think otherwise.

Copy link
Member

That's fine. I wanted to because I think if we can support svd_quant through our modelopt backend, I am happy to drop #12207. Hence wanted to check.

ishan-modi reacted with thumbs up emoji

Copy link
Member

Will merge after @DN6 has had a chance to review. @ishan-modi can we also include a note in the docs that just performing the conversion step with modelopt won't lead to speed improvements (as pointed out here)?

@realAsma @jingyu-ml after this PR is merged, we could plan writing a post/guide on how to take a modelopt converted diffusers pipeline and use in deployment settings for realizing the actual speed gains.

ishan-modi reacted with rocket emoji

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work @ishan-modi 👍🏽 Thank you 🙏🏽

ishan-modi reacted with rocket emoji
Copy link
Member

@ishan-modi can we fix the remaining CI problems and then we should be good to go.

ishan-modi reacted with eyes emoji

Copy link
Contributor Author

@sayakpaul, should be fixed now.

@sayakpaul sayakpaul merged commit 4acbfbf into huggingface:main Sep 3, 2025
12 of 13 checks passed
Copy link
Member

Congratulations on shipping this thing, @ishan-modi! Thank you!

Let's maybe now focus on the following things to maximize the potential impact:

  • SVDQuant Support
  • Guide to actually benefit from speedups

Happy to help.

ishan-modi and SunMarc reacted with heart emoji

@ishan-modi ishan-modi deleted the add-trtquant-backend branch September 3, 2025 05:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Reviewers

@sayakpaul sayakpaul sayakpaul left review comments

@SunMarc SunMarc SunMarc approved these changes

@DN6 DN6 DN6 approved these changes

@realAsma realAsma Awaiting requested review from realAsma

+2 more reviewers

@kevalmorabia97 kevalmorabia97 kevalmorabia97 left review comments

@jingyu-ml jingyu-ml jingyu-ml left review comments

Reviewers whose approvals may not affect merge requirements
Assignees
No one assigned
Labels
None yet
Projects
None yet
Milestone
No milestone
Development

Successfully merging this pull request may close these issues.

AltStyle によって変換されたページ (->オリジナル) /