Skip to content
7 changes: 6 additions & 1 deletion src/gretl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@ set(gretl_sources
about.cpp
data_store.cpp
state_base.cpp
vector_state.cpp)
vector_state.cpp
wang_checkpoint_strategy.cpp
strumm_walther_checkpoint_strategy.cpp)

set(gretl_headers
about.hpp
checkpoint.hpp
checkpoint_strategy.hpp
wang_checkpoint_strategy.hpp
strumm_walther_checkpoint_strategy.hpp
create_state.hpp
data_store.hpp
double_state.hpp
Expand Down
165 changes: 10 additions & 155 deletions src/gretl/checkpoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@

#pragma once

#include <set>
#include <map>
#include <ostream>
#include <iostream>
#include <cassert>
#include <limits>
#include <functional>
#include <memory>

#include "checkpoint_strategy.hpp"

/// @brief gretl_assert that prints line and file info before throwing in release and halting in debug
#define gretl_assert(x) \
Expand All @@ -32,153 +35,21 @@

namespace gretl {

/// @brief checkpoint struct which tracks level and step per "Minimal Repetition Dynamic Checkpointing Algorithm for
/// Unsteady Adjoint Calculation", Wang, et al. , 2009.
struct Checkpoint {
size_t level; ///< level
size_t step; ///< step
static constexpr size_t infinity()
{
return std::numeric_limits<size_t>::max();
} ///< The largest possible step and level value
};

/// @brief comparison operator between two checkpoints to determine which is most disposable per the dynamic
/// checkpointing algorithm
inline bool operator<(const Checkpoint& a, const Checkpoint& b)
{
if (a.level == Checkpoint::infinity() && b.level == Checkpoint::infinity()) {
return a.step > b.step;
}
if (a.level == Checkpoint::infinity()) return false;
if (b.level == Checkpoint::infinity()) return true;
return a.step > b.step;
}

/// @brief output stream for a single checkpoint
inline std::ostream& operator<<(std::ostream& stream, const Checkpoint& p);

/// @brief CheckpointManager class which encapsulates the logic of when and which steps should be dynamically saved a
/// fetched
struct CheckpointManager {
static constexpr size_t invalidCheckpointIndex =
std::numeric_limits<size_t>::max(); ///< magic number of invalid checkpoint

/// @brief utilty for checking if an index is valid. There is a magic number, invalidCheckpointIndex, which
/// represents an invalid checkpoint
static bool valid_checkpoint_index(size_t i) { return i != invalidCheckpointIndex; }

/// @brief returns const_iterator to currently most dispensable checkpoint step
std::set<gretl::Checkpoint>::const_iterator most_dispensable() const
{
size_t maxHigherTimeLevel = 0;
for (auto rIter = cps.begin(); rIter != cps.end(); ++rIter) {
if (rIter->level < maxHigherTimeLevel) {
return rIter;
}
maxHigherTimeLevel = std::max(rIter->level, maxHigherTimeLevel);
}
return cps.end();
}

/// @brief this does multiple things
/// 1. it adds checkpoints into the database, and updates internal data structures
/// 2. it determines if a checkpoint needs to be removed
/// 3. if a checkpoint needs to be removed, it returns the index for that checkpoint
/// 4. otherwise, it returns zero
size_t add_checkpoint_and_get_index_to_remove(size_t step, bool persistent = false)
{
size_t levelupAmount = 1; //= relativeCost >= 2.0 ? 3 : 1;

Checkpoint nextStep{.level = levelupAmount - 1, .step = step};

size_t nextEraseStep = invalidCheckpointIndex;

// don't include persistent data in data quota. MRT, this might change
if (persistent) {
maxNumStates++;
nextStep.level = Checkpoint::infinity();
gretl_assert(cps.size() < maxNumStates);
}

if (cps.size() < maxNumStates) {
cps.insert(nextStep);
} else {
auto iterToMostDispensable = most_dispensable();
if (iterToMostDispensable != cps.end()) {
nextEraseStep = iterToMostDispensable->step;
cps.erase(iterToMostDispensable);
cps.insert(nextStep);
} else {
nextEraseStep = cps.begin()->step;
nextStep.level = cps.begin()->level + levelupAmount;

cps.erase(cps.begin());
cps.insert(nextStep);
}
}

return nextEraseStep;
}

/// @brief return largest currently checkpointed step
size_t last_checkpoint_step() const { return cps.begin()->step; }

/// @brief erase
bool erase_step(size_t stepIndex)
{
for (std::set<Checkpoint>::iterator it = cps.begin(); it != cps.end(); ++it) {
if (it->step == stepIndex) {
if (it->level != Checkpoint::infinity()) {
cps.erase(it);
return true;
}
}
}
return false;
}

/// @brief check if this step is currently checkpointed. This could potentially use performance optimization down the
/// way.
bool contains_step(size_t stepIndex) const
{
for (auto& c : cps) {
if (c.step == stepIndex) {
return true;
}
}
return false;
}

/// @brief erase all non persistent checkpoints
void reset()
{
for (auto cp_it = cps.begin(); cp_it != cps.end(); ++cp_it) {
if (cp_it->level == Checkpoint::infinity()) {
cps.erase(cps.begin(), cp_it);
break;
}
}
}

size_t maxNumStates = 20; ///< The max number of non-persistent, not-in-scope states stored by the CheckpointManager
std::set<Checkpoint> cps; ///< Vector of checkpoints
};

/// @brief interface to run forward with a linear graph, checkpoint, then automatically backpropagate the sensitivities
/// given the reverse_callback vjp.
/// @tparam T type of each state's data
/// @param numSteps number of forward iterations
/// @param storageSize maximum states to save in memory at a time
/// @param x initial condition
/// @param update_func function which evaluates the forward response
/// @param reverse_callback vjp function (action of Jacobian-transposed) to back propagate sensitivities
/// @param strategy checkpoint strategy (required)
/// @return
template <typename T>
T advance_and_reverse_steps(size_t numSteps, size_t storageSize, T x, std::function<T(size_t n, const T&)> update_func,
std::function<void(size_t n, const T&)> reverse_callback)
T advance_and_reverse_steps(size_t numSteps, T x, std::function<T(size_t n, const T&)> update_func,
std::function<void(size_t n, const T&)> reverse_callback,
std::unique_ptr<CheckpointStrategy> strategy)
{
gretl::CheckpointManager cps{.maxNumStates = storageSize, .cps{}};
CheckpointStrategy& cps = *strategy;
std::map<size_t, T> savedCps;
savedCps[0] = x;

Expand All @@ -204,6 +75,7 @@ T advance_and_reverse_steps(size_t numSteps, size_t storageSize, T x, std::funct
savedCps.erase(eraseStep);
}
savedCps[lastCp + 1] = x;
cps.record_recomputation();
}
reverse_callback(i, savedCps[i]);

Expand All @@ -214,21 +86,4 @@ T advance_and_reverse_steps(size_t numSteps, size_t storageSize, T x, std::funct
return xf;
}

/// @brief ostream operator for writing out checkpoint information
inline std::ostream& operator<<(std::ostream& stream, const Checkpoint& p)
{
return stream << " lvl=" << p.level << ", step=" << p.step;
}

/// @brief ostream operator for writing out information about the entire checkpoint manager to see the set of currently
/// checkpointed states
inline std::ostream& operator<<(std::ostream& stream, const CheckpointManager& set)
{
stream << "CHECKPOINTS: capacity = " << set.maxNumStates << std::endl;
for (const auto& s : set.cps) {
stream << s << "\n";
}
return stream;
}

} // namespace gretl
88 changes: 88 additions & 0 deletions src/gretl/checkpoint_strategy.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (c) Lawrence Livermore National Security, LLC and
// other Gretl Project Developers. See the top-level LICENSE file for
// details.
//
// SPDX-License-Identifier: (BSD-3-Clause)

/**
* @file checkpoint_strategy.hpp
* @brief Abstract interface for checkpoint eviction strategies.
*/

#pragma once

#include <cstddef>
#include <limits>
#include <memory>
#include <ostream>

namespace gretl {

/// @brief Performance counters for comparing checkpoint algorithms.
struct CheckpointMetrics {
size_t stores = 0; ///< Number of checkpoint store operations
size_t evictions = 0; ///< Number of checkpoint evictions
size_t recomputations = 0; ///< Forward re-evaluations triggered during reverse
};

/// @brief Abstract interface for checkpoint eviction strategies.
///
/// Implementations decide which step to evict when checkpoint capacity is
/// exceeded. The interface exposes only the operations that DataStore
/// requires, hiding all algorithm-specific data structures.
class CheckpointStrategy {
public:
static constexpr size_t invalidCheckpointIndex =
std::numeric_limits<size_t>::max(); ///< Magic number for invalid checkpoint

/// @brief Check if a checkpoint index is valid
static bool valid_checkpoint_index(size_t i) { return i != invalidCheckpointIndex; }

virtual ~CheckpointStrategy() = default;

/// @brief Add a checkpoint for the given step.
/// @param step The step index to checkpoint.
/// @param persistent If true, this checkpoint cannot be evicted.
/// @return The step index to evict, or invalidCheckpointIndex if none.
virtual size_t add_checkpoint_and_get_index_to_remove(size_t step, bool persistent = false) = 0;

/// @brief Return the step index of the earliest currently stored checkpoint.
virtual size_t last_checkpoint_step() const = 0;

/// @brief Remove the checkpoint at the given step.
/// @return true if a checkpoint was found and removed.
virtual bool erase_step(size_t stepIndex) = 0;

/// @brief Check if a checkpoint exists for the given step.
virtual bool contains_step(size_t stepIndex) const = 0;

/// @brief Clear all non-persistent checkpoints.
virtual void reset() = 0;

/// @brief Return the maximum number of non-persistent checkpoint slots.
virtual size_t capacity() const = 0;

/// @brief Return the current number of checkpoints (persistent + non-persistent).
virtual size_t size() const = 0;

/// @brief Print checkpoint state to the output stream.
virtual void print(std::ostream& os) const = 0;

/// @brief Return accumulated performance metrics.
virtual CheckpointMetrics metrics() const = 0;

/// @brief Reset accumulated performance metrics to zero.
virtual void reset_metrics() = 0;

/// @brief Record a forward recomputation (called by DataStore during fetch).
virtual void record_recomputation() = 0;
};

/// @brief ostream operator for CheckpointStrategy
inline std::ostream& operator<<(std::ostream& os, const CheckpointStrategy& s)
{
s.print(os);
return os;
}

} // namespace gretl
10 changes: 4 additions & 6 deletions src/gretl/create_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@ gretl::State<T, D> create_state_impl(
auto newState = state0.template create_state<T, D>(state_bases, zeroFunc);

newState.set_eval([eval](const gretl::UpstreamStates& inputs, gretl::DownstreamState& output) {
const T e =
eval(inputs[0].get<typename State0::type>(), inputs[state_indices + 1].get<typename StatesN::type>()...);
output.set<T, D>(e);
T e = eval(inputs[0].get<typename State0::type>(), inputs[state_indices + 1].get<typename StatesN::type>()...);
output.set<T, D>(std::move(e));
});

newState.set_vjp([vjp](gretl::UpstreamStates& inputs, const gretl::DownstreamState& output) {
Expand Down Expand Up @@ -104,9 +103,8 @@ gretl::State<typename State0::type, typename State0::dual_type> clone_state_impl
auto newState = state0.clone(state_bases);

newState.set_eval([eval](const gretl::UpstreamStates& inputs, gretl::DownstreamState& output) {
const T e =
eval(inputs[0].get<typename State0::type>(), inputs[state_indices + 1].get<typename StatesN::type>()...);
output.set<T, D>(e);
T e = eval(inputs[0].get<typename State0::type>(), inputs[state_indices + 1].get<typename StatesN::type>()...);
output.set<T, D>(std::move(e));
});

newState.set_vjp([vjp](gretl::UpstreamStates& inputs, const gretl::DownstreamState& output) {
Expand Down
Loading