Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions examples/cfd/flow_past_cylinder_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
"""
Flow past a cylinder in a 2D channel (D2Q9, Re=100).

Classic bluff-body channel flow: Poiseuille inflow passes a circular cylinder and
exits through an open outlet. Used for wake structure, vortex shedding, and drag/lift
coefficients at moderate Reynolds number.

Physical parameters
-------------------
* Reynolds number ``Re = 100`` (based on cylinder diameter and ``prescribed_vel``).
* Relaxation rate ``omega = 1 / (3 * nu + 0.5)`` with
``nu = prescribed_vel * diam / Re``.
* Reference inlet speed ``prescribed_vel = 0.003 * (reference_diam / diam)`` so
resolution can be changed while keeping similar lattice dynamics.

Domain and geometry
-------------------
* Channel size: ``(22 * diam, 4.1 * diam)`` lattice nodes in x and y.
* Cylinder: center ``(2 * diam, 2 * diam)``, radius ``diam / 2``.
* Default ``diam = 80`` → grid ``(1760, 328)`` and ~2.7M time steps to
``100 * diam / prescribed_vel``. Use a smaller ``diam`` (e.g. 20) for quick tests.

Boundary conditions
-------------------
* **Inlet (left):** ``RegularizedBC`` with parabolic Poiseuille profile;
peak speed ``u_peak = 1.5 * prescribed_vel``.
* **Outlet (right):** ``ExtrapolationOutflowBC``.
* **Top / bottom:** ``FullwayBounceBackBC`` (no-slip channel walls).
* **Cylinder:** ``HalfwayBounceBackBC`` (no-slip body).

Compute backends
----------------
* **WARP (default):** Recommended for large 2D GPU runs.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list of backends excludes Neon. Is there a specific reason?

* **JAX:** Set ``compute_backend = ComputeBackend.JAX``. On NVIDIA Ampere and
newer GPUs, keep the TF32 overrides at the top of this file (or export
``NVIDIA_TF32_OVERRIDE=0`` before launch) so ``jnp.tensordot`` in equilibrium
and macroscopic operators use full FP32 accuracy.

Post-processing uses ``Macroscopic`` on JAX (populations are converted from Warp
when needed). PNG snapshots are written with prefix ``flow_past_cylinder_2d``.

Forces
------
After step ``> 0.5 * num_steps``, ``MomentumTransfer`` reports ``CD`` and ``CL``
(drag along x, lift along y), normalized by ``prescribed_vel**2 * diam``. Running
maxima ``CD_max`` and ``CL_max`` are printed at the end.

Usage
-----
From the repository root (with ``PYTHONPATH`` pointing at this repo)::

python3 examples/cfd/flow_past_cylinder_2d.py

For JAX on GPU with full FP32 matmul precision::

NVIDIA_TF32_OVERRIDE=0 python3 examples/cfd/flow_past_cylinder_2d.py
"""

import os
import time

# Full FP32 matmul on NVIDIA GPUs (Macroscopic and other JAX operators use tensordot).
os.environ.setdefault("NVIDIA_TF32_OVERRIDE", "0")

import jax

jax.config.update("jax_default_matmul_precision", "highest")

import jax.numpy as jnp
import numpy as np
import warp as wp
import xlb
from xlb.compute_backend import ComputeBackend
from xlb.grid import grid_factory
from xlb.operator.boundary_condition import (
ExtrapolationOutflowBC,
HalfwayBounceBackBC,
RegularizedBC,
)
from xlb.operator.force.momentum_transfer import MomentumTransfer
from xlb.operator.macroscopic import Macroscopic
from xlb.operator.stepper import IncompressibleNavierStokesStepper
from xlb.precision_policy import PrecisionPolicy
from xlb.utils import save_image, warp_array_to_jax

# -------------------------- Simulation Setup --------------------------

diam = 80 # Cylinder diameter in lattice units; reduce (e.g. 20) for faster runs
reference_diam = 80
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some comments would help highlight the need for a second variable for the diameter.


Re = 100.0
scale_factor = reference_diam / diam
prescribed_vel = 0.003 * scale_factor
visc = prescribed_vel * diam / Re
omega = 1.0 / (3.0 * visc + 0.5)

grid_shape = (int(22 * diam), int(4.1 * diam))
cylinder_center = (2.0 * diam, 2.0 * diam)
cylinder_radius = diam / 2.0
Comment on lines +97 to +99
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd move these after line 88 to keep all geometric parameters in the same place.


compute_backend = ComputeBackend.JAX
precision_policy = PrecisionPolicy.FP32FP32
velocity_set = xlb.velocity_set.D2Q9(precision_policy=precision_policy, compute_backend=compute_backend)

characteristic_time = prescribed_vel / diam
num_steps = int(100 / characteristic_time)
post_process_interval = max(1, int(500 / scale_factor))

u_peak = 1.5 * prescribed_vel # Poiseuille peak (1.5 × mean inlet speed)


def bc_profile():
"""Parabolic inlet profile: u_x = 4 u_peak / d² (y d - y²), u_y = 0."""
channel_height = float(grid_shape[1] - 1)

if compute_backend == ComputeBackend.JAX:

def bc_profile_jax():
y = jnp.arange(grid_shape[1], dtype=jnp.float32)
u_x = jnp.maximum(
0.0,
4.0 * u_peak / channel_height**2 * (y * channel_height - y**2),
)
u_y = jnp.zeros_like(u_x)
return jnp.stack([u_x, u_y])

return bc_profile_jax

wp_dtype = precision_policy.compute_precision.wp_dtype
ch = wp_dtype(channel_height)
four = wp_dtype(4.0)
u_peak_wp = wp_dtype(u_peak)

@wp.func
def bc_profile_warp(index: wp.vec3i):
y = wp_dtype(index[1])
u_x = four * u_peak_wp / (ch * ch) * (y * ch - y * y)
return wp.vec(wp.max(wp_dtype(0.0), u_x), length=1)

return bc_profile_warp


def main() -> None:
xlb.init(
velocity_set=velocity_set,
default_backend=compute_backend,
default_precision_policy=precision_policy,
)

grid = grid_factory(grid_shape, compute_backend=compute_backend)

box = grid.bounding_box_indices()
box_no_edge = grid.bounding_box_indices(remove_edges=True)
inlet = box_no_edge["left"]
outlet = box_no_edge["right"]
walls = [box["bottom"][i] + box["top"][i] for i in range(velocity_set.d)]
walls = np.unique(np.array(walls), axis=-1).tolist()

x = np.arange(grid_shape[0])
y = np.arange(grid_shape[1])
X, Y = np.meshgrid(x, y, indexing="ij")
cx, cy = cylinder_center
cyl_idx = np.where((X - cx) ** 2 + (Y - cy) ** 2 <= cylinder_radius**2)
cylinder = [tuple(cyl_idx[i].tolist()) for i in range(velocity_set.d)]

bc_inlet = RegularizedBC("velocity", profile=bc_profile(), indices=inlet)
bc_walls = HalfwayBounceBackBC(indices=walls)
bc_outlet = ExtrapolationOutflowBC(indices=outlet)
bc_cylinder = HalfwayBounceBackBC(indices=cylinder)
boundary_conditions = [bc_walls, bc_inlet, bc_outlet, bc_cylinder]

stepper = IncompressibleNavierStokesStepper(
grid=grid,
boundary_conditions=boundary_conditions,
collision_type="BGK",
)
f_0, f_1, bc_mask, missing_mask = stepper.prepare_fields()

macro = Macroscopic(
compute_backend=ComputeBackend.JAX,
precision_policy=precision_policy,
velocity_set=xlb.velocity_set.D2Q9(
precision_policy=precision_policy,
compute_backend=ComputeBackend.JAX,
),
)
momentum_transfer = MomentumTransfer(bc_cylinder, compute_backend=compute_backend)

force_stats = {"CL_max": 0.0, "CD_max": 0.0}

def post_process(step: int, f_0, f_1) -> None:
wp.synchronize()

if step > 0.5 * num_steps:
boundary_force = momentum_transfer(f_0, f_1, bc_mask, missing_mask)
drag = boundary_force[0]
lift = boundary_force[1]
cd = 2.0 * drag / (prescribed_vel**2 * diam)
cl = 2.0 * lift / (prescribed_vel**2 * diam)
force_stats["CL_max"] = max(force_stats["CL_max"], float(cl))
force_stats["CD_max"] = max(force_stats["CD_max"], float(cd))
print(f"step={step:7d}, CL={cl: .6f}, CD={cd: .6f}, CL_max={force_stats['CL_max']: .6f}, CD_max={force_stats['CD_max']: .6f}")

if not isinstance(f_0, jnp.ndarray):
# Warp pads 2D domains with a singleton z dimension
f_0 = warp_array_to_jax(f_0)[..., 0]
wp.synchronize()

_, u = macro(f_0)
u_magnitude = jnp.sqrt(u[0] ** 2 + u[1] ** 2)

save_image(u_magnitude, timestep=step, prefix="flow_past_cylinder_2d")
print(f"Post-processed step {step}: saved velocity magnitude (prefix=flow_past_cylinder_2d)")

print(
f"grid_shape={grid_shape}, Re={Re}, omega={omega:.6f}, prescribed_vel={prescribed_vel}, num_steps={num_steps}, backend={compute_backend.name}"
)

start_time = time.time()
for step in range(num_steps):
f_0, f_1 = stepper(f_0, f_1, bc_mask, missing_mask, omega, step)
f_0, f_1 = f_1, f_0

if step % post_process_interval == 0 or step == num_steps - 1:
post_process(step, f_0, f_1)
elapsed = time.time() - start_time
print(f"Completed step {step}. Elapsed for last chunk: {elapsed:.6f} s.")
start_time = time.time()

print(f"Final CL_max={force_stats['CL_max']:.6f}, CD_max={force_stats['CD_max']:.6f}")


if __name__ == "__main__":
main()
29 changes: 20 additions & 9 deletions xlb/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ def register_backend(cls, backend_name):
"""

def decorator(func):
subclass_name = func.__qualname__.split(".")[0]
signature = inspect.signature(func)
unwrapped = inspect.unwrap(func)
qualname = unwrapped.__qualname__
subclass_name = qualname.rsplit(".", 1)[0] if "." in qualname else qualname
signature = inspect.signature(unwrapped)
key = (subclass_name, backend_name, str(signature))
cls._backends[key] = func
return func
Expand All @@ -98,24 +100,26 @@ def __call__(self, *args, callback=None, **kwargs):
------
NotImplementedError
If no implementation is registered for the active backend.
Exception
If all candidate implementations raise errors.
RuntimeError
If all candidate implementations raise errors (chained from the last exception).
"""
method_candidates = [
(key, method) for key, method in self._backends.items() if key[0] == self.__class__.__name__ and key[1] == self.compute_backend
]
if not method_candidates:
supported = [key for key in self._backends.keys() if key[0] == self.__class__.__name__]
raise NotImplementedError(
f"No implementation found for operator {self.__class__.__name__} with backend {self.compute_backend}. "
f"No implementation found for operator {type(self).__qualname__} with backend {self.compute_backend}. "
f"Available implementations: {supported}"
)

bound_arguments = None
key = None
last_key = None
last_method = None
error = None
traceback_str = None
for key, backend_method in method_candidates:
last_key = key
last_method = backend_method
try:
# This attempts to bind the provided args and kwargs to the compute_backend method's signature
bound_arguments = inspect.signature(backend_method).bind(self, *args, **kwargs)
Expand All @@ -129,8 +133,15 @@ def __call__(self, *args, callback=None, **kwargs):
error = e
traceback_str = traceback.format_exc()
continue # This skips to the next candidate if binding fails
method_candidates = [(key, method) for key, method in self._backends.items() if key[1] == self.compute_backend]
raise Exception(f"Error captured for backend with key {key} for operator {self.__class__.__name__}: {error}\n {traceback_str}")

impl_qualname = inspect.unwrap(last_method).__qualname__ if last_method is not None else "unknown"
registered_class = last_key[0] if last_key is not None else "unknown"
instance_class = type(self).__qualname__
if instance_class != registered_class:
instance_note = f", instance={instance_class}"
else:
instance_note = ""
raise RuntimeError(f"{impl_qualname} failed for backend {self.compute_backend}{instance_note}: {error}\n{traceback_str}") from error

@property
def supported_compute_backend(self):
Expand Down
Loading