Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[core] support flash attention through kernels #12387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
sayakpaul wants to merge 4 commits into main
base: main
Choose a base branch
Loading
from fa-hub
Open

[core] support flash attention through kernels #12387

sayakpaul wants to merge 4 commits into main from fa-hub

Conversation

Copy link
Member

@sayakpaul sayakpaul commented Sep 25, 2025

What does this PR do?

Follow-up of #12236.

Testing code:

import torch
from diffusers import FluxPipeline
model_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(
 model_id, torch_dtype=torch.bfloat16
).to("cuda")
pipe.transformer.set_attention_backend("flash_hub")
pipe.transformer.compile(fullgraph=True)
prompt = "A cat holding a sign that says 'hello world'"
with torch._dynamo.config.patch(error_on_recompile=True):
 image = pipe(
 prompt, num_inference_steps=28, guidance_scale=4.0, generator=torch.manual_seed(0)
 ).images[0]
 image.save("output.png")

Tip

Works with torch.compile fullgraph compatibility.

I have tested the code on H100 and A100, and it works.

Fotosmile reacted with rocket emoji
@sayakpaul sayakpaul added the performance Anything related to performance improvements, profiling and benchmarking label Sep 25, 2025
# `flash-attn`
FLASH = "flash"
FLASH_VARLEN = "flash_varlen"
FLASH_HUB = "flash_hub"
Copy link
Member Author

@sayakpaul sayakpaul Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flash Attention is stable. So, we don't have to mark it private like FA3.

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool integration 🔥 ! I just left some nits

Comment on lines +85 to +88
fa3_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = fa3_interface_hub.flash_attn_func
fa_interface_hub = _get_fa_from_hub()
flash_attn_func_hub = fa_interface_hub.flash_attn_func
Copy link

@MekkCyber MekkCyber Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we fetching both kernels here ?

Copy link
Member Author

@sayakpaul sayakpaul Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the way APIs for attention backends are designed and also to support torch.compile with fullgraph traceability (when possible).

We will let it grow a bit and upon feedback, we can revisit how to better deal with this.

MekkCyber reacted with thumbs up emoji
FLASH = "flash"
FLASH_VARLEN = "flash_varlen"
FLASH_HUB = "flash_hub"
# FLASH_VARLEN_HUB = "flash_varlen_hub" # not supported yet.
Copy link

@MekkCyber MekkCyber Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this related to the kernel or it just needs more time to be integrated ?

Copy link
Member Author

@sayakpaul sayakpaul Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have models that use varlen.

MekkCyber reacted with thumbs up emoji
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul qwen image uses varlen. also, native fused qkv+mlp attn requires varlen function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Reviewers

@DN6 DN6 Awaiting requested review from DN6

2 more reviewers

@bghira bghira bghira left review comments

@MekkCyber MekkCyber MekkCyber left review comments

Reviewers whose approvals may not affect merge requirements

At least 1 approving review is required to merge this pull request.

Assignees
No one assigned
Labels
performance Anything related to performance improvements, profiling and benchmarking
Projects
None yet
Milestone
No milestone
Development

Successfully merging this pull request may close these issues.

AltStyle によって変換されたページ (->オリジナル) /