Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 8f80dda

Browse files
sayakpaulyiyixuxu
andauthored
[tests] add tests for flux modular (t2i, i2i, kontext) (#12566)
* start flux modular tests. * up * add kontext * up * up * up * Update src/diffusers/modular_pipelines/flux/denoise.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * up * up --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent cdbf0ad commit 8f80dda

File tree

8 files changed

+152
-27
lines changed

8 files changed

+152
-27
lines changed

‎src/diffusers/modular_pipelines/components_manager.py‎

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,11 @@ def __call__(self, hooks, model_id, model, execution_device):
164164

165165
device_type = execution_device.type
166166
device_module = getattr(torch, device_type, torch.cuda)
167-
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
167+
try:
168+
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
169+
except AttributeError:
170+
raise AttributeError(f"Do not know how to obtain obtain memory info for {str(device_module)}.")
171+
168172
mem_on_device = mem_on_device - self.memory_reserve_margin
169173
if current_module_size < mem_on_device:
170174
return []
@@ -699,6 +703,8 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None,
699703
if not is_accelerate_available():
700704
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
701705

706+
# TODO: add a warning if mem_get_info isn't available on `device`.
707+
702708
for name, component in self.components.items():
703709
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
704710
remove_hook_from_module(component, recurse=True)

‎src/diffusers/modular_pipelines/flux/before_denoise.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
598598
and getattr(block_state, "image_width", None) is not None
599599
):
600600
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
601-
image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
601+
image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2))
602602
img_ids = FluxPipeline._prepare_latent_image_ids(
603603
None, image_latent_height // 2, image_latent_width // 2, device, dtype
604604
)

‎src/diffusers/modular_pipelines/flux/denoise.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def inputs(self) -> List[Tuple[str, Any]]:
5959
),
6060
InputParam(
6161
"guidance",
62-
required=True,
62+
required=False,
6363
type_hint=torch.Tensor,
6464
description="Guidance scale as a tensor",
6565
),
@@ -141,7 +141,7 @@ def inputs(self) -> List[Tuple[str, Any]]:
141141
),
142142
InputParam(
143143
"guidance",
144-
required=True,
144+
required=False,
145145
type_hint=torch.Tensor,
146146
description="Guidance scale as a tensor",
147147
),

‎src/diffusers/modular_pipelines/flux/encoders.py‎

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def expected_components(self) -> List[ComponentSpec]:
9595
ComponentSpec(
9696
"image_processor",
9797
VaeImageProcessor,
98-
config=FrozenDict({"vae_scale_factor": 16}),
98+
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}),
9999
default_creation_method="from_config",
100100
),
101101
]
@@ -143,10 +143,6 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState):
143143
class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
144144
model_name = "flux-kontext"
145145

146-
def __init__(self, _auto_resize=True):
147-
self._auto_resize = _auto_resize
148-
super().__init__()
149-
150146
@property
151147
def description(self) -> str:
152148
return (
@@ -167,7 +163,7 @@ def expected_components(self) -> List[ComponentSpec]:
167163

168164
@property
169165
def inputs(self) -> List[InputParam]:
170-
return [InputParam("image")]
166+
return [InputParam("image"), InputParam("_auto_resize", type_hint=bool, default=True)]
171167

172168
@property
173169
def intermediate_outputs(self) -> List[OutputParam]:
@@ -195,7 +191,8 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState):
195191
img = images[0]
196192
image_height, image_width = components.image_processor.get_default_height_width(img)
197193
aspect_ratio = image_width / image_height
198-
if self._auto_resize:
194+
_auto_resize = block_state._auto_resize
195+
if _auto_resize:
199196
# Kontext is trained on specific resolutions, using one of them is recommended
200197
_, image_width, image_height = min(
201198
(abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS

‎src/diffusers/modular_pipelines/flux/inputs.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
112112
block_state.prompt_embeds = block_state.prompt_embeds.view(
113113
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
114114
)
115+
pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt)
116+
block_state.pooled_prompt_embeds = pooled_prompt_embeds.view(
117+
block_state.batch_size * block_state.num_images_per_prompt, -1
118+
)
115119
self.set_block_state(state, block_state)
116120

117121
return components, state

‎tests/modular_pipelines/flux/__init__.py‎

Whitespace-only changes.
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import random
17+
import tempfile
18+
import unittest
19+
20+
import numpy as np
21+
import PIL
22+
import torch
23+
24+
from diffusers.image_processor import VaeImageProcessor
25+
from diffusers.modular_pipelines import (
26+
FluxAutoBlocks,
27+
FluxKontextAutoBlocks,
28+
FluxKontextModularPipeline,
29+
FluxModularPipeline,
30+
ModularPipeline,
31+
)
32+
33+
from ...testing_utils import floats_tensor, torch_device
34+
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
35+
36+
37+
class FluxModularTests:
38+
pipeline_class = FluxModularPipeline
39+
pipeline_blocks_class = FluxAutoBlocks
40+
repo = "hf-internal-testing/tiny-flux-modular"
41+
42+
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
43+
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
44+
pipeline.load_components(torch_dtype=torch_dtype)
45+
return pipeline
46+
47+
def get_dummy_inputs(self, device, seed=0):
48+
if str(device).startswith("mps"):
49+
generator = torch.manual_seed(seed)
50+
else:
51+
generator = torch.Generator(device=device).manual_seed(seed)
52+
inputs = {
53+
"prompt": "A painting of a squirrel eating a burger",
54+
"generator": generator,
55+
"num_inference_steps": 2,
56+
"guidance_scale": 5.0,
57+
"height": 8,
58+
"width": 8,
59+
"max_sequence_length": 48,
60+
"output_type": "np",
61+
}
62+
return inputs
63+
64+
65+
class FluxModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase):
66+
params = frozenset(["prompt", "height", "width", "guidance_scale"])
67+
batch_params = frozenset(["prompt"])
68+
69+
70+
class FluxImg2ImgModularPipelineFastTests(FluxModularTests, ModularPipelineTesterMixin, unittest.TestCase):
71+
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
72+
batch_params = frozenset(["prompt", "image"])
73+
74+
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
75+
pipeline = super().get_pipeline(components_manager, torch_dtype)
76+
# Override `vae_scale_factor` here as currently, `image_processor` is initialized with
77+
# fixed constants instead of
78+
# https://github.com/huggingface/diffusers/blob/d54622c2679d700b425ad61abce9b80fc36212c0/src/diffusers/pipelines/flux/pipeline_flux_img2img.py#L230C9-L232C10
79+
pipeline.image_processor = VaeImageProcessor(vae_scale_factor=2)
80+
return pipeline
81+
82+
def get_dummy_inputs(self, device, seed=0):
83+
inputs = super().get_dummy_inputs(device, seed)
84+
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
85+
image = image / 2 + 0.5
86+
inputs["image"] = image
87+
inputs["strength"] = 0.8
88+
inputs["height"] = 8
89+
inputs["width"] = 8
90+
return inputs
91+
92+
def test_save_from_pretrained(self):
93+
pipes = []
94+
base_pipe = self.get_pipeline().to(torch_device)
95+
pipes.append(base_pipe)
96+
97+
with tempfile.TemporaryDirectory() as tmpdirname:
98+
base_pipe.save_pretrained(tmpdirname)
99+
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
100+
pipe.load_components(torch_dtype=torch.float32)
101+
pipe.to(torch_device)
102+
pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
103+
104+
pipes.append(pipe)
105+
106+
image_slices = []
107+
for pipe in pipes:
108+
inputs = self.get_dummy_inputs(torch_device)
109+
image = pipe(**inputs, output="images")
110+
111+
image_slices.append(image[0, -3:, -3:, -1].flatten())
112+
113+
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
114+
115+
116+
class FluxKontextModularPipelineFastTests(FluxImg2ImgModularPipelineFastTests):
117+
pipeline_class = FluxKontextModularPipeline
118+
pipeline_blocks_class = FluxKontextAutoBlocks
119+
repo = "hf-internal-testing/tiny-flux-kontext-pipe"
120+
121+
def get_dummy_inputs(self, device, seed=0):
122+
inputs = super().get_dummy_inputs(device, seed)
123+
image = PIL.Image.new("RGB", (32, 32), 0)
124+
_ = inputs.pop("strength")
125+
inputs["image"] = image
126+
inputs["height"] = 8
127+
inputs["width"] = 8
128+
inputs["max_area"] = 8 * 8
129+
inputs["_auto_resize"] = False
130+
return inputs

‎tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py‎

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,12 @@
2121
import torch
2222
from PIL import Image
2323

24-
from diffusers import (
25-
ClassifierFreeGuidance,
26-
StableDiffusionXLAutoBlocks,
27-
StableDiffusionXLModularPipeline,
28-
)
24+
from diffusers import ClassifierFreeGuidance, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
2925
from diffusers.loaders import ModularIPAdapterMixin
3026

31-
from ...models.unets.test_models_unet_2d_condition import (
32-
create_ip_adapter_state_dict,
33-
)
34-
from ...testing_utils import (
35-
enable_full_determinism,
36-
floats_tensor,
37-
torch_device,
38-
)
39-
from ..test_modular_pipelines_common import (
40-
ModularPipelineTesterMixin,
41-
)
27+
from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
28+
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
29+
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
4230

4331

4432
enable_full_determinism()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /