Skip to content

Argument mismatch and hard-coded return_all_eval #195

@cantabile-kwok

Description

@cantabile-kwok

There are mismatched arguments in problems.ODEProblem.odeint
My torchdyn version is 1.0.3
Step to Reproduce
I want to see how many steps did the adaptive dopri5 solver take, so I sought for return_all_eval argument according to issue #131. Then I found the NeuralODE class does not provide such a keyword argument here, so after a little bit diving into the source code I decided to put args={'return_all_eval': True}. However, this still does not give the desired result. The code snippet is:

from torchdyn.core import NeuralODE
import torch
import torch.nn as nn


class VectorField(nn.Module):
    def __init__(self):
        super(VectorField, self).__init__()
        self.net = nn.Linear(2, 2)

    def forward(self, t, x):
        print(f"In VectorField, t is fed as {t}")
        return self.net(t+x)


vf = VectorField()
ode = NeuralODE(vf, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
time = torch.linspace(0, 1, 10)
initial = torch.randn(16, 20, 2)
eval_time, sol = ode(initial, time, args={'return_all_eval': True})
print(sol.shape)

Then, I found the return_all_eval keyword is not actually passed into the numerics.odeint.odeint function. The signature of that function is

def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, nn.Module], atol:float=1e-3, rtol:float=1e-3,
		   t_stops:Union[List, Tensor, None]=None, verbose:bool=False, interpolator:Union[str, Callable, None]=None, return_all_eval:bool=False,
		   save_at:Union[List, Tensor]=(), args:Dict={}, seminorm:Tuple[bool, Union[int, None]]=(False, None)) -> Tuple[Tensor, Tensor]:

so you can see return_all_eval is explicitly passed, but in numerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward it is hard-coded as False:

def forward(ctx, vf_params, x, t_span, B=None, save_at=()):
            t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B, 
                                        False, maxiter, fine_steps, save_at)
            ctx.save_for_backward(sol, t_sol)
            return t_sol, sol

So, basically I don't have any chance to switch it on except changing the source code.

Another thing is the argument mismatch issue of the numerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward function. When it is called from odeint like

return self._autograd_func()(self.vf_params, x, t_span, save_at, args)
, the arguments are mismatched from the signature of that forward function. This means the save_at argument will actually be overwritten by a dict and the B (which I do not understand) argument is actually the true save_at. This so far has not caused any problems in my code but I don't believe this is an expected behavior. I suggest someone take a deep debug into the code to have a look.

Screenshots
There is a traceback that shows the problem.
image

Expected behavior

The return_all_eval option should be handled by user and control whether the ODE solver produces all the evaluation time slots.
Also, there is a huge lack of documentation on the meaning of these arguments and the provided functionalities, e.g. it is not until I found that github issue did I realize that there is a way to return all the evaluation time stamps.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions