Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 6549b04

Browse files
[docs] AutoPipeline (#12160)
* refresh * feedback * feedback * supported models * fix
1 parent 130fd8d commit 6549b04

File tree

1 file changed

+29
-85
lines changed

1 file changed

+29
-85
lines changed

‎docs/source/en/tutorials/autopipeline.md

Lines changed: 29 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -12,112 +12,56 @@ specific language governing permissions and limitations under the License.
1212

1313
# AutoPipeline
1414

15-
Diffusers provides many pipelines for basic tasks like generating images, videos, audio, and inpainting. On top of these, there are specialized pipelines for adapters and features like upscaling, super-resolution, and more. Different pipeline classes can even use the same checkpoint because they share the same pretrained model! With so many different pipelines, it can be overwhelming to know which pipeline class to use.
15+
[AutoPipeline](../api/models/auto_model) is a *task-and-model*pipeline that automatically selects the correct pipeline subclass based on the task. It handles the complexity of loading different pipeline subclasses without needing to know the specific pipeline subclass name.
1616

17-
The [AutoPipeline](../api/pipelines/auto_pipeline) class is designed to simplify the variety of pipelines in Diffusers. It is a generic *task-first* pipeline that lets you focus on a task ([`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`], and [`AutoPipelineForInpainting`]) without needing to know the specific pipeline class. The [AutoPipeline](../api/pipelines/auto_pipeline) automatically detects the correct pipeline class to use.
17+
This is unlike [`DiffusionPipeline`], a *model-only* pipeline that automatically selects the pipeline subclass based on the model.
1818

19-
For example, let's use the [dreamlike-art/dreamlike-photoreal-2.0](https://hf.co/dreamlike-art/dreamlike-photoreal-2.0) checkpoint.
20-
21-
Under the hood, [AutoPipeline](../api/pipelines/auto_pipeline):
22-
23-
1. Detects a `"stable-diffusion"` class from the [model_index.json](https://hf.co/dreamlike-art/dreamlike-photoreal-2.0/blob/main/model_index.json) file.
24-
2. Depending on the task you're interested in, it loads the [`StableDiffusionPipeline`], [`StableDiffusionImg2ImgPipeline`], or [`StableDiffusionInpaintPipeline`]. Any parameter (`strength`, `num_inference_steps`, etc.) you would pass to these specific pipelines can also be passed to the [AutoPipeline](../api/pipelines/auto_pipeline).
25-
26-
<hfoptions id="autopipeline">
27-
<hfoption id="text-to-image">
19+
[`AutoPipelineForImage2Image`] returns a specific pipeline subclass, (for example, [`StableDiffusionXLImg2ImgPipeline`]), which can only be used for image-to-image tasks.
2820

2921
```py
30-
from diffusers import AutoPipelineForText2Image
3122
import torch
32-
33-
pipe_txt2img = AutoPipelineForText2Image.from_pretrained(
34-
"dreamlike-art/dreamlike-photoreal-2.0", torch_dtype=torch.float16, use_safetensors=True
35-
).to("cuda")
36-
37-
prompt = "cinematic photo of Godzilla eating sushi with a cat in a izakaya, 35mm photograph, film, professional, 4k, highly detailed"
38-
generator = torch.Generator(device="cpu").manual_seed(37)
39-
image = pipe_txt2img(prompt, generator=generator).images[0]
40-
image
41-
```
42-
43-
<div class="flex justify-center">
44-
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-text2img.png"/>
45-
</div>
46-
47-
</hfoption>
48-
<hfoption id="image-to-image">
49-
50-
```py
5123
from diffusers import AutoPipelineForImage2Image
52-
from diffusers.utils import load_image
53-
import torch
54-
55-
pipe_img2img = AutoPipelineForImage2Image.from_pretrained(
56-
"dreamlike-art/dreamlike-photoreal-2.0", torch_dtype=torch.float16, use_safetensors=True
57-
).to("cuda")
58-
59-
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-text2img.png")
60-
61-
prompt = "cinematic photo of Godzilla eating burgers with a cat in a fast food restaurant, 35mm photograph, film, professional, 4k, highly detailed"
62-
generator = torch.Generator(device="cpu").manual_seed(53)
63-
image = pipe_img2img(prompt, image=init_image, generator=generator).images[0]
64-
image
65-
```
66-
67-
Notice how the [dreamlike-art/dreamlike-photoreal-2.0](https://hf.co/dreamlike-art/dreamlike-photoreal-2.0) checkpoint is used for both text-to-image and image-to-image tasks? To save memory and avoid loading the checkpoint twice, use the [`~DiffusionPipeline.from_pipe`] method.
6824

69-
```py
70-
pipe_img2img = AutoPipelineForImage2Image.from_pipe(pipe_txt2img).to("cuda")
71-
image = pipeline(prompt, image=init_image, generator=generator).images[0]
72-
image
25+
pipeline = AutoPipelineForImage2Image.from_pretrained(
26+
"RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.bfloat16, device_map="cuda",
27+
)
28+
print(pipeline)
29+
"StableDiffusionXLImg2ImgPipeline {
30+
"_class_name": "StableDiffusionXLImg2ImgPipeline",
31+
...
32+
"
7333
```
7434

75-
You can learn more about the [`~DiffusionPipeline.from_pipe`] method in the [Reuse a pipeline](../using-diffusers/loading#reuse-a-pipeline) guide.
76-
77-
<div class="flex justify-center">
78-
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-img2img.png"/>
79-
</div>
80-
81-
</hfoption>
82-
<hfoption id="inpainting">
35+
Loading the same model with [`DiffusionPipeline`] returns the [`StableDiffusionXLPipeline`] subclass. It can be used for text-to-image, image-to-image, or inpainting tasks depending on the inputs.
8336

8437
```py
85-
from diffusers import AutoPipelineForInpainting
86-
from diffusers.utils import load_image
8738
import torch
39+
from diffusers import DiffusionPipeline
8840

89-
pipeline = AutoPipelineForInpainting.from_pretrained(
90-
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True
91-
).to("cuda")
92-
93-
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-img2img.png")
94-
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-mask.png")
95-
96-
prompt = "cinematic photo of a owl, 35mm photograph, film, professional, 4k, highly detailed"
97-
generator = torch.Generator(device="cpu").manual_seed(38)
98-
image = pipeline(prompt, image=init_image, mask_image=mask_image, generator=generator, strength=0.4).images[0]
99-
image
41+
pipeline = DiffusionPipeline.from_pretrained(
42+
"RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.bfloat16, device_map="cuda",
43+
)
44+
print(pipeline)
45+
"StableDiffusionXLPipeline {
46+
"_class_name": "StableDiffusionXLPipeline",
47+
...
48+
"
10049
```
10150

102-
<div class="flex justify-center">
103-
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png"/>
104-
</div>
51+
Check the [mappings](https://github.com/huggingface/diffusers/blob/130fd8df54f24ffb006d84787b598d8adc899f23/src/diffusers/pipelines/auto_pipeline.py#L114) to see whether a model is supported or not.
10552

106-
</hfoption>
107-
</hfoptions>
108-
109-
## Unsupported checkpoints
110-
111-
The [AutoPipeline](../api/pipelines/auto_pipeline) supports [Stable Diffusion](../api/pipelines/stable_diffusion/overview), [Stable Diffusion XL](../api/pipelines/stable_diffusion/stable_diffusion_xl), [ControlNet](../api/pipelines/controlnet), [Kandinsky 2.1](../api/pipelines/kandinsky.md), [Kandinsky 2.2](../api/pipelines/kandinsky_v22), and [DeepFloyd IF](../api/pipelines/deepfloyd_if) checkpoints.
112-
113-
If you try to load an unsupported checkpoint, you'll get an error.
53+
Trying to load an unsupported model returns an error.
11454

11555
```py
116-
from diffusers import AutoPipelineForImage2Image
11756
import torch
57+
from diffusers import AutoPipelineForImage2Image
11858

11959
pipeline = AutoPipelineForImage2Image.from_pretrained(
120-
"openai/shap-e-img2img", torch_dtype=torch.float16,use_safetensors=True
60+
"openai/shap-e-img2img", torch_dtype=torch.float16,
12161
)
12262
"ValueError: AutoPipeline can't find a pipeline linked to ShapEImg2ImgPipeline for None"
12363
```
64+
65+
There are three types of [AutoPipeline](../api/models/auto_model) classes, [`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`] and [`AutoPipelineForInpainting`]. Each of these classes have a predefined mapping, linking a pipeline to their task-specific subclass.
66+
67+
When [`~AutoPipelineForText2Image.from_pretrained`] is called, it extracts the class name from the `model_index.json` file and selects the appropriate pipeline subclass for the task based on the mapping.

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /