diff --git a/src/gretl/CMakeLists.txt b/src/gretl/CMakeLists.txt index 2214177..12e9d4c 100644 --- a/src/gretl/CMakeLists.txt +++ b/src/gretl/CMakeLists.txt @@ -16,11 +16,16 @@ set(gretl_sources about.cpp data_store.cpp state_base.cpp - vector_state.cpp) + vector_state.cpp + wang_checkpoint_strategy.cpp + strumm_walther_checkpoint_strategy.cpp) set(gretl_headers about.hpp checkpoint.hpp + checkpoint_strategy.hpp + wang_checkpoint_strategy.hpp + strumm_walther_checkpoint_strategy.hpp create_state.hpp data_store.hpp double_state.hpp diff --git a/src/gretl/checkpoint.hpp b/src/gretl/checkpoint.hpp index 7851cbe..5c9bb51 100644 --- a/src/gretl/checkpoint.hpp +++ b/src/gretl/checkpoint.hpp @@ -10,12 +10,15 @@ #pragma once -#include #include #include #include #include #include +#include +#include + +#include "checkpoint_strategy.hpp" /// @brief gretl_assert that prints line and file info before throwing in release and halting in debug #define gretl_assert(x) \ @@ -32,153 +35,21 @@ namespace gretl { -/// @brief checkpoint struct which tracks level and step per "Minimal Repetition Dynamic Checkpointing Algorithm for -/// Unsteady Adjoint Calculation", Wang, et al. , 2009. -struct Checkpoint { - size_t level; ///< level - size_t step; ///< step - static constexpr size_t infinity() - { - return std::numeric_limits::max(); - } ///< The largest possible step and level value -}; - -/// @brief comparison operator between two checkpoints to determine which is most disposable per the dynamic -/// checkpointing algorithm -inline bool operator<(const Checkpoint& a, const Checkpoint& b) -{ - if (a.level == Checkpoint::infinity() && b.level == Checkpoint::infinity()) { - return a.step > b.step; - } - if (a.level == Checkpoint::infinity()) return false; - if (b.level == Checkpoint::infinity()) return true; - return a.step > b.step; -} - -/// @brief output stream for a single checkpoint -inline std::ostream& operator<<(std::ostream& stream, const Checkpoint& p); - -/// @brief CheckpointManager class which encapsulates the logic of when and which steps should be dynamically saved a -/// fetched -struct CheckpointManager { - static constexpr size_t invalidCheckpointIndex = - std::numeric_limits::max(); ///< magic number of invalid checkpoint - - /// @brief utilty for checking if an index is valid. There is a magic number, invalidCheckpointIndex, which - /// represents an invalid checkpoint - static bool valid_checkpoint_index(size_t i) { return i != invalidCheckpointIndex; } - - /// @brief returns const_iterator to currently most dispensable checkpoint step - std::set::const_iterator most_dispensable() const - { - size_t maxHigherTimeLevel = 0; - for (auto rIter = cps.begin(); rIter != cps.end(); ++rIter) { - if (rIter->level < maxHigherTimeLevel) { - return rIter; - } - maxHigherTimeLevel = std::max(rIter->level, maxHigherTimeLevel); - } - return cps.end(); - } - - /// @brief this does multiple things - /// 1. it adds checkpoints into the database, and updates internal data structures - /// 2. it determines if a checkpoint needs to be removed - /// 3. if a checkpoint needs to be removed, it returns the index for that checkpoint - /// 4. otherwise, it returns zero - size_t add_checkpoint_and_get_index_to_remove(size_t step, bool persistent = false) - { - size_t levelupAmount = 1; //= relativeCost >= 2.0 ? 3 : 1; - - Checkpoint nextStep{.level = levelupAmount - 1, .step = step}; - - size_t nextEraseStep = invalidCheckpointIndex; - - // don't include persistent data in data quota. MRT, this might change - if (persistent) { - maxNumStates++; - nextStep.level = Checkpoint::infinity(); - gretl_assert(cps.size() < maxNumStates); - } - - if (cps.size() < maxNumStates) { - cps.insert(nextStep); - } else { - auto iterToMostDispensable = most_dispensable(); - if (iterToMostDispensable != cps.end()) { - nextEraseStep = iterToMostDispensable->step; - cps.erase(iterToMostDispensable); - cps.insert(nextStep); - } else { - nextEraseStep = cps.begin()->step; - nextStep.level = cps.begin()->level + levelupAmount; - - cps.erase(cps.begin()); - cps.insert(nextStep); - } - } - - return nextEraseStep; - } - - /// @brief return largest currently checkpointed step - size_t last_checkpoint_step() const { return cps.begin()->step; } - - /// @brief erase - bool erase_step(size_t stepIndex) - { - for (std::set::iterator it = cps.begin(); it != cps.end(); ++it) { - if (it->step == stepIndex) { - if (it->level != Checkpoint::infinity()) { - cps.erase(it); - return true; - } - } - } - return false; - } - - /// @brief check if this step is currently checkpointed. This could potentially use performance optimization down the - /// way. - bool contains_step(size_t stepIndex) const - { - for (auto& c : cps) { - if (c.step == stepIndex) { - return true; - } - } - return false; - } - - /// @brief erase all non persistent checkpoints - void reset() - { - for (auto cp_it = cps.begin(); cp_it != cps.end(); ++cp_it) { - if (cp_it->level == Checkpoint::infinity()) { - cps.erase(cps.begin(), cp_it); - break; - } - } - } - - size_t maxNumStates = 20; ///< The max number of non-persistent, not-in-scope states stored by the CheckpointManager - std::set cps; ///< Vector of checkpoints -}; - /// @brief interface to run forward with a linear graph, checkpoint, then automatically backpropagate the sensitivities /// given the reverse_callback vjp. /// @tparam T type of each state's data /// @param numSteps number of forward iterations -/// @param storageSize maximum states to save in memory at a time /// @param x initial condition /// @param update_func function which evaluates the forward response /// @param reverse_callback vjp function (action of Jacobian-transposed) to back propagate sensitivities +/// @param strategy checkpoint strategy (required) /// @return template -T advance_and_reverse_steps(size_t numSteps, size_t storageSize, T x, std::function update_func, - std::function reverse_callback) +T advance_and_reverse_steps(size_t numSteps, T x, std::function update_func, + std::function reverse_callback, + std::unique_ptr strategy) { - gretl::CheckpointManager cps{.maxNumStates = storageSize, .cps{}}; + CheckpointStrategy& cps = *strategy; std::map savedCps; savedCps[0] = x; @@ -204,6 +75,7 @@ T advance_and_reverse_steps(size_t numSteps, size_t storageSize, T x, std::funct savedCps.erase(eraseStep); } savedCps[lastCp + 1] = x; + cps.record_recomputation(); } reverse_callback(i, savedCps[i]); @@ -214,21 +86,4 @@ T advance_and_reverse_steps(size_t numSteps, size_t storageSize, T x, std::funct return xf; } -/// @brief ostream operator for writing out checkpoint information -inline std::ostream& operator<<(std::ostream& stream, const Checkpoint& p) -{ - return stream << " lvl=" << p.level << ", step=" << p.step; -} - -/// @brief ostream operator for writing out information about the entire checkpoint manager to see the set of currently -/// checkpointed states -inline std::ostream& operator<<(std::ostream& stream, const CheckpointManager& set) -{ - stream << "CHECKPOINTS: capacity = " << set.maxNumStates << std::endl; - for (const auto& s : set.cps) { - stream << s << "\n"; - } - return stream; -} - } // namespace gretl diff --git a/src/gretl/checkpoint_strategy.hpp b/src/gretl/checkpoint_strategy.hpp new file mode 100644 index 0000000..6c3bcab --- /dev/null +++ b/src/gretl/checkpoint_strategy.hpp @@ -0,0 +1,88 @@ +// Copyright (c) Lawrence Livermore National Security, LLC and +// other Gretl Project Developers. See the top-level LICENSE file for +// details. +// +// SPDX-License-Identifier: (BSD-3-Clause) + +/** + * @file checkpoint_strategy.hpp + * @brief Abstract interface for checkpoint eviction strategies. + */ + +#pragma once + +#include +#include +#include +#include + +namespace gretl { + +/// @brief Performance counters for comparing checkpoint algorithms. +struct CheckpointMetrics { + size_t stores = 0; ///< Number of checkpoint store operations + size_t evictions = 0; ///< Number of checkpoint evictions + size_t recomputations = 0; ///< Forward re-evaluations triggered during reverse +}; + +/// @brief Abstract interface for checkpoint eviction strategies. +/// +/// Implementations decide which step to evict when checkpoint capacity is +/// exceeded. The interface exposes only the operations that DataStore +/// requires, hiding all algorithm-specific data structures. +class CheckpointStrategy { + public: + static constexpr size_t invalidCheckpointIndex = + std::numeric_limits::max(); ///< Magic number for invalid checkpoint + + /// @brief Check if a checkpoint index is valid + static bool valid_checkpoint_index(size_t i) { return i != invalidCheckpointIndex; } + + virtual ~CheckpointStrategy() = default; + + /// @brief Add a checkpoint for the given step. + /// @param step The step index to checkpoint. + /// @param persistent If true, this checkpoint cannot be evicted. + /// @return The step index to evict, or invalidCheckpointIndex if none. + virtual size_t add_checkpoint_and_get_index_to_remove(size_t step, bool persistent = false) = 0; + + /// @brief Return the step index of the earliest currently stored checkpoint. + virtual size_t last_checkpoint_step() const = 0; + + /// @brief Remove the checkpoint at the given step. + /// @return true if a checkpoint was found and removed. + virtual bool erase_step(size_t stepIndex) = 0; + + /// @brief Check if a checkpoint exists for the given step. + virtual bool contains_step(size_t stepIndex) const = 0; + + /// @brief Clear all non-persistent checkpoints. + virtual void reset() = 0; + + /// @brief Return the maximum number of non-persistent checkpoint slots. + virtual size_t capacity() const = 0; + + /// @brief Return the current number of checkpoints (persistent + non-persistent). + virtual size_t size() const = 0; + + /// @brief Print checkpoint state to the output stream. + virtual void print(std::ostream& os) const = 0; + + /// @brief Return accumulated performance metrics. + virtual CheckpointMetrics metrics() const = 0; + + /// @brief Reset accumulated performance metrics to zero. + virtual void reset_metrics() = 0; + + /// @brief Record a forward recomputation (called by DataStore during fetch). + virtual void record_recomputation() = 0; +}; + +/// @brief ostream operator for CheckpointStrategy +inline std::ostream& operator<<(std::ostream& os, const CheckpointStrategy& s) +{ + s.print(os); + return os; +} + +} // namespace gretl diff --git a/src/gretl/create_state.hpp b/src/gretl/create_state.hpp index 6c37b6a..a96cd8f 100644 --- a/src/gretl/create_state.hpp +++ b/src/gretl/create_state.hpp @@ -41,9 +41,8 @@ gretl::State create_state_impl( auto newState = state0.template create_state(state_bases, zeroFunc); newState.set_eval([eval](const gretl::UpstreamStates& inputs, gretl::DownstreamState& output) { - const T e = - eval(inputs[0].get(), inputs[state_indices + 1].get()...); - output.set(e); + T e = eval(inputs[0].get(), inputs[state_indices + 1].get()...); + output.set(std::move(e)); }); newState.set_vjp([vjp](gretl::UpstreamStates& inputs, const gretl::DownstreamState& output) { @@ -104,9 +103,8 @@ gretl::State clone_state_impl auto newState = state0.clone(state_bases); newState.set_eval([eval](const gretl::UpstreamStates& inputs, gretl::DownstreamState& output) { - const T e = - eval(inputs[0].get(), inputs[state_indices + 1].get()...); - output.set(e); + T e = eval(inputs[0].get(), inputs[state_indices + 1].get()...); + output.set(std::move(e)); }); newState.set_vjp([vjp](gretl::UpstreamStates& inputs, const gretl::DownstreamState& output) { diff --git a/src/gretl/data_store.cpp b/src/gretl/data_store.cpp index bf1563d..12288ac 100644 --- a/src/gretl/data_store.cpp +++ b/src/gretl/data_store.cpp @@ -12,7 +12,10 @@ namespace gretl { -DataStore::DataStore(size_t maxStates) : checkpointManager_{.maxNumStates = maxStates, .cps{}} { currentStep_ = 0; } +DataStore::DataStore(std::unique_ptr strategy) : checkpointStrategy_(std::move(strategy)) +{ + currentStep_ = 0; +} void DataStore::back_prop() { @@ -65,7 +68,7 @@ void DataStore::reset() } duals_[stepToClear] = nullptr; } - checkpointManager_.reset(); + checkpointStrategy_->reset(); currentStep_ = num_persistent; } @@ -77,9 +80,13 @@ void DataStore::reset_graph() if (is_persistent(stepToClear)) { num_persistent++; } + duals_[stepToClear] = nullptr; } + // Restore currentStep_ before resize, since back_prop() decrements it to 0 + // but resize() asserts newSize <= currentStep_. + currentStep_ = static_cast(states_.size()); resize(num_persistent); - checkpointManager_.reset(); + checkpointStrategy_->reset(); stillConstructingGraph_ = true; } @@ -118,14 +125,14 @@ void DataStore::reverse_state() { // must erase the final step in the cp manager before we get started if (currentStep_ == states_.size()) { - checkpointManager_.erase_step(currentStep_ - 1); + checkpointStrategy_->erase_step(currentStep_ - 1); } --currentStep_; if (upstreams_[currentStep_].size()) { fetch_state_data(currentStep_ - 1); vjp(*states_[currentStep_]); clear_usage(currentStep_); - checkpointManager_.erase_step(currentStep_ - 1); + checkpointStrategy_->erase_step(currentStep_ - 1); } } @@ -153,6 +160,11 @@ void printv(const std::vector& v) void DataStore::try_to_free(Int step) { + // Don't try to free during destruction to avoid accessing freed memory + if (isDestroying_) { + return; + } + if (!is_persistent(step) && states_[step] && states_[step]->data_) { if (usageCount_[step] == 0 && !active_[step] && states_[step]->data_.use_count() <= 1) { states_[step]->primal() = nullptr; @@ -174,7 +186,7 @@ void DataStore::add_state(std::unique_ptr newState, const std::vector bool persistent = upstreams.size() == 0; if (persistent) { - checkpointManager_.add_checkpoint_and_get_index_to_remove(step, persistent); + checkpointStrategy_->add_checkpoint_and_get_index_to_remove(step, persistent); } std::vector upstreamSteps; @@ -240,11 +252,11 @@ void DataStore::add_state(std::unique_ptr newState, const std::vector void DataStore::fetch_state_data(Int stepIndex) { gretl_assert_msg(!stillConstructingGraph_, "not allowed to fetch state before the graph is constructed"); - Int lastCheckpoint = static_cast(checkpointManager_.last_checkpoint_step()); + Int lastCheckpoint = static_cast(checkpointStrategy_->last_checkpoint_step()); if (lastCheckpoint > stepIndex) { print("An issue was found when fetching a previous states data\n"); print_graph(); - std::cout << checkpointManager_ << std::endl; + checkpointStrategy_->print(std::cout); } gretl_assert_msg(lastCheckpoint <= stepIndex, std::string("last checkpoint cannot be ahead of the currently requested step ") + @@ -266,6 +278,7 @@ void DataStore::fetch_state_data(Int stepIndex) erase_step_state_data(iEval); } else { states_[iEval]->evaluate_forward(); + checkpointStrategy_->record_recomputation(); } gretl_assert(check_validity()); @@ -275,8 +288,8 @@ void DataStore::fetch_state_data(Int stepIndex) void DataStore::erase_step_state_data(Int step) { if (!is_persistent(step)) { - size_t stepToErase = checkpointManager_.add_checkpoint_and_get_index_to_remove(step); - if (checkpointManager_.valid_checkpoint_index(stepToErase)) { + size_t stepToErase = checkpointStrategy_->add_checkpoint_and_get_index_to_remove(step); + if (CheckpointStrategy::valid_checkpoint_index(stepToErase)) { active_[stepToErase] = false; try_to_free(static_cast(stepToErase)); for_each_active_upstream(this, stepToErase, [&](Int upstream) { @@ -299,13 +312,7 @@ bool DataStore::check_validity() const // we are allowed to be saving an extra step here at the end for (size_t i = 0; i < currentStep_; ++i) { if (active_[i]) { - bool cp_has_i = false; - for (auto& cp : checkpointManager_.cps) { - if (cp.step == i) { - cp_has_i = true; - break; - } - } + bool cp_has_i = checkpointStrategy_->contains_step(i); if (!cp_has_i) { gretl::print("step", i, "not consistent with checkpoint manager"); valid = false; @@ -358,7 +365,7 @@ void DataStore::print_graph() const } std::cout << std::endl; } - // std::cout << checkpointManager_ << std::endl; + // checkpointStrategy_->print(std::cout); } } // namespace gretl diff --git a/src/gretl/data_store.hpp b/src/gretl/data_store.hpp index bf6c25a..e85fede 100644 --- a/src/gretl/data_store.hpp +++ b/src/gretl/data_store.hpp @@ -16,7 +16,10 @@ #include #include #include +#include +#include #include "checkpoint.hpp" +#include "checkpoint_strategy.hpp" #include "print_utils.hpp" #ifdef __GNUG__ @@ -52,13 +55,21 @@ struct defaultInitializeZeroDual { /// checkpointing state information, and its backpropagated sensitivities class DataStore { public: - /// @brief Constructor - /// @param maxStates maximum number of states the users is allowing to be allocated for the dynamic checkpointing. - /// This does not include persistent states, nor states held in scope by the user. - DataStore(size_t maxStates); - + /// @brief Constructor requiring a checkpoint strategy. + /// @param strategy a checkpoint strategy implementation (e.g., WangCheckpointStrategy, + /// StrummWaltherCheckpointStrategy) + explicit DataStore(std::unique_ptr strategy); + + /// @brief virtual destructor. Must clear states_ first because StateBase + /// destructors call try_to_free() which accesses upstreams_ and other members. + /// Without this, implicit reverse-declaration-order destruction would destroy + /// upstreams_ before states_, causing use-after-free. /// @brief virtual destructor - virtual ~DataStore() {} + virtual ~DataStore() + { + // Set flag to prevent try_to_free() from accessing freed memory during destruction + isDestroying_ = true; + } /// @brief create a new state in the graph, store it, return it template @@ -155,26 +166,21 @@ class DataStore { return *tptr; } - /// @brief Set primal value + /// @brief Set primal value (forwarding version: moves rvalues, copies lvalues) /// @param step step /// @param t value of type T to set primal to template - void set_primal(Int step, const T& t) + void set_primal(Int step, T&& t) { - T* tptr = std::any_cast(any_primal(step).get()); + using U = std::decay_t; + U* tptr = std::any_cast(any_primal(step).get()); if (!tptr) { gretl_assert(!stillConstructingGraph_); - // MRT, debug reverse pass here - // if (usageCount_[step] != 1) { - // print("step", step); - // print_graph(); - // } - // gretl_assert(usageCount_[step] == 1); - any_primal(step) = std::make_shared(t); + any_primal(step) = std::make_shared(std::forward(t)); return; } gretl_assert(tptr); - *tptr = t; + *tptr = std::forward(t); } /// @brief Get dual value @@ -249,7 +255,7 @@ class DataStore { ///< eventually used in some future step as an upstream /// container which track the states in the graph with allocated data - CheckpointManager checkpointManager_; + std::unique_ptr checkpointStrategy_; /// step counter Int currentStep_; @@ -257,6 +263,9 @@ class DataStore { /// @brief specifies if graph is in construction or back-prop mode. This is used for internal asserts. bool stillConstructingGraph_ = true; + /// @brief flag to prevent accessing freed memory during destruction + bool isDestroying_ = false; + friend struct StateBase; template diff --git a/src/gretl/state.hpp b/src/gretl/state.hpp index 848ac1a..c32acbd 100644 --- a/src/gretl/state.hpp +++ b/src/gretl/state.hpp @@ -27,9 +27,12 @@ struct State : public StateBase { using type = T; ///< type using dual_type = D; ///< dual_type - /// @brief Set primal value of correct type + /// @brief Set primal value of correct type (copy) inline void set(const T& t) { data_store().set_primal(step(), t); } + /// @brief Set primal value of correct type (move) + inline void set(T&& t) { data_store().set_primal(step(), std::move(t)); } + /// @brief Get primal value of correct type inline const T& get() const { return data_store().template get_primal(step()); } @@ -49,17 +52,14 @@ struct State : public StateBase { data_store().vjps_[step()] = v; } - /// @brief Helper function to clone an existing state (keeping its type) + /// @brief Helper function to clone an existing state (keeping its type). + /// Allocates default-constructed primal storage; finalize() will overwrite it + /// via evaluate_forward(), so copying the source primal is unnecessary. /// @param upstreams The upstream dependencies for this new state State clone(const std::vector& upstreams) const { gretl_assert(!upstreams.empty()); - auto primal_ptr = primal().get(); - gretl_assert(primal_ptr); - std::shared_ptr new_val; - if (primal_ptr) { - new_val = std::make_shared(*std::any_cast(primal_ptr)); - } + auto new_val = std::make_shared(T{}); State state(&data_store(), data_store().states_.size(), new_val, initialize_zero_dual_); data_store().add_state(std::make_unique>(state), upstreams); return state; diff --git a/src/gretl/strumm_walther_checkpoint_strategy.cpp b/src/gretl/strumm_walther_checkpoint_strategy.cpp new file mode 100644 index 0000000..8a85f0b --- /dev/null +++ b/src/gretl/strumm_walther_checkpoint_strategy.cpp @@ -0,0 +1,192 @@ +// Copyright (c) Lawrence Livermore National Security, LLC and +// other Gretl Project Developers. See the top-level LICENSE file for +// details. +// +// SPDX-License-Identifier: (BSD-3-Clause) + +#include "strumm_walther_checkpoint_strategy.hpp" +#include +#include +#include + +namespace gretl { + +StrummWaltherCheckpointStrategy::StrummWaltherCheckpointStrategy(size_t maxStates) : maxNumSlots_(maxStates) {} + +size_t StrummWaltherCheckpointStrategy::find_dispensable() const +{ + // Weight-based dispensability (analogous to Wang's most_dispensable): + // Iterate from highest step to lowest. Track the maximum weight seen. + // A slot is "dispensable" if its weight is LESS than the running maximum — + // it sits behind a more important (higher-weight) checkpoint. + // + // Enhancement over Wang: when multiple slots share the same dispensable + // weight, choose the one whose removal minimizes the increase in total + // recomputation cost (gap_left * gap_right). This spacing-aware tiebreaker + // can outperform Wang's arbitrary "first found" selection. + + // First pass: find the dispensable weight threshold + size_t maxWeight = 0; + size_t dispensableWeight = std::numeric_limits::max(); + for (size_t i = slots_.size(); i > 0; --i) { + size_t idx = i - 1; + if (slots_[idx].persistent) continue; + + if (slots_[idx].weight < maxWeight) { + dispensableWeight = slots_[idx].weight; + break; + } + maxWeight = std::max(maxWeight, slots_[idx].weight); + } + + if (dispensableWeight == std::numeric_limits::max()) { + return slots_.size(); // none found + } + + // Second pass: among all slots at dispensableWeight, pick the one with + // minimum gap_left * gap_right (minimum delta recomputation cost). + size_t bestIdx = slots_.size(); + size_t bestProduct = std::numeric_limits::max(); + + for (size_t i = 0; i < slots_.size(); ++i) { + if (slots_[i].persistent) continue; + if (slots_[i].weight != dispensableWeight) continue; + + // Check that this slot is actually in a dispensable position + // (there must be a higher-weight slot at a higher step) + bool hasHigherWeightAfter = false; + for (size_t j = i + 1; j < slots_.size(); ++j) { + if (!slots_[j].persistent && slots_[j].weight > dispensableWeight) { + hasHigherWeightAfter = true; + break; + } + } + if (!hasHigherWeightAfter) continue; + + size_t leftStep = (i > 0) ? slots_[i - 1].step : 0; + size_t rightStep = (i + 1 < slots_.size()) ? slots_[i + 1].step : slots_.back().step + 1; + + size_t gapLeft = slots_[i].step - leftStep; + size_t gapRight = rightStep - slots_[i].step; + size_t product = gapLeft * gapRight; + + if (product < bestProduct) { + bestProduct = product; + bestIdx = i; + } + } + + return bestIdx; +} + +size_t StrummWaltherCheckpointStrategy::find_rightmost_nonpersistent() const +{ + for (size_t i = slots_.size(); i > 0; --i) { + if (!slots_[i - 1].persistent) { + return i - 1; + } + } + return slots_.size(); +} + +size_t StrummWaltherCheckpointStrategy::add_checkpoint_and_get_index_to_remove(size_t step, bool persistent) +{ + size_t nextEraseStep = invalidCheckpointIndex; + + size_t newWeight = 0; + + if (persistent) { + maxNumSlots_++; + assert(slots_.size() < maxNumSlots_); + } + + if (slots_.size() < maxNumSlots_) { + // Space available — insert directly + } else { + // At capacity — must evict + size_t dispensableIdx = find_dispensable(); + + if (dispensableIdx < slots_.size()) { + // Found a dispensable slot: evict it, new checkpoint gets weight 0 + nextEraseStep = slots_[dispensableIdx].step; + slots_.erase(slots_.begin() + static_cast(dispensableIdx)); + } else { + // No dispensable slot (all weights equal): evict the rightmost + // non-persistent and PROMOTE the replacement to a higher weight. + // This is the key self-organizing mechanism: it creates a weight + // hierarchy that forces future evictions to target older, lower-weight + // checkpoints, producing near-logarithmic checkpoint distributions. + size_t rightmostIdx = find_rightmost_nonpersistent(); + assert(rightmostIdx < slots_.size()); + newWeight = slots_[rightmostIdx].weight + 1; + nextEraseStep = slots_[rightmostIdx].step; + slots_.erase(slots_.begin() + static_cast(rightmostIdx)); + } + } + + // Insert new slot in sorted order + Slot newSlot{step, persistent, persistent ? std::numeric_limits::max() : newWeight}; + auto it = std::lower_bound(slots_.begin(), slots_.end(), step, [](const Slot& s, size_t st) { return s.step < st; }); + slots_.insert(it, newSlot); + + metrics_.stores++; + if (valid_checkpoint_index(nextEraseStep)) { + metrics_.evictions++; + } + + return nextEraseStep; +} + +size_t StrummWaltherCheckpointStrategy::last_checkpoint_step() const +{ + assert(!slots_.empty()); + return slots_.back().step; +} + +bool StrummWaltherCheckpointStrategy::erase_step(size_t stepIndex) +{ + for (auto it = slots_.begin(); it != slots_.end(); ++it) { + if (it->step == stepIndex) { + if (!it->persistent) { + slots_.erase(it); + return true; + } + } + } + return false; +} + +bool StrummWaltherCheckpointStrategy::contains_step(size_t stepIndex) const +{ + for (const auto& s : slots_) { + if (s.step == stepIndex) { + return true; + } + } + return false; +} + +void StrummWaltherCheckpointStrategy::reset() +{ + slots_.erase(std::remove_if(slots_.begin(), slots_.end(), [](const Slot& s) { return !s.persistent; }), slots_.end()); +} + +size_t StrummWaltherCheckpointStrategy::capacity() const { return maxNumSlots_; } + +size_t StrummWaltherCheckpointStrategy::size() const { return slots_.size(); } + +void StrummWaltherCheckpointStrategy::print(std::ostream& os) const +{ + os << "CHECKPOINTS (StrummWalther): capacity = " << maxNumSlots_ << std::endl; + for (const auto& s : slots_) { + os << " step=" << s.step << " weight=" << s.weight << (s.persistent ? " (persistent)" : "") << "\n"; + } +} + +CheckpointMetrics StrummWaltherCheckpointStrategy::metrics() const { return metrics_; } + +void StrummWaltherCheckpointStrategy::reset_metrics() { metrics_ = {}; } + +void StrummWaltherCheckpointStrategy::record_recomputation() { metrics_.recomputations++; } + +} // namespace gretl diff --git a/src/gretl/strumm_walther_checkpoint_strategy.hpp b/src/gretl/strumm_walther_checkpoint_strategy.hpp new file mode 100644 index 0000000..901809f --- /dev/null +++ b/src/gretl/strumm_walther_checkpoint_strategy.hpp @@ -0,0 +1,75 @@ +// Copyright (c) Lawrence Livermore National Security, LLC and +// other Gretl Project Developers. See the top-level LICENSE file for +// details. +// +// SPDX-License-Identifier: (BSD-3-Clause) + +/** + * @file strumm_walther_checkpoint_strategy.hpp + * @brief Stumm & Walther 2010 "Online r=2" checkpointing strategy. + * + * Reference: Philipp Stumm and Andrea Walther, "New Algorithms for Optimal + * Online Checkpointing", SIAM J. Sci. Comput., 32(2), 836-854, 2010. + * DOI: 10.1137/080742439 + */ + +#pragma once + +#include "checkpoint_strategy.hpp" +#include +#include + +namespace gretl { + +/// @brief Stumm & Walther 2010 "Online r=2" checkpointing strategy. +/// +/// Unlike the Wang algorithm which uses levels to determine dispensability, +/// this algorithm maintains checkpoints with approximately uniform spacing +/// relative to the current step count. When at capacity, the eviction +/// candidate is the non-persistent checkpoint whose removal results in the +/// smallest maximum gap between remaining checkpoints. +/// +/// Key properties: +/// - No level concept; eviction is based on spacing analysis +/// - Works online: total number of steps need not be known a priori +/// - Achieves near-optimal checkpoint distribution for unknown-length runs +class StrummWaltherCheckpointStrategy final : public CheckpointStrategy { + public: + /// @brief Construct with a given number of non-persistent checkpoint slots. + explicit StrummWaltherCheckpointStrategy(size_t maxStates); + + size_t add_checkpoint_and_get_index_to_remove(size_t step, bool persistent = false) override; + size_t last_checkpoint_step() const override; + bool erase_step(size_t stepIndex) override; + bool contains_step(size_t stepIndex) const override; + void reset() override; + size_t capacity() const override; + size_t size() const override; + void print(std::ostream& os) const override; + CheckpointMetrics metrics() const override; + void reset_metrics() override; + void record_recomputation() override; + + private: + /// @brief A checkpoint slot: stores step, persistent flag, and weight. + struct Slot { + size_t step; + bool persistent; + size_t weight; ///< Importance weight; increases via promotion (like Wang levels) + }; + + /// @brief Find a "dispensable" slot using weight-based priority. + /// Iterates from highest to lowest step; a slot is dispensable if its + /// weight is less than the running maximum weight seen so far. + /// @return Index of the dispensable slot, or slots_.size() if none found. + size_t find_dispensable() const; + + /// @brief Find the index of the rightmost non-persistent slot. + size_t find_rightmost_nonpersistent() const; + + size_t maxNumSlots_; + std::vector slots_; ///< Sorted by step number + CheckpointMetrics metrics_; +}; + +} // namespace gretl diff --git a/src/gretl/upstream_state.hpp b/src/gretl/upstream_state.hpp index 8d4f456..c5235e4 100644 --- a/src/gretl/upstream_state.hpp +++ b/src/gretl/upstream_state.hpp @@ -82,11 +82,18 @@ struct DownstreamState { /// @param step step DownstreamState(DataStore* s, Int step) : dataStore_(s), step_(step) {} - /// @brief set underlying value + /// @brief set underlying value (copy) template void set(const T& t) { - return dataStore_->set_primal(step_, t); + dataStore_->set_primal(step_, t); + } + + /// @brief set underlying value (move) + template > + void set(T&& t) + { + dataStore_->set_primal(step_, std::forward(t)); } /// @brief get underlying value diff --git a/src/gretl/vector_state.cpp b/src/gretl/vector_state.cpp index 063de3e..aa60377 100644 --- a/src/gretl/vector_state.cpp +++ b/src/gretl/vector_state.cpp @@ -18,10 +18,9 @@ VectorState testing_update(const VectorState& a) auto& b_ = downstream; const Vector& A = a_.get(); - size_t sz = A.size(); - Vector B(sz); - for (size_t i = 0; i < sz; ++i) { - B[i] = A[i] / 3.0 + 2.0; + Vector B(A); // copy-construct (avoids zero-init of Vector(sz)) + for (auto& v : B) { + v = v / 3.0 + 2.0; } b_.set(std::move(B)); }); @@ -44,7 +43,7 @@ VectorState operator+(const VectorState& a, const VectorState& b) VectorState c = a.clone({a, b}); c.set_eval([](const UpstreamStates& upstreams, DownstreamState& downstream) { - Vector C = upstreams[0].get(); // just making a copy + Vector C = upstreams[0].get(); const Vector& B = upstreams[1].get(); size_t sz = C.size(); for (size_t i = 0; i < sz; ++i) { @@ -76,8 +75,9 @@ VectorState operator*(const VectorState& a, double b) VectorState c = a.clone({a}); c.set_eval([b](const UpstreamStates& upstreams, DownstreamState& downstream) { - Vector C = upstreams[0].get(); - for (auto&& v : C) { + const Vector& A = upstreams[0].get(); + Vector C(A); // copy-construct (avoids zero-init of Vector(sz)) + for (auto& v : C) { v *= b; } downstream.set(std::move(C)); @@ -102,8 +102,8 @@ State inner_product(const VectorState& a, const VectorState& b) c.set_eval([](const UpstreamStates& upstreams, DownstreamState& downstream) { double prod = 0.0; - auto A = upstreams[0].get(); - auto B = upstreams[1].get(); + const auto& A = upstreams[0].get(); + const auto& B = upstreams[1].get(); size_t sz = get_same_size({&A, &B}); for (size_t i = 0; i < sz; ++i) { prod += A[i] * B[i]; @@ -140,9 +140,10 @@ VectorState operator*(const VectorState& a, const VectorState& b) VectorState c = a.clone({a, b}); c.set_eval([](const UpstreamStates& upstreams, DownstreamState& downstream) { - Vector C = upstreams[0].get(); + const Vector& A = upstreams[0].get(); const Vector& B = upstreams[1].get(); - size_t sz = get_same_size({&B, &C}); + size_t sz = get_same_size({&A, &B}); + Vector C(A); // copy-construct (avoids zero-init of Vector(sz)) for (size_t i = 0; i < sz; ++i) { C[i] *= B[i]; } diff --git a/src/gretl/wang_checkpoint_strategy.cpp b/src/gretl/wang_checkpoint_strategy.cpp new file mode 100644 index 0000000..c3b7da8 --- /dev/null +++ b/src/gretl/wang_checkpoint_strategy.cpp @@ -0,0 +1,120 @@ +// Copyright (c) Lawrence Livermore National Security, LLC and +// other Gretl Project Developers. See the top-level LICENSE file for +// details. +// +// SPDX-License-Identifier: (BSD-3-Clause) + +#include "wang_checkpoint_strategy.hpp" +#include +#include + +namespace gretl { + +WangCheckpointStrategy::WangCheckpointStrategy(size_t maxStates) : maxNumStates_(maxStates) {} + +std::set::const_iterator +WangCheckpointStrategy::most_dispensable() const +{ + size_t maxHigherTimeLevel = 0; + for (auto rIter = cps_.begin(); rIter != cps_.end(); ++rIter) { + if (rIter->level < maxHigherTimeLevel) { + return rIter; + } + maxHigherTimeLevel = std::max(rIter->level, maxHigherTimeLevel); + } + return cps_.end(); +} + +size_t WangCheckpointStrategy::add_checkpoint_and_get_index_to_remove(size_t step, bool persistent) +{ + size_t levelupAmount = 1; + + Checkpoint nextStep{.level = levelupAmount - 1, .step = step}; + + size_t nextEraseStep = invalidCheckpointIndex; + + if (persistent) { + maxNumStates_++; + nextStep.level = Checkpoint::infinity(); + assert(cps_.size() < maxNumStates_); + } + + if (cps_.size() < maxNumStates_) { + cps_.insert(nextStep); + } else { + auto iterToMostDispensable = most_dispensable(); + if (iterToMostDispensable != cps_.end()) { + nextEraseStep = iterToMostDispensable->step; + cps_.erase(iterToMostDispensable); + cps_.insert(nextStep); + } else { + nextEraseStep = cps_.begin()->step; + nextStep.level = cps_.begin()->level + levelupAmount; + + cps_.erase(cps_.begin()); + cps_.insert(nextStep); + } + } + + metrics_.stores++; + if (valid_checkpoint_index(nextEraseStep)) { + metrics_.evictions++; + } + + return nextEraseStep; +} + +size_t WangCheckpointStrategy::last_checkpoint_step() const { return cps_.begin()->step; } + +bool WangCheckpointStrategy::erase_step(size_t stepIndex) +{ + for (auto it = cps_.begin(); it != cps_.end(); ++it) { + if (it->step == stepIndex) { + if (it->level != Checkpoint::infinity()) { + cps_.erase(it); + return true; + } + } + } + return false; +} + +bool WangCheckpointStrategy::contains_step(size_t stepIndex) const +{ + for (const auto& c : cps_) { + if (c.step == stepIndex) { + return true; + } + } + return false; +} + +void WangCheckpointStrategy::reset() +{ + for (auto cp_it = cps_.begin(); cp_it != cps_.end(); ++cp_it) { + if (cp_it->level == Checkpoint::infinity()) { + cps_.erase(cps_.begin(), cp_it); + break; + } + } +} + +size_t WangCheckpointStrategy::capacity() const { return maxNumStates_; } + +size_t WangCheckpointStrategy::size() const { return cps_.size(); } + +void WangCheckpointStrategy::print(std::ostream& os) const +{ + os << "CHECKPOINTS (Wang): capacity = " << maxNumStates_ << std::endl; + for (const auto& s : cps_) { + os << " lvl=" << s.level << ", step=" << s.step << "\n"; + } +} + +CheckpointMetrics WangCheckpointStrategy::metrics() const { return metrics_; } + +void WangCheckpointStrategy::reset_metrics() { metrics_ = {}; } + +void WangCheckpointStrategy::record_recomputation() { metrics_.recomputations++; } + +} // namespace gretl diff --git a/src/gretl/wang_checkpoint_strategy.hpp b/src/gretl/wang_checkpoint_strategy.hpp new file mode 100644 index 0000000..4c98d2c --- /dev/null +++ b/src/gretl/wang_checkpoint_strategy.hpp @@ -0,0 +1,72 @@ +// Copyright (c) Lawrence Livermore National Security, LLC and +// other Gretl Project Developers. See the top-level LICENSE file for +// details. +// +// SPDX-License-Identifier: (BSD-3-Clause) + +/** + * @file wang_checkpoint_strategy.hpp + * @brief Wang et al. 2009 "Minimal Repetition Dynamic Checkpointing" strategy. + */ + +#pragma once + +#include "checkpoint_strategy.hpp" +#include + +namespace gretl { + +/// @brief Wang et al. 2009 "Minimal Repetition Dynamic Checkpointing" +/// +/// Uses a level-based priority scheme where each checkpoint has a level +/// that determines its dispensability. The "most dispensable" checkpoint +/// is the one whose level drops below a previously seen higher level +/// when iterating the ordered set. +class WangCheckpointStrategy final : public CheckpointStrategy { + public: + /// @brief Construct with a given number of non-persistent checkpoint slots. + explicit WangCheckpointStrategy(size_t maxStates); + + size_t add_checkpoint_and_get_index_to_remove(size_t step, bool persistent = false) override; + size_t last_checkpoint_step() const override; + bool erase_step(size_t stepIndex) override; + bool contains_step(size_t stepIndex) const override; + void reset() override; + size_t capacity() const override; + size_t size() const override; + void print(std::ostream& os) const override; + CheckpointMetrics metrics() const override; + void reset_metrics() override; + void record_recomputation() override; + + private: + /// @brief Checkpoint with level for eviction priority (Wang-specific). + struct Checkpoint { + size_t level; ///< level + size_t step; ///< step + static constexpr size_t infinity() { return std::numeric_limits::max(); } + }; + + /// @brief Comparison operator for ordering checkpoints in the set. + /// Persistent checkpoints (infinity level) sort last; among others, higher step first. + struct CheckpointCompare { + bool operator()(const Checkpoint& a, const Checkpoint& b) const + { + if (a.level == Checkpoint::infinity() && b.level == Checkpoint::infinity()) { + return a.step > b.step; + } + if (a.level == Checkpoint::infinity()) return false; + if (b.level == Checkpoint::infinity()) return true; + return a.step > b.step; + } + }; + + /// @brief Find the most dispensable checkpoint per the Wang algorithm. + std::set::const_iterator most_dispensable() const; + + size_t maxNumStates_; + std::set cps_; + CheckpointMetrics metrics_; +}; + +} // namespace gretl diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index 92725bb..e6ae7da 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -6,8 +6,11 @@ set(gretl_test_sources test_gretl_checkpoint.cpp + test_gretl_checkpoint_compare.cpp test_gretl_dynamics.cpp - test_gretl_graph.cpp) + test_gretl_graph.cpp + test_gretl_robustness.cpp + test_persistent_scope.cpp) foreach(test ${gretl_test_sources}) get_filename_component( test_name ${test} NAME_WE ) diff --git a/src/tests/test_gretl_checkpoint.cpp b/src/tests/test_gretl_checkpoint.cpp index 94eb6b7..93fc9ce 100644 --- a/src/tests/test_gretl_checkpoint.cpp +++ b/src/tests/test_gretl_checkpoint.cpp @@ -9,6 +9,9 @@ #include #include "gtest/gtest.h" #include "gretl/checkpoint.hpp" +#include "gretl/checkpoint_strategy.hpp" +#include "gretl/wang_checkpoint_strategy.hpp" +#include "gretl/strumm_walther_checkpoint_strategy.hpp" #include "gretl/state.hpp" #include "gretl/data_store.hpp" @@ -20,6 +23,8 @@ double advance_solution(double x) return x / 3.0 + 2.0; } +// ---------- Original non-parameterized tests (backward compat) ---------- + struct CheckpointFixture : public ::testing::Test { static constexpr size_t S = 6; // max saved states static constexpr size_t N = 10; // run states @@ -42,29 +47,29 @@ TEST_F(CheckpointFixture, Procedural) std::vector states = get_full_state_hist(x0); std::vector reverseStates(N + 1); - gretl::CheckpointManager checkpointManager{.maxNumStates = S, .cps{}}; + gretl::WangCheckpointStrategy checkpointStrategy(S); std::map savedCheckpoints; savedCheckpoints[0] = x0; bool persistentCheckpoint = true; - checkpointManager.add_checkpoint_and_get_index_to_remove(0, persistentCheckpoint); + checkpointStrategy.add_checkpoint_and_get_index_to_remove(0, persistentCheckpoint); for (size_t i = 0; i < N; ++i) { const auto& xPrev = savedCheckpoints[i]; auto x = advance_solution(xPrev); - size_t stepToErase = checkpointManager.add_checkpoint_and_get_index_to_remove(i + 1); - if (checkpointManager.valid_checkpoint_index(stepToErase)) { + size_t stepToErase = checkpointStrategy.add_checkpoint_and_get_index_to_remove(i + 1); + if (gretl::CheckpointStrategy::valid_checkpoint_index(stepToErase)) { savedCheckpoints.erase(stepToErase); } savedCheckpoints[i + 1] = x; } for (size_t i_rev = N; i_rev + 1 > 0; --i_rev) { - for (size_t i = checkpointManager.last_checkpoint_step(); i < i_rev; ++i) { + for (size_t i = checkpointStrategy.last_checkpoint_step(); i < i_rev; ++i) { const auto& xPrev = savedCheckpoints[i]; auto x = advance_solution(xPrev); - size_t stepToErase = checkpointManager.add_checkpoint_and_get_index_to_remove(i + 1); - if (checkpointManager.valid_checkpoint_index(stepToErase)) { + size_t stepToErase = checkpointStrategy.add_checkpoint_and_get_index_to_remove(i + 1); + if (gretl::CheckpointStrategy::valid_checkpoint_index(stepToErase)) { savedCheckpoints.erase(stepToErase); } savedCheckpoints[i + 1] = x; @@ -72,7 +77,7 @@ TEST_F(CheckpointFixture, Procedural) reverseStates[i_rev] = savedCheckpoints[i_rev]; - checkpointManager.erase_step(i_rev); + checkpointStrategy.erase_step(i_rev); savedCheckpoints.erase(i_rev); } @@ -92,7 +97,7 @@ TEST_F(CheckpointFixture, Functional) std::vector reverseStates(N + 1); double xf = gretl::advance_and_reverse_steps( - N, S, x0, + N, x0, [&](size_t n, const double& x) { // update function advanceStates[n] = x; @@ -101,7 +106,8 @@ TEST_F(CheckpointFixture, Functional) [&](size_t n, const double& x) { // callback on reverse pass for computing reverse sensitivities reverseStates[n] = x; - }); + }, + std::make_unique(S)); advanceStates[N] = xf; @@ -128,14 +134,7 @@ gretl::State advance_solution(const gretl::State& a) auto a_ = upstreams[0]; const auto b_ = downstream; auto bBar = b_.get_dual(); - // if (!a_.dual_valid()) { - // this case is really only for performance optimization - // get dual will sometimes have to re-evaluate primal - // for linear operators, this is unneeded, we just tell the dual its data and size - // a_.set_dual(bBar / 3.0); - //} else { a_.get_dual() += bBar / 3.0; - //} }); return b.finalize(); @@ -149,7 +148,7 @@ TEST_F(CheckpointFixture, Automated) std::vector reverseStates(N + 1); std::vector advanceStates(N + 1); - gretl::DataStore dataStore(S); + gretl::DataStore dataStore(std::make_unique(S)); gretl::State X = dataStore.create_state(x); advanceStates[0] = X.get(); @@ -179,3 +178,176 @@ TEST_F(CheckpointFixture, Automated) std::cout << "total eval count = " << count << std::endl; count = 0; } + +// ---------- Parameterized tests across checkpoint strategies ---------- + +enum class StrategyType +{ + Wang, + StrummWalther +}; + +std::string strategy_name(StrategyType t) +{ + switch (t) { + case StrategyType::Wang: + return "Wang"; + case StrategyType::StrummWalther: + return "StrummWalther"; + } + return "Unknown"; +} + +std::unique_ptr make_strategy(StrategyType t, size_t slots) +{ + switch (t) { + case StrategyType::Wang: + return std::make_unique(slots); + case StrategyType::StrummWalther: + return std::make_unique(slots); + } + return nullptr; +} + +struct CheckpointStrategyTest : public ::testing::TestWithParam { + static constexpr size_t S = 6; + static constexpr size_t N = 10; + + std::vector get_full_state_hist(double x0) + { + std::vector states(N + 1); + states[0] = x0; + for (size_t n = 0; n < N; ++n) { + states[n + 1] = advance_solution(states[n]); + } + return states; + } +}; + +TEST_P(CheckpointStrategyTest, Procedural) +{ + double x0 = 0.0; + + std::vector states = get_full_state_hist(x0); + std::vector reverseStates(N + 1); + + auto strategy = make_strategy(GetParam(), S); + std::map savedCheckpoints; + + savedCheckpoints[0] = x0; + + strategy->add_checkpoint_and_get_index_to_remove(0, true); + for (size_t i = 0; i < N; ++i) { + const auto& xPrev = savedCheckpoints[i]; + auto x = advance_solution(xPrev); + size_t stepToErase = strategy->add_checkpoint_and_get_index_to_remove(i + 1); + if (gretl::CheckpointStrategy::valid_checkpoint_index(stepToErase)) { + savedCheckpoints.erase(stepToErase); + } + savedCheckpoints[i + 1] = x; + } + + for (size_t i_rev = N; i_rev + 1 > 0; --i_rev) { + while (strategy->last_checkpoint_step() < i_rev) { + size_t lastCp = strategy->last_checkpoint_step(); + const auto& xPrev = savedCheckpoints[lastCp]; + auto x = advance_solution(xPrev); + size_t stepToErase = strategy->add_checkpoint_and_get_index_to_remove(lastCp + 1); + if (gretl::CheckpointStrategy::valid_checkpoint_index(stepToErase)) { + savedCheckpoints.erase(stepToErase); + } + savedCheckpoints[lastCp + 1] = x; + } + + reverseStates[i_rev] = savedCheckpoints[i_rev]; + + strategy->erase_step(i_rev); + savedCheckpoints.erase(i_rev); + } + + for (size_t n = 0; n < N + 1; ++n) { + ASSERT_EQ(states[n], reverseStates[n]) << strategy_name(GetParam()) << " step " << n << "\n"; + } + + auto m = strategy->metrics(); + std::cout << strategy_name(GetParam()) << " procedural: stores=" << m.stores << " evictions=" << m.evictions + << " eval_count=" << count << std::endl; + count = 0; +} + +TEST_P(CheckpointStrategyTest, Functional) +{ + double x0 = 0.0; + + std::vector states = get_full_state_hist(x0); + std::vector advanceStates(N + 1); + std::vector reverseStates(N + 1); + + auto strategy = make_strategy(GetParam(), S); + + double xf = gretl::advance_and_reverse_steps( + N, x0, + [&](size_t n, const double& x) { + advanceStates[n] = x; + return advance_solution(x); + }, + [&](size_t n, const double& x) { reverseStates[n] = x; }, std::move(strategy)); + + advanceStates[N] = xf; + + for (size_t n = 0; n < N + 1; ++n) { + ASSERT_EQ(states[n], advanceStates[n]) << strategy_name(GetParam()) << " step " << n << "\n"; + ASSERT_EQ(states[n], reverseStates[n]) << strategy_name(GetParam()) << " step " << n << "\n"; + } + + std::cout << strategy_name(GetParam()) << " functional: eval_count=" << count << std::endl; + count = 0; +} + +TEST_P(CheckpointStrategyTest, Automated) +{ + double x = 0.0; + + std::vector states = get_full_state_hist(x); + std::vector reverseStates(N + 1); + std::vector advanceStates(N + 1); + + auto strategy = make_strategy(GetParam(), S); + gretl::DataStore dataStore(std::move(strategy)); + gretl::State X = dataStore.create_state(x); + + advanceStates[0] = X.get(); + for (size_t n = 0; n < N; ++n) { + X = advance_solution(X); + advanceStates[n + 1] = X.get(); + } + + X = set_as_objective(X); + dataStore.stillConstructingGraph_ = false; + + reverseStates[N] = X.get(); + EXPECT_EQ(X.get_dual(), 1.0); + for (size_t n = N; n > 0; --n) { + dataStore.reverse_state(); + auto restoredState = static_cast(n - 1); + reverseStates[n - 1] = dataStore.get_primal(restoredState); + double dual_val = dataStore.get_dual(restoredState); + ASSERT_NEAR(dual_val, std::pow(1. / 3., (N - n + 1)), 1e-14); + } + + for (size_t n = 0; n < N + 1; ++n) { + ASSERT_EQ(states[n], advanceStates[n]) << strategy_name(GetParam()) << " step " << n << "\n"; + ASSERT_EQ(states[n], reverseStates[n]) << strategy_name(GetParam()) << " step " << n << "\n"; + } + + auto m = dataStore.checkpointStrategy_->metrics(); + std::cout << strategy_name(GetParam()) << " automated: stores=" << m.stores << " evictions=" << m.evictions + << " recomps=" << m.recomputations << " eval_count=" << count << std::endl; + count = 0; +} + +INSTANTIATE_TEST_SUITE_P(AllStrategies, CheckpointStrategyTest, + ::testing::Values(StrategyType::Wang, StrategyType::StrummWalther), + [](const ::testing::TestParamInfo& param_info) { + return strategy_name(param_info.param); + }); diff --git a/src/tests/test_gretl_checkpoint_compare.cpp b/src/tests/test_gretl_checkpoint_compare.cpp new file mode 100644 index 0000000..018b27f --- /dev/null +++ b/src/tests/test_gretl_checkpoint_compare.cpp @@ -0,0 +1,178 @@ +// Copyright (c) Lawrence Livermore National Security, LLC and +// other Gretl Project Developers. See the top-level LICENSE file for +// details. +// +// SPDX-License-Identifier: (BSD-3-Clause) + +/// @file test_gretl_checkpoint_compare.cpp +/// @brief Side-by-side comparison of Wang and StrummWalther checkpointing strategies. + +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "gretl/checkpoint.hpp" +#include "gretl/checkpoint_strategy.hpp" +#include "gretl/wang_checkpoint_strategy.hpp" +#include "gretl/strumm_walther_checkpoint_strategy.hpp" +#include "gretl/state.hpp" +#include "gretl/data_store.hpp" + +namespace { + +double forward_step(double x) { return x / 3.0 + 2.0; } + +gretl::State forward_step_state(const gretl::State& a) +{ + auto b = a.clone({a}); + + b.set_eval([](const gretl::UpstreamStates& upstreams, gretl::DownstreamState& downstream) { + downstream.set(forward_step(upstreams[0].get())); + }); + + b.set_vjp([](gretl::UpstreamStates& upstreams, const gretl::DownstreamState& downstream) { + upstreams[0].get_dual() += downstream.get_dual() / 3.0; + }); + + return b.finalize(); +} + +struct AlgorithmResult { + std::string name; + gretl::CheckpointMetrics metrics; + double gradient; +}; + +AlgorithmResult run_procedural_test(std::unique_ptr strategy, const std::string& name, + size_t N) +{ + double x0 = 0.0; + std::map savedCps; + savedCps[0] = x0; + double x = x0; + + strategy->add_checkpoint_and_get_index_to_remove(0, true); + for (size_t i = 0; i < N; ++i) { + x = forward_step(savedCps[i]); + size_t eraseStep = strategy->add_checkpoint_and_get_index_to_remove(i + 1); + if (gretl::CheckpointStrategy::valid_checkpoint_index(eraseStep)) { + savedCps.erase(eraseStep); + } + savedCps[i + 1] = x; + } + + double grad = 1.0; + for (size_t i_rev = N; i_rev + 1 > 0; --i_rev) { + while (strategy->last_checkpoint_step() < i_rev) { + size_t lastCp = strategy->last_checkpoint_step(); + x = forward_step(savedCps[lastCp]); + size_t eraseStep = strategy->add_checkpoint_and_get_index_to_remove(lastCp + 1); + if (gretl::CheckpointStrategy::valid_checkpoint_index(eraseStep)) { + savedCps.erase(eraseStep); + } + savedCps[lastCp + 1] = x; + strategy->record_recomputation(); + } + grad *= 1.0 / 3.0; // derivative of forward_step + + strategy->erase_step(i_rev); + savedCps.erase(i_rev); + } + + return {name, strategy->metrics(), grad}; +} + +AlgorithmResult run_datastore_test(std::unique_ptr strategy, const std::string& name, + size_t N) +{ + gretl::DataStore dataStore(std::move(strategy)); + gretl::State X = dataStore.create_state(0.0); + + for (size_t n = 0; n < N; ++n) { + X = forward_step_state(X); + } + + X = set_as_objective(X); + dataStore.stillConstructingGraph_ = false; + dataStore.back_prop(); + + double grad = dataStore.get_dual(0); + return {name, dataStore.checkpointStrategy_->metrics(), grad}; +} + +} // namespace + +TEST(CheckpointCompare, ProceduralComparison) +{ + struct Config { + size_t N; + size_t budget; + }; + + std::vector configs = {{10, 3}, {10, 5}, {10, 6}, {10, 8}, {20, 3}, {20, 5}, + {20, 8}, {20, 10}, {50, 3}, {50, 5}, {50, 10}, {50, 20}, + {100, 3}, {100, 5}, {100, 10}, {100, 20}, {200, 5}, {200, 10}, + {200, 20}, {500, 5}, {500, 10}, {500, 20}, {1000, 10}, {1000, 20}, + {1000, 50}, {5000, 10}, {5000, 50}, {5000, 100}, {5000, 200}, {5000, 500}}; + + std::cout << "\n--- Procedural Checkpoint Algorithm Comparison ---\n"; + std::cout << std::setw(6) << "N" << std::setw(8) << "Budget" << " | " << std::setw(10) << "Algorithm" << std::setw(10) + << "stores" << std::setw(10) << "evictions" << std::setw(12) << "recomps" << std::setw(14) << "ratio(r/N)" + << "\n"; + std::cout << std::string(72, '-') << "\n"; + + for (const auto& cfg : configs) { + auto wang_result = run_procedural_test(std::make_unique(cfg.budget), "Wang", cfg.N); + auto r2_result = run_procedural_test(std::make_unique(cfg.budget), + "StrummWalther", cfg.N); + + ASSERT_NEAR(wang_result.gradient, r2_result.gradient, 1e-14) + << "Gradient mismatch at N=" << cfg.N << " budget=" << cfg.budget; + + for (const auto& r : {wang_result, r2_result}) { + std::cout << std::setw(6) << cfg.N << std::setw(8) << cfg.budget << " | " << std::setw(10) << r.name + << std::setw(10) << r.metrics.stores << std::setw(10) << r.metrics.evictions << std::setw(12) + << r.metrics.recomputations << std::setw(14) << std::fixed << std::setprecision(3) + << static_cast(r.metrics.recomputations) / static_cast(cfg.N) << "\n"; + } + } + std::cout << std::endl; +} + +TEST(CheckpointCompare, DataStoreComparison) +{ + struct Config { + size_t N; + size_t budget; + }; + + std::vector configs = {{10, 3}, {10, 5}, {10, 6}, {10, 8}, {20, 3}, {20, 5}, {20, 8}, + {20, 10}, {50, 5}, {50, 10}, {50, 20}, {100, 5}, {100, 10}, {100, 20}, + {200, 5}, {200, 10}, {200, 20}, {500, 10}, {500, 20}, {1000, 10}, {1000, 20}, + {1000, 50}, {5000, 10}, {5000, 50}, {5000, 100}, {5000, 200}, {5000, 500}}; + + std::cout << "\n--- DataStore Checkpoint Algorithm Comparison ---\n"; + std::cout << std::setw(6) << "N" << std::setw(8) << "Budget" << " | " << std::setw(10) << "Algorithm" << std::setw(10) + << "stores" << std::setw(10) << "evictions" << std::setw(12) << "recomps" << std::setw(14) << "ratio(r/N)" + << "\n"; + std::cout << std::string(72, '-') << "\n"; + + for (const auto& cfg : configs) { + auto wang_result = run_datastore_test(std::make_unique(cfg.budget), "Wang", cfg.N); + auto r2_result = run_datastore_test(std::make_unique(cfg.budget), + "StrummWalther", cfg.N); + + double expected_grad = std::pow(1.0 / 3.0, cfg.N); + ASSERT_NEAR(wang_result.gradient, expected_grad, 1e-14) << "Wang gradient wrong at N=" << cfg.N; + ASSERT_NEAR(r2_result.gradient, expected_grad, 1e-14) << "StrummWalther gradient wrong at N=" << cfg.N; + + for (const auto& r : {wang_result, r2_result}) { + std::cout << std::setw(6) << cfg.N << std::setw(8) << cfg.budget << " | " << std::setw(10) << r.name + << std::setw(10) << r.metrics.stores << std::setw(10) << r.metrics.evictions << std::setw(12) + << r.metrics.recomputations << std::setw(14) << std::fixed << std::setprecision(3) + << static_cast(r.metrics.recomputations) / static_cast(cfg.N) << "\n"; + } + } + std::cout << std::endl; +} diff --git a/src/tests/test_gretl_dynamics.cpp b/src/tests/test_gretl_dynamics.cpp index 17dcb8a..53ee7ba 100644 --- a/src/tests/test_gretl_dynamics.cpp +++ b/src/tests/test_gretl_dynamics.cpp @@ -12,6 +12,7 @@ #include #include "gtest/gtest.h" #include "gretl/checkpoint.hpp" +#include "gretl/wang_checkpoint_strategy.hpp" #include "gretl/state.hpp" #include "gretl/test_utils.hpp" #include "gretl/vector_state.hpp" @@ -79,7 +80,7 @@ State state_rate_equation(const State& state, const Param& params, [[maybe_unuse class MeshFixture : public ::testing::Test { public: - void SetUp() { dataStore = std::make_shared(20); } + void SetUp() { dataStore = std::make_shared(std::make_unique(20)); } Param::type params_data{1.3, 3.5, 1.1, 0.0}; State::type state0_data{1.7, 1.1, 0.1}; diff --git a/src/tests/test_gretl_graph.cpp b/src/tests/test_gretl_graph.cpp index 55edfb0..c5fd2b4 100644 --- a/src/tests/test_gretl_graph.cpp +++ b/src/tests/test_gretl_graph.cpp @@ -11,6 +11,7 @@ #include "gtest/gtest.h" #include "gretl/vector_state.hpp" #include "gretl/data_store.hpp" +#include "gretl/wang_checkpoint_strategy.hpp" #include "gretl/test_utils.hpp" using gretl::print; @@ -33,7 +34,7 @@ TEST(Graph, NonlinearGraphGradients) std::vector dataB = {1.7, 1.1}; std::vector dataZ = {-0.7, 3.1}; - gretl::DataStore dataStore(3); + gretl::DataStore dataStore(std::make_unique(3)); auto a = dataStore.create_state(dataA, gretl::vec::initialize_zero_dual); auto b = dataStore.create_state(dataB, gretl::vec::initialize_zero_dual); @@ -83,7 +84,7 @@ TEST(Graph, LinearGraphGradients) { std::vector dataA = {1.2, 3.2}; - gretl::DataStore dataStore(6); + gretl::DataStore dataStore(std::make_unique(6)); auto initial = dataStore.create_state(dataA, gretl::vec::initialize_zero_dual); auto a = gretl::copy(initial); @@ -107,7 +108,7 @@ TEST(Graph, LargeNonlinearGraphGradients) std::vector dataB = {0.6, 0.87}; std::vector dataC = {-0.8, 0.32}; - gretl::DataStore dataStore(3); + gretl::DataStore dataStore(std::make_unique(3)); auto a = dataStore.create_state(dataA, gretl::vec::initialize_zero_dual); auto b = dataStore.create_state(dataB, gretl::vec::initialize_zero_dual); diff --git a/src/tests/test_gretl_robustness.cpp b/src/tests/test_gretl_robustness.cpp new file mode 100644 index 0000000..291e3b4 --- /dev/null +++ b/src/tests/test_gretl_robustness.cpp @@ -0,0 +1,1705 @@ +// Copyright (c) Lawrence Livermore National Security, LLC and +// other Gretl Project Developers. See the top-level LICENSE file for +// details. +// +// SPDX-License-Identifier: (BSD-3-Clause) + +// +// Stress tests for state lifecycle, try_to_free, usageCount, checkpoint +// eviction, scope-based external reference tracking, and performance scaling. +// + +#include +#include +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "gretl/data_store.hpp" +#include "gretl/wang_checkpoint_strategy.hpp" +#include "gretl/state.hpp" +#include "gretl/double_state.hpp" +#include "gretl/vector_state.hpp" +#include "gretl/test_utils.hpp" + +using gretl::DataStore; +using gretl::State; +using gretl::VectorState; + +// --------------------------------------------------------------------------- +// Helpers: build graph steps in sub-functions to stress scope/lifetime +// --------------------------------------------------------------------------- + +// Build a chain of N steps in a sub-function and return only the final state. +// All intermediate State objects go out of scope when this function returns, +// which triggers destructors -> try_to_free for each intermediate. +static State build_chain_in_subfunc(const State& x0, int N) +{ + State x = x0; + for (int i = 0; i < N; ++i) { + // x = 0.5*x + 1.0 => after N steps: x = x0/2^N + 2*(1 - 1/2^N) + x = gretl::axpb(0.5, x, 1.0); + } + return x; +} + +// Build a diamond: two branches from x0 that rejoin. +// x0 -> a = 2*x0+1 +// x0 -> b = 3*x0-1 +// c = a + b = 5*x0 +// All intermediates (a, b) go out of scope. +static State build_diamond_in_subfunc(const State& x0) +{ + auto a = gretl::axpb(2.0, x0, 1.0); + auto b = gretl::axpb(3.0, x0, -1.0); + return a + b; // 5*x0 +} + +// Build a fan-out: x0 is used as upstream by many independent states. +// Returns the sum of all of them. +static State build_fanout_in_subfunc(const State& x0, int fanWidth) +{ + State accum = gretl::axpb(1.0, x0, 0.0); // copy of x0 + for (int i = 1; i < fanWidth; ++i) { + auto branch = gretl::axpb(1.0, x0, 0.0); // another copy + accum = accum + branch; + } + return accum; // = fanWidth * x0 +} + +// Nested sub-function: calls another sub-function, introducing two levels +// of scope nesting. +static State build_nested_subfuncs(const State& x0) +{ + auto mid = build_chain_in_subfunc(x0, 3); // 3-step chain + auto out = build_chain_in_subfunc(mid, 3); // another 3-step chain + return out; +} + +// Build a chain that saves intermediates into a user-held vector, +// simulating the pattern of "holding states in scope externally." +static State build_chain_holding_intermediates(const State& x0, int N, std::vector>& held) +{ + State x = x0; + for (int i = 0; i < N; ++i) { + x = gretl::axpb(0.5, x, 1.0); + held.push_back(x); // external reference keeps use_count > 1 + } + return x; +} + +// --------------------------------------------------------------------------- +// TEST SUITE: ScopeLifetime +// States created in sub-functions going out of scope during graph construction +// --------------------------------------------------------------------------- + +TEST(ScopeLifetime, ChainInSubfunc_SmallBudget) +{ + // Very tight checkpoint budget (2), long chain built entirely in a sub-func. + // Intermediates go out of scope on return. + DataStore store(std::make_unique(2)); + auto x0 = store.create_state(3.0); + auto xN = build_chain_in_subfunc(x0, 20); + + double expected = 3.0; + for (int i = 0; i < 20; ++i) expected = 0.5 * expected + 1.0; + EXPECT_NEAR(xN.get(), expected, 1e-12); + + gretl::set_as_objective(xN); + store.back_prop(); + + // df/dx0 = (0.5)^20 + EXPECT_NEAR(x0.get_dual(), std::pow(0.5, 20), 1e-12); +} + +TEST(ScopeLifetime, ChainInSubfunc_TinyBudget) +{ + // Budget of 1 (absolute minimum for non-persistent checkpoints) + DataStore store(std::make_unique(1)); + auto x0 = store.create_state(5.0); + auto xN = build_chain_in_subfunc(x0, 10); + + double expected = 5.0; + for (int i = 0; i < 10; ++i) expected = 0.5 * expected + 1.0; + EXPECT_NEAR(xN.get(), expected, 1e-12); + + gretl::set_as_objective(xN); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), std::pow(0.5, 10), 1e-12); +} + +TEST(ScopeLifetime, DiamondInSubfunc) +{ + DataStore store(std::make_unique(3)); + auto x0 = store.create_state(2.0); + auto result = build_diamond_in_subfunc(x0); + + EXPECT_NEAR(result.get(), 5.0 * 2.0, 1e-14); + + gretl::set_as_objective(result); + store.back_prop(); + + // d(5*x0)/dx0 = 5 + EXPECT_NEAR(x0.get_dual(), 5.0, 1e-14); +} + +TEST(ScopeLifetime, NestedSubfuncs) +{ + // Two levels of sub-function scope nesting + DataStore store(std::make_unique(3)); + auto x0 = store.create_state(4.0); + auto result = build_nested_subfuncs(x0); + + // 6 steps of x -> 0.5*x + 1.0 + double expected = 4.0; + for (int i = 0; i < 6; ++i) expected = 0.5 * expected + 1.0; + EXPECT_NEAR(result.get(), expected, 1e-12); + + gretl::set_as_objective(result); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), std::pow(0.5, 6), 1e-12); +} + +TEST(ScopeLifetime, FanoutInSubfunc) +{ + DataStore store(std::make_unique(4)); + auto x0 = store.create_state(3.0); + auto result = build_fanout_in_subfunc(x0, 5); + + EXPECT_NEAR(result.get(), 5.0 * 3.0, 1e-14); + + gretl::set_as_objective(result); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), 5.0, 1e-14); +} + +// --------------------------------------------------------------------------- +// TEST SUITE: ExternalReferences +// States held externally while graph operations proceed +// --------------------------------------------------------------------------- + +TEST(ExternalReferences, HeldIntermediatesPreventPrematureFreeing) +{ + // Hold all intermediates in a vector -- use_count stays > 1 for each. + // This should prevent try_to_free from deallocating them prematurely. + DataStore store(std::make_unique(3)); + auto x0 = store.create_state(2.0); + + std::vector> held; + auto xN = build_chain_holding_intermediates(x0, 8, held); + + // Verify intermediates are still accessible + double expected = 2.0; + for (int i = 0; i < 8; ++i) { + expected = 0.5 * expected + 1.0; + EXPECT_NEAR(held[static_cast(i)].get(), expected, 1e-12) << "intermediate " << i; + } + + gretl::set_as_objective(xN); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), std::pow(0.5, 8), 1e-12); +} + +TEST(ExternalReferences, HeldIntermediatesThenDropped) +{ + // Hold intermediates, run backprop, then drop them. + // This tests the destructor path when states go out of scope + // after the graph has been back-propagated. + DataStore store(std::make_unique(3)); + auto x0 = store.create_state(2.0); + + { + std::vector> held; + auto xN = build_chain_holding_intermediates(x0, 6, held); + + gretl::set_as_objective(xN); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), std::pow(0.5, 6), 1e-12); + // held goes out of scope here, destructors fire for all held states + } + // If we get here without crash/assert, the destructors handled + // the post-backprop state correctly. + SUCCEED(); +} + +TEST(ExternalReferences, CopyStateAcrossScopes) +{ + // Create a state in one scope, copy it to another, let original go out of scope. + DataStore store(std::make_unique(4)); + auto x0 = store.create_state(7.0); + + // Build inner in a lambda that returns it, so the original goes out of scope + auto outer = [&]() { + auto inner = gretl::axpb(2.0, x0, 3.0); // 2*7+3 = 17 + return inner; + }(); + + auto result = gretl::axpb(3.0, outer, 0.0); // 3*17 = 51 + EXPECT_NEAR(result.get(), 51.0, 1e-14); + + gretl::set_as_objective(result); + store.back_prop(); + + // d(3*(2*x0+3))/dx0 = 6 + EXPECT_NEAR(x0.get_dual(), 6.0, 1e-14); +} + +// --------------------------------------------------------------------------- +// TEST SUITE: AssignmentOperator +// Tests for the StateBase assignment operator and its try_to_free calls +// --------------------------------------------------------------------------- + +TEST(AssignmentOperator, ReassignMidGraph) +{ + // Reassign a local state variable multiple times in graph construction. + // Each assignment triggers try_to_free on the old step. + DataStore store(std::make_unique(4)); + auto x0 = store.create_state(1.0); + + auto a = gretl::axpb(2.0, x0, 0.0); // step 1: 2.0 + a = gretl::axpb(3.0, x0, 0.0); // step 2: 3.0, old step 1 freed + a = gretl::axpb(4.0, a, 0.0); // step 3: 12.0, old step 2 is upstream so NOT freed + auto result = gretl::axpb(1.0, a, 0.0); // step 4: 12.0 + + EXPECT_NEAR(result.get(), 12.0, 1e-14); + + gretl::set_as_objective(result); + store.back_prop(); + + // d(4*3*x0)/dx0 = 12 + EXPECT_NEAR(x0.get_dual(), 12.0, 1e-14); +} + +TEST(AssignmentOperator, ReassignInLoop) +{ + // Classic pattern: `x = f(x)` in a loop. Each iteration reassigns + // the local variable, old step must be freed properly. + DataStore store(std::make_unique(3)); + auto x0 = store.create_state(1.0); + + auto x = gretl::axpb(1.0, x0, 0.0); // copy + int N = 15; + for (int i = 0; i < N; ++i) { + x = gretl::axpb(0.9, x, 0.1); // x = 0.9*x + 0.1 + } + + double expected = 1.0; + for (int i = 0; i < N; ++i) expected = 0.9 * expected + 0.1; + EXPECT_NEAR(x.get(), expected, 1e-10); + + gretl::set_as_objective(x); + store.back_prop(); + + // df/dx0 = 0.9^N (chain rule through linear maps) + // The copy step adds a factor of 1.0 + EXPECT_NEAR(x0.get_dual(), std::pow(0.9, N), 1e-10); +} + +// --------------------------------------------------------------------------- +// TEST SUITE: CheckpointEviction +// Stress the checkpoint manager with various budget/graph combinations +// --------------------------------------------------------------------------- + +TEST(CheckpointEviction, LongChainMinimalBudget) +{ + // 50 steps with budget of 2: forces many recomputations. + int N = 50; + DataStore store(std::make_unique(2)); + auto x0 = store.create_state(1.0); + + auto x = x0; + for (int i = 0; i < N; ++i) { + x = gretl::axpb(0.95, x, 0.05); + } + + gretl::set_as_objective(x); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), std::pow(0.95, N), 1e-8); +} + +TEST(CheckpointEviction, LongChainLargeBudget) +{ + // Same chain, but with generous budget. Exercises different checkpoint + // decisions (most states fit in memory). + int N = 50; + DataStore store(std::make_unique(60)); + auto x0 = store.create_state(1.0); + + auto x = x0; + for (int i = 0; i < N; ++i) { + x = gretl::axpb(0.95, x, 0.05); + } + + gretl::set_as_objective(x); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), std::pow(0.95, N), 1e-10); +} + +TEST(CheckpointEviction, MediumChainExactBudget) +{ + // Budget = N: every state fits. Edge case for checkpoint manager. + int N = 10; + DataStore store(std::make_unique(static_cast(N))); + auto x0 = store.create_state(2.0); + + auto x = x0; + for (int i = 0; i < N; ++i) { + x = gretl::axpb(0.8, x, 0.3); + } + + gretl::set_as_objective(x); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), std::pow(0.8, N), 1e-12); +} + +// --------------------------------------------------------------------------- +// TEST SUITE: DAGTopology +// Non-linear graph topologies: diamonds, fan-in, fan-out, skip connections +// --------------------------------------------------------------------------- + +TEST(DAGTopology, DiamondDependency) +{ + // x0 -> a, x0 -> b, c = a*b (diamond) + DataStore store(std::make_unique(4)); + auto x0 = store.create_state(3.0); + + auto a = gretl::axpb(2.0, x0, 1.0); // 2*3+1 = 7 + auto b = gretl::axpb(3.0, x0, -1.0); // 3*3-1 = 8 + auto c = a * b; // 7*8 = 56 + + EXPECT_NEAR(c.get(), 56.0, 1e-14); + + gretl::set_as_objective(c); + store.back_prop(); + + // dc/dx0 = dc/da * da/dx0 + dc/db * db/dx0 + // = b * 2 + a * 3 + // = 8*2 + 7*3 = 16 + 21 = 37 + EXPECT_NEAR(x0.get_dual(), 37.0, 1e-13); +} + +TEST(DAGTopology, MultiInputDiamond) +{ + // x0, y0 -> a = x0+y0, b = x0*y0, c = a+b + DataStore store(std::make_unique(5)); + auto x0 = store.create_state(2.0); + auto y0 = store.create_state(3.0); + + auto a = x0 + y0; // 5 + auto b = x0 * y0; // 6 + auto c = a + b; // 11 + + EXPECT_NEAR(c.get(), 11.0, 1e-14); + + gretl::set_as_objective(c); + store.back_prop(); + + // dc/dx0 = 1 + y0 = 4 + // dc/dy0 = 1 + x0 = 3 + EXPECT_NEAR(x0.get_dual(), 4.0, 1e-14); + EXPECT_NEAR(y0.get_dual(), 3.0, 1e-14); +} + +TEST(DAGTopology, SkipConnection) +{ + // x0 -> a -> b -> c, but also x0 -> c directly (skip connection) + DataStore store(std::make_unique(5)); + auto x0 = store.create_state(2.0); + + auto a = gretl::axpb(2.0, x0, 0.0); // 4 + auto b = gretl::axpb(3.0, a, 0.0); // 12 + auto c = b + x0; // 12 + 2 = 14 + + EXPECT_NEAR(c.get(), 14.0, 1e-14); + + gretl::set_as_objective(c); + store.back_prop(); + + // dc/dx0 = dc/db * db/da * da/dx0 + 1 = 1*3*2 + 1 = 7 + EXPECT_NEAR(x0.get_dual(), 7.0, 1e-14); +} + +TEST(DAGTopology, WideFanoutThenMerge) +{ + // x0 fans out to 10 branches, all merge back together by summation. + // Stresses usageCount tracking on x0. + int W = 10; + DataStore store(std::make_unique(static_cast(W + 2))); + auto x0 = store.create_state(1.5); + + State sum = gretl::axpb(1.0, x0, 0.0); + for (int i = 1; i < W; ++i) { + auto branch = gretl::axpb(static_cast(i + 1), x0, 0.0); + sum = sum + branch; + } + // sum = x0 + 2*x0 + 3*x0 + ... + W*x0 = x0 * W*(W+1)/2 + double coeff = static_cast(W * (W + 1)) / 2.0; + EXPECT_NEAR(sum.get(), coeff * 1.5, 1e-12); + + gretl::set_as_objective(sum); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), coeff, 1e-12); +} + +TEST(DAGTopology, SkipConnectionTightBudget) +{ + // Skip connection with minimal checkpoint budget. + // x0 is used both early and late in the graph. + DataStore store(std::make_unique(2)); + auto x0 = store.create_state(2.0); + + auto a = gretl::axpb(2.0, x0, 0.0); + auto b = gretl::axpb(3.0, a, 0.0); + auto c = gretl::axpb(4.0, b, 0.0); + // Now use x0 again (skip connection) — x0 must be available + auto d = c + x0; + + EXPECT_NEAR(d.get(), 4.0 * 3.0 * 2.0 * 2.0 + 2.0, 1e-14); + + gretl::set_as_objective(d); + store.back_prop(); + + // dd/dx0 = 4*3*2 + 1 = 25 + EXPECT_NEAR(x0.get_dual(), 25.0, 1e-14); +} + +// --------------------------------------------------------------------------- +// TEST SUITE: ResetAndRerun +// Test reset() and re-evaluation patterns +// --------------------------------------------------------------------------- + +TEST(ResetAndRerun, ResetAndRerunGraph) +{ + // Build graph, backprop, reset, change persistent input, re-evaluate, backprop again. + DataStore store(std::make_unique(5)); + auto x0 = store.create_state(2.0); + + auto a = gretl::axpb(3.0, x0, 1.0); // 7 + auto b = gretl::axpb(2.0, a, -1.0); // 13 + gretl::set_as_objective(b); + store.back_prop(); + + // db/dx0 = 2*3 = 6 + EXPECT_NEAR(x0.get_dual(), 6.0, 1e-14); + + // Reset and re-run with a different x0 + store.reset(); + x0.set(5.0); + store.reset_for_backprop(); + b.set_dual(1.0); + store.back_prop(); + + // After reset+re-eval: a = 3*5+1=16, b = 2*16-1=31 + EXPECT_NEAR(b.get(), 31.0, 1e-14); + EXPECT_NEAR(x0.get_dual(), 6.0, 1e-14); // gradient is independent of x0 for linear graph +} + +TEST(ResetAndRerun, ResetGraphAndRebuild) +{ + // Previously crashed: after back_prop(), currentStep_ was 0 but resize() + // asserted newSize <= currentStep_. Fixed by restoring currentStep_ first. + DataStore store(std::make_unique(5)); + auto x0 = store.create_state(2.0); + + // First graph: x0 -> 3*x0+1 + { + auto a = gretl::axpb(3.0, x0, 1.0); + gretl::set_as_objective(a); + store.back_prop(); + EXPECT_NEAR(x0.get_dual(), 3.0, 1e-14); + } + + // Reset and build a completely different graph + store.reset_graph(); + + // Second graph: x0 -> 5*x0+2 + { + auto b = gretl::axpb(5.0, x0, 2.0); + gretl::set_as_objective(b); + store.back_prop(); + EXPECT_NEAR(x0.get_dual(), 5.0, 1e-14); + } +} + +// --------------------------------------------------------------------------- +// TEST SUITE: NonlinearStress +// Nonlinear operations that stress checkpoint recomputation correctness +// --------------------------------------------------------------------------- + +TEST(NonlinearStress, MultiplyChainTightBudget) +{ + // x = x * x iteratively (squaring). Very sensitive to recomputation errors + // because the function is nonlinear. + DataStore store(std::make_unique(2)); + auto x0 = store.create_state(1.1); + + auto x = gretl::axpb(1.0, x0, 0.0); + int N = 5; + for (int i = 0; i < N; ++i) { + x = x * x; // squaring + } + + // x0^(2^N) + double expected = std::pow(1.1, std::pow(2.0, N)); + EXPECT_NEAR(x.get(), expected, 1e-6); + + gretl::set_as_objective(x); + store.back_prop(); + + // d/dx0(x0^(2^N)) = 2^N * x0^(2^N - 1) + double pow2N = std::pow(2.0, N); + double expectedGrad = pow2N * std::pow(1.1, pow2N - 1); + EXPECT_NEAR(x0.get_dual(), expectedGrad, expectedGrad * 1e-6); +} + +TEST(NonlinearStress, MixedLinearNonlinear) +{ + // Alternating linear and nonlinear ops. + DataStore store(std::make_unique(4)); + auto x0 = store.create_state(0.5); + auto y0 = store.create_state(0.3); + + auto a = x0 + y0; // 0.8 + auto b = a * x0; // 0.8 * 0.5 = 0.4 + auto c = b + y0; // 0.4 + 0.3 = 0.7 + auto d = c * a; // 0.7 * 0.8 = 0.56 + + EXPECT_NEAR(d.get(), 0.56, 1e-14); + + gretl::set_as_objective(d); + store.back_prop(); + + // Numerical gradient check via finite differences + double eps = 1e-7; + + // Perturb x0 + { + double x0v = 0.5, y0v = 0.3; + auto f = [&](double x) { + double a_ = x + y0v; + double b_ = a_ * x; + double c_ = b_ + y0v; + return c_ * a_; + }; + double fd = (f(x0v + eps) - f(x0v - eps)) / (2.0 * eps); + EXPECT_NEAR(x0.get_dual(), fd, 1e-5); + } + + // Perturb y0 + { + double x0v = 0.5, y0v = 0.3; + auto f = [&](double y) { + double a_ = x0v + y; + double b_ = a_ * x0v; + double c_ = b_ + y; + return c_ * a_; + }; + double fd = (f(y0v + eps) - f(y0v - eps)) / (2.0 * eps); + EXPECT_NEAR(y0.get_dual(), fd, 1e-5); + } +} + +// --------------------------------------------------------------------------- +// TEST SUITE: LargeGraphStress +// Push the limits with large graphs and various budget ratios +// --------------------------------------------------------------------------- + +TEST(LargeGraphStress, Chain100Budget3) +{ + int N = 100; + DataStore store(std::make_unique(3)); + auto x0 = store.create_state(1.0); + + auto x = x0; + for (int i = 0; i < N; ++i) { + x = gretl::axpb(0.99, x, 0.01); + } + + gretl::set_as_objective(x); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), std::pow(0.99, N), 1e-8); +} + +TEST(LargeGraphStress, Chain200Budget5) +{ + int N = 200; + DataStore store(std::make_unique(5)); + auto x0 = store.create_state(1.0); + + auto x = x0; + for (int i = 0; i < N; ++i) { + x = gretl::axpb(0.99, x, 0.01); + } + + gretl::set_as_objective(x); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), std::pow(0.99, N), 1e-6); +} + +TEST(LargeGraphStress, NonlinearChain50Budget2) +{ + // Nonlinear chain with very tight budget: exercises checkpoint correctness + // for nonlinear functions where recomputation must match original. + int N = 50; + DataStore store(std::make_unique(2)); + auto x0 = store.create_state(0.5); + + auto x = gretl::axpb(1.0, x0, 0.0); + for (int i = 0; i < N; ++i) { + // x = x^2 + 0.1 (bounded for x0=0.5) + auto xsq = x * x; + x = gretl::axpb(1.0, xsq, 0.1); + } + + gretl::set_as_objective(x); + store.back_prop(); + + // Verify gradient via finite differences + double eps = 1e-7; + auto eval_chain = [&](double x0v) { + double xv = x0v; + for (int i = 0; i < N; ++i) { + xv = xv * xv + 0.1; + } + return xv; + }; + double fd = (eval_chain(0.5 + eps) - eval_chain(0.5 - eps)) / (2.0 * eps); + // Use relative tolerance since values can be large + if (std::abs(fd) > 1e-10) { + EXPECT_NEAR(x0.get_dual() / fd, 1.0, 1e-3); + } +} + +// --------------------------------------------------------------------------- +// TEST SUITE: VectorStateStress +// Same patterns but with vector states, exercising the initialize_zero_dual +// --------------------------------------------------------------------------- + +TEST(VectorStateStress, ChainInSubfunc) +{ + DataStore store(std::make_unique(3)); + std::vector data = {1.0, 2.0, 3.0}; + auto x0 = store.create_state(data, gretl::vec::initialize_zero_dual); + + // Build chain: x = 0.5*x + auto x = gretl::copy(x0); + for (int i = 0; i < 10; ++i) { + x = x * 0.5; + } + + auto norm = gretl::inner_product(x, x); + gretl::set_as_objective(norm); + store.back_prop(); + + // norm = (x0/2^10)^2 summed = sum(x0_i^2) / 2^20 + // d(norm)/dx0_i = 2*x0_i / 2^20 + double factor = 1.0 / std::pow(2.0, 20); + for (size_t i = 0; i < 3; ++i) { + EXPECT_NEAR(x0.get_dual()[i], 2.0 * data[i] * factor, 1e-12); + } +} + +TEST(VectorStateStress, DiamondWithVectors) +{ + DataStore store(std::make_unique(5)); + std::vector dataA = {1.0, 2.0}; + std::vector dataB = {3.0, 4.0}; + + auto a = store.create_state(dataA, gretl::vec::initialize_zero_dual); + auto b = store.create_state(dataB, gretl::vec::initialize_zero_dual); + + auto c = a + b; // {4, 6} + auto d = a * b; // {3, 8} + auto e = c + d; // {7, 14} + auto f = gretl::inner_product(e, e); // 49+196=245 + + gretl::set_as_objective(f); + store.back_prop(); + + // Check via finite differences + double eps = 1e-7; + auto eval_f = [](double a0, double a1, double b0, double b1) { + double e0 = (a0 + b0) + (a0 * b0); + double e1 = (a1 + b1) + (a1 * b1); + return e0 * e0 + e1 * e1; + }; + + double df_da0 = (eval_f(1.0 + eps, 2.0, 3.0, 4.0) - eval_f(1.0 - eps, 2.0, 3.0, 4.0)) / (2.0 * eps); + double df_da1 = (eval_f(1.0, 2.0 + eps, 3.0, 4.0) - eval_f(1.0, 2.0 - eps, 3.0, 4.0)) / (2.0 * eps); + double df_db0 = (eval_f(1.0, 2.0, 3.0 + eps, 4.0) - eval_f(1.0, 2.0, 3.0 - eps, 4.0)) / (2.0 * eps); + + EXPECT_NEAR(a.get_dual()[0], df_da0, 1e-5); + EXPECT_NEAR(a.get_dual()[1], df_da1, 1e-5); + EXPECT_NEAR(b.get_dual()[0], df_db0, 1e-5); +} + +// --------------------------------------------------------------------------- +// TEST SUITE: MultiPersistentState +// Multiple persistent states with various dependency patterns +// --------------------------------------------------------------------------- + +TEST(MultiPersistentState, ThreeInputsDeepGraph) +{ + DataStore store(std::make_unique(4)); + auto x = store.create_state(1.0); + auto y = store.create_state(2.0); + auto z = store.create_state(3.0); + + // Deep graph mixing all three inputs + auto a = x + y; // 3 + auto b = a * z; // 9 + auto c = b + x; // 10 + auto d = c * y; // 20 + auto e = d + z; // 23 + auto f = e * x; // 23 + + gretl::set_as_objective(f); + store.back_prop(); + + // Numerical gradient check + double eps = 1e-7; + auto eval = [](double xv, double yv, double zv) { + double a_ = xv + yv; + double b_ = a_ * zv; + double c_ = b_ + xv; + double d_ = c_ * yv; + double e_ = d_ + zv; + return e_ * xv; + }; + + double df_dx = (eval(1.0 + eps, 2.0, 3.0) - eval(1.0 - eps, 2.0, 3.0)) / (2.0 * eps); + double df_dy = (eval(1.0, 2.0 + eps, 3.0) - eval(1.0, 2.0 - eps, 3.0)) / (2.0 * eps); + double df_dz = (eval(1.0, 2.0, 3.0 + eps) - eval(1.0, 2.0, 3.0 - eps)) / (2.0 * eps); + + EXPECT_NEAR(x.get_dual(), df_dx, 1e-5); + EXPECT_NEAR(y.get_dual(), df_dy, 1e-5); + EXPECT_NEAR(z.get_dual(), df_dz, 1e-5); +} + +TEST(MultiPersistentState, RepeatedUseOfAllInputs) +{ + // All three persistent inputs used at multiple points in the graph. + // Exercises passthrough and lastStepUsed tracking. + DataStore store(std::make_unique(3)); + auto x = store.create_state(0.5); + auto y = store.create_state(0.3); + auto z = store.create_state(0.7); + + auto a = x * y; // early use of x, y + auto b = a + z; // early use of z + auto c = b * x; // x used again (skip) + auto d = c + y; // y used again (skip) + auto e = d * z; // z used again (skip) + + gretl::set_as_objective(e); + store.back_prop(); + + double eps = 1e-7; + auto eval = [](double xv, double yv, double zv) { + double a_ = xv * yv; + double b_ = a_ + zv; + double c_ = b_ * xv; + double d_ = c_ + yv; + return d_ * zv; + }; + + EXPECT_NEAR(x.get_dual(), (eval(0.5 + eps, 0.3, 0.7) - eval(0.5 - eps, 0.3, 0.7)) / (2 * eps), 1e-5); + EXPECT_NEAR(y.get_dual(), (eval(0.5, 0.3 + eps, 0.7) - eval(0.5, 0.3 - eps, 0.7)) / (2 * eps), 1e-5); + EXPECT_NEAR(z.get_dual(), (eval(0.5, 0.3, 0.7 + eps) - eval(0.5, 0.3, 0.7 - eps)) / (2 * eps), 1e-5); +} + +// --------------------------------------------------------------------------- +// TEST SUITE: EdgeCases +// Boundary conditions and unusual patterns +// --------------------------------------------------------------------------- + +TEST(EdgeCases, SingleStepGraph) +{ + // Minimal graph: one persistent state, one derived state. + DataStore store(std::make_unique(1)); + auto x0 = store.create_state(5.0); + auto y = gretl::axpb(2.0, x0, 3.0); + + EXPECT_NEAR(y.get(), 13.0, 1e-14); + + gretl::set_as_objective(y); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), 2.0, 1e-14); +} + +TEST(EdgeCases, TwoStepGraph) +{ + DataStore store(std::make_unique(1)); + auto x0 = store.create_state(5.0); + auto a = gretl::axpb(2.0, x0, 1.0); + auto b = gretl::axpb(3.0, a, -1.0); + + EXPECT_NEAR(b.get(), 3.0 * (2.0 * 5.0 + 1.0) - 1.0, 1e-14); + + gretl::set_as_objective(b); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), 6.0, 1e-14); +} + +TEST(EdgeCases, StateUsedOnceVsManyTimes) +{ + // Compare: state used as upstream once vs. multiple times in the same operation's dependency. + DataStore store(std::make_unique(4)); + auto x0 = store.create_state(3.0); + + // x0 * x0 = x0^2 (x0 used twice as upstream) + auto sq = x0 * x0; + EXPECT_NEAR(sq.get(), 9.0, 1e-14); + + gretl::set_as_objective(sq); + store.back_prop(); + + // d(x0^2)/dx0 = 2*x0 = 6 + EXPECT_NEAR(x0.get_dual(), 6.0, 1e-14); +} + +TEST(EdgeCases, DeepChainSingleBudget) +{ + // Budget of exactly 1 with a deep chain. + // This is the absolute minimum and forces full recomputation on every reverse step. + int N = 30; + DataStore store(std::make_unique(1)); + auto x0 = store.create_state(1.0); + + auto x = x0; + for (int i = 0; i < N; ++i) { + x = gretl::axpb(0.9, x, 0.1); + } + + gretl::set_as_objective(x); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), std::pow(0.9, N), 1e-8); +} + +// --------------------------------------------------------------------------- +// TEST SUITE: TemporaryStateInSubFunction +// Specifically exercise the pattern of creating states in called functions +// where the State objects are temporaries that go out of scope. +// --------------------------------------------------------------------------- + +// Helper: creates a temporary state, uses it, returns something derived from it +static State create_use_and_discard(const State& input, double scale) +{ + auto temp = gretl::axpb(scale, input, 0.0); // temp goes out of scope after return + auto temp2 = gretl::axpb(0.5, temp, 0.0); // temp2 goes out of scope too + return gretl::axpb(1.0, temp2, 1.0); // return final +} + +TEST(TemporaryStateInSubFunction, BasicPattern) +{ + DataStore store(std::make_unique(3)); + auto x0 = store.create_state(4.0); + + // Call sub-function 3 times in sequence + auto r1 = create_use_and_discard(x0, 2.0); // (2*4)*0.5 + 1 = 5 + auto r2 = create_use_and_discard(r1, 3.0); // (3*5)*0.5 + 1 = 8.5 + auto r3 = create_use_and_discard(r2, 1.0); // (1*8.5)*0.5 + 1 = 5.25 + + EXPECT_NEAR(r3.get(), 5.25, 1e-14); + + gretl::set_as_objective(r3); + store.back_prop(); + + // Each call: f(x) = scale*x*0.5 + 1, df/dx = scale*0.5 + // Chain: dr3/dx0 = (1*0.5) * (3*0.5) * (2*0.5) = 0.5 * 1.5 * 1.0 = 0.75 + EXPECT_NEAR(x0.get_dual(), 0.75, 1e-12); +} + +TEST(TemporaryStateInSubFunction, LoopedSubFuncCalls) +{ + DataStore store(std::make_unique(2)); + auto x0 = store.create_state(2.0); + + auto x = x0; + int N = 10; + for (int i = 0; i < N; ++i) { + x = create_use_and_discard(x, 1.0); + // Each: x -> 1.0*x*0.5 + 1 = 0.5*x + 1 + } + + double expected = 2.0; + for (int i = 0; i < N; ++i) expected = 0.5 * expected + 1.0; + EXPECT_NEAR(x.get(), expected, 1e-10); + + gretl::set_as_objective(x); + store.back_prop(); + + // Each sub-function has 3 graph steps internally but df/dx = 0.5 per call + // Total: 0.5^N + EXPECT_NEAR(x0.get_dual(), std::pow(0.5, N), 1e-10); +} + +TEST(TemporaryStateInSubFunction, MixedScopeTempsAndPersisted) +{ + // Some states held externally, others are temporaries in sub-functions. + DataStore store(std::make_unique(4)); + auto x0 = store.create_state(3.0); + + auto held = gretl::axpb(2.0, x0, 0.0); // 6.0, held in this scope + + // Sub-function creates and discards temporaries + auto fromSub = create_use_and_discard(held, 1.0); // 0.5*6 + 1 = 4.0 + + // Use both held and fromSub + auto result = held + fromSub; // 6 + 4 = 10 + + EXPECT_NEAR(result.get(), 10.0, 1e-14); + + gretl::set_as_objective(result); + store.back_prop(); + + // dresult/dx0 = dheld/dx0 + dfromSub/dx0 + // = 2 + (0.5 * 2) = 2 + 1 = 3 + EXPECT_NEAR(x0.get_dual(), 3.0, 1e-14); +} + +// --------------------------------------------------------------------------- +// TEST SUITE: DestructorTiming +// Focus on ordering of destructor calls relative to graph state +// --------------------------------------------------------------------------- + +TEST(DestructorTiming, StateDestroyedAfterBackprop) +{ + // State objects going out of scope after back_prop has completed. + // Their destructors call try_to_free, which should handle the + // post-backprop state gracefully. + DataStore store(std::make_unique(3)); + auto x0 = store.create_state(1.0); + + { + auto a = gretl::axpb(2.0, x0, 1.0); + auto b = gretl::axpb(3.0, a, -1.0); + auto c = b + a; + + gretl::set_as_objective(c); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), 8.0, 1e-14); // dc/dx0 = 3*2 + 2 = 8 + // a, b, c go out of scope here — destructors must not crash + } + + SUCCEED(); +} + +TEST(DestructorTiming, InterleavedCreationAndDestruction) +{ + // Create states, let some go out of scope, create more. + DataStore store(std::make_unique(5)); + auto x0 = store.create_state(1.0); + + // Use a lambda to create a state derived from an intermediate that goes out of scope + auto result = [&]() { + auto a = gretl::axpb(2.0, x0, 0.0); // 2.0 + return gretl::axpb(3.0, a, 0.0); // 6.0 + // a goes out of scope — but returned state still depends on it in the graph + }(); + + auto final_result = gretl::axpb(4.0, result, 0.0); // 24.0 + EXPECT_NEAR(final_result.get(), 24.0, 1e-14); + + gretl::set_as_objective(final_result); + store.back_prop(); + + EXPECT_NEAR(x0.get_dual(), 24.0, 1e-14); +} + +// --------------------------------------------------------------------------- +// Helper: high-resolution timer +// --------------------------------------------------------------------------- +static double elapsed_ms(std::chrono::steady_clock::time_point start) +{ + auto end = std::chrono::steady_clock::now(); + return std::chrono::duration(end - start).count(); +} + +// --------------------------------------------------------------------------- +// TEST SUITE: PerformanceScaling +// Measure how construction + backprop time scales with graph size and budget. +// These tests print timing info and verify correctness. They expose +// super-linear scaling in passthrough tracking, checkpoint search, and +// recomputation overhead. +// --------------------------------------------------------------------------- + +// Count forward evaluations to measure recomputation overhead +static int g_eval_count = 0; + +static State counted_step(const State& x) +{ + auto y = x.clone({x}); + + y.set_eval([](const gretl::UpstreamStates& ups, gretl::DownstreamState& ds) { + ++g_eval_count; + ds.set(0.99 * ups[0].get() + 0.01); + }); + + y.set_vjp([](gretl::UpstreamStates& ups, const gretl::DownstreamState& ds) { + ups[0].get_dual() += 0.99 * ds.get_dual(); + }); + + return y.finalize(); +} + +TEST(PerformanceScaling, LinearChainRecomputationCount) +{ + // Measure how many forward evaluations happen during backprop + // for various chain lengths and checkpoint budgets. + // Optimal (Wang 2009) is O(N * log(N) / S) recomputations for S checkpoints. + struct Config { + int N; + size_t budget; + }; + std::vector configs = { + {50, 3}, {50, 5}, {50, 10}, {50, 50}, {200, 3}, {200, 5}, {200, 10}, {200, 50}, {500, 5}, {500, 10}, {500, 50}, + }; + + std::cout << "\n--- Recomputation counts: N steps, S budget, fwd_evals (backprop), ratio=evals/N ---\n"; + + for (auto& cfg : configs) { + DataStore store(std::make_unique(cfg.budget)); + auto x0 = store.create_state(1.0); + + g_eval_count = 0; + auto x = x0; + for (int i = 0; i < cfg.N; ++i) { + x = counted_step(x); + } + int fwd_evals = g_eval_count; + + g_eval_count = 0; + gretl::set_as_objective(x); + store.back_prop(); + int backprop_evals = g_eval_count; + + double ratio = static_cast(backprop_evals) / cfg.N; + std::cout << " N=" << cfg.N << " S=" << cfg.budget << " fwd=" << fwd_evals << " back=" << backprop_evals + << " ratio=" << ratio << "\n"; + + // Gradient should still be correct + EXPECT_NEAR(x0.get_dual(), std::pow(0.99, cfg.N), std::pow(0.99, cfg.N) * 1e-6); + + // Sanity: backprop evals should be >= N-1 (minimum: everything in memory) + // and forward pass should be exactly N + EXPECT_EQ(fwd_evals, cfg.N); + } + std::cout << "---\n"; +} + +TEST(PerformanceScaling, ConstructionTimeScaling) +{ + // Measure graph construction time as N grows. + // Key concern: passthrough loop in add_state is O(distance_to_last_use) + // per upstream, which can make construction O(N^2) for skip connections. + std::cout << "\n--- Construction time scaling (linear chain, no skip connections) ---\n"; + + std::vector sizes = {100, 500, 1000, 2000, 5000}; + double prev_ms = 0; + + for (int N : sizes) { + DataStore store(std::make_unique(10)); + auto x0 = store.create_state(1.0); + + auto start = std::chrono::steady_clock::now(); + auto x = x0; + for (int i = 0; i < N; ++i) { + x = gretl::axpb(0.99, x, 0.01); + } + double ms = elapsed_ms(start); + + double ratio = (prev_ms > 0) ? ms / prev_ms : 0; + std::cout << " N=" << N << " construct=" << ms << "ms" << (prev_ms > 0 ? " ratio=" + std::to_string(ratio) : "") + << "\n"; + prev_ms = ms; + + // Should still produce correct results + gretl::set_as_objective(x); + store.back_prop(); + EXPECT_NEAR(x0.get_dual(), std::pow(0.99, N), std::pow(0.99, N) * 1e-4); + } + std::cout << "---\n"; +} + +TEST(PerformanceScaling, BackpropTimeVsBudget) +{ + // Fixed graph size, vary budget. Measures the tradeoff between + // memory (checkpoint slots) and recomputation time. + int N = 500; + std::cout << "\n--- Backprop time vs budget (N=" << N << ") ---\n"; + + std::vector budgets = {2, 3, 5, 10, 20, 50, 100, 500}; + + for (size_t budget : budgets) { + DataStore store(std::make_unique(budget)); + auto x0 = store.create_state(1.0); + + auto x = x0; + for (int i = 0; i < N; ++i) { + x = counted_step(x); + } + + g_eval_count = 0; + auto start = std::chrono::steady_clock::now(); + gretl::set_as_objective(x); + store.back_prop(); + double ms = elapsed_ms(start); + + std::cout << " budget=" << budget << " backprop=" << ms << "ms" + << " recomps=" << g_eval_count << "\n"; + + EXPECT_NEAR(x0.get_dual(), std::pow(0.99, N), std::pow(0.99, N) * 1e-4); + } + std::cout << "---\n"; +} + +TEST(PerformanceScaling, SkipConnectionPassthroughOverhead) +{ + // Build a chain where the initial persistent state is used at every step + // (persistent states skip the passthrough loop, so this should be fast). + // Then build a chain where a NON-persistent state at step 1 is used + // at every step (creating long passthrough lists). + // Compare construction times to measure passthrough overhead. + int N = 1000; + std::cout << "\n--- Skip connection passthrough overhead (N=" << N << ") ---\n"; + + // Case 1: persistent x0 used at every step (no passthrough overhead) + { + DataStore store(std::make_unique(static_cast(N + 2))); + auto x0 = store.create_state(1.0); + + auto start = std::chrono::steady_clock::now(); + auto x = gretl::axpb(1.0, x0, 0.0); + for (int i = 1; i < N; ++i) { + x = x + x0; // persistent upstream: no passthroughs + } + double ms = elapsed_ms(start); + std::cout << " persistent_skip: " << ms << "ms\n"; + + gretl::set_as_objective(x); + store.back_prop(); + // x = N*x0, dx/dx0 = N + EXPECT_NEAR(x0.get_dual(), static_cast(N), 1e-6); + } + + // Case 2: non-persistent step 1 used at every step (passthrough overhead) + { + DataStore store(std::make_unique(static_cast(N + 2))); + auto x0 = store.create_state(1.0); + auto base = gretl::axpb(1.0, x0, 0.0); // step 1 (non-persistent) + + auto start = std::chrono::steady_clock::now(); + auto x = gretl::axpb(1.0, base, 0.0); + for (int i = 1; i < N; ++i) { + x = x + base; // non-persistent upstream: passthroughs grow linearly + } + double ms = elapsed_ms(start); + std::cout << " nonpersist_skip: " << ms << "ms\n"; + + gretl::set_as_objective(x); + store.back_prop(); + EXPECT_NEAR(x0.get_dual(), static_cast(N), 1e-6); + } + + std::cout << "---\n"; +} + +TEST(PerformanceScaling, CheckpointSetOperationOverhead) +{ + // The CheckpointManager uses std::set. + // erase_step() does a linear scan O(S) per call. + // contains_step() also does O(S) linear scan. + // For large budgets, this could become a bottleneck. + // Measure backprop time with a huge budget to isolate this cost. + std::cout << "\n--- Checkpoint set overhead (large budget) ---\n"; + + std::vector sizes = {100, 500, 1000, 2000}; + + for (int N : sizes) { + // Budget = N (everything fits, no recomputation) + DataStore store(std::make_unique(static_cast(N))); + auto x0 = store.create_state(1.0); + + auto x = x0; + for (int i = 0; i < N; ++i) { + x = gretl::axpb(0.99, x, 0.01); + } + + auto start = std::chrono::steady_clock::now(); + gretl::set_as_objective(x); + store.back_prop(); + double ms = elapsed_ms(start); + + std::cout << " N=S=" << N << " backprop=" << ms << "ms\n"; + + EXPECT_NEAR(x0.get_dual(), std::pow(0.99, N), std::pow(0.99, N) * 1e-4); + } + std::cout << "---\n"; +} + +TEST(PerformanceScaling, LargeVectorStateScaling) +{ + // Measure performance with large vector states to see if the + // type-erased std::any copies dominate runtime. + std::cout << "\n--- Large vector state scaling ---\n"; + + std::vector vec_sizes = {10, 100, 1000, 10000}; + int N = 100; + + for (size_t S : vec_sizes) { + std::vector data(S, 1.0); + + DataStore store(std::make_unique(10)); + auto x0 = store.create_state(data, gretl::vec::initialize_zero_dual); + + auto start = std::chrono::steady_clock::now(); + auto x = gretl::copy(x0); + for (int i = 0; i < N; ++i) { + x = x * 0.99; + } + auto norm = gretl::inner_product(x, x); + gretl::set_as_objective(norm); + store.back_prop(); + double ms = elapsed_ms(start); + + std::cout << " vec_size=" << S << " N=" << N << " total=" << ms << "ms\n"; + + // norm = S * (0.99^N)^2 + double expected_norm = static_cast(S) * std::pow(0.99, 2 * N); + EXPECT_NEAR(norm.get(), expected_norm, expected_norm * 1e-6); + } + std::cout << "---\n"; +} + +TEST(PerformanceScaling, WideFanoutScaling) +{ + // Measure how fan-out width affects construction and backprop time. + // Each branch adds passthroughs and usageCount updates. + std::cout << "\n--- Wide fan-out scaling ---\n"; + + std::vector widths = {5, 10, 20, 50, 100}; + + for (int W : widths) { + DataStore store(std::make_unique(static_cast(2 * W + 5))); + auto x0 = store.create_state(1.0); + + auto start = std::chrono::steady_clock::now(); + State sum = gretl::axpb(1.0, x0, 0.0); + for (int i = 1; i < W; ++i) { + auto branch = gretl::axpb(static_cast(i + 1), x0, 0.0); + sum = sum + branch; + } + double construct_ms = elapsed_ms(start); + + start = std::chrono::steady_clock::now(); + gretl::set_as_objective(sum); + store.back_prop(); + double backprop_ms = elapsed_ms(start); + + double coeff = static_cast(W * (W + 1)) / 2.0; + std::cout << " W=" << W << " construct=" << construct_ms << "ms" + << " backprop=" << backprop_ms << "ms\n"; + + EXPECT_NEAR(x0.get_dual(), coeff, 1e-8); + } + std::cout << "---\n"; +} + +TEST(PerformanceScaling, DeepDAGWithMultipleInputs) +{ + // A realistic-ish DAG pattern: two inputs, alternating linear operations, + // deep chain. Measures scaling with depth. + // Uses bounded linear ops (scale < 1) to avoid overflow. + std::cout << "\n--- Deep DAG with 2 inputs ---\n"; + + std::vector depths = {50, 100, 200, 500, 1000}; + + for (int N : depths) { + DataStore store(std::make_unique(10)); + auto x = store.create_state(0.5); + auto y = store.create_state(0.3); + + auto start = std::chrono::steady_clock::now(); + auto a = x + y; // 0.8 + auto b = gretl::axpb(0.5, x, 0.0) + gretl::axpb(0.3, y, 0.0); // 0.34 + for (int i = 0; i < N; ++i) { + auto c = gretl::axpb(0.6, a, 0.0) + gretl::axpb(0.3, b, 0.0); // bounded + b = gretl::axpb(0.3, a, 0.0) + gretl::axpb(0.5, b, 0.0); // bounded + a = c; + } + auto result = a + b; + double construct_ms = elapsed_ms(start); + + start = std::chrono::steady_clock::now(); + gretl::set_as_objective(result); + store.back_prop(); + double backprop_ms = elapsed_ms(start); + + std::cout << " depth=" << N << " construct=" << construct_ms << "ms" + << " backprop=" << backprop_ms << "ms\n"; + + // Verify finite gradients (bounded recurrence) + EXPECT_TRUE(std::isfinite(x.get_dual())); + EXPECT_TRUE(std::isfinite(y.get_dual())); + } + std::cout << "---\n"; +} + +TEST(PerformanceScaling, RepeatedResetAndBackprop) +{ + // Measure the cost of reset + re-evaluation cycles. + // This is the pattern for iterative optimization. + int N = 100; + int iters = 20; + DataStore store(std::make_unique(10)); + auto x0 = store.create_state(1.0); + + auto x = x0; + for (int i = 0; i < N; ++i) { + x = gretl::axpb(0.99, x, 0.01); + } + + std::cout << "\n--- Repeated reset+backprop (N=" << N << ", iters=" << iters << ") ---\n"; + + gretl::set_as_objective(x); + store.back_prop(); + double grad0 = x0.get_dual(); + + auto start = std::chrono::steady_clock::now(); + for (int iter = 0; iter < iters; ++iter) { + store.reset(); + store.reset_for_backprop(); + x.set_dual(1.0); + store.back_prop(); + } + double ms = elapsed_ms(start); + + std::cout << " total=" << ms << "ms" + << " per_iter=" << ms / iters << "ms\n"; + std::cout << "---\n"; + + // Gradient should be unchanged each time (linear graph) + EXPECT_NEAR(x0.get_dual(), grad0, 1e-12); +} + +// --------------------------------------------------------------------------- +// TEST SUITE: VectorBottleneck +// Detailed timing breakdown for State> to identify where +// the remaining performance bottlenecks are after Phase 1 (move overloads). +// --------------------------------------------------------------------------- + +// Helper: a vector scale operation using const ref + move (best practice) +static VectorState vec_scale_move(const VectorState& a, double s) +{ + VectorState b = a.clone({a}); + + b.set_eval([s](const gretl::UpstreamStates& upstreams, gretl::DownstreamState& downstream) { + const gretl::Vector& A = upstreams[0].get(); + gretl::Vector C(A); // copy-construct (avoids zero-init of Vector(sz)) + for (auto& v : C) { + v *= s; + } + downstream.set(std::move(C)); // move into primal + }); + + b.set_vjp([s](gretl::UpstreamStates& upstreams, const gretl::DownstreamState& downstream) { + const gretl::Vector& Cbar = downstream.get_dual(); + gretl::Vector& Abar = upstreams[0].get_dual(); + for (size_t i = 0; i < Abar.size(); ++i) { + Abar[i] += s * Cbar[i]; + } + }); + + return b.finalize(); +} + +// Helper: a vector scale using old pattern (copy input + copy output) +static VectorState vec_scale_copy(const VectorState& a, double s) +{ + VectorState b = a.clone({a}); + + b.set_eval([s](const gretl::UpstreamStates& upstreams, gretl::DownstreamState& downstream) { + gretl::Vector C = upstreams[0].get(); // copy input + for (auto& v : C) { + v *= s; + } + downstream.set(C); // copy output (no move) + }); + + b.set_vjp([s](gretl::UpstreamStates& upstreams, const gretl::DownstreamState& downstream) { + const gretl::Vector& Cbar = downstream.get_dual(); + gretl::Vector& Abar = upstreams[0].get_dual(); + for (size_t i = 0; i < Abar.size(); ++i) { + Abar[i] += s * Cbar[i]; + } + }); + + return b.finalize(); +} + +TEST(VectorBottleneck, ProfileByPhase) +{ + // Break down total time into construction vs backprop for varying vector sizes. + // Chain length fixed at N=100, budget=10. + int N = 100; + size_t budget = 10; + + std::cout << "\n--- Vector bottleneck profile (N=" << N << ", budget=" << budget << ") ---\n"; + std::cout << " vec_size | construct_ms | backprop_ms | total_ms | bytes_per_vec\n"; + + std::vector vec_sizes = {100, 1000, 10000, 50000, 100000}; + + for (size_t S : vec_sizes) { + gretl::Vector data(S, 1.0); + + DataStore store(std::make_unique(budget)); + auto x0 = store.create_state(data, gretl::vec::initialize_zero_dual); + + auto start = std::chrono::steady_clock::now(); + auto x = gretl::copy(x0); + for (int i = 0; i < N; ++i) { + x = vec_scale_move(x, 0.99); + } + double construct_ms = elapsed_ms(start); + + auto norm = gretl::inner_product(x, x); + + start = std::chrono::steady_clock::now(); + gretl::set_as_objective(norm); + store.back_prop(); + double backprop_ms = elapsed_ms(start); + + size_t bytes = S * sizeof(double); + std::cout << " " << std::setw(8) << S << " | " << std::setw(12) << construct_ms << " | " << std::setw(11) + << backprop_ms << " | " << std::setw(8) << (construct_ms + backprop_ms) << " | " << std::setw(13) << bytes + << "\n"; + + // Verify correctness + double expected_norm = static_cast(S) * std::pow(0.99, 2 * N); + EXPECT_NEAR(norm.get(), expected_norm, expected_norm * 1e-6); + } + std::cout << "---\n"; +} + +TEST(VectorBottleneck, MoveVsCopy) +{ + // Compare the move-enabled eval path vs the copy path. + // This directly measures the benefit of Phase 1 move overloads. + int N = 100; + size_t budget = 10; + + std::cout << "\n--- Move vs Copy comparison (N=" << N << ", budget=" << budget << ") ---\n"; + std::cout << " vec_size | move_total_ms | copy_total_ms | speedup\n"; + + std::vector vec_sizes = {1000, 10000, 50000}; + + for (size_t S : vec_sizes) { + gretl::Vector data(S, 1.0); + + // Move path + double move_ms; + { + DataStore store(std::make_unique(budget)); + auto x0 = store.create_state(data, gretl::vec::initialize_zero_dual); + + auto start = std::chrono::steady_clock::now(); + auto x = gretl::copy(x0); + for (int i = 0; i < N; ++i) { + x = vec_scale_move(x, 0.99); + } + auto norm = gretl::inner_product(x, x); + gretl::set_as_objective(norm); + store.back_prop(); + move_ms = elapsed_ms(start); + + double expected = static_cast(S) * std::pow(0.99, 2 * N); + EXPECT_NEAR(norm.get(), expected, expected * 1e-6); + } + + // Copy path + double copy_ms; + { + DataStore store(std::make_unique(budget)); + auto x0 = store.create_state(data, gretl::vec::initialize_zero_dual); + + auto start = std::chrono::steady_clock::now(); + auto x = gretl::copy(x0); + for (int i = 0; i < N; ++i) { + x = vec_scale_copy(x, 0.99); + } + auto norm = gretl::inner_product(x, x); + gretl::set_as_objective(norm); + store.back_prop(); + copy_ms = elapsed_ms(start); + + double expected = static_cast(S) * std::pow(0.99, 2 * N); + EXPECT_NEAR(norm.get(), expected, expected * 1e-6); + } + + double speedup = copy_ms / move_ms; + std::cout << " " << std::setw(8) << S << " | " << std::setw(13) << move_ms << " | " << std::setw(13) << copy_ms + << " | " << std::setw(7) << speedup << "x\n"; + } + std::cout << "---\n"; +} + +TEST(VectorBottleneck, CloneOverhead) +{ + // Measure the cost of clone() for vector states. clone() always copies + // the primal via make_shared(*any_cast(...)), so this is a + // remaining copy that move overloads don't help. + std::cout << "\n--- Clone overhead for vector states ---\n"; + std::cout << " vec_size | clone_100x_ms | per_clone_us\n"; + + std::vector vec_sizes = {100, 1000, 10000, 50000}; + int N = 100; + + for (size_t S : vec_sizes) { + gretl::Vector data(S, 1.0); + + DataStore store(std::make_unique(static_cast(N + 5))); + auto x0 = store.create_state(data, gretl::vec::initialize_zero_dual); + + auto start = std::chrono::steady_clock::now(); + auto x = x0; + for (int i = 0; i < N; ++i) { + // clone is called inside vec_scale_move via clone({a}) + x = vec_scale_move(x, 1.0); + } + double total_ms = elapsed_ms(start); + double per_clone_us = total_ms / N * 1000.0; + + std::cout << " " << std::setw(8) << S << " | " << std::setw(13) << total_ms << " | " << std::setw(12) + << per_clone_us << "\n"; + } + std::cout << "---\n"; +} + +TEST(VectorBottleneck, CheckpointRecomputeVsNoRecompute) +{ + // Compare budget=N (all in memory, no recomputation) vs budget=5 (heavy recomputation). + // This isolates the cost of checkpoint-driven recomputation for large vector states. + int N = 100; + + std::cout << "\n--- Checkpoint recompute cost for vectors (N=" << N << ") ---\n"; + std::cout << " vec_size | budget=N_ms | budget=5_ms | recomp_overhead\n"; + + std::vector vec_sizes = {1000, 10000, 50000}; + + for (size_t S : vec_sizes) { + gretl::Vector data(S, 1.0); + + // budget = N (no recomputation) + double no_recomp_ms; + { + DataStore store(std::make_unique(static_cast(N + 5))); + auto x0 = store.create_state(data, gretl::vec::initialize_zero_dual); + auto x = gretl::copy(x0); + for (int i = 0; i < N; ++i) { + x = vec_scale_move(x, 0.99); + } + auto norm = gretl::inner_product(x, x); + gretl::set_as_objective(norm); + + auto start = std::chrono::steady_clock::now(); + store.back_prop(); + no_recomp_ms = elapsed_ms(start); + } + + // budget = 5 (heavy recomputation) + double recomp_ms; + { + DataStore store(std::make_unique(5)); + auto x0 = store.create_state(data, gretl::vec::initialize_zero_dual); + auto x = gretl::copy(x0); + for (int i = 0; i < N; ++i) { + x = vec_scale_move(x, 0.99); + } + auto norm = gretl::inner_product(x, x); + gretl::set_as_objective(norm); + + auto start = std::chrono::steady_clock::now(); + store.back_prop(); + recomp_ms = elapsed_ms(start); + } + + double overhead = recomp_ms / no_recomp_ms; + std::cout << " " << std::setw(8) << S << " | " << std::setw(11) << no_recomp_ms << " | " << std::setw(11) + << recomp_ms << " | " << std::setw(15) << overhead << "x\n"; + } + std::cout << "---\n"; +} + +TEST(VectorBottleneck, GetPrimalCopyCost) +{ + // Measure the cost of get_primal (which returns const ref, no copy) + // vs the copy that happens inside eval when reading upstream.get() + // followed by modification. This isolates the read side. + int N = 200; + size_t S = 10000; + + std::cout << "\n--- get_primal read cost (N=" << N << ", vec_size=" << S << ") ---\n"; + + gretl::Vector data(S, 1.0); + + // Pattern 1: Read-only (inner_product reads but doesn't copy vectors) + double readonly_ms; + { + DataStore store(std::make_unique(static_cast(N + 5))); + auto x0 = store.create_state(data, gretl::vec::initialize_zero_dual); + auto x = gretl::copy(x0); + // Chain of inner_products: each reads 2 vectors but writes 1 double + std::vector> norms; + for (int i = 0; i < N; ++i) { + norms.push_back(gretl::inner_product(x, x)); + } + // Sum all norms via a chain + auto total = norms[0]; + for (int i = 1; i < N; ++i) { + total = total + norms[static_cast(i)]; + } + gretl::set_as_objective(total); + auto start = std::chrono::steady_clock::now(); + store.back_prop(); + readonly_ms = elapsed_ms(start); + } + + // Pattern 2: Read + copy + write (scale operation copies the vector) + double readwrite_ms; + { + DataStore store(std::make_unique(static_cast(N + 5))); + auto x0 = store.create_state(data, gretl::vec::initialize_zero_dual); + auto x = gretl::copy(x0); + for (int i = 0; i < N; ++i) { + x = vec_scale_move(x, 1.0); // copies upstream vector, scales, moves result + } + auto norm = gretl::inner_product(x, x); + gretl::set_as_objective(norm); + auto start = std::chrono::steady_clock::now(); + store.back_prop(); + readwrite_ms = elapsed_ms(start); + } + + std::cout << " read-only (inner_product chain): " << readonly_ms << "ms\n"; + std::cout << " read+copy+write (scale chain): " << readwrite_ms << "ms\n"; + std::cout << " copy+write overhead: " << (readwrite_ms - readonly_ms) << "ms" + << " (" << (readwrite_ms / readonly_ms) << "x)\n"; + std::cout << "---\n"; +} diff --git a/src/tests/test_persistent_scope.cpp b/src/tests/test_persistent_scope.cpp new file mode 100644 index 0000000..f1199ff --- /dev/null +++ b/src/tests/test_persistent_scope.cpp @@ -0,0 +1,112 @@ +// Copyright (c) Lawrence Livermore National Security, LLC and +// other Gretl Project Developers. See the top-level LICENSE file for +// details. +// +// SPDX-License-Identifier: (BSD-3-Clause) + +#include +#include +#include +#include "gtest/gtest.h" +#include "gretl/vector_state.hpp" +#include "gretl/data_store.hpp" +#include "gretl/wang_checkpoint_strategy.hpp" +#include "gretl/test_utils.hpp" + +using gretl::print; + +// Helper function that creates the initial persistent state but doesn't return it +// This tests that the DataStore properly tracks persistent states even when +// the local State<> object goes out of scope +void create_initial_state(gretl::DataStore& dataStore, const std::vector& data) +{ + // Create persistent state - it goes into dataStore's persistent list + auto initial = dataStore.create_state(data, gretl::vec::initialize_zero_dual); + + // State object goes out of scope here, but the underlying StateData + // should still be tracked by DataStore since it's persistent +} + +TEST(PersistentScope, InitialStateGoesOutOfScope) +{ + std::vector dataA = {1.5, 2.5, 3.5}; + + gretl::DataStore dataStore(std::make_unique(10)); + + // Create the initial persistent state in a function - it goes out of scope + create_initial_state(dataStore, dataA); + + // Now create some computation steps + // The first state created after this should be at step 1 + std::vector dataB = {0.5, 0.5, 0.5}; + auto b = dataStore.create_state(dataB, gretl::vec::initialize_zero_dual); + + // Do some operations + auto c = b + b; // 2*b + auto d = c + b; // 3*b + auto e = d + c; // 5*b + + // Verify the computation + for (size_t i = 0; i < 3; ++i) { + EXPECT_NEAR(e.get()[i], 5.0 * dataB[i], 1e-14); + } + + // Set up for backpropagation + auto qoi = gretl::inner_product(e, e); + gretl::set_as_objective(qoi); + + dataStore.back_prop(); + + // Verify gradients + // qoi = e·e = (5b)·(5b) = 25(b·b) + // dqoi/db = 50b + for (size_t i = 0; i < 3; ++i) { + EXPECT_NEAR(b.get_dual()[i], 50.0 * dataB[i], 1e-13); + } + + // The test passes if we get here without ASAN errors + std::cout << "Test passed - initial persistent state properly managed" << std::endl; +} + +TEST(PersistentScope, MultipleStatesGoOutOfScope) +{ + std::vector data1 = {1.0, 2.0}; + std::vector data2 = {3.0, 4.0}; + std::vector data3 = {5.0, 6.0}; + + gretl::DataStore dataStore(std::make_unique(10)); + + // Create multiple persistent states that go out of scope + { + auto s1 = dataStore.create_state(data1, gretl::vec::initialize_zero_dual); + auto s2 = dataStore.create_state(data2, gretl::vec::initialize_zero_dual); + auto s3 = dataStore.create_state(data3, gretl::vec::initialize_zero_dual); + // All three go out of scope here + } + + // Now do some computation with new states + std::vector dataX = {0.1, 0.2}; + auto x = dataStore.create_state(dataX, gretl::vec::initialize_zero_dual); + auto y = x + x; + auto z = y * x; + + // Verify computation + for (size_t i = 0; i < 2; ++i) { + EXPECT_NEAR(y.get()[i], 2.0 * dataX[i], 1e-14); + EXPECT_NEAR(z.get()[i], 2.0 * dataX[i] * dataX[i], 1e-14); + } + + auto qoi = gretl::inner_product(z, z); + gretl::set_as_objective(qoi); + + dataStore.back_prop(); + + // qoi = z·z = (2x²)·(2x²) = 4(x²)·(x²) = 4x⁴ + // dqoi/dx = 16x³ + for (size_t i = 0; i < 2; ++i) { + double xi = dataX[i]; + EXPECT_NEAR(x.get_dual()[i], 16.0 * xi * xi * xi, 1e-13); + } + + std::cout << "Test passed - multiple persistent states properly managed" << std::endl; +}