Skip to content

[ExecuTorch][WebGPU] Add 4-bit weight-only quantized linear (et_vk.linear_q4gsw)#20262

Merged
JulianCloudNTH merged 5 commits into
gh/JulianCloudNTH/20/origfrom
gh/JulianCloudNTH/23/orig
Jun 13, 2026
Merged

[ExecuTorch][WebGPU] Add 4-bit weight-only quantized linear (et_vk.linear_q4gsw)#20262
JulianCloudNTH merged 5 commits into
gh/JulianCloudNTH/20/origfrom
gh/JulianCloudNTH/23/orig

Conversation

@pytorchbot

Copy link
Copy Markdown
Collaborator

This PR was created by the merge bot to help merge the original PR into the main branch.
ghstack PR number: #20226 by @JulianCloudNTH
^ Please use this as the source of truth for the PR details, comments, and reviews
ghstack PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/23/base
ghstack PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/23/head
Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/20/orig
Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/23/orig
Differential Revision: D108312283
@diff-train-skip-merge

…ation)

Pull Request resolved: #20201

Backend-agnostic GPU-timestamp infrastructure, split out so the general implementation is foundational (below SDPA) while the SDPA-specific dispatch labeling stays above the SDPA op. Composed of: `WebGPUQueryPool`, a faithful re-port of Vulkan's `vkapi::QueryPool` (`backends/vulkan/runtime/vk_api/QueryPool.{h,cpp}`) — same `ShaderDuration` data model and ticks->ns conversion; three deviations are forced by the WebGPU API (per-dispatch bracketing via a compute-pass `timestampWrites` descriptor since there is no mid-encoder `writeTimestamp`; readback via `resolveQuerySet` + buffer map rather than host-side `vkGetQueryPoolResults`; the `TimestampQuery` capability requested as an explicit device feature, fail-open if the adapter lacks it). `WebGPUDevice` gains timestamp-feature detection, and `WebGPUGraph` gains a per-dispatch `kernel_name` label plus `execute()` bracketing of each compute pass when the pool is active. Opt-in via the `WEBGPU_TIMESTAMP_QUERY` env var; off by default, so the production `execute()` path is byte-identical. The SDPA per-kernel labeling lives in the companion "for SDPA" diff above the SDPA op.

Co-authored with Claude.
ghstack-source-id: 392093395
@exported-using-ghexport

Differential Revision: [D108188287](https://our.internmc.facebook.com/intern/diff/D108188287/)
… input_pos

Pull Request resolved: #20086

Adds the fused `sdpa_with_kv_cache` op (QK attention-weights, softmax, attention-output sub-kernels over the KV cache), composing the three enablers below it: the base graph's inter-dispatch buffer passing (scratch buffers + multi-pass execute), the `update_cache` op, and the SymInt live-scalar mechanism. The QK/softmax/AV kernels mirror the Vulkan reference's flat-index/GQA/causal-mask math (NCHW, buffer-only, fp32).

`input_pos` is consumed dynamically via the SymInt mechanism: the op reads `symint_buffer()` as a uniform, sizes its scratch + dispatches for the max context length, and registers a resize hook so a single delegate runs an autoregressive decode loop (feed only the new token + advancing `input_pos`) instead of a fixed baked position. Mirrors the Vulkan SymInt = live uniform-buffer design.

Tests live in the stacked test-suite diff above (clean op diff here).

Authored with assistance from Claude.
ghstack-source-id: 392609088
@exported-using-ghexport

Differential Revision: [D107595125](https://our.internmc.facebook.com/intern/diff/D107595125/)
…-graph KV cache

Pull Request resolved: #20087

Adds the WebGPU SDPA test coverage as its own diff, stacked on the SDPA op (which already carries the dynamic-`input_pos` consumption) and the SymInt mechanism below it: multi-step prefill->mt->decode replay, runtime-dynamic `input_pos` (autoregressive decode), and an in-graph mutable KV cache, each compared against a torch `F.scaled_dot_product_attention` golden.

- `test/ops/sdpa/test_sdpa.py`: `ReplaySeq`/`REPLAY_SEQS` + per-step replay export/golden; `DynamicSdpaModule` + `export_dynamic_decode` (one `.pte`, `input_pos` supplied at runtime as a SymInt); `DecodeCacheModule` + `export_incache_decode` (KV cache as `register_buffer` mutable buffers, so the cache persists in-graph and forward() feeds only the new token + `input_pos`).
- `test/test_webgpu_native.cpp`: `test_sdpa_replay`, `test_sdpa_dynamic_decode` (+ negative control: a pinned `input_pos` diverges), `test_sdpa_incache_decode` (+ static control: a fresh Module per step diverges, proving in-graph accumulation is real), `test_symint_roundtrip`, `test_resize_hook`; shared per-element tolerance `sdpa_within_tol` (abs 1e-4 OR rel 1e-3).
- `test/test_build_webgpu.sh`: export the replay / dynamic / in-graph-cache models for the native test.
Authored with assistance from Claude.

ghstack-source-id: 392255556
@exported-using-ghexport

Differential Revision: [D107595144](https://our.internmc.facebook.com/intern/diff/D107595144/)
…near_q4gsw)

Pull Request resolved: #20226

Adds the `et_vk.linear_q4gsw` operator (4-bit groupwise-symmetric weight-only linear) to the WebGPU backend: dequantize the packed int4 weight in WGSL (`(q-8)*scale`) and accumulate an fp32 matmul, consuming the serialized `[N, K/2]` uint8 weight directly (no prepack), one workgroup per output row. Mirrors the Vulkan reference (`backends/vulkan/.../impl/QuantizedLinear.cpp`). The dispatch carries a `linear_q4gsw` label for GPU-timestamp-query profiling (mirroring the SDPA kernels). The numerical test suite is in the stacked test diff.
ghstack-source-id: 392908894
@exported-using-ghexport

Differential Revision: [D108312283](https://our.internmc.facebook.com/intern/diff/D108312283/)
@pytorch-bot

pytorch-bot Bot commented Jun 13, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20262

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 13, 2026
@JulianCloudNTH JulianCloudNTH self-requested a review June 13, 2026 00:33
@JulianCloudNTH JulianCloudNTH merged commit a73b187 into gh/JulianCloudNTH/20/orig Jun 13, 2026
154 of 176 checks passed
@JulianCloudNTH JulianCloudNTH deleted the gh/JulianCloudNTH/23/orig branch June 13, 2026 01:09
@JulianCloudNTH JulianCloudNTH restored the gh/JulianCloudNTH/23/orig branch June 13, 2026 01:17
@JulianCloudNTH JulianCloudNTH deleted the gh/JulianCloudNTH/23/orig branch June 13, 2026 01:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants