Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

use kernels to support _flash_hub attention backend #12318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ParagEkbote wants to merge 9 commits into huggingface:main from ParagEkbote:Add-FA2

Conversation

Copy link
Contributor

@ParagEkbote ParagEkbote commented Sep 11, 2025
edited
Loading

What does this PR do?

As discussed in the issue, this PR adds support for kernels-community/flash-attn kernel. Could you please review?

Fixes #12308

Before submitting

Who can review?

@sayakpaul

Copy link
Member

Thanks for this PR. Could you update it with some code examples and results?

ParagEkbote reacted with thumbs up emoji

Copy link
Contributor Author

ParagEkbote commented Sep 11, 2025
edited
Loading

This is the test command, but unable to generate images.

import os
os.environ["DIFFUSERS_ENABLE_HUB_KERNELS"] = "yes"
# Debug: Verify the env var is set
print(f"DIFFUSERS_ENABLE_HUB_KERNELS = {os.environ.get('DIFFUSERS_ENABLE_HUB_KERNELS')}")
import torch
from diffusers import FluxPipeline
from diffusers.quantizers import PipelineQuantizationConfig
# Debug: Check if diffusers sees the env var
from diffusers.models.attention_dispatch import DIFFUSERS_ENABLE_HUB_KERNELS
print(f"Diffusers sees DIFFUSERS_ENABLE_HUB_KERNELS = {DIFFUSERS_ENABLE_HUB_KERNELS}")
# ✅ 3. Load pipeline with quantization
model_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(
 model_id,
 torch_dtype=torch.bfloat16,
 quantization_config=PipelineQuantizationConfig(
 quant_backend="bitsandbytes_4bit",
 quant_kwargs={
 "load_in_4bit": True,
 "bnb_4bit_quant_type": "nf4",
 "bnb_4bit_compute_dtype": torch.bfloat16,
 },
 components_to_quantize=["transformer"],
 ),
).to("cuda")
pipe.transformer.set_attention_backend("_flash_hub")
prompt = "A cat holding a sign that says 'hello world'"
image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
image.save("output.png")

Copy link
Contributor Author

I'm having issues regarding some of the parameters with the following traceback:

Traceback (most recent call last):
 File "/teamspace/studios/this_studio/diffusers/main.py", line 34, in <module>
 image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
 File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
 return func(*args, **kwargs)
 File "/teamspace/studios/this_studio/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 944, in __call__
 noise_pred = self.transformer(
 File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
 return self._call_impl(*args, **kwargs)
 File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
 return forward_call(*args, **kwargs)
 File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 720, in forward
 encoder_hidden_states, hidden_states = block(
 File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
 return self._call_impl(*args, **kwargs)
 File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
 return forward_call(*args, **kwargs)
 File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 443, in forward
 attention_outputs = self.attn(
 File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
 return self._call_impl(*args, **kwargs)
 File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
 return forward_call(*args, **kwargs)
 File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 342, in forward
 return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
 File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 116, in __call__
 hidden_states = dispatch_attention_fn(
 File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/attention_dispatch.py", line 304, in dispatch_attention_fn
 return backend_fn(**kwargs)
 File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/attention_dispatch.py", line 765, in _flash_attention_hub
 out = flash_attn_func_hub(
TypeError: flash_attn_func() got an unexpected keyword argument 'alibi_slopes'

The same error occurs with dropout_p parameter as well. WDYT?

cc: @sayakpaul

Copy link
Member

@ParagEkbote I think we can close this PR in favor of #12387. You're more than welcome to test the PR and let us know of any feedback.

ParagEkbote reacted with thumbs up emoji

Copy link
Contributor Author

@sayakpaul Thanks for letting me know and being a patient reviewer. Closing the PR..

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Reviewers
No reviews
Assignees
No one assigned
Labels
None yet
Projects
None yet
Milestone
No milestone
Development

Successfully merging this pull request may close these issues.

Support flash-attn kernel support for non-Hopper GPUs

AltStyle によって変換されたページ (->オリジナル) /