Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ jobs:
host_config: llvm@19.1.1.cmake
compiler_image: ${{ needs.set_image_vars.outputs.clang_docker_image }}
cmake_opts: "-DBUILD_SHARED_LIBS=ON"
- job_name: llvm@19.1.1, C++20, shared
host_config: llvm@19.1.1.cmake
compiler_image: ${{ needs.set_image_vars.outputs.clang_docker_image }}
cmake_opts: "-DBUILD_SHARED_LIBS=ON -DBLT_CXX_STD=c++20"
- job_name: gcc@14.2.0, shared
host_config: gcc@14.2.0.cmake
compiler_image: ${{ needs.set_image_vars.outputs.gcc_docker_image }}
Expand Down
28 changes: 14 additions & 14 deletions src/gretl/data_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ void DataStore::back_prop()
template <typename Func>
void for_each_active_upstream(const DataStore* dataStore, size_t step, const Func& func)
{
for (const auto& upstream : dataStore->upstreams_[step].states()) {
if (!dataStore->is_persistent(upstream.step_)) {
func(upstream.step_);
for (Int upstreamStep : dataStore->upstreamSteps_[step]) {
if (!dataStore->is_persistent(upstreamStep)) {
func(upstreamStep);
}
}
for (Int upstreamStepPassingThrough : dataStore->passthroughs_[step]) {
Expand Down Expand Up @@ -98,7 +98,7 @@ void DataStore::resize(Int newSize)
std::to_string(newSize) + std::string(" ") + std::to_string(currentStep_));
states_.resize(newSize);
duals_.resize(newSize);
upstreams_.resize(newSize);
upstreamSteps_.resize(newSize);
evals_.resize(newSize);
vjps_.resize(newSize);
active_.resize(newSize);
Expand All @@ -119,7 +119,7 @@ void DataStore::reset_for_backprop()

void DataStore::vjp(StateBase& state) { state.evaluate_vjp(); }

bool DataStore::is_persistent(Int step) const { return !upstreams_[step].size(); }
bool DataStore::is_persistent(Int step) const { return upstreamSteps_[step].empty(); }

void DataStore::reverse_state()
{
Expand All @@ -128,7 +128,7 @@ void DataStore::reverse_state()
checkpointStrategy_->erase_step(currentStep_ - 1);
}
--currentStep_;
if (upstreams_[currentStep_].size()) {
if (!upstreamSteps_[currentStep_].empty()) {
fetch_state_data(currentStep_ - 1);
vjp(*states_[currentStep_]);
clear_usage(currentStep_);
Expand Down Expand Up @@ -195,7 +195,7 @@ void DataStore::add_state(std::unique_ptr<StateBase> newState, const std::vector
Int upstreamStep = u.step();
upstreamSteps.push_back(upstreamStep);
}
upstreams_.emplace_back(*this, upstreamSteps);
upstreamSteps_.emplace_back(std::move(upstreamSteps));

for (auto& u : upstreams) {
Int upstreamStep = u.step();
Expand Down Expand Up @@ -225,13 +225,13 @@ void DataStore::add_state(std::unique_ptr<StateBase> newState, const std::vector
}
}

evals_.emplace_back([=](const UpstreamStates&, DownstreamState&) {
std::cout << "eval not implemented for step " << currentStep_ << std::endl;
evals_.emplace_back([step](const UpstreamStates&, DownstreamState&) {
std::cout << "eval not implemented for step " << step << std::endl;
gretl_assert(false);
});

vjps_.emplace_back([=](UpstreamStates&, const DownstreamState&) {
std::cout << "vjp not implemented for step " << currentStep_ << std::endl;
vjps_.emplace_back([step](UpstreamStates&, const DownstreamState&) {
std::cout << "vjp not implemented for step " << step << std::endl;
gretl_assert(false);
});

Expand All @@ -241,7 +241,7 @@ void DataStore::add_state(std::unique_ptr<StateBase> newState, const std::vector
++currentStep_;
gretl_assert(currentStep_ == states_.size());
gretl_assert(currentStep_ == duals_.size());
gretl_assert(currentStep_ == upstreams_.size());
gretl_assert(currentStep_ == upstreamSteps_.size());
gretl_assert(currentStep_ == passthroughs_.size());
gretl_assert(currentStep_ == active_.size());
gretl_assert(currentStep_ == usageCount_.size());
Expand Down Expand Up @@ -356,8 +356,8 @@ void DataStore::print_graph() const
std::cout << i << ", act: " << std::setw(3) << active_[i] << ":" << std::setw(3) << usageCount_[i] << ":"
<< std::setw(3) << states_[i]->data_.use_count() << ":" << std::setw(3)
<< (states_[i]->primal() != nullptr) << ", ups: ";
for (auto& v : upstreams_[i].states()) {
std::cout << v.step_ << " ";
for (Int upstreamStep : upstreamSteps_[i]) {
std::cout << upstreamStep << " ";
}
std::cout << ", pass: ";
for (auto& v : passthroughs_[i]) {
Expand Down
6 changes: 3 additions & 3 deletions src/gretl/data_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ class DataStore {
explicit DataStore(std::unique_ptr<CheckpointStrategy> strategy);

/// @brief virtual destructor. Must clear states_ first because StateBase
/// destructors call try_to_free() which accesses upstreams_ and other members.
/// destructors call try_to_free() which accesses upstreamSteps_ and other members.
/// Without this, implicit reverse-declaration-order destruction would destroy
/// upstreams_ before states_, causing use-after-free.
/// upstreamSteps_ before states_.
/// @brief virtual destructor
virtual ~DataStore()
{
Expand Down Expand Up @@ -242,7 +242,7 @@ class DataStore {

std::vector<std::unique_ptr<StateBase>> states_; ///< states for steps
std::vector<std::unique_ptr<std::any>> duals_; ///< duals for steps
std::vector<UpstreamStates> upstreams_; ///< upstreams dependencies for steps
std::vector<std::vector<Int>> upstreamSteps_; ///< upstream step dependencies for steps
std::vector<EvalT> evals_; ///< forward evaluation functions for steps
std::vector<VjpT> vjps_; ///< vector-jacobian product functions for steps
std::vector<bool> active_; ///< active status for steps
Expand Down
6 changes: 4 additions & 2 deletions src/gretl/state_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@ namespace gretl {
void StateBase::evaluate_forward()
{
DownstreamState ds(&data_store(), step());
data_store().evals_[step()](data_store().upstreams_[step()], ds);
UpstreamStates upstreams(data_store(), data_store().upstreamSteps_[step()]);
data_store().evals_[step()](upstreams, ds);
data_store().erase_step_state_data(step());
}

void StateBase::evaluate_vjp()
{
const DownstreamState ds(&data_store(), step());
data_store().vjps_[step()](data_store().upstreams_[step()], ds);
UpstreamStates upstreams(data_store(), data_store().upstreamSteps_[step()]);
data_store().vjps_[step()](upstreams, ds);
}

} // namespace gretl
Loading