|
| 1 | +<!-- Copyright 2025 The HuggingFace Team. All rights reserved. |
| 2 | + |
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
| 4 | +the License. You may obtain a copy of the License at |
| 5 | + |
| 6 | +http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | + |
| 8 | +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
| 9 | +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
| 10 | +specific language governing permissions and limitations under the License. --> |
| 11 | + |
| 12 | +# NVIDIA ModelOpt |
| 13 | + |
| 14 | +[NVIDIA-ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) is a unified library of state-of-the-art model optimization techniques like quantization, pruning, distillation, speculative decoding, etc. It compresses deep learning models for downstream deployment frameworks like TensorRT-LLM or TensorRT to optimize inference speed. |
| 15 | + |
| 16 | +Before you begin, make sure you have nvidia_modelopt installed. |
| 17 | + |
| 18 | +```bash |
| 19 | +pip install -U "nvidia_modelopt[hf]" |
| 20 | +``` |
| 21 | + |
| 22 | +Quantize a model by passing [`NVIDIAModelOptConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. |
| 23 | + |
| 24 | +The example below only quantizes the weights to FP8. |
| 25 | + |
| 26 | +```python |
| 27 | +import torch |
| 28 | +from diffusers import AutoModel, SanaPipeline, NVIDIAModelOptConfig |
| 29 | + |
| 30 | +model_id = "Efficient-Large-Model/Sana_600M_1024px_diffusers" |
| 31 | +dtype = torch.bfloat16 |
| 32 | + |
| 33 | +quantization_config = NVIDIAModelOptConfig(quant_type="FP8", quant_method="modelopt") |
| 34 | +transformer = AutoModel.from_pretrained( |
| 35 | + model_id, |
| 36 | + subfolder="transformer", |
| 37 | + quantization_config=quantization_config, |
| 38 | + torch_dtype=dtype, |
| 39 | +) |
| 40 | +pipe = SanaPipeline.from_pretrained( |
| 41 | + model_id, |
| 42 | + transformer=transformer, |
| 43 | + torch_dtype=dtype, |
| 44 | +) |
| 45 | +pipe.to("cuda") |
| 46 | + |
| 47 | +print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB") |
| 48 | + |
| 49 | +prompt = "A cat holding a sign that says hello world" |
| 50 | +image = pipe( |
| 51 | + prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 |
| 52 | +).images[0] |
| 53 | +image.save("output.png") |
| 54 | +``` |
| 55 | + |
| 56 | +> **Note:** |
| 57 | +> |
| 58 | +> The quantization methods in NVIDIA-ModelOpt are designed to reduce the memory footprint of model weights using various QAT (Quantization-Aware Training) and PTQ (Post-Training Quantization) techniques while maintaining model performance. However, the actual performance gain during inference depends on the deployment framework (e.g., TRT-LLM, TensorRT) and the specific hardware configuration. |
| 59 | +> |
| 60 | +> More details can be found [here](https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples). |
| 61 | + |
| 62 | +## NVIDIAModelOptConfig |
| 63 | + |
| 64 | +The `NVIDIAModelOptConfig` class accepts three parameters: |
| 65 | +- `quant_type`: A string value mentioning one of the quantization types below. |
| 66 | +- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`SD3Transformer2DModel`]'s pos_embed projection blocks, one would specify: `modules_to_not_convert=["pos_embed.proj.weight"]`. |
| 67 | +- `disable_conv_quantization`: A boolean value which when set to `True` disables quantization for all convolutional layers in the model. This is useful as channel and block quantization generally don't work well with convolutional layers (used with INT4, NF4, NVFP4). If you want to disable quantization for specific convolutional layers, use `modules_to_not_convert` instead. |
| 68 | +- `algorithm`: The algorithm to use for determining scale, defaults to `"max"`. You can check modelopt documentation for more algorithms and details. |
| 69 | +- `forward_loop`: The forward loop function to use for calibrating activation during quantization. If not provided, it relies on static scale values computed using the weights only. |
| 70 | +- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`. |
| 71 | + |
| 72 | +## Supported quantization types |
| 73 | + |
| 74 | +ModelOpt supports weight-only, channel and block quantization int8, fp8, int4, nf4, and nvfp4. The quantization methods are designed to reduce the memory footprint of the model weights while maintaining the performance of the model during inference. |
| 75 | + |
| 76 | +Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation. |
| 77 | + |
| 78 | +The quantization methods supported are as follows: |
| 79 | + |
| 80 | +| **Quantization Type** | **Supported Schemes** | **Required Kwargs** | **Additional Notes** | |
| 81 | +|-----------------------|-----------------------|---------------------|----------------------| |
| 82 | +| **INT8** | `int8 weight only`, `int8 channel quantization`, `int8 block quantization` | `quant_type`, `quant_type + channel_quantize`, `quant_type + channel_quantize + block_quantize` | |
| 83 | +| **FP8** | `fp8 weight only`, `fp8 channel quantization`, `fp8 block quantization` | `quant_type`, `quant_type + channel_quantize`, `quant_type + channel_quantize + block_quantize` | |
| 84 | +| **INT4** | `int4 weight only`, `int4 block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize` | `channel_quantize = -1 is only supported for now`| |
| 85 | +| **NF4** | `nf4 weight only`, `nf4 double block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize + scale_channel_quantize` + `scale_block_quantize` | `channel_quantize = -1 and scale_channel_quantize = -1 are only supported for now` | |
| 86 | +| **NVFP4** | `nvfp4 weight only`, `nvfp4 block quantization` | `quant_type`, `quant_type + channel_quantize + block_quantize` | `channel_quantize = -1 is only supported for now`| |
| 87 | + |
| 88 | + |
| 89 | +Refer to the [official modelopt documentation](https://nvidia.github.io/TensorRT-Model-Optimizer/) for a better understanding of the available quantization methods and the exhaustive list of configuration options available. |
| 90 | + |
| 91 | +## Serializing and Deserializing quantized models |
| 92 | + |
| 93 | +To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method. |
| 94 | + |
| 95 | +```python |
| 96 | +import torch |
| 97 | +from diffusers import AutoModel, NVIDIAModelOptConfig |
| 98 | +from modelopt.torch.opt import enable_huggingface_checkpointing |
| 99 | + |
| 100 | +enable_huggingface_checkpointing() |
| 101 | + |
| 102 | +model_id = "Efficient-Large-Model/Sana_600M_1024px_diffusers" |
| 103 | +quant_config_fp8 = {"quant_type": "FP8", "quant_method": "modelopt"} |
| 104 | +quant_config_fp8 = NVIDIAModelOptConfig(**quant_config_fp8) |
| 105 | +model = AutoModel.from_pretrained( |
| 106 | + model_id, |
| 107 | + subfolder="transformer", |
| 108 | + quantization_config=quant_config_fp8, |
| 109 | + torch_dtype=torch.bfloat16, |
| 110 | +) |
| 111 | +model.save_pretrained('path/to/sana_fp8', safe_serialization=False) |
| 112 | +``` |
| 113 | + |
| 114 | +To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method. |
| 115 | + |
| 116 | +```python |
| 117 | +import torch |
| 118 | +from diffusers import AutoModel, NVIDIAModelOptConfig, SanaPipeline |
| 119 | +from modelopt.torch.opt import enable_huggingface_checkpointing |
| 120 | + |
| 121 | +enable_huggingface_checkpointing() |
| 122 | + |
| 123 | +quantization_config = NVIDIAModelOptConfig(quant_type="FP8", quant_method="modelopt") |
| 124 | +transformer = AutoModel.from_pretrained( |
| 125 | + "path/to/sana_fp8", |
| 126 | + subfolder="transformer", |
| 127 | + quantization_config=quantization_config, |
| 128 | + torch_dtype=torch.bfloat16, |
| 129 | +) |
| 130 | +pipe = SanaPipeline.from_pretrained( |
| 131 | + "Efficient-Large-Model/Sana_600M_1024px_diffusers", |
| 132 | + transformer=transformer, |
| 133 | + torch_dtype=torch.bfloat16, |
| 134 | +) |
| 135 | +pipe.to("cuda") |
| 136 | +prompt = "A cat holding a sign that says hello world" |
| 137 | +image = pipe( |
| 138 | + prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 |
| 139 | +).images[0] |
| 140 | +image.save("output.png") |
| 141 | +``` |
0 commit comments