-
Notifications
You must be signed in to change notification settings - Fork 6.4k
use kernels to support _flash_hub
attention backend
#12318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for this PR. Could you update it with some code examples and results?
This is the test command, but unable to generate images.
import os os.environ["DIFFUSERS_ENABLE_HUB_KERNELS"] = "yes" # Debug: Verify the env var is set print(f"DIFFUSERS_ENABLE_HUB_KERNELS = {os.environ.get('DIFFUSERS_ENABLE_HUB_KERNELS')}") import torch from diffusers import FluxPipeline from diffusers.quantizers import PipelineQuantizationConfig # Debug: Check if diffusers sees the env var from diffusers.models.attention_dispatch import DIFFUSERS_ENABLE_HUB_KERNELS print(f"Diffusers sees DIFFUSERS_ENABLE_HUB_KERNELS = {DIFFUSERS_ENABLE_HUB_KERNELS}") # ✅ 3. Load pipeline with quantization model_id = "black-forest-labs/FLUX.1-dev" pipe = FluxPipeline.from_pretrained( model_id, torch_dtype=torch.bfloat16, quantization_config=PipelineQuantizationConfig( quant_backend="bitsandbytes_4bit", quant_kwargs={ "load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16, }, components_to_quantize=["transformer"], ), ).to("cuda") pipe.transformer.set_attention_backend("_flash_hub") prompt = "A cat holding a sign that says 'hello world'" image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] image.save("output.png")
I'm having issues regarding some of the parameters with the following traceback:
Traceback (most recent call last):
File "/teamspace/studios/this_studio/diffusers/main.py", line 34, in <module>
image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/teamspace/studios/this_studio/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 944, in __call__
noise_pred = self.transformer(
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 720, in forward
encoder_hidden_states, hidden_states = block(
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 443, in forward
attention_outputs = self.attn(
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 342, in forward
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 116, in __call__
hidden_states = dispatch_attention_fn(
File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/attention_dispatch.py", line 304, in dispatch_attention_fn
return backend_fn(**kwargs)
File "/teamspace/studios/this_studio/diffusers/src/diffusers/models/attention_dispatch.py", line 765, in _flash_attention_hub
out = flash_attn_func_hub(
TypeError: flash_attn_func() got an unexpected keyword argument 'alibi_slopes'
The same error occurs with dropout_p parameter as well. WDYT?
cc: @sayakpaul
@ParagEkbote I think we can close this PR in favor of #12387. You're more than welcome to test the PR and let us know of any feedback.
@sayakpaul Thanks for letting me know and being a patient reviewer. Closing the PR..
Uh oh!
There was an error while loading. Please reload this page.
What does this PR do?
As discussed in the issue, this PR adds support for kernels-community/flash-attn kernel. Could you please review?
Fixes #12308
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul