From 2765ca758354050d1ce783294ba97db21b81746b Mon Sep 17 00:00:00 2001 From: Michael Tupek Date: Sun, 12 Apr 2026 00:18:42 -0600 Subject: [PATCH 1/8] Change tracking disable to assign no-op VJP instead of skipping graph registration --- src/gretl/data_store.hpp | 36 ++++++++++++++++++++++++++++++++++++ src/gretl/state.hpp | 6 +++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/gretl/data_store.hpp b/src/gretl/data_store.hpp index 5f27bf0..7dc25ab 100644 --- a/src/gretl/data_store.hpp +++ b/src/gretl/data_store.hpp @@ -80,6 +80,9 @@ class DataStore { { State state(this, states_.size(), std::make_shared(t), initial_zero_dual); add_state(std::make_unique>(state), {}); + if (!is_tracking()) { + state.set_vjp([](UpstreamStates&, const DownstreamState&) {}); + } return state; } @@ -118,6 +121,9 @@ class DataStore { auto t = std::make_shared(T{}); State state(this, states_.size(), t, initial_zero_dual); add_state(std::make_unique>(state), upstreams); + if (!is_tracking()) { + state.set_vjp([](UpstreamStates&, const DownstreamState&) {}); + } return state; } @@ -266,6 +272,15 @@ class DataStore { /// @brief specifies if graph is in construction or back-prop mode. This is used for internal asserts. bool stillConstructingGraph_ = true; + /// @brief flag to control whether states compute gradients (VJP) + bool is_tracking_enabled_ = true; + + /// @brief Query tracking status + bool is_tracking() const { return is_tracking_enabled_; } + + /// @brief Set tracking status + void set_tracking(bool enable) { is_tracking_enabled_ = enable; } + /// @brief flag to prevent accessing freed memory during destruction bool isDestroying_ = false; @@ -278,4 +293,25 @@ class DataStore { friend struct DownstreamState; }; +/// @brief RAII scoped guard to temporarily disable graph tracking for a DataStore. +/// When tracking is disabled, newly created states will receive a no-op VJP, +/// acting as "stop-gradient" nodes during back-propagation. They are still +/// added to the graph and managed by the checkpointer normally. +struct ScopedGraphDisable { + DataStore& store_; + bool previous_tracking_; + + explicit ScopedGraphDisable(DataStore& store) + : store_(store), previous_tracking_(store.is_tracking()) + { + store_.set_tracking(false); + } + + ~ScopedGraphDisable() { store_.set_tracking(previous_tracking_); } + + // Prevent copying + ScopedGraphDisable(const ScopedGraphDisable&) = delete; + ScopedGraphDisable& operator=(const ScopedGraphDisable&) = delete; +}; + } // namespace gretl diff --git a/src/gretl/state.hpp b/src/gretl/state.hpp index c32acbd..2d9538b 100644 --- a/src/gretl/state.hpp +++ b/src/gretl/state.hpp @@ -49,7 +49,11 @@ struct State : public StateBase { /// plus-equals into the upstream duals. void set_vjp(const std::function& v) { - data_store().vjps_[step()] = v; + if (!data_store().is_tracking()) { + data_store().vjps_[step()] = [](UpstreamStates&, const DownstreamState&) {}; + } else { + data_store().vjps_[step()] = v; + } } /// @brief Helper function to clone an existing state (keeping its type). From 0dce5f96a506e1223f4e494c3408973eefccc1b8 Mon Sep 17 00:00:00 2001 From: Michael Tupek Date: Sun, 12 Apr 2026 11:20:35 -0600 Subject: [PATCH 2/8] Rename tracking disable to gradients_enabled flag --- src/gretl/data_store.hpp | 35 +++--------- src/gretl/state.hpp | 2 +- src/tests/test_tracking_disable.cpp | 85 +++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+), 29 deletions(-) create mode 100644 src/tests/test_tracking_disable.cpp diff --git a/src/gretl/data_store.hpp b/src/gretl/data_store.hpp index 7dc25ab..3500a2f 100644 --- a/src/gretl/data_store.hpp +++ b/src/gretl/data_store.hpp @@ -80,7 +80,7 @@ class DataStore { { State state(this, states_.size(), std::make_shared(t), initial_zero_dual); add_state(std::make_unique>(state), {}); - if (!is_tracking()) { + if (!gradients_enabled()) { state.set_vjp([](UpstreamStates&, const DownstreamState&) {}); } return state; @@ -121,7 +121,7 @@ class DataStore { auto t = std::make_shared(T{}); State state(this, states_.size(), t, initial_zero_dual); add_state(std::make_unique>(state), upstreams); - if (!is_tracking()) { + if (!gradients_enabled()) { state.set_vjp([](UpstreamStates&, const DownstreamState&) {}); } return state; @@ -273,13 +273,13 @@ class DataStore { bool stillConstructingGraph_ = true; /// @brief flag to control whether states compute gradients (VJP) - bool is_tracking_enabled_ = true; + bool gradients_enabled_ = true; - /// @brief Query tracking status - bool is_tracking() const { return is_tracking_enabled_; } + /// @brief Query if gradients are enabled for newly created states + bool gradients_enabled() const { return gradients_enabled_; } - /// @brief Set tracking status - void set_tracking(bool enable) { is_tracking_enabled_ = enable; } + /// @brief Set whether gradients (VJPs) should be recorded for newly created states + void set_gradients_enabled(bool enable) { gradients_enabled_ = enable; } /// @brief flag to prevent accessing freed memory during destruction bool isDestroying_ = false; @@ -293,25 +293,4 @@ class DataStore { friend struct DownstreamState; }; -/// @brief RAII scoped guard to temporarily disable graph tracking for a DataStore. -/// When tracking is disabled, newly created states will receive a no-op VJP, -/// acting as "stop-gradient" nodes during back-propagation. They are still -/// added to the graph and managed by the checkpointer normally. -struct ScopedGraphDisable { - DataStore& store_; - bool previous_tracking_; - - explicit ScopedGraphDisable(DataStore& store) - : store_(store), previous_tracking_(store.is_tracking()) - { - store_.set_tracking(false); - } - - ~ScopedGraphDisable() { store_.set_tracking(previous_tracking_); } - - // Prevent copying - ScopedGraphDisable(const ScopedGraphDisable&) = delete; - ScopedGraphDisable& operator=(const ScopedGraphDisable&) = delete; -}; - } // namespace gretl diff --git a/src/gretl/state.hpp b/src/gretl/state.hpp index 2d9538b..8a926d8 100644 --- a/src/gretl/state.hpp +++ b/src/gretl/state.hpp @@ -49,7 +49,7 @@ struct State : public StateBase { /// plus-equals into the upstream duals. void set_vjp(const std::function& v) { - if (!data_store().is_tracking()) { + if (!data_store().gradients_enabled()) { data_store().vjps_[step()] = [](UpstreamStates&, const DownstreamState&) {}; } else { data_store().vjps_[step()] = v; diff --git a/src/tests/test_tracking_disable.cpp b/src/tests/test_tracking_disable.cpp new file mode 100644 index 0000000..9e23398 --- /dev/null +++ b/src/tests/test_tracking_disable.cpp @@ -0,0 +1,85 @@ +#include +#include "gretl/data_store.hpp" +#include "gretl/state.hpp" +#include "gretl/create_state.hpp" +#include "gretl/wang_checkpoint_strategy.hpp" + +using namespace gretl; + +TEST(GraphTracking, StopGradient) { + DataStore ds(10); + auto s1 = ds.create_state(2.0); // tracked, val=2 + auto s2 = ds.create_state(3.0); // tracked, val=3 + + // Create an untracked (stop-gradient) state manually using the ds toggle + ds.set_gradients_enabled(false); + + // This state is added to the graph but its VJP is replaced with a no-op + auto s3 = create_state( + [](const double&) { return 0.0; }, + [](const double& a, const double& b) { return a * b; }, + [](const double&, const double&, const double&, double&, double&, const double&) { + // This original VJP code would normally fail if we somehow reached it + // but we shouldn't reach it because it gets replaced with a no-op. + gretl_assert_msg(false, "VJP for stop-gradient node should never be called"); + }, + s1, s2); + + EXPECT_EQ(s3.get(), 6.0); + + // Re-enable gradients + ds.set_gradients_enabled(true); + + // Create a tracked state using s3 as an upstream + // s4 = s1 + s3 = 2.0 + 6.0 = 8.0 + auto s4 = create_state( + [](const double&) { return 0.0; }, + [](const double& a, const double& b) { return a + b; }, + [](const double&, const double&, const double&, double& a_bar, double& b_bar, const double& c_bar) { + a_bar += c_bar; + b_bar += c_bar; + }, + s1, s3); + + EXPECT_EQ(s4.get(), 8.0); + + auto obj = set_as_objective(s4); // derivative wrt s4 is 1.0 + + ds.finalize_graph(); + ds.back_prop(); + + // derivative of s4 wrt s1 directly is 1.0. + // derivative of s4 wrt s3 is 1.0. + // BUT since s3 is stop-gradient, its derivative is NOT passed back to s1 or s2. + // Normally, ds4/ds1 = 1 + ds3/ds1 = 1 + s2 = 1 + 3 = 4.0 + // With stop-gradient on s3, ds4/ds1 = 1.0. + + EXPECT_EQ(s1.get_dual(), 1.0); + + // s2 only affects s4 via s3. So its dual should be zero. + EXPECT_EQ(s2.get_dual(), 0.0); +} + + EXPECT_TRUE(ds.is_tracking()); + + // Create a tracked state using s3 as an upstream + // s4 = s1 + s3 = 2.0 + 6.0 = 8.0 + auto s4 = create_state( + [](const double&) { return 0.0; }, + [](const double& a, const double& b) { return a + b; }, + [](const double&, const double&, const double&, double& a_bar, double& b_bar, const double& c_bar) { + a_bar += c_bar; + b_bar += c_bar; + }, + s1, s3); + + EXPECT_EQ(s4.get(), 8.0); + + auto obj = set_as_objective(s4); // derivative wrt s4 is 1.0 + + ds.finalize_graph(); + ds.back_prop(); + + EXPECT_EQ(s1.get_dual(), 1.0); + EXPECT_EQ(s2.get_dual(), 0.0); +} From 4d075c4816cb09b19d03c0e6cd5800d952bb32de Mon Sep 17 00:00:00 2001 From: Michael Tupek Date: Sun, 12 Apr 2026 14:05:12 -0600 Subject: [PATCH 3/8] Optimize DataStore::reverse_state to avoid fetch_state_data for no-op VJPs --- src/gretl/data_store.cpp | 7 +- src/gretl/data_store.hpp | 1 + src/tests/CMakeLists.txt | 3 +- src/tests/test_tracking_disable.cpp | 100 ++++++++++------------------ 4 files changed, 45 insertions(+), 66 deletions(-) diff --git a/src/gretl/data_store.cpp b/src/gretl/data_store.cpp index 679092f..906b784 100644 --- a/src/gretl/data_store.cpp +++ b/src/gretl/data_store.cpp @@ -106,6 +106,7 @@ void DataStore::resize(Int newSize) upstreamSteps_.resize(newSize); evals_.resize(newSize); vjps_.resize(newSize); + requires_vjp_.resize(newSize); active_.resize(newSize); usageCount_.resize(newSize); lastStepUsed_.resize(newSize); @@ -133,11 +134,14 @@ void DataStore::reverse_state() checkpointStrategy_->erase_step(currentStep_ - 1); } --currentStep_; - if (!upstreamSteps_[currentStep_].empty()) { + if (requires_vjp_[currentStep_] && !upstreamSteps_[currentStep_].empty()) { fetch_state_data(currentStep_ - 1); vjp(*states_[currentStep_]); clear_usage(currentStep_); checkpointStrategy_->erase_step(currentStep_ - 1); + } else if (!upstreamSteps_[currentStep_].empty()) { + clear_usage(currentStep_); + checkpointStrategy_->erase_step(currentStep_ - 1); } } @@ -188,6 +192,7 @@ void DataStore::add_state(std::unique_ptr newState, const std::vector active_.push_back(true); lastStepUsed_.push_back(step); passthroughs_.push_back({}); + requires_vjp_.push_back(gradients_enabled()); bool persistent = upstreams.size() == 0; if (persistent) { diff --git a/src/gretl/data_store.hpp b/src/gretl/data_store.hpp index 3500a2f..8dd9a24 100644 --- a/src/gretl/data_store.hpp +++ b/src/gretl/data_store.hpp @@ -254,6 +254,7 @@ class DataStore { std::vector> upstreamSteps_; ///< upstream step dependencies for steps std::vector evals_; ///< forward evaluation functions for steps std::vector vjps_; ///< vector-jacobian product functions for steps + std::vector requires_vjp_; ///< flag to indicate if state requires VJP evaluation std::vector active_; ///< active status for steps std::vector usageCount_; ///< count how many times a step is used in some downstream still is the scope of the ///< checkpoint algorithm diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt index e6ae7da..a9be41e 100644 --- a/src/tests/CMakeLists.txt +++ b/src/tests/CMakeLists.txt @@ -10,7 +10,8 @@ set(gretl_test_sources test_gretl_dynamics.cpp test_gretl_graph.cpp test_gretl_robustness.cpp - test_persistent_scope.cpp) + test_persistent_scope.cpp + test_tracking_disable.cpp) foreach(test ${gretl_test_sources}) get_filename_component( test_name ${test} NAME_WE ) diff --git a/src/tests/test_tracking_disable.cpp b/src/tests/test_tracking_disable.cpp index 9e23398..8afade1 100644 --- a/src/tests/test_tracking_disable.cpp +++ b/src/tests/test_tracking_disable.cpp @@ -3,83 +3,55 @@ #include "gretl/state.hpp" #include "gretl/create_state.hpp" #include "gretl/wang_checkpoint_strategy.hpp" +#include "gretl/double_state.hpp" using namespace gretl; -TEST(GraphTracking, StopGradient) { - DataStore ds(10); - auto s1 = ds.create_state(2.0); // tracked, val=2 - auto s2 = ds.create_state(3.0); // tracked, val=3 +TEST(GraphTracking, PicardIterationNoOpVjp) { + DataStore ds(std::make_unique(3)); - // Create an untracked (stop-gradient) state manually using the ds toggle - ds.set_gradients_enabled(false); + // Parameter p + auto p = ds.create_state(0.1); - // This state is added to the graph but its VJP is replaced with a no-op - auto s3 = create_state( - [](const double&) { return 0.0; }, - [](const double& a, const double& b) { return a * b; }, - [](const double&, const double&, const double&, double&, double&, const double&) { - // This original VJP code would normally fail if we somehow reached it - // but we shouldn't reach it because it gets replaced with a no-op. - gretl_assert_msg(false, "VJP for stop-gradient node should never be called"); - }, - s1, s2); + // Initial guess x0 + auto x = ds.create_state(1.0); - EXPECT_EQ(s3.get(), 6.0); - - // Re-enable gradients + // Iterate 10 times without tracking gradients (stop-gradient nodes) + ds.set_gradients_enabled(false); + for (int i = 0; i < 10; ++i) { + x = create_state( + [](const double&) { return 0.0; }, + [](const double& x_val, const double& p_val) { + // simple iteration: x = x * 0.5 + p + return x_val * 0.5 + p_val; + }, + [](const double&, const double&, const double&, double&, double&, const double&) { + gretl_assert_msg(false, "VJP should not be called for stop-gradient nodes"); + }, + x, p); + } + + // Re-enable gradients for the final step ds.set_gradients_enabled(true); - // Create a tracked state using s3 as an upstream - // s4 = s1 + s3 = 2.0 + 6.0 = 8.0 - auto s4 = create_state( + // One final iteration to connect the parameter sensitivity + auto x_final = create_state( [](const double&) { return 0.0; }, - [](const double& a, const double& b) { return a + b; }, - [](const double&, const double&, const double&, double& a_bar, double& b_bar, const double& c_bar) { - a_bar += c_bar; - b_bar += c_bar; + [](const double& x_val, const double& p_val) { + return x_val * 0.5 + p_val; }, - s1, s3); - - EXPECT_EQ(s4.get(), 8.0); - - auto obj = set_as_objective(s4); // derivative wrt s4 is 1.0 - - ds.finalize_graph(); - ds.back_prop(); - - // derivative of s4 wrt s1 directly is 1.0. - // derivative of s4 wrt s3 is 1.0. - // BUT since s3 is stop-gradient, its derivative is NOT passed back to s1 or s2. - // Normally, ds4/ds1 = 1 + ds3/ds1 = 1 + s2 = 1 + 3 = 4.0 - // With stop-gradient on s3, ds4/ds1 = 1.0. - - EXPECT_EQ(s1.get_dual(), 1.0); - - // s2 only affects s4 via s3. So its dual should be zero. - EXPECT_EQ(s2.get_dual(), 0.0); -} - - EXPECT_TRUE(ds.is_tracking()); - - // Create a tracked state using s3 as an upstream - // s4 = s1 + s3 = 2.0 + 6.0 = 8.0 - auto s4 = create_state( - [](const double&) { return 0.0; }, - [](const double& a, const double& b) { return a + b; }, - [](const double&, const double&, const double&, double& a_bar, double& b_bar, const double& c_bar) { - a_bar += c_bar; - b_bar += c_bar; + [](const double& /*x_val*/, const double& /*p_val*/, const double& /*f_val*/, double& dx, double& dp, const double& df) { + dx += 0.5 * df; + dp += 1.0 * df; }, - s1, s3); + x, p); - EXPECT_EQ(s4.get(), 8.0); - - auto obj = set_as_objective(s4); // derivative wrt s4 is 1.0 - + auto obj = set_as_objective(x_final); ds.finalize_graph(); ds.back_prop(); - EXPECT_EQ(s1.get_dual(), 1.0); - EXPECT_EQ(s2.get_dual(), 0.0); + // Since x_final = 0.5 * x_10 + p, + // dx_final / dp = 1.0 (from the direct dependency of x_final on p) + // The dependency through x_10 is killed because x_10 has no-op VJP. + EXPECT_EQ(p.get_dual(), 1.0); } From f38d9a386478cf1084402d8f11cf0f1273bc1c43 Mon Sep 17 00:00:00 2001 From: mrtupek2 <125162206+mrtupek2@users.noreply.github.com> Date: Mon, 13 Apr 2026 12:00:13 -0600 Subject: [PATCH 4/8] Apply suggestions from code review Co-authored-by: mrtupek2 <125162206+mrtupek2@users.noreply.github.com> --- .github/workflows/ci-tests.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 80ac002..188fa58 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -38,10 +38,6 @@ jobs: host_config: llvm@19.1.1.cmake compiler_image: ${{ needs.set_image_vars.outputs.clang_docker_image }} cmake_opts: "-DBUILD_SHARED_LIBS=ON" - - job_name: llvm@19.1.1, C++20, shared - host_config: llvm@19.1.1.cmake - compiler_image: ${{ needs.set_image_vars.outputs.clang_docker_image }} - cmake_opts: "-DBUILD_SHARED_LIBS=ON -DBLT_CXX_STD=c++20" - job_name: gcc@14.2.0, shared host_config: gcc@14.2.0.cmake compiler_image: ${{ needs.set_image_vars.outputs.gcc_docker_image }} From c6f48bf5cff41514ea4821b49411f9c1b6ed9fcc Mon Sep 17 00:00:00 2001 From: mrtupek2 <125162206+mrtupek2@users.noreply.github.com> Date: Mon, 13 Apr 2026 12:03:03 -0600 Subject: [PATCH 5/8] Apply suggestion from @mrtupek2 --- .github/workflows/ci-tests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 188fa58..0f5c9cb 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -38,7 +38,6 @@ jobs: host_config: llvm@19.1.1.cmake compiler_image: ${{ needs.set_image_vars.outputs.clang_docker_image }} cmake_opts: "-DBUILD_SHARED_LIBS=ON" - - job_name: gcc@14.2.0, shared host_config: gcc@14.2.0.cmake compiler_image: ${{ needs.set_image_vars.outputs.gcc_docker_image }} cmake_opts: "-DBUILD_SHARED_LIBS=ON" From ebccfb9883a82ac9a17e3afdace019ce667ee710 Mon Sep 17 00:00:00 2001 From: Michael Tupek Date: Mon, 13 Apr 2026 12:03:36 -0600 Subject: [PATCH 6/8] Put it back. --- .github/workflows/ci-tests.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 0f5c9cb..80ac002 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -38,6 +38,11 @@ jobs: host_config: llvm@19.1.1.cmake compiler_image: ${{ needs.set_image_vars.outputs.clang_docker_image }} cmake_opts: "-DBUILD_SHARED_LIBS=ON" + - job_name: llvm@19.1.1, C++20, shared + host_config: llvm@19.1.1.cmake + compiler_image: ${{ needs.set_image_vars.outputs.clang_docker_image }} + cmake_opts: "-DBUILD_SHARED_LIBS=ON -DBLT_CXX_STD=c++20" + - job_name: gcc@14.2.0, shared host_config: gcc@14.2.0.cmake compiler_image: ${{ needs.set_image_vars.outputs.gcc_docker_image }} cmake_opts: "-DBUILD_SHARED_LIBS=ON" From c1d8af52ac7c26ddf30a52b8cbfd4ddf977c7a91 Mon Sep 17 00:00:00 2001 From: Michael Tupek Date: Mon, 13 Apr 2026 13:05:13 -0600 Subject: [PATCH 7/8] Slightly simplify picard. --- src/tests/test_tracking_disable.cpp | 99 +++++++++++++++++------------ 1 file changed, 57 insertions(+), 42 deletions(-) diff --git a/src/tests/test_tracking_disable.cpp b/src/tests/test_tracking_disable.cpp index 8afade1..2cd27ae 100644 --- a/src/tests/test_tracking_disable.cpp +++ b/src/tests/test_tracking_disable.cpp @@ -7,51 +7,66 @@ using namespace gretl; -TEST(GraphTracking, PicardIterationNoOpVjp) { - DataStore ds(std::make_unique(3)); - - // Parameter p - auto p = ds.create_state(0.1); - - // Initial guess x0 - auto x = ds.create_state(1.0); - - // Iterate 10 times without tracking gradients (stop-gradient nodes) - ds.set_gradients_enabled(false); - for (int i = 0; i < 10; ++i) { - x = create_state( - [](const double&) { return 0.0; }, - [](const double& x_val, const double& p_val) { - // simple iteration: x = x * 0.5 + p - return x_val * 0.5 + p_val; - }, - [](const double&, const double&, const double&, double&, double&, const double&) { - gretl_assert_msg(false, "VJP should not be called for stop-gradient nodes"); - }, - x, p); - } - - // Re-enable gradients for the final step - ds.set_gradients_enabled(true); - - // One final iteration to connect the parameter sensitivity - auto x_final = create_state( +namespace { + +State picard_step(const State& x, const State& p, bool vjp_implemented = true) +{ + if (vjp_implemented) { + return create_state( + [](const double&) { return 0.0; }, + [](const double& x_val, const double& p_val) { + // simple iteration: x = x * 0.5 + p + return x_val * 0.5 + p_val; + }, + [](const double& /*x_val*/, const double& /*p_val*/, const double& /*f_val*/, double& dx, double& dp, + const double& df) { + dx += 0.5 * df; + dp += 1.0 * df; + }, + x, p); + } else { + return create_state( [](const double&) { return 0.0; }, [](const double& x_val, const double& p_val) { - return x_val * 0.5 + p_val; + // simple iteration: x = x * 0.5 + p + return x_val * 0.5 + p_val; }, - [](const double& /*x_val*/, const double& /*p_val*/, const double& /*f_val*/, double& dx, double& dp, const double& df) { - dx += 0.5 * df; - dp += 1.0 * df; + [](const double&, const double&, const double&, double&, double&, const double&) { + gretl_assert_msg(false, "VJP should not be called for stop-gradient nodes"); }, x, p); - - auto obj = set_as_objective(x_final); - ds.finalize_graph(); - ds.back_prop(); - - // Since x_final = 0.5 * x_10 + p, - // dx_final / dp = 1.0 (from the direct dependency of x_final on p) - // The dependency through x_10 is killed because x_10 has no-op VJP. - EXPECT_EQ(p.get_dual(), 1.0); + } +} + +} // namespace + +TEST(GraphTracking, PicardIterationNoOpVjp) +{ + DataStore ds(std::make_unique(3)); + + // Parameter p + auto p = ds.create_state(0.1); + + // Initial guess x0 + auto x = ds.create_state(1.0); + + // Iterate 10 times without tracking gradients (stop-gradient nodes) + ds.set_gradients_enabled(false); + for (int i = 0; i < 10; ++i) { + x = picard_step(x, p, false); + } + + // Re-enable gradients for the final step + ds.set_gradients_enabled(true); + + // One final iteration to connect the parameter sensitivity + auto x_final = picard_step(x, p, true); + auto obj = set_as_objective(x_final); + ds.finalize_graph(); + ds.back_prop(); + + // Since x_final = 0.5 * x_10 + p, + // dx_final / dp = 1.0 (from the direct dependency of x_final on p) + // The dependency through x_10 is killed because x_10 has no-op VJP. + EXPECT_EQ(p.get_dual(), 1.0); } From 25378c3d5d0dcaf0b5489d61233aeab686a8711f Mon Sep 17 00:00:00 2001 From: Michael Tupek Date: Mon, 13 Apr 2026 12:48:57 -0700 Subject: [PATCH 8/8] Update style. --- src/tests/test_tracking_disable.cpp | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/tests/test_tracking_disable.cpp b/src/tests/test_tracking_disable.cpp index 2cd27ae..fd8f967 100644 --- a/src/tests/test_tracking_disable.cpp +++ b/src/tests/test_tracking_disable.cpp @@ -12,18 +12,17 @@ namespace { State picard_step(const State& x, const State& p, bool vjp_implemented = true) { if (vjp_implemented) { - return create_state( - [](const double&) { return 0.0; }, - [](const double& x_val, const double& p_val) { - // simple iteration: x = x * 0.5 + p - return x_val * 0.5 + p_val; - }, - [](const double& /*x_val*/, const double& /*p_val*/, const double& /*f_val*/, double& dx, double& dp, - const double& df) { - dx += 0.5 * df; - dp += 1.0 * df; - }, - x, p); + return create_state([](const double&) { return 0.0; }, + [](const double& x_val, const double& p_val) { + // simple iteration: x = x * 0.5 + p + return x_val * 0.5 + p_val; + }, + [](const double& /*x_val*/, const double& /*p_val*/, const double& /*f_val*/, + double& dx, double& dp, const double& df) { + dx += 0.5 * df; + dp += 1.0 * df; + }, + x, p); } else { return create_state( [](const double&) { return 0.0; },