Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 8d1de40

Browse files
linoytsabansayakpaulgithub-actions[bot]linoy
authored
[Wan 2.2 LoRA] add support for 2nd transformer lora loading + wan 2.2 lightx2v lora (#12074)
* add alpha * load into 2nd transformer * Update src/diffusers/loaders/lora_conversion_utils.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/loaders/lora_conversion_utils.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * pr comments * pr comments * pr comments * fix * fix * Apply style fixes * fix copies * fix * fix copies * Update src/diffusers/loaders/lora_pipeline.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * revert change * revert change * fix copies * up * fix --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: linoy <linoy@hf.co>
1 parent 8cc528c commit 8d1de40

File tree

4 files changed

+150
-55
lines changed

4 files changed

+150
-55
lines changed

‎docs/source/en/api/pipelines/wan.md‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,8 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
333333

334334
- Wan 2.1 and 2.2 support using [LightX2V LoRAs](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Lightx2v) to speed up inference. Using them on Wan 2.2 is slightly more involed. Refer to [this code snippet](https://github.com/huggingface/diffusers/pull/12040#issuecomment-3144185272) to learn more.
335335

336+
- Wan 2.2 has two denoisers. By default, LoRAs are only loaded into the first denoiser. One can set `load_into_transformer_2=True` to load LoRAs into the second denoiser. Refer to [this](https://github.com/huggingface/diffusers/pull/12074#issue-3292620048) and [this](https://github.com/huggingface/diffusers/pull/12074#issuecomment-3155896144) examples to learn more.
337+
336338
## WanPipeline
337339

338340
[[autodoc]] WanPipeline

‎src/diffusers/loaders/lora_base.py‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,11 @@ def set_adapters(
754754
# Decompose weights into weights for denoiser and text encoders.
755755
_component_adapter_weights = {}
756756
for component in self._lora_loadable_modules:
757-
model = getattr(self, component)
757+
model = getattr(self, component, None)
758+
# To guard for cases like Wan. In Wan2.1 and WanVace, we have a single denoiser.
759+
# Whereas in Wan 2.2, we have two denoisers.
760+
if model is None:
761+
continue
758762

759763
for adapter_name, weights in zip(adapter_names, adapter_weights):
760764
if isinstance(weights, dict):

‎src/diffusers/loaders/lora_conversion_utils.py‎

Lines changed: 84 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,6 +1833,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18331833
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
18341834
)
18351835

1836+
def get_alpha_scales(down_weight, alpha_key):
1837+
rank = down_weight.shape[0]
1838+
alpha = original_state_dict.pop(alpha_key).item()
1839+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
1840+
scale_down = scale
1841+
scale_up = 1.0
1842+
while scale_down * 2 < scale_up:
1843+
scale_down *= 2
1844+
scale_up /= 2
1845+
return scale_down, scale_up
1846+
18361847
for key in list(original_state_dict.keys()):
18371848
if key.endswith((".diff", ".diff_b")) and "norm" in key:
18381849
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
@@ -1852,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18521863
for i in range(min_block, max_block + 1):
18531864
# Self-attention
18541865
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1855-
original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
1856-
converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
1857-
iforiginal_keyinoriginal_state_dict:
1858-
converted_state_dict[converted_key] =original_state_dict.pop(original_key)
1866+
alpha_key = f"blocks.{i}.self_attn.{o}.alpha"
1867+
has_alpha = alpha_keyinoriginal_state_dict
1868+
original_key_A=f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
1869+
converted_key_A=f"blocks.{i}.attn1.{c}.lora_A.weight"
18591870

1860-
original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
1861-
converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
1862-
if original_key in original_state_dict:
1863-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1871+
original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
1872+
converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight"
1873+
1874+
if has_alpha:
1875+
down_weight = original_state_dict.pop(original_key_A)
1876+
up_weight = original_state_dict.pop(original_key_B)
1877+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1878+
converted_state_dict[converted_key_A] = down_weight * scale_down
1879+
converted_state_dict[converted_key_B] = up_weight * scale_up
1880+
1881+
else:
1882+
if original_key_A in original_state_dict:
1883+
converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A)
1884+
if original_key_B in original_state_dict:
1885+
converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B)
18641886

18651887
original_key = f"blocks.{i}.self_attn.{o}.diff_b"
18661888
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
@@ -1869,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18691891

18701892
# Cross-attention
18711893
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1872-
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1873-
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
1874-
if original_key in original_state_dict:
1875-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1876-
1877-
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1878-
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
1879-
if original_key in original_state_dict:
1880-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1894+
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
1895+
has_alpha = alpha_key in original_state_dict
1896+
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1897+
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
1898+
1899+
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1900+
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
1901+
1902+
if original_key_A in original_state_dict:
1903+
down_weight = original_state_dict.pop(original_key_A)
1904+
converted_state_dict[converted_key_A] = down_weight
1905+
if original_key_B in original_state_dict:
1906+
up_weight = original_state_dict.pop(original_key_B)
1907+
converted_state_dict[converted_key_B] = up_weight
1908+
if has_alpha:
1909+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1910+
converted_state_dict[converted_key_A] *= scale_down
1911+
converted_state_dict[converted_key_B] *= scale_up
18811912

18821913
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
18831914
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1886,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
18861917

18871918
if is_i2v_lora:
18881919
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1889-
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1890-
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
1891-
if original_key in original_state_dict:
1892-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1893-
1894-
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1895-
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
1896-
if original_key in original_state_dict:
1897-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1920+
alpha_key = f"blocks.{i}.cross_attn.{o}.alpha"
1921+
has_alpha = alpha_key in original_state_dict
1922+
original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1923+
converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight"
1924+
1925+
original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1926+
converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight"
1927+
1928+
if original_key_A in original_state_dict:
1929+
down_weight = original_state_dict.pop(original_key_A)
1930+
converted_state_dict[converted_key_A] = down_weight
1931+
if original_key_B in original_state_dict:
1932+
up_weight = original_state_dict.pop(original_key_B)
1933+
converted_state_dict[converted_key_B] = up_weight
1934+
if has_alpha:
1935+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1936+
converted_state_dict[converted_key_A] *= scale_down
1937+
converted_state_dict[converted_key_B] *= scale_up
18981938

18991939
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
19001940
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
@@ -1903,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
19031943

19041944
# FFN
19051945
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
1906-
original_key = f"blocks.{i}.{o}.{lora_down_key}.weight"
1907-
converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight"
1908-
if original_key in original_state_dict:
1909-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1910-
1911-
original_key = f"blocks.{i}.{o}.{lora_up_key}.weight"
1912-
converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight"
1913-
if original_key in original_state_dict:
1914-
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1946+
alpha_key = f"blocks.{i}.{o}.alpha"
1947+
has_alpha = alpha_key in original_state_dict
1948+
original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight"
1949+
converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight"
1950+
1951+
original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight"
1952+
converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight"
1953+
1954+
if original_key_A in original_state_dict:
1955+
down_weight = original_state_dict.pop(original_key_A)
1956+
converted_state_dict[converted_key_A] = down_weight
1957+
if original_key_B in original_state_dict:
1958+
up_weight = original_state_dict.pop(original_key_B)
1959+
converted_state_dict[converted_key_B] = up_weight
1960+
if has_alpha:
1961+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
1962+
converted_state_dict[converted_key_A] *= scale_down
1963+
converted_state_dict[converted_key_B] *= scale_up
19151964

19161965
original_key = f"blocks.{i}.{o}.diff_b"
19171966
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"

‎src/diffusers/loaders/lora_pipeline.py‎

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5065,7 +5065,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
50655065
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
50665066
"""
50675067

5068-
_lora_loadable_modules = ["transformer"]
5068+
_lora_loadable_modules = ["transformer", "transformer_2"]
50695069
transformer_name = TRANSFORMER_NAME
50705070

50715071
@classmethod
@@ -5270,15 +5270,35 @@ def load_lora_weights(
52705270
if not is_correct_format:
52715271
raise ValueError("Invalid LoRA checkpoint.")
52725272

5273-
self.load_lora_into_transformer(
5274-
state_dict,
5275-
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5276-
adapter_name=adapter_name,
5277-
metadata=metadata,
5278-
_pipeline=self,
5279-
low_cpu_mem_usage=low_cpu_mem_usage,
5280-
hotswap=hotswap,
5281-
)
5273+
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
5274+
if load_into_transformer_2:
5275+
if not hasattr(self, "transformer_2"):
5276+
raise AttributeError(
5277+
f"'{type(self).__name__}' object has no attribute transformer_2"
5278+
"Note that Wan2.1 models do not have a transformer_2 component."
5279+
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
5280+
)
5281+
self.load_lora_into_transformer(
5282+
state_dict,
5283+
transformer=self.transformer_2,
5284+
adapter_name=adapter_name,
5285+
metadata=metadata,
5286+
_pipeline=self,
5287+
low_cpu_mem_usage=low_cpu_mem_usage,
5288+
hotswap=hotswap,
5289+
)
5290+
else:
5291+
self.load_lora_into_transformer(
5292+
state_dict,
5293+
transformer=getattr(self, self.transformer_name)
5294+
if not hasattr(self, "transformer")
5295+
else self.transformer,
5296+
adapter_name=adapter_name,
5297+
metadata=metadata,
5298+
_pipeline=self,
5299+
low_cpu_mem_usage=low_cpu_mem_usage,
5300+
hotswap=hotswap,
5301+
)
52825302

52835303
@classmethod
52845304
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
@@ -5668,15 +5688,35 @@ def load_lora_weights(
56685688
if not is_correct_format:
56695689
raise ValueError("Invalid LoRA checkpoint.")
56705690

5671-
self.load_lora_into_transformer(
5672-
state_dict,
5673-
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5674-
adapter_name=adapter_name,
5675-
metadata=metadata,
5676-
_pipeline=self,
5677-
low_cpu_mem_usage=low_cpu_mem_usage,
5678-
hotswap=hotswap,
5679-
)
5691+
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
5692+
if load_into_transformer_2:
5693+
if not hasattr(self, "transformer_2"):
5694+
raise AttributeError(
5695+
f"'{type(self).__name__}' object has no attribute transformer_2"
5696+
"Note that Wan2.1 models do not have a transformer_2 component."
5697+
"Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
5698+
)
5699+
self.load_lora_into_transformer(
5700+
state_dict,
5701+
transformer=self.transformer_2,
5702+
adapter_name=adapter_name,
5703+
metadata=metadata,
5704+
_pipeline=self,
5705+
low_cpu_mem_usage=low_cpu_mem_usage,
5706+
hotswap=hotswap,
5707+
)
5708+
else:
5709+
self.load_lora_into_transformer(
5710+
state_dict,
5711+
transformer=getattr(self, self.transformer_name)
5712+
if not hasattr(self, "transformer")
5713+
else self.transformer,
5714+
adapter_name=adapter_name,
5715+
metadata=metadata,
5716+
_pipeline=self,
5717+
low_cpu_mem_usage=low_cpu_mem_usage,
5718+
hotswap=hotswap,
5719+
)
56805720

56815721
@classmethod
56825722
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /