Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit f20aba3

Browse files
sayakpaulDN6
andauthored
[GGUF] feat: support loading diffusers format gguf checkpoints. (#11684)
* feat: support loading diffusers format gguf checkpoints. * update * update * qwen --------- Co-authored-by: DN6 <dhruv.nair@gmail.com>
1 parent ccf2c31 commit f20aba3

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

‎src/diffusers/loaders/single_file_model.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,17 @@
153153
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
154154
"default_subfolder": "transformer",
155155
},
156+
"QwenImageTransformer2DModel": {
157+
"checkpoint_mapping_fn": lambda x: x,
158+
"default_subfolder": "transformer",
159+
},
156160
}
157161

158162

163+
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
164+
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
165+
166+
159167
def _get_single_file_loadable_mapping_class(cls):
160168
diffusers_module = importlib.import_module(__name__.split(".")[0])
161169
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
@@ -381,19 +389,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
381389
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
382390
diffusers_model_config.update(model_kwargs)
383391

392+
ctx = init_empty_weights if is_accelerate_available() else nullcontext
393+
with ctx():
394+
model = cls.from_config(diffusers_model_config)
395+
384396
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
385-
diffusers_format_checkpoint = checkpoint_mapping_fn(
386-
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
387-
)
397+
398+
if _should_convert_state_dict_to_diffusers(model.state_dict(), checkpoint):
399+
diffusers_format_checkpoint = checkpoint_mapping_fn(
400+
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
401+
)
402+
else:
403+
diffusers_format_checkpoint = checkpoint
404+
388405
if not diffusers_format_checkpoint:
389406
raise SingleFileComponentError(
390407
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
391408
)
392-
393-
ctx = init_empty_weights if is_accelerate_available() else nullcontext
394-
with ctx():
395-
model = cls.from_config(diffusers_model_config)
396-
397409
# Check if `_keep_in_fp32_modules` is not None
398410
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
399411
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")

‎src/diffusers/loaders/single_file_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
6161

6262
CHECKPOINT_KEY_NAMES = {
63+
"v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
6364
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
6465
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
6566
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",

‎tests/quantization/gguf/test_gguf.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def _check_for_gguf_linear(model):
212212

213213
class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
214214
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
215+
diffusers_ckpt_path = "https://huggingface.co/sayakpaul/flux-diffusers-gguf/blob/main/model-Q4_0.gguf"
215216
torch_dtype = torch.bfloat16
216217
model_cls = FluxTransformer2DModel
217218
expected_memory_use_in_gb = 5
@@ -296,6 +297,16 @@ def test_pipeline_inference(self):
296297
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
297298
assert max_diff < 1e-4
298299

300+
def test_loading_gguf_diffusers_format(self):
301+
model = self.model_cls.from_single_file(
302+
self.diffusers_ckpt_path,
303+
subfolder="transformer",
304+
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
305+
config="black-forest-labs/FLUX.1-dev",
306+
)
307+
model.to("cuda")
308+
model(**self.get_dummy_inputs())
309+
299310

300311
class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
301312
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /