Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

refactor: backend_requirement + supported_compute_capability decorator for gemm #2000

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jimmyzho wants to merge 3 commits into flashinfer-ai:main
base: main
Choose a base branch
Loading
from jimmyzho:bknd

Conversation

@jimmyzho
Copy link
Contributor

@jimmyzho jimmyzho commented Oct 29, 2025
edited by coderabbitai bot
Loading

πŸ“Œ Description

πŸ” Related Issues

πŸš€ Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

βœ… Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

πŸ§ͺ Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Broader FP8 support: added groupwise, blockscaled, DeepGEMM and batch/grouped FP8 GEMM paths across more GPU targets (SM100, SM103, SM120, SM121, TRTLLM).
    • New public entry points for multiple FP8 GEMM variants and backend-dispatched paths.
  • Improvements

    • Strengthened input validation with clear error handling for shapes, dtypes and contiguity.
    • Centralized backend capability checks and more informative failure reporting.

Copy link
Contributor

coderabbitai bot commented Oct 29, 2025
edited
Loading

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Added explicit input validation and compute-capability guards for multiple FP8 GEMM/grouped/NT-groupwise paths. Converted in-function asserts to explicit checks and introduced backend_requirement decorators plus standalone problem-size validators to gate backend dispatch and return clear errors on invalid inputs.

Changes

Cohort / File(s) Summary
Grouped FP8 GEMM Validation (deep_gemm.py)
flashinfer/deep_gemm.py
Added _check_group_deepgemm_fp8_nt_contiguous_problem_size() and _check_m_grouped_fp8_gemm_nt_masked_problem_size() validators; annotated m_grouped_fp8_gemm_nt_contiguous() and m_grouped_fp8_gemm_nt_masked() with @backend_requirement; replaced runtime assert calls with explicit ValueError checks; imported supported_compute_capability and backend_requirement.
FP8 GEMM Backend & NT-Groupwise Paths (gemm.py)
flashinfer/gemm.py
Introduced backend requirement hooks (_cudnn_bmm_fp8_requirement, _cublas_bmm_fp8_requirement, _cutlass_bmm_fp8_requirement, _cutlass_gemm_fp8_nt_groupwise_requirement, _trtllm_gemm_fp8_nt_groupwise_requirement, etc.), many _check_* problem-size validators, and new public entry points (gemm_fp8_nt_groupwise, gemm_fp8_nt_blockscaled, group_gemm_fp8_nt_groupwise, group_gemm_mxfp8_mxfp4_nt_groupwise, batch/group deepgemm variants). Centralized validation and backend-dispatch logic; added capability annotations for SM gating.

Sequence Diagram(s)

sequenceDiagram
 participant Caller as Caller
 participant API as Public API<br/>(e.g., gemm_fp8 / m_grouped_fp8)
 participant Decorator as @backend_requirement
 participant Validator as _check_* validators
 participant Dispatcher as Backend Dispatcher
 participant Backend as Backend (cutlass/cudnn/cublas/trtllm)
 participant Error as Error
 Caller->>API: call FP8 GEMM entry
 API->>Decorator: decorated entry invoked
 Decorator->>Validator: run problem-size & capability checks
 Validator-->>Decorator: true / false
 alt checks pass
 Decorator->>Dispatcher: select backend
 Dispatcher->>Backend: invoke backend-specific implementation
 Backend-->>Dispatcher: result tensor
 Dispatcher-->>API: return result
 API-->>Caller: tensor
 else checks fail
 Decorator->>Error: raise ValueError with reason
 Error-->>Caller: exception
 end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Areas needing extra attention:
    • Correctness and coverage of each _check_* validator (contiguity, dtype, shapes, scale granularity)
    • Proper wiring of @backend_requirement dispatch and return semantics
    • Capability annotations (supported_compute_capability) mapped to appropriate SMs/backends
    • Consistency between grouped/batched API signatures and backend expectations

Suggested reviewers

  • yzh119
  • cyx-6
  • wenscarl
  • bkryu
  • nvmbreughe

Poem

🐰 I hopped through tensors, scales in tow,

Guards on SMs now steady and slow,
FP8 paths checked from end to start,
Backends chosen with cautious heart,
Now GEMM hops clean β€” swift as a sparrow. πŸ₯•

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The pull request description is largely incomplete and does not provide meaningful content. While the repository's PR template is present with proper structure, the author has left all critical sections empty: the "πŸ“Œ Description" section contains only an HTML comment placeholder with no explanation of what the changes do or why they are needed, and the "πŸ” Related Issues" section is also empty. The checklist sections are present but lack substantive information about the pre-commit checks or tests performed. No "Reviewer Notes" section content is provided despite the placeholder being present. The raw summary reveals substantial refactoring work across two files with significant API additions and behavioral changes that should be explained in the PR description, but none of this context is captured by the author. The author should fill in the "πŸ“Œ Description" section with a clear explanation of what changes are being made and whyβ€”specifically describing the introduction of the @backend_requirement and @supported_compute_capability decorators, the problem-size validation functions, and the refactoring of FP8 GEMM operations. The "πŸ” Related Issues" section should be populated with links to any related GitHub issues. Additionally, the author should confirm the pre-commit and test checklist items by checking the relevant boxes and can optionally add details about what tests were run or any concerns reviewers should focus on in the "Reviewer Notes" section.
Docstring Coverage ⚠️ Warning Docstring coverage is 27.59% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
βœ… Passed checks (1 passed)
Check name Status Explanation
Title Check βœ… Passed The pull request title "refactor: backend_requirement + supported_compute_capability decorator for gemm" clearly and specifically describes the main refactoring work. It accurately reflects the primary changes shown in the raw summary: introducing @backend_requirement and @supported_compute_capability decorators to wrap GEMM entry points, adding capability gating, and implementing problem-size validation functions. The title is concise, avoids generic terminology, and provides sufficient context for someone reviewing the commit history to understand the nature of the refactoring.
✨ Finishing touches
  • πŸ“ Generate docstrings
πŸ§ͺ Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

Summary of Changes

Hello @jimmyzho, I'm Gemini Code Assist1 ! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly refactors the General Matrix Multiply (GEMM) operations by implementing a new decorator-based system. The primary goal is to streamline and consolidate the validation of inputs and hardware compatibility requirements for various GEMM functions. By abstracting these checks into decorators and dedicated problem-size validation functions, the core logic of the GEMM implementations becomes cleaner and more focused, while ensuring consistent and informative error handling across the codebase.

Highlights

  • New Decorator System: Introduced @supported_compute_capability and @backend_requirement decorators to standardize and centralize input validation and hardware compatibility checks for GEMM operations.
  • Centralized Validation Logic: Migrated existing input validation logic, including shape, data type, contiguity, and compute capability checks, from function bodies into dedicated _check_*_problem_size functions. This improves code modularity and maintainability.
  • Improved Error Handling: Replaced assert statements with if ... raise ValueError for more robust and user-friendly error messages, providing clearer feedback when inputs do not meet requirements.
  • Backend-Specific Checks: Enabled granular, backend-specific validation logic through the backend_requirement decorator, allowing different backends (e.g., cuDNN, cuBLAS, CUTLASS, TRTLLM) to have their own specific prerequisite checks.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with πŸ‘ and πŸ‘Ž on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a great refactoring that introduces @backend_requirement and @supported_compute_capability decorators to centralize validation logic for GEMM functions. This improves code clarity and maintainability by separating validation from the core logic. The conversion of assert statements to ValueError exceptions is also a good improvement for more robust error handling.

I've found a critical bug that would lead to a runtime error, and a couple of missing validation checks that were present in the original code. My review includes suggestions to fix these issues. Overall, the direction of the changes is excellent.

Comment on lines 2618 to 2632
if out is None:
if out_dtype is None:
out_dtype = torch.bfloat16
else:
if out_dtype is None:
out_dtype = out.dtype
if out.shape != (a.shape[0], n):
raise ValueError(f"Shape mismatch. out.shape = {out.shape}, (a.shape[0], n) = {(a.shape[0], n)}")
if out.dtype != out_dtype:
raise ValueError(f"dtype mismatch. out.dtype = {out.dtype}, out_dtype = {out_dtype}")

_validate_fp8_output_dtype(out_dtype)

n = b.shape[1]
k = b.shape[2]
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable n is used on line 2624 before it is assigned a value on line 2631. This will raise an UnboundLocalError if out is not None. The shape variables n and k should be defined before they are used in the shape checks.

Suggested change
ifoutisNone:
ifout_dtypeisNone:
out_dtype=torch.bfloat16
else:
if out_dtype is None:
out_dtype = out.dtype
ifout.shape!= (a.shape[0], n):
raiseValueError(f"Shape mismatch. out.shape = {out.shape}, (a.shape[0], n) = {(a.shape[0], n)}")
ifout.dtype!=out_dtype:
raiseValueError(f"dtype mismatch. out.dtype = {out.dtype}, out_dtype = {out_dtype}")
_validate_fp8_output_dtype(out_dtype)
n=b.shape[1]
k=b.shape[2]
n=b.shape[1]
k=b.shape[2]
ifoutisNone:
if out_dtype is None:
out_dtype = torch.bfloat16
else:
ifout_dtypeisNone:
out_dtype=out.dtype
ifout.shape!= (a.shape[0], n):
raiseValueError(f"Shape mismatch. out.shape = {out.shape}, (a.shape[0], n) = {(a.shape[0], n)}")
ifout.dtype!=out_dtype:
raiseValueError(f"dtype mismatch. out.dtype = {out.dtype}, out_dtype = {out_dtype}")
_validate_fp8_output_dtype(out_dtype)

num_groups, n, k_ = b.shape
m_, n_ = d.shape
m__ = m_indices.numel()

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check for positive dimensions n, k, and num_groups is missing. The original code had assert n > 0 and k > 0 and num_groups > 0. This check is important to prevent unexpected behavior with empty or invalid dimensions.

Suggested change
ifn<=0ork<=0ornum_groups<=0:
raise ValueError(f"n, k, and num_groups must be positive, but got n={n}, k={k}, num_groups={num_groups}")

raise ValueError(f"k must be a multiple of 16, but got {k}")

num_groups = m_indptr.shape[0] - 1

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

A check to ensure that the number of groups in b matches num_groups derived from m_indptr is missing. This was present in the original code (assert b.shape[0] == num_groups) and is important for correctness.

Suggested change
ifb.shape[0] !=num_groups:
raise ValueError(f"The number of groups in b ({b.shape[0]}) does not match the number of groups in m_indptr ({num_groups}).")

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer/deep_gemm.py (1)

1365-1410: Fix incorrect shape checks and over‐strict dtype validation in contiguous checker.

  • num_groups != m__ is wrong: m_indices is per-row (length m), not per-group. This will falsely reject valid inputs.
  • Allow d.dtype to be either bf16 or float32, which the runtime supports.

Apply:

@@
- if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__:
- raise ValueError(f"Shape mismatch. m = {m}, m_ = {m_}, k = {k}, k_ = {k_}, n = {n}, n_ = {n_}, m__ = {m__}")
+ if m != m_ or k != k_ or n != n_ or m__ != m_:
+ raise ValueError(
+ f"Shape mismatch. m={m}, m_={m_}, k={k}, k_={k_}, n={n}, n_={n_}, |m_indices|={m__}"
+ )
@@
- if d.dtype != torch.bfloat16:
- raise ValueError(f"d must be bfloat16, but got {d.dtype}")
+ if d.dtype not in (torch.bfloat16, torch.float):
+ raise ValueError(f"d must be bfloat16 or float32, got {d.dtype}")

Optionally, quiet Ruff warnings:

- a, sfa = a_fp8
- b, sfb = b_fp8
+ a, _ = a_fp8
+ b, _ = b_fp8
flashinfer/gemm.py (1)

2589-2665: Fix F821: n used before assignment in groupwise problem-size check.

Move n/k derivation before using n in output shape checks; keep dtype validation.

Apply:

@@
- if out is None:
- if out_dtype is None:
- out_dtype = torch.bfloat16
- else:
- if out_dtype is None:
- out_dtype = out.dtype
- if out.shape != (a.shape[0], n):
- raise ValueError(f"Shape mismatch. out.shape = {out.shape}, (a.shape[0], n) = {(a.shape[0], n)}")
- if out.dtype != out_dtype:
- raise ValueError(f"dtype mismatch. out.dtype = {out.dtype}, out_dtype = {out_dtype}")
- 
- _validate_fp8_output_dtype(out_dtype)
-
- n = b.shape[1]
- k = b.shape[2]
+ n = b.shape[1]
+ k = b.shape[2]
+
+ if out is None:
+ if out_dtype is None:
+ out_dtype = torch.bfloat16
+ else:
+ if out_dtype is None:
+ out_dtype = out.dtype
+ if out.shape != (a.shape[0], n):
+ raise ValueError(
+ f"Shape mismatch. out.shape={out.shape}, expected={(a.shape[0], n)}"
+ )
+ if out.dtype != out_dtype:
+ raise ValueError(f"dtype mismatch. out.dtype={out.dtype}, out_dtype={out_dtype}")
+
+ _validate_fp8_output_dtype(out_dtype)
🧹 Nitpick comments (4)
flashinfer/deep_gemm.py (2)

1428-1432: Nit: unused variable k_.

Use _ to silence RUF059.

- num_groups, n, k_ = b.shape
+ num_groups, n, _ = b.shape

1461-1514: Masked checker: align dtype, contiguity, and unused unpack.

  • Allow d.dtype to be bf16 or f32 to match runtime support.
  • Use _ for unused unpacked scale tensors.

Apply:

- a, sfa = a_fp8
- b, sfb = b_fp8
+ a, _ = a_fp8
+ b, _ = b_fp8
@@
- if d.dtype != torch.bfloat16:
- raise ValueError(f"d must be bfloat16, but got {d.dtype}")
+ if d.dtype not in (torch.bfloat16, torch.float):
+ raise ValueError(f"d must be bfloat16 or float32, got {d.dtype}")
flashinfer/gemm.py (2)

2734-2734: Silence F841: unused local.

-num_groups = m_indptr.shape[0] - 1
+_num_groups = m_indptr.shape[0] - 1

2940-2940: Silence F841: unused local.

-num_groups = m_indptr.shape[0] - 1
+_num_groups = m_indptr.shape[0] - 1
πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between 159d0a0 and 92096da.

πŸ“’ Files selected for processing (2)
  • flashinfer/deep_gemm.py (4 hunks)
  • flashinfer/gemm.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/gemm.py (2)
flashinfer/utils.py (4)
  • supported_compute_capability (772-852)
  • backend_requirement (855-1028)
  • is_sm120a_supported (504-506)
  • is_sm121a_supported (509-511)
flashinfer/deep_gemm.py (2)
  • _check_group_deepgemm_fp8_nt_contiguous_problem_size (1366-1409)
  • _check_m_grouped_fp8_gemm_nt_masked_problem_size (1462-1509)
flashinfer/deep_gemm.py (1)
flashinfer/utils.py (4)
  • ceil_div (575-586)
  • round_up (589-591)
  • supported_compute_capability (772-852)
  • backend_requirement (855-1028)
πŸͺ› GitHub Actions: pre-commit
flashinfer/gemm.py

[error] 2624-2624: ruff (F821): Undefined name 'n'. Cannot determine type of 'n'.


[error] 2625-2625: ruff (F821): Undefined name 'n'. Cannot determine type of 'n'.


[error] 2734-2734: ruff (F841): Local variable 'num_groups' is assigned to but never used.


[error] 2940-2940: ruff (F841): Local variable 'num_groups' is assigned to but never used.


[error] 1-1: Trailing whitespace detected and fixed by pre-commit hook.

flashinfer/deep_gemm.py

[error] 1411-1411: mypy: Missing positional argument 'backend_checks' in call to 'backend_requirement'.


[error] 1511-1511: mypy: Missing positional argument 'backend_checks' in call to 'backend_requirement'.


[error] 1-1: Trailing whitespace detected and fixed by pre-commit hook.

πŸͺ› Ruff (0.14.2)
flashinfer/gemm.py

2012-2012: Unused function argument: A

(ARG001)


2013-2013: Unused function argument: B

(ARG001)


2014-2014: Unused function argument: A_scale

(ARG001)


2015-2015: Unused function argument: B_scale

(ARG001)


2016-2016: Unused function argument: dtype

(ARG001)


2017-2017: Unused function argument: out

(ARG001)


2018-2018: Unused function argument: backend

(ARG001)


2026-2026: Unused function argument: A

(ARG001)


2027-2027: Unused function argument: B

(ARG001)


2028-2028: Unused function argument: A_scale

(ARG001)


2029-2029: Unused function argument: B_scale

(ARG001)


2030-2030: Unused function argument: dtype

(ARG001)


2031-2031: Unused function argument: out

(ARG001)


2032-2032: Unused function argument: backend

(ARG001)


2041-2041: Unused function argument: A_scale

(ARG001)


2042-2042: Unused function argument: B_scale

(ARG001)


2043-2043: Unused function argument: dtype

(ARG001)


2044-2044: Unused function argument: out

(ARG001)


2045-2045: Unused function argument: backend

(ARG001)


2048-2048: Avoid specifying long messages outside the exception class

(TRY003)


2052-2052: Unused function argument: A

(ARG001)


2053-2053: Unused function argument: B

(ARG001)


2054-2054: Unused function argument: A_scale

(ARG001)


2055-2055: Unused function argument: B_scale

(ARG001)


2057-2057: Unused function argument: out

(ARG001)


2058-2058: Unused function argument: backend

(ARG001)


2165-2165: Unused function argument: a

(ARG001)


2166-2166: Unused function argument: b

(ARG001)


2167-2167: Unused function argument: a_scale

(ARG001)


2168-2168: Unused function argument: b_scale

(ARG001)


2170-2170: Unused function argument: mma_sm

(ARG001)


2171-2171: Unused function argument: scale_granularity_mnk

(ARG001)


2172-2172: Unused function argument: out

(ARG001)


2173-2173: Unused function argument: out_dtype

(ARG001)


2174-2174: Unused function argument: backend

(ARG001)


2177-2177: Avoid specifying long messages outside the exception class

(TRY003)


2185-2185: Unused function argument: b

(ARG001)


2186-2186: Unused function argument: a_scale

(ARG001)


2187-2187: Unused function argument: b_scale

(ARG001)


2188-2188: Unused function argument: scale_major_mode

(ARG001)


2189-2189: Unused function argument: mma_sm

(ARG001)


2191-2191: Unused function argument: out

(ARG001)


2192-2192: Unused function argument: out_dtype

(ARG001)


2193-2193: Unused function argument: backend

(ARG001)


2196-2196: Avoid specifying long messages outside the exception class

(TRY003)


2198-2198: Avoid specifying long messages outside the exception class

(TRY003)


2206-2206: Unused function argument: a_scale

(ARG001)


2207-2207: Unused function argument: b_scale

(ARG001)


2208-2208: Unused function argument: scale_major_mode

(ARG001)


2209-2209: Unused function argument: mma_sm

(ARG001)


2210-2210: Unused function argument: scale_granularity_mnk

(ARG001)


2211-2211: Unused function argument: out

(ARG001)


2213-2213: Unused function argument: backend

(ARG001)


2216-2216: Avoid specifying long messages outside the exception class

(TRY003)


2219-2221: Avoid specifying long messages outside the exception class

(TRY003)


2595-2595: Unused function argument: scale_granularity_mnk

(ARG001)


2602-2602: Avoid specifying long messages outside the exception class

(TRY003)


2604-2604: Avoid specifying long messages outside the exception class

(TRY003)


2606-2606: Avoid specifying long messages outside the exception class

(TRY003)


2608-2608: Avoid specifying long messages outside the exception class

(TRY003)


2610-2610: Avoid specifying long messages outside the exception class

(TRY003)


2612-2612: Avoid specifying long messages outside the exception class

(TRY003)


2614-2614: Avoid specifying long messages outside the exception class

(TRY003)


2624-2624: Undefined name n

(F821)


2625-2625: Avoid specifying long messages outside the exception class

(TRY003)


2625-2625: Undefined name n

(F821)


2627-2627: Avoid specifying long messages outside the exception class

(TRY003)


2635-2635: Avoid specifying long messages outside the exception class

(TRY003)


2637-2637: Avoid specifying long messages outside the exception class

(TRY003)


2639-2639: Avoid specifying long messages outside the exception class

(TRY003)


2645-2647: Avoid specifying long messages outside the exception class

(TRY003)


2793-2793: Avoid specifying long messages outside the exception class

(TRY003)


2795-2795: Avoid specifying long messages outside the exception class

(TRY003)


2797-2797: Avoid specifying long messages outside the exception class

(TRY003)


2799-2799: Avoid specifying long messages outside the exception class

(TRY003)


2801-2801: Avoid specifying long messages outside the exception class

(TRY003)


2803-2803: Avoid specifying long messages outside the exception class

(TRY003)


2805-2805: Avoid specifying long messages outside the exception class

(TRY003)


2807-2807: Avoid specifying long messages outside the exception class

(TRY003)


2809-2809: Avoid specifying long messages outside the exception class

(TRY003)


2811-2811: Avoid specifying long messages outside the exception class

(TRY003)


2822-2822: Avoid specifying long messages outside the exception class

(TRY003)


2826-2826: Avoid specifying long messages outside the exception class

(TRY003)


2833-2833: Avoid specifying long messages outside the exception class

(TRY003)


2838-2838: Avoid specifying long messages outside the exception class

(TRY003)


2840-2840: Avoid specifying long messages outside the exception class

(TRY003)


2845-2845: Avoid specifying long messages outside the exception class

(TRY003)


2847-2847: Avoid specifying long messages outside the exception class

(TRY003)


3010-3010: Unused function argument: out_dtype

(ARG001)


3154-3154: Unused function argument: out_dtype

(ARG001)

flashinfer/deep_gemm.py

1371-1371: Unused function argument: recipe

(ARG001)


1372-1372: Unused function argument: compiled_dims

(ARG001)


1379-1379: Avoid specifying long messages outside the exception class

(TRY003)


1381-1381: Avoid specifying long messages outside the exception class

(TRY003)


1384-1384: Avoid specifying long messages outside the exception class

(TRY003)


1395-1395: Avoid specifying long messages outside the exception class

(TRY003)


1397-1397: Avoid specifying long messages outside the exception class

(TRY003)


1399-1399: Avoid specifying long messages outside the exception class

(TRY003)


1401-1401: Avoid specifying long messages outside the exception class

(TRY003)


1403-1403: Avoid specifying long messages outside the exception class

(TRY003)


1407-1407: Avoid specifying long messages outside the exception class

(TRY003)


1431-1431: Unpacked variable k_ is never used

(RUF059)


1468-1468: Unused function argument: recipe

(ARG001)


1469-1469: Unused function argument: compiled_dims

(ARG001)


1475-1475: Avoid specifying long messages outside the exception class

(TRY003)


1477-1477: Avoid specifying long messages outside the exception class

(TRY003)


1480-1480: Avoid specifying long messages outside the exception class

(TRY003)


1482-1482: Unpacked variable sfa is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1483-1483: Unpacked variable sfb is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1491-1491: Avoid specifying long messages outside the exception class

(TRY003)


1493-1493: Avoid specifying long messages outside the exception class

(TRY003)


1495-1495: Avoid specifying long messages outside the exception class

(TRY003)


1497-1497: Avoid specifying long messages outside the exception class

(TRY003)


1499-1499: Avoid specifying long messages outside the exception class

(TRY003)


1501-1501: Avoid specifying long messages outside the exception class

(TRY003)


1503-1503: Avoid specifying long messages outside the exception class

(TRY003)


1507-1507: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
πŸ”‡ Additional comments (2)
flashinfer/gemm.py (2)

3012-3014: Forwarding check OK, but inherits dtype/shape constraints from deep_gemm.

Once deep_gemm checker is fixed to allow bf16/float32 and correct m_indices semantics, this indirection remains valid. No changes needed here.


3161-3164: Forwarding check OK (batch path).

Same note as above; relies on deep_gemm masked checker. No change required.

Comment on lines 1411 to 1420
@backend_requirement(
common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size,
)
def m_grouped_fp8_gemm_nt_contiguous(
Copy link
Contributor

@coderabbitai coderabbitai bot Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Decorator misuse: missing backend_checks causes mypy error and runtime failure.

@backend_requirement requires a backend map. Also, calls pass no backend kwarg, so use a None key as the default backend and gate CC via a trivial requirement.

Apply:

-@backend_requirement(
- common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size,
-)
+@backend_requirement(
+ {None: _deepgemm_sm10x_requirement},
+ common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size,
+)

And add this requirement helper near the checker:

+@supported_compute_capability([100, 103])
+def _deepgemm_sm10x_requirement(
+ a_fp8, b_fp8, d, m_indices, recipe=None, compiled_dims="nk"
+):
+ # CC gating only; shape/dtype validated by common_check.
+ return True

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
πŸͺ› GitHub Actions: pre-commit

[error] 1411-1411: mypy: Missing positional argument 'backend_checks' in call to 'backend_requirement'.

πŸ€– Prompt for AI Agents
In flashinfer/deep_gemm.py around lines 1411-1414, the @backend_requirement
decorator is missing a backend_checks mapping which causes mypy/runtime failures
because callers pass no backend kwarg; add a trivial requirement helper (e.g., a
function that returns True or otherwise acts as a no-op requirement) near the
existing _check_group_deepgemm_fp8_nt_contiguous_problem_size checker, then
update the decorator call to include backend_checks with a None key mapped to
that trivial requirement (backend_checks={None: trivial_requirement}) so the
decorator has a default backend entry and CC gating remains intact.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/deep_gemm.py (1)

1366-1414: Add missing positive-dimension checks and silence unused params.

  • Validate n, k, num_groups > 0 to prevent empty/invalid shapes (was present before per prior review).
  • Prefix unused parameters to satisfy linters.

Apply:

-@supported_compute_capability([100, 103])
-def _check_group_deepgemm_fp8_nt_contiguous_problem_size(
+@supported_compute_capability([100, 103])
+def _check_group_deepgemm_fp8_nt_contiguous_problem_size(
 a_fp8: Tuple[torch.Tensor, torch.Tensor],
 b_fp8: Tuple[torch.Tensor, torch.Tensor],
 d: torch.Tensor,
 m_indices: torch.Tensor,
- recipe: Optional[Tuple[int, int, int]] = None,
- compiled_dims: str = "nk",
+ recipe: Optional[Tuple[int, int, int]] = None,
+ compiled_dims: str = "nk",
 ) -> bool:
@@
 m, k = a.shape
 num_groups, n, k_ = b.shape
 m_, n_ = d.shape
 m__ = m_indices.numel()
 
 # Type and shape checks
@@
 if m_indices.dtype != torch.int32:
 raise ValueError(f"m_indices must be int32, but got {m_indices.dtype}")
 
+ # Positive-dimension checks
+ if n <= 0 or k <= 0 or num_groups <= 0:
+ raise ValueError(
+ f"n, k, and num_groups must be positive, got n={n}, k={k}, num_groups={num_groups}"
+ )
 # D must be N-major
♻️ Duplicate comments (5)
flashinfer/gemm.py (3)

2557-2560: Critical: Empty backend map causes gemm_fp8_nt_blockscaled to always error.

This mirrors a previously reported issue. Add a default requirement with CC gating.

+@supported_compute_capability([100, 103, 120, 121])
+def _cutlass_gemm_fp8_nt_blockscaled_requirement(
+ a, b, a_scale, b_scale, scale_major_mode="MN", mma_sm=1, out=None, out_dtype=None
+):
+ return True
+
-@backend_requirement(
- {},
- common_check=_check_gemm_fp8_nt_blockscaled_problem_size,
-)
+@backend_requirement(
+ {None: _cutlass_gemm_fp8_nt_blockscaled_requirement},
+ common_check=_check_gemm_fp8_nt_blockscaled_problem_size,
+)

2590-2655: Add missing groups check and validate out dtype early.

Ensure b.shape[0] matches num_groups derived from m_indptr; set/validate out_dtype like other paths.

@@
- num_groups = m_indptr.shape[0] - 1
+ num_groups = m_indptr.shape[0] - 1
+ if b.shape[0] != num_groups:
+ raise ValueError(
+ f"b.shape[0] ({b.shape[0]}) must equal num_groups ({num_groups})"
+ )

2865-2868: Critical: group_gemm_mxfp8_mxfp4_nt_groupwise also uses {}.

Provide a default requirement with CC gating.

+@supported_compute_capability([100, 103, 110, 120, 121])
+def _cutlass_group_gemm_mxfp8_mxfp4_requirement(
+ a, b, a_scale, b_scale, m_indptr,
+ mma_sm=1, tile_m=128, tile_n=128, tile_k=128, swap_ab=True,
+ out=None, out_dtype=None
+):
+ return True
+
-@backend_requirement(
- {},
- common_check=_check_group_gemm_mxfp8_mxfp4_nt_groupwise_problem_size,
-)
+@backend_requirement(
+ {None: _cutlass_group_gemm_mxfp8_mxfp4_requirement},
+ common_check=_check_group_gemm_mxfp8_mxfp4_nt_groupwise_problem_size,
+)
flashinfer/deep_gemm.py (2)

1416-1420: Critical: backend_requirement with empty backend map always fails at runtime.

Passing {} and not supplying a backend kwarg makes the wrapper raise BackendSupportedError. Provide a default backend entry keyed by None and a trivial requirement function with proper CC gating.

Apply:

+@supported_compute_capability([100, 103])
+def _deepgemm_sm10x_requirement(
+ a_fp8, b_fp8, d, m_indices, recipe=None, compiled_dims="nk"
+):
+ # CC gating only; shape/dtype validated by common_check.
+ return True
+
-@backend_requirement(
- {},
- common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size,
-)
+@backend_requirement(
+ {None: _deepgemm_sm10x_requirement},
+ common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size,
+)

1529-1533: Critical: same empty backend map issue for masked path.

Mirror the contiguous fix by adding a default backend requirement with CC gating.

-@backend_requirement(
- {},
- common_check=_check_m_grouped_fp8_gemm_nt_masked_problem_size,
-)
+@backend_requirement(
+ {None: _deepgemm_sm10x_requirement},
+ common_check=_check_m_grouped_fp8_gemm_nt_masked_problem_size,
+)
🧹 Nitpick comments (3)
flashinfer/deep_gemm.py (2)

1431-1438: Minor: avoid unused k_.

Use _ to silence the unused binding.

- num_groups, n, k_ = b.shape
+ num_groups, n, _ = b.shape

1467-1527: LGTM with tiny lint nits.

Logic is correct and includes positivity checks. Prefix unused sfa/sfb to satisfy linters.

- a, sfa = a_fp8
- b, sfb = b_fp8
+ a, _sfa = a_fp8
+ b, _sfb = b_fp8
flashinfer/gemm.py (1)

2165-2235: NT-groupwise single-GEMM backend checks: solid, consider reusing for blockscaled.

The requirement/check split is good. If desired, mirror the CUTLASS CC list for the blockscaled helper to keep consistency.

πŸ“œ Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

πŸ“₯ Commits

Reviewing files that changed from the base of the PR and between 92096da and ad39f67.

πŸ“’ Files selected for processing (2)
  • flashinfer/deep_gemm.py (4 hunks)
  • flashinfer/gemm.py (9 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/deep_gemm.py (1)
flashinfer/utils.py (4)
  • ceil_div (575-586)
  • round_up (589-591)
  • supported_compute_capability (772-852)
  • backend_requirement (855-1028)
flashinfer/gemm.py (2)
flashinfer/utils.py (4)
  • supported_compute_capability (772-852)
  • backend_requirement (855-1028)
  • is_sm120a_supported (504-506)
  • is_sm121a_supported (509-511)
flashinfer/deep_gemm.py (2)
  • _check_group_deepgemm_fp8_nt_contiguous_problem_size (1367-1413)
  • _check_m_grouped_fp8_gemm_nt_masked_problem_size (1468-1526)
πŸͺ› Ruff (0.14.2)
flashinfer/deep_gemm.py

1372-1372: Unused function argument: recipe

(ARG001)


1373-1373: Unused function argument: compiled_dims

(ARG001)


1379-1379: Avoid specifying long messages outside the exception class

(TRY003)


1381-1381: Avoid specifying long messages outside the exception class

(TRY003)


1384-1386: Avoid specifying long messages outside the exception class

(TRY003)


1397-1399: Avoid specifying long messages outside the exception class

(TRY003)


1401-1401: Avoid specifying long messages outside the exception class

(TRY003)


1403-1403: Avoid specifying long messages outside the exception class

(TRY003)


1405-1405: Avoid specifying long messages outside the exception class

(TRY003)


1407-1407: Avoid specifying long messages outside the exception class

(TRY003)


1411-1411: Avoid specifying long messages outside the exception class

(TRY003)


1437-1437: Unpacked variable k_ is never used

(RUF059)


1474-1474: Unused function argument: recipe

(ARG001)


1475-1475: Unused function argument: compiled_dims

(ARG001)


1480-1480: Avoid specifying long messages outside the exception class

(TRY003)


1482-1482: Avoid specifying long messages outside the exception class

(TRY003)


1485-1487: Avoid specifying long messages outside the exception class

(TRY003)


1489-1489: Unpacked variable sfa is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1490-1490: Unpacked variable sfb is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1502-1504: Avoid specifying long messages outside the exception class

(TRY003)


1506-1508: Avoid specifying long messages outside the exception class

(TRY003)


1510-1512: Avoid specifying long messages outside the exception class

(TRY003)


1514-1514: Avoid specifying long messages outside the exception class

(TRY003)


1516-1516: Avoid specifying long messages outside the exception class

(TRY003)


1518-1518: Avoid specifying long messages outside the exception class

(TRY003)


1520-1520: Avoid specifying long messages outside the exception class

(TRY003)


1524-1524: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer/gemm.py

2011-2011: Unused function argument: A

(ARG001)


2012-2012: Unused function argument: B

(ARG001)


2013-2013: Unused function argument: A_scale

(ARG001)


2014-2014: Unused function argument: B_scale

(ARG001)


2015-2015: Unused function argument: dtype

(ARG001)


2016-2016: Unused function argument: out

(ARG001)


2017-2017: Unused function argument: backend

(ARG001)


2025-2025: Unused function argument: A

(ARG001)


2026-2026: Unused function argument: B

(ARG001)


2027-2027: Unused function argument: A_scale

(ARG001)


2028-2028: Unused function argument: B_scale

(ARG001)


2029-2029: Unused function argument: dtype

(ARG001)


2030-2030: Unused function argument: out

(ARG001)


2031-2031: Unused function argument: backend

(ARG001)


2040-2040: Unused function argument: A_scale

(ARG001)


2041-2041: Unused function argument: B_scale

(ARG001)


2042-2042: Unused function argument: dtype

(ARG001)


2043-2043: Unused function argument: out

(ARG001)


2044-2044: Unused function argument: backend

(ARG001)


2047-2047: Avoid specifying long messages outside the exception class

(TRY003)


2052-2052: Unused function argument: A

(ARG001)


2053-2053: Unused function argument: B

(ARG001)


2054-2054: Unused function argument: A_scale

(ARG001)


2055-2055: Unused function argument: B_scale

(ARG001)


2057-2057: Unused function argument: out

(ARG001)


2058-2058: Unused function argument: backend

(ARG001)


2166-2166: Unused function argument: a

(ARG001)


2167-2167: Unused function argument: b

(ARG001)


2168-2168: Unused function argument: a_scale

(ARG001)


2169-2169: Unused function argument: b_scale

(ARG001)


2171-2171: Unused function argument: mma_sm

(ARG001)


2172-2172: Unused function argument: scale_granularity_mnk

(ARG001)


2173-2173: Unused function argument: out

(ARG001)


2174-2174: Unused function argument: out_dtype

(ARG001)


2175-2175: Unused function argument: backend

(ARG001)


2178-2178: Avoid specifying long messages outside the exception class

(TRY003)


2186-2186: Unused function argument: b

(ARG001)


2187-2187: Unused function argument: a_scale

(ARG001)


2188-2188: Unused function argument: b_scale

(ARG001)


2189-2189: Unused function argument: scale_major_mode

(ARG001)


2190-2190: Unused function argument: mma_sm

(ARG001)


2192-2192: Unused function argument: out

(ARG001)


2193-2193: Unused function argument: out_dtype

(ARG001)


2194-2194: Unused function argument: backend

(ARG001)


2197-2197: Avoid specifying long messages outside the exception class

(TRY003)


2199-2199: Avoid specifying long messages outside the exception class

(TRY003)


2207-2207: Unused function argument: a_scale

(ARG001)


2208-2208: Unused function argument: b_scale

(ARG001)


2209-2209: Unused function argument: scale_major_mode

(ARG001)


2210-2210: Unused function argument: mma_sm

(ARG001)


2211-2211: Unused function argument: scale_granularity_mnk

(ARG001)


2212-2212: Unused function argument: out

(ARG001)


2214-2214: Unused function argument: backend

(ARG001)


2217-2217: Avoid specifying long messages outside the exception class

(TRY003)


2220-2222: Avoid specifying long messages outside the exception class

(TRY003)


2596-2596: Unused function argument: scale_granularity_mnk

(ARG001)


2603-2603: Avoid specifying long messages outside the exception class

(TRY003)


2605-2605: Avoid specifying long messages outside the exception class

(TRY003)


2607-2607: Avoid specifying long messages outside the exception class

(TRY003)


2609-2609: Avoid specifying long messages outside the exception class

(TRY003)


2611-2611: Avoid specifying long messages outside the exception class

(TRY003)


2613-2615: Avoid specifying long messages outside the exception class

(TRY003)


2617-2617: Avoid specifying long messages outside the exception class

(TRY003)


2630-2632: Avoid specifying long messages outside the exception class

(TRY003)


2634-2636: Avoid specifying long messages outside the exception class

(TRY003)


2641-2641: Avoid specifying long messages outside the exception class

(TRY003)


2643-2643: Avoid specifying long messages outside the exception class

(TRY003)


2645-2645: Avoid specifying long messages outside the exception class

(TRY003)


2651-2653: Avoid specifying long messages outside the exception class

(TRY003)


2798-2800: Avoid specifying long messages outside the exception class

(TRY003)


2802-2802: Avoid specifying long messages outside the exception class

(TRY003)


2804-2804: Avoid specifying long messages outside the exception class

(TRY003)


2806-2806: Avoid specifying long messages outside the exception class

(TRY003)


2808-2808: Avoid specifying long messages outside the exception class

(TRY003)


2810-2810: Avoid specifying long messages outside the exception class

(TRY003)


2812-2812: Avoid specifying long messages outside the exception class

(TRY003)


2814-2814: Avoid specifying long messages outside the exception class

(TRY003)


2816-2816: Avoid specifying long messages outside the exception class

(TRY003)


2818-2818: Avoid specifying long messages outside the exception class

(TRY003)


2829-2831: Avoid specifying long messages outside the exception class

(TRY003)


2835-2837: Avoid specifying long messages outside the exception class

(TRY003)


2844-2846: Avoid specifying long messages outside the exception class

(TRY003)


2851-2851: Avoid specifying long messages outside the exception class

(TRY003)


2853-2853: Avoid specifying long messages outside the exception class

(TRY003)


2858-2858: Avoid specifying long messages outside the exception class

(TRY003)


2860-2860: Avoid specifying long messages outside the exception class

(TRY003)


3022-3022: Unused function argument: out_dtype

(ARG001)


3172-3172: Unused function argument: out_dtype

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
πŸ”‡ Additional comments (2)
flashinfer/deep_gemm.py (1)

48-53: Imports look good.

The utils import now brings in supported_compute_capability and backend_requirement correctly.

flashinfer/gemm.py (1)

2009-2072: bmm_fp8 requirements mapping looks correct.

Backends and CC gating are wired; common check validates output dtype.

Comment on lines +2658 to +2661
@backend_requirement(
{},
common_check=_check_group_gemm_fp8_nt_groupwise_problem_size,
)
Copy link
Contributor

@coderabbitai coderabbitai bot Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Critical: group_gemm_fp8_nt_groupwise uses {} backend map.

Add a default backend requirement with CC gating; signature must match this API.

+@supported_compute_capability([100, 120, 121])
+def _cutlass_group_gemm_fp8_nt_groupwise_requirement(
+ a, b, a_scale, b_scale, m_indptr,
+ scale_granularity_mnk=(1, 128, 128),
+ scale_major_mode="MN",
+ mma_sm=1,
+ out=None, out_dtype=None
+):
+ return True
+
-@backend_requirement(
- {},
- common_check=_check_group_gemm_fp8_nt_groupwise_problem_size,
-)
+@backend_requirement(
+ {None: _cutlass_group_gemm_fp8_nt_groupwise_requirement},
+ common_check=_check_group_gemm_fp8_nt_groupwise_problem_size,
+)
πŸ“ Committable suggestion

!!️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@backend_requirement(
{},
common_check=_check_group_gemm_fp8_nt_groupwise_problem_size,
)
@supported_compute_capability([100, 120, 121])
def _cutlass_group_gemm_fp8_nt_groupwise_requirement(
a, b, a_scale, b_scale, m_indptr,
scale_granularity_mnk=(1, 128, 128),
scale_major_mode="MN",
mma_sm=1,
out=None, out_dtype=None
):
return True
@backend_requirement(
{None: _cutlass_group_gemm_fp8_nt_groupwise_requirement},
common_check=_check_group_gemm_fp8_nt_groupwise_problem_size,
)

Comment on lines +3033 to +3036
@backend_requirement(
{},
common_check=_check_group_deepgemm_fp8_nt_groupwise_problem_size,
)
Copy link
Contributor

@coderabbitai coderabbitai bot Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Critical: DeepGEMM group path requires a default backend requirement.

Add CC-gated default backend for SM100/103.

+@supported_compute_capability([100, 103])
+def _deepgemm_sm10x_requirement_group(
+ a, b, a_scale, b_scale, m_indices, scale_granularity_mnk=(1, 128, 128),
+ out=None, out_dtype=None
+):
+ return True
+
-@backend_requirement(
- {},
- common_check=_check_group_deepgemm_fp8_nt_groupwise_problem_size,
-)
+@backend_requirement(
+ {None: _deepgemm_sm10x_requirement_group},
+ common_check=_check_group_deepgemm_fp8_nt_groupwise_problem_size,
+)

Committable suggestion skipped: line range outside the PR's diff.

πŸ€– Prompt for AI Agents
In flashinfer/gemm.py around lines 3033 to 3036, the @backend_requirement
decorator on the DeepGEMM group path is missing a CC-gated default backend; add
a default backend entry that targets SM100 and SM103 (e.g. a "default" key with
a CUDA backend CC filter for 100 and 103) so the decorator includes both the
existing specific requirements and a default backend gated to compute capability
100 and 103.

Comment on lines +3181 to +3184
@backend_requirement(
{},
common_check=_check_batch_deepgemm_fp8_nt_groupwise,
)
Copy link
Contributor

@coderabbitai coderabbitai bot Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Critical: DeepGEMM batch path also uses {}.

Add CC-gated default backend.

+@supported_compute_capability([100, 103])
+def _deepgemm_sm10x_requirement_batch(
+ a, b, a_scale, b_scale, masked_m, expected_m,
+ scale_granularity_mnk=(1, 128, 128),
+ out=None, out_dtype=None
+):
+ return True
+
-@backend_requirement(
- {},
- common_check=_check_batch_deepgemm_fp8_nt_groupwise,
-)
+@backend_requirement(
+ {None: _deepgemm_sm10x_requirement_batch},
+ common_check=_check_batch_deepgemm_fp8_nt_groupwise,
+)
πŸ“ Committable suggestion

!!️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@backend_requirement(
{},
common_check=_check_batch_deepgemm_fp8_nt_groupwise,
)
@supported_compute_capability([100, 103])
def _deepgemm_sm10x_requirement_batch(
a, b, a_scale, b_scale, masked_m, expected_m,
scale_granularity_mnk=(1, 128, 128),
out=None, out_dtype=None
):
return True
@backend_requirement(
{None: _deepgemm_sm10x_requirement_batch},
common_check=_check_batch_deepgemm_fp8_nt_groupwise,
)
πŸ€– Prompt for AI Agents
In flashinfer/gemm.py around lines 3181 to 3184, the @backend_requirement
decorator for the DeepGEMM batch path currently passes an empty dict for
backends ({}); change it to include a CC-gated default backend entry so the
batch path is enabled only when the compute-capability gate is satisfied. Update
the decorator's first argument to supply a default backend dict guarded by the
same CC check used elsewhere (i.e., add the CC-gated backend entry instead of
{}), keeping the existing common_check=_check_batch_deepgemm_fp8_nt_groupwise.

Copy link
Contributor Author

/bot run

Copy link
Collaborator

GitLab MR !97 has been created, and the CI pipeline #37488727 is currently running. I'll report back once the pipeline job completes.

Copy link
Collaborator

[FAILED] Pipeline #37488727: 1/17 passed



def m_grouped_fp8_gemm_nt_contiguous(
@supported_compute_capability([100, 103])
Copy link
Contributor

@nvmbreughe nvmbreughe Oct 29, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the compute capability will not be checked actually on the common_check function.
We would either need to:

  1. Redesign the decorator so that we can just add @supported_compute_capability to m_grouped_fp8_gemm_nt_contiguous directly; OR
  2. Change the decorator's implementations of wrapper and is_compute_capability_supported so it also checks this on the common_check.
  3. Add a backend parameter to m_grouped_fp8_gemm_nt_contiguous. Though this is deepgemm and I don't think we want to call this a backend for now.

My preference would go to option 2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 2 also makes most sense. In a lot of the APIs there are also no 'backend' arg to be passed in so we can't only check @ supported_compute_capability there. I can change this in a separate PR

Copy link
Contributor

@nvmbreughe nvmbreughe Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A separate PR is fine, since we wouldn't cause a regression (we didn't have CC checks before the current PR anyway).

impl(a, sfa, b, sfb, d, m_indices)


@supported_compute_capability([100, 103])
Copy link
Contributor

@nvmbreughe nvmbreughe Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, the compute capabilities will be ignored.

"cudnn": _cudnn_bmm_fp8_requirement,
"cublas": _cublas_bmm_fp8_requirement,
"cutlass": _cutlass_bmm_fp8_requirement,
"auto": _cublas_bmm_fp8_requirement, # cublas default
Copy link
Contributor

@nvmbreughe nvmbreughe Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto doesn't default to cublas. It actually sets backends = ["cutlass", "cublas", "cudnn"] , this usually will be autotuned across backends.

That's not trivial to resolve. One way may be to just create an aggregate function "check_all_backends" that calls the 3 checker functions.

Note that the function that takes all these backends is fp8_gemm_sm100, which also does a bunch of checks. So perhaps we need to rely on those, since the checks here are very minimal anyway.

But this does remain a general issue: "How do we isolate backend support checks when the auto backend is chosen". I think the behavior that we wish is "filter out unsuitable backends", rather than just rejecting auto when one of the backends fails. Maybe this can be an additional "attribute" that the decorator adds for the special auto backend. To illustrate, we would have:
bmm_fp8.suitable_auto_backends() return ["cutlass", "cublas", "cudnn"] if all checks pass, or ["cutlass", "cublas"] if cudnn's check failed. bmm_fp8.suitable_auto_backends() would return None if the selected backend was not "auto".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bmm_fp8.suitable_auto_backends() return ["cutlass", "cublas", "cudnn"] if all checks pass, or ["cutlass", "cublas"] if cudnn's check failed.

In this decorator, we could set an attribute with the final list of supported backends for the actual function to reference when passing into the autotuner, is this what you had in mind?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Reviewers

@nvmbreughe nvmbreughe nvmbreughe requested changes

@coderabbitai coderabbitai[bot] coderabbitai[bot] left review comments

@yzh119 yzh119 Awaiting requested review from yzh119 yzh119 is a code owner

@cyx-6 cyx-6 Awaiting requested review from cyx-6 cyx-6 is a code owner

@wenscarl wenscarl Awaiting requested review from wenscarl wenscarl is a code owner

@bkryu bkryu Awaiting requested review from bkryu bkryu is a code owner

+1 more reviewer

@gemini-code-assist gemini-code-assist[bot] gemini-code-assist[bot] left review comments

Reviewers whose approvals may not affect merge requirements

Requested changes must be addressed to merge this pull request.

Assignees

No one assigned

Labels

None yet

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

AltStyle γ«γ‚ˆγ£γ¦ε€‰ζ›γ•γ‚ŒγŸγƒšγƒΌγ‚Έ (->γ‚ͺγƒͺγ‚ΈγƒŠγƒ«) /