Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ jobs:
host_config: llvm@19.1.1.cmake
compiler_image: ${{ needs.set_image_vars.outputs.clang_docker_image }}
cmake_opts: "-DBUILD_SHARED_LIBS=ON"
- job_name: llvm@19.1.1, C++20, shared
host_config: llvm@19.1.1.cmake
compiler_image: ${{ needs.set_image_vars.outputs.clang_docker_image }}
cmake_opts: "-DBUILD_SHARED_LIBS=ON -DBLT_CXX_STD=c++20"
- job_name: gcc@14.2.0, shared
host_config: gcc@14.2.0.cmake
compiler_image: ${{ needs.set_image_vars.outputs.gcc_docker_image }}
Expand Down
1 change: 0 additions & 1 deletion src/gretl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ set(gretl_headers
state_base.hpp
state.hpp
test_utils.hpp
upstream_state.hpp
vector_state.hpp)

blt_add_library(NAME gretl
Expand Down
1 change: 0 additions & 1 deletion src/gretl/create_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#pragma once

#include "data_store.hpp"
#include "upstream_state.hpp"
#include <functional>

namespace gretl {
Expand Down
6 changes: 3 additions & 3 deletions src/gretl/data_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
//
// SPDX-License-Identifier: (BSD-3-Clause)

#include <any>
#include "data_store.hpp"
#include "state.hpp"
#include <any>
#include <iostream>
#include <iomanip>

Expand Down Expand Up @@ -225,12 +225,12 @@ void DataStore::add_state(std::unique_ptr<StateBase> newState, const std::vector
}
}

evals_.emplace_back([=](const UpstreamStates&, DownstreamState&) {
evals_.emplace_back([this](const UpstreamStates&, DownstreamState&) {
std::cout << "eval not implemented for step " << currentStep_ << std::endl;
gretl_assert(false);
});

vjps_.emplace_back([=](UpstreamStates&, const DownstreamState&) {
vjps_.emplace_back([this](UpstreamStates&, const DownstreamState&) {
std::cout << "vjp not implemented for step " << currentStep_ << std::endl;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems I should be handling error here better. but that is for another day.

gretl_assert(false);
});
Expand Down
122 changes: 120 additions & 2 deletions src/gretl/data_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,96 @@ namespace gretl {

using Int = unsigned int; ///< gretl Int type

class DataStore;

struct StateBase;

template <typename T, typename D = T>
struct State;

struct UpstreamStates;
/// @brief UpstreamState is a wrapper for a states. Its used in external-facing interfaces to ensure const correctness
/// for users to encourage correct usage.
struct UpstreamState {
Int step_; ///< step
DataStore* dataStore_; ///< datastore

/// @brief get underlying value
template <typename T>
const T& get() const;

/// @brief get underlying dual value
template <typename D, typename T>
D& get_dual() const;
};

/// @brief UpstreamStates is a wrapper for a vector of states. Its used in external-facing interfaces to ensure const
/// correctness for users to encourage correct usage.
struct UpstreamStates {
/// @brief Default constructor to use in std containers
UpstreamStates() = default;

/// @brief Constructor for upstream states
/// @param store datastore
/// @param steps vector of upstream steps
UpstreamStates(DataStore& store, std::vector<Int> steps)
{
for (Int s : steps) {
states_.push_back({s, &store});
}
}

/// @brief Accessor for individual upstream states
/// @param index index
template <typename IntT>
const UpstreamState& operator[](IntT index) const
{
return states_[static_cast<size_t>(index)];
}

/// @brief Accessor for individual upstream states
/// @param index index
const UpstreamState& operator[](Int index) const { return states_[index]; }

/// @brief Number of upstream states
Int size() const { return static_cast<Int>(states_.size()); }

/// @brief Vector of upstream step indices
const std::vector<UpstreamState>& states() const { return states_; }

private:
std::vector<UpstreamState> states_; ///< states
};

/// @brief DownstreamState is a wrapper for a state. Its used in external-facing interfaces to ensure const correctness
/// for users to encourage correct usage.
struct DownstreamState {
/// @brief Constructor
/// @param s datastore
/// @param step step
DownstreamState(DataStore* s, Int step) : dataStore_(s), step_(step) {}

/// @brief set underlying value (copy)
template <typename T, typename D = T>
void set(const T& t);

/// @brief set underlying value (move)
template <typename T, typename D = std::decay_t<T>>
void set(T&& t);

/// @brief get underlying value
template <typename T, typename D = T>
const T& get() const;

/// @brief get underlying dual value
template <typename D, typename T = D>
const D& get_dual() const;

struct DownstreamState;
friend class DataStore;

private:
DataStore* dataStore_; ///< datastore
Int step_; ///< step
};

/// @brief ZeroDual function type
template <typename T, typename D = T>
Expand Down Expand Up @@ -275,4 +357,40 @@ class DataStore {
friend struct DownstreamState;
};

template <typename T>
const T& UpstreamState::get() const
{
return dataStore_->get_primal<T>(step_);
}

template <typename D, typename T>
D& UpstreamState::get_dual() const
{
return dataStore_->get_dual<D, T>(step_);
}

template <typename T, typename D>
void DownstreamState::set(const T& t)
{
dataStore_->set_primal(step_, t);
}

template <typename T, typename D>
void DownstreamState::set(T&& t)
{
dataStore_->set_primal(step_, std::forward<T>(t));
}

template <typename T, typename D>
const T& DownstreamState::get() const
{
return dataStore_->get_primal<T>(step_);
}

template <typename D, typename T>
const D& DownstreamState::get_dual() const
{
return dataStore_->get_dual<D, T>(step_);
}

} // namespace gretl
2 changes: 0 additions & 2 deletions src/gretl/state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
#pragma once

#include <functional>
#include "upstream_state.hpp"
#include "state_base.hpp"
#include "upstream_state.hpp"

namespace gretl {

Expand Down
1 change: 0 additions & 1 deletion src/gretl/state_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// SPDX-License-Identifier: (BSD-3-Clause)

#include "state_base.hpp"
#include "upstream_state.hpp"

namespace gretl {

Expand Down
120 changes: 0 additions & 120 deletions src/gretl/upstream_state.hpp

This file was deleted.

Loading