Skip to content

Leave some elements of jax array unchanged during diffeqsolve without computation #747

@BenjaminDAnjou

Description

@BenjaminDAnjou

I have a differential system where the solution ys is a simple jax.Array. I need to repeatedly evolve this array using a sequence of diffeqsolve, each time evolving a different subarray based on a condition. The simplest way is to set the vector field to zero when I don't want an element to change. Here is an example:

ys = jnp.ones((10,))
thresh = jnp.linspace(0, 1, 10)

def vf(t, ys, args):
    thresh = args
    return jnp.where(ys > thresh, -ys, jnp.zeros_like(ys))

term = dx.ODETerm(vf)
solver = dx.Tsit5()
stepsize_controller = dx.PIDController(rtol=1e-8, atol=1e-8)

sol = dx.diffeqsolve(term, solver, t0=0., t1=1., dt0=0.01, y0=ys, stepsize_controller=stepsize_controller, args=thresh)

sol = dx.diffeqsolve(term, solver, t0=1., t1=2., dt0=0.01, y0=sol.ys, stepsize_controller=stepsize_controller, args=jnp.flip(thresh))

This gives the desired result. However, it seems wasteful to evolve the elements for which the vector field is zero.

I could try to split the relevant subarray and only evolve that part, but the size of that subarray could vary between iterations, so I'd rather operate on the whole array.

Is there an intelligent way to operate diffeqolve on a subarray without actually propagating the rest, or at least propagating the rest at marginal cost? Or is the best way to set the vector field to zero?

Another option is to convert the array into a list and use equinox.partition. But I would like to avoid changing the structure if I can. Basically, my question is: is it possible to do something like equinox.partition at the jax.Array level?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions