Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,15 @@ To generate images, run the following command:
| --- | --- | --- | --- | --- |
| **CFG Cache** | `use_cfg_cache: True` | Wan 2.1 T2V, Wan 2.2 T2V/I2V | ~1.2x | FasterCache-style: caches the unconditional branch and applies FFT frequency-domain compensation on skipped steps. |
| **SenCache** | `use_sen_cache: True` | Wan 2.2 T2V/I2V | ~1.4x | Sensitivity-Aware Caching ([arXiv:2602.24208](https://arxiv.org/abs/2602.24208)): predicts output change via first-order sensitivity S = α_x·‖Δx‖ + α_t·\|Δt\|. Skips the full CFG forward pass when predicted change is below tolerance ε. |
| **MagCache** | `use_magcache: True` | Wan 2.1 T2V, Wan 2.2 T2V/I2V | ~1.75–1.9x | [MagCache](https://github.com/Zehong-Ma/MagCache): skips the transformer blocks and reuses the cached block residual when the accumulated magnitude-ratio error stays below `magcache_thresh`, capped at `magcache_K` consecutive skips. Uses a precalibrated per-step `mag_ratios_base` curve, so the skip schedule is deterministic (no data-dependent control flow). |

For Wan 2.2 (dual-transformer), MagCache keeps a single `mag_ratios_base` curve spanning both the high-noise and low-noise phases, forces a full recompute for the first `retention_ratio` fraction of each phase, and resets the cached residual at the high→low boundary. The shipped `mag_ratios_base` in `base_wan_27b.yml` are seeded from the official Wan2.2 values; recalibrating them for your exact setup (model dtype / attention kernel) improves the speedup/quality trade-off.

> **Wan 2.2 T2V requires `flow_shift=12.0`** (the official A14B sampling shift; `base_wan_27b.yml` now defaults to it). `flow_shift` controls where the high→low noise boundary lands, which is the boundary the `mag_ratios_base` curve is calibrated against — a lower shift (e.g. the old `5.0`) moves the boundary several steps out of phase, so MagCache skips at the wrong steps and quality drops. This also corrects the off-spec dense baseline.
Measured on a v7x (Wan 2.2 A14B T2V, 720×1280, 81 frames, 40 steps, `flow_shift=12.0`, seeded ratios at `magcache_thresh=0.04`, `magcache_K=2`): **~1.82× speedup** (18/40 steps skipped, denoise 360s → 198s) at **SSIM ≈ 0.72 / PSNR ≈ 21.8 dB** versus the dense (`use_magcache=False`) render with the same seed/config. These reference-based metrics mostly reflect *trajectory divergence* — caching nudges the sampler onto a different but equally plausible sample — rather than visible degradation; the cached clips are visually hard to tell apart from dense. Recalibrating `mag_ratios_base` for your exact dtype/attention kernel tightens the metric gap further.

Wan 2.2 I2V (`base_wan_i2v_27b.yml`) uses the official I2V-A14B `mag_ratios_base` together with the I2V sampling settings (`flow_shift=5.0`, `boundary_ratio=0.900`) that the curve is aligned to. Measured on a v7x (Wan 2.2 A14B I2V, 720×1280, 81 frames, 40 steps, `magcache_thresh=0.06`, `magcache_K=2`): **~1.75× speedup** (17/40 steps skipped, 6.30s → 3.61s per step) at **SSIM ≈ 0.91 / PSNR ≈ 25.4 dB** versus the dense render. Fidelity is higher than T2V because the image conditioning anchors the sampling trajectory, leaving less room for divergence.
Comment on lines +610 to +618

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can reduce the text content here.

Minor nit: It is 7x and not v7x

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do. I will change to 7x, but note that the whole readme has this mistake (and it exists also in other places in the repo)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, we will put put a PR to fix those


To enable a caching mechanism, set the corresponding flag in your config YAML or pass it as a command-line override:

Expand All @@ -622,7 +631,23 @@ To generate images, run the following command:
src/maxdiffusion/configs/base_wan_i2v_27b.yml \
use_cfg_cache=True \
...
```

# Example: enable MagCache for Wan 2.2 T2V
python src/maxdiffusion/generate_wan.py \
src/maxdiffusion/configs/base_wan_27b.yml \
use_magcache=True \
magcache_thresh=0.04 \
magcache_K=2 \
...

# Example: enable MagCache for Wan 2.2 I2V
python src/maxdiffusion/generate_wan.py \
src/maxdiffusion/configs/base_wan_i2v_27b.yml \
use_magcache=True \
magcache_thresh=0.06 \
magcache_K=2 \
...
```

### Ring Attention
We added ring attention support for Wan models. Below are the stats for one `720p` (81 frames) video generation (with CFG DP):
Expand Down Expand Up @@ -819,4 +844,4 @@ This script will automatically format your code with `pyink` and help you identi
The full suite of -end-to end tests is in `tests` and `src/maxdiffusion/tests`. We run them with a nightly cadance.
## Profiling
To learn how to enable ML Diagnostics and XProf profiling for your runs, please see our [ML Diagnostics Guide](docs/profiling.md).
To learn how to enable ML Diagnostics and XProf profiling for your runs, please see our [ML Diagnostics Guide](docs/profiling.md).
22 changes: 20 additions & 2 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,11 @@ do_classifier_free_guidance: True
height: 720
width: 1280
num_frames: 81
flow_shift: 5.0
# Official Wan2.2 T2V A14B sampling shift (wan_t2v_A14B.py: sample_shift=12.0). This sets where the
# high->low transformer boundary falls in the step schedule (~step 26 of 40 at boundary_ratio 0.875),
# which must match the schedule the seeded mag_ratios_base was calibrated at. (A smaller shift moves
# the boundary earlier, which both degrades the base sample and misaligns MagCache's skip schedule.)
flow_shift: 12.0

# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py
# guidance scale factor for low noise transformer
Expand All @@ -361,6 +365,20 @@ use_kv_cache: False
# when predicted output change (based on accumulated latent/timestep drift) is small
use_sen_cache: False

# MagCache (https://github.com/Zehong-Ma/MagCache) — skip transformer blocks when
# the accumulated magnitude-ratio error stays below `magcache_thresh`, reusing the
# cached block residual. `magcache_K` caps consecutive skips; `retention_ratio` is
# the fraction at the start of each phase that always computes in full.
use_magcache: False
magcache_thresh: 0.04
magcache_K: 2
retention_ratio: 0.2
# Calibrated average magnitude ratios, interleaved [cond, uncond, ...], length
# num_inference_steps*2 (= 80 for 40 steps). Single curve spanning both phases;
# the dip near the middle marks the high->low boundary. Seeded from the official
# WAN 2.2 T2V ratios — recalibrate for this setup to tune the speedup/quality.
mag_ratios_base: [1.0, 1.0, 1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181]

# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 40
Expand Down Expand Up @@ -422,4 +440,4 @@ enable_ssim: False
# ML Diagnostics settings
enable_ml_diagnostics: False
profiler_gcs_path: ""
enable_ondemand_xprof: False
enable_ondemand_xprof: False
22 changes: 19 additions & 3 deletions src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ width: 1280
num_frames: 81
flow_shift: 5.0

# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_t2v_A14B.py
# Reference for below guidance scale and boundary values: https://github.com/Wan-Video/Wan2.2/blob/main/wan/configs/wan_i2v_A14B.py
# guidance scale factor for low noise transformer
guidance_scale_low: 3.0

Expand All @@ -346,8 +346,10 @@ guidance_scale_high: 4.0

# The timestep threshold. If `t` is at or above this value,
# the `high_noise_model` is considered as the required model.
# timestep to switch between low noise and high noise transformer
boundary_ratio: 0.875
# timestep to switch between low noise and high noise transformer.
# Official Wan2.2 I2V-A14B uses boundary=0.900 (vs 0.875 for T2V); this sets
# the high->low expert switch the MagCache ratios below are calibrated against.
boundary_ratio: 0.900

# Diffusion CFG cache (FasterCache-style)
use_cfg_cache: False
Expand All @@ -359,6 +361,20 @@ use_kv_cache: False
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208)
use_sen_cache: False

# ── MagCache (Ma et al., https://github.com/Zehong-Ma/MagCache) ──
# Skips the transformer blocks on steps whose accumulated magnitude-ratio error
# stays below `magcache_thresh`, reusing the cached block residual. Official
# Wan2.2 I2V recommendation: thresh 0.06, K 2, retention 0.1-0.2.
use_magcache: False
magcache_thresh: 0.06
magcache_K: 2
retention_ratio: 0.2
# Average magnitude ratios, interleaved [cond, uncond, ...], length
# num_inference_steps*2 (= 80 for 40 steps). Single curve spanning both phases;
# the dip near step 15 marks the high->low boundary (boundary=0.900, flow_shift=5).
# Official Wan2.2 I2V-A14B ratios (upstream MagCache, WanI2V branch).
mag_ratios_base: [1.0, 1.0, 0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902]

# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 40
Expand Down
8 changes: 8 additions & 0 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
use_cfg_cache=config.use_cfg_cache,
use_sen_cache=config.use_sen_cache,
use_kv_cache=config.use_kv_cache,
use_magcache=config.use_magcache,
magcache_thresh=config.magcache_thresh,
magcache_K=config.magcache_K,
retention_ratio=config.retention_ratio,
)
else:
raise ValueError(f"Unsupported model_name for I2V in config: {model_key}")
Expand Down Expand Up @@ -157,6 +161,10 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
use_cfg_cache=config.use_cfg_cache,
use_sen_cache=config.use_sen_cache,
use_kv_cache=config.use_kv_cache,
use_magcache=config.use_magcache,
magcache_thresh=config.magcache_thresh,
magcache_K=config.magcache_K,
retention_ratio=config.retention_ratio,
)
else:
raise ValueError(f"Unsupported model_name for T2V in config: {model_key}")
Expand Down
135 changes: 132 additions & 3 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache
from .wan_pipeline import (
WanPipeline,
transformer_forward_pass,
transformer_forward_pass_full_cfg,
transformer_forward_pass_cfg_cache,
init_magcache,
magcache_step,
)
from ...models.wan.transformers.transformer_wan import WanModel
from typing import List, Union, Optional
from ...pyconfig import HyperParameters
Expand Down Expand Up @@ -118,13 +125,24 @@ def __call__(
use_cfg_cache: bool = False,
use_sen_cache: bool = False,
use_kv_cache: bool = False,
use_magcache: bool = False,
magcache_thresh: float = 0.04,
magcache_K: int = 2,
retention_ratio: float = 0.2,
):
config = getattr(self, "config", None)
if max_sequence_length is None:
max_sequence_length = getattr(config, "max_sequence_length", 512)

if use_cfg_cache and use_sen_cache:
raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.")
if sum([use_cfg_cache, use_sen_cache, use_magcache]) > 1:
raise ValueError("use_cfg_cache, use_sen_cache and use_magcache are mutually exclusive. Enable only one.")

if use_magcache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
raise ValueError(
f"use_magcache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
f"(got {guidance_scale_low}, {guidance_scale_high}). "
"MagCache reuses the cached residual across the doubled CFG batch, which must be enabled for both phases."
)

if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
raise ValueError(
Expand Down Expand Up @@ -183,6 +201,10 @@ def __call__(
use_sen_cache=use_sen_cache,
height=height,
use_kv_cache=use_kv_cache,
use_magcache=use_magcache,
magcache_thresh=magcache_thresh,
magcache_K=magcache_K,
retention_ratio=retention_ratio,
)

t_denoise_start = time.perf_counter()
Expand Down Expand Up @@ -233,6 +255,10 @@ def run_inference_2_2(
height: int = 480,
config=None,
use_kv_cache: bool = False,
use_magcache: bool = False,
magcache_thresh: float = 0.04,
magcache_K: int = 2,
retention_ratio: float = 0.2,
):
"""Denoising loop for WAN 2.2 T2V with optional caching acceleration.

Expand Down Expand Up @@ -279,6 +305,109 @@ def run_inference_2_2(
high_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest)
kv_cache_high, encoder_attention_mask_high = high_transformer.compute_kv_cache(prompt_embeds_combined)

# ── MagCache path (Ma et al., https://github.com/Zehong-Ma/MagCache) ──
# Skips the transformer blocks on steps whose accumulated magnitude-ratio error
# stays below `magcache_thresh`, reusing the cached block residual instead.
# The skip schedule is fully static (the ratios are calibration constants), so
# the decision is made host-side and only a static `skip_blocks` bool crosses
# into the forward pass. Dual-transformer handling: one calibrated curve spans
# both phases (the official `mag_ratios` layout), with a forced-compute zone at
# the start of each phase and an explicit cache reset at the high→low boundary.
if use_magcache and do_classifier_free_guidance:
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
high_noise_steps = sum(step_uses_high)

mag_ratios_base = getattr(config, "mag_ratios_base", None) if config else None
if mag_ratios_base is None:
raise ValueError(
"use_magcache=True requires config.mag_ratios_base — the calibrated magnitude ratios "
"(interleaved cond/uncond, length num_inference_steps*2). Run the calibration pass or "
"use the published WAN 2.2 ratios."
)

# Single state + single ratio curve spanning both phases (official layout).
magcache_init = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)
accumulated_state = magcache_init[:6]
cached_residual = magcache_init[6]
mag_ratios = magcache_init[8]

# Forced-compute ("retention") zones, in step units: the first
# `retention_ratio` fraction of each phase always computes in full. This is
# also what flushes the stale residual right after the boundary.
high_warmup_end = int(high_noise_steps * retention_ratio)
low_warmup_end = high_noise_steps + int((num_inference_steps - high_noise_steps) * retention_ratio)

cache_count = 0
for step in range(num_inference_steps):
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move jnp.array(scheduler_state.timesteps, dtype=jnp.int32) outside and save some compute

timesteps = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)
for step in range(num_inference_steps):
    t = timesteps[step]

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but note that for senchache and cfg cache in the same file the array is also inside the loop.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes those will be fixed in PR #432


# Select transformer + guidance for this phase.
if step_uses_high[step]:
graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest
guidance_scale = guidance_scale_high
kv_cache = kv_cache_high
encoder_attention_mask = encoder_attention_mask_high
else:
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
guidance_scale = guidance_scale_low
kv_cache = kv_cache_low
encoder_attention_mask = encoder_attention_mask_low

# Boundary reset: the high-noise transformer's residual is meaningless to
# the low-noise transformer, so start the low phase with a fresh cache.
is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1]
if is_boundary:
cached_residual = None
accumulated_state = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)[:6]

# Force a full compute inside either phase's warmup zone, at the boundary,
# or until we have a residual to reuse.
in_warmup = step < high_warmup_end or (high_noise_steps <= step < low_warmup_end)
force_compute = in_warmup or is_boundary or cached_residual is None

skip_blocks, accumulated_state = magcache_step(
step,
mag_ratios,
accumulated_state,
magcache_thresh,
magcache_K,
use_magcache=(not force_compute),
)

latents_doubled = jnp.concatenate([latents] * 2)
timestep = jnp.broadcast_to(t, bsz * 2)
noise_pred, _, residual_x_cur = transformer_forward_pass(
graphdef,
state,
rest,
latents_doubled,
timestep,
prompt_embeds_combined,
do_classifier_free_guidance=True,
guidance_scale=guidance_scale,
skip_blocks=bool(skip_blocks),
cached_residual=cached_residual,
return_residual=True,
kv_cache=kv_cache,
rotary_emb=rotary_emb,
encoder_attention_mask=encoder_attention_mask,
)

if skip_blocks:
cache_count += 1
else:
cached_residual = residual_x_cur

latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()

max_logging.log(
f"[MagCache] Cached {cache_count}/{num_inference_steps} steps "
f"({100*cache_count/num_inference_steps:.1f}% cache ratio), "
f"high_noise_steps={high_noise_steps}, thresh={magcache_thresh}, K={magcache_K}"
)
return latents

# ── SenCache path (arXiv:2602.24208) ──
if use_sen_cache and do_classifier_free_guidance:
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
Expand Down
Loading
Loading