Skip to content

Add small-batch quantized matvec kernel (qmv_wide)#3764

Merged
angeloskath merged 2 commits into
ml-explore:mainfrom
jessegross:qmv-wide
Jun 26, 2026
Merged

Add small-batch quantized matvec kernel (qmv_wide)#3764
angeloskath merged 2 commits into
ml-explore:mainfrom
jessegross:qmv-wide

Conversation

@jessegross

Copy link
Copy Markdown
Contributor

Speculative decoding (e.g. MTP) verifies a small batch of draft tokens at once, M of 2-8. That batch is too large for qmv, which re-reads the whole quantized weight per vector, and too small for qmm. qmv_wide dequantizes each weight group once and reuses it across the tile, so the weight read amortizes over the batch. The approach is adapted from llama.cpp's kernel_mul_mv_ext.

qmv_wide is selected for M in [2, vector_limit) and covers every quantization mode (affine, nvfp4, mxfp4, mxfp8), dtype, and batched weights. The fp modes use it on all GPU generations; affine is gated to gen-15+, where it overtakes qmv (the qmv entries below are affine on gen-14 staying on the old path).

Benchmarks

Speedup over the per-vector qmv path (GPU kernel time) on the Gemma-4-12B MLP gate/up matmul [15360x3840] with bf16 activations, by M and quant mode. The benefit grows with M as the weight read amortizes:

M=2

GPU int4 int8 nvfp4 mxfp4 mxfp8
M2 Pro qmv qmv 1.3x 1.3x 1.3x
M3 Ultra 1.1x 1.2x 1.4x 1.2x 1.4x
M4 Pro 1.2x 1.0x 1.4x 1.2x 1.4x
M5 Max 1.3x 1.2x 1.6x 1.4x 1.3x

M=4

GPU int4 int8 nvfp4 mxfp4 mxfp8
M2 Pro qmv qmv 1.7x 1.4x 1.7x
M3 Ultra 1.4x 1.6x 1.6x 1.1x 1.8x
M4 Pro 1.4x 1.6x 1.7x 1.2x 2.0x
M5 Max 1.6x 1.7x 2.0x 1.3x 1.6x

M=8

GPU int4 int8 nvfp4 mxfp4 mxfp8
M2 Pro qmv qmv 1.7x 1.4x 1.7x
M3 Ultra 1.4x 1.7x 1.7x 1.2x 2.0x
M4 Pro 1.4x 1.8x 1.7x 1.2x 2.2x
M5 Max 1.6x 1.8x 2.0x 1.3x 1.8x

Speculative decoding (e.g. MTP) verifies a small batch of draft tokens at
once, M of 2-8. That batch is too large for qmv, which re-reads the whole
quantized weight per vector, and too small for qmm. qmv_wide dequantizes
each weight group once and reuses it across the tile, so the weight read
amortizes over the batch. The approach is adapted from llama.cpp's
kernel_mul_mv_ext.

qmv_wide is selected for M in [2, vector_limit) and covers every
quantization mode (affine, nvfp4, mxfp4, mxfp8), dtype, and batched
weights. The fp modes use it on all GPU generations; affine is gated to
gen-15+, where it overtakes qmv (the qmv entries below are affine on
gen-14 staying on the old path).

Speedup over the per-vector qmv path (GPU kernel time) on the Gemma-4-12B
MLP gate/up matmul [15360x3840] with bf16 activations, by M and quant
mode. The benefit grows with M as the weight read amortizes:

  M=2            int4   int8  nvfp4  mxfp4  mxfp8
  M2 Pro          qmv    qmv   1.3x   1.3x   1.3x
  M3 Ultra       1.1x   1.2x   1.4x   1.2x   1.4x
  M4 Pro         1.2x   1.0x   1.4x   1.2x   1.4x
  M5 Max         1.3x   1.2x   1.6x   1.4x   1.3x

  M=4            int4   int8  nvfp4  mxfp4  mxfp8
  M2 Pro          qmv    qmv   1.7x   1.4x   1.7x
  M3 Ultra       1.4x   1.6x   1.6x   1.1x   1.8x
  M4 Pro         1.4x   1.6x   1.7x   1.2x   2.0x
  M5 Max         1.6x   1.7x   2.0x   1.3x   1.6x

  M=8            int4   int8  nvfp4  mxfp4  mxfp8
  M2 Pro          qmv    qmv   1.7x   1.4x   1.7x
  M3 Ultra       1.4x   1.7x   1.7x   1.2x   2.0x
  M4 Pro         1.4x   1.8x   1.7x   1.2x   2.2x
  M5 Max         1.6x   1.8x   2.0x   1.3x   1.8x

@angeloskath angeloskath left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome speedup! This is also exceptionally written! Very clear and intuitive, nice!

I left a comment about an improvement but I am super excited to merge this. Fantastic work.

const device uint16_t* wq = (const device uint16_t*)wg;
#pragma unroll
for (int i = 0; i < nf4; i++) {
w4[i] = s *

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This multiplication should move down after the dot product. We are now unnecessarily doing nf4 * 4 = 32 muls instead of just 1.

I wrote it real quick and I get a pretty massive speedup on mxfp4 Qwen 9B w batch size 4 from 253tps to 295tps on the M5 Max.

for (int j = 0; j < nf4; j++) {
acc += dot(w4[j], float4(xv4[j]));
}
result[v] += acc;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
result[v] += acc;
result[v] += s * acc;

const device uint16_t* wq = (const device uint16_t*)wg;
#pragma unroll
for (int i = 0; i < nf4; i++) {
w4[i] = s *

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
w4[i] = s *
w4[i] =

} else {
#pragma unroll
for (int i = 0; i < nf4; i++) {
w4[i] = s *

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
w4[i] = s *
w4[i] =

fp_qmv_wide folded the per-group scale into every dequantized weight
before the dot product. The dot is linear in the weights, so scaling the
accumulated result instead is equivalent and replaces group_size scale
multiplies per group with one per streamed vector.

Isolated speedup of this change over the previous qmv_wide (GPU kernel
time) on the Gemma-4-12B gate/up matmul [15360x3840] with bf16
activations, by M and fp mode. Affine modes use a separate kernel and
are unchanged; the win grows with M as the saved multiplies amortize:

  M=2           nvfp4  mxfp4  mxfp8
  M2 Pro        1.00x  1.11x  1.02x
  M3 Ultra      0.99x  1.13x  1.08x
  M4 Pro        1.02x  1.10x  1.03x
  M5 Max        1.16x  1.23x  1.05x

  M=4           nvfp4  mxfp4  mxfp8
  M2 Pro        1.09x  1.15x  1.12x
  M3 Ultra      1.12x  1.14x  1.25x
  M4 Pro        1.10x  1.16x  1.14x
  M5 Max        1.14x  1.19x  1.09x

  M=8           nvfp4  mxfp4  mxfp8
  M2 Pro        1.08x  1.15x  1.13x
  M3 Ultra      1.16x  1.12x  1.25x
  M4 Pro        1.14x  1.16x  1.18x
  M5 Max        1.10x  1.16x  1.08x

qmv_wide with this change over the per-vector qmv path, refreshing the
table from the previous commit (qmv = affine staying on the old path on
gen-14):

  M=2            int4   int8  nvfp4  mxfp4  mxfp8
  M2 Pro          qmv    qmv   1.5x   1.6x   1.5x
  M3 Ultra       1.2x   1.3x   1.4x   1.0x   1.8x
  M4 Pro         1.2x   1.0x   1.4x   1.0x   1.5x
  M5 Max         1.4x   1.2x   2.0x   1.9x   1.6x

  M=4            int4   int8  nvfp4  mxfp4  mxfp8
  M2 Pro          qmv    qmv   2.2x   1.7x   2.0x
  M3 Ultra       1.5x   1.8x   1.4x   1.2x   2.1x
  M4 Pro         1.5x   1.7x   1.3x   1.3x   2.2x
  M5 Max         1.7x   1.8x   2.5x   1.8x   2.2x

  M=8            int4   int8  nvfp4  mxfp4  mxfp8
  M2 Pro          qmv    qmv   2.3x   1.8x   2.1x
  M3 Ultra       1.4x   1.6x   1.3x   1.1x   2.0x
  M4 Pro         1.5x   1.7x   1.3x   1.1x   2.2x
  M5 Max         1.6x   1.6x   2.2x   1.7x   2.1x
@jessegross

Copy link
Copy Markdown
Contributor Author

Great suggestion - that brings another significant improvement. Benchmarking it relative to the original PR I got:

M=2

GPU nvfp4 mxfp4 mxfp8
M2 Pro 1.00x 1.11x 1.02x
M3 Ultra 0.99x 1.13x 1.08x
M4 Pro 1.02x 1.10x 1.03x
M5 Max 1.16x 1.23x 1.05x

M=4

GPU nvfp4 mxfp4 mxfp8
M2 Pro 1.09x 1.15x 1.12x
M3 Ultra 1.12x 1.14x 1.25x
M4 Pro 1.10x 1.16x 1.14x
M5 Max 1.14x 1.19x 1.09x

M=8

GPU nvfp4 mxfp4 mxfp8
M2 Pro 1.08x 1.15x 1.13x
M3 Ultra 1.16x 1.12x 1.25x
M4 Pro 1.14x 1.16x 1.18x
M5 Max 1.10x 1.16x 1.08x

And that brings the total improvement of the PR to:

M=2

GPU int4 int8 nvfp4 mxfp4 mxfp8
M2 Pro qmv qmv 1.5x 1.6x 1.5x
M3 Ultra 1.2x 1.3x 1.4x 1.0x 1.8x
M4 Pro 1.2x 1.0x 1.4x 1.0x 1.5x
M5 Max 1.4x 1.2x 2.0x 1.9x 1.6x

M=4

GPU int4 int8 nvfp4 mxfp4 mxfp8
M2 Pro qmv qmv 2.2x 1.7x 2.0x
M3 Ultra 1.5x 1.8x 1.4x 1.2x 2.1x
M4 Pro 1.5x 1.7x 1.3x 1.3x 2.2x
M5 Max 1.7x 1.8x 2.5x 1.8x 2.2x

M=8

GPU int4 int8 nvfp4 mxfp4 mxfp8
M2 Pro qmv qmv 2.3x 1.8x 2.1x
M3 Ultra 1.4x 1.6x 1.3x 1.1x 2.0x
M4 Pro 1.5x 1.7x 1.3x 1.1x 2.2x
M5 Max 1.6x 1.6x 2.2x 1.7x 2.1x

@angeloskath angeloskath left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lovely, merging after tests pass.

@angeloskath angeloskath merged commit 548dd80 into ml-explore:main Jun 26, 2026
16 checks passed
@jessegross jessegross deleted the qmv-wide branch June 26, 2026 23:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants