Skip to content

Fix soft cap in JAX Splash attention#4305

Merged
copybara-service[bot] merged 2 commits into
AI-Hypercomputer:mainfrom
huytransformer:htn/fix-jax-splash-soft-cap
Jun 30, 2026
Merged

Fix soft cap in JAX Splash attention#4305
copybara-service[bot] merged 2 commits into
AI-Hypercomputer:mainfrom
huytransformer:htn/fix-jax-splash-soft-cap

Conversation

@huytransformer

@huytransformer huytransformer commented Jun 30, 2026

Copy link
Copy Markdown
Collaborator

Description

Fixes attention logit soft cap in the JAX Splash attention path. Example, gemma2 configs set attn_logits_soft_cap, but the use_jax_splash path did not pass that value into flash_attention_block_masked.

Tests

Ran:

python3 -m py_compile \
  src/maxtext/layers/attention_op.py \
  src/maxtext/kernels/attention/jax_flash_attention.py \
  tests/unit/attention_test.py

python3 -m pytest tests/unit/attention_test.py::JaxFlashAttentionTest::test_flash_attention_block_masked_soft_cap

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 30, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Comment thread tests/unit/kernels_test.py Outdated
@huytransformer huytransformer force-pushed the htn/fix-jax-splash-soft-cap branch from b147c4c to a56309f Compare June 30, 2026 17:31
@huytransformer huytransformer requested a review from xibinliu as a code owner June 30, 2026 17:31
@huytransformer huytransformer force-pushed the htn/fix-jax-splash-soft-cap branch from f8d7265 to c479d48 Compare June 30, 2026 17:40
@copybara-service copybara-service Bot merged commit 9ce291b into AI-Hypercomputer:main Jun 30, 2026
44 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants