Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 9c3b58d

Browse files
authored
Handle deprecated transformer classes (#12517)
* update * update * update
1 parent 74b5fed commit 9c3b58d

File tree

3 files changed

+64
-1
lines changed

3 files changed

+64
-1
lines changed

‎src/diffusers/pipelines/pipeline_loading_utils.py‎

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ONNX_WEIGHTS_NAME,
3434
SAFETENSORS_WEIGHTS_NAME,
3535
WEIGHTS_NAME,
36+
_maybe_remap_transformers_class,
3637
deprecate,
3738
get_class_from_dynamic_module,
3839
is_accelerate_available,
@@ -356,6 +357,11 @@ def maybe_raise_or_warn(
356357
"""Simple helper method to raise or warn in case incorrect module has been passed"""
357358
if not is_pipeline_module:
358359
library = importlib.import_module(library_name)
360+
361+
# Handle deprecated Transformers classes
362+
if library_name == "transformers":
363+
class_name = _maybe_remap_transformers_class(class_name) or class_name
364+
359365
class_obj = getattr(library, class_name)
360366
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
361367

@@ -390,6 +396,11 @@ def simple_get_class_obj(library_name, class_name):
390396
class_obj = getattr(pipeline_module, class_name)
391397
else:
392398
library = importlib.import_module(library_name)
399+
400+
# Handle deprecated Transformers classes
401+
if library_name == "transformers":
402+
class_name = _maybe_remap_transformers_class(class_name) or class_name
403+
393404
class_obj = getattr(library, class_name)
394405

395406
return class_obj
@@ -416,6 +427,10 @@ def get_class_obj_and_candidates(
416427
# else we just import it from the library.
417428
library = importlib.import_module(library_name)
418429

430+
# Handle deprecated Transformers classes
431+
if library_name == "transformers":
432+
class_name = _maybe_remap_transformers_class(class_name) or class_name
433+
419434
class_obj = getattr(library, class_name)
420435
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
421436

‎src/diffusers/utils/__init__.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
WEIGHTS_INDEX_NAME,
3939
WEIGHTS_NAME,
4040
)
41-
from .deprecation_utils import deprecate
41+
from .deprecation_utils import _maybe_remap_transformers_class, deprecate
4242
from .doc_utils import replace_example_docstring
4343
from .dynamic_modules_utils import get_class_from_dynamic_module
4444
from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video

‎src/diffusers/utils/deprecation_utils.py‎

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,54 @@
44

55
from packaging import version
66

7+
from ..utils import logging
8+
9+
10+
logger = logging.get_logger(__name__)
11+
12+
# Mapping for deprecated Transformers classes to their replacements
13+
# This is used to handle models that reference deprecated class names in their configs
14+
# Reference: https://github.com/huggingface/transformers/issues/40822
15+
# Format: {
16+
# "DeprecatedClassName": {
17+
# "new_class": "NewClassName",
18+
# "transformers_version": (">=", "5.0.0"), # (operation, version) tuple
19+
# }
20+
# }
21+
_TRANSFORMERS_CLASS_REMAPPING = {
22+
"CLIPFeatureExtractor": {
23+
"new_class": "CLIPImageProcessor",
24+
"transformers_version": (">", "4.57.0"),
25+
},
26+
}
27+
28+
29+
def _maybe_remap_transformers_class(class_name: str) -> Optional[str]:
30+
"""
31+
Check if a Transformers class should be remapped to a newer version.
32+
33+
Args:
34+
class_name: The name of the class to check
35+
36+
Returns:
37+
The new class name if remapping should occur, None otherwise
38+
"""
39+
if class_name not in _TRANSFORMERS_CLASS_REMAPPING:
40+
return None
41+
42+
from .import_utils import is_transformers_version
43+
44+
mapping = _TRANSFORMERS_CLASS_REMAPPING[class_name]
45+
operation, required_version = mapping["transformers_version"]
46+
47+
# Only remap if the transformers version meets the requirement
48+
if is_transformers_version(operation, required_version):
49+
new_class = mapping["new_class"]
50+
logger.warning(f"{class_name} appears to have been deprecated in transformers. Using {new_class} instead.")
51+
return mapping["new_class"]
52+
53+
return None
54+
755

856
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
957
from .. import __version__

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /