@@ -126,13 +126,20 @@ class PeftLoraLoaderMixinTests:
126
126
text_encoder_target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ]
127
127
denoiser_target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ]
128
128
129
- def get_dummy_components (self , use_dora = False , lora_alpha = None ):
129
+ cached_non_lora_output = None
130
+
131
+ def get_base_pipe_output (self ):
132
+ if self .cached_non_lora_output is None :
133
+ self .cached_non_lora_output = self ._compute_baseline_output ()
134
+ return self .cached_non_lora_output
135
+
136
+ def get_dummy_components (self , scheduler_cls = None , use_dora = False , lora_alpha = None ):
130
137
if self .unet_kwargs and self .transformer_kwargs :
131
138
raise ValueError ("Both `unet_kwargs` and `transformer_kwargs` cannot be specified." )
132
139
if self .has_two_text_encoders and self .has_three_text_encoders :
133
140
raise ValueError ("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True." )
134
141
135
- scheduler_cls = self .scheduler_cls
142
+ scheduler_cls = scheduler_cls if scheduler_cls is not None else self .scheduler_cls
136
143
rank = 4
137
144
lora_alpha = rank if lora_alpha is None else lora_alpha
138
145
@@ -238,15 +245,16 @@ def get_dummy_inputs(self, with_generator=True):
238
245
239
246
return noise , input_ids , pipeline_inputs
240
247
241
- # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
242
- def get_dummy_tokens (self ):
243
- max_seq_length = 77
244
-
245
- inputs = torch . randint ( 2 , 56 , size = ( 1 , max_seq_length ), generator = torch . manual_seed ( 0 ) )
248
+ def _compute_baseline_output ( self ):
249
+ components , _ , _ = self . get_dummy_components (self . scheduler_cls )
250
+ pipe = self . pipeline_class ( ** components )
251
+ pipe = pipe . to ( torch_device )
252
+ pipe . set_progress_bar_config ( disable = None )
246
253
247
- prepared_inputs = {}
248
- prepared_inputs ["input_ids" ] = inputs
249
- return prepared_inputs
254
+ # Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
255
+ # explicitly.
256
+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
257
+ return pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
250
258
251
259
def _get_lora_state_dicts (self , modules_to_save ):
252
260
state_dicts = {}
@@ -316,14 +324,8 @@ def test_simple_inference(self):
316
324
"""
317
325
Tests a simple inference and makes sure it works as expected
318
326
"""
319
- components , text_lora_config , _ = self .get_dummy_components ()
320
- pipe = self .pipeline_class (** components )
321
- pipe = pipe .to (torch_device )
322
- pipe .set_progress_bar_config (disable = None )
323
-
324
- _ , _ , inputs = self .get_dummy_inputs ()
325
- output_no_lora = pipe (** inputs )[0 ]
326
- self .assertTrue (output_no_lora .shape == self .output_shape )
327
+ output_no_lora = self .get_base_pipe_output ()
328
+ assert output_no_lora .shape == self .output_shape
327
329
328
330
def test_simple_inference_with_text_lora (self ):
329
331
"""
@@ -336,9 +338,7 @@ def test_simple_inference_with_text_lora(self):
336
338
pipe .set_progress_bar_config (disable = None )
337
339
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
338
340
339
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
340
- self .assertTrue (output_no_lora .shape == self .output_shape )
341
-
341
+ output_no_lora = self .get_base_pipe_output ()
342
342
pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
343
343
344
344
output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -414,9 +414,6 @@ def test_low_cpu_mem_usage_with_loading(self):
414
414
pipe .set_progress_bar_config (disable = None )
415
415
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
416
416
417
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
418
- self .assertTrue (output_no_lora .shape == self .output_shape )
419
-
420
417
pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
421
418
422
419
images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -466,8 +463,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
466
463
pipe .set_progress_bar_config (disable = None )
467
464
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
468
465
469
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
470
- self .assertTrue (output_no_lora .shape == self .output_shape )
466
+ output_no_lora = self .get_base_pipe_output ()
471
467
472
468
pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
473
469
@@ -503,8 +499,7 @@ def test_simple_inference_with_text_lora_fused(self):
503
499
pipe .set_progress_bar_config (disable = None )
504
500
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
505
501
506
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
507
- self .assertTrue (output_no_lora .shape == self .output_shape )
502
+ output_no_lora = self .get_base_pipe_output ()
508
503
509
504
pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
510
505
@@ -534,8 +529,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
534
529
pipe .set_progress_bar_config (disable = None )
535
530
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
536
531
537
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
538
- self .assertTrue (output_no_lora .shape == self .output_shape )
532
+ output_no_lora = self .get_base_pipe_output ()
539
533
540
534
pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
541
535
@@ -566,9 +560,6 @@ def test_simple_inference_with_text_lora_save_load(self):
566
560
pipe .set_progress_bar_config (disable = None )
567
561
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
568
562
569
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
570
- self .assertTrue (output_no_lora .shape == self .output_shape )
571
-
572
563
pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
573
564
574
565
images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -616,8 +607,7 @@ def test_simple_inference_with_partial_text_lora(self):
616
607
pipe .set_progress_bar_config (disable = None )
617
608
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
618
609
619
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
620
- self .assertTrue (output_no_lora .shape == self .output_shape )
610
+ output_no_lora = self .get_base_pipe_output ()
621
611
622
612
pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
623
613
@@ -666,9 +656,6 @@ def test_simple_inference_save_pretrained_with_text_lora(self):
666
656
pipe .set_progress_bar_config (disable = None )
667
657
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
668
658
669
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
670
- self .assertTrue (output_no_lora .shape == self .output_shape )
671
-
672
659
pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
673
660
images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
674
661
@@ -708,9 +695,6 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self):
708
695
pipe .set_progress_bar_config (disable = None )
709
696
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
710
697
711
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
712
- self .assertTrue (output_no_lora .shape == self .output_shape )
713
-
714
698
pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
715
699
716
700
images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -747,9 +731,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
747
731
pipe .set_progress_bar_config (disable = None )
748
732
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
749
733
750
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
751
- self .assertTrue (output_no_lora .shape == self .output_shape )
752
-
734
+ output_no_lora = self .get_base_pipe_output ()
753
735
pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
754
736
755
737
output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -790,8 +772,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self):
790
772
pipe .set_progress_bar_config (disable = None )
791
773
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
792
774
793
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
794
- self .assertTrue (output_no_lora .shape == self .output_shape )
775
+ output_no_lora = self .get_base_pipe_output ()
795
776
796
777
pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
797
778
@@ -825,8 +806,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self):
825
806
pipe .set_progress_bar_config (disable = None )
826
807
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
827
808
828
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
829
- self .assertTrue (output_no_lora .shape == self .output_shape )
809
+ output_no_lora = self .get_base_pipe_output ()
830
810
831
811
pipe , denoiser = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
832
812
@@ -900,7 +880,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self):
900
880
pipe .set_progress_bar_config (disable = None )
901
881
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
902
882
903
- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
883
+ output_no_lora = self . get_base_pipe_output ()
904
884
905
885
if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
906
886
pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1024,7 +1004,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self):
1024
1004
pipe .set_progress_bar_config (disable = None )
1025
1005
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1026
1006
1027
- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1007
+ output_no_lora = self . get_base_pipe_output ()
1028
1008
1029
1009
pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
1030
1010
self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" )
@@ -1080,7 +1060,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
1080
1060
pipe .set_progress_bar_config (disable = None )
1081
1061
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1082
1062
1083
- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1063
+ output_no_lora = self . get_base_pipe_output ()
1084
1064
1085
1065
if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
1086
1066
pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1240,7 +1220,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
1240
1220
pipe .set_progress_bar_config (disable = None )
1241
1221
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1242
1222
1243
- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1223
+ output_no_lora = self . get_base_pipe_output ()
1244
1224
1245
1225
if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
1246
1226
pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1331,7 +1311,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
1331
1311
pipe .set_progress_bar_config (disable = None )
1332
1312
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1333
1313
1334
- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1314
+ output_no_lora = self . get_base_pipe_output ()
1335
1315
1336
1316
if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
1337
1317
pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1551,7 +1531,6 @@ def test_get_list_adapters(self):
1551
1531
1552
1532
self .assertDictEqual (pipe .get_list_adapters (), dicts_to_be_checked )
1553
1533
1554
- @require_peft_version_greater (peft_version = "0.6.2" )
1555
1534
def test_simple_inference_with_text_lora_denoiser_fused_multi (
1556
1535
self , expected_atol : float = 1e-3 , expected_rtol : float = 1e-3
1557
1536
):
@@ -1565,9 +1544,6 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi(
1565
1544
pipe .set_progress_bar_config (disable = None )
1566
1545
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1567
1546
1568
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
1569
- self .assertTrue (output_no_lora .shape == self .output_shape )
1570
-
1571
1547
if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
1572
1548
pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
1573
1549
self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" )
@@ -1641,8 +1617,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
1641
1617
pipe .set_progress_bar_config (disable = None )
1642
1618
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1643
1619
1644
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
1645
- self .assertTrue (output_no_lora .shape == self .output_shape )
1620
+ output_no_lora = self .get_base_pipe_output ()
1646
1621
1647
1622
if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
1648
1623
pipe .text_encoder .add_adapter (text_lora_config , "adapter-1" )
@@ -1685,7 +1660,6 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec
1685
1660
"LoRA should change the output" ,
1686
1661
)
1687
1662
1688
- @require_peft_version_greater (peft_version = "0.9.0" )
1689
1663
def test_simple_inference_with_dora (self ):
1690
1664
components , text_lora_config , denoiser_lora_config = self .get_dummy_components (use_dora = True )
1691
1665
pipe = self .pipeline_class (** components )
@@ -1695,7 +1669,6 @@ def test_simple_inference_with_dora(self):
1695
1669
1696
1670
output_no_dora_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
1697
1671
self .assertTrue (output_no_dora_lora .shape == self .output_shape )
1698
-
1699
1672
pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
1700
1673
1701
1674
output_dora_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
@@ -1783,7 +1756,6 @@ def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
1783
1756
pipe = pipe .to (torch_device )
1784
1757
pipe .set_progress_bar_config (disable = None )
1785
1758
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1786
-
1787
1759
pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
1788
1760
1789
1761
pipe .unet = torch .compile (pipe .unet , mode = "reduce-overhead" , fullgraph = True )
@@ -1820,7 +1792,7 @@ def test_logs_info_when_no_lora_keys_found(self):
1820
1792
pipe .set_progress_bar_config (disable = None )
1821
1793
1822
1794
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1823
- original_out = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
1795
+ output_no_lora = self . get_base_pipe_output ()
1824
1796
1825
1797
no_op_state_dict = {"lora_foo" : torch .tensor (2.0 ), "lora_bar" : torch .tensor (3.0 )}
1826
1798
logger = logging .get_logger ("diffusers.loaders.peft" )
@@ -1832,7 +1804,7 @@ def test_logs_info_when_no_lora_keys_found(self):
1832
1804
1833
1805
denoiser = getattr (pipe , "unet" ) if self .unet_kwargs is not None else getattr (pipe , "transformer" )
1834
1806
self .assertTrue (cap_logger .out .startswith (f"No LoRA keys associated to { denoiser .__class__ .__name__ } " ))
1835
- self .assertTrue (np .allclose (original_out , out_after_lora_attempt , atol = 1e-5 , rtol = 1e-5 ))
1807
+ self .assertTrue (np .allclose (output_no_lora , out_after_lora_attempt , atol = 1e-5 , rtol = 1e-5 ))
1836
1808
1837
1809
# test only for text encoder
1838
1810
for lora_module in self .pipeline_class ._lora_loadable_modules :
@@ -1864,9 +1836,7 @@ def test_set_adapters_match_attention_kwargs(self):
1864
1836
pipe .set_progress_bar_config (disable = None )
1865
1837
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
1866
1838
1867
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
1868
- self .assertTrue (output_no_lora .shape == self .output_shape )
1869
-
1839
+ output_no_lora = self .get_base_pipe_output ()
1870
1840
pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config )
1871
1841
1872
1842
lora_scale = 0.5
@@ -2212,9 +2182,6 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
2212
2182
pipe = self .pipeline_class (** components ).to (torch_device )
2213
2183
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2214
2184
2215
- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2216
- self .assertTrue (output_no_lora .shape == self .output_shape )
2217
-
2218
2185
pipe , _ = self .add_adapters_to_pipeline (
2219
2186
pipe , text_lora_config = text_lora_config , denoiser_lora_config = denoiser_lora_config
2220
2187
)
@@ -2260,7 +2227,7 @@ def test_inference_load_delete_load_adapters(self):
2260
2227
pipe .set_progress_bar_config (disable = None )
2261
2228
_ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2262
2229
2263
- output_no_lora = pipe ( ** inputs , generator = torch . manual_seed ( 0 ))[ 0 ]
2230
+ output_no_lora = self . get_base_pipe_output ()
2264
2231
2265
2232
if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
2266
2233
pipe .text_encoder .add_adapter (text_lora_config )
0 commit comments