- 
  Notifications
 
You must be signed in to change notification settings  - Fork 557
 
refactor: backend_requirement + supported_compute_capability decorator for gemm #2000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| 
 
 Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdded explicit input validation and compute-capability guards for multiple FP8 GEMM/grouped/NT-groupwise paths. Converted in-function asserts to explicit checks and introduced backend_requirement decorators plus standalone problem-size validators to gate backend dispatch and return clear errors on invalid inputs. Changes
 Sequence Diagram(s)sequenceDiagram
 participant Caller as Caller
 participant API as Public API<br/>(e.g., gemm_fp8 / m_grouped_fp8)
 participant Decorator as @backend_requirement
 participant Validator as _check_* validators
 participant Dispatcher as Backend Dispatcher
 participant Backend as Backend (cutlass/cudnn/cublas/trtllm)
 participant Error as Error
 Caller->>API: call FP8 GEMM entry
 API->>Decorator: decorated entry invoked
 Decorator->>Validator: run problem-size & capability checks
 Validator-->>Decorator: true / false
 alt checks pass
 Decorator->>Dispatcher: select backend
 Dispatcher->>Backend: invoke backend-specific implementation
 Backend-->>Dispatcher: result tensor
 Dispatcher-->>API: return result
 API-->>Caller: tensor
 else checks fail
 Decorator->>Error: raise ValueError with reason
 Error-->>Caller: exception
 end
 Estimated code review effortπ― 4 (Complex) | β±οΈ ~45 minutes 
 Suggested reviewers
 Poem
 Pre-merge checks and finishing touchesβ Failed checks (2 warnings)
 β Passed checks (1 passed)
 β¨ Finishing touches
 π§ͺ Generate unit tests (beta)
 Comment   | 
 
 Summary of ChangesHello @jimmyzho, I'm Gemini Code Assist1 ! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the General Matrix Multiply (GEMM) operations by implementing a new decorator-based system. The primary goal is to streamline and consolidate the validation of inputs and hardware compatibility requirements for various GEMM functions. By abstracting these checks into decorators and dedicated problem-size validation functions, the core logic of the GEMM implementations becomes cleaner and more focused, while ensuring consistent and informative error handling across the codebase. Highlights
 Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either  
 Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a  Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with π and π on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
  | 
 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a great refactoring that introduces @backend_requirement and @supported_compute_capability decorators to centralize validation logic for GEMM functions. This improves code clarity and maintainability by separating validation from the core logic. The conversion of assert statements to ValueError exceptions is also a good improvement for more robust error handling.
I've found a critical bug that would lead to a runtime error, and a couple of missing validation checks that were present in the original code. My review includes suggestions to fix these issues. Overall, the direction of the changes is excellent.
 
 
 flashinfer/gemm.py
 
 Outdated
 
 
 There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable n is used on line 2624 before it is assigned a value on line 2631. This will raise an UnboundLocalError if out is not None. The shape variables n and k should be defined before they are used in the shape checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check for positive dimensions n, k, and num_groups is missing. The original code had assert n > 0 and k > 0 and num_groups > 0. This check is important to prevent unexpected behavior with empty or invalid dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A check to ensure that the number of groups in b matches num_groups derived from m_indptr is missing. This was present in the original code (assert b.shape[0] == num_groups) and is important for correctness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and canβt be posted inline due to platform limitations.
β οΈ  Outside diff range comments (2)
flashinfer/deep_gemm.py (1)
1365-1410: Fix incorrect shape checks and overβstrict dtype validation in contiguous checker.
num_groups != m__is wrong:m_indicesis per-row (length m), not per-group. This will falsely reject valid inputs.- Allow
 d.dtypeto be either bf16 or float32, which the runtime supports.Apply:
@@ - if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__: - raise ValueError(f"Shape mismatch. m = {m}, m_ = {m_}, k = {k}, k_ = {k_}, n = {n}, n_ = {n_}, m__ = {m__}") + if m != m_ or k != k_ or n != n_ or m__ != m_: + raise ValueError( + f"Shape mismatch. m={m}, m_={m_}, k={k}, k_={k_}, n={n}, n_={n_}, |m_indices|={m__}" + ) @@ - if d.dtype != torch.bfloat16: - raise ValueError(f"d must be bfloat16, but got {d.dtype}") + if d.dtype not in (torch.bfloat16, torch.float): + raise ValueError(f"d must be bfloat16 or float32, got {d.dtype}")Optionally, quiet Ruff warnings:
- a, sfa = a_fp8 - b, sfb = b_fp8 + a, _ = a_fp8 + b, _ = b_fp8flashinfer/gemm.py (1)
2589-2665: Fix F821:nused before assignment in groupwise problem-size check.Move
n/kderivation before usingnin output shape checks; keep dtype validation.Apply:
@@ - if out is None: - if out_dtype is None: - out_dtype = torch.bfloat16 - else: - if out_dtype is None: - out_dtype = out.dtype - if out.shape != (a.shape[0], n): - raise ValueError(f"Shape mismatch. out.shape = {out.shape}, (a.shape[0], n) = {(a.shape[0], n)}") - if out.dtype != out_dtype: - raise ValueError(f"dtype mismatch. out.dtype = {out.dtype}, out_dtype = {out_dtype}") - - _validate_fp8_output_dtype(out_dtype) - - n = b.shape[1] - k = b.shape[2] + n = b.shape[1] + k = b.shape[2] + + if out is None: + if out_dtype is None: + out_dtype = torch.bfloat16 + else: + if out_dtype is None: + out_dtype = out.dtype + if out.shape != (a.shape[0], n): + raise ValueError( + f"Shape mismatch. out.shape={out.shape}, expected={(a.shape[0], n)}" + ) + if out.dtype != out_dtype: + raise ValueError(f"dtype mismatch. out.dtype={out.dtype}, out_dtype={out_dtype}") + + _validate_fp8_output_dtype(out_dtype)
π§Ή Nitpick comments (4)
flashinfer/deep_gemm.py (2)
1428-1432: Nit: unused variablek_.Use
_to silence RUF059.- num_groups, n, k_ = b.shape + num_groups, n, _ = b.shape
1461-1514: Masked checker: align dtype, contiguity, and unused unpack.
- Allow
 d.dtypeto be bf16 or f32 to match runtime support.- Use
 _for unused unpacked scale tensors.Apply:
- a, sfa = a_fp8 - b, sfb = b_fp8 + a, _ = a_fp8 + b, _ = b_fp8 @@ - if d.dtype != torch.bfloat16: - raise ValueError(f"d must be bfloat16, but got {d.dtype}") + if d.dtype not in (torch.bfloat16, torch.float): + raise ValueError(f"d must be bfloat16 or float32, got {d.dtype}")flashinfer/gemm.py (2)
2734-2734: Silence F841: unused local.-num_groups = m_indptr.shape[0] - 1 +_num_groups = m_indptr.shape[0] - 1
2940-2940: Silence F841: unused local.-num_groups = m_indptr.shape[0] - 1 +_num_groups = m_indptr.shape[0] - 1
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (2)
flashinfer/deep_gemm.py(4 hunks)flashinfer/gemm.py(9 hunks)
π§° Additional context used
𧬠Code graph analysis (2)
flashinfer/gemm.py (2)
flashinfer/utils.py (4)
supported_compute_capability(772-852)backend_requirement(855-1028)is_sm120a_supported(504-506)is_sm121a_supported(509-511)flashinfer/deep_gemm.py (2)
_check_group_deepgemm_fp8_nt_contiguous_problem_size(1366-1409)_check_m_grouped_fp8_gemm_nt_masked_problem_size(1462-1509)
flashinfer/deep_gemm.py (1)
flashinfer/utils.py (4)
ceil_div(575-586)round_up(589-591)supported_compute_capability(772-852)backend_requirement(855-1028)
πͺ GitHub Actions: pre-commit
flashinfer/gemm.py
[error] 2624-2624: ruff (F821): Undefined name 'n'. Cannot determine type of 'n'.
[error] 2625-2625: ruff (F821): Undefined name 'n'. Cannot determine type of 'n'.
[error] 2734-2734: ruff (F841): Local variable 'num_groups' is assigned to but never used.
[error] 2940-2940: ruff (F841): Local variable 'num_groups' is assigned to but never used.
[error] 1-1: Trailing whitespace detected and fixed by pre-commit hook.
flashinfer/deep_gemm.py
[error] 1411-1411: mypy: Missing positional argument 'backend_checks' in call to 'backend_requirement'.
[error] 1511-1511: mypy: Missing positional argument 'backend_checks' in call to 'backend_requirement'.
[error] 1-1: Trailing whitespace detected and fixed by pre-commit hook.
πͺ Ruff (0.14.2)
flashinfer/gemm.py
2012-2012: Unused function argument: A
(ARG001)
2013-2013: Unused function argument: B
(ARG001)
2014-2014: Unused function argument: A_scale
(ARG001)
2015-2015: Unused function argument: B_scale
(ARG001)
2016-2016: Unused function argument: dtype
(ARG001)
2017-2017: Unused function argument: out
(ARG001)
2018-2018: Unused function argument: backend
(ARG001)
2026-2026: Unused function argument: A
(ARG001)
2027-2027: Unused function argument: B
(ARG001)
2028-2028: Unused function argument: A_scale
(ARG001)
2029-2029: Unused function argument: B_scale
(ARG001)
2030-2030: Unused function argument: dtype
(ARG001)
2031-2031: Unused function argument: out
(ARG001)
2032-2032: Unused function argument: backend
(ARG001)
2041-2041: Unused function argument: A_scale
(ARG001)
2042-2042: Unused function argument: B_scale
(ARG001)
2043-2043: Unused function argument: dtype
(ARG001)
2044-2044: Unused function argument: out
(ARG001)
2045-2045: Unused function argument: backend
(ARG001)
2048-2048: Avoid specifying long messages outside the exception class
(TRY003)
2052-2052: Unused function argument: A
(ARG001)
2053-2053: Unused function argument: B
(ARG001)
2054-2054: Unused function argument: A_scale
(ARG001)
2055-2055: Unused function argument: B_scale
(ARG001)
2057-2057: Unused function argument: out
(ARG001)
2058-2058: Unused function argument: backend
(ARG001)
2165-2165: Unused function argument: a
(ARG001)
2166-2166: Unused function argument: b
(ARG001)
2167-2167: Unused function argument: a_scale
(ARG001)
2168-2168: Unused function argument: b_scale
(ARG001)
2170-2170: Unused function argument: mma_sm
(ARG001)
2171-2171: Unused function argument: scale_granularity_mnk
(ARG001)
2172-2172: Unused function argument: out
(ARG001)
2173-2173: Unused function argument: out_dtype
(ARG001)
2174-2174: Unused function argument: backend
(ARG001)
2177-2177: Avoid specifying long messages outside the exception class
(TRY003)
2185-2185: Unused function argument: b
(ARG001)
2186-2186: Unused function argument: a_scale
(ARG001)
2187-2187: Unused function argument: b_scale
(ARG001)
2188-2188: Unused function argument: scale_major_mode
(ARG001)
2189-2189: Unused function argument: mma_sm
(ARG001)
2191-2191: Unused function argument: out
(ARG001)
2192-2192: Unused function argument: out_dtype
(ARG001)
2193-2193: Unused function argument: backend
(ARG001)
2196-2196: Avoid specifying long messages outside the exception class
(TRY003)
2198-2198: Avoid specifying long messages outside the exception class
(TRY003)
2206-2206: Unused function argument: a_scale
(ARG001)
2207-2207: Unused function argument: b_scale
(ARG001)
2208-2208: Unused function argument: scale_major_mode
(ARG001)
2209-2209: Unused function argument: mma_sm
(ARG001)
2210-2210: Unused function argument: scale_granularity_mnk
(ARG001)
2211-2211: Unused function argument: out
(ARG001)
2213-2213: Unused function argument: backend
(ARG001)
2216-2216: Avoid specifying long messages outside the exception class
(TRY003)
2219-2221: Avoid specifying long messages outside the exception class
(TRY003)
2595-2595: Unused function argument: scale_granularity_mnk
(ARG001)
2602-2602: Avoid specifying long messages outside the exception class
(TRY003)
2604-2604: Avoid specifying long messages outside the exception class
(TRY003)
2606-2606: Avoid specifying long messages outside the exception class
(TRY003)
2608-2608: Avoid specifying long messages outside the exception class
(TRY003)
2610-2610: Avoid specifying long messages outside the exception class
(TRY003)
2612-2612: Avoid specifying long messages outside the exception class
(TRY003)
2614-2614: Avoid specifying long messages outside the exception class
(TRY003)
2624-2624: Undefined name n
(F821)
2625-2625: Avoid specifying long messages outside the exception class
(TRY003)
2625-2625: Undefined name n
(F821)
2627-2627: Avoid specifying long messages outside the exception class
(TRY003)
2635-2635: Avoid specifying long messages outside the exception class
(TRY003)
2637-2637: Avoid specifying long messages outside the exception class
(TRY003)
2639-2639: Avoid specifying long messages outside the exception class
(TRY003)
2645-2647: Avoid specifying long messages outside the exception class
(TRY003)
2793-2793: Avoid specifying long messages outside the exception class
(TRY003)
2795-2795: Avoid specifying long messages outside the exception class
(TRY003)
2797-2797: Avoid specifying long messages outside the exception class
(TRY003)
2799-2799: Avoid specifying long messages outside the exception class
(TRY003)
2801-2801: Avoid specifying long messages outside the exception class
(TRY003)
2803-2803: Avoid specifying long messages outside the exception class
(TRY003)
2805-2805: Avoid specifying long messages outside the exception class
(TRY003)
2807-2807: Avoid specifying long messages outside the exception class
(TRY003)
2809-2809: Avoid specifying long messages outside the exception class
(TRY003)
2811-2811: Avoid specifying long messages outside the exception class
(TRY003)
2822-2822: Avoid specifying long messages outside the exception class
(TRY003)
2826-2826: Avoid specifying long messages outside the exception class
(TRY003)
2833-2833: Avoid specifying long messages outside the exception class
(TRY003)
2838-2838: Avoid specifying long messages outside the exception class
(TRY003)
2840-2840: Avoid specifying long messages outside the exception class
(TRY003)
2845-2845: Avoid specifying long messages outside the exception class
(TRY003)
2847-2847: Avoid specifying long messages outside the exception class
(TRY003)
3010-3010: Unused function argument: out_dtype
(ARG001)
3154-3154: Unused function argument: out_dtype
(ARG001)
flashinfer/deep_gemm.py
1371-1371: Unused function argument: recipe
(ARG001)
1372-1372: Unused function argument: compiled_dims
(ARG001)
1379-1379: Avoid specifying long messages outside the exception class
(TRY003)
1381-1381: Avoid specifying long messages outside the exception class
(TRY003)
1384-1384: Avoid specifying long messages outside the exception class
(TRY003)
1395-1395: Avoid specifying long messages outside the exception class
(TRY003)
1397-1397: Avoid specifying long messages outside the exception class
(TRY003)
1399-1399: Avoid specifying long messages outside the exception class
(TRY003)
1401-1401: Avoid specifying long messages outside the exception class
(TRY003)
1403-1403: Avoid specifying long messages outside the exception class
(TRY003)
1407-1407: Avoid specifying long messages outside the exception class
(TRY003)
1431-1431: Unpacked variable k_ is never used
(RUF059)
1468-1468: Unused function argument: recipe
(ARG001)
1469-1469: Unused function argument: compiled_dims
(ARG001)
1475-1475: Avoid specifying long messages outside the exception class
(TRY003)
1477-1477: Avoid specifying long messages outside the exception class
(TRY003)
1480-1480: Avoid specifying long messages outside the exception class
(TRY003)
1482-1482: Unpacked variable sfa is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1483-1483: Unpacked variable sfb is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1491-1491: Avoid specifying long messages outside the exception class
(TRY003)
1493-1493: Avoid specifying long messages outside the exception class
(TRY003)
1495-1495: Avoid specifying long messages outside the exception class
(TRY003)
1497-1497: Avoid specifying long messages outside the exception class
(TRY003)
1499-1499: Avoid specifying long messages outside the exception class
(TRY003)
1501-1501: Avoid specifying long messages outside the exception class
(TRY003)
1503-1503: Avoid specifying long messages outside the exception class
(TRY003)
1507-1507: Avoid specifying long messages outside the exception class
(TRY003)
β° Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
 
π Additional comments (2)
flashinfer/gemm.py (2)
3012-3014: Forwarding check OK, but inherits dtype/shape constraints from deep_gemm.Once deep_gemm checker is fixed to allow bf16/float32 and correct
m_indicessemantics, this indirection remains valid. No changes needed here.
3161-3164: Forwarding check OK (batch path).Same note as above; relies on deep_gemm masked checker. No change required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decorator misuse: missing backend_checks causes mypy error and runtime failure.
@backend_requirement requires a backend map. Also, calls pass no backend kwarg, so use a None key as the default backend and gate CC via a trivial requirement.
Apply:
-@backend_requirement( - common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size, -) +@backend_requirement( + {None: _deepgemm_sm10x_requirement}, + common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size, +)
And add this requirement helper near the checker:
+@supported_compute_capability([100, 103]) +def _deepgemm_sm10x_requirement( + a_fp8, b_fp8, d, m_indices, recipe=None, compiled_dims="nk" +): + # CC gating only; shape/dtype validated by common_check. + return True
Committable suggestion skipped: line range outside the PR's diff.
π§° Tools
πͺ GitHub Actions: pre-commit
[error] 1411-1411: mypy: Missing positional argument 'backend_checks' in call to 'backend_requirement'.
π€ Prompt for AI Agents
In flashinfer/deep_gemm.py around lines 1411-1414, the @backend_requirement
decorator is missing a backend_checks mapping which causes mypy/runtime failures
because callers pass no backend kwarg; add a trivial requirement helper (e.g., a
function that returns True or otherwise acts as a no-op requirement) near the
existing _check_group_deepgemm_fp8_nt_contiguous_problem_size checker, then
update the decorator call to include backend_checks with a None key mapped to
that trivial requirement (backend_checks={None: trivial_requirement}) so the
decorator has a default backend entry and CC gating remains intact.
92096da to
 ad39f67  
 Compare
 
 There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and canβt be posted inline due to platform limitations.
β οΈ  Outside diff range comments (1)
flashinfer/deep_gemm.py (1)
1366-1414: Add missing positive-dimension checks and silence unused params.
- Validate n, k, num_groups > 0 to prevent empty/invalid shapes (was present before per prior review).
 - Prefix unused parameters to satisfy linters.
 Apply:
-@supported_compute_capability([100, 103]) -def _check_group_deepgemm_fp8_nt_contiguous_problem_size( +@supported_compute_capability([100, 103]) +def _check_group_deepgemm_fp8_nt_contiguous_problem_size( a_fp8: Tuple[torch.Tensor, torch.Tensor], b_fp8: Tuple[torch.Tensor, torch.Tensor], d: torch.Tensor, m_indices: torch.Tensor, - recipe: Optional[Tuple[int, int, int]] = None, - compiled_dims: str = "nk", + recipe: Optional[Tuple[int, int, int]] = None, + compiled_dims: str = "nk", ) -> bool: @@ m, k = a.shape num_groups, n, k_ = b.shape m_, n_ = d.shape m__ = m_indices.numel() # Type and shape checks @@ if m_indices.dtype != torch.int32: raise ValueError(f"m_indices must be int32, but got {m_indices.dtype}") + # Positive-dimension checks + if n <= 0 or k <= 0 or num_groups <= 0: + raise ValueError( + f"n, k, and num_groups must be positive, got n={n}, k={k}, num_groups={num_groups}" + ) # D must be N-major
β»οΈ Duplicate comments (5)
flashinfer/gemm.py (3)
2557-2560: Critical: Empty backend map causes gemm_fp8_nt_blockscaled to always error.This mirrors a previously reported issue. Add a default requirement with CC gating.
+@supported_compute_capability([100, 103, 120, 121]) +def _cutlass_gemm_fp8_nt_blockscaled_requirement( + a, b, a_scale, b_scale, scale_major_mode="MN", mma_sm=1, out=None, out_dtype=None +): + return True + -@backend_requirement( - {}, - common_check=_check_gemm_fp8_nt_blockscaled_problem_size, -) +@backend_requirement( + {None: _cutlass_gemm_fp8_nt_blockscaled_requirement}, + common_check=_check_gemm_fp8_nt_blockscaled_problem_size, +)
2590-2655: Add missing groups check and validate out dtype early.Ensure b.shape[0] matches num_groups derived from m_indptr; set/validate out_dtype like other paths.
@@ - num_groups = m_indptr.shape[0] - 1 + num_groups = m_indptr.shape[0] - 1 + if b.shape[0] != num_groups: + raise ValueError( + f"b.shape[0] ({b.shape[0]}) must equal num_groups ({num_groups})" + )
2865-2868: Critical: group_gemm_mxfp8_mxfp4_nt_groupwise also uses {}.Provide a default requirement with CC gating.
+@supported_compute_capability([100, 103, 110, 120, 121]) +def _cutlass_group_gemm_mxfp8_mxfp4_requirement( + a, b, a_scale, b_scale, m_indptr, + mma_sm=1, tile_m=128, tile_n=128, tile_k=128, swap_ab=True, + out=None, out_dtype=None +): + return True + -@backend_requirement( - {}, - common_check=_check_group_gemm_mxfp8_mxfp4_nt_groupwise_problem_size, -) +@backend_requirement( + {None: _cutlass_group_gemm_mxfp8_mxfp4_requirement}, + common_check=_check_group_gemm_mxfp8_mxfp4_nt_groupwise_problem_size, +)flashinfer/deep_gemm.py (2)
1416-1420: Critical: backend_requirement with empty backend map always fails at runtime.Passing {} and not supplying a backend kwarg makes the wrapper raise BackendSupportedError. Provide a default backend entry keyed by None and a trivial requirement function with proper CC gating.
Apply:
+@supported_compute_capability([100, 103]) +def _deepgemm_sm10x_requirement( + a_fp8, b_fp8, d, m_indices, recipe=None, compiled_dims="nk" +): + # CC gating only; shape/dtype validated by common_check. + return True + -@backend_requirement( - {}, - common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size, -) +@backend_requirement( + {None: _deepgemm_sm10x_requirement}, + common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size, +)
1529-1533: Critical: same empty backend map issue for masked path.Mirror the contiguous fix by adding a default backend requirement with CC gating.
-@backend_requirement( - {}, - common_check=_check_m_grouped_fp8_gemm_nt_masked_problem_size, -) +@backend_requirement( + {None: _deepgemm_sm10x_requirement}, + common_check=_check_m_grouped_fp8_gemm_nt_masked_problem_size, +)
π§Ή Nitpick comments (3)
flashinfer/deep_gemm.py (2)
1431-1438: Minor: avoid unusedk_.Use
_to silence the unused binding.- num_groups, n, k_ = b.shape + num_groups, n, _ = b.shape
1467-1527: LGTM with tiny lint nits.Logic is correct and includes positivity checks. Prefix unused
sfa/sfbto satisfy linters.- a, sfa = a_fp8 - b, sfb = b_fp8 + a, _sfa = a_fp8 + b, _sfb = b_fp8flashinfer/gemm.py (1)
2165-2235: NT-groupwise single-GEMM backend checks: solid, consider reusing for blockscaled.The requirement/check split is good. If desired, mirror the CUTLASS CC list for the blockscaled helper to keep consistency.
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (2)
flashinfer/deep_gemm.py(4 hunks)flashinfer/gemm.py(9 hunks)
π§° Additional context used
𧬠Code graph analysis (2)
flashinfer/deep_gemm.py (1)
flashinfer/utils.py (4)
ceil_div(575-586)round_up(589-591)supported_compute_capability(772-852)backend_requirement(855-1028)
flashinfer/gemm.py (2)
flashinfer/utils.py (4)
supported_compute_capability(772-852)backend_requirement(855-1028)is_sm120a_supported(504-506)is_sm121a_supported(509-511)flashinfer/deep_gemm.py (2)
_check_group_deepgemm_fp8_nt_contiguous_problem_size(1367-1413)_check_m_grouped_fp8_gemm_nt_masked_problem_size(1468-1526)
πͺ Ruff (0.14.2)
flashinfer/deep_gemm.py
1372-1372: Unused function argument: recipe
(ARG001)
1373-1373: Unused function argument: compiled_dims
(ARG001)
1379-1379: Avoid specifying long messages outside the exception class
(TRY003)
1381-1381: Avoid specifying long messages outside the exception class
(TRY003)
1384-1386: Avoid specifying long messages outside the exception class
(TRY003)
1397-1399: Avoid specifying long messages outside the exception class
(TRY003)
1401-1401: Avoid specifying long messages outside the exception class
(TRY003)
1403-1403: Avoid specifying long messages outside the exception class
(TRY003)
1405-1405: Avoid specifying long messages outside the exception class
(TRY003)
1407-1407: Avoid specifying long messages outside the exception class
(TRY003)
1411-1411: Avoid specifying long messages outside the exception class
(TRY003)
1437-1437: Unpacked variable k_ is never used
(RUF059)
1474-1474: Unused function argument: recipe
(ARG001)
1475-1475: Unused function argument: compiled_dims
(ARG001)
1480-1480: Avoid specifying long messages outside the exception class
(TRY003)
1482-1482: Avoid specifying long messages outside the exception class
(TRY003)
1485-1487: Avoid specifying long messages outside the exception class
(TRY003)
1489-1489: Unpacked variable sfa is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1490-1490: Unpacked variable sfb is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1502-1504: Avoid specifying long messages outside the exception class
(TRY003)
1506-1508: Avoid specifying long messages outside the exception class
(TRY003)
1510-1512: Avoid specifying long messages outside the exception class
(TRY003)
1514-1514: Avoid specifying long messages outside the exception class
(TRY003)
1516-1516: Avoid specifying long messages outside the exception class
(TRY003)
1518-1518: Avoid specifying long messages outside the exception class
(TRY003)
1520-1520: Avoid specifying long messages outside the exception class
(TRY003)
1524-1524: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/gemm.py
2011-2011: Unused function argument: A
(ARG001)
2012-2012: Unused function argument: B
(ARG001)
2013-2013: Unused function argument: A_scale
(ARG001)
2014-2014: Unused function argument: B_scale
(ARG001)
2015-2015: Unused function argument: dtype
(ARG001)
2016-2016: Unused function argument: out
(ARG001)
2017-2017: Unused function argument: backend
(ARG001)
2025-2025: Unused function argument: A
(ARG001)
2026-2026: Unused function argument: B
(ARG001)
2027-2027: Unused function argument: A_scale
(ARG001)
2028-2028: Unused function argument: B_scale
(ARG001)
2029-2029: Unused function argument: dtype
(ARG001)
2030-2030: Unused function argument: out
(ARG001)
2031-2031: Unused function argument: backend
(ARG001)
2040-2040: Unused function argument: A_scale
(ARG001)
2041-2041: Unused function argument: B_scale
(ARG001)
2042-2042: Unused function argument: dtype
(ARG001)
2043-2043: Unused function argument: out
(ARG001)
2044-2044: Unused function argument: backend
(ARG001)
2047-2047: Avoid specifying long messages outside the exception class
(TRY003)
2052-2052: Unused function argument: A
(ARG001)
2053-2053: Unused function argument: B
(ARG001)
2054-2054: Unused function argument: A_scale
(ARG001)
2055-2055: Unused function argument: B_scale
(ARG001)
2057-2057: Unused function argument: out
(ARG001)
2058-2058: Unused function argument: backend
(ARG001)
2166-2166: Unused function argument: a
(ARG001)
2167-2167: Unused function argument: b
(ARG001)
2168-2168: Unused function argument: a_scale
(ARG001)
2169-2169: Unused function argument: b_scale
(ARG001)
2171-2171: Unused function argument: mma_sm
(ARG001)
2172-2172: Unused function argument: scale_granularity_mnk
(ARG001)
2173-2173: Unused function argument: out
(ARG001)
2174-2174: Unused function argument: out_dtype
(ARG001)
2175-2175: Unused function argument: backend
(ARG001)
2178-2178: Avoid specifying long messages outside the exception class
(TRY003)
2186-2186: Unused function argument: b
(ARG001)
2187-2187: Unused function argument: a_scale
(ARG001)
2188-2188: Unused function argument: b_scale
(ARG001)
2189-2189: Unused function argument: scale_major_mode
(ARG001)
2190-2190: Unused function argument: mma_sm
(ARG001)
2192-2192: Unused function argument: out
(ARG001)
2193-2193: Unused function argument: out_dtype
(ARG001)
2194-2194: Unused function argument: backend
(ARG001)
2197-2197: Avoid specifying long messages outside the exception class
(TRY003)
2199-2199: Avoid specifying long messages outside the exception class
(TRY003)
2207-2207: Unused function argument: a_scale
(ARG001)
2208-2208: Unused function argument: b_scale
(ARG001)
2209-2209: Unused function argument: scale_major_mode
(ARG001)
2210-2210: Unused function argument: mma_sm
(ARG001)
2211-2211: Unused function argument: scale_granularity_mnk
(ARG001)
2212-2212: Unused function argument: out
(ARG001)
2214-2214: Unused function argument: backend
(ARG001)
2217-2217: Avoid specifying long messages outside the exception class
(TRY003)
2220-2222: Avoid specifying long messages outside the exception class
(TRY003)
2596-2596: Unused function argument: scale_granularity_mnk
(ARG001)
2603-2603: Avoid specifying long messages outside the exception class
(TRY003)
2605-2605: Avoid specifying long messages outside the exception class
(TRY003)
2607-2607: Avoid specifying long messages outside the exception class
(TRY003)
2609-2609: Avoid specifying long messages outside the exception class
(TRY003)
2611-2611: Avoid specifying long messages outside the exception class
(TRY003)
2613-2615: Avoid specifying long messages outside the exception class
(TRY003)
2617-2617: Avoid specifying long messages outside the exception class
(TRY003)
2630-2632: Avoid specifying long messages outside the exception class
(TRY003)
2634-2636: Avoid specifying long messages outside the exception class
(TRY003)
2641-2641: Avoid specifying long messages outside the exception class
(TRY003)
2643-2643: Avoid specifying long messages outside the exception class
(TRY003)
2645-2645: Avoid specifying long messages outside the exception class
(TRY003)
2651-2653: Avoid specifying long messages outside the exception class
(TRY003)
2798-2800: Avoid specifying long messages outside the exception class
(TRY003)
2802-2802: Avoid specifying long messages outside the exception class
(TRY003)
2804-2804: Avoid specifying long messages outside the exception class
(TRY003)
2806-2806: Avoid specifying long messages outside the exception class
(TRY003)
2808-2808: Avoid specifying long messages outside the exception class
(TRY003)
2810-2810: Avoid specifying long messages outside the exception class
(TRY003)
2812-2812: Avoid specifying long messages outside the exception class
(TRY003)
2814-2814: Avoid specifying long messages outside the exception class
(TRY003)
2816-2816: Avoid specifying long messages outside the exception class
(TRY003)
2818-2818: Avoid specifying long messages outside the exception class
(TRY003)
2829-2831: Avoid specifying long messages outside the exception class
(TRY003)
2835-2837: Avoid specifying long messages outside the exception class
(TRY003)
2844-2846: Avoid specifying long messages outside the exception class
(TRY003)
2851-2851: Avoid specifying long messages outside the exception class
(TRY003)
2853-2853: Avoid specifying long messages outside the exception class
(TRY003)
2858-2858: Avoid specifying long messages outside the exception class
(TRY003)
2860-2860: Avoid specifying long messages outside the exception class
(TRY003)
3022-3022: Unused function argument: out_dtype
(ARG001)
3172-3172: Unused function argument: out_dtype
(ARG001)
β° Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
 
π Additional comments (2)
flashinfer/deep_gemm.py (1)
48-53: Imports look good.The utils import now brings in supported_compute_capability and backend_requirement correctly.
flashinfer/gemm.py (1)
2009-2072: bmm_fp8 requirements mapping looks correct.Backends and CC gating are wired; common check validates output dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: group_gemm_fp8_nt_groupwise uses {} backend map.
Add a default backend requirement with CC gating; signature must match this API.
+@supported_compute_capability([100, 120, 121]) +def _cutlass_group_gemm_fp8_nt_groupwise_requirement( + a, b, a_scale, b_scale, m_indptr, + scale_granularity_mnk=(1, 128, 128), + scale_major_mode="MN", + mma_sm=1, + out=None, out_dtype=None +): + return True + -@backend_requirement( - {}, - common_check=_check_group_gemm_fp8_nt_groupwise_problem_size, -) +@backend_requirement( + {None: _cutlass_group_gemm_fp8_nt_groupwise_requirement}, + common_check=_check_group_gemm_fp8_nt_groupwise_problem_size, +)
π Committable suggestion
!!οΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: DeepGEMM group path requires a default backend requirement.
Add CC-gated default backend for SM100/103.
+@supported_compute_capability([100, 103]) +def _deepgemm_sm10x_requirement_group( + a, b, a_scale, b_scale, m_indices, scale_granularity_mnk=(1, 128, 128), + out=None, out_dtype=None +): + return True + -@backend_requirement( - {}, - common_check=_check_group_deepgemm_fp8_nt_groupwise_problem_size, -) +@backend_requirement( + {None: _deepgemm_sm10x_requirement_group}, + common_check=_check_group_deepgemm_fp8_nt_groupwise_problem_size, +)
Committable suggestion skipped: line range outside the PR's diff.
π€ Prompt for AI Agents
In flashinfer/gemm.py around lines 3033 to 3036, the @backend_requirement
decorator on the DeepGEMM group path is missing a CC-gated default backend; add
a default backend entry that targets SM100 and SM103 (e.g. a "default" key with
a CUDA backend CC filter for 100 and 103) so the decorator includes both the
existing specific requirements and a default backend gated to compute capability
100 and 103.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: DeepGEMM batch path also uses {}.
Add CC-gated default backend.
+@supported_compute_capability([100, 103]) +def _deepgemm_sm10x_requirement_batch( + a, b, a_scale, b_scale, masked_m, expected_m, + scale_granularity_mnk=(1, 128, 128), + out=None, out_dtype=None +): + return True + -@backend_requirement( - {}, - common_check=_check_batch_deepgemm_fp8_nt_groupwise, -) +@backend_requirement( + {None: _deepgemm_sm10x_requirement_batch}, + common_check=_check_batch_deepgemm_fp8_nt_groupwise, +)
π Committable suggestion
!!οΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
π€ Prompt for AI Agents
In flashinfer/gemm.py around lines 3181 to 3184, the @backend_requirement
decorator for the DeepGEMM batch path currently passes an empty dict for
backends ({}); change it to include a CC-gated default backend entry so the
batch path is enabled only when the compute-capability gate is satisfied. Update
the decorator's first argument to supply a default backend dict guarded by the
same CC check used elsewhere (i.e., add the CC-gated backend entry instead of
{}), keeping the existing common_check=_check_batch_deepgemm_fp8_nt_groupwise.
/bot run
[FAILED] Pipeline #37488727: 1/17 passed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, the compute capability will not be checked actually on the common_check function.
We would either need to:
- Redesign the decorator so that we can just add @supported_compute_capability to m_grouped_fp8_gemm_nt_contiguous directly; OR
 - Change the decorator's implementations of wrapper and is_compute_capability_supported so it also checks this on the common_check.
 - Add a backend parameter to m_grouped_fp8_gemm_nt_contiguous. Though this is deepgemm and I don't think we want to call this a backend for now.
 
My preference would go to option 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 2 also makes most sense. In a lot of the APIs there are also no 'backend' arg to be passed in so we can't only check @ supported_compute_capability there. I can change this in a separate PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A separate PR is fine, since we wouldn't cause a regression (we didn't have CC checks before the current PR anyway).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, the compute capabilities will be ignored.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto doesn't default to cublas. It actually sets backends = ["cutlass", "cublas", "cudnn"] , this usually will be autotuned across backends.
That's not trivial to resolve. One way may be to just create an aggregate function "check_all_backends" that calls the 3 checker functions.
Note that the function that takes all these backends is fp8_gemm_sm100, which also does a bunch of checks. So perhaps we need to rely on those, since the checks here are very minimal anyway.
But this does remain a general issue: "How do we isolate backend support checks when the auto backend is chosen". I think the behavior that we wish is "filter out unsuitable backends", rather than just rejecting auto when one of the backends fails. Maybe this can be an additional "attribute" that the decorator adds for the special auto backend. To illustrate, we would have:
bmm_fp8.suitable_auto_backends() return ["cutlass", "cublas", "cudnn"] if all checks pass, or ["cutlass", "cublas"] if cudnn's check failed. bmm_fp8.suitable_auto_backends() would return None if the selected backend was not "auto".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bmm_fp8.suitable_auto_backends() return ["cutlass", "cublas", "cudnn"] if all checks pass, or ["cutlass", "cublas"] if cudnn's check failed.
In this decorator, we could set an attribute with the final list of supported backends for the actual function to reference when passing into the autotuner, is this what you had in mind?
Uh oh!
There was an error while loading. Please reload this page.
π Description
π Related Issues
π Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
β Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.π§ͺ Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Improvements