Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 814d710

Browse files
[tests] cache non lora pipeline outputs. (#12298)
* cache non lora pipeline outputs. * up * up * up * up * Revert "up" This reverts commit 772c32e. * up * Revert "up" This reverts commit cca03df. * up * up * add . * up * up * up * up * up * up
1 parent cc5b31f commit 814d710

File tree

4 files changed

+42
-83
lines changed

4 files changed

+42
-83
lines changed

‎tests/lora/test_lora_layers_cogview4.py‎

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,6 @@ def test_simple_inference_save_pretrained(self):
129129
pipe.set_progress_bar_config(disable=None)
130130
_, _, inputs = self.get_dummy_inputs(with_generator=False)
131131

132-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
133-
self.assertTrue(output_no_lora.shape == self.output_shape)
134-
135132
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
136133

137134
with tempfile.TemporaryDirectory() as tmpdirname:

‎tests/lora/test_lora_layers_flux.py‎

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,6 @@ def test_with_alpha_in_state_dict(self):
122122
pipe.set_progress_bar_config(disable=None)
123123
_, _, inputs = self.get_dummy_inputs(with_generator=False)
124124

125-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
126-
self.assertTrue(output_no_lora.shape == self.output_shape)
127-
128125
pipe.transformer.add_adapter(denoiser_lora_config)
129126
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
130127

@@ -170,8 +167,7 @@ def test_lora_expansion_works_for_absent_keys(self):
170167
pipe.set_progress_bar_config(disable=None)
171168
_, _, inputs = self.get_dummy_inputs(with_generator=False)
172169

173-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
174-
self.assertTrue(output_no_lora.shape == self.output_shape)
170+
output_no_lora = self.get_base_pipe_output()
175171

176172
# Modify the config to have a layer which won't be present in the second LoRA we will load.
177173
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
@@ -218,9 +214,7 @@ def test_lora_expansion_works_for_extra_keys(self):
218214
pipe = pipe.to(torch_device)
219215
pipe.set_progress_bar_config(disable=None)
220216
_, _, inputs = self.get_dummy_inputs(with_generator=False)
221-
222-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
223-
self.assertTrue(output_no_lora.shape == self.output_shape)
217+
output_no_lora = self.get_base_pipe_output()
224218

225219
# Modify the config to have a layer which won't be present in the first LoRA we will load.
226220
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
@@ -329,6 +323,7 @@ def get_dummy_inputs(self, with_generator=True):
329323
noise = floats_tensor((batch_size, num_channels) + sizes)
330324
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
331325

326+
np.random.seed(0)
332327
pipeline_inputs = {
333328
"prompt": "A painting of a squirrel eating a burger",
334329
"control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")),

‎tests/lora/test_lora_layers_wanvace.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def test_lora_exclude_modules_wanvace(self):
169169
pipe = self.pipeline_class(**components).to(torch_device)
170170
_, _, inputs = self.get_dummy_inputs(with_generator=False)
171171

172-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
172+
output_no_lora = self.get_base_pipe_output()
173173
self.assertTrue(output_no_lora.shape == self.output_shape)
174174

175175
# only supported for `denoiser` now

‎tests/lora/utils.py‎

Lines changed: 38 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,20 @@ class PeftLoraLoaderMixinTests:
126126
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
127127
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
128128

129-
def get_dummy_components(self, use_dora=False, lora_alpha=None):
129+
cached_non_lora_output = None
130+
131+
def get_base_pipe_output(self):
132+
if self.cached_non_lora_output is None:
133+
self.cached_non_lora_output = self._compute_baseline_output()
134+
return self.cached_non_lora_output
135+
136+
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
130137
if self.unet_kwargs and self.transformer_kwargs:
131138
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
132139
if self.has_two_text_encoders and self.has_three_text_encoders:
133140
raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")
134141

135-
scheduler_cls = self.scheduler_cls
142+
scheduler_cls = scheduler_clsifscheduler_clsisnotNoneelseself.scheduler_cls
136143
rank = 4
137144
lora_alpha = rank if lora_alpha is None else lora_alpha
138145

@@ -238,15 +245,16 @@ def get_dummy_inputs(self, with_generator=True):
238245

239246
return noise, input_ids, pipeline_inputs
240247

241-
# Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
242-
defget_dummy_tokens(self):
243-
max_seq_length = 77
244-
245-
inputs=torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))
248+
def_compute_baseline_output(self):
249+
components, _, _=self.get_dummy_components(self.scheduler_cls)
250+
pipe = self.pipeline_class(**components)
251+
pipe=pipe.to(torch_device)
252+
pipe.set_progress_bar_config(disable=None)
246253

247-
prepared_inputs = {}
248-
prepared_inputs["input_ids"] = inputs
249-
return prepared_inputs
254+
# Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
255+
# explicitly.
256+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
257+
return pipe(**inputs, generator=torch.manual_seed(0))[0]
250258

251259
def _get_lora_state_dicts(self, modules_to_save):
252260
state_dicts = {}
@@ -316,14 +324,8 @@ def test_simple_inference(self):
316324
"""
317325
Tests a simple inference and makes sure it works as expected
318326
"""
319-
components, text_lora_config, _ = self.get_dummy_components()
320-
pipe = self.pipeline_class(**components)
321-
pipe = pipe.to(torch_device)
322-
pipe.set_progress_bar_config(disable=None)
323-
324-
_, _, inputs = self.get_dummy_inputs()
325-
output_no_lora = pipe(**inputs)[0]
326-
self.assertTrue(output_no_lora.shape == self.output_shape)
327+
output_no_lora = self.get_base_pipe_output()
328+
assert output_no_lora.shape == self.output_shape
327329

328330
def test_simple_inference_with_text_lora(self):
329331
"""
@@ -336,9 +338,7 @@ def test_simple_inference_with_text_lora(self):
336338
pipe.set_progress_bar_config(disable=None)
337339
_, _, inputs = self.get_dummy_inputs(with_generator=False)
338340

339-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
340-
self.assertTrue(output_no_lora.shape == self.output_shape)
341-
341+
output_no_lora = self.get_base_pipe_output()
342342
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
343343

344344
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -414,9 +414,6 @@ def test_low_cpu_mem_usage_with_loading(self):
414414
pipe.set_progress_bar_config(disable=None)
415415
_, _, inputs = self.get_dummy_inputs(with_generator=False)
416416

417-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
418-
self.assertTrue(output_no_lora.shape == self.output_shape)
419-
420417
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
421418

422419
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -466,8 +463,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
466463
pipe.set_progress_bar_config(disable=None)
467464
_, _, inputs = self.get_dummy_inputs(with_generator=False)
468465

469-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
470-
self.assertTrue(output_no_lora.shape == self.output_shape)
466+
output_no_lora = self.get_base_pipe_output()
471467

472468
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
473469

@@ -503,8 +499,7 @@ def test_simple_inference_with_text_lora_fused(self):
503499
pipe.set_progress_bar_config(disable=None)
504500
_, _, inputs = self.get_dummy_inputs(with_generator=False)
505501

506-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
507-
self.assertTrue(output_no_lora.shape == self.output_shape)
502+
output_no_lora = self.get_base_pipe_output()
508503

509504
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
510505

@@ -534,8 +529,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
534529
pipe.set_progress_bar_config(disable=None)
535530
_, _, inputs = self.get_dummy_inputs(with_generator=False)
536531

537-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
538-
self.assertTrue(output_no_lora.shape == self.output_shape)
532+
output_no_lora = self.get_base_pipe_output()
539533

540534
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
541535

@@ -566,9 +560,6 @@ def test_simple_inference_with_text_lora_save_load(self):
566560
pipe.set_progress_bar_config(disable=None)
567561
_, _, inputs = self.get_dummy_inputs(with_generator=False)
568562

569-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
570-
self.assertTrue(output_no_lora.shape == self.output_shape)
571-
572563
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
573564

574565
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -616,8 +607,7 @@ def test_simple_inference_with_partial_text_lora(self):
616607
pipe.set_progress_bar_config(disable=None)
617608
_, _, inputs = self.get_dummy_inputs(with_generator=False)
618609

619-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
620-
self.assertTrue(output_no_lora.shape == self.output_shape)
610+
output_no_lora = self.get_base_pipe_output()
621611

622612
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
623613

@@ -666,9 +656,6 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
666656
pipe.set_progress_bar_config(disable=None)
667657
_, _, inputs = self.get_dummy_inputs(with_generator=False)
668658

669-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
670-
self.assertTrue(output_no_lora.shape == self.output_shape)
671-
672659
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
673660
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
674661

@@ -708,9 +695,6 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
708695
pipe.set_progress_bar_config(disable=None)
709696
_, _, inputs = self.get_dummy_inputs(with_generator=False)
710697

711-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
712-
self.assertTrue(output_no_lora.shape == self.output_shape)
713-
714698
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
715699

716700
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -747,9 +731,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
747731
pipe.set_progress_bar_config(disable=None)
748732
_, _, inputs = self.get_dummy_inputs(with_generator=False)
749733

750-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
751-
self.assertTrue(output_no_lora.shape == self.output_shape)
752-
734+
output_no_lora = self.get_base_pipe_output()
753735
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
754736

755737
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -790,8 +772,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
790772
pipe.set_progress_bar_config(disable=None)
791773
_, _, inputs = self.get_dummy_inputs(with_generator=False)
792774

793-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
794-
self.assertTrue(output_no_lora.shape == self.output_shape)
775+
output_no_lora = self.get_base_pipe_output()
795776

796777
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
797778

@@ -825,8 +806,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
825806
pipe.set_progress_bar_config(disable=None)
826807
_, _, inputs = self.get_dummy_inputs(with_generator=False)
827808

828-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
829-
self.assertTrue(output_no_lora.shape == self.output_shape)
809+
output_no_lora = self.get_base_pipe_output()
830810

831811
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
832812

@@ -900,7 +880,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
900880
pipe.set_progress_bar_config(disable=None)
901881
_, _, inputs = self.get_dummy_inputs(with_generator=False)
902882

903-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
883+
output_no_lora = self.get_base_pipe_output()
904884

905885
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
906886
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1024,7 +1004,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
10241004
pipe.set_progress_bar_config(disable=None)
10251005
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10261006

1027-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1007+
output_no_lora = self.get_base_pipe_output()
10281008

10291009
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
10301010
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
@@ -1080,7 +1060,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
10801060
pipe.set_progress_bar_config(disable=None)
10811061
_, _, inputs = self.get_dummy_inputs(with_generator=False)
10821062

1083-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1063+
output_no_lora = self.get_base_pipe_output()
10841064

10851065
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
10861066
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1240,7 +1220,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
12401220
pipe.set_progress_bar_config(disable=None)
12411221
_, _, inputs = self.get_dummy_inputs(with_generator=False)
12421222

1243-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1223+
output_no_lora = self.get_base_pipe_output()
12441224

12451225
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
12461226
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1331,7 +1311,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
13311311
pipe.set_progress_bar_config(disable=None)
13321312
_, _, inputs = self.get_dummy_inputs(with_generator=False)
13331313

1334-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1314+
output_no_lora = self.get_base_pipe_output()
13351315

13361316
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
13371317
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1551,7 +1531,6 @@ def test_get_list_adapters(self):
15511531

15521532
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
15531533

1554-
@require_peft_version_greater(peft_version="0.6.2")
15551534
def test_simple_inference_with_text_lora_denoiser_fused_multi(
15561535
self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
15571536
):
@@ -1565,9 +1544,6 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
15651544
pipe.set_progress_bar_config(disable=None)
15661545
_, _, inputs = self.get_dummy_inputs(with_generator=False)
15671546

1568-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1569-
self.assertTrue(output_no_lora.shape == self.output_shape)
1570-
15711547
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
15721548
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
15731549
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
@@ -1641,8 +1617,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
16411617
pipe.set_progress_bar_config(disable=None)
16421618
_, _, inputs = self.get_dummy_inputs(with_generator=False)
16431619

1644-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1645-
self.assertTrue(output_no_lora.shape == self.output_shape)
1620+
output_no_lora = self.get_base_pipe_output()
16461621

16471622
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
16481623
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1685,7 +1660,6 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
16851660
"LoRA should change the output",
16861661
)
16871662

1688-
@require_peft_version_greater(peft_version="0.9.0")
16891663
def test_simple_inference_with_dora(self):
16901664
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True)
16911665
pipe = self.pipeline_class(**components)
@@ -1695,7 +1669,6 @@ def test_simple_inference_with_dora(self):
16951669

16961670
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
16971671
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
1698-
16991672
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
17001673

17011674
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -1783,7 +1756,6 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
17831756
pipe = pipe.to(torch_device)
17841757
pipe.set_progress_bar_config(disable=None)
17851758
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1786-
17871759
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
17881760

17891761
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
@@ -1820,7 +1792,7 @@ def test_logs_info_when_no_lora_keys_found(self):
18201792
pipe.set_progress_bar_config(disable=None)
18211793

18221794
_, _, inputs = self.get_dummy_inputs(with_generator=False)
1823-
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
1795+
output_no_lora = self.get_base_pipe_output()
18241796

18251797
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
18261798
logger = logging.get_logger("diffusers.loaders.peft")
@@ -1832,7 +1804,7 @@ def test_logs_info_when_no_lora_keys_found(self):
18321804

18331805
denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer")
18341806
self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}"))
1835-
self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
1807+
self.assertTrue(np.allclose(output_no_lora, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
18361808

18371809
# test only for text encoder
18381810
for lora_module in self.pipeline_class._lora_loadable_modules:
@@ -1864,9 +1836,7 @@ def test_set_adapters_match_attention_kwargs(self):
18641836
pipe.set_progress_bar_config(disable=None)
18651837
_, _, inputs = self.get_dummy_inputs(with_generator=False)
18661838

1867-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
1868-
self.assertTrue(output_no_lora.shape == self.output_shape)
1869-
1839+
output_no_lora = self.get_base_pipe_output()
18701840
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
18711841

18721842
lora_scale = 0.5
@@ -2212,9 +2182,6 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
22122182
pipe = self.pipeline_class(**components).to(torch_device)
22132183
_, _, inputs = self.get_dummy_inputs(with_generator=False)
22142184

2215-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2216-
self.assertTrue(output_no_lora.shape == self.output_shape)
2217-
22182185
pipe, _ = self.add_adapters_to_pipeline(
22192186
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
22202187
)
@@ -2260,7 +2227,7 @@ def test_inference_load_delete_load_adapters(self):
22602227
pipe.set_progress_bar_config(disable=None)
22612228
_, _, inputs = self.get_dummy_inputs(with_generator=False)
22622229

2263-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2230+
output_no_lora = self.get_base_pipe_output()
22642231

22652232
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
22662233
pipe.text_encoder.add_adapter(text_lora_config)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /