Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions diffsol/examples/bdf_tstop_reproducer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
use diffsol::{
BdfState, NalgebraContext, NalgebraLU, NalgebraMat, NalgebraVec, OdeBuilder, OdeSolverMethod,
OdeSolverStopReason, VectorHost,
};

type M = NalgebraMat<f64>;
type LS = NalgebraLU<f64>;

const K: f64 = 0.1;
const T_END: f64 = 10.0;
const N_TARGETS: usize = 100;
const H0: f64 = 1e-3;
const RTOL: f64 = 1e-6;
const ATOL: f64 = 1e-8;

fn make_problem(
t0: f64,
y0: f64,
) -> diffsol::OdeSolverProblem<
impl diffsol::OdeEquationsImplicit<M = M, V = NalgebraVec<f64>, T = f64, C = NalgebraContext>,
> {
OdeBuilder::<M>::new()
.t0(t0)
.h0(H0)
.rtol(RTOL)
.atol([ATOL])
.p([K])
.rhs_implicit(
|x: &NalgebraVec<f64>, p: &NalgebraVec<f64>, _t: f64, y: &mut NalgebraVec<f64>| {
y.as_mut_slice()[0] = -p.as_slice()[0] * x.as_slice()[0];
},
|_: &NalgebraVec<f64>,
p: &NalgebraVec<f64>,
_t: f64,
v: &NalgebraVec<f64>,
jv: &mut NalgebraVec<f64>| {
jv.as_mut_slice()[0] = -p.as_slice()[0] * v.as_slice()[0];
},
)
.init(
move |_p: &NalgebraVec<f64>, _t: f64, y: &mut NalgebraVec<f64>| {
y.as_mut_slice()[0] = y0;
},
1,
)
.build()
.unwrap()
}

fn target(i: usize) -> f64 {
i as f64 * T_END / N_TARGETS as f64
}

fn single_sweep() -> usize {
let problem = make_problem(0.0, 100.0);
let mut solver = problem.bdf::<LS>().unwrap();
while solver.state().t < T_END {
solver.step().unwrap();
}
let _ = solver.interpolate(T_END).unwrap();
solver.get_statistics().number_of_steps
}

fn cold_restart_each_target() -> usize {
let mut steps = 0;
let mut t0 = 0.0;
let mut y0 = 100.0;
for i in 1..=N_TARGETS {
let t = target(i);
let problem = make_problem(t0, y0);
let mut solver = problem.bdf::<LS>().unwrap();
solver.set_stop_time(t).unwrap();
loop {
if solver.step().unwrap() == OdeSolverStopReason::TstopReached {
break;
}
}
steps += solver.get_statistics().number_of_steps;
t0 = t;
y0 = solver.state().y.as_slice()[0];
}
steps
}

fn checkpoint_restart_with_tstop() -> usize {
let mut steps = 0;
let mut t0 = 0.0;
let mut y0 = 100.0;
let mut checkpoint: Option<BdfState<NalgebraVec<f64>>> = None;
for i in 1..=N_TARGETS {
let t = target(i);
let problem = make_problem(t0, y0);
let mut solver = if let Some(state) = checkpoint.take() {
problem.bdf_solver::<LS>(state).unwrap()
} else {
problem.bdf::<LS>().unwrap()
};
solver.set_stop_time(t).unwrap();
loop {
if solver.step().unwrap() == OdeSolverStopReason::TstopReached {
break;
}
}
steps += solver.get_statistics().number_of_steps;
t0 = t;
y0 = solver.state().y.as_slice()[0];
checkpoint = Some(solver.checkpoint());
}
steps
}

fn checkpoint_restart_with_interpolation() -> usize {
let mut steps = 0;
let mut t0 = 0.0;
let mut y0 = 100.0;
let mut checkpoint: Option<BdfState<NalgebraVec<f64>>> = None;
for i in 1..=N_TARGETS {
let t = target(i);
let problem = make_problem(t0, y0);
let mut solver = if let Some(state) = checkpoint.take() {
problem.bdf_solver::<LS>(state).unwrap()
} else {
problem.bdf::<LS>().unwrap()
};
while solver.state().t < t {
solver.step().unwrap();
}
let y = solver.interpolate(t).unwrap();
steps += solver.get_statistics().number_of_steps;
t0 = t;
y0 = y.as_slice()[0];
checkpoint = Some(solver.checkpoint());
}
steps
}

fn main() {
println!("BDF tstop checkpoint/restart reproducer");
println!("ODE: dy/dt = -0.1 y, t in [0, 10], 100 output targets");
println!("rtol={RTOL}, atol={ATOL}, h0={H0}");
println!();
println!("{:<48}{}", "single sweep", single_sweep());
println!(
"{:<48}{}",
"cold restart at each target",
cold_restart_each_target()
);
println!(
"{:<48}{}",
"checkpoint/restart + set_stop_time",
checkpoint_restart_with_tstop()
);
println!(
"{:<48}{}",
"checkpoint/restart + overshoot/interpolate",
checkpoint_restart_with_interpolation()
);
}
42 changes: 41 additions & 1 deletion diffsol/src/ode_solver/bdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ pub struct Bdf<
statistics: BdfStatistics,
state: BdfState<Eqn::V, M>,
tstop: Option<Eqn::T>,
h_before_tstop: Option<Eqn::T>,
root_finder: Option<RootFinder<Eqn::V>>,
is_state_modified: bool,
jacobian_update: JacobianUpdate<Eqn::T>,
Expand Down Expand Up @@ -209,6 +210,7 @@ where
statistics: BdfStatistics::default(),
state: self.state.clone(),
tstop: self.tstop,
h_before_tstop: self.h_before_tstop,
root_finder: self.root_finder.clone(),
is_state_modified: self.is_state_modified,
jacobian_update: self.jacobian_update.clone(),
Expand Down Expand Up @@ -363,6 +365,7 @@ where
statistics: BdfStatistics::default(),
state,
tstop: None,
h_before_tstop: None,
root_finder,
is_state_modified,
jacobian_update: JacobianUpdate::new(&problem.ode_options),
Expand Down Expand Up @@ -687,6 +690,9 @@ where
* (abs(state.t) + abs(state.h));
if abs(state.t - tstop) <= troundoff {
self.tstop = None;
if let Some(h) = self.h_before_tstop.take() {
self.restore_step_size_after_tstop(h)?;
}
return Ok(Some(OdeSolverStopReason::TstopReached));
} else if (state.h > M::T::zero() && tstop < state.t - troundoff)
|| (state.h < M::T::zero() && tstop > state.t + troundoff)
Expand All @@ -696,6 +702,7 @@ where
state_time: state.t.to_f64().unwrap(),
};
self.tstop = None;
self.h_before_tstop = None;

return Err(DiffsolError::from(error));
}
Expand All @@ -709,12 +716,23 @@ where
tstop.to_f64().unwrap()
);
let factor = (tstop - state.t) / state.h;
if self.h_before_tstop.is_none() {
self.h_before_tstop = Some(state.h);
}
// update step size ignoring the possible "step size too small" error
let _ = self._update_step_size(factor);
}
Ok(None)
}

fn restore_step_size_after_tstop(&mut self, h: Eqn::T) -> Result<(), DiffsolError> {
let factor = h / self.state.h;
self._update_step_size(factor)?;
self.state.dy.copy_from_view(&self.state.diff.column(1));
self.state.dy *= scale(Eqn::T::one() / self.state.h);
Ok(())
}

fn initialise_to_first_order(&mut self) {
self.state.n_equal_steps = 0;

Expand Down Expand Up @@ -1623,7 +1641,7 @@ mod test {
},
scale, BdfState, ConstantOp, Context, DenseMatrix, FaerLU, FaerMat, FaerSparseLU,
FaerSparseMat, MatrixCommon, NalgebraLU, NalgebraVec, OdeBuilder, OdeEquations,
OdeSolverMethod, Op, Vector, VectorHost, VectorView,
OdeSolverMethod, OdeSolverStopReason, Op, Vector, VectorHost, VectorView,
};

type M = NalgebraMat<f64>;
Expand Down Expand Up @@ -1781,6 +1799,28 @@ mod test {
);
}

#[test]
fn bdf_tstop_restores_pre_truncation_step_size() {
let (problem, _) = exponential_decay_problem::<M>(false);
let mut solver = problem.bdf::<LS>().unwrap();
let h_before = solver.state().h;
let tstop = solver.state().t + h_before / 10.0;

solver.set_stop_time(tstop).unwrap();
loop {
if solver.step().unwrap() == OdeSolverStopReason::TstopReached {
break;
}
}

assert!(
solver.state().h > h_before * 0.9,
"tstop truncation should not persist as the next proposed step size: before={}, after={}",
h_before,
solver.state().h
);
}

#[test]
fn bdf_test_faer_exponential_decay() {
type M = FaerMat<f64>;
Expand Down