Skip to content

NODE Fails with "Detected differentiation of a custom_jvp function with respect to a closed-over value" #41

Description

@adam-hartshorne

The simple MVE NODE shown below produces the following error,

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/media/adam/shared_drive/PycharmProjects/deer_test/deer_ode_solver.py", line 161, in
main()
File "/media/adam/shared_drive/PycharmProjects/deer_test/deer_ode_solver.py", line 142, in main
loss, model, opt_state = make_step(_ts, yi, model, opt_state)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_jit.py", line 239, in call
return self._call(False, args, kwargs)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_module.py", line 1093, in call
return self.func(self.self, *args, **kwargs)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_jit.py", line 212, in _call
out = self._cached(dynamic_donate, dynamic_nodonate, static)
File "/media/adam/shared_drive/PycharmProjects/deer_test/deer_ode_solver.py", line 128, in make_step
loss, grads = grad_loss(model, ti, yi)
File "/media/adam/shared_drive/PycharmProjects/deer_test/deer_ode_solver.py", line 123, in grad_loss
y_pred = jax.vmap(model, in_axes=(None, 0))(ti, yi[:, 0])
File "/media/adam/shared_drive/PycharmProjects/deer_test/deer_ode_solver.py", line 53, in call
res = solve_ivp(self.func, y0, tpts[..., None], None, tpts, method=solve_ivp.DEER())
File "/home/adam/Downloads/deer-mfk/deer/fsolve_ivp.py", line 80, in solve_ivp
return method.compute(func, y0, xinp, params, tpts)
File "/home/adam/Downloads/deer-mfk/deer/fsolve_ivp.py", line 127, in compute
result = deer_iteration(

jax._src.interpreters.ad.CustomJVPException: Detected differentiation of a custom_jvp function with respect to a closed-over value. That isn't supported because the custom JVP rule only specifies how to differentiate the custom_jvp function with respect to explicit input parameters. Try passing the closed-over value into the custom_jvp function as an argument, and adapting the custom_jvp rule.

import time
import diffrax
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.test_util
import matplotlib.pyplot as plt
import optax
from deer import solve_ivp

# enable jax x64 for this test
jax.config.update("jax_enable_x64", True)

dtype = jnp.float64
npts = 10

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size+1,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            key=key,
        )

    def __call__(self, y, t, args=None):
        # concatenate the t and the y
        y = jnp.concatenate([y, jnp.full((1,), t)], axis=-1)
        return self.mlp(y)

class NeuralODE(eqx.Module):
    func: Func
    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, width_size, depth, key=key)

    def __call__(self, ts, y0):
        tpts = jnp.linspace(0, 1.0, npts, dtype=dtype)  # (ntpts,)
        res = solve_ivp(self.func, y0, tpts[..., None], None, tpts, method=solve_ivp.DEER())
        return res.value

def _get_data(ts, *, key):
    y0 = jr.uniform(key, (2,), minval=-0.6, maxval=1)

    def f(t, y, args):
        x = y / (1 + y)
        return jnp.stack([x[1], -x[0]], axis=-1)

    solver = diffrax.Tsit5()
    dt0 = 0.1
    saveat = diffrax.SaveAt(ts=ts)
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
    )
    ys = sol.ys
    return ys


def get_data(dataset_size, *, key):
    ts = jnp.linspace(0, 10, 100)
    key = jr.split(key, dataset_size)
    ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
    return ts, ys

def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jr.permutation(key, indices)
        (key,) = jr.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size


def main(
    dataset_size=256,
    batch_size=32,
    lr_strategy=(3e-3, 3e-3),
    steps_strategy=(500, 500),
    length_strategy=(0.1, 1),
    width_size=64,
    depth=2,
    seed=5678,
    plot=True,
    print_every=100,
):
    key = jr.PRNGKey(seed)
    data_key, model_key, loader_key = jr.split(key, 3)

    ts, ys = get_data(dataset_size, key=data_key)
    _, length_size, data_size = ys.shape

    model = NeuralODE(data_size, width_size, depth, key=model_key)


    @eqx.filter_value_and_grad
    def grad_loss(model, ti, yi):
        y_pred = jax.vmap(model, in_axes=(None, 0))(ti, yi[:, 0])
        return jnp.mean((yi - y_pred) ** 2)

    @eqx.filter_jit
    def make_step(ti, yi, model, opt_state):
        loss, grads = grad_loss(model, ti, yi)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
        optim = optax.adabelief(lr)
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
        _ts = ts[: int(length_size * length)]
        _ys = ys[:, : int(length_size * length)]
        for step, (yi,) in zip(
            range(steps), dataloader((_ys,), batch_size, key=loader_key)
        ):
            start = time.time()
            loss, model, opt_state = make_step(_ts, yi, model, opt_state)
            end = time.time()
            if (step % print_every) == 0 or step == steps - 1:
                print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")

    if plot:
        plt.plot(ts, ys[0, :, 0], c="dodgerblue", label="Real")
        plt.plot(ts, ys[0, :, 1], c="dodgerblue")
        model_y = model(ts, ys[0, 0])
        plt.plot(ts, model_y[:, 0], c="crimson", label="Model")
        plt.plot(ts, model_y[:, 1], c="crimson")
        plt.legend()
        plt.tight_layout()
        plt.savefig("neural_ode.png")
        plt.show()

    return ts, ys, model

if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions