diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 188fa58..80ac002 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -38,6 +38,10 @@ jobs: host_config: llvm@19.1.1.cmake compiler_image: ${{ needs.set_image_vars.outputs.clang_docker_image }} cmake_opts: "-DBUILD_SHARED_LIBS=ON" + - job_name: llvm@19.1.1, C++20, shared + host_config: llvm@19.1.1.cmake + compiler_image: ${{ needs.set_image_vars.outputs.clang_docker_image }} + cmake_opts: "-DBUILD_SHARED_LIBS=ON -DBLT_CXX_STD=c++20" - job_name: gcc@14.2.0, shared host_config: gcc@14.2.0.cmake compiler_image: ${{ needs.set_image_vars.outputs.gcc_docker_image }} diff --git a/src/gretl/CMakeLists.txt b/src/gretl/CMakeLists.txt index 12e9d4c..85b0be5 100644 --- a/src/gretl/CMakeLists.txt +++ b/src/gretl/CMakeLists.txt @@ -34,7 +34,6 @@ set(gretl_headers state_base.hpp state.hpp test_utils.hpp - upstream_state.hpp vector_state.hpp) blt_add_library(NAME gretl diff --git a/src/gretl/create_state.hpp b/src/gretl/create_state.hpp index a96cd8f..2aa3f0e 100644 --- a/src/gretl/create_state.hpp +++ b/src/gretl/create_state.hpp @@ -11,7 +11,6 @@ #pragma once #include "data_store.hpp" -#include "upstream_state.hpp" #include namespace gretl { diff --git a/src/gretl/data_store.cpp b/src/gretl/data_store.cpp index 12288ac..9cb4b9c 100644 --- a/src/gretl/data_store.cpp +++ b/src/gretl/data_store.cpp @@ -4,9 +4,9 @@ // // SPDX-License-Identifier: (BSD-3-Clause) -#include #include "data_store.hpp" #include "state.hpp" +#include #include #include @@ -225,12 +225,12 @@ void DataStore::add_state(std::unique_ptr newState, const std::vector } } - evals_.emplace_back([=](const UpstreamStates&, DownstreamState&) { + evals_.emplace_back([this](const UpstreamStates&, DownstreamState&) { std::cout << "eval not implemented for step " << currentStep_ << std::endl; gretl_assert(false); }); - vjps_.emplace_back([=](UpstreamStates&, const DownstreamState&) { + vjps_.emplace_back([this](UpstreamStates&, const DownstreamState&) { std::cout << "vjp not implemented for step " << currentStep_ << std::endl; gretl_assert(false); }); diff --git a/src/gretl/data_store.hpp b/src/gretl/data_store.hpp index e85fede..1acd78d 100644 --- a/src/gretl/data_store.hpp +++ b/src/gretl/data_store.hpp @@ -31,14 +31,96 @@ namespace gretl { using Int = unsigned int; ///< gretl Int type +class DataStore; + struct StateBase; template struct State; -struct UpstreamStates; +/// @brief UpstreamState is a wrapper for a states. Its used in external-facing interfaces to ensure const correctness +/// for users to encourage correct usage. +struct UpstreamState { + Int step_; ///< step + DataStore* dataStore_; ///< datastore + + /// @brief get underlying value + template + const T& get() const; + + /// @brief get underlying dual value + template + D& get_dual() const; +}; + +/// @brief UpstreamStates is a wrapper for a vector of states. Its used in external-facing interfaces to ensure const +/// correctness for users to encourage correct usage. +struct UpstreamStates { + /// @brief Default constructor to use in std containers + UpstreamStates() = default; + + /// @brief Constructor for upstream states + /// @param store datastore + /// @param steps vector of upstream steps + UpstreamStates(DataStore& store, std::vector steps) + { + for (Int s : steps) { + states_.push_back({s, &store}); + } + } + + /// @brief Accessor for individual upstream states + /// @param index index + template + const UpstreamState& operator[](IntT index) const + { + return states_[static_cast(index)]; + } + + /// @brief Accessor for individual upstream states + /// @param index index + const UpstreamState& operator[](Int index) const { return states_[index]; } + + /// @brief Number of upstream states + Int size() const { return static_cast(states_.size()); } + + /// @brief Vector of upstream step indices + const std::vector& states() const { return states_; } + + private: + std::vector states_; ///< states +}; + +/// @brief DownstreamState is a wrapper for a state. Its used in external-facing interfaces to ensure const correctness +/// for users to encourage correct usage. +struct DownstreamState { + /// @brief Constructor + /// @param s datastore + /// @param step step + DownstreamState(DataStore* s, Int step) : dataStore_(s), step_(step) {} + + /// @brief set underlying value (copy) + template + void set(const T& t); + + /// @brief set underlying value (move) + template > + void set(T&& t); + + /// @brief get underlying value + template + const T& get() const; + + /// @brief get underlying dual value + template + const D& get_dual() const; -struct DownstreamState; + friend class DataStore; + + private: + DataStore* dataStore_; ///< datastore + Int step_; ///< step +}; /// @brief ZeroDual function type template @@ -275,4 +357,40 @@ class DataStore { friend struct DownstreamState; }; +template +const T& UpstreamState::get() const +{ + return dataStore_->get_primal(step_); +} + +template +D& UpstreamState::get_dual() const +{ + return dataStore_->get_dual(step_); +} + +template +void DownstreamState::set(const T& t) +{ + dataStore_->set_primal(step_, t); +} + +template +void DownstreamState::set(T&& t) +{ + dataStore_->set_primal(step_, std::forward(t)); +} + +template +const T& DownstreamState::get() const +{ + return dataStore_->get_primal(step_); +} + +template +const D& DownstreamState::get_dual() const +{ + return dataStore_->get_dual(step_); +} + } // namespace gretl diff --git a/src/gretl/state.hpp b/src/gretl/state.hpp index c32acbd..ee78d35 100644 --- a/src/gretl/state.hpp +++ b/src/gretl/state.hpp @@ -11,9 +11,7 @@ #pragma once #include -#include "upstream_state.hpp" #include "state_base.hpp" -#include "upstream_state.hpp" namespace gretl { diff --git a/src/gretl/state_base.cpp b/src/gretl/state_base.cpp index 1da29da..b1f09e1 100644 --- a/src/gretl/state_base.cpp +++ b/src/gretl/state_base.cpp @@ -5,7 +5,6 @@ // SPDX-License-Identifier: (BSD-3-Clause) #include "state_base.hpp" -#include "upstream_state.hpp" namespace gretl { diff --git a/src/gretl/upstream_state.hpp b/src/gretl/upstream_state.hpp deleted file mode 100644 index c5235e4..0000000 --- a/src/gretl/upstream_state.hpp +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) Lawrence Livermore National Security, LLC and -// other Gretl Project Developers. See the top-level LICENSE file for -// details. -// -// SPDX-License-Identifier: (BSD-3-Clause) - -/** - * @file upstream_state.hpp - */ - -#pragma once - -#include -#include "data_store.hpp" - -namespace gretl { - -/// @brief UpstreamState is a wrapper for a states. Its used in external-facing interfaces to ensure const correctness -/// for users to encourage correct usage. -struct UpstreamState { - Int step_; ///< step - DataStore* dataStore_; ///< datastore - - /// @brief get underlying value - template - const T& get() const - { - return dataStore_->get_primal(step_); - } - - /// @brief get underlying dual value - template - D& get_dual() const - { - return dataStore_->get_dual(step_); - } -}; - -/// @brief UpstreamStates is a wrapper for a vector of states. Its used in external-facing interfaces to ensure const -/// correctness for users to encourage correct usage. -struct UpstreamStates { - /// @brief Default constructor to use in std containers - UpstreamStates() {} - - /// @brief Constructor for upstream states - /// @param store datastore - /// @param steps vector of upstream steps - UpstreamStates(DataStore& store, std::vector steps) - { - for (Int s : steps) { - states_.push_back({s, &store}); - } - } - - /// @brief Accessor for individual upstream states - /// @param index index - template - const UpstreamState& operator[](IntT index) const - { - return states_[static_cast(index)]; - } - - /// @brief Accessor for individual upstream states - /// @param index index - const UpstreamState& operator[](Int index) const { return states_[index]; } - - /// @brief Number of upstream states - Int size() const { return static_cast(states_.size()); } - - /// @brief Vector of upstream step indices - const std::vector& states() const { return states_; } - - private: - std::vector states_; ///< states -}; - -/// @brief DownstreamState is a wrapper for a state. Its used in external-facing interfaces to ensure const correctness -/// for users to encourage correct usage. -struct DownstreamState { - /// @brief Constructor - /// @param s datastore - /// @param step step - DownstreamState(DataStore* s, Int step) : dataStore_(s), step_(step) {} - - /// @brief set underlying value (copy) - template - void set(const T& t) - { - dataStore_->set_primal(step_, t); - } - - /// @brief set underlying value (move) - template > - void set(T&& t) - { - dataStore_->set_primal(step_, std::forward(t)); - } - - /// @brief get underlying value - template - const T& get() const - { - return dataStore_->get_primal(step_); - } - - /// @brief get underlying dual value - template - const D& get_dual() const - { - return dataStore_->get_dual(step_); - } - - friend class DataStore; - - private: - DataStore* dataStore_; ///< datastore - Int step_; ///< step -}; - -} // namespace gretl