Skip to content

FhG-IISB/foundax

foundax logo

Tests License Hugging Face

A small Equinox-based collection of JAX models for operator learning and PDE surrogates: a handful of core architectures, the KAN family, and wrappers around eight vendored foundation models. Plays nicely with jNO.

Early days — APIs may shift between minor versions.

Install

pip install foundax

Development setup uses pixi — see CONTRIBUTING.md.

Integration With jNO

import foundax as fx
import jno
import optax

net = jno.nn.wrap(fx.poseidon.T(num_channels=5, num_out_channels=1))
net.optimizer(
    optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adamw(
            learning_rate=optax.schedules.warmup_cosine_decay_schedule(
                init_value=1e-7,
                peak_value=1e-3,
                warmup_steps=500,
                decay_steps=10000,
                end_value=1e-6,
            ),
            weight_decay=1e-4,
        ),
    )
)
net.initialize('./poseidonT.eqx')
net.mask(param_mask).lora(rank=4)

Supported architectures

Full list with paper references: docs/architectures.md. Numerical parity against the actual upstream PyTorch classes (max abs diff ≤ 2e-4 for the 6 architectures with comparable references; WNO uses a different wavelet algorithm) is documented in the parity verification table; reproduce with pixi run verify-<name>.

Core architectures

Family Constructors Reference
Linear / MLP fx.linear, fx.mlp
Fourier Neural Operator fx.fno1d/2d/3d Li et al. 2020 — arXiv:2010.08895
U-Net fx.unet1d/2d/3d Ronneberger et al. 2015 — arXiv:1505.04597
Generic transformer fx.transformer Vaswani et al. 2017 — arXiv:1706.03762
DeepONet fx.deeponet Lu et al. 2019 — arXiv:1910.03193
Continuous Neural Operator fx.cno2d Raonić et al. 2023 — arXiv:2302.01178
Multigrid Neural Operator fx.mgno1d/2d He et al. 2023 — arXiv:2310.19809
Geometry-aware FNO fx.geofno Li et al. 2022 — arXiv:2207.05209
Point-Cloud Neural Operator fx.pcno PKU-CMEGroup/NeuralOperator
Position-induced Transformer fx.pit Chen & Wu 2024 — arXiv:2405.09285
PointNet fx.pointnet Qi et al. 2017 — arXiv:1612.00593
GNOT family fx.gnot, fx.cgptno, fx.moegptno Hao et al., ICML 2023 — arXiv:2302.14376
Diffusion Transformer (DiT) fx.dit2d/3d Peebles & Xie 2022 — arXiv:2212.09748
Factorized FNO fx.ffno2d/3d Tran et al. 2023 — arXiv:2111.13802
Wavelet Neural Operator fx.wno1d/2d/3d Tripura & Chakraborty 2022 — arXiv:2205.02191
Transolver fx.transolver, fx.transolver2d/3d Wu et al., ICML 2024 — arXiv:2402.02366
Spherical FNO fx.sfno2d Bonev et al., ICML 2023 — arXiv:2306.03838
GAOT fx.gaot.S/M/L, fx.gaot_S/M/L Gao et al., NeurIPS 2025 — arXiv:2505.18781

Kolmogorov–Arnold Networks

Factory Basis Reference
fx.kan B-spline + SiLU residual Liu et al. 2024 — arXiv:2404.19756
fx.kan.efficient B-spline (memory-optimised) Blealtan 2024 — github.com/Blealtan/efficient-kan
fx.kan.fast Gaussian RBF Li 2024 — arXiv:2405.06721
fx.kan.fourier sin/cos series GistNoesis 2024 — github.com/GistNoesis/FourierKAN
fx.kan.chebyshev Chebyshev T_n SS 2024 — arXiv:2405.07200
fx.kan.jacobi Jacobi P_n^(α,β) Aghaei 2024 (fKAN) — arXiv:2406.07456
fx.kan.legendre Legendre P_n Seydi 2024 — arXiv:2406.02583
fx.kan.wavelet Mexican hat / Morlet / Shannon / DoG Bozorgasl & Chen 2024 (Wav-KAN) — arXiv:2405.12832
fx.kan.taylor Truncated power series github.com/Muyuzhierchengse/TaylorKAN
fx.kan.hermite Hermite He_n Seydi 2024 — arXiv:2406.02583
fx.kan.laguerre Laguerre L_n Seydi 2024 — arXiv:2406.02583
fx.kan.bernstein Bernstein polynomials Seydi 2024 — arXiv:2406.02583
fx.kan.relu (ReLU·ReLU)^order on a grid Qiu et al. 2024 — arXiv:2406.02075
fx.kan.rational Padé-style rational Chebyshev Aghaei 2024 (rKAN) — arXiv:2406.14495
fx.kan.sinc sinc basis on a grid Yu et al. 2024 (SincKAN) — arXiv:2410.04096
fx.kan.gram Orthonormal Legendre (Gram limit) Seydi 2024 — arXiv:2406.02583
fx.kan.bsrbf B-spline + RBF concatenation Ta 2024 (BSRBF-KAN) — arXiv:2406.11173

Foundation-model wrappers

Namespace Variants Backbone Reference
fx.poseidon T, B, L ScOT (Swin operator transformer) Herde et al. 2024 — arXiv:2405.19101
fx.morph Ti, S, M, L ViT3D regression Rautela et al. 2025 — arXiv:2509.21670
fx.mpp Ti, S, B, L AViT (axial ViT) McCabe et al., NeurIPS 2024 — openreview/DKSI3bULiZ
fx.walrus base Encoder-processor-decoder (1.29B) McCabe et al. 2025 — arXiv:2511.15684
fx.bcat base Block-causal transformer Liu et al. 2025 — arXiv:2501.18972
fx.pdeformer2 small, base, fast Graphormer + INR Ye et al. 2025 — arXiv:2507.15409
fx.dpot Ti, S, M, L, H DPOTNet (AFNO) Hao et al., ICML 2024 — arXiv:2403.03542
fx.prose fd_1to1, fd_2to1, ode_2to1, pde_2to1 Seq-to-seq transformer Liu et al. 2023 — arXiv:2309.16816; follow-up Sun et al. 2024 — arXiv:2404.12355
fx.timesfm small Decoder-only transformer (time-series, 200M, Flax NNX wrap; jit + fine-tuning) Das et al. 2024 — arXiv:2310.10688

Pretrained weights keep their upstream licenses — see THIRD_PARTY_LICENSES.

Quick Start

import foundax as fx

# Core models
model = fx.mlp(in_features=2, output_dim=1, hidden_dims=64, num_layers=3)
model = fx.fno2d(in_features=1, hidden_channels=32, n_modes=16)
model = fx.unet2d(in_channels=1, out_channels=1)
model = fx.deeponet(branch_type="mlp", trunk_type="mlp")

# KAN family (one of 17 variants)
model = fx.kan.fast(in_features=2, output_dim=1, hidden_dims=64, num_layers=3)

# Foundation wrappers (namespace style)
model = fx.poseidon.T()           # T/B/L
model = fx.morph.S()              # Ti/S/M/L
model = fx.mpp.B(n_states=12)     # Ti/S/B/L
model = fx.walrus.base()
model = fx.bcat.base()
model = fx.pdeformer2.small()     # small/base/fast
model = fx.dpot.Ti()              # Ti/S/M/L/H
model, variables = fx.prose.fd_1to1()

# TimesFM 2.5 — time-series foundation model (Flax NNX wrap, 200M params)
# - input/output: channel-last unbatched (context, 1) → (horizon, 1);
#   also accepts (B, context, 1) → (B, horizon, 1)
# - horizon fixed at construction so JIT can specialise on it
# - context must be a multiple of 32 (input patch size)
import jax.numpy as jnp
model = fx.timesfm.small(horizon=24)    # downloads checkpoint from HF on first call
y = model(jnp.zeros((512, 1)))          # (24, 1)

TimesFM extras (JIT, fine-tuning)

import equinox as eqx, optax, jax.numpy as jnp
import foundax as fx

model = fx.timesfm.small(horizon=24)

# JIT — compiles once per input shape
fast = eqx.filter_jit(model)
y = fast(jnp.zeros((512, 1)))

# Fine-tuning — gradients flow through the full ~231M-param state
optimizer = optax.adamw(1e-5)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

@eqx.filter_jit
def step(model, opt_state, x, y):
    def loss_fn(m): return jnp.mean((m(x) - y) ** 2)
    loss, grads = eqx.filter_value_and_grad(loss_fn)(model)
    updates, opt_state = optimizer.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    return eqx.apply_updates(model, updates), opt_state, loss

For .eqx checkpoint serialisation (and the training/inference helpers built on top of it) see jNOjno.nn.wrap(fx.timesfm.small(horizon=24)).initialize('./timesfm.eqx').

Caveats:

  • jax.vmap over the wrapper is not supported — the upstream decode() uses nnx.scan for the per-layer carry, which conflicts with vmap's trace level. Use the explicit (B, context, 1) batched form instead (JIT specialises per batch shape; equally efficient).
  • TimesFM is strictly univariate — the channel axis is always size 1.
  • Each distinct horizon triggers a separate JIT trace; build one model per horizon you care about.

Composable Pipe API

Wrap any model or layer with fx.block() and chain them with |. Channel mismatches are caught at construction time with a clear error message.

import jax
import foundax as fx

ks = jax.random.split(jax.random.PRNGKey(0), 8)

# ── Build a 2-D FNO-style pipeline from individual spectral layers ──────────
lift    = fx.block(fx.layers.SpectralBlock2d(1,  32, n_modes=16, key=ks[0]), name="lift")
s1      = fx.block(fx.layers.SpectralBlock2d(32, 32, n_modes=16, key=ks[1]))
s2      = fx.block(fx.layers.SpectralBlock2d(32, 32, n_modes=16, key=ks[2]))
s3      = fx.block(fx.layers.SpectralBlock2d(32, 32, n_modes=16, key=ks[3]))
project = fx.block(fx.layers.SpectralBlock2d(32,  1, n_modes=16, key=ks[4]), name="project")

model = lift | s1 | s2 | s3 | project   # Pipe of 5 blocks

# ── Existing full models work as blocks too ──────────────────────────────────
encoder = fx.block(fx.fno2d(in_features=3, hidden_channels=32, n_modes=16, key=ks[5]))
decoder = fx.block(fx.layers.SpectralBlock2d(32, 1, n_modes=16, key=ks[6]))

model = encoder | decoder

# ── Multi-input combinators (DeepONet-style) ─────────────────────────────────
branch = (
    fx.block(fx.layers.SpectralBlock1d(1, 32, n_modes=16, key=ks[0]))
    | fx.block(fx.mlp(in_features=32, output_dim=64, hidden_dims=64, key=ks[1]))
)
trunk = fx.block(fx.mlp(in_features=2, output_dim=64, hidden_dims=64, key=ks[2]))

model = fx.dot(branch, trunk)   # branch(u) · trunk(y)  →  (N_pts,)

# Also available: fx.add(a, b)  — elementwise sum of two branches
#                 fx.cat(a, b)  — concatenate outputs along the channel axis

# ── All pipe models are plain Equinox modules ────────────────────────────────
import equinox as eqx, optax, jax.numpy as jnp

opt   = optax.adam(1e-3)
state = opt.init(eqx.filter(model, eqx.is_array))

@eqx.filter_jit
def step(model, state, u, y, target):
    loss, grads = eqx.filter_value_and_grad(
        lambda m: jnp.mean((m(u, y) - target) ** 2)
    )(model)
    updates, state = opt.update(grads, state, eqx.filter(model, eqx.is_array))
    return eqx.apply_updates(model, updates), state, loss

Citation

If you use foundax in academic work, the accompanying paper is the jNO preprint (arXiv:2605.10159). A machine-readable CITATION.cff is provided.

AI Disclosure

Parts of this codebase — including model ports, tests, and documentation — were developed with the assistance of AI coding tools. All contributions are reviewed and tested to the best of our ability, but mistakes may remain; please open an issue if you spot one.

License

EPL-2.0 — see LICENSE. Vendored foundation-model code and pretrained weights keep their original licenses (see THIRD_PARTY_LICENSES); Poseidon weights are non-commercial.

Packages

 
 
 

Contributors

Languages