diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md index e46aa55ad82a..b9c5990f2435 100644 --- a/docs/source/en/api/pipelines/wan.md +++ b/docs/source/en/api/pipelines/wan.md @@ -333,6 +333,8 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip - Wan 2.1 and 2.2 support using [LightX2V LoRAs](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Lightx2v) to speed up inference. Using them on Wan 2.2 is slightly more involed. Refer to [this code snippet](https://github.com/huggingface/diffusers/pull/12040#issuecomment-3144185272) to learn more. +- Wan 2.2 has two denoisers. By default, LoRAs are only loaded into the first denoiser. One can set `load_into_transformer_2=True` to load LoRAs into the second denoiser. Refer to [this](https://github.com/huggingface/diffusers/pull/12074#issue-3292620048) and [this](https://github.com/huggingface/diffusers/pull/12074#issuecomment-3155896144) examples to learn more. + ## WanPipeline [[autodoc]] WanPipeline diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 3089086d5478..d18c82df4f62 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -754,7 +754,11 @@ def set_adapters( # Decompose weights into weights for denoiser and text encoders. _component_adapter_weights = {} for component in self._lora_loadable_modules: - model = getattr(self, component) + model = getattr(self, component, None) + # To guard for cases like Wan. In Wan2.1 and WanVace, we have a single denoiser. + # Whereas in Wan 2.2, we have two denoisers. + if model is None: + continue for adapter_name, weights in zip(adapter_names, adapter_weights): if isinstance(weights, dict): diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index e7ed379e9cff..d1692bd61ba5 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -1833,6 +1833,17 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict ) + def get_alpha_scales(down_weight, alpha_key): + rank = down_weight.shape[0] + alpha = original_state_dict.pop(alpha_key).item() + scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here + scale_down = scale + scale_up = 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + return scale_down, scale_up + for key in list(original_state_dict.keys()): if key.endswith((".diff", ".diff_b")) and "norm" in key: # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it @@ -1852,15 +1863,26 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): for i in range(min_block, max_block + 1): # Self-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): - original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight" - converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) + alpha_key = f"blocks.{i}.self_attn.{o}.alpha" + has_alpha = alpha_key in original_state_dict + original_key_A = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight" + converted_key_A = f"blocks.{i}.attn1.{c}.lora_A.weight" - original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight" - converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) + original_key_B = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight" + converted_key_B = f"blocks.{i}.attn1.{c}.lora_B.weight" + + if has_alpha: + down_weight = original_state_dict.pop(original_key_A) + up_weight = original_state_dict.pop(original_key_B) + scale_down, scale_up = get_alpha_scales(down_weight, alpha_key) + converted_state_dict[converted_key_A] = down_weight * scale_down + converted_state_dict[converted_key_B] = up_weight * scale_up + + else: + if original_key_A in original_state_dict: + converted_state_dict[converted_key_A] = original_state_dict.pop(original_key_A) + if original_key_B in original_state_dict: + converted_state_dict[converted_key_B] = original_state_dict.pop(original_key_B) original_key = f"blocks.{i}.self_attn.{o}.diff_b" converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias" @@ -1869,15 +1891,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): # Cross-attention for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]): - original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" - converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" - converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) + alpha_key = f"blocks.{i}.cross_attn.{o}.alpha" + has_alpha = alpha_key in original_state_dict + original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" + converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight" + + original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" + converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight" + + if original_key_A in original_state_dict: + down_weight = original_state_dict.pop(original_key_A) + converted_state_dict[converted_key_A] = down_weight + if original_key_B in original_state_dict: + up_weight = original_state_dict.pop(original_key_B) + converted_state_dict[converted_key_B] = up_weight + if has_alpha: + scale_down, scale_up = get_alpha_scales(down_weight, alpha_key) + converted_state_dict[converted_key_A] *= scale_down + converted_state_dict[converted_key_B] *= scale_up original_key = f"blocks.{i}.cross_attn.{o}.diff_b" converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias" @@ -1886,15 +1917,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): if is_i2v_lora: for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): - original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" - converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" - converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) + alpha_key = f"blocks.{i}.cross_attn.{o}.alpha" + has_alpha = alpha_key in original_state_dict + original_key_A = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight" + converted_key_A = f"blocks.{i}.attn2.{c}.lora_A.weight" + + original_key_B = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight" + converted_key_B = f"blocks.{i}.attn2.{c}.lora_B.weight" + + if original_key_A in original_state_dict: + down_weight = original_state_dict.pop(original_key_A) + converted_state_dict[converted_key_A] = down_weight + if original_key_B in original_state_dict: + up_weight = original_state_dict.pop(original_key_B) + converted_state_dict[converted_key_B] = up_weight + if has_alpha: + scale_down, scale_up = get_alpha_scales(down_weight, alpha_key) + converted_state_dict[converted_key_A] *= scale_down + converted_state_dict[converted_key_B] *= scale_up original_key = f"blocks.{i}.cross_attn.{o}.diff_b" converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias" @@ -1903,15 +1943,24 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict): # FFN for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]): - original_key = f"blocks.{i}.{o}.{lora_down_key}.weight" - converted_key = f"blocks.{i}.ffn.{c}.lora_A.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) - - original_key = f"blocks.{i}.{o}.{lora_up_key}.weight" - converted_key = f"blocks.{i}.ffn.{c}.lora_B.weight" - if original_key in original_state_dict: - converted_state_dict[converted_key] = original_state_dict.pop(original_key) + alpha_key = f"blocks.{i}.{o}.alpha" + has_alpha = alpha_key in original_state_dict + original_key_A = f"blocks.{i}.{o}.{lora_down_key}.weight" + converted_key_A = f"blocks.{i}.ffn.{c}.lora_A.weight" + + original_key_B = f"blocks.{i}.{o}.{lora_up_key}.weight" + converted_key_B = f"blocks.{i}.ffn.{c}.lora_B.weight" + + if original_key_A in original_state_dict: + down_weight = original_state_dict.pop(original_key_A) + converted_state_dict[converted_key_A] = down_weight + if original_key_B in original_state_dict: + up_weight = original_state_dict.pop(original_key_B) + converted_state_dict[converted_key_B] = up_weight + if has_alpha: + scale_down, scale_up = get_alpha_scales(down_weight, alpha_key) + converted_state_dict[converted_key_A] *= scale_down + converted_state_dict[converted_key_B] *= scale_up original_key = f"blocks.{i}.{o}.diff_b" converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias" diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 97e5647360ee..572ace472f1d 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5065,7 +5065,7 @@ class WanLoraLoaderMixin(LoraBaseMixin): Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`]. """ - _lora_loadable_modules = ["transformer"] + _lora_loadable_modules = ["transformer", "transformer_2"] transformer_name = TRANSFORMER_NAME @classmethod @@ -5270,15 +5270,35 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) + load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) + if load_into_transformer_2: + if not hasattr(self, "transformer_2"): + raise AttributeError( + f"'{type(self).__name__}' object has no attribute transformer_2" + "Note that Wan2.1 models do not have a transformer_2 component." + "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True." + ) + self.load_lora_into_transformer( + state_dict, + transformer=self.transformer_2, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + else: + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) + if not hasattr(self, "transformer") + else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel @@ -5668,15 +5688,35 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) + load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False) + if load_into_transformer_2: + if not hasattr(self, "transformer_2"): + raise AttributeError( + f"'{type(self).__name__}' object has no attribute transformer_2" + "Note that Wan2.1 models do not have a transformer_2 component." + "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True." + ) + self.load_lora_into_transformer( + state_dict, + transformer=self.transformer_2, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + else: + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) + if not hasattr(self, "transformer") + else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel

AltStyle によって変換されたページ (->オリジナル) /