From adbdb923ec386dface63af44ee7baddfb292dcb6 Mon Sep 17 00:00:00 2001 From: HadarIngonyama Date: Mon, 29 Jun 2026 20:41:53 +0300 Subject: [PATCH 1/4] - run_inference_2_2: MagCache skip path with per-phase retention (forced-compute) zones and an explicit residual reset at the high->low transformer boundary, driven by a single interleaved mag_ratios_base curve spanning both phases - generate_wan.py: pass use_magcache / magcache_thresh / magcache_K / retention_ratio through to the 2.2 pipeline - base_wan_27b.yml: default flow_shift=12.0 (official A14B sampling shift; sets the high->low boundary the ratios are aligned to) + MagCache params and the official mag_ratios_base - README: document MagCache for Wan2.2 (flow_shift requirement, ~1.82x speedup, SSIM/PSNR vs dense) - tests: wan_mag_cache_test.py (host-side validation/schedule/core tests + a TPU-only end-to-end smoke test) --- README.md | 17 +- src/maxdiffusion/configs/base_wan_27b.yml | 22 +- src/maxdiffusion/generate_wan.py | 6 +- .../pipelines/wan/wan_pipeline_2_2.py | 137 +++++- .../tests/wan/wan2_2_magcache_test.py | 389 ++++++++++++++++++ 5 files changed, 563 insertions(+), 8 deletions(-) create mode 100644 src/maxdiffusion/tests/wan/wan2_2_magcache_test.py diff --git a/README.md b/README.md index 5ddcc323..63d57960 100755 --- a/README.md +++ b/README.md @@ -607,6 +607,13 @@ To generate images, run the following command: | --- | --- | --- | --- | --- | | **CFG Cache** | `use_cfg_cache: True` | Wan 2.1 T2V, Wan 2.2 T2V/I2V | ~1.2x | FasterCache-style: caches the unconditional branch and applies FFT frequency-domain compensation on skipped steps. | | **SenCache** | `use_sen_cache: True` | Wan 2.2 T2V/I2V | ~1.4x | Sensitivity-Aware Caching ([arXiv:2602.24208](https://arxiv.org/abs/2602.24208)): predicts output change via first-order sensitivity S = α_x·‖Δx‖ + α_t·\|Δt\|. Skips the full CFG forward pass when predicted change is below tolerance ε. | + | **MagCache** | `use_magcache: True` | Wan 2.1 T2V, Wan 2.2 T2V | ~1.8–1.9x | [MagCache](https://github.com/Zehong-Ma/MagCache): skips the transformer blocks and reuses the cached block residual when the accumulated magnitude-ratio error stays below `magcache_thresh`, capped at `magcache_K` consecutive skips. Uses a precalibrated per-step `mag_ratios_base` curve, so the skip schedule is deterministic (no data-dependent control flow). | + + For Wan 2.2 (dual-transformer), MagCache keeps a single `mag_ratios_base` curve spanning both the high-noise and low-noise phases, forces a full recompute for the first `retention_ratio` fraction of each phase, and resets the cached residual at the high→low boundary. The shipped `mag_ratios_base` in `base_wan_27b.yml` are seeded from the official Wan2.2 values; recalibrating them for your exact setup (model dtype / attention kernel) improves the speedup/quality trade-off. + + > **Wan 2.2 T2V requires `flow_shift=12.0`** (the official A14B sampling shift; `base_wan_27b.yml` now defaults to it). `flow_shift` controls where the high→low noise boundary lands, which is the boundary the `mag_ratios_base` curve is calibrated against — a lower shift (e.g. the old `5.0`) moves the boundary several steps out of phase, so MagCache skips at the wrong steps and quality drops. This also corrects the off-spec dense baseline. + + Measured on a v7x (Wan 2.2 A14B T2V, 720×1280, 81 frames, 40 steps, `flow_shift=12.0`, seeded ratios at `magcache_thresh=0.04`, `magcache_K=2`): **~1.82× speedup** (18/40 steps skipped, denoise 360s → 198s) at **SSIM ≈ 0.72 / PSNR ≈ 21.8 dB** versus the dense (`use_magcache=False`) render with the same seed/config. These reference-based metrics mostly reflect *trajectory divergence* — caching nudges the sampler onto a different but equally plausible sample — rather than visible degradation; the cached clips are visually hard to tell apart from dense. Recalibrating `mag_ratios_base` for your exact dtype/attention kernel tightens the metric gap further. To enable a caching mechanism, set the corresponding flag in your config YAML or pass it as a command-line override: @@ -622,6 +629,14 @@ To generate images, run the following command: src/maxdiffusion/configs/base_wan_i2v_27b.yml \ use_cfg_cache=True \ ... + + # Example: enable MagCache for Wan 2.2 T2V + python src/maxdiffusion/generate_wan.py \ + src/maxdiffusion/configs/base_wan_27b.yml \ + use_magcache=True \ + magcache_thresh=0.04 \ + magcache_K=2 \ + ... ``` ### Ring Attention @@ -819,4 +834,4 @@ This script will automatically format your code with `pyink` and help you identi The full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance. ## Profiling -To learn how to enable ML Diagnostics and XProf profiling for your runs, please see our [ML Diagnostics Guide](docs/profiling.md). +To learn how to enable ML Diagnostics and XProf profiling for your runs, please see our [ML Diagnostics Guide](docs/profiling.md). \ No newline at end of file diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 19153da1..d3734c3f 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -336,7 +336,11 @@ do_classifier_free_guidance: True height: 720 width: 1280 num_frames: 81 -flow_shift: 5.0 +# Official Wan2.2 T2V A14B sampling shift (wan_t2v_A14B.py: sample_shift=12.0). This sets where the +# high->low transformer boundary falls in the step schedule (~step 26 of 40 at boundary_ratio 0.875), +# which must match the schedule the seeded mag_ratios_base was calibrated at. (A smaller shift moves +# the boundary earlier, which both degrades the base sample and misaligns MagCache's skip schedule.) +flow_shift: 12.0 # Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py # guidance scale factor for low noise transformer @@ -361,6 +365,20 @@ use_kv_cache: False # when predicted output change (based on accumulated latent/timestep drift) is small use_sen_cache: False +# MagCache (https://github.com/Zehong-Ma/MagCache) — skip transformer blocks when +# the accumulated magnitude-ratio error stays below `magcache_thresh`, reusing the +# cached block residual. `magcache_K` caps consecutive skips; `retention_ratio` is +# the fraction at the start of each phase that always computes in full. +use_magcache: False +magcache_thresh: 0.04 +magcache_K: 2 +retention_ratio: 0.2 +# Calibrated average magnitude ratios, interleaved [cond, uncond, ...], length +# num_inference_steps*2 (= 80 for 40 steps). Single curve spanning both phases; +# the dip near the middle marks the high->low boundary. Seeded from the official +# WAN 2.2 T2V ratios — recalibrate for this setup to tune the speedup/quality. +mag_ratios_base: [1.0, 1.0, 1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181] + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 40 @@ -422,4 +440,4 @@ enable_ssim: False # ML Diagnostics settings enable_ml_diagnostics: False profiler_gcs_path: "" -enable_ondemand_xprof: False +enable_ondemand_xprof: False \ No newline at end of file diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index a30c4d63..d5eaf4d3 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -157,6 +157,10 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): use_cfg_cache=config.use_cfg_cache, use_sen_cache=config.use_sen_cache, use_kv_cache=config.use_kv_cache, + use_magcache=config.use_magcache, + magcache_thresh=config.magcache_thresh, + magcache_K=config.magcache_K, + retention_ratio=config.retention_ratio, ) else: raise ValueError(f"Unsupported model_name for T2V in config: {model_key}") @@ -384,4 +388,4 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": with transformer_engine_context(): - app.run(main) + app.run(main) \ No newline at end of file diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 00d11f96..077fbfb2 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache +from .wan_pipeline import ( + WanPipeline, + transformer_forward_pass, + transformer_forward_pass_full_cfg, + transformer_forward_pass_cfg_cache, + init_magcache, + magcache_step, +) from ...models.wan.transformers.transformer_wan import WanModel from typing import List, Union, Optional from ...pyconfig import HyperParameters @@ -118,13 +125,24 @@ def __call__( use_cfg_cache: bool = False, use_sen_cache: bool = False, use_kv_cache: bool = False, + use_magcache: bool = False, + magcache_thresh: float = 0.04, + magcache_K: int = 2, + retention_ratio: float = 0.2, ): config = getattr(self, "config", None) if max_sequence_length is None: max_sequence_length = getattr(config, "max_sequence_length", 512) - if use_cfg_cache and use_sen_cache: - raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.") + if sum([use_cfg_cache, use_sen_cache, use_magcache]) > 1: + raise ValueError("use_cfg_cache, use_sen_cache and use_magcache are mutually exclusive. Enable only one.") + + if use_magcache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0): + raise ValueError( + f"use_magcache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 " + f"(got {guidance_scale_low}, {guidance_scale_high}). " + "MagCache reuses the cached residual across the doubled CFG batch, which must be enabled for both phases." + ) if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0): raise ValueError( @@ -183,6 +201,10 @@ def __call__( use_sen_cache=use_sen_cache, height=height, use_kv_cache=use_kv_cache, + use_magcache=use_magcache, + magcache_thresh=magcache_thresh, + magcache_K=magcache_K, + retention_ratio=retention_ratio, ) t_denoise_start = time.perf_counter() @@ -233,6 +255,10 @@ def run_inference_2_2( height: int = 480, config=None, use_kv_cache: bool = False, + use_magcache: bool = False, + magcache_thresh: float = 0.04, + magcache_K: int = 2, + retention_ratio: float = 0.2, ): """Denoising loop for WAN 2.2 T2V with optional caching acceleration. @@ -279,6 +305,109 @@ def run_inference_2_2( high_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest) kv_cache_high, encoder_attention_mask_high = high_transformer.compute_kv_cache(prompt_embeds_combined) + # ── MagCache path (Ma et al., https://github.com/Zehong-Ma/MagCache) ── + # Skips the transformer blocks on steps whose accumulated magnitude-ratio error + # stays below `magcache_thresh`, reusing the cached block residual instead. + # The skip schedule is fully static (the ratios are calibration constants), so + # the decision is made host-side and only a static `skip_blocks` bool crosses + # into the forward pass. Dual-transformer handling: one calibrated curve spans + # both phases (the official `mag_ratios` layout), with a forced-compute zone at + # the start of each phase and an explicit cache reset at the high→low boundary. + if use_magcache and do_classifier_free_guidance: + timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) + step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)] + high_noise_steps = sum(step_uses_high) + + mag_ratios_base = getattr(config, "mag_ratios_base", None) if config else None + if mag_ratios_base is None: + raise ValueError( + "use_magcache=True requires config.mag_ratios_base — the calibrated magnitude ratios " + "(interleaved cond/uncond, length num_inference_steps*2). Run the calibration pass or " + "use the published WAN 2.2 ratios." + ) + + # Single state + single ratio curve spanning both phases (official layout). + magcache_init = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base) + accumulated_state = magcache_init[:6] + cached_residual = magcache_init[6] + mag_ratios = magcache_init[8] + + # Forced-compute ("retention") zones, in step units: the first + # `retention_ratio` fraction of each phase always computes in full. This is + # also what flushes the stale residual right after the boundary. + high_warmup_end = int(high_noise_steps * retention_ratio) + low_warmup_end = high_noise_steps + int((num_inference_steps - high_noise_steps) * retention_ratio) + + cache_count = 0 + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + + # Select transformer + guidance for this phase. + if step_uses_high[step]: + graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high + else: + graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest + guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low + + # Boundary reset: the high-noise transformer's residual is meaningless to + # the low-noise transformer, so start the low phase with a fresh cache. + is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1] + if is_boundary: + cached_residual = None + accumulated_state = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)[:6] + + # Force a full compute inside either phase's warmup zone, at the boundary, + # or until we have a residual to reuse. + in_warmup = step < high_warmup_end or (high_noise_steps <= step < low_warmup_end) + force_compute = in_warmup or is_boundary or cached_residual is None + + skip_blocks, accumulated_state = magcache_step( + step, + mag_ratios, + accumulated_state, + magcache_thresh, + magcache_K, + use_magcache=(not force_compute), + ) + + latents_doubled = jnp.concatenate([latents] * 2) + timestep = jnp.broadcast_to(t, bsz * 2) + noise_pred, _, residual_x_cur = transformer_forward_pass( + graphdef, + state, + rest, + latents_doubled, + timestep, + prompt_embeds_combined, + do_classifier_free_guidance=True, + guidance_scale=guidance_scale, + skip_blocks=bool(skip_blocks), + cached_residual=cached_residual, + return_residual=True, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, + ) + + if skip_blocks: + cache_count += 1 + else: + cached_residual = residual_x_cur + + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + + max_logging.log( + f"[MagCache] Cached {cache_count}/{num_inference_steps} steps " + f"({100*cache_count/num_inference_steps:.1f}% cache ratio), " + f"high_noise_steps={high_noise_steps}, thresh={magcache_thresh}, K={magcache_K}" + ) + return latents + # ── SenCache path (arXiv:2602.24208) ── if use_sen_cache and do_classifier_free_guidance: timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) @@ -711,4 +840,4 @@ def scan_body(carry, t): if profiler: latents.block_until_ready() profiler.stop() - return latents + return latents \ No newline at end of file diff --git a/src/maxdiffusion/tests/wan/wan2_2_magcache_test.py b/src/maxdiffusion/tests/wan/wan2_2_magcache_test.py new file mode 100644 index 00000000..954e0f87 --- /dev/null +++ b/src/maxdiffusion/tests/wan/wan2_2_magcache_test.py @@ -0,0 +1,389 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import time +import unittest + +import numpy as np +import pytest +from absl.testing import absltest + +from maxdiffusion.pipelines.wan.wan_pipeline import init_magcache, magcache_step +from maxdiffusion.pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2 + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +class WanMagCacheValidationTest(unittest.TestCase): + """Tests that use_magcache validation raises correct errors.""" + + def _make_pipeline(self): + return WanPipeline2_2.__new__(WanPipeline2_2) + + def test_magcache_with_both_scales_low_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline(prompt=["test"], guidance_scale_low=1.0, guidance_scale_high=1.0, use_magcache=True) + self.assertIn("use_magcache", str(ctx.exception)) + + def test_magcache_with_low_scale_low_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline(prompt=["test"], guidance_scale_low=0.5, guidance_scale_high=4.0, use_magcache=True) + self.assertIn("use_magcache", str(ctx.exception)) + + def test_magcache_with_high_scale_low_raises(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline(prompt=["test"], guidance_scale_low=3.0, guidance_scale_high=1.0, use_magcache=True) + self.assertIn("use_magcache", str(ctx.exception)) + + def test_magcache_mutually_exclusive_with_cfg_cache(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + guidance_scale_low=3.0, + guidance_scale_high=4.0, + use_cfg_cache=True, + use_magcache=True, + ) + self.assertIn("mutually exclusive", str(ctx.exception)) + + def test_magcache_mutually_exclusive_with_sen_cache(self): + pipeline = self._make_pipeline() + with self.assertRaises(ValueError) as ctx: + pipeline( + prompt=["test"], + guidance_scale_low=3.0, + guidance_scale_high=4.0, + use_sen_cache=True, + use_magcache=True, + ) + self.assertIn("mutually exclusive", str(ctx.exception)) + + def test_magcache_with_valid_scales_no_validation_error(self): + """Both guidance_scales > 1.0 should pass validation (may fail later without model).""" + pipeline = self._make_pipeline() + try: + pipeline(prompt=["test"], guidance_scale_low=3.0, guidance_scale_high=4.0, use_magcache=True) + except ValueError as e: + if "use_magcache" in str(e): + self.fail(f"Unexpected validation error: {e}") + except Exception: + pass + + def test_no_magcache_with_low_scales_no_error(self): + """use_magcache=False should never raise our ValueError.""" + pipeline = self._make_pipeline() + try: + pipeline(prompt=["test"], guidance_scale_low=0.5, guidance_scale_high=0.5, use_magcache=False) + except ValueError as e: + if "use_magcache" in str(e): + self.fail(f"Unexpected validation error: {e}") + except Exception: + pass + + +class WanMagCacheScheduleTest(unittest.TestCase): + """Tests the MagCache dual-phase schedule (retention zones + boundary reset). + + Mirrors the deterministic, host-side schedule in run_inference_2_2's MagCache + branch: forced-compute (retention) zone at the start of EACH phase, an explicit + cache reset at the high->low boundary, and global-step indexing into mag_ratios. + The actual skip decision (magcache_step) is exercised separately below. + """ + + def _get_magcache_schedule(self, num_inference_steps, retention_ratio=0.2, boundary_ratio=0.875, num_train_timesteps=1000): + boundary = boundary_ratio * num_train_timesteps + timesteps = np.linspace(num_train_timesteps - 1, 0, num_inference_steps, dtype=np.int32) + step_uses_high = [bool(timesteps[s] >= boundary) for s in range(num_inference_steps)] + high_noise_steps = sum(step_uses_high) + + high_warmup_end = int(high_noise_steps * retention_ratio) + low_warmup_end = high_noise_steps + int((num_inference_steps - high_noise_steps) * retention_ratio) + + force_compute, is_boundary_list = [], [] + cached_residual_is_none = True # no residual until the first compute + for step in range(num_inference_steps): + is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1] + if is_boundary: + cached_residual_is_none = True # boundary reset + in_warmup = step < high_warmup_end or (high_noise_steps <= step < low_warmup_end) + forced = in_warmup or is_boundary or cached_residual_is_none + force_compute.append(forced) + is_boundary_list.append(is_boundary) + cached_residual_is_none = False # a residual exists after any step (computed or reused) + + return { + "step_uses_high": step_uses_high, + "high_noise_steps": high_noise_steps, + "high_warmup_end": high_warmup_end, + "low_warmup_end": low_warmup_end, + "force_compute": force_compute, + "is_boundary": is_boundary_list, + } + + def test_first_step_always_forced(self): + """Step 0 has no cached residual yet, so it must compute.""" + sched = self._get_magcache_schedule(40) + self.assertTrue(sched["force_compute"][0]) + + def test_high_phase_warmup_forced(self): + """The first retention_ratio fraction of the high-noise phase is forced.""" + sched = self._get_magcache_schedule(40) + self.assertTrue(all(sched["force_compute"][: sched["high_warmup_end"]])) + + def test_low_phase_warmup_forced(self): + """The first retention_ratio fraction of the low-noise phase (post-boundary) is forced.""" + sched = self._get_magcache_schedule(40) + high, low_end = sched["high_noise_steps"], sched["low_warmup_end"] + self.assertTrue(all(sched["force_compute"][high:low_end]), "Low-phase warmup zone must be forced") + + def test_boundary_is_forced(self): + """Every high<->low transition step must compute (residual reset).""" + sched = self._get_magcache_schedule(40) + suh = sched["step_uses_high"] + for s in range(1, 40): + if suh[s] != suh[s - 1]: + self.assertTrue(sched["is_boundary"][s], f"step {s} should be a boundary") + self.assertTrue(sched["force_compute"][s], f"Boundary step {s} must be forced") + + def test_exactly_one_boundary(self): + """A monotone timestep schedule crosses the boundary exactly once.""" + sched = self._get_magcache_schedule(40) + self.assertEqual(sum(sched["is_boundary"]), 1) + + def test_cacheable_window_exists(self): + """With enough steps, some steps are eligible to skip (not forced).""" + sched = self._get_magcache_schedule(40) + self.assertGreater(sum(not f for f in sched["force_compute"]), 0, "Expected some cacheable steps") + + def test_warmup_zones_scale_with_steps(self): + for n in [20, 40, 80]: + sched = self._get_magcache_schedule(n) + self.assertTrue(all(sched["force_compute"][: sched["high_warmup_end"]]), f"high warmup forced @ {n}") + self.assertTrue( + all(sched["force_compute"][sched["high_noise_steps"] : sched["low_warmup_end"]]), + f"low warmup forced @ {n}", + ) + + def test_global_step_indexing_in_bounds(self): + """mag_ratios is indexed by GLOBAL step as [2*step] (cond) / [2*step+1] (uncond).""" + n = 40 + mag_ratios = np.ones(2 * n) + for step in range(n): + _ = mag_ratios[step * 2] + _ = mag_ratios[step * 2 + 1] # must not raise + self.assertEqual(len(mag_ratios), 2 * n) + + +class WanMagCacheCoreTest(unittest.TestCase): + """Pure-host tests for init_magcache / magcache_step (no TPU, CI-safe). + + These confirm the skip schedule is deterministic given constant mag_ratios, + which is the property that lets MagCache live in the host-side denoise loop. + """ + + def test_init_passthrough_when_double_length(self): + """A curve already of length 2*steps is used verbatim (no interpolation).""" + n = 5 + base = list(np.linspace(1.0, 0.9, 2 * n)) + out = init_magcache(n, 0.2, base) + mag_ratios = out[8] + self.assertEqual(len(mag_ratios), 2 * n) + np.testing.assert_allclose(mag_ratios, np.array(base)) + + def test_init_interpolates_when_mismatched(self): + """A shorter curve is nearest-interpolated up to length 2*steps.""" + n = 40 + base = list(np.linspace(1.0, 0.8, 2 * 20)) # 20-step curve, run is 40 steps + out = init_magcache(n, 0.2, base) + self.assertEqual(len(out[8]), 2 * n) + + def test_init_skip_warmup(self): + self.assertEqual(init_magcache(40, 0.2, list(np.ones(80)))[7], int(40 * 0.2)) + + def test_disabled_never_skips(self): + """use_magcache=False forces a compute and leaves accumulators untouched.""" + n = 10 + mag_ratios = np.ones(2 * n) + state = init_magcache(n, 0.2, list(mag_ratios))[:6] + skip, new_state = magcache_step(3, mag_ratios, state, magcache_thresh=1.0, magcache_K=99, use_magcache=False) + self.assertFalse(skip) + self.assertEqual(new_state, state) + + def test_skips_when_under_threshold(self): + """Ratios ~1.0 with a generous threshold should skip.""" + n = 10 + mag_ratios = np.ones(2 * n) + state = init_magcache(n, 0.2, list(mag_ratios))[:6] + skip, _ = magcache_step(3, mag_ratios, state, magcache_thresh=0.5, magcache_K=99, use_magcache=True) + self.assertTrue(skip) + + def test_resets_when_over_threshold(self): + """A ratio far from 1.0 exceeds the error budget -> no skip + accumulators reset.""" + n = 10 + mag_ratios = np.full(2 * n, 0.5) # err = |1 - 0.5| = 0.5 per step + state = init_magcache(n, 0.2, list(mag_ratios))[:6] + skip, new_state = magcache_step(3, mag_ratios, state, magcache_thresh=0.04, magcache_K=99, use_magcache=True) + self.assertFalse(skip) + self.assertEqual(new_state, (1.0, 1.0, 0.0, 0.0, 0, 0)) + + def test_K_caps_consecutive_skips(self): + """Even with err=0 (ratio 1.0), no more than magcache_K consecutive skips.""" + n = 20 + K = 2 + mag_ratios = np.ones(2 * n) + state = init_magcache(n, 0.0, list(mag_ratios))[:6] + consecutive = 0 + max_consecutive = 0 + for step in range(n): + skip, state = magcache_step(step, mag_ratios, state, magcache_thresh=1.0, magcache_K=K, use_magcache=True) + if skip: + consecutive += 1 + max_consecutive = max(max_consecutive, consecutive) + else: + consecutive = 0 + self.assertLessEqual(max_consecutive, K, f"skipped {max_consecutive} in a row, K={K}") + + def test_requires_both_cond_and_uncond_under_threshold(self): + """If the uncond branch blows the budget, the step is not skipped even if cond is fine.""" + n = 10 + mag_ratios = np.ones(2 * n) + mag_ratios[3 * 2 + 1] = 0.5 # uncond at step 3 is far from 1.0 + state = init_magcache(n, 0.2, list(mag_ratios))[:6] + skip, _ = magcache_step(3, mag_ratios, state, magcache_thresh=0.04, magcache_K=99, use_magcache=True) + self.assertFalse(skip) + + +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Requires TPU v7-8 and model weights") +class WanMagCacheSmokeTest(unittest.TestCase): + """End-to-end smoke test: MagCache should be faster with PSNR >= 30 dB, SSIM >= 0.95. + + Runs on TPU (WAN 2.2 27B T2V, 720p). Skipped in CI (GitHub Actions) — run with: + python -m pytest src/maxdiffusion/tests/wan/wan_mag_cache_test.py::WanMagCacheSmokeTest -v + """ + + @classmethod + def setUpClass(cls): + from maxdiffusion import pyconfig + from maxdiffusion.checkpointing.wan_checkpointer_2_2 import WanCheckpointer2_2 + + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "..", "configs", "base_wan_27b.yml"), + "num_inference_steps=40", + "height=720", + "width=1280", + "num_frames=81", + "fps=16", + "guidance_scale_low=3.0", + "guidance_scale_high=4.0", + "boundary_ratio=0.875", + "flow_shift=12.0", + "seed=118445", + "attention=flash", + "remat_policy=FULL", + "allow_split_physical_axes=True", + "skip_jax_distributed_system=True", + "weights_dtype=bfloat16", + "activations_dtype=bfloat16", + "per_device_batch_size=0.125", + "ici_data_parallelism=1", + "ici_fsdp_parallelism=1", + "ici_context_parallelism=8", + "ici_tensor_parallelism=1", + "flash_min_seq_length=0", + ], + unittest=True, + ) + cls.config = pyconfig.config + checkpoint_loader = WanCheckpointer2_2(config=cls.config) + cls.pipeline, _, _ = checkpoint_loader.load_checkpoint() + + cls.prompt = [cls.config.prompt] * cls.config.global_batch_size_to_train_on + cls.negative_prompt = [cls.config.negative_prompt] * cls.config.global_batch_size_to_train_on + + for use_cache in [False, True]: # warm up both XLA code paths + cls.pipeline( + prompt=cls.prompt, + negative_prompt=cls.negative_prompt, + height=cls.config.height, + width=cls.config.width, + num_frames=cls.config.num_frames, + num_inference_steps=cls.config.num_inference_steps, + guidance_scale_low=cls.config.guidance_scale_low, + guidance_scale_high=cls.config.guidance_scale_high, + use_magcache=use_cache, + ) + + def _run_pipeline(self, use_magcache): + t0 = time.perf_counter() + videos = self.pipeline( + prompt=self.prompt, + negative_prompt=self.negative_prompt, + height=self.config.height, + width=self.config.width, + num_frames=self.config.num_frames, + num_inference_steps=self.config.num_inference_steps, + guidance_scale_low=self.config.guidance_scale_low, + guidance_scale_high=self.config.guidance_scale_high, + use_magcache=use_magcache, + ) + return videos, time.perf_counter() - t0 + + def test_magcache_speedup_and_fidelity(self): + videos_baseline, t_baseline = self._run_pipeline(use_magcache=False) + videos_cached, t_cached = self._run_pipeline(use_magcache=True) + + speedup = t_baseline / t_cached + print(f"Baseline: {t_baseline:.2f}s, MagCache: {t_cached:.2f}s, Speedup: {speedup:.3f}x") + self.assertGreater(speedup, 1.0, f"MagCache should be faster. Speedup={speedup:.3f}x") + + v1 = np.array(videos_baseline[0], dtype=np.float64) + v2 = np.array(videos_cached[0], dtype=np.float64) + + mse = np.mean((v1 - v2) ** 2) + psnr = 10.0 * np.log10(1.0 / mse) if mse > 0 else float("inf") + print(f"PSNR: {psnr:.2f} dB") + self.assertGreaterEqual(psnr, 30.0, f"PSNR={psnr:.2f} dB < 30 dB") + + C1, C2 = 0.01**2, 0.03**2 + ssim_scores = [] + for f in range(v1.shape[0]): + mu1, mu2 = np.mean(v1[f]), np.mean(v2[f]) + sigma1_sq, sigma2_sq = np.var(v1[f]), np.var(v2[f]) + sigma12 = np.mean((v1[f] - mu1) * (v2[f] - mu2)) + ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2)) + ssim_scores.append(float(ssim)) + mean_ssim = np.mean(ssim_scores) + print(f"SSIM: mean={mean_ssim:.4f}, min={np.min(ssim_scores):.4f}") + self.assertGreaterEqual(mean_ssim, 0.95, f"Mean SSIM={mean_ssim:.4f} < 0.95") + + @classmethod + def tearDownClass(cls): + del cls.pipeline + import gc + + gc.collect() + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file From 007e041ddc9a077b6be7274ca2a3b8825a17a0ea Mon Sep 17 00:00:00 2001 From: HadarIngonyama Date: Mon, 29 Jun 2026 21:49:17 +0300 Subject: [PATCH 2/4] Add MagCache support for Wan2.2 I2V - wan_pipeline_i2v_2p2.py: MagCache skip path for the dual-transformer I2V pipeline, mirroring the T2V 2.2 logic (per-phase retention/forced-compute zones, residual reset at the high->low boundary, single interleaved mag_ratios_base curve) with I2V-specific handling for the image condition (concat with latents + BFHWC<->BCFHW transposes) - generate_wan.py: pass use_magcache / magcache_thresh / magcache_K / retention_ratio through to the 2.2 I2V pipeline - base_wan_i2v_27b.yml: MagCache params + the official I2V-A14B mag_ratios_base, and boundary_ratio=0.900 to align the high->low switch with the curve (flow_shift stays at the I2V default of 5.0) - README: document MagCache for Wan2.2 I2V (settings + ~1.75x speedup, SSIM/PSNR vs dense) and add it to the caching support matrix --- README.md | 28 ++-- src/maxdiffusion/configs/base_wan_i2v_27b.yml | 22 ++- src/maxdiffusion/generate_wan.py | 4 + .../pipelines/wan/wan_pipeline_i2v_2p2.py | 149 +++++++++++++++++- 4 files changed, 190 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 63d57960..605a1533 100755 --- a/README.md +++ b/README.md @@ -607,7 +607,7 @@ To generate images, run the following command: | --- | --- | --- | --- | --- | | **CFG Cache** | `use_cfg_cache: True` | Wan 2.1 T2V, Wan 2.2 T2V/I2V | ~1.2x | FasterCache-style: caches the unconditional branch and applies FFT frequency-domain compensation on skipped steps. | | **SenCache** | `use_sen_cache: True` | Wan 2.2 T2V/I2V | ~1.4x | Sensitivity-Aware Caching ([arXiv:2602.24208](https://arxiv.org/abs/2602.24208)): predicts output change via first-order sensitivity S = α_x·‖Δx‖ + α_t·\|Δt\|. Skips the full CFG forward pass when predicted change is below tolerance ε. | - | **MagCache** | `use_magcache: True` | Wan 2.1 T2V, Wan 2.2 T2V | ~1.8–1.9x | [MagCache](https://github.com/Zehong-Ma/MagCache): skips the transformer blocks and reuses the cached block residual when the accumulated magnitude-ratio error stays below `magcache_thresh`, capped at `magcache_K` consecutive skips. Uses a precalibrated per-step `mag_ratios_base` curve, so the skip schedule is deterministic (no data-dependent control flow). | + | **MagCache** | `use_magcache: True` | Wan 2.1 T2V, Wan 2.2 T2V/I2V | ~1.75–1.9x | [MagCache](https://github.com/Zehong-Ma/MagCache): skips the transformer blocks and reuses the cached block residual when the accumulated magnitude-ratio error stays below `magcache_thresh`, capped at `magcache_K` consecutive skips. Uses a precalibrated per-step `mag_ratios_base` curve, so the skip schedule is deterministic (no data-dependent control flow). | For Wan 2.2 (dual-transformer), MagCache keeps a single `mag_ratios_base` curve spanning both the high-noise and low-noise phases, forces a full recompute for the first `retention_ratio` fraction of each phase, and resets the cached residual at the high→low boundary. The shipped `mag_ratios_base` in `base_wan_27b.yml` are seeded from the official Wan2.2 values; recalibrating them for your exact setup (model dtype / attention kernel) improves the speedup/quality trade-off. @@ -615,6 +615,8 @@ To generate images, run the following command: Measured on a v7x (Wan 2.2 A14B T2V, 720×1280, 81 frames, 40 steps, `flow_shift=12.0`, seeded ratios at `magcache_thresh=0.04`, `magcache_K=2`): **~1.82× speedup** (18/40 steps skipped, denoise 360s → 198s) at **SSIM ≈ 0.72 / PSNR ≈ 21.8 dB** versus the dense (`use_magcache=False`) render with the same seed/config. These reference-based metrics mostly reflect *trajectory divergence* — caching nudges the sampler onto a different but equally plausible sample — rather than visible degradation; the cached clips are visually hard to tell apart from dense. Recalibrating `mag_ratios_base` for your exact dtype/attention kernel tightens the metric gap further. +Wan 2.2 I2V (`base_wan_i2v_27b.yml`) uses the official I2V-A14B `mag_ratios_base` together with the I2V sampling settings (`flow_shift=5.0`, `boundary_ratio=0.900`) that the curve is aligned to. Measured on a v7x (Wan 2.2 A14B I2V, 720×1280, 81 frames, 40 steps, `magcache_thresh=0.06`, `magcache_K=2`): **~1.75× speedup** (17/40 steps skipped, 6.30s → 3.61s per step) at **SSIM ≈ 0.91 / PSNR ≈ 25.4 dB** versus the dense render. Fidelity is higher than T2V because the image conditioning anchors the sampling trajectory, leaving less room for divergence. + To enable a caching mechanism, set the corresponding flag in your config YAML or pass it as a command-line override: ```bash @@ -630,14 +632,22 @@ To generate images, run the following command: use_cfg_cache=True \ ... - # Example: enable MagCache for Wan 2.2 T2V - python src/maxdiffusion/generate_wan.py \ - src/maxdiffusion/configs/base_wan_27b.yml \ - use_magcache=True \ - magcache_thresh=0.04 \ - magcache_K=2 \ - ... - ``` +# Example: enable MagCache for Wan 2.2 T2V +python src/maxdiffusion/generate_wan.py \ + src/maxdiffusion/configs/base_wan_27b.yml \ + use_magcache=True \ + magcache_thresh=0.04 \ + magcache_K=2 \ + ... + +# Example: enable MagCache for Wan 2.2 I2V +python src/maxdiffusion/generate_wan.py \ + src/maxdiffusion/configs/base_wan_i2v_27b.yml \ + use_magcache=True \ + magcache_thresh=0.06 \ + magcache_K=2 \ + ... +``` ### Ring Attention We added ring attention support for Wan models. Below are the stats for one `720p` (81 frames) video generation (with CFG DP): diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index be7cfaef..5f2e8c88 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -337,7 +337,7 @@ width: 1280 num_frames: 81 flow_shift: 5.0 -# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py +# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_i2v_A14B.py # guidance scale factor for low noise transformer guidance_scale_low: 3.0 @@ -346,8 +346,10 @@ guidance_scale_high: 4.0 # The timestep threshold. If `t` is at or above this value, # the `high_noise_model` is considered as the required model. -# timestep to switch between low noise and high noise transformer -boundary_ratio: 0.875 +# timestep to switch between low noise and high noise transformer. +# Official Wan2.2 I2V-A14B uses boundary=0.900 (vs 0.875 for T2V); this sets +# the high->low expert switch the MagCache ratios below are calibrated against. +boundary_ratio: 0.900 # Diffusion CFG cache (FasterCache-style) use_cfg_cache: False @@ -359,6 +361,20 @@ use_kv_cache: False # SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) use_sen_cache: False +# ── MagCache (Ma et al., https://github.com/Zehong-Ma/MagCache) ── +# Skips the transformer blocks on steps whose accumulated magnitude-ratio error +# stays below `magcache_thresh`, reusing the cached block residual. Official +# Wan2.2 I2V recommendation: thresh 0.06, K 2, retention 0.1-0.2. +use_magcache: False +magcache_thresh: 0.06 +magcache_K: 2 +retention_ratio: 0.2 +# Average magnitude ratios, interleaved [cond, uncond, ...], length +# num_inference_steps*2 (= 80 for 40 steps). Single curve spanning both phases; +# the dip near step 15 marks the high->low boundary (boundary=0.900, flow_shift=5). +# Official Wan2.2 I2V-A14B ratios (upstream MagCache, WanI2V branch). +mag_ratios_base: [1.0, 1.0, 0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902] + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 40 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index d5eaf4d3..ebf7b3c9 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -124,6 +124,10 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): use_cfg_cache=config.use_cfg_cache, use_sen_cache=config.use_sen_cache, use_kv_cache=config.use_kv_cache, + use_magcache=config.use_magcache, + magcache_thresh=config.magcache_thresh, + magcache_K=config.magcache_K, + retention_ratio=config.retention_ratio, ) else: raise ValueError(f"Unsupported model_name for I2V in config: {model_key}") diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index f071c231..865a0467 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -14,7 +14,14 @@ from maxdiffusion.image_processor import PipelineImageInput from maxdiffusion import max_logging -from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache +from .wan_pipeline import ( + WanPipeline, + transformer_forward_pass, + transformer_forward_pass_full_cfg, + transformer_forward_pass_cfg_cache, + init_magcache, + magcache_step, +) from ...models.wan.transformers.transformer_wan import WanModel from typing import List, Union, Optional, Tuple from ...pyconfig import HyperParameters @@ -188,14 +195,35 @@ def __call__( use_cfg_cache: bool = False, use_sen_cache: bool = False, use_kv_cache: bool = False, + use_magcache: bool = False, + magcache_thresh: Optional[float] = None, + magcache_K: Optional[int] = None, + retention_ratio: Optional[float] = None, ): config = getattr(self, "config", None) if max_sequence_length is None: max_sequence_length = getattr(config, "max_sequence_length", 512) + if magcache_thresh is None: + magcache_thresh = getattr(config, "magcache_thresh", 0.06) + if magcache_K is None: + magcache_K = getattr(config, "magcache_K", 2) + if retention_ratio is None: + retention_ratio = getattr(config, "retention_ratio", 0.2) + + if sum([use_cfg_cache, use_sen_cache, use_magcache]) > 1: + raise ValueError("use_cfg_cache, use_sen_cache and use_magcache are mutually exclusive. Enable only one.") + if use_cfg_cache and use_sen_cache: raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.") + if use_magcache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0): + raise ValueError( + f"use_magcache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 " + f"(got {guidance_scale_low}, {guidance_scale_high}). " + "MagCache requires classifier-free guidance to be enabled for both transformer phases." + ) + if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0): raise ValueError( f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 " @@ -309,6 +337,10 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): image_embeds=image_embeds, use_cfg_cache=use_cfg_cache, use_sen_cache=use_sen_cache, + use_magcache=use_magcache, + magcache_thresh=magcache_thresh, + magcache_K=magcache_K, + retention_ratio=retention_ratio, height=height, config=self.config, use_kv_cache=use_kv_cache, @@ -366,6 +398,10 @@ def run_inference_2_2_i2v( scheduler_state, use_cfg_cache: bool = False, use_sen_cache: bool = False, + use_magcache: bool = False, + magcache_thresh: float = 0.06, + magcache_K: int = 2, + retention_ratio: float = 0.2, height: int = 480, config=None, use_kv_cache: bool = False, @@ -404,6 +440,117 @@ def run_inference_2_2_i2v( prompt_embeds_combined, image_embeds_combined ) + # ── MagCache path (Ma et al., https://github.com/Zehong-Ma/MagCache) ── + # Skips the transformer blocks on steps whose accumulated magnitude-ratio error + # stays below `magcache_thresh`, reusing the cached block residual instead. The + # skip schedule is fully static (the ratios are calibration constants), so the + # decision is made host-side and only a static `skip_blocks` bool crosses into + # the forward pass. Dual-transformer handling: one calibrated curve spans both + # phases (official `mag_ratios` layout), with a forced-compute zone at the start + # of each phase and an explicit cache reset at the high->low boundary. + if use_magcache and do_classifier_free_guidance: + timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) + step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)] + high_noise_steps = sum(step_uses_high) + + mag_ratios_base = getattr(config, "mag_ratios_base", None) if config else None + if mag_ratios_base is None: + raise ValueError( + "use_magcache=True requires config.mag_ratios_base — the magnitude ratios " + "(interleaved cond/uncond, length num_inference_steps*2). Run the calibration " + "pass or use the published WAN 2.2 I2V ratios." + ) + + # Single state + single ratio curve spanning both phases (official layout). + magcache_init = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base) + accumulated_state = magcache_init[:6] + cached_residual = magcache_init[6] + mag_ratios = magcache_init[8] + + # Forced-compute ("retention") zones, in step units: the first + # `retention_ratio` fraction of each phase always computes in full. This also + # flushes the stale residual right after the boundary. + high_warmup_end = int(high_noise_steps * retention_ratio) + low_warmup_end = high_noise_steps + int((num_inference_steps - high_noise_steps) * retention_ratio) + + condition_doubled = jnp.concatenate([condition] * 2) + + cache_count = 0 + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + + # Select transformer + guidance for this phase. + if step_uses_high[step]: + graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high + else: + graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest + guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low + + # Boundary reset: the high-noise transformer's residual is meaningless to + # the low-noise transformer, so start the low phase with a fresh cache. + is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1] + if is_boundary: + cached_residual = None + accumulated_state = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)[:6] + + # Force a full compute inside either phase's warmup zone, at the boundary, + # or until we have a residual to reuse. + in_warmup = step < high_warmup_end or (high_noise_steps <= step < low_warmup_end) + force_compute = in_warmup or is_boundary or cached_residual is None + + skip_blocks, accumulated_state = magcache_step( + step, + mag_ratios, + accumulated_state, + magcache_thresh, + magcache_K, + use_magcache=(not force_compute), + ) + + # I2V input: concat the image-conditioning latents on the channel axis, then + # BFHWC -> BCFHW for the transformer (mirrors the dense/SenCache paths). + latents_doubled = jnp.concatenate([latents, latents], axis=0) + latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1) + latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) + timestep = jnp.broadcast_to(t, bsz * 2) + noise_pred, _, residual_x_cur = transformer_forward_pass( + graphdef, + state, + rest, + latent_model_input, + timestep, + prompt_embeds_combined, + do_classifier_free_guidance=True, + guidance_scale=guidance_scale, + encoder_hidden_states_image=image_embeds_combined, + skip_blocks=bool(skip_blocks), + cached_residual=cached_residual, + return_residual=True, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, + ) + noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) + + if skip_blocks: + cache_count += 1 + else: + cached_residual = residual_x_cur + + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + + max_logging.log( + f"[MagCache] Cached {cache_count}/{num_inference_steps} steps " + f"({100*cache_count/num_inference_steps:.1f}% cache ratio), " + f"high_noise_steps={high_noise_steps}, thresh={magcache_thresh}, K={magcache_K}" + ) + return latents + # ── SenCache path (arXiv:2602.24208) ── if use_sen_cache and do_classifier_free_guidance: timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) From 56621c663ff1b9a01f0ec2b6ee998513cc3358ef Mon Sep 17 00:00:00 2001 From: HadarIngonyama Date: Mon, 29 Jun 2026 21:58:35 +0300 Subject: [PATCH 3/4] fix linting --- src/maxdiffusion/generate_wan.py | 2 +- src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py | 2 +- src/maxdiffusion/tests/wan/wan2_2_magcache_test.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index ebf7b3c9..dd3e5a71 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -392,4 +392,4 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": with transformer_engine_context(): - app.run(main) \ No newline at end of file + app.run(main) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 077fbfb2..6fb30279 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -840,4 +840,4 @@ def scan_body(carry, t): if profiler: latents.block_until_ready() profiler.stop() - return latents \ No newline at end of file + return latents diff --git a/src/maxdiffusion/tests/wan/wan2_2_magcache_test.py b/src/maxdiffusion/tests/wan/wan2_2_magcache_test.py index 954e0f87..31034889 100644 --- a/src/maxdiffusion/tests/wan/wan2_2_magcache_test.py +++ b/src/maxdiffusion/tests/wan/wan2_2_magcache_test.py @@ -386,4 +386,4 @@ def tearDownClass(cls): if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From 944f130e98be55ef4e4083e24612e57dbac860c5 Mon Sep 17 00:00:00 2001 From: HadarIngonyama Date: Mon, 29 Jun 2026 22:22:19 +0300 Subject: [PATCH 4/4] rebase --- src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index 865a0467..a056809c 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -214,9 +214,6 @@ def __call__( if sum([use_cfg_cache, use_sen_cache, use_magcache]) > 1: raise ValueError("use_cfg_cache, use_sen_cache and use_magcache are mutually exclusive. Enable only one.") - if use_cfg_cache and use_sen_cache: - raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.") - if use_magcache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0): raise ValueError( f"use_magcache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "