[ExecuTorch][WebGPU] Add fused SDPA (sdpa_with_kv_cache) with dynamic input_pos#20259
Conversation
…ation) Pull Request resolved: #20201 Backend-agnostic GPU-timestamp infrastructure, split out so the general implementation is foundational (below SDPA) while the SDPA-specific dispatch labeling stays above the SDPA op. Composed of: `WebGPUQueryPool`, a faithful re-port of Vulkan's `vkapi::QueryPool` (`backends/vulkan/runtime/vk_api/QueryPool.{h,cpp}`) — same `ShaderDuration` data model and ticks->ns conversion; three deviations are forced by the WebGPU API (per-dispatch bracketing via a compute-pass `timestampWrites` descriptor since there is no mid-encoder `writeTimestamp`; readback via `resolveQuerySet` + buffer map rather than host-side `vkGetQueryPoolResults`; the `TimestampQuery` capability requested as an explicit device feature, fail-open if the adapter lacks it). `WebGPUDevice` gains timestamp-feature detection, and `WebGPUGraph` gains a per-dispatch `kernel_name` label plus `execute()` bracketing of each compute pass when the pool is active. Opt-in via the `WEBGPU_TIMESTAMP_QUERY` env var; off by default, so the production `execute()` path is byte-identical. The SDPA per-kernel labeling lives in the companion "for SDPA" diff above the SDPA op. Co-authored with Claude. ghstack-source-id: 392975889 @exported-using-ghexport Differential Revision: [D108188287](https://our.internmc.facebook.com/intern/diff/D108188287/)
… input_pos Pull Request resolved: #20086 Adds the fused `sdpa_with_kv_cache` op (QK attention-weights, softmax, attention-output sub-kernels over the KV cache), composing the three enablers below it: the base graph's inter-dispatch buffer passing (scratch buffers + multi-pass execute), the `update_cache` op, and the SymInt live-scalar mechanism. The QK/softmax/AV kernels mirror the Vulkan reference's flat-index/GQA/causal-mask math (NCHW, buffer-only, fp32). `input_pos` is consumed dynamically via the SymInt mechanism: the op reads `symint_buffer()` as a uniform, sizes its scratch + dispatches for the max context length, and registers a resize hook so a single delegate runs an autoregressive decode loop (feed only the new token + advancing `input_pos`) instead of a fixed baked position. Mirrors the Vulkan SymInt = live uniform-buffer design. Tests live in the stacked test-suite diff above (clean op diff here). Authored with assistance from Claude. ghstack-source-id: 392609088 @exported-using-ghexport Differential Revision: [D107595125](https://our.internmc.facebook.com/intern/diff/D107595125/)
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20259
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ⏳ No Failures, 7 PendingAs of commit 97b7196 with merge base e4f434c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
…-graph KV cache (#20260) This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #20087 by @JulianCloudNTH ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/20/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/20/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/19/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/20/orig @diff-train-skip-merge --------- Co-authored-by: Julian Ng-Thow-Hing <juliannth@meta.com>
This PR was created by the merge bot to help merge the original PR into the main branch.
ghstack PR number: #20086 by @JulianCloudNTH
^ Please use this as the source of truth for the PR details, comments, and reviews
ghstack PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/19/base
ghstack PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/19/head
Merge bot PR base: https://github.com/pytorch/executorch/tree/main
Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/19/orig
@diff-train-skip-merge