diff --git a/diffsol/examples/bdf_tstop_reproducer.rs b/diffsol/examples/bdf_tstop_reproducer.rs new file mode 100644 index 00000000..0f0743d4 --- /dev/null +++ b/diffsol/examples/bdf_tstop_reproducer.rs @@ -0,0 +1,158 @@ +use diffsol::{ + BdfState, NalgebraContext, NalgebraLU, NalgebraMat, NalgebraVec, OdeBuilder, OdeSolverMethod, + OdeSolverStopReason, VectorHost, +}; + +type M = NalgebraMat; +type LS = NalgebraLU; + +const K: f64 = 0.1; +const T_END: f64 = 10.0; +const N_TARGETS: usize = 100; +const H0: f64 = 1e-3; +const RTOL: f64 = 1e-6; +const ATOL: f64 = 1e-8; + +fn make_problem( + t0: f64, + y0: f64, +) -> diffsol::OdeSolverProblem< + impl diffsol::OdeEquationsImplicit, T = f64, C = NalgebraContext>, +> { + OdeBuilder::::new() + .t0(t0) + .h0(H0) + .rtol(RTOL) + .atol([ATOL]) + .p([K]) + .rhs_implicit( + |x: &NalgebraVec, p: &NalgebraVec, _t: f64, y: &mut NalgebraVec| { + y.as_mut_slice()[0] = -p.as_slice()[0] * x.as_slice()[0]; + }, + |_: &NalgebraVec, + p: &NalgebraVec, + _t: f64, + v: &NalgebraVec, + jv: &mut NalgebraVec| { + jv.as_mut_slice()[0] = -p.as_slice()[0] * v.as_slice()[0]; + }, + ) + .init( + move |_p: &NalgebraVec, _t: f64, y: &mut NalgebraVec| { + y.as_mut_slice()[0] = y0; + }, + 1, + ) + .build() + .unwrap() +} + +fn target(i: usize) -> f64 { + i as f64 * T_END / N_TARGETS as f64 +} + +fn single_sweep() -> usize { + let problem = make_problem(0.0, 100.0); + let mut solver = problem.bdf::().unwrap(); + while solver.state().t < T_END { + solver.step().unwrap(); + } + let _ = solver.interpolate(T_END).unwrap(); + solver.get_statistics().number_of_steps +} + +fn cold_restart_each_target() -> usize { + let mut steps = 0; + let mut t0 = 0.0; + let mut y0 = 100.0; + for i in 1..=N_TARGETS { + let t = target(i); + let problem = make_problem(t0, y0); + let mut solver = problem.bdf::().unwrap(); + solver.set_stop_time(t).unwrap(); + loop { + if solver.step().unwrap() == OdeSolverStopReason::TstopReached { + break; + } + } + steps += solver.get_statistics().number_of_steps; + t0 = t; + y0 = solver.state().y.as_slice()[0]; + } + steps +} + +fn checkpoint_restart_with_tstop() -> usize { + let mut steps = 0; + let mut t0 = 0.0; + let mut y0 = 100.0; + let mut checkpoint: Option>> = None; + for i in 1..=N_TARGETS { + let t = target(i); + let problem = make_problem(t0, y0); + let mut solver = if let Some(state) = checkpoint.take() { + problem.bdf_solver::(state).unwrap() + } else { + problem.bdf::().unwrap() + }; + solver.set_stop_time(t).unwrap(); + loop { + if solver.step().unwrap() == OdeSolverStopReason::TstopReached { + break; + } + } + steps += solver.get_statistics().number_of_steps; + t0 = t; + y0 = solver.state().y.as_slice()[0]; + checkpoint = Some(solver.checkpoint()); + } + steps +} + +fn checkpoint_restart_with_interpolation() -> usize { + let mut steps = 0; + let mut t0 = 0.0; + let mut y0 = 100.0; + let mut checkpoint: Option>> = None; + for i in 1..=N_TARGETS { + let t = target(i); + let problem = make_problem(t0, y0); + let mut solver = if let Some(state) = checkpoint.take() { + problem.bdf_solver::(state).unwrap() + } else { + problem.bdf::().unwrap() + }; + while solver.state().t < t { + solver.step().unwrap(); + } + let y = solver.interpolate(t).unwrap(); + steps += solver.get_statistics().number_of_steps; + t0 = t; + y0 = y.as_slice()[0]; + checkpoint = Some(solver.checkpoint()); + } + steps +} + +fn main() { + println!("BDF tstop checkpoint/restart reproducer"); + println!("ODE: dy/dt = -0.1 y, t in [0, 10], 100 output targets"); + println!("rtol={RTOL}, atol={ATOL}, h0={H0}"); + println!(); + println!("{:<48}{}", "single sweep", single_sweep()); + println!( + "{:<48}{}", + "cold restart at each target", + cold_restart_each_target() + ); + println!( + "{:<48}{}", + "checkpoint/restart + set_stop_time", + checkpoint_restart_with_tstop() + ); + println!( + "{:<48}{}", + "checkpoint/restart + overshoot/interpolate", + checkpoint_restart_with_interpolation() + ); +} diff --git a/diffsol/src/ode_solver/bdf.rs b/diffsol/src/ode_solver/bdf.rs index 7c95e205..d268e45c 100644 --- a/diffsol/src/ode_solver/bdf.rs +++ b/diffsol/src/ode_solver/bdf.rs @@ -156,6 +156,7 @@ pub struct Bdf< statistics: BdfStatistics, state: BdfState, tstop: Option, + h_before_tstop: Option, root_finder: Option>, is_state_modified: bool, jacobian_update: JacobianUpdate, @@ -209,6 +210,7 @@ where statistics: BdfStatistics::default(), state: self.state.clone(), tstop: self.tstop, + h_before_tstop: self.h_before_tstop, root_finder: self.root_finder.clone(), is_state_modified: self.is_state_modified, jacobian_update: self.jacobian_update.clone(), @@ -363,6 +365,7 @@ where statistics: BdfStatistics::default(), state, tstop: None, + h_before_tstop: None, root_finder, is_state_modified, jacobian_update: JacobianUpdate::new(&problem.ode_options), @@ -687,6 +690,9 @@ where * (abs(state.t) + abs(state.h)); if abs(state.t - tstop) <= troundoff { self.tstop = None; + if let Some(h) = self.h_before_tstop.take() { + self.restore_step_size_after_tstop(h)?; + } return Ok(Some(OdeSolverStopReason::TstopReached)); } else if (state.h > M::T::zero() && tstop < state.t - troundoff) || (state.h < M::T::zero() && tstop > state.t + troundoff) @@ -696,6 +702,7 @@ where state_time: state.t.to_f64().unwrap(), }; self.tstop = None; + self.h_before_tstop = None; return Err(DiffsolError::from(error)); } @@ -709,12 +716,23 @@ where tstop.to_f64().unwrap() ); let factor = (tstop - state.t) / state.h; + if self.h_before_tstop.is_none() { + self.h_before_tstop = Some(state.h); + } // update step size ignoring the possible "step size too small" error let _ = self._update_step_size(factor); } Ok(None) } + fn restore_step_size_after_tstop(&mut self, h: Eqn::T) -> Result<(), DiffsolError> { + let factor = h / self.state.h; + self._update_step_size(factor)?; + self.state.dy.copy_from_view(&self.state.diff.column(1)); + self.state.dy *= scale(Eqn::T::one() / self.state.h); + Ok(()) + } + fn initialise_to_first_order(&mut self) { self.state.n_equal_steps = 0; @@ -1623,7 +1641,7 @@ mod test { }, scale, BdfState, ConstantOp, Context, DenseMatrix, FaerLU, FaerMat, FaerSparseLU, FaerSparseMat, MatrixCommon, NalgebraLU, NalgebraVec, OdeBuilder, OdeEquations, - OdeSolverMethod, Op, Vector, VectorHost, VectorView, + OdeSolverMethod, OdeSolverStopReason, Op, Vector, VectorHost, VectorView, }; type M = NalgebraMat; @@ -1781,6 +1799,28 @@ mod test { ); } + #[test] + fn bdf_tstop_restores_pre_truncation_step_size() { + let (problem, _) = exponential_decay_problem::(false); + let mut solver = problem.bdf::().unwrap(); + let h_before = solver.state().h; + let tstop = solver.state().t + h_before / 10.0; + + solver.set_stop_time(tstop).unwrap(); + loop { + if solver.step().unwrap() == OdeSolverStopReason::TstopReached { + break; + } + } + + assert!( + solver.state().h > h_before * 0.9, + "tstop truncation should not persist as the next proposed step size: before={}, after={}", + h_before, + solver.state().h + ); + } + #[test] fn bdf_test_faer_exponential_decay() { type M = FaerMat;