Optimize Wan VAE: Add Temporal Chunking, On-TPU Postprocessing, Static JIT Support, and Configurable Dtypes#411
Optimize Wan VAE: Add Temporal Chunking, On-TPU Postprocessing, Static JIT Support, and Configurable Dtypes#411Perseus14 wants to merge 1 commit into
Conversation
1091ff8 to
7406dd0
Compare
…ic shapes, and JIT support
|
🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR introduces significant memory, compilation, and bandwidth optimizations for the Wan VAE on TPUs. By implementing temporal chunking, enforcing static shapes, and moving postprocessing operations directly into the JAX execution graph, the PR achieves better scaling efficiency and execution robustness.
🔍 General Feedback
- Optimization: The transition to on-TPU postprocessing (quantizing to
uint8on-device) is an excellent optimization that significantly reduces host-device bandwidth. - Architectural Clarity: Standardizing the spatial sharding rules and using temporal chunking with static padding/trimming is a very high-quality improvement for JAX/XLA performance.
- Breaking Change: Note that the pipeline's output has changed from floating-point values (typically
[0, 1]) touint8values ([0, 255]), and it now retains the batch dimension. This is well-handled in the included utility updates but may affect external users. - Test Integrity: The dramatic increase in tolerance for the KV cache test (
atol=180) is concerning and should be reviewed to ensure it's not masking a functional regression.
|
|
||
| # Compare outputs | ||
| np.testing.assert_allclose(video_no_cache, video_with_cache, rtol=1e-1, atol=0.7) | ||
| np.testing.assert_allclose(video_no_cache, video_with_cache, rtol=1e-1, atol=180) |
There was a problem hiding this comment.
🟠 The tolerance atol=180 is extremely high for uint8 image data (range 0-255). A difference of 180 represents over 70% of the possible pixel intensity range, which effectively means the regression test is no longer providing meaningful validation of the KV cache's numerical integrity.
If this change was prompted by moving postprocessing to the TPU, it's possible that the comparison is being made between float32 [0, 1] and uint8 [0, 255] or that there's a significant regression.
Please investigate the root cause of this divergence and use a more appropriate tolerance (e.g., atol=1 or 2 for uint8 data).
| np.testing.assert_allclose(video_no_cache, video_with_cache, rtol=1e-1, atol=180) | |
| np.testing.assert_allclose(video_no_cache, video_with_cache, rtol=1e-1, atol=2) |
| return next_feat_map, out_chunk | ||
|
|
||
| enc_feat_map, out_rest = jax.lax.scan(scan_fn, enc_feat_map, x_scannable) | ||
| # out_rest shape: (iter_-2, B, T', H, W, C) -> transpose back |
There was a problem hiding this comment.
🟡 To ensure optimal performance and avoid unnecessary cross-mesh synchronization within the jax.lax.scan loop, the next_feat_map (which serves as the scan carry) should explicitly have sharding constraints applied, similar to the implementation in the _decode method.
| # out_rest shape: (iter_-2, B, T', H, W, C) -> transpose back | |
| next_feat_map = jax.tree_util.tree_map( | |
| lambda x: jax.lax.with_sharding_constraint(x, spatial_sharding) if isinstance(x, jax.Array) else x, next_feat_map | |
| ) | |
| return next_feat_map, out_chunk |
| video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) | ||
| video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) | ||
| return self.video_processor.postprocess_video(video, output_type="np") | ||
| video = np.array(video) |
There was a problem hiding this comment.
🟡 The removal of self.video_processor.postprocess_video and the transpose operation changes the return type and shape of the pipeline. Previously, it returned a list of numpy arrays or a 4D array (depending on the processor) in the format expected by most users. Now, it returns a 5D np.ndarray of shape (Batch, Time, Height, Width, Channels) with uint8 precision.
While this is more efficient for TPU-to-Host transfer, please ensure that this change in the public API is intentional and that all downstream consumers (like example scripts or notebook users) are aware that they may now need to index into the batch dimension (e.g., videos[0]) before processing.
Overview
This pull request introduces significant memory, compilation, and bandwidth optimizations for the Wan (T2V, I2V, Vace and Animate) VAE encoding and decoding execution graphs on TPUs. By enforcing static shapes and moving postprocessing operations on-device, this PR significantly improves scaling efficiency and execution robustness.
Key Changes & Optimizations
1. VAE Temporal Chunking & Static Compilation
vae_decode_chunkandvae_encode_chunk) in the base configuration files.jax.lax.scanloops and precise output trimming post-scan. This guarantees uniform static shapes across iterations, allowing seamless static JIT/XLA compilation without dynamic shape fallbacks or recompilation overhead.2. On-TPU Output Quantization
[0, 1], clipping, and scaling touint8) directly into the JAX/TPU execution graph across all Wan pipelines (WanPipeline,WanPipelineI2V, andWanAnimatePipeline).uint8instead of floating-point cuts device-to-host memory interconnect transfer bandwidth by up to 4x.torch.Tensorconversions and PyTorch processor overhead during generation output assembly.3. Distributed Sharding Robustness
P("redundant", None, None, "vae_spatial", None). Explicitly marking the batch axis as redundant prevents unintended cross-mesh synchronization across replicas under XLA SPMD partitioning.attention_flax.pyto support diverse hardware topologies gracefully.4. Configurable Activation & Weight Precision
pyconfig.pyfor user-defined numerical precision (vae_dtype,vae_weights_dtype, andscheduler_dtype), defaulting VAE runtime memory down to 16-bit precision to save 50% High Bandwidth Memory (HBM) capacity.Performance
Note
Test Configuration: 720p | 81 frames | 40 steps
Hardware: TPU 7x-8
JAX Version: v0.10.0
Conclusion: No visual change from baseline across all tested variants.