Skip to content
7 changes: 6 additions & 1 deletion src/gretl/data_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ void DataStore::resize(Int newSize)
upstreamSteps_.resize(newSize);
evals_.resize(newSize);
vjps_.resize(newSize);
requires_vjp_.resize(newSize);
active_.resize(newSize);
usageCount_.resize(newSize);
lastStepUsed_.resize(newSize);
Expand Down Expand Up @@ -133,11 +134,14 @@ void DataStore::reverse_state()
checkpointStrategy_->erase_step(currentStep_ - 1);
}
--currentStep_;
if (!upstreamSteps_[currentStep_].empty()) {
if (requires_vjp_[currentStep_] && !upstreamSteps_[currentStep_].empty()) {
fetch_state_data(currentStep_ - 1);
vjp(*states_[currentStep_]);
clear_usage(currentStep_);
checkpointStrategy_->erase_step(currentStep_ - 1);
} else if (!upstreamSteps_[currentStep_].empty()) {
clear_usage(currentStep_);
checkpointStrategy_->erase_step(currentStep_ - 1);
}
}

Expand Down Expand Up @@ -183,6 +187,7 @@ void DataStore::add_state(std::unique_ptr<StateBase> newState, const std::vector
active_.push_back(true);
lastStepUsed_.push_back(step);
passthroughs_.push_back({});
requires_vjp_.push_back(gradients_enabled());

bool persistent = upstreams.size() == 0;
if (persistent) {
Expand Down
18 changes: 18 additions & 0 deletions src/gretl/data_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class DataStore {
{
State<T, D> state(this, lifetimeToken_, states_.size(), std::make_shared<std::any>(t), initial_zero_dual);
add_state(std::make_unique<State<T, D>>(state), {});
if (!gradients_enabled()) {
state.set_vjp([](UpstreamStates&, const DownstreamState&) {});
}
return state;
}

Expand Down Expand Up @@ -114,6 +117,9 @@ class DataStore {
auto t = std::make_shared<std::any>(T{});
State<T, D> state(this, lifetimeToken_, states_.size(), t, initial_zero_dual);
add_state(std::make_unique<State<T, D>>(state), upstreams);
if (!gradients_enabled()) {
state.set_vjp([](UpstreamStates&, const DownstreamState&) {});
}
return state;
}

Expand Down Expand Up @@ -244,6 +250,7 @@ class DataStore {
std::vector<std::vector<Int>> upstreamSteps_; ///< upstream step dependencies for steps
std::vector<EvalT> evals_; ///< forward evaluation functions for steps
std::vector<VjpT> vjps_; ///< vector-jacobian product functions for steps
std::vector<bool> requires_vjp_; ///< flag to indicate if state requires VJP evaluation
std::vector<bool> active_; ///< active status for steps
std::vector<Int> usageCount_; ///< count how many times a step is used in some downstream still is the scope of the
///< checkpoint algorithm
Expand All @@ -262,6 +269,17 @@ class DataStore {
/// @brief specifies if graph is in construction or back-prop mode. This is used for internal asserts.
bool stillConstructingGraph_ = true;

/// @brief flag to control whether states compute gradients (VJP)
bool gradients_enabled_ = true;

/// @brief Query if gradients are enabled for newly created states
bool gradients_enabled() const { return gradients_enabled_; }

/// @brief Set whether gradients (VJPs) should be recorded for newly created states
void set_gradients_enabled(bool enable) { gradients_enabled_ = enable; }

/// @brief flag to prevent accessing freed memory during destruction
bool isDestroying_ = false;
std::shared_ptr<void> lifetimeToken_ = std::make_shared<int>(0);

friend struct StateBase;
Expand Down
6 changes: 5 additions & 1 deletion src/gretl/state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ struct State : public StateBase {
/// plus-equals into the upstream duals.
void set_vjp(const std::function<void(UpstreamStates& upstreams, const DownstreamState& downstream)>& v)
{
data_store().vjps_[step()] = v;
if (!data_store().gradients_enabled()) {
data_store().vjps_[step()] = [](UpstreamStates&, const DownstreamState&) {};
} else {
data_store().vjps_[step()] = v;
}
}

/// @brief Helper function to clone an existing state (keeping its type).
Expand Down
3 changes: 2 additions & 1 deletion src/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ set(gretl_test_sources
test_gretl_dynamics.cpp
test_gretl_graph.cpp
test_gretl_robustness.cpp
test_persistent_scope.cpp)
test_persistent_scope.cpp
test_tracking_disable.cpp)

foreach(test ${gretl_test_sources})
get_filename_component( test_name ${test} NAME_WE )
Expand Down
71 changes: 71 additions & 0 deletions src/tests/test_tracking_disable.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#include <gtest/gtest.h>
#include "gretl/data_store.hpp"
#include "gretl/state.hpp"
#include "gretl/create_state.hpp"
#include "gretl/wang_checkpoint_strategy.hpp"
#include "gretl/double_state.hpp"

using namespace gretl;

namespace {

State<double> picard_step(const State<double>& x, const State<double>& p, bool vjp_implemented = true)
{
if (vjp_implemented) {
return create_state<double, double>([](const double&) { return 0.0; },
[](const double& x_val, const double& p_val) {
// simple iteration: x = x * 0.5 + p
return x_val * 0.5 + p_val;
},
[](const double& /*x_val*/, const double& /*p_val*/, const double& /*f_val*/,
double& dx, double& dp, const double& df) {
dx += 0.5 * df;
dp += 1.0 * df;
},
x, p);
} else {
return create_state<double, double>(
[](const double&) { return 0.0; },
[](const double& x_val, const double& p_val) {
// simple iteration: x = x * 0.5 + p
return x_val * 0.5 + p_val;
},
[](const double&, const double&, const double&, double&, double&, const double&) {
gretl_assert_msg(false, "VJP should not be called for stop-gradient nodes");
},
x, p);
}
}

} // namespace

TEST(GraphTracking, PicardIterationNoOpVjp)
{
DataStore ds(std::make_unique<WangCheckpointStrategy>(3));

// Parameter p
auto p = ds.create_state<double, double>(0.1);

// Initial guess x0
auto x = ds.create_state<double, double>(1.0);

// Iterate 10 times without tracking gradients (stop-gradient nodes)
ds.set_gradients_enabled(false);
for (int i = 0; i < 10; ++i) {
x = picard_step(x, p, false);
}

// Re-enable gradients for the final step
ds.set_gradients_enabled(true);

// One final iteration to connect the parameter sensitivity
auto x_final = picard_step(x, p, true);
auto obj = set_as_objective(x_final);
ds.finalize_graph();
ds.back_prop();

// Since x_final = 0.5 * x_10 + p,
// dx_final / dp = 1.0 (from the direct dependency of x_final on p)
// The dependency through x_10 is killed because x_10 has no-op VJP.
EXPECT_EQ(p.get_dual(), 1.0);
}
Loading