diff --git a/src/gretl/data_store.cpp b/src/gretl/data_store.cpp index 679092f..7c03da4 100644 --- a/src/gretl/data_store.cpp +++ b/src/gretl/data_store.cpp @@ -165,11 +165,6 @@ void printv(const std::vector& v) void DataStore::try_to_free(Int step) { - // Don't try to free during destruction to avoid accessing freed memory - if (isDestroying_) { - return; - } - if (!is_persistent(step) && states_[step] && states_[step]->data_) { if (usageCount_[step] == 0 && !active_[step] && states_[step]->data_.use_count() <= 1) { states_[step]->primal() = nullptr; diff --git a/src/gretl/data_store.hpp b/src/gretl/data_store.hpp index 5f27bf0..b119cbb 100644 --- a/src/gretl/data_store.hpp +++ b/src/gretl/data_store.hpp @@ -68,17 +68,13 @@ class DataStore { /// Without this, implicit reverse-declaration-order destruction would destroy /// upstreamSteps_ before states_. /// @brief virtual destructor - virtual ~DataStore() - { - // Set flag to prevent try_to_free() from accessing freed memory during destruction - isDestroying_ = true; - } + virtual ~DataStore() { lifetimeToken_.reset(); } /// @brief create a new state in the graph, store it, return it template State create_state(const T& t, InitializeZeroDual initial_zero_dual = [](const T&) { return D{}; }) { - State state(this, states_.size(), std::make_shared(t), initial_zero_dual); + State state(this, lifetimeToken_, states_.size(), std::make_shared(t), initial_zero_dual); add_state(std::make_unique>(state), {}); return state; } @@ -116,7 +112,7 @@ class DataStore { { gretl_assert(!upstreams.empty()); auto t = std::make_shared(T{}); - State state(this, states_.size(), t, initial_zero_dual); + State state(this, lifetimeToken_, states_.size(), t, initial_zero_dual); add_state(std::make_unique>(state), upstreams); return state; } @@ -266,8 +262,7 @@ class DataStore { /// @brief specifies if graph is in construction or back-prop mode. This is used for internal asserts. bool stillConstructingGraph_ = true; - /// @brief flag to prevent accessing freed memory during destruction - bool isDestroying_ = false; + std::shared_ptr lifetimeToken_ = std::make_shared(0); friend struct StateBase; diff --git a/src/gretl/state.hpp b/src/gretl/state.hpp index c32acbd..a74aa61 100644 --- a/src/gretl/state.hpp +++ b/src/gretl/state.hpp @@ -60,7 +60,9 @@ struct State : public StateBase { { gretl_assert(!upstreams.empty()); auto new_val = std::make_shared(T{}); - State state(&data_store(), data_store().states_.size(), new_val, initialize_zero_dual_); + gretl_assert_msg(!data_.get()->lifetimeToken_.expired(), "Attempted to clone a state with an expired DataStore"); + State state(data_.get()->dataStore_, data_.get()->lifetimeToken_, data_store().states_.size(), new_val, + initialize_zero_dual_); data_store().add_state(std::make_unique>(state), upstreams); return state; } @@ -83,9 +85,9 @@ struct State : public StateBase { /// @param val type-erased value which is the data for the state /// @param initialize_zero_dual std::function which takes a primal value type T, and returns a zeroed out, but memory /// allocated dual type D - State(DataStore* store, size_t step, std::shared_ptr val, + State(DataStore* store, std::weak_ptr lifetimeToken, size_t step, std::shared_ptr val, const InitializeZeroDual& initialize_zero_dual) - : StateBase(store, val), initialize_zero_dual_(initialize_zero_dual) + : StateBase(store, std::move(lifetimeToken), val), initialize_zero_dual_(initialize_zero_dual) { reset_step(static_cast(step)); } diff --git a/src/gretl/state_base.hpp b/src/gretl/state_base.hpp index 4a74f52..0f5cf58 100644 --- a/src/gretl/state_base.hpp +++ b/src/gretl/state_base.hpp @@ -19,8 +19,12 @@ namespace gretl { /// @brief Constainer for DataStore, primal value and step index for a given state struct StateData { /// @brief constructor - StateData(DataStore* dataStore, std::shared_ptr primal) : dataStore_(dataStore), primal_(primal) {} + StateData(DataStore* dataStore, std::weak_ptr lifetimeToken, std::shared_ptr primal) + : dataStore_(dataStore), lifetimeToken_(std::move(lifetimeToken)), primal_(primal) + { + } DataStore* dataStore_; ///< datastore + std::weak_ptr lifetimeToken_; ///< datastore lifetime token std::shared_ptr primal_; ///< value, stores as shared_ptr to std::any Int step_ = std::numeric_limits::max(); ///< step }; @@ -28,7 +32,10 @@ struct StateData { /// @brief Baseclass for State. State stores type-erased value and step number in the graph. struct StateBase { /// @brief Construct state base from a date store and a type-erased values - StateBase(DataStore* store, const std::shared_ptr& val) : data_(std::make_shared(store, val)) {} + StateBase(DataStore* store, std::weak_ptr lifetimeToken, const std::shared_ptr& val) + : data_(std::make_shared(store, std::move(lifetimeToken), val)) + { + } /// @brief copy operator StateBase(const StateBase& oldState) { data_ = oldState.data_; } @@ -40,12 +47,11 @@ struct StateBase { data_ = oldState.data_; return *this; } - auto* dataStore = &data_store(); + auto* oldDataStore = data_->dataStore_; + auto oldLifetimeToken = data_->lifetimeToken_; Int s = step(); data_ = oldState.data_; - if (dataStore) { - dataStore->try_to_free(s); - } + try_to_free_if_live(oldDataStore, oldLifetimeToken, s); return *this; } @@ -55,12 +61,11 @@ struct StateBase { if (!data_) { return; } - auto* dataStore = &data_store(); + auto* oldDataStore = data_->dataStore_; + auto oldLifetimeToken = data_->lifetimeToken_; Int s = step(); data_ = nullptr; - if (dataStore) { - dataStore->try_to_free(s); - } + try_to_free_if_live(oldDataStore, oldLifetimeToken, s); } /// @brief get the underlying value @@ -112,7 +117,12 @@ struct StateBase { void evaluate_vjp(); /// @brief Datastore accessor - DataStore& data_store() const { return *data_->dataStore_; } + DataStore& data_store() const + { + auto* dataStore = lock_data_store(); + gretl_assert_msg(dataStore, "Attempted to access an expired DataStore"); + return *dataStore; + } /// @brief Get step Int step() const { return data_->step_; } @@ -131,6 +141,24 @@ struct StateBase { size_t wild_count() const { return static_cast(data_.use_count()) - 1; } protected: + static void try_to_free_if_live(DataStore* dataStore, const std::weak_ptr& lifetimeToken, Int step) + { + if (!lifetimeToken.expired()) { + dataStore->try_to_free(step); + } + } + + DataStore* lock_data_store() const + { + if (!data_) { + return nullptr; + } + if (data_->lifetimeToken_.expired()) { + return nullptr; + } + return data_->dataStore_; + } + /// @brief state data which store step, and value information. The shared_ptr allows tracking of the number of /// external usages of this state. std::shared_ptr data_; diff --git a/src/tests/test_gretl_robustness.cpp b/src/tests/test_gretl_robustness.cpp index 291e3b4..81987f4 100644 --- a/src/tests/test_gretl_robustness.cpp +++ b/src/tests/test_gretl_robustness.cpp @@ -1023,6 +1023,89 @@ TEST(DestructorTiming, InterleavedCreationAndDestruction) EXPECT_NEAR(x0.get_dual(), 24.0, 1e-14); } +TEST(DestructorTiming, StateOutlivesSharedOwner) +{ + // Regression test for teardown ordering when the last DataStore owner is a + // shared_ptr held by another object, but a copied State survives longer. + // + // Before the lifetime-token fix, the escaped state's destructor would call + // try_to_free() through a dangling DataStore* after owner destruction. ASan + // reliably reported this as a heap-use-after-free at test teardown. + struct SharedOwner { + std::shared_ptr store = std::make_shared(std::make_unique(3)); + }; + + std::unique_ptr> escaped; + + { + auto owner = std::make_unique(); + auto x0 = owner->store->create_state(2.0); + auto y = gretl::axpb(3.0, x0, 1.0); + + EXPECT_NEAR(y.get(), 7.0, 1e-14); + + escaped = std::make_unique>(y); + } + + // No explicit assertions needed here. The regression is in teardown: + // escaped is destroyed after owner/store are already gone. + SUCCEED(); +} + +TEST(DestructorTiming, LastExternalHandleDestructionStillEvictsInactiveState) +{ + DataStore store(std::make_unique(3)); + auto x0 = store.create_state(2.0); + + gretl::Int step = 0; + + { + auto held = gretl::axpb(3.0, x0, 1.0); + step = held.step(); + + // Materialize both primal and dual so the test can verify they are dropped. + EXPECT_NEAR(held.get(), 7.0, 1e-14); + EXPECT_NEAR(held.get_dual(), 0.0, 1e-14); + + store.active_[step] = false; + store.usageCount_[step] = 0; + + // With an external handle alive, try_to_free() should not evict yet. + store.try_to_free(step); + EXPECT_TRUE(store.states_[step]->primal()); + EXPECT_TRUE(store.duals_[step]); + } + + // Once the last external handle dies, its destructor should retry eviction. + EXPECT_FALSE(store.states_[step]->primal()); + EXPECT_FALSE(store.duals_[step]); +} + +TEST(AssignmentOperator, ReassigningLastExternalHandleStillEvictsInactiveState) +{ + DataStore store(std::make_unique(3)); + auto x0 = store.create_state(2.0); + + auto held = gretl::axpb(3.0, x0, 1.0); + gretl::Int old_step = held.step(); + + EXPECT_NEAR(held.get(), 7.0, 1e-14); + EXPECT_NEAR(held.get_dual(), 0.0, 1e-14); + + store.active_[old_step] = false; + store.usageCount_[old_step] = 0; + + // The old step is inactive, but the external handle still keeps it alive. + store.try_to_free(old_step); + EXPECT_TRUE(store.states_[old_step]->primal()); + EXPECT_TRUE(store.duals_[old_step]); + + held = x0; + + EXPECT_FALSE(store.states_[old_step]->primal()); + EXPECT_FALSE(store.duals_[old_step]); +} + // --------------------------------------------------------------------------- // Helper: high-resolution timer // ---------------------------------------------------------------------------