Skip to content

Warp backend does not support gradient-based optimization (stepper returns zero gradients) #161

@Medyan-Naser

Description

@Medyan-Naser

When using the Warp backend for differentiable LBM, gradients do not flow through the stepper. This prevents gradient descent and other optimization methods from working.


The Problem

When computing gradients through an LBM simulation:

Component Gradient Flow
Macroscopic (rho, u) Works
Stepper (collision + streaming) Returns 0.0

This breaks inverse problems and differentiable physics with Warp.


Test Script

Run with: python examples/cfd/test_stepper_autodiff.py


Test Output

======================================================================
XLB STEPPER AUTODIFF TEST
======================================================================

This test checks if gradients propagate through the LBM stepper.
We run the SAME test on both JAX and Warp backends and compare.

----------------------------------------------------------------------
TEST CONFIGURATION
----------------------------------------------------------------------
  Grid shape:       (32, 32)
  Omega:            1.8
  Precision:        FP32FP32
  Boundary:         Periodic (no walls)
  Collision:        BGK
  Test:             Forward 1 step -> Compute rho -> Loss -> Backward

======================================================================
RESULTS: SIDE-BY-SIDE COMPARISON
======================================================================

Metric                                   WARP            JAX            
----------------------------------------------------------------------
Loss value                               1024.00         1024.00        
Gradient norm (through stepper)          0.00            192.00         

----------------------------------------------------------------------
GRADIENT FLOW ANALYSIS (Warp)
----------------------------------------------------------------------

  Checking gradients at each stage:

    1. loss.grad (seed)              : 1.0
    2. d(loss)/d(rho) gradient norm  : 0.00
    3. d(loss)/d(f_out) gradient norm: 192.00
    4. d(loss)/d(f_in) gradient norm : 0.00  <-- PROBLEM

  Gradient flows: loss -> rho -> f_out (Macroscopic works)
  Gradient STOPS: f_out -> f_in (Stepper broken)


======================================================================
SUMMARY
======================================================================

  WARP: Loss=1024.00, Gradient=0.00 --> BROKEN
  JAX:  Loss=1024.00, Gradient=192.00 --> OK

======================================================================

Why This Happens

Warp's autodiff (wp.Tape) needs either:

  1. Simple kernels it can auto-differentiate, OR
  2. Manual @wp.func_grad adjoint functions

XLB's stepper has patterns Warp cannot auto-differentiate:

# From xlb/operator/stepper/nse_stepper.py
if _boundary_id == wp.uint8(255):
    return  # Early return breaks autodiff

Warp silently returns 0.0 when it cannot differentiate (no error thrown).

The Macroscopic operator works because it is a simple summation kernel. The Stepper (collision + streaming) has complex control flow that prevents automatic adjoint generation.


How to Fix (Future Work)

Add manual @wp.func_grad adjoint implementations for:

  1. xlb/operator/collision/bgk.py - warp_functional()
  2. xlb/operator/stream/stream.py - warp_functional()
  3. xlb/operator/equilibrium/*.py - warp_functional()

Workaround

Use the JAX backend for differentiable LBM applications.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions