Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit d7a1a03

Browse files
[docs] CP (#12331)
* init * feedback * feedback * feedback * feedback * feedback * feedback
1 parent b596545 commit d7a1a03

File tree

3 files changed

+63
-7
lines changed

3 files changed

+63
-7
lines changed

‎docs/source/en/_toctree.yml‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@
7070
title: Reduce memory usage
7171
- local: optimization/speed-memory-optims
7272
title: Compiling and offloading quantized models
73-
- local: api/parallel
74-
title: Parallel inference
7573
- title: Community optimizations
7674
sections:
7775
- local: optimization/pruna
@@ -282,6 +280,8 @@
282280
title: Outputs
283281
- local: api/quantization
284282
title: Quantization
283+
- local: api/parallel
284+
title: Parallel inference
285285
- title: Modular
286286
sections:
287287
- local: api/modular_diffusers/pipeline

‎docs/source/en/api/parallel.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License. -->
1111

1212
# Parallelism
1313

14-
Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times.
14+
Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times. Refer to the [Distributed inferece](../training/distributed_inference) guide to learn more.
1515

1616
## ParallelConfig
1717

‎docs/source/en/training/distributed_inference.md‎

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,64 @@ with torch.no_grad():
226226
image[0].save("split_transformer.png")
227227
```
228228

229-
## Resources
229+
By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.
230230

231-
- Take a look at this [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for a minimal example of distributed inference with Accelerate.
232-
- For more details, check out Accelerate's [Distributed inference](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
233-
- The `device_map` argument assign models or an entire pipeline to devices. Refer to the [device placement](../using-diffusers/loading#device-placement) docs for more information.
231+
## Context parallelism
232+
233+
[Context parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism) splits input sequences across multiple GPUs to reduce memory usage. Each GPU processes its own slice of the sequence.
234+
235+
Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.
236+
237+
### Ring Attention
238+
239+
Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
240+
241+
Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transformer model. The config supports the `ring_degree` argument that determines how many devices to use for Ring Attention.
242+
243+
```py
244+
import torch
245+
from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig
246+
247+
try:
248+
torch.distributed.init_process_group("nccl")
249+
rank = torch.distributed.get_rank()
250+
device = torch.device("cuda", rank % torch.cuda.device_count())
251+
torch.cuda.set_device(device)
252+
253+
transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
254+
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
255+
pipeline.transformer.set_attention_backend("flash")
256+
257+
prompt = """
258+
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
259+
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
260+
"""
261+
262+
# Must specify generator so all ranks start with same latents (or pass your own)
263+
generator = torch.Generator().manual_seed(42)
264+
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
265+
266+
if rank == 0:
267+
image.save("output.png")
268+
269+
except Exception as e:
270+
print(f"An error occurred: {e}")
271+
torch.distributed.breakpoint()
272+
raise
273+
274+
finally:
275+
if torch.distributed.is_initialized():
276+
torch.distributed.destroy_process_group()
277+
```
278+
279+
### Ulysses Attention
280+
281+
[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.
282+
283+
[`ContextParallelConfig`] supports Ulysses Attention through the `ulysses_degree` argument. This determines how many devices to use for Ulysses Attention.
284+
285+
Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
286+
287+
```py
288+
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
289+
```

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /