A small Equinox-based collection of JAX models for operator learning and PDE surrogates: a handful of core architectures, the KAN family, and wrappers around eight vendored foundation models. Plays nicely with jNO.
Early days — APIs may shift between minor versions.
pip install foundaxDevelopment setup uses pixi — see CONTRIBUTING.md.
Integration With jNO
import foundax as fx
import jno
import optax
net = jno.nn.wrap(fx.poseidon.T(num_channels=5, num_out_channels=1))
net.optimizer(
optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(
learning_rate=optax.schedules.warmup_cosine_decay_schedule(
init_value=1e-7,
peak_value=1e-3,
warmup_steps=500,
decay_steps=10000,
end_value=1e-6,
),
weight_decay=1e-4,
),
)
)
net.initialize('./poseidonT.eqx')
net.mask(param_mask).lora(rank=4)Full list with paper references: docs/architectures.md. Numerical parity against the actual upstream PyTorch classes (max abs diff ≤ 2e-4 for the 6 architectures with comparable references; WNO uses a different wavelet algorithm) is documented in the parity verification table; reproduce with pixi run verify-<name>.
| Family | Constructors | Reference |
|---|---|---|
| Linear / MLP | fx.linear, fx.mlp |
— |
| Fourier Neural Operator | fx.fno1d/2d/3d |
Li et al. 2020 — arXiv:2010.08895 |
| U-Net | fx.unet1d/2d/3d |
Ronneberger et al. 2015 — arXiv:1505.04597 |
| Generic transformer | fx.transformer |
Vaswani et al. 2017 — arXiv:1706.03762 |
| DeepONet | fx.deeponet |
Lu et al. 2019 — arXiv:1910.03193 |
| Continuous Neural Operator | fx.cno2d |
Raonić et al. 2023 — arXiv:2302.01178 |
| Multigrid Neural Operator | fx.mgno1d/2d |
He et al. 2023 — arXiv:2310.19809 |
| Geometry-aware FNO | fx.geofno |
Li et al. 2022 — arXiv:2207.05209 |
| Point-Cloud Neural Operator | fx.pcno |
PKU-CMEGroup/NeuralOperator |
| Position-induced Transformer | fx.pit |
Chen & Wu 2024 — arXiv:2405.09285 |
| PointNet | fx.pointnet |
Qi et al. 2017 — arXiv:1612.00593 |
| GNOT family | fx.gnot, fx.cgptno, fx.moegptno |
Hao et al., ICML 2023 — arXiv:2302.14376 |
| Diffusion Transformer (DiT) | fx.dit2d/3d |
Peebles & Xie 2022 — arXiv:2212.09748 |
| Factorized FNO | fx.ffno2d/3d |
Tran et al. 2023 — arXiv:2111.13802 |
| Wavelet Neural Operator | fx.wno1d/2d/3d |
Tripura & Chakraborty 2022 — arXiv:2205.02191 |
| Transolver | fx.transolver, fx.transolver2d/3d |
Wu et al., ICML 2024 — arXiv:2402.02366 |
| Spherical FNO | fx.sfno2d |
Bonev et al., ICML 2023 — arXiv:2306.03838 |
| GAOT | fx.gaot.S/M/L, fx.gaot_S/M/L |
Gao et al., NeurIPS 2025 — arXiv:2505.18781 |
| Factory | Basis | Reference |
|---|---|---|
fx.kan |
B-spline + SiLU residual | Liu et al. 2024 — arXiv:2404.19756 |
fx.kan.efficient |
B-spline (memory-optimised) | Blealtan 2024 — github.com/Blealtan/efficient-kan |
fx.kan.fast |
Gaussian RBF | Li 2024 — arXiv:2405.06721 |
fx.kan.fourier |
sin/cos series | GistNoesis 2024 — github.com/GistNoesis/FourierKAN |
fx.kan.chebyshev |
Chebyshev T_n | SS 2024 — arXiv:2405.07200 |
fx.kan.jacobi |
Jacobi P_n^(α,β) | Aghaei 2024 (fKAN) — arXiv:2406.07456 |
fx.kan.legendre |
Legendre P_n | Seydi 2024 — arXiv:2406.02583 |
fx.kan.wavelet |
Mexican hat / Morlet / Shannon / DoG | Bozorgasl & Chen 2024 (Wav-KAN) — arXiv:2405.12832 |
fx.kan.taylor |
Truncated power series | github.com/Muyuzhierchengse/TaylorKAN |
fx.kan.hermite |
Hermite He_n | Seydi 2024 — arXiv:2406.02583 |
fx.kan.laguerre |
Laguerre L_n | Seydi 2024 — arXiv:2406.02583 |
fx.kan.bernstein |
Bernstein polynomials | Seydi 2024 — arXiv:2406.02583 |
fx.kan.relu |
(ReLU·ReLU)^order on a grid | Qiu et al. 2024 — arXiv:2406.02075 |
fx.kan.rational |
Padé-style rational Chebyshev | Aghaei 2024 (rKAN) — arXiv:2406.14495 |
fx.kan.sinc |
sinc basis on a grid | Yu et al. 2024 (SincKAN) — arXiv:2410.04096 |
fx.kan.gram |
Orthonormal Legendre (Gram limit) | Seydi 2024 — arXiv:2406.02583 |
fx.kan.bsrbf |
B-spline + RBF concatenation | Ta 2024 (BSRBF-KAN) — arXiv:2406.11173 |
| Namespace | Variants | Backbone | Reference |
|---|---|---|---|
fx.poseidon |
T, B, L | ScOT (Swin operator transformer) | Herde et al. 2024 — arXiv:2405.19101 |
fx.morph |
Ti, S, M, L | ViT3D regression | Rautela et al. 2025 — arXiv:2509.21670 |
fx.mpp |
Ti, S, B, L | AViT (axial ViT) | McCabe et al., NeurIPS 2024 — openreview/DKSI3bULiZ |
fx.walrus |
base | Encoder-processor-decoder (1.29B) | McCabe et al. 2025 — arXiv:2511.15684 |
fx.bcat |
base | Block-causal transformer | Liu et al. 2025 — arXiv:2501.18972 |
fx.pdeformer2 |
small, base, fast | Graphormer + INR | Ye et al. 2025 — arXiv:2507.15409 |
fx.dpot |
Ti, S, M, L, H | DPOTNet (AFNO) | Hao et al., ICML 2024 — arXiv:2403.03542 |
fx.prose |
fd_1to1, fd_2to1, ode_2to1, pde_2to1 | Seq-to-seq transformer | Liu et al. 2023 — arXiv:2309.16816; follow-up Sun et al. 2024 — arXiv:2404.12355 |
fx.timesfm |
small | Decoder-only transformer (time-series, 200M, Flax NNX wrap; jit + fine-tuning) | Das et al. 2024 — arXiv:2310.10688 |
Pretrained weights keep their upstream licenses — see THIRD_PARTY_LICENSES.
import foundax as fx
# Core models
model = fx.mlp(in_features=2, output_dim=1, hidden_dims=64, num_layers=3)
model = fx.fno2d(in_features=1, hidden_channels=32, n_modes=16)
model = fx.unet2d(in_channels=1, out_channels=1)
model = fx.deeponet(branch_type="mlp", trunk_type="mlp")
# KAN family (one of 17 variants)
model = fx.kan.fast(in_features=2, output_dim=1, hidden_dims=64, num_layers=3)
# Foundation wrappers (namespace style)
model = fx.poseidon.T() # T/B/L
model = fx.morph.S() # Ti/S/M/L
model = fx.mpp.B(n_states=12) # Ti/S/B/L
model = fx.walrus.base()
model = fx.bcat.base()
model = fx.pdeformer2.small() # small/base/fast
model = fx.dpot.Ti() # Ti/S/M/L/H
model, variables = fx.prose.fd_1to1()
# TimesFM 2.5 — time-series foundation model (Flax NNX wrap, 200M params)
# - input/output: channel-last unbatched (context, 1) → (horizon, 1);
# also accepts (B, context, 1) → (B, horizon, 1)
# - horizon fixed at construction so JIT can specialise on it
# - context must be a multiple of 32 (input patch size)
import jax.numpy as jnp
model = fx.timesfm.small(horizon=24) # downloads checkpoint from HF on first call
y = model(jnp.zeros((512, 1))) # (24, 1)import equinox as eqx, optax, jax.numpy as jnp
import foundax as fx
model = fx.timesfm.small(horizon=24)
# JIT — compiles once per input shape
fast = eqx.filter_jit(model)
y = fast(jnp.zeros((512, 1)))
# Fine-tuning — gradients flow through the full ~231M-param state
optimizer = optax.adamw(1e-5)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
@eqx.filter_jit
def step(model, opt_state, x, y):
def loss_fn(m): return jnp.mean((m(x) - y) ** 2)
loss, grads = eqx.filter_value_and_grad(loss_fn)(model)
updates, opt_state = optimizer.update(grads, opt_state, eqx.filter(model, eqx.is_array))
return eqx.apply_updates(model, updates), opt_state, lossFor .eqx checkpoint serialisation (and the training/inference helpers built on top of it) see jNO — jno.nn.wrap(fx.timesfm.small(horizon=24)).initialize('./timesfm.eqx').
Caveats:
jax.vmapover the wrapper is not supported — the upstreamdecode()usesnnx.scanfor the per-layer carry, which conflicts withvmap's trace level. Use the explicit(B, context, 1)batched form instead (JIT specialises per batch shape; equally efficient).- TimesFM is strictly univariate — the channel axis is always size 1.
- Each distinct
horizontriggers a separate JIT trace; build one model per horizon you care about.
Wrap any model or layer with fx.block() and chain them with |.
Channel mismatches are caught at construction time with a clear error message.
import jax
import foundax as fx
ks = jax.random.split(jax.random.PRNGKey(0), 8)
# ── Build a 2-D FNO-style pipeline from individual spectral layers ──────────
lift = fx.block(fx.layers.SpectralBlock2d(1, 32, n_modes=16, key=ks[0]), name="lift")
s1 = fx.block(fx.layers.SpectralBlock2d(32, 32, n_modes=16, key=ks[1]))
s2 = fx.block(fx.layers.SpectralBlock2d(32, 32, n_modes=16, key=ks[2]))
s3 = fx.block(fx.layers.SpectralBlock2d(32, 32, n_modes=16, key=ks[3]))
project = fx.block(fx.layers.SpectralBlock2d(32, 1, n_modes=16, key=ks[4]), name="project")
model = lift | s1 | s2 | s3 | project # Pipe of 5 blocks
# ── Existing full models work as blocks too ──────────────────────────────────
encoder = fx.block(fx.fno2d(in_features=3, hidden_channels=32, n_modes=16, key=ks[5]))
decoder = fx.block(fx.layers.SpectralBlock2d(32, 1, n_modes=16, key=ks[6]))
model = encoder | decoder
# ── Multi-input combinators (DeepONet-style) ─────────────────────────────────
branch = (
fx.block(fx.layers.SpectralBlock1d(1, 32, n_modes=16, key=ks[0]))
| fx.block(fx.mlp(in_features=32, output_dim=64, hidden_dims=64, key=ks[1]))
)
trunk = fx.block(fx.mlp(in_features=2, output_dim=64, hidden_dims=64, key=ks[2]))
model = fx.dot(branch, trunk) # branch(u) · trunk(y) → (N_pts,)
# Also available: fx.add(a, b) — elementwise sum of two branches
# fx.cat(a, b) — concatenate outputs along the channel axis
# ── All pipe models are plain Equinox modules ────────────────────────────────
import equinox as eqx, optax, jax.numpy as jnp
opt = optax.adam(1e-3)
state = opt.init(eqx.filter(model, eqx.is_array))
@eqx.filter_jit
def step(model, state, u, y, target):
loss, grads = eqx.filter_value_and_grad(
lambda m: jnp.mean((m(u, y) - target) ** 2)
)(model)
updates, state = opt.update(grads, state, eqx.filter(model, eqx.is_array))
return eqx.apply_updates(model, updates), state, lossIf you use foundax in academic work, the accompanying paper is the jNO preprint (arXiv:2605.10159). A machine-readable CITATION.cff is provided.
Parts of this codebase — including model ports, tests, and documentation — were developed with the assistance of AI coding tools. All contributions are reviewed and tested to the best of our ability, but mistakes may remain; please open an issue if you spot one.
EPL-2.0 — see LICENSE. Vendored foundation-model code and pretrained weights keep their original licenses (see THIRD_PARTY_LICENSES); Poseidon weights are non-commercial.
