From 8e1ea006f0975c5f5c98423c438b47c29981b370 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: 2025年8月21日 12:25:59 +0530 Subject: [PATCH 01/13] start nunchaku. --- .../quantizers/gguf/gguf_quantizer.py | 2 +- .../quantizers/nunchaku/nunchaku_quantizer.py | 182 ++++++++++++++++++ src/diffusers/quantizers/nunchaku/utils.py | 77 ++++++++ .../quantizers/quantization_config.py | 38 ++++ src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 22 ++- src/diffusers/utils/torch_utils.py | 4 + 7 files changed, 324 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py create mode 100644 src/diffusers/quantizers/nunchaku/utils.py diff --git a/src/diffusers/quantizers/gguf/gguf_quantizer.py b/src/diffusers/quantizers/gguf/gguf_quantizer.py index aa5ebf5711a3..b0f6c8a2a95a 100644 --- a/src/diffusers/quantizers/gguf/gguf_quantizer.py +++ b/src/diffusers/quantizers/gguf/gguf_quantizer.py @@ -90,7 +90,7 @@ def check_quantized_param_shape(self, param_name, current_param, loaded_param): def check_if_quantized_param( self, model: "ModelMixin", - param_value: Union["GGUFParameter", "torch.Tensor"], + param_value: Union["torch.Tensor"], param_name: str, state_dict: Dict[str, Any], **kwargs, diff --git a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py new file mode 100644 index 000000000000..d7a26307d9d0 --- /dev/null +++ b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py @@ -0,0 +1,182 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from diffusers.utils.import_utils import is_nunchaku_version + +from ...utils import ( + get_module_from_name, + is_accelerate_available, + is_nunchaku_available, + is_torch_available, + logging, +) +from ...utils.torch_utils import is_fp8_available +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +if is_torch_available(): + import torch + +if is_accelerate_available(): + pass + +if is_nunchaku_available(): + from .utils import replace_with_nunchaku_linear + +logger = logging.get_logger(__name__) + + +class QuantoQuantizer(DiffusersQuantizer): + r""" + Diffusers Quantizer for Optimum Quanto + """ + + use_keep_in_fp32_modules = True + requires_calibration = False + required_packages = ["nunchaku", "accelerate"] + + dtype_map = {"int4": torch.int8} + if is_fp8_available(): + dtype_map = {"nvfp4": torch.float8_e4m3fn} + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, *args, **kwargs): + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for nunchaku quantization.") + + if not is_nunchaku_available(): + raise ImportError( + "Loading an nunchaku quantized model requires nunchaku library (follow https://nunchaku.tech/docs/nunchaku/installation/installation.html)" + ) + if not is_nunchaku_version(">=", "0.3.1"): + raise ImportError( + "Loading an nunchaku quantized model requires `nunchaku>=1.0.0`. " + "Please upgrade your installation by following https://nunchaku.tech/docs/nunchaku/installation/installation.html." + ) + + if not is_accelerate_available(): + raise ImportError( + "Loading an nunchaku quantized model requires accelerate library (`pip install accelerate`)" + ) + + # TODO: check + # device_map = kwargs.get("device_map", None) + # if isinstance(device_map, dict) and len(device_map.keys())> 1: + # raise ValueError( + # "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend" + # ) + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ): + # Quanto imports diffusers internally. This is here to prevent circular imports + from nunchaku.models.linear import SVDQW4A4Linear + + module, tensor_name = get_module_from_name(model, param_name) + if self.pre_quantized and isinstance(module, SVDQW4A4Linear): + return True + + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + *args, + **kwargs, + ): + """ + Create a quantized parameter. + """ + from nunchaku.models.linear import SVDQW4A4Linear + + module, tensor_name = get_module_from_name(model, param_name) + if tensor_name not in module._parameters and tensor_name not in module._buffers: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + + if self.pre_quantized: + if tensor_name in module._parameters: + module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) + if tensor_name in module._buffers: + module._buffers[tensor_name] = torch.nn.Parameter(param_value.to(target_device)) + + elif isinstance(module, torch.nn.Linear): + if tensor_name in module._parameters: + module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) + if tensor_name in module._buffers: + module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(target_device) + + new_module = SVDQW4A4Linear.from_linear(module) + setattr(model, param_name, new_module) + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + precision = self.quantization_config.precision + expected_target_dtypes = [torch.int8] + if is_fp8_available(): + expected_target_dtypes.append(torch.float8_e4m3fn) + if target_dtype not in expected_target_dtypes: + new_target_dtype = self.dtype_map[precision] + + logger.info(f"target_dtype {target_dtype} is replaced by {new_target_dtype} for `nunchaku` quantization") + return new_target_dtype + else: + raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.") + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + # We force the `dtype` to be bfloat16, this is a requirement from `bitsandbytes` + logger.info( + "Overriding torch_dtype=%s with `torch_dtype=torch.bfloat16` due to " + "requirements of `nunchaku` to enable model loading in 4-bit. " + "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass" + " torch_dtype=torch.bfloat16 to remove this warning.", + torch_dtype, + ) + torch_dtype = torch.bfloat16 + return torch_dtype + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + # TODO: deal with `device_map` + self.modules_to_not_convert = self.quantization_config.modules_to_not_convert + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + model = replace_with_nunchaku_linear( + model, + modules_to_not_convert=self.modules_to_not_convert, + quantization_config=self.quantization_config, + pre_quantized=self.pre_quantized, + ) + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model, **kwargs): + return model + + # @property + # def is_serializable(self): + # return True diff --git a/src/diffusers/quantizers/nunchaku/utils.py b/src/diffusers/quantizers/nunchaku/utils.py new file mode 100644 index 000000000000..abdcd186dcdd --- /dev/null +++ b/src/diffusers/quantizers/nunchaku/utils.py @@ -0,0 +1,77 @@ +import torch.nn as nn + +from ...utils import is_accelerate_available, is_nunchaku_available, logging + + +if is_accelerate_available(): + from accelerate import init_empty_weights + +if is_nunchaku_available(): + from nunchaku.models.linear import SVDQW4A4Linear + + +logger = logging.get_logger(__name__) + + +def _replace_with_nunchaku_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, +): + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + with init_empty_weights(): + in_features = module.in_features + out_features = module.out_features + + if quantization_config.precision in ["int4", "nvfp4"]: + model._modules[name] = SVDQW4A4Linear( + in_features, + out_features, + rank=quantization_config.rank, + bias=module.bias is not None, + dtype=model.dtype, + ) + has_been_replaced = True + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + if len(list(module.children()))> 0: + _, has_been_replaced = _replace_with_nunchaku_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): + model, _ = _replace_with_nunchaku_linear(model, modules_to_not_convert, current_key_name, quantization_config) + + has_been_replaced = any( + isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules() + ) + if not has_been_replaced: + logger.warning( + "You are loading your model in the SVDQuant method but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + return model diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 871faf076e5a..e83c9e72d4f0 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -46,6 +46,7 @@ class QuantizationMethod(str, Enum): GGUF = "gguf" TORCHAO = "torchao" QUANTO = "quanto" + NUNCHAKU = "nunchaku" if is_torchao_available(): @@ -724,3 +725,40 @@ def post_init(self): accepted_weights = ["float8", "int8", "int4", "int2"] if self.weights_dtype not in accepted_weights: raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}") + + +class NunchakuConfig(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using `nunchaku`. + + Args: + TODO + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have some + modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers). + """ + + group_size_map = {"int4": 64, "nvfp4": 16} + + def __init__( + self, + precision: str = "int4", + rank: int = 32, + modules_to_not_convert: Optional[List[str]] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.NUNCHAKU + self.precision = precision + self.group_size = self.group_size_map[precision] + self.modules_to_not_convert = modules_to_not_convert + + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct + """ + accpeted_precision = ["int4", "nvfp4"] + if self.precision not in accpeted_precision: + raise ValueError(f"Only supported precision in {accpeted_precision} but found {self.precision}") diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index b27cf981edeb..12f948414c4d 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -89,6 +89,7 @@ is_matplotlib_available, is_nltk_available, is_note_seq_available, + is_nunchaku_available, is_onnx_available, is_opencv_available, is_optimum_quanto_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index ac209afb74a6..80fd2acbbf4b 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -217,6 +217,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _torchao_available, _torchao_version = _is_package_available("torchao") _bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes") _optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True) +_nunchaku_available, _nunchaku_version = _is_package_available("nunchaku", get_dist_name=True) _pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_available("pytorch_retinaface") _better_profanity_available, _better_profanity_version = _is_package_available("better_profanity") _nltk_available, _nltk_version = _is_package_available("nltk") @@ -363,6 +364,10 @@ def is_optimum_quanto_available(): return _optimum_quanto_available +def is_nunchaku_available(): + return _nunchaku_available + + def is_timm_available(): return _timm_available @@ -816,7 +821,7 @@ def is_k_diffusion_version(operation: str, version: str): def is_optimum_quanto_version(operation: str, version: str): """ - Compares the current Accelerate version to a given reference with an operation. + Compares the current quanto version to a given reference with an operation. Args: operation (`str`): @@ -829,6 +834,21 @@ def is_optimum_quanto_version(operation: str, version: str): return compare_versions(parse(_optimum_quanto_version), operation, version) +def is_nunchaku_version(operation: str, version: str): + """ + Compares the current nunchaku version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _nunchaku_available: + return False + return compare_versions(parse(_nunchaku_version), operation, version) + + def is_xformers_version(operation: str, version: str): """ Compares the current xformers version to a given reference with an operation. diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 5bc708a60c29..c1bf2aed7890 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -197,3 +197,7 @@ def device_synchronize(device_type: Optional[str] = None): device_type = get_device() device_mod = getattr(torch, device_type, torch.cuda) device_mod.synchronize() + + +def is_fp8_available(): + return getattr(torch, "float8_e4m3fn", None) is None From 7022169c135a166497e63d47d0b8c4c404d69a54 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: 2025年8月21日 12:27:16 +0530 Subject: [PATCH 02/13] up --- src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py index d7a26307d9d0..e80504f650be 100644 --- a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py +++ b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py @@ -29,9 +29,9 @@ logger = logging.get_logger(__name__) -class QuantoQuantizer(DiffusersQuantizer): +class NunChakuQuantizer(DiffusersQuantizer): r""" - Diffusers Quantizer for Optimum Quanto + Diffusers Quantizer for Nunchaku (https://github.com/nunchaku-tech/nunchaku) """ use_keep_in_fp32_modules = True @@ -68,7 +68,7 @@ def validate_environment(self, *args, **kwargs): # device_map = kwargs.get("device_map", None) # if isinstance(device_map, dict) and len(device_map.keys())> 1: # raise ValueError( - # "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the Quanto backend" + # "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the nunchaku backend" # ) def check_if_quantized_param( @@ -79,7 +79,6 @@ def check_if_quantized_param( state_dict: Dict[str, Any], **kwargs, ): - # Quanto imports diffusers internally. This is here to prevent circular imports from nunchaku.models.linear import SVDQW4A4Linear module, tensor_name = get_module_from_name(model, param_name) @@ -140,7 +139,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if torch_dtype is None: - # We force the `dtype` to be bfloat16, this is a requirement from `bitsandbytes` + # We force the `dtype` to be bfloat16, this is a requirement from `nunchaku` logger.info( "Overriding torch_dtype=%s with `torch_dtype=torch.bfloat16` due to " "requirements of `nunchaku` to enable model loading in 4-bit. " From f4262b8877d61521cae820f19e1cce06065abcfe Mon Sep 17 00:00:00 2001 From: sayakpaul Date: 2025年8月21日 13:01:19 +0530 Subject: [PATCH 03/13] up --- src/diffusers/quantizers/auto.py | 4 ++++ src/diffusers/quantizers/nunchaku/__init__.py | 1 + 2 files changed, 5 insertions(+) create mode 100644 src/diffusers/quantizers/nunchaku/__init__.py diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index ce214ae7bc17..93691b3b53a2 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -21,9 +21,11 @@ from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer from .gguf import GGUFQuantizer +from .nunchaku import NunChakuQuantizer from .quantization_config import ( BitsAndBytesConfig, GGUFQuantizationConfig, + NunchakuConfig, QuantizationConfigMixin, QuantizationMethod, QuantoConfig, @@ -39,6 +41,7 @@ "gguf": GGUFQuantizer, "quanto": QuantoQuantizer, "torchao": TorchAoHfQuantizer, + "nunchaku": NunChakuQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -47,6 +50,7 @@ "gguf": GGUFQuantizationConfig, "quanto": QuantoConfig, "torchao": TorchAoConfig, + "nunchaku": NunchakuConfig, } diff --git a/src/diffusers/quantizers/nunchaku/__init__.py b/src/diffusers/quantizers/nunchaku/__init__.py new file mode 100644 index 000000000000..04b57f831a6b --- /dev/null +++ b/src/diffusers/quantizers/nunchaku/__init__.py @@ -0,0 +1 @@ +from .nunchaku_quantizer import NunChakuQuantizer From ac1aa8bbecd1efa2fdcb821567964cbb6f11caaa Mon Sep 17 00:00:00 2001 From: sayakpaul Date: 2025年8月21日 13:05:12 +0530 Subject: [PATCH 04/13] up --- src/diffusers/__init__.py | 21 +++++++++++++++++++ src/diffusers/utils/dummy_nunchaku_objects.py | 17 +++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 src/diffusers/utils/dummy_nunchaku_objects.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3f0f87b92609..7b1648c3dcce 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -13,6 +13,7 @@ is_k_diffusion_available, is_librosa_available, is_note_seq_available, + is_nunchaku_available, is_onnx_available, is_opencv_available, is_optimum_quanto_available, @@ -99,6 +100,18 @@ else: _import_structure["quantizers.quantization_config"].append("TorchAoConfig") +try: + if not is_torch_available() and not is_accelerate_available() and not is_nunchaku_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_nunchaku_objects + + _import_structure["utils.dummy_nunchaku_objects"] = [ + name for name in dir(dummy_nunchaku_objects) if not name.startswith("_") + ] +else: + _import_structure["quantizers.quantization_config"].append("NunchakuConfig") + try: if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available(): raise OptionalDependencyNotAvailable() @@ -791,6 +804,14 @@ else: from .quantizers.quantization_config import QuantoConfig + try: + if not is_nunchaku_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_optimum_quanto_objects import * + else: + from .quantizers.quantization_config import NunchakuConfig + try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/utils/dummy_nunchaku_objects.py b/src/diffusers/utils/dummy_nunchaku_objects.py new file mode 100644 index 000000000000..2de7cd7c0a69 --- /dev/null +++ b/src/diffusers/utils/dummy_nunchaku_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class NunchakuConfig(metaclass=DummyObject): + _backends = ["nunchaku"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["nunchaku"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["nunchaku"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["nunchaku"]) From 269813fcc561694f738f63e6d22d1977697cc234 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: 2025年8月21日 13:40:32 +0530 Subject: [PATCH 05/13] up --- src/diffusers/quantizers/auto.py | 4 ++-- src/diffusers/quantizers/nunchaku/__init__.py | 2 +- src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index 93691b3b53a2..910557865ea0 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -21,7 +21,7 @@ from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer from .gguf import GGUFQuantizer -from .nunchaku import NunChakuQuantizer +from .nunchaku import NunchakuQuantizer from .quantization_config import ( BitsAndBytesConfig, GGUFQuantizationConfig, @@ -41,7 +41,7 @@ "gguf": GGUFQuantizer, "quanto": QuantoQuantizer, "torchao": TorchAoHfQuantizer, - "nunchaku": NunChakuQuantizer, + "nunchaku": NunchakuQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { diff --git a/src/diffusers/quantizers/nunchaku/__init__.py b/src/diffusers/quantizers/nunchaku/__init__.py index 04b57f831a6b..759039054073 100644 --- a/src/diffusers/quantizers/nunchaku/__init__.py +++ b/src/diffusers/quantizers/nunchaku/__init__.py @@ -1 +1 @@ -from .nunchaku_quantizer import NunChakuQuantizer +from .nunchaku_quantizer import NunchakuQuantizer diff --git a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py index e80504f650be..2e13d78ba349 100644 --- a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py +++ b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py @@ -29,7 +29,7 @@ logger = logging.get_logger(__name__) -class NunChakuQuantizer(DiffusersQuantizer): +class NunchakuQuantizer(DiffusersQuantizer): r""" Diffusers Quantizer for Nunchaku (https://github.com/nunchaku-tech/nunchaku) """ From 9e0caa7afcd5d3a4c4c65cdfc0548acd6cf70041 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: 2025年8月21日 14:02:54 +0530 Subject: [PATCH 06/13] up --- .../quantizers/nunchaku/nunchaku_quantizer.py | 11 +++++---- src/diffusers/quantizers/nunchaku/utils.py | 12 ++++++---- .../quantizers/quantization_config.py | 23 +++++++++++++++++++ 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py index 2e13d78ba349..113d9a4ba1d2 100644 --- a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py +++ b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py @@ -169,13 +169,16 @@ def _process_model_before_weight_loading( model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config, - pre_quantized=self.pre_quantized, ) model.config.quantization_config = self.quantization_config def _process_model_after_weight_loading(self, model, **kwargs): return model - # @property - # def is_serializable(self): - # return True + @property + def is_serializable(self): + return False + + @property + def is_trainable(self): + return False diff --git a/src/diffusers/quantizers/nunchaku/utils.py b/src/diffusers/quantizers/nunchaku/utils.py index abdcd186dcdd..af36d7e6387f 100644 --- a/src/diffusers/quantizers/nunchaku/utils.py +++ b/src/diffusers/quantizers/nunchaku/utils.py @@ -5,9 +5,7 @@ if is_accelerate_available(): from accelerate import init_empty_weights - -if is_nunchaku_available(): - from nunchaku.models.linear import SVDQW4A4Linear + logger = logging.get_logger(__name__) @@ -15,6 +13,7 @@ def _replace_with_nunchaku_linear( model, + svdq_linear_cls, modules_to_not_convert=None, current_key_name=None, quantization_config=None, @@ -36,7 +35,7 @@ def _replace_with_nunchaku_linear( out_features = module.out_features if quantization_config.precision in ["int4", "nvfp4"]: - model._modules[name] = SVDQW4A4Linear( + model._modules[name] = svdq_linear_cls( in_features, out_features, rank=quantization_config.rank, @@ -62,7 +61,10 @@ def _replace_with_nunchaku_linear( def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): - model, _ = _replace_with_nunchaku_linear(model, modules_to_not_convert, current_key_name, quantization_config) + if is_nunchaku_available(): + from nunchaku.models.linear import SVDQW4A4Linear + + model, _ = _replace_with_nunchaku_linear(model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config) has_been_replaced = any( isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules() diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index e83c9e72d4f0..057a5aec7151 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -762,3 +762,26 @@ def post_init(self): accpeted_precision = ["int4", "nvfp4"] if self.precision not in accpeted_precision: raise ValueError(f"Only supported precision in {accpeted_precision} but found {self.precision}") + + # Copied from diffusers.quantizers.bitsandbytes.quantization_config.BitsandBytesConfig.to_diff_dict with BitsandBytesConfig->NunchakuConfig + def to_diff_dict(self) -> Dict[str, Any]: + """ + Removes all attributes from config which correspond to the default config attributes for better readability and + serializes to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = NunchakuConfig().to_dict() + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict From 5d08150a2e8ba092d26e402361d10034503c186c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: 2025年8月21日 14:39:58 +0530 Subject: [PATCH 07/13] up --- src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py | 5 +++++ src/diffusers/quantizers/nunchaku/utils.py | 8 +++++--- src/diffusers/quantizers/quantization_config.py | 3 +++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py index 113d9a4ba1d2..c61ac2692892 100644 --- a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py +++ b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py @@ -164,6 +164,11 @@ def _process_model_before_weight_loading( self.modules_to_not_convert = [self.modules_to_not_convert] self.modules_to_not_convert.extend(keep_in_fp32_modules) + # Purge `None`. + # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 + # in case of diffusion transformer models. For language models and others alike, `lm_head` + # and tied modules are usually kept in FP32. + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] model = replace_with_nunchaku_linear( model, diff --git a/src/diffusers/quantizers/nunchaku/utils.py b/src/diffusers/quantizers/nunchaku/utils.py index af36d7e6387f..b8b015e6eef8 100644 --- a/src/diffusers/quantizers/nunchaku/utils.py +++ b/src/diffusers/quantizers/nunchaku/utils.py @@ -5,7 +5,6 @@ if is_accelerate_available(): from accelerate import init_empty_weights - logger = logging.get_logger(__name__) @@ -40,7 +39,7 @@ def _replace_with_nunchaku_linear( out_features, rank=quantization_config.rank, bias=module.bias is not None, - dtype=model.dtype, + torch_dtype=module.weight.dtype, ) has_been_replaced = True # Store the module class in case we need to transpose the weight later @@ -50,6 +49,7 @@ def _replace_with_nunchaku_linear( if len(list(module.children()))> 0: _, has_been_replaced = _replace_with_nunchaku_linear( module, + svdq_linear_cls, modules_to_not_convert, current_key_name, quantization_config, @@ -64,7 +64,9 @@ def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key if is_nunchaku_available(): from nunchaku.models.linear import SVDQW4A4Linear - model, _ = _replace_with_nunchaku_linear(model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config) + model, _ = _replace_with_nunchaku_linear( + model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config + ) has_been_replaced = any( isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules() diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 057a5aec7151..acde80879ac7 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -750,6 +750,7 @@ def __init__( ): self.quant_method = QuantizationMethod.NUNCHAKU self.precision = precision + self.rank = rank self.group_size = self.group_size_map[precision] self.modules_to_not_convert = modules_to_not_convert @@ -763,6 +764,8 @@ def post_init(self): if self.precision not in accpeted_precision: raise ValueError(f"Only supported precision in {accpeted_precision} but found {self.precision}") + # TODO: should there be a check for rank? + # Copied from diffusers.quantizers.bitsandbytes.quantization_config.BitsandBytesConfig.to_diff_dict with BitsandBytesConfig->NunchakuConfig def to_diff_dict(self) -> Dict[str, Any]: """ From d35e77ece0706701965bc336817e074b8a3ad1c1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: 2025年8月21日 14:48:08 +0530 Subject: [PATCH 08/13] up --- src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py index c61ac2692892..7cee0bccc32d 100644 --- a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py +++ b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py @@ -38,12 +38,12 @@ class NunchakuQuantizer(DiffusersQuantizer): requires_calibration = False required_packages = ["nunchaku", "accelerate"] - dtype_map = {"int4": torch.int8} - if is_fp8_available(): - dtype_map = {"nvfp4": torch.float8_e4m3fn} - def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) + dtype_map = {"int4": torch.int8} + if is_fp8_available(): + dtype_map = {"nvfp4": torch.float8_e4m3fn} + self.dtype_map = dtype_map def validate_environment(self, *args, **kwargs): if not torch.cuda.is_available(): From 2a827ec19f432ccc4035a1d6a63dee636bcc918d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: 2025年8月21日 14:50:57 +0530 Subject: [PATCH 09/13] up --- src/diffusers/quantizers/quantization_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index acde80879ac7..919267f59c3a 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -766,7 +766,7 @@ def post_init(self): # TODO: should there be a check for rank? - # Copied from diffusers.quantizers.bitsandbytes.quantization_config.BitsandBytesConfig.to_diff_dict with BitsandBytesConfig->NunchakuConfig + # Copied from diffusers.quantizers.quantization_config.BitsAndBytesConfig.to_diff_dict with BitsAndBytesConfig->NunchakuConfig def to_diff_dict(self) -> Dict[str, Any]: """ Removes all attributes from config which correspond to the default config attributes for better readability and From df58c8017e77833be570b3377906b3de2e3ac1f7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: 2025年8月21日 17:08:29 +0530 Subject: [PATCH 10/13] up --- .../quantizers/nunchaku/nunchaku_quantizer.py | 66 +++++++++------ src/diffusers/quantizers/nunchaku/utils.py | 81 ------------------- 2 files changed, 41 insertions(+), 106 deletions(-) delete mode 100644 src/diffusers/quantizers/nunchaku/utils.py diff --git a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py index 7cee0bccc32d..dbb68a68e863 100644 --- a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py +++ b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py @@ -20,11 +20,6 @@ if is_torch_available(): import torch -if is_accelerate_available(): - pass - -if is_nunchaku_available(): - from .utils import replace_with_nunchaku_linear logger = logging.get_logger(__name__) @@ -79,13 +74,14 @@ def check_if_quantized_param( state_dict: Dict[str, Any], **kwargs, ): - from nunchaku.models.linear import SVDQW4A4Linear - - module, tensor_name = get_module_from_name(model, param_name) - if self.pre_quantized and isinstance(module, SVDQW4A4Linear): - return True - - return False + # TODO: revisit + # Check if the param_name is not in self.modules_to_not_convert + if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert): + return False + else: + # We only quantize the weight of nn.Linear + module, _ = get_module_from_name(model, param_name) + return isinstance(module, torch.nn.Linear) def create_quantized_param( self, @@ -112,13 +108,32 @@ def create_quantized_param( module._buffers[tensor_name] = torch.nn.Parameter(param_value.to(target_device)) elif isinstance(module, torch.nn.Linear): - if tensor_name in module._parameters: - module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) - if tensor_name in module._buffers: - module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(target_device) - - new_module = SVDQW4A4Linear.from_linear(module) - setattr(model, param_name, new_module) + # TODO: this returns an `SVDQW4A4Linear` layer initialized from the corresponding `linear` module. + # But we need to have a utility that can take a pretrained param value and quantize it. Not sure + # how to do that yet. + # Essentially, we need something like `bnb.nn.Params4bit.from_prequantized`. Or is there a better + # way to do it? + is_param = tensor_name in module._parameters + is_buffer = tensor_name in module._buffers + new_module = SVDQW4A4Linear.from_linear( + module, precision=self.quantization_config.precision, rank=self.quantization_config.rank + ) + module_name = ".".join(param_name.split(".")[:-1]) + if "." in module_name: + parent_name, leaf = module_name.rsplit(".", 1) + parent = model.get_submodule(parent_name) + else: + parent, leaf = model, module_name + + # rebind + # this will result into + # AttributeError: 'SVDQW4A4Linear' object has no attribute 'weight'. Did you mean: 'qweight'. + if is_param: + new_module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) + elif is_buffer: + new_module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) + + setattr(parent, leaf, new_module) def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: max_memory = {key: val * 0.90 for key, val in max_memory.items()} @@ -157,24 +172,25 @@ def _process_model_before_weight_loading( keep_in_fp32_modules: List[str] = [], **kwargs, ): - # TODO: deal with `device_map` self.modules_to_not_convert = self.quantization_config.modules_to_not_convert if not isinstance(self.modules_to_not_convert, list): self.modules_to_not_convert = [self.modules_to_not_convert] self.modules_to_not_convert.extend(keep_in_fp32_modules) + + # TODO: revisit + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + # if isinstance(device_map, dict) and len(device_map.keys())> 1: + # keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + # self.modules_to_not_convert.extend(keys_on_cpu) + # Purge `None`. # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 # in case of diffusion transformer models. For language models and others alike, `lm_head` # and tied modules are usually kept in FP32. self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] - model = replace_with_nunchaku_linear( - model, - modules_to_not_convert=self.modules_to_not_convert, - quantization_config=self.quantization_config, - ) model.config.quantization_config = self.quantization_config def _process_model_after_weight_loading(self, model, **kwargs): diff --git a/src/diffusers/quantizers/nunchaku/utils.py b/src/diffusers/quantizers/nunchaku/utils.py deleted file mode 100644 index b8b015e6eef8..000000000000 --- a/src/diffusers/quantizers/nunchaku/utils.py +++ /dev/null @@ -1,81 +0,0 @@ -import torch.nn as nn - -from ...utils import is_accelerate_available, is_nunchaku_available, logging - - -if is_accelerate_available(): - from accelerate import init_empty_weights - - -logger = logging.get_logger(__name__) - - -def _replace_with_nunchaku_linear( - model, - svdq_linear_cls, - modules_to_not_convert=None, - current_key_name=None, - quantization_config=None, - has_been_replaced=False, -): - for name, module in model.named_children(): - if current_key_name is None: - current_key_name = [] - current_key_name.append(name) - - if isinstance(module, nn.Linear) and name not in modules_to_not_convert: - # Check if the current key is not in the `modules_to_not_convert` - current_key_name_str = ".".join(current_key_name) - if not any( - (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert - ): - with init_empty_weights(): - in_features = module.in_features - out_features = module.out_features - - if quantization_config.precision in ["int4", "nvfp4"]: - model._modules[name] = svdq_linear_cls( - in_features, - out_features, - rank=quantization_config.rank, - bias=module.bias is not None, - torch_dtype=module.weight.dtype, - ) - has_been_replaced = True - # Store the module class in case we need to transpose the weight later - model._modules[name].source_cls = type(module) - # Force requires grad to False to avoid unexpected errors - model._modules[name].requires_grad_(False) - if len(list(module.children()))> 0: - _, has_been_replaced = _replace_with_nunchaku_linear( - module, - svdq_linear_cls, - modules_to_not_convert, - current_key_name, - quantization_config, - has_been_replaced=has_been_replaced, - ) - # Remove the last key for recursion - current_key_name.pop(-1) - return model, has_been_replaced - - -def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): - if is_nunchaku_available(): - from nunchaku.models.linear import SVDQW4A4Linear - - model, _ = _replace_with_nunchaku_linear( - model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config - ) - - has_been_replaced = any( - isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules() - ) - if not has_been_replaced: - logger.warning( - "You are loading your model in the SVDQuant method but no linear modules were found in your model." - " Please double check your model architecture, or submit an issue on github if you think this is" - " a bug." - ) - - return model From 4af534bde942a5baf42d82871752cd74da58b0c1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: 2025年8月22日 10:36:27 +0530 Subject: [PATCH 11/13] up --- src/diffusers/loaders/single_file_model.py | 34 ++++++- src/diffusers/models/model_loading_utils.py | 7 ++ .../quantizers/nunchaku/nunchaku_quantizer.py | 98 +++++++++---------- src/diffusers/quantizers/nunchaku/utils.py | 80 +++++++++++++++ .../quantizers/quantization_config.py | 47 ++++++--- 5 files changed, 200 insertions(+), 66 deletions(-) create mode 100644 src/diffusers/quantizers/nunchaku/utils.py diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 16bd0441072a..0b77f558f19a 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -190,6 +190,23 @@ def _get_mapping_function_kwargs(mapping_fn, **kwargs): return mapping_kwargs +def _maybe_determine_modules_to_not_convert(quantization_config, state_dict): + if quantization_config is None: + return None + else: + is_nunchaku = quantization_config.quant_method == "nunchaku" + if not is_nunchaku: + return None + else: + no_qweight = set() + for key in state_dict: + if key.endswith(".weight"): + # module name is everything except the last piece after "." + module_name = ".".join(key.split(".")[:-1]) + no_qweight.add(module_name) + return sorted(no_qweight) + + class FromOriginalModelMixin: """ Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model. @@ -324,6 +341,18 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = user_agent=user_agent, ) if quantization_config is not None: + # For `nunchaku` checkpoints, we might want to determine the `modules_to_not_convert`. + original_modules_to_not_convert = quantization_config.modules_to_not_convert or [] + determined_modules_to_not_convert = _maybe_determine_modules_to_not_convert( + quantization_config, checkpoint + ) + if determined_modules_to_not_convert: + original_modules_to_not_convert.extend(determined_modules_to_not_convert) + original_modules_to_not_convert = list(set(original_modules_to_not_convert)) + logger.info( + f"`modules_to_not_convert` in the quantization_config was updated from {quantization_config.modules_to_not_convert} to {original_modules_to_not_convert}." + ) + quantization_config.modules_to_not_convert = original_modules_to_not_convert hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) hf_quantizer.validate_environment() torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) @@ -404,8 +433,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = model = cls.from_config(diffusers_model_config) checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs) - - if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint): + if not ( + quantization_config is not None and quantization_config.quant_method == "nunchaku" + ) and _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint): diffusers_format_checkpoint = checkpoint_mapping_fn( config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs ) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 8b48ba6b4873..0fed2d4f483f 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -297,6 +297,13 @@ def load_model_dict_into_meta( offload_index = offload_weight(param, param_name, offload_folder, offload_index) elif param_device == "cpu" and state_dict_index is not None: state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) + # This check below might be a bit counter-intuitive in nature. This is because we're checking if the param + # or its module is quantized and if so, we're proceeding with creating a quantized param. This is because + # of the way pre-trained models are loaded. They're initialized under "meta" device, where + # quantization layers are first injected. Hence, for a model that is either pre-quantized or supplemented + # with a `quantization_config` during `from_pretrained`, we expect `check_if_quantized_param` to return True. + # Then depending on the quantization backend being used, we run the actual quantization step under + # `create_quantized_param`. elif is_quantized and ( hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device) ): diff --git a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py index dbb68a68e863..b9e2bb23bceb 100644 --- a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py +++ b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py @@ -20,6 +20,11 @@ if is_torch_available(): import torch +if is_accelerate_available(): + pass + +if is_nunchaku_available(): + from .utils import replace_with_nunchaku_linear logger = logging.get_logger(__name__) @@ -35,10 +40,6 @@ class NunchakuQuantizer(DiffusersQuantizer): def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) - dtype_map = {"int4": torch.int8} - if is_fp8_available(): - dtype_map = {"nvfp4": torch.float8_e4m3fn} - self.dtype_map = dtype_map def validate_environment(self, *args, **kwargs): if not torch.cuda.is_available(): @@ -74,14 +75,13 @@ def check_if_quantized_param( state_dict: Dict[str, Any], **kwargs, ): - # TODO: revisit - # Check if the param_name is not in self.modules_to_not_convert - if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert): - return False - else: - # We only quantize the weight of nn.Linear - module, _ = get_module_from_name(model, param_name) - return isinstance(module, torch.nn.Linear) + from nunchaku.models.linear import SVDQW4A4Linear + + module, _ = get_module_from_name(model, param_name) + if self.pre_quantized and isinstance(module, SVDQW4A4Linear): + return True + + return False def create_quantized_param( self, @@ -98,42 +98,33 @@ def create_quantized_param( from nunchaku.models.linear import SVDQW4A4Linear module, tensor_name = get_module_from_name(model, param_name) + state_dict = args[0] if tensor_name not in module._parameters and tensor_name not in module._buffers: raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") - if self.pre_quantized: - if tensor_name in module._parameters: - module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) - if tensor_name in module._buffers: - module._buffers[tensor_name] = torch.nn.Parameter(param_value.to(target_device)) - - elif isinstance(module, torch.nn.Linear): - # TODO: this returns an `SVDQW4A4Linear` layer initialized from the corresponding `linear` module. - # But we need to have a utility that can take a pretrained param value and quantize it. Not sure - # how to do that yet. - # Essentially, we need something like `bnb.nn.Params4bit.from_prequantized`. Or is there a better - # way to do it? - is_param = tensor_name in module._parameters - is_buffer = tensor_name in module._buffers - new_module = SVDQW4A4Linear.from_linear( - module, precision=self.quantization_config.precision, rank=self.quantization_config.rank - ) - module_name = ".".join(param_name.split(".")[:-1]) - if "." in module_name: - parent_name, leaf = module_name.rsplit(".", 1) - parent = model.get_submodule(parent_name) + if isinstance(module, SVDQW4A4Linear): + if param_value.ndim == 1: + module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to( + target_device + ) + elif tensor_name == "qweight": + module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to( + target_device + ) + # if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors + for t in ["lora_up", "lora_down"]: + # need to check at the state dict level for this + new_tensor_name = param_name.replace(".qweight", f".{t}") + if new_tensor_name not in state_dict: + oc, ic = param_value.shape + ic = ic * 2 # v is packed into INT8, so we need to double the size + module._parameters[t] = torch.zeros( + (0, ic) if t == "lora_down" else (oc, 0), device=param_value.device, dtype=torch.bfloat16 + ) else: - parent, leaf = model, module_name - - # rebind - # this will result into - # AttributeError: 'SVDQW4A4Linear' object has no attribute 'weight'. Did you mean: 'qweight'. - if is_param: - new_module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) - elif is_buffer: - new_module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) - - setattr(parent, leaf, new_module) + module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to( + target_device + ) def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: max_memory = {key: val * 0.90 for key, val in max_memory.items()} @@ -173,24 +164,25 @@ def _process_model_before_weight_loading( **kwargs, ): self.modules_to_not_convert = self.quantization_config.modules_to_not_convert - if not isinstance(self.modules_to_not_convert, list): self.modules_to_not_convert = [self.modules_to_not_convert] - self.modules_to_not_convert.extend(keep_in_fp32_modules) - - # TODO: revisit - # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` - # if isinstance(device_map, dict) and len(device_map.keys())> 1: - # keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] - # self.modules_to_not_convert.extend(keys_on_cpu) - # Purge `None`. # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 # in case of diffusion transformer models. For language models and others alike, `lm_head` # and tied modules are usually kept in FP32. self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + if isinstance(device_map, dict) and len(device_map.keys())> 1: + keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + self.modules_to_not_convert.extend(keys_on_cpu) + + model = replace_with_nunchaku_linear( + model, + modules_to_not_convert=self.modules_to_not_convert, + quantization_config=self.quantization_config, + ) model.config.quantization_config = self.quantization_config def _process_model_after_weight_loading(self, model, **kwargs): diff --git a/src/diffusers/quantizers/nunchaku/utils.py b/src/diffusers/quantizers/nunchaku/utils.py new file mode 100644 index 000000000000..4d7eb1a021ea --- /dev/null +++ b/src/diffusers/quantizers/nunchaku/utils.py @@ -0,0 +1,80 @@ +import torch.nn as nn + +from ...utils import is_accelerate_available, is_nunchaku_available, logging + + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +logger = logging.get_logger(__name__) + + +def _replace_with_nunchaku_linear( + model, + svdq_linear_cls, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, +): + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + with init_empty_weights(): + in_features = module.in_features + out_features = module.out_features + + model._modules[name] = svdq_linear_cls( + in_features, + out_features, + rank=quantization_config.rank, + bias=module.bias is not None, + torch_dtype=module.weight.dtype, + ) + has_been_replaced = True + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + if len(list(module.children()))> 0: + _, has_been_replaced = _replace_with_nunchaku_linear( + module, + svdq_linear_cls, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): + if is_nunchaku_available(): + from nunchaku.models.linear import SVDQW4A4Linear + + model, _ = _replace_with_nunchaku_linear( + model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config + ) + + has_been_replaced = any( + isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules() + ) + if not has_been_replaced: + logger.warning( + "You are loading your model in the SVDQuant method but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + return model diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 919267f59c3a..c82b1e4b24e4 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -733,36 +733,61 @@ class NunchakuConfig(QuantizationConfigMixin): loaded using `nunchaku`. Args: - TODO + TODO modules_to_not_convert (`list`, *optional*, default to `None`): The list of modules to not quantize, useful for quantizing models that explicitly require to have some - modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers). + modules left in their original precision (e.g. `norm` layers in Qwen-Image). """ - group_size_map = {"int4": 64, "nvfp4": 16} - def __init__( self, - precision: str = "int4", + method: str = "svdquant", + weight_dtype: str = "int4", + weight_scale_dtype: str = None, + weight_group_size: int = 64, + activation_dtype: str = "int4", + activation_scale_dtype: str = None, + activation_group_size: int = 64, rank: int = 32, modules_to_not_convert: Optional[List[str]] = None, **kwargs, ): self.quant_method = QuantizationMethod.NUNCHAKU - self.precision = precision + self.method = method + self.weight_dtype = weight_dtype + self.weight_scale_dtype = weight_scale_dtype + self.weight_group_size = weight_group_size + self.activation_dtype = activation_dtype + self.activation_scale_dtype = activation_scale_dtype + self.activation_group_size = activation_group_size self.rank = rank - self.group_size = self.group_size_map[precision] self.modules_to_not_convert = modules_to_not_convert self.post_init() def post_init(self): r""" - Safety checker that arguments are correct + Safety checker that arguments are correct. Hardware checks were largely adapted from the official `nunchaku` + codebase. """ - accpeted_precision = ["int4", "nvfp4"] - if self.precision not in accpeted_precision: - raise ValueError(f"Only supported precision in {accpeted_precision} but found {self.precision}") + from ..utils.torch_utils import get_device + + device = get_device() + if isinstance(device, str): + device = torch.device(device) + capability = torch.cuda.get_device_capability(0 if device.index is None else device.index) + sm = f"{capability[0]}{capability[1]}" + if sm == "120": # you can only use the fp4 models + if self.weight_dtype != "fp4_e2m1_all": + raise ValueError('Please use "fp4" quantization for Blackwell GPUs.') + elif sm in ["75", "80", "86", "89"]: + if self.weight_dtype != "int4": + raise ValueError('Please use "int4" quantization for Turing, Ampere and Ada GPUs.') + else: + raise ValueError( + f"Unsupported GPU architecture {sm} due to the lack of 4-bit tensorcores. " + "Please use a Turing, Ampere, Ada or Blackwell GPU for this quantization configuration." + ) # TODO: should there be a check for rank? From 8e07445540e6c3c5868a3837dc1915ac7a7d9af7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: 2025年8月22日 19:21:56 +0530 Subject: [PATCH 12/13] up --- src/diffusers/loaders/single_file_model.py | 47 +++++--- src/diffusers/loaders/single_file_utils.py | 95 ++++++++++++++++ .../loaders/single_file_utils_nunchaku.py | 102 ++++++++++++++++++ .../quantizers/nunchaku/nunchaku_quantizer.py | 43 ++------ 4 files changed, 239 insertions(+), 48 deletions(-) create mode 100644 src/diffusers/loaders/single_file_utils_nunchaku.py diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 0b77f558f19a..88f95a01a646 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -42,6 +42,7 @@ convert_ltx_vae_checkpoint_to_diffusers, convert_lumina2_to_diffusers, convert_mochi_transformer_checkpoint_to_diffusers, + convert_nunchaku_flux_to_diffusers, convert_sana_transformer_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, @@ -341,18 +342,6 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = user_agent=user_agent, ) if quantization_config is not None: - # For `nunchaku` checkpoints, we might want to determine the `modules_to_not_convert`. - original_modules_to_not_convert = quantization_config.modules_to_not_convert or [] - determined_modules_to_not_convert = _maybe_determine_modules_to_not_convert( - quantization_config, checkpoint - ) - if determined_modules_to_not_convert: - original_modules_to_not_convert.extend(determined_modules_to_not_convert) - original_modules_to_not_convert = list(set(original_modules_to_not_convert)) - logger.info( - f"`modules_to_not_convert` in the quantization_config was updated from {quantization_config.modules_to_not_convert} to {original_modules_to_not_convert}." - ) - quantization_config.modules_to_not_convert = original_modules_to_not_convert hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) hf_quantizer.validate_environment() torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) @@ -433,9 +422,14 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = model = cls.from_config(diffusers_model_config) checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs) - if not ( - quantization_config is not None and quantization_config.quant_method == "nunchaku" - ) and _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint): + model_state_dict = model.state_dict() + # TODO: Only flux nunchaku checkpoint for now. Unify with how checkpoint mappers are done. + # For `nunchaku` checkpoints, we might want to determine the `modules_to_not_convert`. + if quantization_config is not None and quantization_config.quant_method == "nunchaku": + diffusers_format_checkpoint = convert_nunchaku_flux_to_diffusers( + checkpoint, model_state_dict=model_state_dict + ) + elif _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint): diffusers_format_checkpoint = checkpoint_mapping_fn( config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs ) @@ -446,6 +440,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = raise SingleFileComponentError( f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." ) + + # This step is better off here than above because `diffusers_format_checkpoint` holds the keys we expect. + if quantization_config is not None: + original_modules_to_not_convert = quantization_config.modules_to_not_convert or [] + determined_modules_to_not_convert = _maybe_determine_modules_to_not_convert( + quantization_config, checkpoint + ) + if determined_modules_to_not_convert: + determined_modules_to_not_convert.extend(original_modules_to_not_convert) + determined_modules_to_not_convert = list(set(determined_modules_to_not_convert)) + logger.info( + f"`modules_to_not_convert` in the quantization_config was updated from {quantization_config.modules_to_not_convert} to {determined_modules_to_not_convert}." + ) + quantization_config.modules_to_not_convert = original_modules_to_not_convert + # Update the `quant_config`. + hf_quantizer.quantization_config = quantization_config + # Check if `_keep_in_fp32_modules` is not None use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") @@ -473,6 +484,12 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = unexpected_keys = [ param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict ] + for k in unexpected_keys: + if "single_transformer_blocks.0" in k: + print(f"Unexpected {k=}") + for k in empty_state_dict: + if "single_transformer_blocks.0" in k: + print(f"model {k=}") device_map = {"": param_device} load_model_dict_into_meta( model, diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index ef6c41e3ce97..722d38d4f5bb 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -2189,6 +2189,101 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): return converted_state_dict +# Adapted from https://github.com/nunchaku-tech/nunchaku/blob/3ec299f439f9986a69ded320798cab4e258c871d/nunchaku/models/transformers/transformer_flux_v2.py#L395 +def convert_nunchaku_flux_to_diffusers(checkpoint, **kwargs): + from .single_file_utils_nunchaku import _unpack_qkv_state_dict + + _SMOOTH_ORIG_RE = re.compile(r"\.smooth_orig(\.|$)") + _SMOOTH_RE = re.compile(r"\.smooth(\.|$)") + + new_state_dict = {} + model_state_dict = kwargs["model_state_dict"] + + ckpt_keys = list(checkpoint.keys()) + for k in ckpt_keys: + if "qweight" in k: + # only the shape information of this tensor is needed + v = checkpoint[k] + # if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors + for t in ["lora_up", "lora_down"]: + new_k = k.replace(".qweight", f".{t}") + if new_k not in ckpt_keys: + oc, ic = v.shape + ic = ic * 2 # v is packed into INT8, so we need to double the size + checkpoint[k.replace(".qweight", f".{t}")] = torch.zeros( + (0, ic) if t == "lora_down" else (oc, 0), device=v.device, dtype=torch.bfloat16 + ) + + for k, v in checkpoint.items(): + new_k = k # start with original, then apply independent replacements + + if k.startswith("single_transformer_blocks."): + # attention / qkv / norms + new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.") + new_k = new_k.replace(".out_proj.", ".attn.to_out.") + new_k = new_k.replace(".norm_k.", ".attn.norm_k.") + new_k = new_k.replace(".norm_q.", ".attn.norm_q.") + + # mlp heads + new_k = new_k.replace(".mlp_fc1.", ".proj_mlp.") + new_k = new_k.replace(".mlp_fc2.", ".proj_out.") + + # smooth params (use regex to avoid substring collisions) + new_k = _SMOOTH_ORIG_RE.sub(r".smooth_factor_orig1円", new_k) + new_k = _SMOOTH_RE.sub(r".smooth_factor1円", new_k) + + # lora -> proj + new_k = new_k.replace(".lora_down", ".proj_down") + new_k = new_k.replace(".lora_up", ".proj_up") + + elif k.startswith("transformer_blocks."): + # feed-forward (context & base) + new_k = new_k.replace(".mlp_context_fc1.", ".ff_context.net.0.proj.") + new_k = new_k.replace(".mlp_context_fc2.", ".ff_context.net.2.") + new_k = new_k.replace(".mlp_fc1.", ".ff.net.0.proj.") + new_k = new_k.replace(".mlp_fc2.", ".ff.net.2.") + + # attention projections + new_k = new_k.replace(".qkv_proj_context.", ".attn.add_qkv_proj.") + new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.") + new_k = new_k.replace(".out_proj.", ".attn.to_out.0.") + new_k = new_k.replace(".out_proj_context.", ".attn.to_add_out.") + + # norms + new_k = new_k.replace(".norm_k.", ".attn.norm_k.") + new_k = new_k.replace(".norm_q.", ".attn.norm_q.") + new_k = new_k.replace(".norm_added_k.", ".attn.norm_added_k.") + new_k = new_k.replace(".norm_added_q.", ".attn.norm_added_q.") + + # smooth params + new_k = _SMOOTH_ORIG_RE.sub(r".smooth_factor_orig1円", new_k) + new_k = _SMOOTH_RE.sub(r".smooth_factor1円", new_k) + + # lora -> proj + new_k = new_k.replace(".lora_down", ".proj_down") + new_k = new_k.replace(".lora_up", ".proj_up") + + new_state_dict[new_k] = v + + new_state_dict = _unpack_qkv_state_dict(new_state_dict) + + # some remnant keys need to be patched + new_sd_keys = list(new_state_dict.keys()) + for k in new_sd_keys: + if "qweight" in k: + no_qweight_k = ".".join(k.split(".qweight")[:-1]) + for unexpected_k in ["wzeros"]: + unexpected_k = no_qweight_k + f".{unexpected_k}" + if unexpected_k in new_sd_keys: + _ = new_state_dict.pop(unexpected_k) + for k in model_state_dict: + if k not in new_state_dict: + # CPU device for now + new_state_dict[k] = torch.ones_like(k, device="cpu") + + return new_state_dict + + def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} keys = list(checkpoint.keys()) diff --git a/src/diffusers/loaders/single_file_utils_nunchaku.py b/src/diffusers/loaders/single_file_utils_nunchaku.py new file mode 100644 index 000000000000..faf89a77f755 --- /dev/null +++ b/src/diffusers/loaders/single_file_utils_nunchaku.py @@ -0,0 +1,102 @@ +import re + +import torch + + +_QKV_ANCHORS_NUNCHAKU = ("to_qkv", "add_qkv_proj") +_ALLOWED_SUFFIXES_NUNCHAKU = { + "bias", + "lora_down", + "lora_up", + "qweight", + "smooth_factor", + "smooth_factor_orig", + "wscales", +} + +_QKV_NUNCHAKU_REGEX = re.compile( + rf"^(?P.*)\.(?:{'|'.join(map(re.escape, _QKV_ANCHORS_NUNCHAKU))})\.(?P.+)$" +) + + +def _pick_split_dim(t: torch.Tensor, suffix: str) -> int: + """ + Choose which dimension to split by 3. Heuristics: + - 1D -> dim 0 + - 2D -> prefer dim=1 for 'qweight' (common layout [*, 3*out_features]), + otherwise prefer dim=0 (common layout [3*out_features, *]). + - If preferred dim isn't divisible by 3, try the other; else error. + """ + shape = list(t.shape) + if len(shape) == 0: + raise ValueError("Cannot split a scalar into Q/K/V.") + + if len(shape) == 1: + dim = 0 + if shape[dim] % 3 == 0: + return dim + raise ValueError(f"1D tensor of length {shape[0]} not divisible by 3.") + + # len(shape)>= 2 + preferred = 1 if suffix == "qweight" else 0 + other = 0 if preferred == 1 else 1 + + if shape[preferred] % 3 == 0: + return preferred + if shape[other] % 3 == 0: + return other + + # Fall back: any dim divisible by 3 + for d, s in enumerate(shape): + if s % 3 == 0: + return d + + raise ValueError(f"None of the dims {shape} are divisible by 3 for suffix '{suffix}'.") + + +def _split_qkv(t: torch.Tensor, dim: int): + return torch.tensor_split(t, 3, dim=dim) + + +def _unpack_qkv_state_dict( + state_dict: dict, anchors=_QKV_ANCHORS_NUNCHAKU, allowed_suffixes=_ALLOWED_SUFFIXES_NUNCHAKU +): + """ + Convert fused QKV entries (e.g., '...to_qkv.bias', '...qkv_proj.wscales') into separate Q/K/V entries: + '...to_q.bias', '...to_k.bias', '...to_v.bias' '...to_q.wscales', '...to_k.wscales', '...to_v.wscales' + Returns a NEW dict; original is not modified. + + Only keys with suffix in `allowed_suffixes` are processed. Keys with non-divisible-by-3 tensors raise a ValueError. + """ + anchors = tuple(anchors) + allowed_suffixes = set(allowed_suffixes) + + new_sd: dict = {} + for k, v in state_dict.items(): + m = _QKV_NUNCHAKU_REGEX.match(k) + if m: + suffix = m.group("suffix") + if suffix not in allowed_suffixes: + # keep as-is if it's not one of the targeted suffixes + new_sd[k] = v + continue + + prefix = m.group("prefix") # everything before .to_qkv/.qkv_proj + # Decide split axis + split_dim = _pick_split_dim(v, suffix) + q, k_, vv = _split_qkv(v, dim=split_dim) + + # Build new keys + base_q = f"{prefix}.to_q.{suffix}" + base_k = f"{prefix}.to_k.{suffix}" + base_v = f"{prefix}.to_v.{suffix}" + + # Write into result dict + new_sd[base_q] = q + new_sd[base_k] = k_ + new_sd[base_v] = vv + else: + # not a fused qkv key + new_sd[k] = v + + return new_sd diff --git a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py index b9e2bb23bceb..b2886b118de9 100644 --- a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py +++ b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py @@ -2,13 +2,7 @@ from diffusers.utils.import_utils import is_nunchaku_version -from ...utils import ( - get_module_from_name, - is_accelerate_available, - is_nunchaku_available, - is_torch_available, - logging, -) +from ...utils import get_module_from_name, is_accelerate_available, is_nunchaku_available, is_torch_available, logging from ...utils.torch_utils import is_fp8_available from ..base import DiffusersQuantizer @@ -20,15 +14,20 @@ if is_torch_available(): import torch -if is_accelerate_available(): - pass - if is_nunchaku_available(): from .utils import replace_with_nunchaku_linear logger = logging.get_logger(__name__) +KEY_MAP = { + "lora_down": "proj_down", + "lora_up": "proj_up", + "smooth_orig": "smooth_factor_orig", + "smooth": "smooth_factor", +} + + class NunchakuQuantizer(DiffusersQuantizer): r""" Diffusers Quantizer for Nunchaku (https://github.com/nunchaku-tech/nunchaku) @@ -98,33 +97,11 @@ def create_quantized_param( from nunchaku.models.linear import SVDQW4A4Linear module, tensor_name = get_module_from_name(model, param_name) - state_dict = args[0] if tensor_name not in module._parameters and tensor_name not in module._buffers: raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") if isinstance(module, SVDQW4A4Linear): - if param_value.ndim == 1: - module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to( - target_device - ) - elif tensor_name == "qweight": - module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to( - target_device - ) - # if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors - for t in ["lora_up", "lora_down"]: - # need to check at the state dict level for this - new_tensor_name = param_name.replace(".qweight", f".{t}") - if new_tensor_name not in state_dict: - oc, ic = param_value.shape - ic = ic * 2 # v is packed into INT8, so we need to double the size - module._parameters[t] = torch.zeros( - (0, ic) if t == "lora_down" else (oc, 0), device=param_value.device, dtype=torch.bfloat16 - ) - else: - module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to( - target_device - ) + module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to(target_device) def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: max_memory = {key: val * 0.90 for key, val in max_memory.items()} From 3295c6aba56b96587e1578ee590c9002c5fe457c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: 2025年8月26日 19:38:43 +0200 Subject: [PATCH 13/13] up --- src/diffusers/loaders/single_file_model.py | 13 +++++++---- src/diffusers/loaders/single_file_utils.py | 8 +++++-- .../loaders/single_file_utils_nunchaku.py | 10 ++++---- src/diffusers/quantizers/auto.py | 2 +- .../quantizers/quantization_config.py | 23 ++----------------- 5 files changed, 24 insertions(+), 32 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 88f95a01a646..2d33e2e1e26c 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -23,6 +23,7 @@ from .. import __version__ from ..quantizers import DiffusersAutoQuantizer +from ..quantizers.quantization_config import NunchakuConfig from ..utils import deprecate, is_accelerate_available, is_torch_version, logging from ..utils.torch_utils import empty_device_cache from .single_file_utils import ( @@ -442,6 +443,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = ) # This step is better off here than above because `diffusers_format_checkpoint` holds the keys we expect. + # We can move it to a separate function as well. if quantization_config is not None: original_modules_to_not_convert = quantization_config.modules_to_not_convert or [] determined_modules_to_not_convert = _maybe_determine_modules_to_not_convert( @@ -450,12 +452,15 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = if determined_modules_to_not_convert: determined_modules_to_not_convert.extend(original_modules_to_not_convert) determined_modules_to_not_convert = list(set(determined_modules_to_not_convert)) - logger.info( + logger.debug( f"`modules_to_not_convert` in the quantization_config was updated from {quantization_config.modules_to_not_convert} to {determined_modules_to_not_convert}." ) - quantization_config.modules_to_not_convert = original_modules_to_not_convert - # Update the `quant_config`. - hf_quantizer.quantization_config = quantization_config + modified_quant_config = quantization_config.to_dict() + modified_quant_config["modules_to_not_convert"] = determined_modules_to_not_convert + # TODO: figure out a better way. + modified_quant_config = NunchakuConfig.from_dict(modified_quant_config) + setattr(hf_quantizer, "quantization_config", modified_quant_config) + logger.debug("TODO") # Check if `_keep_in_fp32_modules` is not None use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 722d38d4f5bb..9a62fc12cf63 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -2220,7 +2220,7 @@ def convert_nunchaku_flux_to_diffusers(checkpoint, **kwargs): if k.startswith("single_transformer_blocks."): # attention / qkv / norms new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.") - new_k = new_k.replace(".out_proj.", ".attn.to_out.") + new_k = new_k.replace(".out_proj.", ".proj_out.") new_k = new_k.replace(".norm_k.", ".attn.norm_k.") new_k = new_k.replace(".norm_q.", ".attn.norm_q.") @@ -2279,7 +2279,11 @@ def convert_nunchaku_flux_to_diffusers(checkpoint, **kwargs): for k in model_state_dict: if k not in new_state_dict: # CPU device for now - new_state_dict[k] = torch.ones_like(k, device="cpu") + new_state_dict[k] = torch.ones_like(model_state_dict[k], device="cpu") + + for k in new_state_dict: + if "single_transformer_blocks.0" in k and k.endswith(".weight"): + print(f"{k=}") return new_state_dict diff --git a/src/diffusers/loaders/single_file_utils_nunchaku.py b/src/diffusers/loaders/single_file_utils_nunchaku.py index faf89a77f755..da74aa5211fb 100644 --- a/src/diffusers/loaders/single_file_utils_nunchaku.py +++ b/src/diffusers/loaders/single_file_utils_nunchaku.py @@ -6,8 +6,8 @@ _QKV_ANCHORS_NUNCHAKU = ("to_qkv", "add_qkv_proj") _ALLOWED_SUFFIXES_NUNCHAKU = { "bias", - "lora_down", - "lora_up", + "proj_down", + "proj_up", "qweight", "smooth_factor", "smooth_factor_orig", @@ -66,14 +66,16 @@ def _unpack_qkv_state_dict( '...to_q.bias', '...to_k.bias', '...to_v.bias' '...to_q.wscales', '...to_k.wscales', '...to_v.wscales' Returns a NEW dict; original is not modified. - Only keys with suffix in `allowed_suffixes` are processed. Keys with non-divisible-by-3 tensors raise a ValueError. + Only keys with suffix in `allowed_suffixes` are processed. Keys with non-divisible-by-3 tensors raise a ValueError.: """ anchors = tuple(anchors) allowed_suffixes = set(allowed_suffixes) new_sd: dict = {} - for k, v in state_dict.items(): + sd_keys = list(state_dict.keys()) + for k in sd_keys: m = _QKV_NUNCHAKU_REGEX.match(k) + v = state_dict.pop(k) if m: suffix = m.group("suffix") if suffix not in allowed_suffixes: diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index 910557865ea0..a921888a71da 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -56,7 +56,7 @@ class DiffusersAutoQuantizer: """ - The auto diffusers quantizer class that takes care of automatically instantiating to the correct + The auto diffusers quantizer class that takes care of automatically instantiating to the correct `DiffusersQuantizer` given the `QuantizationConfig`. """ diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index c82b1e4b24e4..5168382e5bef 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -791,25 +791,6 @@ def post_init(self): # TODO: should there be a check for rank? - # Copied from diffusers.quantizers.quantization_config.BitsAndBytesConfig.to_diff_dict with BitsAndBytesConfig->NunchakuConfig - def to_diff_dict(self) -> Dict[str, Any]: - """ - Removes all attributes from config which correspond to the default config attributes for better readability and - serializes to a Python dictionary. - - Returns: - `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, - """ + def __repr__(self): config_dict = self.to_dict() - - # get the default config dict - default_config_dict = NunchakuConfig().to_dict() - - serializable_config_dict = {} - - # only serialize values that differ from the default config - for key, value in config_dict.items(): - if value != default_config_dict[key]: - serializable_config_dict[key] = value - - return serializable_config_dict + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"

AltStyle によって変換されたページ (->オリジナル) /