Skip to content

[ExecuTorch][WebGPU] Add fused SDPA (sdpa_with_kv_cache) with dynamic input_pos#20259

Merged
JulianCloudNTH merged 5 commits into
mainfrom
gh/JulianCloudNTH/19/orig
Jun 13, 2026
Merged

[ExecuTorch][WebGPU] Add fused SDPA (sdpa_with_kv_cache) with dynamic input_pos#20259
JulianCloudNTH merged 5 commits into
mainfrom
gh/JulianCloudNTH/19/orig

Conversation

@pytorchbot

Copy link
Copy Markdown
Collaborator

This PR was created by the merge bot to help merge the original PR into the main branch.
ghstack PR number: #20086 by @JulianCloudNTH
^ Please use this as the source of truth for the PR details, comments, and reviews
ghstack PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/19/base
ghstack PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/19/head
Merge bot PR base: https://github.com/pytorch/executorch/tree/main
Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/19/orig

@diff-train-skip-merge

…ation)

Pull Request resolved: #20201

Backend-agnostic GPU-timestamp infrastructure, split out so the general implementation is foundational (below SDPA) while the SDPA-specific dispatch labeling stays above the SDPA op. Composed of: `WebGPUQueryPool`, a faithful re-port of Vulkan's `vkapi::QueryPool` (`backends/vulkan/runtime/vk_api/QueryPool.{h,cpp}`) — same `ShaderDuration` data model and ticks->ns conversion; three deviations are forced by the WebGPU API (per-dispatch bracketing via a compute-pass `timestampWrites` descriptor since there is no mid-encoder `writeTimestamp`; readback via `resolveQuerySet` + buffer map rather than host-side `vkGetQueryPoolResults`; the `TimestampQuery` capability requested as an explicit device feature, fail-open if the adapter lacks it). `WebGPUDevice` gains timestamp-feature detection, and `WebGPUGraph` gains a per-dispatch `kernel_name` label plus `execute()` bracketing of each compute pass when the pool is active. Opt-in via the `WEBGPU_TIMESTAMP_QUERY` env var; off by default, so the production `execute()` path is byte-identical. The SDPA per-kernel labeling lives in the companion "for SDPA" diff above the SDPA op.

Co-authored with Claude.
ghstack-source-id: 392975889
@exported-using-ghexport

Differential Revision: [D108188287](https://our.internmc.facebook.com/intern/diff/D108188287/)
… input_pos

Pull Request resolved: #20086

Adds the fused `sdpa_with_kv_cache` op (QK attention-weights, softmax, attention-output sub-kernels over the KV cache), composing the three enablers below it: the base graph's inter-dispatch buffer passing (scratch buffers + multi-pass execute), the `update_cache` op, and the SymInt live-scalar mechanism. The QK/softmax/AV kernels mirror the Vulkan reference's flat-index/GQA/causal-mask math (NCHW, buffer-only, fp32).

`input_pos` is consumed dynamically via the SymInt mechanism: the op reads `symint_buffer()` as a uniform, sizes its scratch + dispatches for the max context length, and registers a resize hook so a single delegate runs an autoregressive decode loop (feed only the new token + advancing `input_pos`) instead of a fixed baked position. Mirrors the Vulkan SymInt = live uniform-buffer design.

Tests live in the stacked test-suite diff above (clean op diff here).

Authored with assistance from Claude.
ghstack-source-id: 392609088
@exported-using-ghexport

Differential Revision: [D107595125](https://our.internmc.facebook.com/intern/diff/D107595125/)
@pytorch-bot

pytorch-bot Bot commented Jun 12, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20259

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

⏳ No Failures, 7 Pending

As of commit 97b7196 with merge base e4f434c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 12, 2026
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@JulianCloudNTH JulianCloudNTH self-requested a review June 13, 2026 00:36
…-graph KV cache (#20260)

This PR was created by the merge bot to help merge the original PR into
the main branch.
ghstack PR number: #20087 by
@JulianCloudNTH
^ Please use this as the source of truth for the PR details, comments,
and reviews
ghstack PR base:
https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/20/base
ghstack PR head:
https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/20/head
Merge bot PR base:
https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/19/orig
Merge bot PR head:
https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/20/orig

@diff-train-skip-merge

---------

Co-authored-by: Julian Ng-Thow-Hing <juliannth@meta.com>
@JulianCloudNTH JulianCloudNTH merged commit 96a64ec into main Jun 13, 2026
156 of 179 checks passed
@JulianCloudNTH JulianCloudNTH deleted the gh/JulianCloudNTH/19/orig branch June 13, 2026 01:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants