Skip to content

Add MagCache inference acceleration for Wan2.2 (T2V + I2V)#433

Open
HadarIngonyama wants to merge 4 commits into
AI-Hypercomputer:mainfrom
HadarIngonyama:magcache_wan22_integration
Open

Add MagCache inference acceleration for Wan2.2 (T2V + I2V)#433
HadarIngonyama wants to merge 4 commits into
AI-Hypercomputer:mainfrom
HadarIngonyama:magcache_wan22_integration

Conversation

@HadarIngonyama

Copy link
Copy Markdown

Add MagCache inference acceleration for Wan2.2 (T2V + I2V)

Summary

This PR adds MagCache support to the Wan2.2 dual-transformer pipelines (both T2V and I2V), extending the existing Wan2.1 T2V MagCache support. MagCache skips the transformer blocks and reuses the cached block residual when the accumulated magnitude-ratio error stays below a threshold, using a precalibrated per-step mag_ratios_base curve so the skip schedule is deterministic (no data-dependent control flow, TPU/JIT friendly).

Measured speedups vs the dense render: ~1.82× for T2V and ~1.75× for I2V, with visually near-indistinguishable output.

What's included

  • Wan2.2 T2V (wan_pipeline_2_2.py): MagCache skip path for the dual transformer — a single interleaved mag_ratios_base curve spanning both the high-noise and low-noise phases, a per-phase forced-compute (retention) zone, and an explicit cached-residual reset at the high→low transformer boundary.
  • Wan2.2 I2V (wan_pipeline_i2v_2p2.py): the same skip path adapted for the image-conditioned pipeline (image condition concatenated with the latents, with the required BFHWC↔BCFHW transposes).
  • generate_wan.py: threads use_magcache / magcache_thresh / magcache_K / retention_ratio through to both 2.2 pipelines.
  • Configs:
    • base_wan_27b.yml (T2V): MagCache params + official mag_ratios_base, and flow_shift defaulted to 12.0 (see note below).
    • base_wan_i2v_27b.yml (I2V): MagCache params + official I2V-A14B mag_ratios_base, with boundary_ratio=0.900 to align the high→low switch with the curve (flow_shift stays at the I2V default of 5.0).
  • Tests (wan2_2_magcache_test.py): host-side validation/schedule/core tests plus a TPU-only end-to-end smoke test.
  • README: documents MagCache for Wan2.2 T2V and I2V, including the support matrix, config flags, sampling-shift requirement, and benchmark results.

Important: flow_shift alignment

mag_ratios_base is calibrated against where the high→low noise boundary lands, which flow_shift controls. Wan2.2 T2V requires flow_shift=12.0 (the official A14B sampling shift) — the previous default of 5.0 moved the boundary several steps out of phase, so MagCache skipped at the wrong steps and quality dropped. This PR sets the correct default, which also fixes the off-spec dense baseline. For I2V the official shift is 5.0, paired with boundary_ratio=0.900.

Results

Measured on a v7x (720×1280, 81 frames, 40 steps), reference = dense (use_magcache=False) render with the same seed/config:

Model Settings Speedup Steps skipped SSIM PSNR
Wan2.2 T2V flow_shift=12.0, thresh=0.04, K=2 ~1.82× 18/40 (360s→198s) ≈0.72 ≈21.8 dB
Wan2.2 I2V flow_shift=5.0, boundary_ratio=0.900, thresh=0.06, K=2 ~1.75× 17/40 (6.30→3.61 s/step) ≈0.91 ≈25.4 dB

The reference-based metrics mostly reflect trajectory divergence — caching nudges the sampler onto a different but equally plausible sample — rather than visible degradation; cached clips are visually hard to tell apart from dense. I2V scores higher because the image conditioning anchors the trajectory. Recalibrating mag_ratios_base for a specific dtype/attention kernel can tighten the metric gap further.

Usage

MagCache is one of several mutually-exclusive caching strategies (CFG Cache, SenCache, MagCache) — enable only one at a time.

# Wan2.2 T2V
python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_27b.yml \
  use_magcache=True magcache_thresh=0.04 magcache_K=2 ...

# Wan2.2 I2V
python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_i2v_27b.yml \
  use_magcache=True magcache_thresh=0.06 magcache_K=2 ...

Testing

  • wan2_2_magcache_test.py host-side tests pass (schedule/core logic).
  • End-to-end T2V and I2V runs validated on a v7x TPU; speedup and SSIM/PSNR numbers above were collected from those runs.

…ced-compute)

  zones and an explicit residual reset at the high->low transformer boundary,
  driven by a single interleaved mag_ratios_base curve spanning both phases
- generate_wan.py: pass use_magcache / magcache_thresh / magcache_K /
  retention_ratio through to the 2.2 pipeline
- base_wan_27b.yml: default flow_shift=12.0 (official A14B sampling shift; sets
  the high->low boundary the ratios are aligned to) + MagCache params and the
  official mag_ratios_base
- README: document MagCache for Wan2.2 (flow_shift requirement, ~1.82x speedup,
  SSIM/PSNR vs dense)
- tests: wan_mag_cache_test.py (host-side validation/schedule/core tests + a
  TPU-only end-to-end smoke test)
- wan_pipeline_i2v_2p2.py: MagCache skip path for the dual-transformer I2V
  pipeline, mirroring the T2V 2.2 logic (per-phase retention/forced-compute
  zones, residual reset at the high->low boundary, single interleaved
  mag_ratios_base curve) with I2V-specific handling for the image condition
  (concat with latents + BFHWC<->BCFHW transposes)
- generate_wan.py: pass use_magcache / magcache_thresh / magcache_K /
  retention_ratio through to the 2.2 I2V pipeline
- base_wan_i2v_27b.yml: MagCache params + the official I2V-A14B mag_ratios_base,
  and boundary_ratio=0.900 to align the high->low switch with the curve
  (flow_shift stays at the I2V default of 5.0)
- README: document MagCache for Wan2.2 I2V (settings + ~1.75x speedup,
  SSIM/PSNR vs dense) and add it to the caching support matrix
@HadarIngonyama HadarIngonyama requested a review from entrpn as a code owner June 29, 2026 19:27
@google-cla

google-cla Bot commented Jun 29, 2026

Copy link
Copy Markdown

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@Perseus14 Perseus14 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. I have added some comments. PTAL!

Please run a manual linting test.

pip install pylint pyink==23.10.0 pytype==2024.2.27
pyink src/maxdiffusion --check --diff --color --pyink-indentation=2 --line-length=125

Additionally could you also squash the commits?


cache_count = 0
for step in range(num_inference_steps):
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move jnp.array(scheduler_state.timesteps, dtype=jnp.int32) outside and save some compute

timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)
for step in range(num_inference_steps):
    t = timesteps[step]

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but note that for senchache and cfg cache in the same file the array is also inside the loop.

Comment thread README.md
Comment on lines +610 to +618
| **MagCache** | `use_magcache: True` | Wan 2.1 T2V, Wan 2.2 T2V/I2V | ~1.75–1.9x | [MagCache](https://github.com/Zehong-Ma/MagCache): skips the transformer blocks and reuses the cached block residual when the accumulated magnitude-ratio error stays below `magcache_thresh`, capped at `magcache_K` consecutive skips. Uses a precalibrated per-step `mag_ratios_base` curve, so the skip schedule is deterministic (no data-dependent control flow). |

For Wan 2.2 (dual-transformer), MagCache keeps a single `mag_ratios_base` curve spanning both the high-noise and low-noise phases, forces a full recompute for the first `retention_ratio` fraction of each phase, and resets the cached residual at the high→low boundary. The shipped `mag_ratios_base` in `base_wan_27b.yml` are seeded from the official Wan2.2 values; recalibrating them for your exact setup (model dtype / attention kernel) improves the speedup/quality trade-off.

> **Wan 2.2 T2V requires `flow_shift=12.0`** (the official A14B sampling shift; `base_wan_27b.yml` now defaults to it). `flow_shift` controls where the high→low noise boundary lands, which is the boundary the `mag_ratios_base` curve is calibrated against — a lower shift (e.g. the old `5.0`) moves the boundary several steps out of phase, so MagCache skips at the wrong steps and quality drops. This also corrects the off-spec dense baseline.

Measured on a v7x (Wan 2.2 A14B T2V, 720×1280, 81 frames, 40 steps, `flow_shift=12.0`, seeded ratios at `magcache_thresh=0.04`, `magcache_K=2`): **~1.82× speedup** (18/40 steps skipped, denoise 360s → 198s) at **SSIM ≈ 0.72 / PSNR ≈ 21.8 dB** versus the dense (`use_magcache=False`) render with the same seed/config. These reference-based metrics mostly reflect *trajectory divergence* — caching nudges the sampler onto a different but equally plausible sample — rather than visible degradation; the cached clips are visually hard to tell apart from dense. Recalibrating `mag_ratios_base` for your exact dtype/attention kernel tightens the metric gap further.

Wan 2.2 I2V (`base_wan_i2v_27b.yml`) uses the official I2V-A14B `mag_ratios_base` together with the I2V sampling settings (`flow_shift=5.0`, `boundary_ratio=0.900`) that the curve is aligned to. Measured on a v7x (Wan 2.2 A14B I2V, 720×1280, 81 frames, 40 steps, `magcache_thresh=0.06`, `magcache_K=2`): **~1.75× speedup** (17/40 steps skipped, 6.30s → 3.61s per step) at **SSIM ≈ 0.91 / PSNR ≈ 25.4 dB** versus the dense render. Fidelity is higher than T2V because the image conditioning anchors the sampling trajectory, leaving less room for divergence.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can reduce the text content here.

Minor nit: It is 7x and not v7x

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do. I will change to 7x, but note that the whole readme has this mistake (and it exists also in other places in the repo)

@@ -0,0 +1,389 @@
"""

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please skip the heavy tests when running in github actions

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I already have.
The one heavy test — WanMagCacheSmokeTest (loads the 27B model and runs on TPU) — is already guarded with @pytest.mark.skipif(IN_GITHUB_ACTIONS, ...) at line 275, following the same pattern as wan_sen_cache_test.py / wan_cfg_cache_test.py.

The other three classes are intentionally left to run in CI since they're pure-host and don't need a TPU or weights:

  • WanMagCacheValidationTest — argument-validation only (call raises on bad config)
  • WanMagCacheScheduleTest — host-side skip-schedule math
  • WanMagCacheCoreTest — init_magcache / magcache_step unit tests

Model loading is inside the test method (not at import/collection time), so collection stays green in CI. Let me know if you were thinking of something beyond the skipif guard and I'll adjust.

)
return videos, time.perf_counter() - t0

def test_magcache_speedup_and_fidelity(self):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to use logs instead of print statements

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will fix it, but again, note that this is the case for other tests too - wan_sen_cache_test.py for example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants