Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit d45199a

Browse files
dg845github-actions[bot]a-r-r-o-w
authored
Implement Frequency-Decoupled Guidance (FDG) as a Guider (#11976)
* Initial commit implementing frequency-decoupled guidance (FDG) as a guider * Update FrequencyDecoupledGuidance docstring to describe FDG * Update project so that it accepts any number of non-batch dims * Change guidance_scale and other params to accept a list of params for each freq level * Add comment with Laplacian pyramid shapes * Add function to import_utils to check if the kornia package is available * Only import from kornia if package is available * Fix bug: use pred_cond/uncond in freq space rather than data space * Allow guidance rescaling to be done in data space or frequency space (speculative) * Add kornia install instructions to kornia import error message * Add config to control whether operations are upcast to fp64 * Add parallel_weights recommended values to docstring * Apply style fixes * make fix-copies --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Aryan <aryan@huggingface.co>
1 parent 0611631 commit d45199a

File tree

6 files changed

+352
-0
lines changed

6 files changed

+352
-0
lines changed

‎src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
"AutoGuidance",
140140
"ClassifierFreeGuidance",
141141
"ClassifierFreeZeroStarGuidance",
142+
"FrequencyDecoupledGuidance",
142143
"PerturbedAttentionGuidance",
143144
"SkipLayerGuidance",
144145
"SmoothedEnergyGuidance",
@@ -804,6 +805,7 @@
804805
AutoGuidance,
805806
ClassifierFreeGuidance,
806807
ClassifierFreeZeroStarGuidance,
808+
FrequencyDecoupledGuidance,
807809
PerturbedAttentionGuidance,
808810
SkipLayerGuidance,
809811
SmoothedEnergyGuidance,

‎src/diffusers/guiders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .auto_guidance import AutoGuidance
2323
from .classifier_free_guidance import ClassifierFreeGuidance
2424
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
25+
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
2526
from .perturbed_attention_guidance import PerturbedAttentionGuidance
2627
from .skip_layer_guidance import SkipLayerGuidance
2728
from .smoothed_energy_guidance import SmoothedEnergyGuidance
@@ -32,6 +33,7 @@
3233
AutoGuidance,
3334
ClassifierFreeGuidance,
3435
ClassifierFreeZeroStarGuidance,
36+
FrequencyDecoupledGuidance,
3537
PerturbedAttentionGuidance,
3638
SkipLayerGuidance,
3739
SmoothedEnergyGuidance,
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
17+
18+
import torch
19+
20+
from ..configuration_utils import register_to_config
21+
from ..utils import is_kornia_available
22+
from .guider_utils import BaseGuidance, rescale_noise_cfg
23+
24+
25+
if TYPE_CHECKING:
26+
from ..modular_pipelines.modular_pipeline import BlockState
27+
28+
29+
_CAN_USE_KORNIA = is_kornia_available()
30+
31+
32+
if _CAN_USE_KORNIA:
33+
from kornia.geometry import pyrup as upsample_and_blur_func
34+
from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func
35+
else:
36+
upsample_and_blur_func = None
37+
build_laplacian_pyramid_func = None
38+
39+
40+
def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
41+
"""
42+
Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper
43+
(Algorithm 2).
44+
"""
45+
# v0 shape: [B, ...]
46+
# v1 shape: [B, ...]
47+
# Assume first dim is a batch dim and all other dims are channel or "spatial" dims
48+
all_dims_but_first = list(range(1, len(v0.shape)))
49+
if upcast_to_double:
50+
dtype = v0.dtype
51+
v0, v1 = v0.double(), v1.double()
52+
v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first)
53+
v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1
54+
v0_orthogonal = v0 - v0_parallel
55+
if upcast_to_double:
56+
v0_parallel = v0_parallel.to(dtype)
57+
v0_orthogonal = v0_orthogonal.to(dtype)
58+
return v0_parallel, v0_orthogonal
59+
60+
61+
def build_image_from_pyramid(pyramid: List[torch.Tensor]) -> torch.Tensor:
62+
"""
63+
Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper
64+
(Algorihtm 2).
65+
"""
66+
# pyramid shapes: [[B, C, H, W], [B, C, H/2, W/2], ...]
67+
img = pyramid[-1]
68+
for i in range(len(pyramid) - 2, -1, -1):
69+
img = upsample_and_blur_func(img) + pyramid[i]
70+
return img
71+
72+
73+
class FrequencyDecoupledGuidance(BaseGuidance):
74+
"""
75+
Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713
76+
77+
FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation
78+
quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both
79+
conditional and unconditional data, and use a combination of the two during inference. (If you want more details on
80+
how CFG works, you can check out the CFG guider.)
81+
82+
FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components
83+
using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space
84+
separately for the low- and high-frequency components with different guidance scales. Finally, the inverse
85+
frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images)
86+
to form the final FDG prediction.
87+
88+
For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample
89+
diversity and realistic color composition, while using high guidance scales for high-frequency components enhances
90+
sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for
91+
the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an
92+
example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper).
93+
94+
As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen
95+
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in
96+
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
97+
98+
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
99+
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
100+
101+
Args:
102+
guidance_scales (`List[float]`, defaults to `[10.0, 5.0]`):
103+
The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest
104+
frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower
105+
values allow for more freedom in generation. Higher values may lead to saturation and deterioration of
106+
image quality. The FDG authors recommend using higher guidance scales for higher frequency components and
107+
lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in
108+
descending order).
109+
guidance_rescale (`float` or `List[float]`, defaults to `0.0`):
110+
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
111+
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
112+
Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as
113+
`guidance_scales`.
114+
parallel_weights (`float` or `List[float]`, *optional*):
115+
Optional weights for the parallel component of each frequency component of the projected CFG shift. If not
116+
set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift
117+
(that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is
118+
recommended. If a list is supplied, it should be the same length as `guidance_scales`.
119+
use_original_formulation (`bool`, defaults to `False`):
120+
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
121+
we use the diffusers-native implementation that has been in the codebase for a long time. See
122+
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
123+
start (`float` or `List[float]`, defaults to `0.0`):
124+
The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it
125+
should be the same length as `guidance_scales`.
126+
stop (`float` or `List[float]`, defaults to `1.0`):
127+
The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it
128+
should be the same length as `guidance_scales`.
129+
guidance_rescale_space (`str`, defaults to `"data"`):
130+
Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in
131+
`"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is
132+
speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value
133+
will be used; otherwise, per-frequency-level guidance rescale values will be used if available.
134+
upcast_to_double (`bool`, defaults to `True`):
135+
Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to
136+
float64 when performing guidance. This may result in better performance at the cost of increased runtime.
137+
"""
138+
139+
_input_predictions = ["pred_cond", "pred_uncond"]
140+
141+
@register_to_config
142+
def __init__(
143+
self,
144+
guidance_scales: Union[List[float], Tuple[float]] = [10.0, 5.0],
145+
guidance_rescale: Union[float, List[float], Tuple[float]] = 0.0,
146+
parallel_weights: Optional[Union[float, List[float], Tuple[float]]] = None,
147+
use_original_formulation: bool = False,
148+
start: Union[float, List[float], Tuple[float]] = 0.0,
149+
stop: Union[float, List[float], Tuple[float]] = 1.0,
150+
guidance_rescale_space: str = "data",
151+
upcast_to_double: bool = True,
152+
):
153+
if not _CAN_USE_KORNIA:
154+
raise ImportError(
155+
"The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which "
156+
"it depends is not available in the current environment. You can install `kornia` with `pip install "
157+
"kornia`."
158+
)
159+
160+
# Set start to earliest start for any freq component and stop to latest stop for any freq component
161+
min_start = start if isinstance(start, float) else min(start)
162+
max_stop = stop if isinstance(stop, float) else max(stop)
163+
super().__init__(min_start, max_stop)
164+
165+
self.guidance_scales = guidance_scales
166+
self.levels = len(guidance_scales)
167+
168+
if isinstance(guidance_rescale, float):
169+
self.guidance_rescale = [guidance_rescale] * self.levels
170+
elif len(guidance_rescale) == self.levels:
171+
self.guidance_rescale = guidance_rescale
172+
else:
173+
raise ValueError(
174+
f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as "
175+
f"`guidance_scales` ({len(self.guidance_scales)})"
176+
)
177+
# Whether to perform guidance rescaling in frequency space (right after the CFG update) or data space (after
178+
# transforming from frequency space back to data space)
179+
if guidance_rescale_space not in ["data", "freq"]:
180+
raise ValueError(
181+
f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`."
182+
)
183+
self.guidance_rescale_space = guidance_rescale_space
184+
185+
if parallel_weights is None:
186+
# Use normal CFG shift (equal weights for parallel and orthogonal components)
187+
self.parallel_weights = [1.0] * self.levels
188+
elif isinstance(parallel_weights, float):
189+
self.parallel_weights = [parallel_weights] * self.levels
190+
elif len(parallel_weights) == self.levels:
191+
self.parallel_weights = parallel_weights
192+
else:
193+
raise ValueError(
194+
f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as "
195+
f"`guidance_scales` ({len(self.guidance_scales)})"
196+
)
197+
198+
self.use_original_formulation = use_original_formulation
199+
self.upcast_to_double = upcast_to_double
200+
201+
if isinstance(start, float):
202+
self.guidance_start = [start] * self.levels
203+
elif len(start) == self.levels:
204+
self.guidance_start = start
205+
else:
206+
raise ValueError(
207+
f"`start` has length {len(start)} but should have the same length as `guidance_scales` "
208+
f"({len(self.guidance_scales)})"
209+
)
210+
if isinstance(stop, float):
211+
self.guidance_stop = [stop] * self.levels
212+
elif len(stop) == self.levels:
213+
self.guidance_stop = stop
214+
else:
215+
raise ValueError(
216+
f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` "
217+
f"({len(self.guidance_scales)})"
218+
)
219+
220+
def prepare_inputs(
221+
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
222+
) -> List["BlockState"]:
223+
if input_fields is None:
224+
input_fields = self._input_fields
225+
226+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
227+
data_batches = []
228+
for i in range(self.num_conditions):
229+
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
230+
data_batches.append(data_batch)
231+
return data_batches
232+
233+
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
234+
pred = None
235+
236+
if not self._is_fdg_enabled():
237+
pred = pred_cond
238+
else:
239+
# Apply the frequency transform (e.g. Laplacian pyramid) to the conditional and unconditional predictions.
240+
pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels)
241+
pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels)
242+
243+
# From high frequencies to low frequencies, following the paper implementation
244+
pred_guided_pyramid = []
245+
parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale)
246+
for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters):
247+
if self._is_fdg_enabled_for_level(level):
248+
# Get the cond/uncond preds (in freq space) at the current frequency level
249+
pred_cond_freq = pred_cond_pyramid[level]
250+
pred_uncond_freq = pred_uncond_pyramid[level]
251+
252+
shift = pred_cond_freq - pred_uncond_freq
253+
254+
# Apply parallel weights, if used (1.0 corresponds to using the normal CFG shift)
255+
if not math.isclose(parallel_weight, 1.0):
256+
shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double)
257+
shift = parallel_weight * shift_parallel + shift_orthogonal
258+
259+
# Apply CFG update for the current frequency level
260+
pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq
261+
pred = pred + guidance_scale * shift
262+
263+
if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0:
264+
pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale)
265+
266+
# Add the current FDG guided level to the FDG prediction pyramid
267+
pred_guided_pyramid.append(pred)
268+
else:
269+
# Add the current pred_cond_pyramid level as the "non-FDG" prediction
270+
pred_guided_pyramid.append(pred_cond_freq)
271+
272+
# Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform
273+
pred = build_image_from_pyramid(pred_guided_pyramid)
274+
275+
# If rescaling in data space, use the first elem of self.guidance_rescale as the "global" rescale value
276+
# across all freq levels
277+
if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0:
278+
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0])
279+
280+
return pred, {}
281+
282+
@property
283+
def is_conditional(self) -> bool:
284+
return self._count_prepared == 1
285+
286+
@property
287+
def num_conditions(self) -> int:
288+
num_conditions = 1
289+
if self._is_fdg_enabled():
290+
num_conditions += 1
291+
return num_conditions
292+
293+
def _is_fdg_enabled(self) -> bool:
294+
if not self._enabled:
295+
return False
296+
297+
is_within_range = True
298+
if self._num_inference_steps is not None:
299+
skip_start_step = int(self._start * self._num_inference_steps)
300+
skip_stop_step = int(self._stop * self._num_inference_steps)
301+
is_within_range = skip_start_step <= self._step < skip_stop_step
302+
303+
is_close = False
304+
if self.use_original_formulation:
305+
is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales)
306+
else:
307+
is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales)
308+
309+
return is_within_range and not is_close
310+
311+
def _is_fdg_enabled_for_level(self, level: int) -> bool:
312+
if not self._enabled:
313+
return False
314+
315+
is_within_range = True
316+
if self._num_inference_steps is not None:
317+
skip_start_step = int(self.guidance_start[level] * self._num_inference_steps)
318+
skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps)
319+
is_within_range = skip_start_step <= self._step < skip_stop_step
320+
321+
is_close = False
322+
if self.use_original_formulation:
323+
is_close = math.isclose(self.guidance_scales[level], 0.0)
324+
else:
325+
is_close = math.isclose(self.guidance_scales[level], 1.0)
326+
327+
return is_within_range and not is_close

‎src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
is_k_diffusion_available,
8383
is_k_diffusion_version,
8484
is_kernels_available,
85+
is_kornia_available,
8586
is_librosa_available,
8687
is_matplotlib_available,
8788
is_nltk_available,

‎src/diffusers/utils/dummy_pt_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,21 @@ def from_pretrained(cls, *args, **kwargs):
6262
requires_backends(cls, ["torch"])
6363

6464

65+
class FrequencyDecoupledGuidance(metaclass=DummyObject):
66+
_backends = ["torch"]
67+
68+
def __init__(self, *args, **kwargs):
69+
requires_backends(self, ["torch"])
70+
71+
@classmethod
72+
def from_config(cls, *args, **kwargs):
73+
requires_backends(cls, ["torch"])
74+
75+
@classmethod
76+
def from_pretrained(cls, *args, **kwargs):
77+
requires_backends(cls, ["torch"])
78+
79+
6580
class PerturbedAttentionGuidance(metaclass=DummyObject):
6681
_backends = ["torch"]
6782

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /