When using the Warp backend for differentiable LBM, gradients do not flow through the stepper. This prevents gradient descent and other optimization methods from working.
The Problem
When computing gradients through an LBM simulation:
| Component |
Gradient Flow |
Macroscopic (rho, u) |
Works |
Stepper (collision + streaming) |
Returns 0.0 |
This breaks inverse problems and differentiable physics with Warp.
Test Script
Run with: python examples/cfd/test_stepper_autodiff.py
Test Output
======================================================================
XLB STEPPER AUTODIFF TEST
======================================================================
This test checks if gradients propagate through the LBM stepper.
We run the SAME test on both JAX and Warp backends and compare.
----------------------------------------------------------------------
TEST CONFIGURATION
----------------------------------------------------------------------
Grid shape: (32, 32)
Omega: 1.8
Precision: FP32FP32
Boundary: Periodic (no walls)
Collision: BGK
Test: Forward 1 step -> Compute rho -> Loss -> Backward
======================================================================
RESULTS: SIDE-BY-SIDE COMPARISON
======================================================================
Metric WARP JAX
----------------------------------------------------------------------
Loss value 1024.00 1024.00
Gradient norm (through stepper) 0.00 192.00
----------------------------------------------------------------------
GRADIENT FLOW ANALYSIS (Warp)
----------------------------------------------------------------------
Checking gradients at each stage:
1. loss.grad (seed) : 1.0
2. d(loss)/d(rho) gradient norm : 0.00
3. d(loss)/d(f_out) gradient norm: 192.00
4. d(loss)/d(f_in) gradient norm : 0.00 <-- PROBLEM
Gradient flows: loss -> rho -> f_out (Macroscopic works)
Gradient STOPS: f_out -> f_in (Stepper broken)
======================================================================
SUMMARY
======================================================================
WARP: Loss=1024.00, Gradient=0.00 --> BROKEN
JAX: Loss=1024.00, Gradient=192.00 --> OK
======================================================================
Why This Happens
Warp's autodiff (wp.Tape) needs either:
- Simple kernels it can auto-differentiate, OR
- Manual
@wp.func_grad adjoint functions
XLB's stepper has patterns Warp cannot auto-differentiate:
# From xlb/operator/stepper/nse_stepper.py
if _boundary_id == wp.uint8(255):
return # Early return breaks autodiff
Warp silently returns 0.0 when it cannot differentiate (no error thrown).
The Macroscopic operator works because it is a simple summation kernel. The Stepper (collision + streaming) has complex control flow that prevents automatic adjoint generation.
How to Fix (Future Work)
Add manual @wp.func_grad adjoint implementations for:
xlb/operator/collision/bgk.py - warp_functional()
xlb/operator/stream/stream.py - warp_functional()
xlb/operator/equilibrium/*.py - warp_functional()
Workaround
Use the JAX backend for differentiable LBM applications.
When using the Warp backend for differentiable LBM, gradients do not flow through the stepper. This prevents gradient descent and other optimization methods from working.
The Problem
When computing gradients through an LBM simulation:
Macroscopic(rho, u)Stepper(collision + streaming)This breaks inverse problems and differentiable physics with Warp.
Test Script
Run with:
python examples/cfd/test_stepper_autodiff.pyTest Output
Why This Happens
Warp's autodiff (
wp.Tape) needs either:@wp.func_gradadjoint functionsXLB's stepper has patterns Warp cannot auto-differentiate:
Warp silently returns 0.0 when it cannot differentiate (no error thrown).
The
Macroscopicoperator works because it is a simple summation kernel. TheStepper(collision + streaming) has complex control flow that prevents automatic adjoint generation.How to Fix (Future Work)
Add manual
@wp.func_gradadjoint implementations for:xlb/operator/collision/bgk.py-warp_functional()xlb/operator/stream/stream.py-warp_functional()xlb/operator/equilibrium/*.py-warp_functional()Workaround
Use the JAX backend for differentiable LBM applications.