Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Add multi-column attn-out low projection kernel for small batches#399

Open
rwl4 wants to merge 1 commit into
antirez:main from
rwl4:attn-out-ext-kernel
Open

Add multi-column attn-out low projection kernel for small batches #399
rwl4 wants to merge 1 commit into
antirez:main from
rwl4:attn-out-ext-kernel

Conversation

@rwl4

@rwl4 rwl4 commented Jun 11, 2026

Copy link
Copy Markdown

The batched attention-output low projection dispatched one (group, token) matvec per grid slice, re-reading the out_a weights once per token. The new kernel_dsv4_attn_out_low_ext_q8_0_r1_{2..5} variants reuse the small-batch ext mat-vec body (one weight read serves every token column) with a token-major dst mode so the out_b matmul consumes the result directly. Engaged for 2..8-token batches, which today means the MTP batch-verify suffix passes; DS4_METAL_DISABLE_ATTN_OUT_LOW_EXT restores the per-token path.

Measured on M5 Max: low_proj stage 0.472 -> 0.314 ms/layer in an 8-token verify pass; --mtp --mtp-draft 4 greedy run 27.6 -> 28.4 t/s with output byte-identical to the old path at batch widths 2/4/6/8; --long-context gate OK. Any future consumer of small verify batches gets the same per-pass saving.

The batched attention-output low projection dispatched one (group, token)
matvec per grid slice, re-reading the out_a weights once per token. The new
kernel_dsv4_attn_out_low_ext_q8_0_r1_{2..5} variants reuse the small-batch
ext mat-vec body (one weight read serves every token column) with a
token-major dst mode so the out_b matmul consumes the result directly.
Engaged for 2..8-token batches, which today means the MTP batch-verify
suffix passes; DS4_METAL_DISABLE_ATTN_OUT_LOW_EXT restores the per-token
path.
Measured on M5 Max: low_proj stage 0.472 -> 0.314 ms/layer in an 8-token
verify pass; --mtp --mtp-draft 4 greedy run 27.6 -> 28.4 t/s with output
byte-identical to the old path at batch widths 2/4/6/8; --long-context
gate OK. Any future consumer of small verify batches gets the same
per-pass saving.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Reviewers

No reviews

Assignees

No one assigned

Labels

None yet

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

1 participant

AltStyle によって変換されたページ (->オリジナル) /