Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/gretl/data_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,6 @@ void printv(const std::vector<StateBase>& v)

void DataStore::try_to_free(Int step)
{
// Don't try to free during destruction to avoid accessing freed memory
if (isDestroying_) {
return;
}

if (!is_persistent(step) && states_[step] && states_[step]->data_) {
if (usageCount_[step] == 0 && !active_[step] && states_[step]->data_.use_count() <= 1) {
states_[step]->primal() = nullptr;
Expand Down
13 changes: 4 additions & 9 deletions src/gretl/data_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,13 @@ class DataStore {
/// Without this, implicit reverse-declaration-order destruction would destroy
/// upstreamSteps_ before states_.
/// @brief virtual destructor
virtual ~DataStore()
{
// Set flag to prevent try_to_free() from accessing freed memory during destruction
isDestroying_ = true;
}
virtual ~DataStore() { lifetimeToken_.reset(); }

/// @brief create a new state in the graph, store it, return it
template <typename T, typename D>
State<T, D> create_state(const T& t, InitializeZeroDual<T, D> initial_zero_dual = [](const T&) { return D{}; })
{
State<T, D> state(this, states_.size(), std::make_shared<std::any>(t), initial_zero_dual);
State<T, D> state(this, lifetimeToken_, states_.size(), std::make_shared<std::any>(t), initial_zero_dual);
add_state(std::make_unique<State<T, D>>(state), {});
return state;
}
Expand Down Expand Up @@ -116,7 +112,7 @@ class DataStore {
{
gretl_assert(!upstreams.empty());
auto t = std::make_shared<std::any>(T{});
State<T, D> state(this, states_.size(), t, initial_zero_dual);
State<T, D> state(this, lifetimeToken_, states_.size(), t, initial_zero_dual);
add_state(std::make_unique<State<T, D>>(state), upstreams);
return state;
}
Expand Down Expand Up @@ -266,8 +262,7 @@ class DataStore {
/// @brief specifies if graph is in construction or back-prop mode. This is used for internal asserts.
bool stillConstructingGraph_ = true;

/// @brief flag to prevent accessing freed memory during destruction
bool isDestroying_ = false;
std::shared_ptr<void> lifetimeToken_ = std::make_shared<int>(0);

friend struct StateBase;

Expand Down
8 changes: 5 additions & 3 deletions src/gretl/state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ struct State : public StateBase {
{
gretl_assert(!upstreams.empty());
auto new_val = std::make_shared<std::any>(T{});
State<T, D> state(&data_store(), data_store().states_.size(), new_val, initialize_zero_dual_);
gretl_assert_msg(!data_.get()->lifetimeToken_.expired(), "Attempted to clone a state with an expired DataStore");
State<T, D> state(data_.get()->dataStore_, data_.get()->lifetimeToken_, data_store().states_.size(), new_val,
initialize_zero_dual_);
data_store().add_state(std::make_unique<State<T, D>>(state), upstreams);
return state;
}
Expand All @@ -83,9 +85,9 @@ struct State : public StateBase {
/// @param val type-erased value which is the data for the state
/// @param initialize_zero_dual std::function which takes a primal value type T, and returns a zeroed out, but memory
/// allocated dual type D
State(DataStore* store, size_t step, std::shared_ptr<std::any> val,
State(DataStore* store, std::weak_ptr<void> lifetimeToken, size_t step, std::shared_ptr<std::any> val,
const InitializeZeroDual<T, D>& initialize_zero_dual)
: StateBase(store, val), initialize_zero_dual_(initialize_zero_dual)
: StateBase(store, std::move(lifetimeToken), val), initialize_zero_dual_(initialize_zero_dual)
{
reset_step(static_cast<Int>(step));
}
Expand Down
50 changes: 39 additions & 11 deletions src/gretl/state_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,23 @@ namespace gretl {
/// @brief Constainer for DataStore, primal value and step index for a given state
struct StateData {
/// @brief constructor
StateData(DataStore* dataStore, std::shared_ptr<std::any> primal) : dataStore_(dataStore), primal_(primal) {}
StateData(DataStore* dataStore, std::weak_ptr<void> lifetimeToken, std::shared_ptr<std::any> primal)
: dataStore_(dataStore), lifetimeToken_(std::move(lifetimeToken)), primal_(primal)
{
}
DataStore* dataStore_; ///< datastore

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not store dataStore_itself as a weak_ptr? which eliminates the use of a raw pointer and may solve this lifetime issue as well.

std::weak_ptr<void> lifetimeToken_; ///< datastore lifetime token
std::shared_ptr<std::any> primal_; ///< value, stores as shared_ptr to std::any
Int step_ = std::numeric_limits<Int>::max(); ///< step
};

/// @brief Baseclass for State. State stores type-erased value and step number in the graph.
struct StateBase {
/// @brief Construct state base from a date store and a type-erased values
StateBase(DataStore* store, const std::shared_ptr<std::any>& val) : data_(std::make_shared<StateData>(store, val)) {}
StateBase(DataStore* store, std::weak_ptr<void> lifetimeToken, const std::shared_ptr<std::any>& val)
: data_(std::make_shared<StateData>(store, std::move(lifetimeToken), val))
{
}

/// @brief copy operator
StateBase(const StateBase& oldState) { data_ = oldState.data_; }
Expand All @@ -40,12 +47,11 @@ struct StateBase {
data_ = oldState.data_;
return *this;
}
auto* dataStore = &data_store();
auto* oldDataStore = data_->dataStore_;
auto oldLifetimeToken = data_->lifetimeToken_;
Int s = step();
data_ = oldState.data_;
if (dataStore) {
dataStore->try_to_free(s);
}
try_to_free_if_live(oldDataStore, oldLifetimeToken, s);
return *this;
}

Expand All @@ -55,12 +61,11 @@ struct StateBase {
if (!data_) {
return;
}
auto* dataStore = &data_store();
auto* oldDataStore = data_->dataStore_;
auto oldLifetimeToken = data_->lifetimeToken_;
Int s = step();
data_ = nullptr;
if (dataStore) {
dataStore->try_to_free(s);
}
try_to_free_if_live(oldDataStore, oldLifetimeToken, s);
}

/// @brief get the underlying value
Expand Down Expand Up @@ -112,7 +117,12 @@ struct StateBase {
void evaluate_vjp();

/// @brief Datastore accessor
DataStore& data_store() const { return *data_->dataStore_; }
DataStore& data_store() const
{
auto* dataStore = lock_data_store();
gretl_assert_msg(dataStore, "Attempted to access an expired DataStore");
return *dataStore;
}

/// @brief Get step
Int step() const { return data_->step_; }
Expand All @@ -131,6 +141,24 @@ struct StateBase {
size_t wild_count() const { return static_cast<size_t>(data_.use_count()) - 1; }

protected:
static void try_to_free_if_live(DataStore* dataStore, const std::weak_ptr<void>& lifetimeToken, Int step)
{
if (!lifetimeToken.expired()) {
dataStore->try_to_free(step);
}
}

DataStore* lock_data_store() const
{
if (!data_) {
return nullptr;
}
if (data_->lifetimeToken_.expired()) {
return nullptr;
}
return data_->dataStore_;
}

/// @brief state data which store step, and value information. The shared_ptr allows tracking of the number of
/// external usages of this state.
std::shared_ptr<StateData> data_;
Expand Down
83 changes: 83 additions & 0 deletions src/tests/test_gretl_robustness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,89 @@ TEST(DestructorTiming, InterleavedCreationAndDestruction)
EXPECT_NEAR(x0.get_dual(), 24.0, 1e-14);
}

TEST(DestructorTiming, StateOutlivesSharedOwner)
{
// Regression test for teardown ordering when the last DataStore owner is a
// shared_ptr held by another object, but a copied State survives longer.
//
// Before the lifetime-token fix, the escaped state's destructor would call
// try_to_free() through a dangling DataStore* after owner destruction. ASan
// reliably reported this as a heap-use-after-free at test teardown.
struct SharedOwner {
std::shared_ptr<DataStore> store = std::make_shared<DataStore>(std::make_unique<gretl::WangCheckpointStrategy>(3));
};

std::unique_ptr<State<double>> escaped;

{
auto owner = std::make_unique<SharedOwner>();
auto x0 = owner->store->create_state<double, double>(2.0);
auto y = gretl::axpb(3.0, x0, 1.0);

EXPECT_NEAR(y.get(), 7.0, 1e-14);

escaped = std::make_unique<State<double>>(y);
}

// No explicit assertions needed here. The regression is in teardown:
// escaped is destroyed after owner/store are already gone.
SUCCEED();
}

TEST(DestructorTiming, LastExternalHandleDestructionStillEvictsInactiveState)
{
DataStore store(std::make_unique<gretl::WangCheckpointStrategy>(3));
auto x0 = store.create_state<double, double>(2.0);

gretl::Int step = 0;

{
auto held = gretl::axpb(3.0, x0, 1.0);
step = held.step();

// Materialize both primal and dual so the test can verify they are dropped.
EXPECT_NEAR(held.get(), 7.0, 1e-14);
EXPECT_NEAR(held.get_dual(), 0.0, 1e-14);

store.active_[step] = false;
store.usageCount_[step] = 0;

// With an external handle alive, try_to_free() should not evict yet.
store.try_to_free(step);
EXPECT_TRUE(store.states_[step]->primal());
EXPECT_TRUE(store.duals_[step]);
}

// Once the last external handle dies, its destructor should retry eviction.
EXPECT_FALSE(store.states_[step]->primal());
EXPECT_FALSE(store.duals_[step]);
}

TEST(AssignmentOperator, ReassigningLastExternalHandleStillEvictsInactiveState)
{
DataStore store(std::make_unique<gretl::WangCheckpointStrategy>(3));
auto x0 = store.create_state<double, double>(2.0);

auto held = gretl::axpb(3.0, x0, 1.0);
gretl::Int old_step = held.step();

EXPECT_NEAR(held.get(), 7.0, 1e-14);
EXPECT_NEAR(held.get_dual(), 0.0, 1e-14);

store.active_[old_step] = false;
store.usageCount_[old_step] = 0;

// The old step is inactive, but the external handle still keeps it alive.
store.try_to_free(old_step);
EXPECT_TRUE(store.states_[old_step]->primal());
EXPECT_TRUE(store.duals_[old_step]);

held = x0;

EXPECT_FALSE(store.states_[old_step]->primal());
EXPECT_FALSE(store.duals_[old_step]);
}

// ---------------------------------------------------------------------------
// Helper: high-resolution timer
// ---------------------------------------------------------------------------
Expand Down
Loading