Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 11c27c8

Browse files
DN6sayakpaul
authored andcommitted
Fix typos (#9739)
* update * update * update * update * update * update
1 parent a2591a6 commit 11c27c8

File tree

2 files changed

+95
-2
lines changed

2 files changed

+95
-2
lines changed

‎docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md‎

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,26 @@ image = pipe("a picture of a cat holding a sign that says hello world").images[0
313313
image.save('sd3-single-file-t5-fp8.png')
314314
```
315315

316+
### Loading the single file checkpoint for the Stable Diffusion 3.5 Transformer Model
317+
318+
```python
319+
import torch
320+
from diffusers import SD3Transformer2DModel, StableDiffusion3Pipeline
321+
322+
transformer = SD3Transformer2DModel.from_single_file(
323+
"https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo/blob/main/sd3.5_large.safetensors",
324+
torch_dtype=torch.bfloat16,
325+
)
326+
pipe = StableDiffusion3Pipeline.from_pretrained(
327+
"stabilityai/stable-diffusion-3.5-large",
328+
transformer=transformer,
329+
torch_dtype=torch.bfloat16,
330+
)
331+
pipe.enable_model_cpu_offload()
332+
image = pipe("a cat holding a sign that says hello world").images[0]
333+
image.save("sd35.png")
334+
```
335+
316336
## StableDiffusion3Pipeline
317337

318338
[[autodoc]] StableDiffusion3Pipeline

‎src/diffusers/loaders/single_file_utils.py‎

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
7676
"stable_cascade_stage_c": "clip_txt_mapper.weight",
7777
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
78+
"sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
7879
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
7980
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
8081
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
@@ -113,6 +114,9 @@
113114
"sd3": {
114115
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
115116
},
117+
"sd35_large": {
118+
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large",
119+
},
116120
"animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
117121
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
118122
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
@@ -504,9 +508,12 @@ def infer_diffusers_model_type(checkpoint):
504508
):
505509
model_type = "stable_cascade_stage_b"
506510

507-
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
511+
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpointandcheckpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] ==9216:
508512
model_type = "sd3"
509513

514+
elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint:
515+
model_type = "sd35_large"
516+
510517
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
511518
if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
512519
model_type = "animatediff_scribble"
@@ -1670,6 +1677,22 @@ def swap_scale_shift(weight, dim):
16701677
return new_weight
16711678

16721679

1680+
def get_attn2_layers(state_dict):
1681+
attn2_layers = []
1682+
for key in state_dict.keys():
1683+
if "attn2." in key:
1684+
# Extract the layer number from the key
1685+
layer_num = int(key.split(".")[1])
1686+
attn2_layers.append(layer_num)
1687+
1688+
return tuple(sorted(set(attn2_layers)))
1689+
1690+
1691+
def get_caption_projection_dim(state_dict):
1692+
caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
1693+
return caption_projection_dim
1694+
1695+
16731696
def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
16741697
converted_state_dict = {}
16751698
keys = list(checkpoint.keys())
@@ -1678,7 +1701,10 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
16781701
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
16791702

16801703
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
1681-
caption_projection_dim = 1536
1704+
dual_attention_layers = get_attn2_layers(checkpoint)
1705+
1706+
caption_projection_dim = get_caption_projection_dim(checkpoint)
1707+
has_qk_norm = any("ln_q" in key for key in checkpoint.keys())
16821708

16831709
# Positional and patch embeddings.
16841710
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
@@ -1735,6 +1761,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
17351761
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
17361762
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
17371763

1764+
# qk norm
1765+
if has_qk_norm:
1766+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop(
1767+
f"joint_blocks.{i}.x_block.attn.ln_q.weight"
1768+
)
1769+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop(
1770+
f"joint_blocks.{i}.x_block.attn.ln_k.weight"
1771+
)
1772+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop(
1773+
f"joint_blocks.{i}.context_block.attn.ln_q.weight"
1774+
)
1775+
converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop(
1776+
f"joint_blocks.{i}.context_block.attn.ln_k.weight"
1777+
)
1778+
17381779
# output projections.
17391780
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
17401781
f"joint_blocks.{i}.x_block.attn.proj.weight"
@@ -1750,6 +1791,38 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
17501791
f"joint_blocks.{i}.context_block.attn.proj.bias"
17511792
)
17521793

1794+
if i in dual_attention_layers:
1795+
# Q, K, V
1796+
sample_q2, sample_k2, sample_v2 = torch.chunk(
1797+
checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
1798+
)
1799+
sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
1800+
checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
1801+
)
1802+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
1803+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
1804+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
1805+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
1806+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
1807+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])
1808+
1809+
# qk norm
1810+
if has_qk_norm:
1811+
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop(
1812+
f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
1813+
)
1814+
converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop(
1815+
f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
1816+
)
1817+
1818+
# output projections.
1819+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
1820+
f"joint_blocks.{i}.x_block.attn2.proj.weight"
1821+
)
1822+
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
1823+
f"joint_blocks.{i}.x_block.attn2.proj.bias"
1824+
)
1825+
17531826
# norms.
17541827
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
17551828
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /