I have a differential system where the solution ys is a simple jax.Array. I need to repeatedly evolve this array using a sequence of diffeqsolve, each time evolving a different subarray based on a condition. The simplest way is to set the vector field to zero when I don't want an element to change. Here is an example:
ys = jnp.ones((10,))
thresh = jnp.linspace(0, 1, 10)
def vf(t, ys, args):
thresh = args
return jnp.where(ys > thresh, -ys, jnp.zeros_like(ys))
term = dx.ODETerm(vf)
solver = dx.Tsit5()
stepsize_controller = dx.PIDController(rtol=1e-8, atol=1e-8)
sol = dx.diffeqsolve(term, solver, t0=0., t1=1., dt0=0.01, y0=ys, stepsize_controller=stepsize_controller, args=thresh)
sol = dx.diffeqsolve(term, solver, t0=1., t1=2., dt0=0.01, y0=sol.ys, stepsize_controller=stepsize_controller, args=jnp.flip(thresh))
This gives the desired result. However, it seems wasteful to evolve the elements for which the vector field is zero.
I could try to split the relevant subarray and only evolve that part, but the size of that subarray could vary between iterations, so I'd rather operate on the whole array.
Is there an intelligent way to operate diffeqolve on a subarray without actually propagating the rest, or at least propagating the rest at marginal cost? Or is the best way to set the vector field to zero?
Another option is to convert the array into a list and use equinox.partition. But I would like to avoid changing the structure if I can. Basically, my question is: is it possible to do something like equinox.partition at the jax.Array level?
I have a differential system where the solution
ysis a simplejax.Array. I need to repeatedly evolve this array using a sequence ofdiffeqsolve, each time evolving a different subarray based on a condition. The simplest way is to set the vector field to zero when I don't want an element to change. Here is an example:This gives the desired result. However, it seems wasteful to evolve the elements for which the vector field is zero.
I could try to split the relevant subarray and only evolve that part, but the size of that subarray could vary between iterations, so I'd rather operate on the whole array.
Is there an intelligent way to operate
diffeqolveon a subarray without actually propagating the rest, or at least propagating the rest at marginal cost? Or is the best way to set the vector field to zero?Another option is to convert the array into a list and use
equinox.partition. But I would like to avoid changing the structure if I can. Basically, my question is: is it possible to do something likeequinox.partitionat thejax.Arraylevel?