diff --git a/src/gretl/data_store.cpp b/src/gretl/data_store.cpp index 7c03da4..47d2ed7 100644 --- a/src/gretl/data_store.cpp +++ b/src/gretl/data_store.cpp @@ -106,6 +106,7 @@ void DataStore::resize(Int newSize) upstreamSteps_.resize(newSize); evals_.resize(newSize); vjps_.resize(newSize); + requires_vjp_.resize(newSize); active_.resize(newSize); usageCount_.resize(newSize); lastStepUsed_.resize(newSize); @@ -133,11 +134,14 @@ void DataStore::reverse_state() checkpointStrategy_->erase_step(currentStep_ - 1); } --currentStep_; - if (!upstreamSteps_[currentStep_].empty()) { + if (requires_vjp_[currentStep_] && !upstreamSteps_[currentStep_].empty()) { fetch_state_data(currentStep_ - 1); vjp(*states_[currentStep_]); clear_usage(currentStep_); checkpointStrategy_->erase_step(currentStep_ - 1); + } else if (!upstreamSteps_[currentStep_].empty()) { + clear_usage(currentStep_); + checkpointStrategy_->erase_step(currentStep_ - 1); } } @@ -183,6 +187,7 @@ void DataStore::add_state(std::unique_ptr newState, const std::vector active_.push_back(true); lastStepUsed_.push_back(step); passthroughs_.push_back({}); + requires_vjp_.push_back(gradients_enabled()); bool persistent = upstreams.size() == 0; if (persistent) { diff --git a/src/gretl/data_store.hpp b/src/gretl/data_store.hpp index b119cbb..868839d 100644 --- a/src/gretl/data_store.hpp +++ b/src/gretl/data_store.hpp @@ -76,6 +76,9 @@ class DataStore { { State state(this, lifetimeToken_, states_.size(), std::make_shared(t), initial_zero_dual); add_state(std::make_unique>(state), {}); + if (!gradients_enabled()) { + state.set_vjp([](UpstreamStates&, const DownstreamState&) {}); + } return state; } @@ -114,6 +117,9 @@ class DataStore { auto t = std::make_shared(T{}); State state(this, lifetimeToken_, states_.size(), t, initial_zero_dual); add_state(std::make_unique>(state), upstreams); + if (!gradients_enabled()) { + state.set_vjp([](UpstreamStates&, const DownstreamState&) {}); + } return state; } @@ -244,6 +250,7 @@ class DataStore { std::vector> upstreamSteps_; ///< upstream step dependencies for steps std::vector evals_; ///< forward evaluation functions for steps std::vector vjps_; ///< vector-jacobian product functions for steps + std::vector requires_vjp_; ///< flag to indicate if state requires VJP evaluation std::vector active_; ///< active status for steps std::vector usageCount_; ///< count how many times a step is used in some downstream still is the scope of the ///< checkpoint algorithm @@ -262,6 +269,17 @@ class DataStore { /// @brief specifies if graph is in construction or back-prop mode. This is used for internal asserts. bool stillConstructingGraph_ = true; + /// @brief flag to control whether states compute gradients (VJP) + bool gradients_enabled_ = true; + + /// @brief Query if gradients are enabled for newly created states + bool gradients_enabled() const { return gradients_enabled_; } + + /// @brief Set whether gradients (VJPs) should be recorded for newly created states + void set_gradients_enabled(bool enable) { gradients_enabled_ = enable; } + + /// @brief flag to prevent accessing freed memory during destruction + bool isDestroying_ = false; std::shared_ptr lifetimeToken_ = std::make_shared(0); friend struct StateBase; diff --git a/src/gretl/state.hpp b/src/gretl/state.hpp index a74aa61..401b11b 100644 --- a/src/gretl/state.hpp +++ b/src/gretl/state.hpp @@ -49,7 +49,11 @@ struct State : public StateBase { /// plus-equals into the upstream duals. void set_vjp(const std::function& v) { - data_store().vjps_[step()] = v; + if (!data_store().gradients_enabled()) { + data_store().vjps_[step()] = [](UpstreamStates&, const DownstreamState&) {}; + } else { + data_store().vjps_[step()] = v; + } } /// @brief Helper function to clone an existing state (keeping its type). diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index e6ae7da..a9be41e 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -10,7 +10,8 @@ set(gretl_test_sources test_gretl_dynamics.cpp test_gretl_graph.cpp test_gretl_robustness.cpp - test_persistent_scope.cpp) + test_persistent_scope.cpp + test_tracking_disable.cpp) foreach(test ${gretl_test_sources}) get_filename_component( test_name ${test} NAME_WE ) diff --git a/src/tests/test_tracking_disable.cpp b/src/tests/test_tracking_disable.cpp new file mode 100644 index 0000000..fd8f967 --- /dev/null +++ b/src/tests/test_tracking_disable.cpp @@ -0,0 +1,71 @@ +#include +#include "gretl/data_store.hpp" +#include "gretl/state.hpp" +#include "gretl/create_state.hpp" +#include "gretl/wang_checkpoint_strategy.hpp" +#include "gretl/double_state.hpp" + +using namespace gretl; + +namespace { + +State picard_step(const State& x, const State& p, bool vjp_implemented = true) +{ + if (vjp_implemented) { + return create_state([](const double&) { return 0.0; }, + [](const double& x_val, const double& p_val) { + // simple iteration: x = x * 0.5 + p + return x_val * 0.5 + p_val; + }, + [](const double& /*x_val*/, const double& /*p_val*/, const double& /*f_val*/, + double& dx, double& dp, const double& df) { + dx += 0.5 * df; + dp += 1.0 * df; + }, + x, p); + } else { + return create_state( + [](const double&) { return 0.0; }, + [](const double& x_val, const double& p_val) { + // simple iteration: x = x * 0.5 + p + return x_val * 0.5 + p_val; + }, + [](const double&, const double&, const double&, double&, double&, const double&) { + gretl_assert_msg(false, "VJP should not be called for stop-gradient nodes"); + }, + x, p); + } +} + +} // namespace + +TEST(GraphTracking, PicardIterationNoOpVjp) +{ + DataStore ds(std::make_unique(3)); + + // Parameter p + auto p = ds.create_state(0.1); + + // Initial guess x0 + auto x = ds.create_state(1.0); + + // Iterate 10 times without tracking gradients (stop-gradient nodes) + ds.set_gradients_enabled(false); + for (int i = 0; i < 10; ++i) { + x = picard_step(x, p, false); + } + + // Re-enable gradients for the final step + ds.set_gradients_enabled(true); + + // One final iteration to connect the parameter sensitivity + auto x_final = picard_step(x, p, true); + auto obj = set_as_objective(x_final); + ds.finalize_graph(); + ds.back_prop(); + + // Since x_final = 0.5 * x_10 + p, + // dx_final / dp = 1.0 (from the direct dependency of x_final on p) + // The dependency through x_10 is killed because x_10 has no-op VJP. + EXPECT_EQ(p.get_dual(), 1.0); +}