Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[wip][quantization] incorporate nunchaku #12207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
sayakpaul wants to merge 13 commits into main
base: main
Choose a base branch
Loading
from nunchaku
Draft

[wip][quantization] incorporate nunchaku #12207

sayakpaul wants to merge 13 commits into main from nunchaku

Conversation

Copy link
Member

@sayakpaul sayakpaul commented Aug 21, 2025
edited
Loading

What does this PR do?

Caution

Doesn't work yet.

Test code:

from diffusers import DiffusionPipeline, AutoModel, NunchakuConfig
import torch 
ckpt_id = "black-forest-labs/FLUX.1-dev"
model = AutoModel.from_pretrained(
 ckpt_id, 
 subfolder="transformer",
 torch_dtype=torch.bfloat16, 
 quantization_config=NunchakuConfig()
)
pipe = DiffusionPipeline.from_pretrained(
 ckpt_id, transformer=model, torch_dtype=torch.bfloat16
)
image = pipe(
 "A cat holding a sign that says hello world", 
 num_inference_steps=50, 
 guidance_scale=3.5,
 generator=torch.manual_seed(0),
).images[0]
image.save(f"nunchaku.png")

diffusers-cli env:

- 🤗 Diffusers version: 0.36.0.dev0
- Platform: Linux-6.8.0-55-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.10.12
- PyTorch version (GPU?): 2.8.0.dev20250626+cu126 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.34.4
- Transformers version: 4.53.2
- Accelerate version: 1.10.0.dev0
- PEFT version: 0.17.0
- Bitsandbytes version: 0.46.0
- Safetensors version: 0.5.3
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
NVIDIA GeForce RTX 4090, 24564 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

@lmxyy I am going to outline the stage we're currently at in this integration as that will help us better understand the blockers.

Cc: @SunMarc

tolgacangoz reacted with hooray emoji
Copy link
Member Author

Let me outline the stage where we're currently at, as this will help us understand the current blockers:

This is for quantizing a pre-trained non-quantized model checkpoint as opposed to trying to directly load a quantized checkpoint.

If you have suggestions, please LMK.

Copy link
Member Author

Discussed with @SunMarc internally. We will also try to first supported pre-quantized checkpoints from https://huggingface.co/nunchaku-tech/nunchaku and see how it goes.

Copy link
Member Author

Tried a bit for loading pre-quantized checkpoints. The issues currently are:

  • The prequantized checkpoint (example) has mlp_fc* keys which aren't present in our implementation for Flux. This needs to be accounted for.
  • It uses horizontal fusion for attention in the checkpoints -- something we don't support in our implementation yet. This will also need to be accounted for.
Code
from diffusers import DiffusionPipeline, FluxTransformer2DModel, NunchakuConfig
from nunchaku.models.linear import SVDQW4A4Linear
from safetensors import safe_open
from huggingface_hub import hf_hub_download
import torch 
def modules_without_qweight(safetensors_path: str):
 no_qweight = set()
 with safe_open(safetensors_path, framework="pt", device="cpu") as f:
 for key in f.keys():
 if key.endswith(".weight"):
 # module name is everything except the last piece after "."
 module_name = ".".join(key.split(".")[:-1])
 no_qweight.add(module_name)
 return sorted(no_qweight)
ckpt_id = "black-forest-labs/FLUX.1-dev"
state_dict_path = hf_hub_download(repo_id="nunchaku-tech/nunchaku-flux.1-dev", filename="svdq-int4_r32-flux.1-dev.safetensors")
modules_to_not_convert = modules_without_qweight(state_dict_path)
# print(f"{modules_to_convert=}")
model = FluxTransformer2DModel.from_single_file(
 state_dict_path,
 config=ckpt_id, 
 subfolder="transformer",
 torch_dtype=torch.bfloat16, 
 quantization_config=NunchakuConfig(
 weight_dtype="int4",
 weight_group_size=64,
 activation_dtype="int4",
 activation_group_size=64,
 modules_to_not_convert=modules_to_not_convert
 )
).to("cuda")
has_svd = any(isinstance(module, SVDQW4A4Linear) for _, module in model.named_modules())
assert has_svd
pipe = DiffusionPipeline.from_pretrained(
 ckpt_id, transformer=model, torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
 "A cat holding a sign that says hello world", 
 num_inference_steps=50, 
 guidance_scale=3.5,
 generator=torch.manual_seed(0),
).images[0]
image.save(f"nunchaku.png")

Cc: @SunMarc

SunMarc reacted with thumbs up emoji

Copy link
Contributor

lmxyy commented Aug 22, 2025

SunMarc reacted with thumbs up emoji

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Reviewers
No reviews
Assignees
No one assigned
Labels
None yet
Projects
None yet
Milestone
No milestone
Development

Successfully merging this pull request may close these issues.

2 participants

AltStyle によって変換されたページ (->オリジナル) /