-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[wip][quantization] incorporate nunchaku
#12207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Let me outline the stage where we're currently at, as this will help us understand the current blockers:
This is for quantizing a pre-trained non-quantized model checkpoint as opposed to trying to directly load a quantized checkpoint.
-
This is where we're replacing the linear modules:
new_module = SVDQW4A4Linear.from_linear(Now, we're doing it here instead of the typical replacement step because if we do so, we won't have
qweight
andwscales
in the pre-trained (non-quantized state dict as shown in the example) state dict. This will lead to errors. -
The above isn't uncommon. We do this for TorchAO as well:
-
However, there doesn't seem to be a method in
nunchaku
that can quantize a pre-trained parameter. This is the current blocker. So, simply doing the following isn't supposed to work as expected:
https://github.com/huggingface/diffusers/blob/nunchaku/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py#L110-L136
If you have suggestions, please LMK.
Discussed with @SunMarc internally. We will also try to first supported pre-quantized checkpoints from https://huggingface.co/nunchaku-tech/nunchaku and see how it goes.
Tried a bit for loading pre-quantized checkpoints. The issues currently are:
- The prequantized checkpoint (example) has
mlp_fc*
keys which aren't present in our implementation for Flux. This needs to be accounted for. - It uses horizontal fusion for attention in the checkpoints -- something we don't support in our implementation yet. This will also need to be accounted for.
Code
from diffusers import DiffusionPipeline, FluxTransformer2DModel, NunchakuConfig from nunchaku.models.linear import SVDQW4A4Linear from safetensors import safe_open from huggingface_hub import hf_hub_download import torch def modules_without_qweight(safetensors_path: str): no_qweight = set() with safe_open(safetensors_path, framework="pt", device="cpu") as f: for key in f.keys(): if key.endswith(".weight"): # module name is everything except the last piece after "." module_name = ".".join(key.split(".")[:-1]) no_qweight.add(module_name) return sorted(no_qweight) ckpt_id = "black-forest-labs/FLUX.1-dev" state_dict_path = hf_hub_download(repo_id="nunchaku-tech/nunchaku-flux.1-dev", filename="svdq-int4_r32-flux.1-dev.safetensors") modules_to_not_convert = modules_without_qweight(state_dict_path) # print(f"{modules_to_convert=}") model = FluxTransformer2DModel.from_single_file( state_dict_path, config=ckpt_id, subfolder="transformer", torch_dtype=torch.bfloat16, quantization_config=NunchakuConfig( weight_dtype="int4", weight_group_size=64, activation_dtype="int4", activation_group_size=64, modules_to_not_convert=modules_to_not_convert ) ).to("cuda") has_svd = any(isinstance(module, SVDQW4A4Linear) for _, module in model.named_modules()) assert has_svd pipe = DiffusionPipeline.from_pretrained( ckpt_id, transformer=model, torch_dtype=torch.bfloat16 ).to("cuda") image = pipe( "A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5, generator=torch.manual_seed(0), ).images[0] image.save(f"nunchaku.png")
Cc: @SunMarc
Uh oh!
There was an error while loading. Please reload this page.
What does this PR do?
Caution
Doesn't work yet.
Test code:
diffusers-cli env
:@lmxyy I am going to outline the stage we're currently at in this integration as that will help us better understand the blockers.
Cc: @SunMarc