diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 3f0f87b92609..7b1648c3dcce 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -13,6 +13,7 @@ is_k_diffusion_available, is_librosa_available, is_note_seq_available, + is_nunchaku_available, is_onnx_available, is_opencv_available, is_optimum_quanto_available, @@ -99,6 +100,18 @@ else: _import_structure["quantizers.quantization_config"].append("TorchAoConfig") +try: + if not is_torch_available() and not is_accelerate_available() and not is_nunchaku_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils import dummy_nunchaku_objects + + _import_structure["utils.dummy_nunchaku_objects"] = [ + name for name in dir(dummy_nunchaku_objects) if not name.startswith("_") + ] +else: + _import_structure["quantizers.quantization_config"].append("NunchakuConfig") + try: if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available(): raise OptionalDependencyNotAvailable() @@ -791,6 +804,14 @@ else: from .quantizers.quantization_config import QuantoConfig + try: + if not is_nunchaku_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from .utils.dummy_optimum_quanto_objects import * + else: + from .quantizers.quantization_config import NunchakuConfig + try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 16bd0441072a..2d33e2e1e26c 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -23,6 +23,7 @@ from .. import __version__ from ..quantizers import DiffusersAutoQuantizer +from ..quantizers.quantization_config import NunchakuConfig from ..utils import deprecate, is_accelerate_available, is_torch_version, logging from ..utils.torch_utils import empty_device_cache from .single_file_utils import ( @@ -42,6 +43,7 @@ convert_ltx_vae_checkpoint_to_diffusers, convert_lumina2_to_diffusers, convert_mochi_transformer_checkpoint_to_diffusers, + convert_nunchaku_flux_to_diffusers, convert_sana_transformer_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, @@ -190,6 +192,23 @@ def _get_mapping_function_kwargs(mapping_fn, **kwargs): return mapping_kwargs +def _maybe_determine_modules_to_not_convert(quantization_config, state_dict): + if quantization_config is None: + return None + else: + is_nunchaku = quantization_config.quant_method == "nunchaku" + if not is_nunchaku: + return None + else: + no_qweight = set() + for key in state_dict: + if key.endswith(".weight"): + # module name is everything except the last piece after "." + module_name = ".".join(key.split(".")[:-1]) + no_qweight.add(module_name) + return sorted(no_qweight) + + class FromOriginalModelMixin: """ Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model. @@ -404,8 +423,14 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = model = cls.from_config(diffusers_model_config) checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs) - - if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint): + model_state_dict = model.state_dict() + # TODO: Only flux nunchaku checkpoint for now. Unify with how checkpoint mappers are done. + # For `nunchaku` checkpoints, we might want to determine the `modules_to_not_convert`. + if quantization_config is not None and quantization_config.quant_method == "nunchaku": + diffusers_format_checkpoint = convert_nunchaku_flux_to_diffusers( + checkpoint, model_state_dict=model_state_dict + ) + elif _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint): diffusers_format_checkpoint = checkpoint_mapping_fn( config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs ) @@ -416,6 +441,27 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = raise SingleFileComponentError( f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint." ) + + # This step is better off here than above because `diffusers_format_checkpoint` holds the keys we expect. + # We can move it to a separate function as well. + if quantization_config is not None: + original_modules_to_not_convert = quantization_config.modules_to_not_convert or [] + determined_modules_to_not_convert = _maybe_determine_modules_to_not_convert( + quantization_config, checkpoint + ) + if determined_modules_to_not_convert: + determined_modules_to_not_convert.extend(original_modules_to_not_convert) + determined_modules_to_not_convert = list(set(determined_modules_to_not_convert)) + logger.debug( + f"`modules_to_not_convert` in the quantization_config was updated from {quantization_config.modules_to_not_convert} to {determined_modules_to_not_convert}." + ) + modified_quant_config = quantization_config.to_dict() + modified_quant_config["modules_to_not_convert"] = determined_modules_to_not_convert + # TODO: figure out a better way. + modified_quant_config = NunchakuConfig.from_dict(modified_quant_config) + setattr(hf_quantizer, "quantization_config", modified_quant_config) + logger.debug("TODO") + # Check if `_keep_in_fp32_modules` is not None use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") @@ -443,6 +489,12 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = unexpected_keys = [ param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict ] + for k in unexpected_keys: + if "single_transformer_blocks.0" in k: + print(f"Unexpected {k=}") + for k in empty_state_dict: + if "single_transformer_blocks.0" in k: + print(f"model {k=}") device_map = {"": param_device} load_model_dict_into_meta( model, diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index ef6c41e3ce97..9a62fc12cf63 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -2189,6 +2189,105 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): return converted_state_dict +# Adapted from https://github.com/nunchaku-tech/nunchaku/blob/3ec299f439f9986a69ded320798cab4e258c871d/nunchaku/models/transformers/transformer_flux_v2.py#L395 +def convert_nunchaku_flux_to_diffusers(checkpoint, **kwargs): + from .single_file_utils_nunchaku import _unpack_qkv_state_dict + + _SMOOTH_ORIG_RE = re.compile(r"\.smooth_orig(\.|$)") + _SMOOTH_RE = re.compile(r"\.smooth(\.|$)") + + new_state_dict = {} + model_state_dict = kwargs["model_state_dict"] + + ckpt_keys = list(checkpoint.keys()) + for k in ckpt_keys: + if "qweight" in k: + # only the shape information of this tensor is needed + v = checkpoint[k] + # if the tensor has qweight, but does not have low-rank branch, we need to add some artificial tensors + for t in ["lora_up", "lora_down"]: + new_k = k.replace(".qweight", f".{t}") + if new_k not in ckpt_keys: + oc, ic = v.shape + ic = ic * 2 # v is packed into INT8, so we need to double the size + checkpoint[k.replace(".qweight", f".{t}")] = torch.zeros( + (0, ic) if t == "lora_down" else (oc, 0), device=v.device, dtype=torch.bfloat16 + ) + + for k, v in checkpoint.items(): + new_k = k # start with original, then apply independent replacements + + if k.startswith("single_transformer_blocks."): + # attention / qkv / norms + new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.") + new_k = new_k.replace(".out_proj.", ".proj_out.") + new_k = new_k.replace(".norm_k.", ".attn.norm_k.") + new_k = new_k.replace(".norm_q.", ".attn.norm_q.") + + # mlp heads + new_k = new_k.replace(".mlp_fc1.", ".proj_mlp.") + new_k = new_k.replace(".mlp_fc2.", ".proj_out.") + + # smooth params (use regex to avoid substring collisions) + new_k = _SMOOTH_ORIG_RE.sub(r".smooth_factor_orig1円", new_k) + new_k = _SMOOTH_RE.sub(r".smooth_factor1円", new_k) + + # lora -> proj + new_k = new_k.replace(".lora_down", ".proj_down") + new_k = new_k.replace(".lora_up", ".proj_up") + + elif k.startswith("transformer_blocks."): + # feed-forward (context & base) + new_k = new_k.replace(".mlp_context_fc1.", ".ff_context.net.0.proj.") + new_k = new_k.replace(".mlp_context_fc2.", ".ff_context.net.2.") + new_k = new_k.replace(".mlp_fc1.", ".ff.net.0.proj.") + new_k = new_k.replace(".mlp_fc2.", ".ff.net.2.") + + # attention projections + new_k = new_k.replace(".qkv_proj_context.", ".attn.add_qkv_proj.") + new_k = new_k.replace(".qkv_proj.", ".attn.to_qkv.") + new_k = new_k.replace(".out_proj.", ".attn.to_out.0.") + new_k = new_k.replace(".out_proj_context.", ".attn.to_add_out.") + + # norms + new_k = new_k.replace(".norm_k.", ".attn.norm_k.") + new_k = new_k.replace(".norm_q.", ".attn.norm_q.") + new_k = new_k.replace(".norm_added_k.", ".attn.norm_added_k.") + new_k = new_k.replace(".norm_added_q.", ".attn.norm_added_q.") + + # smooth params + new_k = _SMOOTH_ORIG_RE.sub(r".smooth_factor_orig1円", new_k) + new_k = _SMOOTH_RE.sub(r".smooth_factor1円", new_k) + + # lora -> proj + new_k = new_k.replace(".lora_down", ".proj_down") + new_k = new_k.replace(".lora_up", ".proj_up") + + new_state_dict[new_k] = v + + new_state_dict = _unpack_qkv_state_dict(new_state_dict) + + # some remnant keys need to be patched + new_sd_keys = list(new_state_dict.keys()) + for k in new_sd_keys: + if "qweight" in k: + no_qweight_k = ".".join(k.split(".qweight")[:-1]) + for unexpected_k in ["wzeros"]: + unexpected_k = no_qweight_k + f".{unexpected_k}" + if unexpected_k in new_sd_keys: + _ = new_state_dict.pop(unexpected_k) + for k in model_state_dict: + if k not in new_state_dict: + # CPU device for now + new_state_dict[k] = torch.ones_like(model_state_dict[k], device="cpu") + + for k in new_state_dict: + if "single_transformer_blocks.0" in k and k.endswith(".weight"): + print(f"{k=}") + + return new_state_dict + + def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} keys = list(checkpoint.keys()) diff --git a/src/diffusers/loaders/single_file_utils_nunchaku.py b/src/diffusers/loaders/single_file_utils_nunchaku.py new file mode 100644 index 000000000000..da74aa5211fb --- /dev/null +++ b/src/diffusers/loaders/single_file_utils_nunchaku.py @@ -0,0 +1,104 @@ +import re + +import torch + + +_QKV_ANCHORS_NUNCHAKU = ("to_qkv", "add_qkv_proj") +_ALLOWED_SUFFIXES_NUNCHAKU = { + "bias", + "proj_down", + "proj_up", + "qweight", + "smooth_factor", + "smooth_factor_orig", + "wscales", +} + +_QKV_NUNCHAKU_REGEX = re.compile( + rf"^(?P.*)\.(?:{'|'.join(map(re.escape, _QKV_ANCHORS_NUNCHAKU))})\.(?P.+)$" +) + + +def _pick_split_dim(t: torch.Tensor, suffix: str) -> int: + """ + Choose which dimension to split by 3. Heuristics: + - 1D -> dim 0 + - 2D -> prefer dim=1 for 'qweight' (common layout [*, 3*out_features]), + otherwise prefer dim=0 (common layout [3*out_features, *]). + - If preferred dim isn't divisible by 3, try the other; else error. + """ + shape = list(t.shape) + if len(shape) == 0: + raise ValueError("Cannot split a scalar into Q/K/V.") + + if len(shape) == 1: + dim = 0 + if shape[dim] % 3 == 0: + return dim + raise ValueError(f"1D tensor of length {shape[0]} not divisible by 3.") + + # len(shape)>= 2 + preferred = 1 if suffix == "qweight" else 0 + other = 0 if preferred == 1 else 1 + + if shape[preferred] % 3 == 0: + return preferred + if shape[other] % 3 == 0: + return other + + # Fall back: any dim divisible by 3 + for d, s in enumerate(shape): + if s % 3 == 0: + return d + + raise ValueError(f"None of the dims {shape} are divisible by 3 for suffix '{suffix}'.") + + +def _split_qkv(t: torch.Tensor, dim: int): + return torch.tensor_split(t, 3, dim=dim) + + +def _unpack_qkv_state_dict( + state_dict: dict, anchors=_QKV_ANCHORS_NUNCHAKU, allowed_suffixes=_ALLOWED_SUFFIXES_NUNCHAKU +): + """ + Convert fused QKV entries (e.g., '...to_qkv.bias', '...qkv_proj.wscales') into separate Q/K/V entries: + '...to_q.bias', '...to_k.bias', '...to_v.bias' '...to_q.wscales', '...to_k.wscales', '...to_v.wscales' + Returns a NEW dict; original is not modified. + + Only keys with suffix in `allowed_suffixes` are processed. Keys with non-divisible-by-3 tensors raise a ValueError.: + """ + anchors = tuple(anchors) + allowed_suffixes = set(allowed_suffixes) + + new_sd: dict = {} + sd_keys = list(state_dict.keys()) + for k in sd_keys: + m = _QKV_NUNCHAKU_REGEX.match(k) + v = state_dict.pop(k) + if m: + suffix = m.group("suffix") + if suffix not in allowed_suffixes: + # keep as-is if it's not one of the targeted suffixes + new_sd[k] = v + continue + + prefix = m.group("prefix") # everything before .to_qkv/.qkv_proj + # Decide split axis + split_dim = _pick_split_dim(v, suffix) + q, k_, vv = _split_qkv(v, dim=split_dim) + + # Build new keys + base_q = f"{prefix}.to_q.{suffix}" + base_k = f"{prefix}.to_k.{suffix}" + base_v = f"{prefix}.to_v.{suffix}" + + # Write into result dict + new_sd[base_q] = q + new_sd[base_k] = k_ + new_sd[base_v] = vv + else: + # not a fused qkv key + new_sd[k] = v + + return new_sd diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 8b48ba6b4873..0fed2d4f483f 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -297,6 +297,13 @@ def load_model_dict_into_meta( offload_index = offload_weight(param, param_name, offload_folder, offload_index) elif param_device == "cpu" and state_dict_index is not None: state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) + # This check below might be a bit counter-intuitive in nature. This is because we're checking if the param + # or its module is quantized and if so, we're proceeding with creating a quantized param. This is because + # of the way pre-trained models are loaded. They're initialized under "meta" device, where + # quantization layers are first injected. Hence, for a model that is either pre-quantized or supplemented + # with a `quantization_config` during `from_pretrained`, we expect `check_if_quantized_param` to return True. + # Then depending on the quantization backend being used, we run the actual quantization step under + # `create_quantized_param`. elif is_quantized and ( hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device) ): diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index ce214ae7bc17..a921888a71da 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -21,9 +21,11 @@ from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer from .gguf import GGUFQuantizer +from .nunchaku import NunchakuQuantizer from .quantization_config import ( BitsAndBytesConfig, GGUFQuantizationConfig, + NunchakuConfig, QuantizationConfigMixin, QuantizationMethod, QuantoConfig, @@ -39,6 +41,7 @@ "gguf": GGUFQuantizer, "quanto": QuantoQuantizer, "torchao": TorchAoHfQuantizer, + "nunchaku": NunchakuQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -47,12 +50,13 @@ "gguf": GGUFQuantizationConfig, "quanto": QuantoConfig, "torchao": TorchAoConfig, + "nunchaku": NunchakuConfig, } class DiffusersAutoQuantizer: """ - The auto diffusers quantizer class that takes care of automatically instantiating to the correct + The auto diffusers quantizer class that takes care of automatically instantiating to the correct `DiffusersQuantizer` given the `QuantizationConfig`. """ diff --git a/src/diffusers/quantizers/gguf/gguf_quantizer.py b/src/diffusers/quantizers/gguf/gguf_quantizer.py index aa5ebf5711a3..b0f6c8a2a95a 100644 --- a/src/diffusers/quantizers/gguf/gguf_quantizer.py +++ b/src/diffusers/quantizers/gguf/gguf_quantizer.py @@ -90,7 +90,7 @@ def check_quantized_param_shape(self, param_name, current_param, loaded_param): def check_if_quantized_param( self, model: "ModelMixin", - param_value: Union["GGUFParameter", "torch.Tensor"], + param_value: Union["torch.Tensor"], param_name: str, state_dict: Dict[str, Any], **kwargs, diff --git a/src/diffusers/quantizers/nunchaku/__init__.py b/src/diffusers/quantizers/nunchaku/__init__.py new file mode 100644 index 000000000000..759039054073 --- /dev/null +++ b/src/diffusers/quantizers/nunchaku/__init__.py @@ -0,0 +1 @@ +from .nunchaku_quantizer import NunchakuQuantizer diff --git a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py new file mode 100644 index 000000000000..b2886b118de9 --- /dev/null +++ b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py @@ -0,0 +1,174 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from diffusers.utils.import_utils import is_nunchaku_version + +from ...utils import get_module_from_name, is_accelerate_available, is_nunchaku_available, is_torch_available, logging +from ...utils.torch_utils import is_fp8_available +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + + +if is_torch_available(): + import torch + +if is_nunchaku_available(): + from .utils import replace_with_nunchaku_linear + +logger = logging.get_logger(__name__) + + +KEY_MAP = { + "lora_down": "proj_down", + "lora_up": "proj_up", + "smooth_orig": "smooth_factor_orig", + "smooth": "smooth_factor", +} + + +class NunchakuQuantizer(DiffusersQuantizer): + r""" + Diffusers Quantizer for Nunchaku (https://github.com/nunchaku-tech/nunchaku) + """ + + use_keep_in_fp32_modules = True + requires_calibration = False + required_packages = ["nunchaku", "accelerate"] + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + def validate_environment(self, *args, **kwargs): + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for nunchaku quantization.") + + if not is_nunchaku_available(): + raise ImportError( + "Loading an nunchaku quantized model requires nunchaku library (follow https://nunchaku.tech/docs/nunchaku/installation/installation.html)" + ) + if not is_nunchaku_version(">=", "0.3.1"): + raise ImportError( + "Loading an nunchaku quantized model requires `nunchaku>=1.0.0`. " + "Please upgrade your installation by following https://nunchaku.tech/docs/nunchaku/installation/installation.html." + ) + + if not is_accelerate_available(): + raise ImportError( + "Loading an nunchaku quantized model requires accelerate library (`pip install accelerate`)" + ) + + # TODO: check + # device_map = kwargs.get("device_map", None) + # if isinstance(device_map, dict) and len(device_map.keys())> 1: + # raise ValueError( + # "`device_map` for multi-GPU inference or CPU/disk offload is currently not supported with Diffusers and the nunchaku backend" + # ) + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ): + from nunchaku.models.linear import SVDQW4A4Linear + + module, _ = get_module_from_name(model, param_name) + if self.pre_quantized and isinstance(module, SVDQW4A4Linear): + return True + + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + *args, + **kwargs, + ): + """ + Create a quantized parameter. + """ + from nunchaku.models.linear import SVDQW4A4Linear + + module, tensor_name = get_module_from_name(model, param_name) + if tensor_name not in module._parameters and tensor_name not in module._buffers: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + + if isinstance(module, SVDQW4A4Linear): + module._parameters[tensor_name] = torch.nn.Parameter(param_value, requires_grad=False).to(target_device) + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + precision = self.quantization_config.precision + expected_target_dtypes = [torch.int8] + if is_fp8_available(): + expected_target_dtypes.append(torch.float8_e4m3fn) + if target_dtype not in expected_target_dtypes: + new_target_dtype = self.dtype_map[precision] + + logger.info(f"target_dtype {target_dtype} is replaced by {new_target_dtype} for `nunchaku` quantization") + return new_target_dtype + else: + raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.") + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + # We force the `dtype` to be bfloat16, this is a requirement from `nunchaku` + logger.info( + "Overriding torch_dtype=%s with `torch_dtype=torch.bfloat16` due to " + "requirements of `nunchaku` to enable model loading in 4-bit. " + "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass" + " torch_dtype=torch.bfloat16 to remove this warning.", + torch_dtype, + ) + torch_dtype = torch.bfloat16 + return torch_dtype + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + self.modules_to_not_convert = self.quantization_config.modules_to_not_convert + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + self.modules_to_not_convert.extend(keep_in_fp32_modules) + # Purge `None`. + # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 + # in case of diffusion transformer models. For language models and others alike, `lm_head` + # and tied modules are usually kept in FP32. + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] + + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + if isinstance(device_map, dict) and len(device_map.keys())> 1: + keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + self.modules_to_not_convert.extend(keys_on_cpu) + + model = replace_with_nunchaku_linear( + model, + modules_to_not_convert=self.modules_to_not_convert, + quantization_config=self.quantization_config, + ) + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model, **kwargs): + return model + + @property + def is_serializable(self): + return False + + @property + def is_trainable(self): + return False diff --git a/src/diffusers/quantizers/nunchaku/utils.py b/src/diffusers/quantizers/nunchaku/utils.py new file mode 100644 index 000000000000..4d7eb1a021ea --- /dev/null +++ b/src/diffusers/quantizers/nunchaku/utils.py @@ -0,0 +1,80 @@ +import torch.nn as nn + +from ...utils import is_accelerate_available, is_nunchaku_available, logging + + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +logger = logging.get_logger(__name__) + + +def _replace_with_nunchaku_linear( + model, + svdq_linear_cls, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, +): + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + with init_empty_weights(): + in_features = module.in_features + out_features = module.out_features + + model._modules[name] = svdq_linear_cls( + in_features, + out_features, + rank=quantization_config.rank, + bias=module.bias is not None, + torch_dtype=module.weight.dtype, + ) + has_been_replaced = True + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + if len(list(module.children()))> 0: + _, has_been_replaced = _replace_with_nunchaku_linear( + module, + svdq_linear_cls, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): + if is_nunchaku_available(): + from nunchaku.models.linear import SVDQW4A4Linear + + model, _ = _replace_with_nunchaku_linear( + model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config + ) + + has_been_replaced = any( + isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules() + ) + if not has_been_replaced: + logger.warning( + "You are loading your model in the SVDQuant method but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + return model diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 871faf076e5a..5168382e5bef 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -46,6 +46,7 @@ class QuantizationMethod(str, Enum): GGUF = "gguf" TORCHAO = "torchao" QUANTO = "quanto" + NUNCHAKU = "nunchaku" if is_torchao_available(): @@ -724,3 +725,72 @@ def post_init(self): accepted_weights = ["float8", "int8", "int4", "int2"] if self.weights_dtype not in accepted_weights: raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}") + + +class NunchakuConfig(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using `nunchaku`. + + Args: + TODO + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have some + modules left in their original precision (e.g. `norm` layers in Qwen-Image). + """ + + def __init__( + self, + method: str = "svdquant", + weight_dtype: str = "int4", + weight_scale_dtype: str = None, + weight_group_size: int = 64, + activation_dtype: str = "int4", + activation_scale_dtype: str = None, + activation_group_size: int = 64, + rank: int = 32, + modules_to_not_convert: Optional[List[str]] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.NUNCHAKU + self.method = method + self.weight_dtype = weight_dtype + self.weight_scale_dtype = weight_scale_dtype + self.weight_group_size = weight_group_size + self.activation_dtype = activation_dtype + self.activation_scale_dtype = activation_scale_dtype + self.activation_group_size = activation_group_size + self.rank = rank + self.modules_to_not_convert = modules_to_not_convert + + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct. Hardware checks were largely adapted from the official `nunchaku` + codebase. + """ + from ..utils.torch_utils import get_device + + device = get_device() + if isinstance(device, str): + device = torch.device(device) + capability = torch.cuda.get_device_capability(0 if device.index is None else device.index) + sm = f"{capability[0]}{capability[1]}" + if sm == "120": # you can only use the fp4 models + if self.weight_dtype != "fp4_e2m1_all": + raise ValueError('Please use "fp4" quantization for Blackwell GPUs.') + elif sm in ["75", "80", "86", "89"]: + if self.weight_dtype != "int4": + raise ValueError('Please use "int4" quantization for Turing, Ampere and Ada GPUs.') + else: + raise ValueError( + f"Unsupported GPU architecture {sm} due to the lack of 4-bit tensorcores. " + "Please use a Turing, Ampere, Ada or Blackwell GPU for this quantization configuration." + ) + + # TODO: should there be a check for rank? + + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index b27cf981edeb..12f948414c4d 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -89,6 +89,7 @@ is_matplotlib_available, is_nltk_available, is_note_seq_available, + is_nunchaku_available, is_onnx_available, is_opencv_available, is_optimum_quanto_available, diff --git a/src/diffusers/utils/dummy_nunchaku_objects.py b/src/diffusers/utils/dummy_nunchaku_objects.py new file mode 100644 index 000000000000..2de7cd7c0a69 --- /dev/null +++ b/src/diffusers/utils/dummy_nunchaku_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class NunchakuConfig(metaclass=DummyObject): + _backends = ["nunchaku"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["nunchaku"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["nunchaku"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["nunchaku"]) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index ac209afb74a6..80fd2acbbf4b 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -217,6 +217,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _torchao_available, _torchao_version = _is_package_available("torchao") _bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes") _optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True) +_nunchaku_available, _nunchaku_version = _is_package_available("nunchaku", get_dist_name=True) _pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_available("pytorch_retinaface") _better_profanity_available, _better_profanity_version = _is_package_available("better_profanity") _nltk_available, _nltk_version = _is_package_available("nltk") @@ -363,6 +364,10 @@ def is_optimum_quanto_available(): return _optimum_quanto_available +def is_nunchaku_available(): + return _nunchaku_available + + def is_timm_available(): return _timm_available @@ -816,7 +821,7 @@ def is_k_diffusion_version(operation: str, version: str): def is_optimum_quanto_version(operation: str, version: str): """ - Compares the current Accelerate version to a given reference with an operation. + Compares the current quanto version to a given reference with an operation. Args: operation (`str`): @@ -829,6 +834,21 @@ def is_optimum_quanto_version(operation: str, version: str): return compare_versions(parse(_optimum_quanto_version), operation, version) +def is_nunchaku_version(operation: str, version: str): + """ + Compares the current nunchaku version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _nunchaku_available: + return False + return compare_versions(parse(_nunchaku_version), operation, version) + + def is_xformers_version(operation: str, version: str): """ Compares the current xformers version to a given reference with an operation. diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 5bc708a60c29..c1bf2aed7890 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -197,3 +197,7 @@ def device_synchronize(device_type: Optional[str] = None): device_type = get_device() device_mod = getattr(torch, device_type, torch.cuda) device_mod.synchronize() + + +def is_fp8_available(): + return getattr(torch, "float8_e4m3fn", None) is None

AltStyle によって変換されたページ (->オリジナル) /