[ExecuTorch][WebGPU] Add fused SDPA (sdpa_with_kv_cache) with dynamic input_pos#20086
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20086
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ⏳ 134 Pending, 1 Unrelated FailureAs of commit e29cfe2 with merge base 5526971 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review |
|
Claude finished @JulianCloudNTH's task in 3m 59s —— View job Code Review: Fused SDPA (
|
SS-JIA
left a comment
There was a problem hiding this comment.
Review automatically exported from Phabricator review in Meta.
e41cf0e
into
gh/JulianCloudNTH/19/base
… input_pos Pull Request resolved: #20086 Adds the fused `sdpa_with_kv_cache` op (QK attention-weights, softmax, attention-output sub-kernels over the KV cache), composing the three enablers below it: the base graph's inter-dispatch buffer passing (scratch buffers + multi-pass execute), the `update_cache` op, and the SymInt live-scalar mechanism. The QK/softmax/AV kernels mirror the Vulkan reference's flat-index/GQA/causal-mask math (NCHW, buffer-only, fp32). `input_pos` is consumed dynamically via the SymInt mechanism: the op reads `symint_buffer()` as a uniform, sizes its scratch + dispatches for the max context length, and registers a resize hook so a single delegate runs an autoregressive decode loop (feed only the new token + advancing `input_pos`) instead of a fixed baked position. Mirrors the Vulkan SymInt = live uniform-buffer design. Tests live in the stacked test-suite diff above (clean op diff here). Authored with assistance from Claude. ghstack-source-id: 392609088 @exported-using-ghexport Differential Revision: [D107595125](https://our.internmc.facebook.com/intern/diff/D107595125/)
… input_pos (#20259) This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #20086 by @JulianCloudNTH ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/19/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/19/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/JulianCloudNTH/19/orig @diff-train-skip-merge --------- Co-authored-by: Julian Ng-Thow-Hing <juliannth@meta.com>
Stack from ghstack (oldest at bottom):
Adds the fused
sdpa_with_kv_cacheop (QK attention-weights, softmax, attention-output sub-kernels over the KV cache), composing the three enablers below it: the base graph's inter-dispatch buffer passing (scratch buffers + multi-pass execute), theupdate_cacheop, and the SymInt live-scalar mechanism. The QK/softmax/AV kernels mirror the Vulkan reference's flat-index/GQA/causal-mask math (NCHW, buffer-only, fp32).input_posis consumed dynamically via the SymInt mechanism: the op readssymint_buffer()as a uniform, sizes its scratch + dispatches for the max context length, and registers a resize hook so a single delegate runs an autoregressive decode loop (feed only the new token + advancinginput_pos) instead of a fixed baked position. Mirrors the Vulkan SymInt = live uniform-buffer design.Tests live in the stacked test-suite diff above (clean op diff here).
Authored with assistance from Claude.
@exported-using-ghexport
Differential Revision: D107595125
Differential Revision: D107595125