JAX-NG is a modular JAX framework for second-order optimization of physics-informed neural networks (PINNs), with a focus on Gauss-Newton and natural-gradient-style methods that scale beyond the regime typically accessible to parameter-space solvers.
This repository accompanies the paper:
Dual Natural Gradient Descent for Scalable Training of Physics-Informed Neural Networks
Anas Jnini, Flavio Vella
Published in Transactions on Machine Learning Research (2025)
JAX-NG is designed as a research codebase for building and evaluating second-order PINN training pipelines across multiple PDE benchmarks, with reusable components for models, sampling, line search, optimization, training, and problem definitions.
- Modular JAX implementation for second-order PINN optimization
- Gauss-Newton solvers with automatic primal/dual system selection
- Problem implementations covering elliptic, fluid, and time-dependent PDEs
- Example scripts for reproducing supported experiments
- Helmholtz
- Kovasznay
- KdV with windowed hard-IC ansatz
- KS1d with windowed hard-IC ansatz
- Stokes wedge with pressure anchor
- Beltrami 3D (space-time Navier-Stokes benchmark)
jax_ng/examples/helmholtz_gn.pyjax_ng/examples/kovasznay_gn.pyjax_ng/examples/kdv_windowed_gn.pyjax_ng/examples/beltrami_gn.py
jax_ng/
models/ # activations, initialization, MLPs, jets
samplers/ # box and triangle/wedge samplers
linesearch/ # grid, Armijo, Wolfe, fixed-step rules
optimizers/ # gauss_newton, multistage, windowed_gn, stokes_gn
problems/ # helmholtz, kovasznay, kdv, ks1d, stokes_wedge, beltrami
utils/ # trainer, metrics, checkpointing, plotting
examples/ # runnable scripts
tests/ # pytest suite
For a clean setup, we recommend creating a fresh Conda environment first:
conda create -n jax-ng python=3.11 -y
conda activate jax-ng
git clone https://github.com/HicrestLaboratory/JAX-NG.git
cd JAX-NG
pip install -e .If the repository includes a requirements.txt, install it before the editable package install:
pip install -r requirements.txt
pip install -e .If you are using a CPU-only environment, a common setup is:
pip install --upgrade "jax[cpu]"
pip install -e .Python imports follow the package namespace:
import jax_ngSome examples expect external data files under ./data:
kdv.matks_chaotic.matst_flow.csv
Make sure these files are available before running the corresponding scripts.
From the repository root, the bundled examples can be launched with:
python -m jax_ng.examples.helmholtz_gn
python -m jax_ng.examples.kovasznay_gn
python -m jax_ng.examples.kdv_windowed_gn
python -m jax_ng.examples.ks1d_windowed_gn
python -m jax_ng.examples.stokes_gnThe KdV, KS1d, and Stokes examples require the corresponding files in ./data.
Install pytest if needed:
pip install pytestThen run the full test suite from the repository root:
pytestYou can also target the package tests directly:
pytest jax_ng/testsFor more verbose output:
pytest -vAfter installation, the following quick checks are useful:
python -c "import jax_ng; print('jax_ng import OK')"
python -c "import jax; print(jax.__version__)"The snippet below illustrates the basic workflow for defining a problem, sampling collocation points, constructing a Gauss-Newton optimizer, and running training.
import jax
import jax.numpy as jnp
from jax import random
from jax_ng import linesearch, models, optimizers, samplers, utils
jax.config.update("jax_enable_x64", True)
class SimplePoisson1D:
def exact_u(self, x):
return jnp.sin(jnp.pi * x[0])
def forcing(self, x):
return -(jnp.pi ** 2) * jnp.sin(jnp.pi * x[0])
def interior_res(self, params, x):
u, lap_u = models.jet_laplacian(params, x)
return lap_u[0] - self.forcing(x)
def boundary_res(self, params, x):
u, _ = models.jet_laplacian(params, x)
return u[0] - self.exact_u(x)
def init_params(self, key):
sizes = models.layer_sizes(input_dim=1, width=32, depth=3, output_dim=1)
return models.glorot_init(sizes, key)
pde = SimplePoisson1D()
params = pde.init_params(random.PRNGKey(0))
sampler = lambda key: samplers.uniform_box(key, 256, 64, ((-1.0, 1.0),))
ls = linesearch.build("grid_search", n_steps=12)
opt = optimizers.GaussNewton(
interior_res_fn=pde.interior_res,
boundary_res_fn=pde.boundary_res,
sampler_fn=sampler,
linesearch_fn=ls,
solve_config=optimizers.SolveConfig(mode="auto", damping=1e-8),
)
trainer = utils.Trainer(opt, n_iters=300, log_interval=50)
params, history = trainer.run(params, random.PRNGKey(1))JAX-NG supports automatic selection between primal and dual linear systems through:
optimizers.SolveConfig(mode="auto")The current rule is:
dualwhenN_params > N_residualsprimalotherwise
You can also force the system explicitly with:
mode="dual"mode="primal"
This codebase is intended as a companion research repository for the paper and as a starting point for:
- reproducing reported experiments
- extending second-order PINN optimizers
- adding new PDE benchmarks and sampling schemes
- experimenting with primal and residual-space Gauss-Newton formulations
If you use this repository in academic work, please cite the accompanying paper:
@article{jnini2025dual,
title = {Dual Natural Gradient Descent for Scalable Training of Physics-Informed Neural Networks},
author = {Jnini, Anas and Vella, Flavio},
journal = {Transactions on Machine Learning Research},
year = {2025}
}JAX-NG is an active research-oriented codebase. The repository currently supports the optimizers, problems, and examples listed above, with the structure intentionally kept modular for further extensions.