Add small-batch quantized matvec kernel (qmv_wide)#3764
Conversation
Speculative decoding (e.g. MTP) verifies a small batch of draft tokens at once, M of 2-8. That batch is too large for qmv, which re-reads the whole quantized weight per vector, and too small for qmm. qmv_wide dequantizes each weight group once and reuses it across the tile, so the weight read amortizes over the batch. The approach is adapted from llama.cpp's kernel_mul_mv_ext. qmv_wide is selected for M in [2, vector_limit) and covers every quantization mode (affine, nvfp4, mxfp4, mxfp8), dtype, and batched weights. The fp modes use it on all GPU generations; affine is gated to gen-15+, where it overtakes qmv (the qmv entries below are affine on gen-14 staying on the old path). Speedup over the per-vector qmv path (GPU kernel time) on the Gemma-4-12B MLP gate/up matmul [15360x3840] with bf16 activations, by M and quant mode. The benefit grows with M as the weight read amortizes: M=2 int4 int8 nvfp4 mxfp4 mxfp8 M2 Pro qmv qmv 1.3x 1.3x 1.3x M3 Ultra 1.1x 1.2x 1.4x 1.2x 1.4x M4 Pro 1.2x 1.0x 1.4x 1.2x 1.4x M5 Max 1.3x 1.2x 1.6x 1.4x 1.3x M=4 int4 int8 nvfp4 mxfp4 mxfp8 M2 Pro qmv qmv 1.7x 1.4x 1.7x M3 Ultra 1.4x 1.6x 1.6x 1.1x 1.8x M4 Pro 1.4x 1.6x 1.7x 1.2x 2.0x M5 Max 1.6x 1.7x 2.0x 1.3x 1.6x M=8 int4 int8 nvfp4 mxfp4 mxfp8 M2 Pro qmv qmv 1.7x 1.4x 1.7x M3 Ultra 1.4x 1.7x 1.7x 1.2x 2.0x M4 Pro 1.4x 1.8x 1.7x 1.2x 2.2x M5 Max 1.6x 1.8x 2.0x 1.3x 1.8x
angeloskath
left a comment
There was a problem hiding this comment.
Awesome speedup! This is also exceptionally written! Very clear and intuitive, nice!
I left a comment about an improvement but I am super excited to merge this. Fantastic work.
| const device uint16_t* wq = (const device uint16_t*)wg; | ||
| #pragma unroll | ||
| for (int i = 0; i < nf4; i++) { | ||
| w4[i] = s * |
There was a problem hiding this comment.
This multiplication should move down after the dot product. We are now unnecessarily doing nf4 * 4 = 32 muls instead of just 1.
I wrote it real quick and I get a pretty massive speedup on mxfp4 Qwen 9B w batch size 4 from 253tps to 295tps on the M5 Max.
| for (int j = 0; j < nf4; j++) { | ||
| acc += dot(w4[j], float4(xv4[j])); | ||
| } | ||
| result[v] += acc; |
There was a problem hiding this comment.
| result[v] += acc; | |
| result[v] += s * acc; |
| const device uint16_t* wq = (const device uint16_t*)wg; | ||
| #pragma unroll | ||
| for (int i = 0; i < nf4; i++) { | ||
| w4[i] = s * |
There was a problem hiding this comment.
| w4[i] = s * | |
| w4[i] = |
| } else { | ||
| #pragma unroll | ||
| for (int i = 0; i < nf4; i++) { | ||
| w4[i] = s * |
There was a problem hiding this comment.
| w4[i] = s * | |
| w4[i] = |
fp_qmv_wide folded the per-group scale into every dequantized weight before the dot product. The dot is linear in the weights, so scaling the accumulated result instead is equivalent and replaces group_size scale multiplies per group with one per streamed vector. Isolated speedup of this change over the previous qmv_wide (GPU kernel time) on the Gemma-4-12B gate/up matmul [15360x3840] with bf16 activations, by M and fp mode. Affine modes use a separate kernel and are unchanged; the win grows with M as the saved multiplies amortize: M=2 nvfp4 mxfp4 mxfp8 M2 Pro 1.00x 1.11x 1.02x M3 Ultra 0.99x 1.13x 1.08x M4 Pro 1.02x 1.10x 1.03x M5 Max 1.16x 1.23x 1.05x M=4 nvfp4 mxfp4 mxfp8 M2 Pro 1.09x 1.15x 1.12x M3 Ultra 1.12x 1.14x 1.25x M4 Pro 1.10x 1.16x 1.14x M5 Max 1.14x 1.19x 1.09x M=8 nvfp4 mxfp4 mxfp8 M2 Pro 1.08x 1.15x 1.13x M3 Ultra 1.16x 1.12x 1.25x M4 Pro 1.14x 1.16x 1.18x M5 Max 1.10x 1.16x 1.08x qmv_wide with this change over the per-vector qmv path, refreshing the table from the previous commit (qmv = affine staying on the old path on gen-14): M=2 int4 int8 nvfp4 mxfp4 mxfp8 M2 Pro qmv qmv 1.5x 1.6x 1.5x M3 Ultra 1.2x 1.3x 1.4x 1.0x 1.8x M4 Pro 1.2x 1.0x 1.4x 1.0x 1.5x M5 Max 1.4x 1.2x 2.0x 1.9x 1.6x M=4 int4 int8 nvfp4 mxfp4 mxfp8 M2 Pro qmv qmv 2.2x 1.7x 2.0x M3 Ultra 1.5x 1.8x 1.4x 1.2x 2.1x M4 Pro 1.5x 1.7x 1.3x 1.3x 2.2x M5 Max 1.7x 1.8x 2.5x 1.8x 2.2x M=8 int4 int8 nvfp4 mxfp4 mxfp8 M2 Pro qmv qmv 2.3x 1.8x 2.1x M3 Ultra 1.4x 1.6x 1.3x 1.1x 2.0x M4 Pro 1.5x 1.7x 1.3x 1.1x 2.2x M5 Max 1.6x 1.6x 2.2x 1.7x 2.1x
|
Great suggestion - that brings another significant improvement. Benchmarking it relative to the original PR I got: M=2
M=4
M=8
And that brings the total improvement of the PR to: M=2
M=4
M=8
|
angeloskath
left a comment
There was a problem hiding this comment.
Lovely, merging after tests pass.
Speculative decoding (e.g. MTP) verifies a small batch of draft tokens at once, M of 2-8. That batch is too large for
qmv, which re-reads the whole quantized weight per vector, and too small forqmm.qmv_widedequantizes each weight group once and reuses it across the tile, so the weight read amortizes over the batch. The approach is adapted from llama.cpp'skernel_mul_mv_ext.qmv_wideis selected for M in[2, vector_limit)and covers every quantization mode (affine, nvfp4, mxfp4, mxfp8), dtype, and batched weights. The fp modes use it on all GPU generations; affine is gated to gen-15+, where it overtakesqmv(theqmventries below are affine on gen-14 staying on the old path).Benchmarks
Speedup over the per-vector
qmvpath (GPU kernel time) on the Gemma-4-12B MLP gate/up matmul[15360x3840]with bf16 activations, by M and quant mode. The benefit grows with M as the weight read amortizes:M=2
M=4
M=8