Skip to content

Optimize Wan VAE: Add Temporal Chunking, On-TPU Postprocessing, Static JIT Support, and Configurable Dtypes#411

Open
Perseus14 wants to merge 1 commit into
mainfrom
wan_vae_opt
Open

Optimize Wan VAE: Add Temporal Chunking, On-TPU Postprocessing, Static JIT Support, and Configurable Dtypes#411
Perseus14 wants to merge 1 commit into
mainfrom
wan_vae_opt

Conversation

@Perseus14
Copy link
Copy Markdown
Collaborator

@Perseus14 Perseus14 commented May 18, 2026

Overview

This pull request introduces significant memory, compilation, and bandwidth optimizations for the Wan (T2V, I2V, Vace and Animate) VAE encoding and decoding execution graphs on TPUs. By enforcing static shapes and moving postprocessing operations on-device, this PR significantly improves scaling efficiency and execution robustness.

Key Changes & Optimizations

1. VAE Temporal Chunking & Static Compilation

  • Introduced configurable chunk size parameters (vae_decode_chunk and vae_encode_chunk) in the base configuration files.
  • Added static temporal zero-padding prior to jax.lax.scan loops and precise output trimming post-scan. This guarantees uniform static shapes across iterations, allowing seamless static JIT/XLA compilation without dynamic shape fallbacks or recompilation overhead.

2. On-TPU Output Quantization

  • Transitioned the final video post-processing stages (range rescaling to [0, 1], clipping, and scaling to uint8) directly into the JAX/TPU execution graph across all Wan pipelines (WanPipeline, WanPipelineI2V, and WanAnimatePipeline).
  • Gathering sharded arrays as uint8 instead of floating-point cuts device-to-host memory interconnect transfer bandwidth by up to 4x.
  • Completely eliminated the need for torch.Tensor conversions and PyTorch processor overhead during generation output assembly.

3. Distributed Sharding Robustness

  • Standardized VAE 1D spatial sharding rules to use P("redundant", None, None, "vae_spatial", None). Explicitly marking the batch axis as redundant prevents unintended cross-mesh synchronization across replicas under XLA SPMD partitioning.
  • Added defensive dimension existence checks when accessing mesh properties in attention_flax.py to support diverse hardware topologies gracefully.
  • Cleaned up Attention QKV splitting logic with direct matrix reshaping, establishing optimal contiguous memory alignment for dot-product attention kernels.

4. Configurable Activation & Weight Precision

  • Integrated robust type casting in pyconfig.py for user-defined numerical precision (vae_dtype, vae_weights_dtype, and scheduler_dtype), defaulting VAE runtime memory down to 16-bit precision to save 50% High Bandwidth Memory (HBM) capacity.

Performance

Note

Test Configuration: 720p | 81 frames | 40 steps
Hardware: TPU 7x-8
JAX Version: v0.10.0

Model Variant BaselIne Generation Time Current Generation Time
WAN2.2 T2V 132.2s 130.6s
WAN2.2 I2V 133.3s 132.3s
WAN2.1 T2V 132.1s 130.6s
WAN2.1 I2V 142.7s 141.6s

Conclusion: No visual change from baseline across all tested variants.

@Perseus14 Perseus14 requested a review from entrpn as a code owner May 18, 2026 10:02
@github-actions
Copy link
Copy Markdown

@Perseus14 Perseus14 self-assigned this May 18, 2026
@Perseus14 Perseus14 requested review from eltsai and ninatu May 18, 2026 10:03
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py
Comment thread src/maxdiffusion/configs/base_wan_14b.yml Outdated
Comment thread src/maxdiffusion/models/wan/autoencoder_kl_wan.py Outdated
@Perseus14 Perseus14 force-pushed the wan_vae_opt branch 3 times, most recently from 1091ff8 to 7406dd0 Compare May 18, 2026 16:05
@github-actions
Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR introduces significant memory, compilation, and bandwidth optimizations for the Wan VAE on TPUs. By implementing temporal chunking, enforcing static shapes, and moving postprocessing operations directly into the JAX execution graph, the PR achieves better scaling efficiency and execution robustness.

🔍 General Feedback

  • Optimization: The transition to on-TPU postprocessing (quantizing to uint8 on-device) is an excellent optimization that significantly reduces host-device bandwidth.
  • Architectural Clarity: Standardizing the spatial sharding rules and using temporal chunking with static padding/trimming is a very high-quality improvement for JAX/XLA performance.
  • Breaking Change: Note that the pipeline's output has changed from floating-point values (typically [0, 1]) to uint8 values ([0, 255]), and it now retains the batch dimension. This is well-handled in the included utility updates but may affect external users.
  • Test Integrity: The dramatic increase in tolerance for the KV cache test (atol=180) is concerning and should be reviewed to ensure it's not masking a functional regression.


# Compare outputs
np.testing.assert_allclose(video_no_cache, video_with_cache, rtol=1e-1, atol=0.7)
np.testing.assert_allclose(video_no_cache, video_with_cache, rtol=1e-1, atol=180)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 The tolerance atol=180 is extremely high for uint8 image data (range 0-255). A difference of 180 represents over 70% of the possible pixel intensity range, which effectively means the regression test is no longer providing meaningful validation of the KV cache's numerical integrity.

If this change was prompted by moving postprocessing to the TPU, it's possible that the comparison is being made between float32 [0, 1] and uint8 [0, 255] or that there's a significant regression.

Please investigate the root cause of this divergence and use a more appropriate tolerance (e.g., atol=1 or 2 for uint8 data).

Suggested change
np.testing.assert_allclose(video_no_cache, video_with_cache, rtol=1e-1, atol=180)
np.testing.assert_allclose(video_no_cache, video_with_cache, rtol=1e-1, atol=2)

return next_feat_map, out_chunk

enc_feat_map, out_rest = jax.lax.scan(scan_fn, enc_feat_map, x_scannable)
# out_rest shape: (iter_-2, B, T', H, W, C) -> transpose back
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 To ensure optimal performance and avoid unnecessary cross-mesh synchronization within the jax.lax.scan loop, the next_feat_map (which serves as the scan carry) should explicitly have sharding constraints applied, similar to the implementation in the _decode method.

Suggested change
# out_rest shape: (iter_-2, B, T', H, W, C) -> transpose back
next_feat_map = jax.tree_util.tree_map(
lambda x: jax.lax.with_sharding_constraint(x, spatial_sharding) if isinstance(x, jax.Array) else x, next_feat_map
)
return next_feat_map, out_chunk

video = jax.experimental.multihost_utils.process_allgather(video, tiled=True)
video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16)
return self.video_processor.postprocess_video(video, output_type="np")
video = np.array(video)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 The removal of self.video_processor.postprocess_video and the transpose operation changes the return type and shape of the pipeline. Previously, it returned a list of numpy arrays or a 4D array (depending on the processor) in the format expected by most users. Now, it returns a 5D np.ndarray of shape (Batch, Time, Height, Width, Channels) with uint8 precision.

While this is more efficient for TPU-to-Host transfer, please ensure that this change in the public API is intentional and that all downstream consumers (like example scripts or notebook users) are aware that they may now need to index into the batch dimension (e.g., videos[0]) before processing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants