diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 188fa58..80ac002 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -38,6 +38,10 @@ jobs: host_config: llvm@19.1.1.cmake compiler_image: ${{ needs.set_image_vars.outputs.clang_docker_image }} cmake_opts: "-DBUILD_SHARED_LIBS=ON" + - job_name: llvm@19.1.1, C++20, shared + host_config: llvm@19.1.1.cmake + compiler_image: ${{ needs.set_image_vars.outputs.clang_docker_image }} + cmake_opts: "-DBUILD_SHARED_LIBS=ON -DBLT_CXX_STD=c++20" - job_name: gcc@14.2.0, shared host_config: gcc@14.2.0.cmake compiler_image: ${{ needs.set_image_vars.outputs.gcc_docker_image }} diff --git a/src/gretl/data_store.cpp b/src/gretl/data_store.cpp index 12288ac..bfbaaaf 100644 --- a/src/gretl/data_store.cpp +++ b/src/gretl/data_store.cpp @@ -29,9 +29,9 @@ void DataStore::back_prop() template void for_each_active_upstream(const DataStore* dataStore, size_t step, const Func& func) { - for (const auto& upstream : dataStore->upstreams_[step].states()) { - if (!dataStore->is_persistent(upstream.step_)) { - func(upstream.step_); + for (Int upstreamStep : dataStore->upstreamSteps_[step]) { + if (!dataStore->is_persistent(upstreamStep)) { + func(upstreamStep); } } for (Int upstreamStepPassingThrough : dataStore->passthroughs_[step]) { @@ -98,7 +98,7 @@ void DataStore::resize(Int newSize) std::to_string(newSize) + std::string(" ") + std::to_string(currentStep_)); states_.resize(newSize); duals_.resize(newSize); - upstreams_.resize(newSize); + upstreamSteps_.resize(newSize); evals_.resize(newSize); vjps_.resize(newSize); active_.resize(newSize); @@ -119,7 +119,7 @@ void DataStore::reset_for_backprop() void DataStore::vjp(StateBase& state) { state.evaluate_vjp(); } -bool DataStore::is_persistent(Int step) const { return !upstreams_[step].size(); } +bool DataStore::is_persistent(Int step) const { return upstreamSteps_[step].empty(); } void DataStore::reverse_state() { @@ -128,7 +128,7 @@ void DataStore::reverse_state() checkpointStrategy_->erase_step(currentStep_ - 1); } --currentStep_; - if (upstreams_[currentStep_].size()) { + if (!upstreamSteps_[currentStep_].empty()) { fetch_state_data(currentStep_ - 1); vjp(*states_[currentStep_]); clear_usage(currentStep_); @@ -195,7 +195,7 @@ void DataStore::add_state(std::unique_ptr newState, const std::vector Int upstreamStep = u.step(); upstreamSteps.push_back(upstreamStep); } - upstreams_.emplace_back(*this, upstreamSteps); + upstreamSteps_.emplace_back(std::move(upstreamSteps)); for (auto& u : upstreams) { Int upstreamStep = u.step(); @@ -225,13 +225,13 @@ void DataStore::add_state(std::unique_ptr newState, const std::vector } } - evals_.emplace_back([=](const UpstreamStates&, DownstreamState&) { - std::cout << "eval not implemented for step " << currentStep_ << std::endl; + evals_.emplace_back([step](const UpstreamStates&, DownstreamState&) { + std::cout << "eval not implemented for step " << step << std::endl; gretl_assert(false); }); - vjps_.emplace_back([=](UpstreamStates&, const DownstreamState&) { - std::cout << "vjp not implemented for step " << currentStep_ << std::endl; + vjps_.emplace_back([step](UpstreamStates&, const DownstreamState&) { + std::cout << "vjp not implemented for step " << step << std::endl; gretl_assert(false); }); @@ -241,7 +241,7 @@ void DataStore::add_state(std::unique_ptr newState, const std::vector ++currentStep_; gretl_assert(currentStep_ == states_.size()); gretl_assert(currentStep_ == duals_.size()); - gretl_assert(currentStep_ == upstreams_.size()); + gretl_assert(currentStep_ == upstreamSteps_.size()); gretl_assert(currentStep_ == passthroughs_.size()); gretl_assert(currentStep_ == active_.size()); gretl_assert(currentStep_ == usageCount_.size()); @@ -356,8 +356,8 @@ void DataStore::print_graph() const std::cout << i << ", act: " << std::setw(3) << active_[i] << ":" << std::setw(3) << usageCount_[i] << ":" << std::setw(3) << states_[i]->data_.use_count() << ":" << std::setw(3) << (states_[i]->primal() != nullptr) << ", ups: "; - for (auto& v : upstreams_[i].states()) { - std::cout << v.step_ << " "; + for (Int upstreamStep : upstreamSteps_[i]) { + std::cout << upstreamStep << " "; } std::cout << ", pass: "; for (auto& v : passthroughs_[i]) { diff --git a/src/gretl/data_store.hpp b/src/gretl/data_store.hpp index e85fede..9e9955f 100644 --- a/src/gretl/data_store.hpp +++ b/src/gretl/data_store.hpp @@ -61,9 +61,9 @@ class DataStore { explicit DataStore(std::unique_ptr strategy); /// @brief virtual destructor. Must clear states_ first because StateBase - /// destructors call try_to_free() which accesses upstreams_ and other members. + /// destructors call try_to_free() which accesses upstreamSteps_ and other members. /// Without this, implicit reverse-declaration-order destruction would destroy - /// upstreams_ before states_, causing use-after-free. + /// upstreamSteps_ before states_. /// @brief virtual destructor virtual ~DataStore() { @@ -242,7 +242,7 @@ class DataStore { std::vector> states_; ///< states for steps std::vector> duals_; ///< duals for steps - std::vector upstreams_; ///< upstreams dependencies for steps + std::vector> upstreamSteps_; ///< upstream step dependencies for steps std::vector evals_; ///< forward evaluation functions for steps std::vector vjps_; ///< vector-jacobian product functions for steps std::vector active_; ///< active status for steps diff --git a/src/gretl/state_base.cpp b/src/gretl/state_base.cpp index 1da29da..e9697f1 100644 --- a/src/gretl/state_base.cpp +++ b/src/gretl/state_base.cpp @@ -12,14 +12,16 @@ namespace gretl { void StateBase::evaluate_forward() { DownstreamState ds(&data_store(), step()); - data_store().evals_[step()](data_store().upstreams_[step()], ds); + UpstreamStates upstreams(data_store(), data_store().upstreamSteps_[step()]); + data_store().evals_[step()](upstreams, ds); data_store().erase_step_state_data(step()); } void StateBase::evaluate_vjp() { const DownstreamState ds(&data_store(), step()); - data_store().vjps_[step()](data_store().upstreams_[step()], ds); + UpstreamStates upstreams(data_store(), data_store().upstreamSteps_[step()]); + data_store().vjps_[step()](upstreams, ds); } } // namespace gretl