Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit baa9b58

Browse files
sayakpaulDN6
andauthored
[core] parallel loading of shards (#12028)
* checking. * checking * checking * up * up * up * Apply suggestions from code review Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * up * up * fix * review feedback. --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent da096a4 commit baa9b58

File tree

10 files changed

+251
-66
lines changed

10 files changed

+251
-66
lines changed

‎src/diffusers/loaders/single_file_model.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
if is_accelerate_available():
6363
from accelerate import dispatch_model, init_empty_weights
6464

65-
from ..models.modeling_utils import load_model_dict_into_meta
65+
from ..models.model_loading_utils import load_model_dict_into_meta
6666

6767
if is_torch_version(">=", "1.9.0") and is_accelerate_available():
6868
_LOW_CPU_MEM_USAGE_DEFAULT = True

‎src/diffusers/loaders/single_file_utils.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
if is_accelerate_available():
5656
from accelerate import init_empty_weights
5757

58-
from ..models.modeling_utils import load_model_dict_into_meta
58+
from ..models.model_loading_utils import load_model_dict_into_meta
5959

6060
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
6161

‎src/diffusers/loaders/transformer_flux.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
ImageProjection,
1818
MultiIPAdapterImageProjection,
1919
)
20-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
20+
from ..models.model_loading_utils import load_model_dict_into_meta
21+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
2122
from ..utils import is_accelerate_available, is_torch_version, logging
2223
from ..utils.torch_utils import empty_device_cache
2324

‎src/diffusers/loaders/transformer_sd3.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
1818
from ..models.embeddings import IPAdapterTimeImageProjection
19-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
19+
from ..models.model_loading_utils import load_model_dict_into_meta
20+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
2021
from ..utils import is_accelerate_available, is_torch_version, logging
2122
from ..utils.torch_utils import empty_device_cache
2223

‎src/diffusers/loaders/unet.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
IPAdapterPlusImageProjection,
3131
MultiIPAdapterImageProjection,
3232
)
33-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
33+
from ..models.model_loading_utils import load_model_dict_into_meta
34+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
3435
from ..utils import (
3536
USE_PEFT_BACKEND,
3637
_get_model_file,

‎src/diffusers/models/model_loading_utils.py‎

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import functools
1718
import importlib
1819
import inspect
1920
import math
2021
import os
2122
from array import array
2223
from collections import OrderedDict, defaultdict
24+
from concurrent.futures import ThreadPoolExecutor, as_completed
2325
from pathlib import Path
2426
from typing import Dict, List, Optional, Union
2527
from zipfile import is_zipfile
@@ -31,6 +33,7 @@
3133

3234
from ..quantizers import DiffusersQuantizer
3335
from ..utils import (
36+
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
3437
GGUF_FILE_EXTENSION,
3538
SAFE_WEIGHTS_INDEX_NAME,
3639
SAFETENSORS_FILE_EXTENSION,
@@ -310,6 +313,161 @@ def load_model_dict_into_meta(
310313
return offload_index, state_dict_index
311314

312315

316+
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
317+
"""
318+
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
319+
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
320+
parameters.
321+
322+
"""
323+
if model_to_load.device.type == "meta":
324+
return False
325+
326+
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
327+
return False
328+
329+
# Some models explicitly do not support param buffer assignment
330+
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
331+
logger.debug(
332+
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
333+
)
334+
return False
335+
336+
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
337+
first_key = next(iter(model_to_load.state_dict().keys()))
338+
if start_prefix + first_key in state_dict:
339+
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
340+
341+
return False
342+
343+
344+
def _load_shard_file(
345+
shard_file,
346+
model,
347+
model_state_dict,
348+
device_map=None,
349+
dtype=None,
350+
hf_quantizer=None,
351+
keep_in_fp32_modules=None,
352+
dduf_entries=None,
353+
loaded_keys=None,
354+
unexpected_keys=None,
355+
offload_index=None,
356+
offload_folder=None,
357+
state_dict_index=None,
358+
state_dict_folder=None,
359+
ignore_mismatched_sizes=False,
360+
low_cpu_mem_usage=False,
361+
):
362+
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
363+
mismatched_keys = _find_mismatched_keys(
364+
state_dict,
365+
model_state_dict,
366+
loaded_keys,
367+
ignore_mismatched_sizes,
368+
)
369+
error_msgs = []
370+
if low_cpu_mem_usage:
371+
offload_index, state_dict_index = load_model_dict_into_meta(
372+
model,
373+
state_dict,
374+
device_map=device_map,
375+
dtype=dtype,
376+
hf_quantizer=hf_quantizer,
377+
keep_in_fp32_modules=keep_in_fp32_modules,
378+
unexpected_keys=unexpected_keys,
379+
offload_folder=offload_folder,
380+
offload_index=offload_index,
381+
state_dict_index=state_dict_index,
382+
state_dict_folder=state_dict_folder,
383+
)
384+
else:
385+
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
386+
387+
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
388+
return offload_index, state_dict_index, mismatched_keys, error_msgs
389+
390+
391+
def _load_shard_files_with_threadpool(
392+
shard_files,
393+
model,
394+
model_state_dict,
395+
device_map=None,
396+
dtype=None,
397+
hf_quantizer=None,
398+
keep_in_fp32_modules=None,
399+
dduf_entries=None,
400+
loaded_keys=None,
401+
unexpected_keys=None,
402+
offload_index=None,
403+
offload_folder=None,
404+
state_dict_index=None,
405+
state_dict_folder=None,
406+
ignore_mismatched_sizes=False,
407+
low_cpu_mem_usage=False,
408+
):
409+
# Do not spawn anymore workers than you need
410+
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
411+
412+
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
413+
414+
error_msgs = []
415+
mismatched_keys = []
416+
417+
load_one = functools.partial(
418+
_load_shard_file,
419+
model=model,
420+
model_state_dict=model_state_dict,
421+
device_map=device_map,
422+
dtype=dtype,
423+
hf_quantizer=hf_quantizer,
424+
keep_in_fp32_modules=keep_in_fp32_modules,
425+
dduf_entries=dduf_entries,
426+
loaded_keys=loaded_keys,
427+
unexpected_keys=unexpected_keys,
428+
offload_index=offload_index,
429+
offload_folder=offload_folder,
430+
state_dict_index=state_dict_index,
431+
state_dict_folder=state_dict_folder,
432+
ignore_mismatched_sizes=ignore_mismatched_sizes,
433+
low_cpu_mem_usage=low_cpu_mem_usage,
434+
)
435+
436+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
437+
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
438+
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
439+
for future in as_completed(futures):
440+
result = future.result()
441+
offload_index, state_dict_index, _mismatched_keys, _error_msgs = result
442+
error_msgs += _error_msgs
443+
mismatched_keys += _mismatched_keys
444+
pbar.update(1)
445+
446+
return offload_index, state_dict_index, mismatched_keys, error_msgs
447+
448+
449+
def _find_mismatched_keys(
450+
state_dict,
451+
model_state_dict,
452+
loaded_keys,
453+
ignore_mismatched_sizes,
454+
):
455+
mismatched_keys = []
456+
if ignore_mismatched_sizes:
457+
for checkpoint_key in loaded_keys:
458+
model_key = checkpoint_key
459+
# If the checkpoint is sharded, we may not have the key here.
460+
if checkpoint_key not in state_dict:
461+
continue
462+
463+
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
464+
mismatched_keys.append(
465+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
466+
)
467+
del state_dict[checkpoint_key]
468+
return mismatched_keys
469+
470+
313471
def _load_state_dict_into_model(
314472
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
315473
) -> List[str]:

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /