-
Notifications
You must be signed in to change notification settings - Fork 82
Add MagCache inference acceleration for Wan2.2 (T2V + I2V) #433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,14 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache | ||
| from .wan_pipeline import ( | ||
| WanPipeline, | ||
| transformer_forward_pass, | ||
| transformer_forward_pass_full_cfg, | ||
| transformer_forward_pass_cfg_cache, | ||
| init_magcache, | ||
| magcache_step, | ||
| ) | ||
| from ...models.wan.transformers.transformer_wan import WanModel | ||
| from typing import List, Union, Optional | ||
| from ...pyconfig import HyperParameters | ||
|
|
@@ -118,13 +125,24 @@ def __call__( | |
| use_cfg_cache: bool = False, | ||
| use_sen_cache: bool = False, | ||
| use_kv_cache: bool = False, | ||
| use_magcache: bool = False, | ||
| magcache_thresh: float = 0.04, | ||
| magcache_K: int = 2, | ||
| retention_ratio: float = 0.2, | ||
| ): | ||
| config = getattr(self, "config", None) | ||
| if max_sequence_length is None: | ||
| max_sequence_length = getattr(config, "max_sequence_length", 512) | ||
|
|
||
| if use_cfg_cache and use_sen_cache: | ||
| raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.") | ||
| if sum([use_cfg_cache, use_sen_cache, use_magcache]) > 1: | ||
| raise ValueError("use_cfg_cache, use_sen_cache and use_magcache are mutually exclusive. Enable only one.") | ||
|
|
||
| if use_magcache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0): | ||
| raise ValueError( | ||
| f"use_magcache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 " | ||
| f"(got {guidance_scale_low}, {guidance_scale_high}). " | ||
| "MagCache reuses the cached residual across the doubled CFG batch, which must be enabled for both phases." | ||
| ) | ||
|
|
||
| if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0): | ||
| raise ValueError( | ||
|
|
@@ -183,6 +201,10 @@ def __call__( | |
| use_sen_cache=use_sen_cache, | ||
| height=height, | ||
| use_kv_cache=use_kv_cache, | ||
| use_magcache=use_magcache, | ||
| magcache_thresh=magcache_thresh, | ||
| magcache_K=magcache_K, | ||
| retention_ratio=retention_ratio, | ||
| ) | ||
|
|
||
| t_denoise_start = time.perf_counter() | ||
|
|
@@ -233,6 +255,10 @@ def run_inference_2_2( | |
| height: int = 480, | ||
| config=None, | ||
| use_kv_cache: bool = False, | ||
| use_magcache: bool = False, | ||
| magcache_thresh: float = 0.04, | ||
| magcache_K: int = 2, | ||
| retention_ratio: float = 0.2, | ||
| ): | ||
| """Denoising loop for WAN 2.2 T2V with optional caching acceleration. | ||
|
|
||
|
|
@@ -279,6 +305,109 @@ def run_inference_2_2( | |
| high_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest) | ||
| kv_cache_high, encoder_attention_mask_high = high_transformer.compute_kv_cache(prompt_embeds_combined) | ||
|
|
||
| # ── MagCache path (Ma et al., https://github.com/Zehong-Ma/MagCache) ── | ||
| # Skips the transformer blocks on steps whose accumulated magnitude-ratio error | ||
| # stays below `magcache_thresh`, reusing the cached block residual instead. | ||
| # The skip schedule is fully static (the ratios are calibration constants), so | ||
| # the decision is made host-side and only a static `skip_blocks` bool crosses | ||
| # into the forward pass. Dual-transformer handling: one calibrated curve spans | ||
| # both phases (the official `mag_ratios` layout), with a forced-compute zone at | ||
| # the start of each phase and an explicit cache reset at the high→low boundary. | ||
| if use_magcache and do_classifier_free_guidance: | ||
| timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) | ||
| step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)] | ||
| high_noise_steps = sum(step_uses_high) | ||
|
|
||
| mag_ratios_base = getattr(config, "mag_ratios_base", None) if config else None | ||
| if mag_ratios_base is None: | ||
| raise ValueError( | ||
| "use_magcache=True requires config.mag_ratios_base — the calibrated magnitude ratios " | ||
| "(interleaved cond/uncond, length num_inference_steps*2). Run the calibration pass or " | ||
| "use the published WAN 2.2 ratios." | ||
| ) | ||
|
|
||
| # Single state + single ratio curve spanning both phases (official layout). | ||
| magcache_init = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base) | ||
| accumulated_state = magcache_init[:6] | ||
| cached_residual = magcache_init[6] | ||
| mag_ratios = magcache_init[8] | ||
|
|
||
| # Forced-compute ("retention") zones, in step units: the first | ||
| # `retention_ratio` fraction of each phase always computes in full. This is | ||
| # also what flushes the stale residual right after the boundary. | ||
| high_warmup_end = int(high_noise_steps * retention_ratio) | ||
| low_warmup_end = high_noise_steps + int((num_inference_steps - high_noise_steps) * retention_ratio) | ||
|
|
||
| cache_count = 0 | ||
| for step in range(num_inference_steps): | ||
| t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can move
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, but note that for senchache and cfg cache in the same file the array is also inside the loop.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes those will be fixed in PR #432 |
||
|
|
||
| # Select transformer + guidance for this phase. | ||
| if step_uses_high[step]: | ||
| graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest | ||
| guidance_scale = guidance_scale_high | ||
| kv_cache = kv_cache_high | ||
| encoder_attention_mask = encoder_attention_mask_high | ||
| else: | ||
| graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest | ||
| guidance_scale = guidance_scale_low | ||
| kv_cache = kv_cache_low | ||
| encoder_attention_mask = encoder_attention_mask_low | ||
|
|
||
| # Boundary reset: the high-noise transformer's residual is meaningless to | ||
| # the low-noise transformer, so start the low phase with a fresh cache. | ||
| is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1] | ||
| if is_boundary: | ||
| cached_residual = None | ||
| accumulated_state = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base)[:6] | ||
|
|
||
| # Force a full compute inside either phase's warmup zone, at the boundary, | ||
| # or until we have a residual to reuse. | ||
| in_warmup = step < high_warmup_end or (high_noise_steps <= step < low_warmup_end) | ||
| force_compute = in_warmup or is_boundary or cached_residual is None | ||
|
|
||
| skip_blocks, accumulated_state = magcache_step( | ||
| step, | ||
| mag_ratios, | ||
| accumulated_state, | ||
| magcache_thresh, | ||
| magcache_K, | ||
| use_magcache=(not force_compute), | ||
| ) | ||
|
|
||
| latents_doubled = jnp.concatenate([latents] * 2) | ||
| timestep = jnp.broadcast_to(t, bsz * 2) | ||
| noise_pred, _, residual_x_cur = transformer_forward_pass( | ||
| graphdef, | ||
| state, | ||
| rest, | ||
| latents_doubled, | ||
| timestep, | ||
| prompt_embeds_combined, | ||
| do_classifier_free_guidance=True, | ||
| guidance_scale=guidance_scale, | ||
| skip_blocks=bool(skip_blocks), | ||
| cached_residual=cached_residual, | ||
| return_residual=True, | ||
| kv_cache=kv_cache, | ||
| rotary_emb=rotary_emb, | ||
| encoder_attention_mask=encoder_attention_mask, | ||
| ) | ||
|
|
||
| if skip_blocks: | ||
| cache_count += 1 | ||
| else: | ||
| cached_residual = residual_x_cur | ||
|
|
||
| latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() | ||
|
|
||
| max_logging.log( | ||
| f"[MagCache] Cached {cache_count}/{num_inference_steps} steps " | ||
| f"({100*cache_count/num_inference_steps:.1f}% cache ratio), " | ||
| f"high_noise_steps={high_noise_steps}, thresh={magcache_thresh}, K={magcache_K}" | ||
| ) | ||
| return latents | ||
|
|
||
| # ── SenCache path (arXiv:2602.24208) ── | ||
| if use_sen_cache and do_classifier_free_guidance: | ||
| timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can reduce the text content here.
Minor nit: It is 7x and not v7x
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do. I will change to 7x, but note that the whole readme has this mistake (and it exists also in other places in the repo)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, we will put put a PR to fix those