diff --git a/include/ghex/threads/atomic/mutex.hpp b/include/ghex/threads/atomic/mutex.hpp new file mode 100644 index 0000000..4eee98d --- /dev/null +++ b/include/ghex/threads/atomic/mutex.hpp @@ -0,0 +1,62 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_GHEX_THREADS_MUTEX_HPP +#define INCLUDED_GHEX_THREADS_MUTEX_HPP + +#include +#include + +namespace gridtools { + namespace ghex { + namespace threads { + namespace atomic { + + class atomic_mutex + { + private: // members + std::atomic m_flag; + public: + atomic_mutex() noexcept : m_flag(0) {} + atomic_mutex(const atomic_mutex&) = delete; + atomic_mutex(atomic_mutex&&) = delete; + + inline bool try_lock() noexcept + { + bool expected = false; + return m_flag.compare_exchange_weak(expected, true, std::memory_order_relaxed); + } + + inline bool try_unlock() noexcept + { + bool expected = true; + return m_flag.compare_exchange_weak(expected, false, std::memory_order_relaxed); + } + + inline void lock() noexcept + { + while (!try_lock()) {} + } + + inline void unlock() noexcept + { + while (!try_unlock()) {} + } + }; + + template + using lock_guard = std::lock_guard; + } // namespace atomic + } // namespace threads + } // namespace ghex +} // namespace gridtools + +#endif /* INCLUDED_GHEX_THREADS_MUTEX_HPP */ + diff --git a/include/ghex/threads/atomic/primitives.hpp b/include/ghex/threads/atomic/primitives.hpp new file mode 100644 index 0000000..abda84b --- /dev/null +++ b/include/ghex/threads/atomic/primitives.hpp @@ -0,0 +1,193 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_GHEX_THREADS_ATOMIC_PRIMITIVES_HPP +#define INCLUDED_GHEX_THREADS_ATOMIC_PRIMITIVES_HPP + +#include +#include +#include "./mutex.hpp" + +namespace gridtools { + namespace ghex { + namespace threads { + namespace atomic { + + template + using void_return_type = typename std::enable_if< + std::is_same,void>::value, + void>::type; + + template + using return_type = typename std::enable_if< + !std::is_same,void>::value, + boost::callable_traits::return_type_t>::type; + +#ifndef GHEX_THREAD_SINGLE + struct primitives + { + public: // member types + using id_type = int; + + class token + { + private: // members + id_type m_id; + int m_epoch = 0; + bool m_selected = false; + + friend primitives; + + token(id_type id, int epoch) noexcept + : m_id(id), m_epoch(epoch), m_selected(id==0?true:false) + {} + + public: // ctors + token(const token&) = delete; + token(token&&) = default; + + public: // member functions + id_type id() const noexcept { return m_id; } + }; + + using mutex_type = atomic_mutex; + using lock_type = lock_guard; + + private: // members + const int m_num_threads; + std::atomic m_ids; + mutable volatile int m_epoch; + mutable std::atomic b_count; + mutable mutex_type m_mutex; + + public: // ctors + primitives(int num_threads) noexcept + : m_num_threads(num_threads) + , m_ids(0) + , m_epoch(0) + , b_count(0) + {} + + primitives(const primitives&) = delete; + primitives(primitives&&) = delete; + + public: // public member functions + inline token get_token() noexcept + { + return {(int)m_ids++,0}; + } + + inline void barrier(token& t) const noexcept + { + int expected = b_count; + while (!b_count.compare_exchange_weak(expected, expected+1, std::memory_order_relaxed)) + expected = b_count; + t.m_epoch ^= 1; + t.m_selected = (expected?false:true); + if (expected == m_num_threads-1) + { + b_count.store(0); + m_epoch ^= 1; + } + while(t.m_epoch != m_epoch) {} + } + + template + inline void single(token& t, F && f) const noexcept + { + if (t.m_selected) { + f(); + } + } + + template + inline void master(token& t, F && f) const noexcept + { + if (t.m_id == 0) { + f(); + } + } + + template + inline void_return_type critical(F && f) const noexcept + { + lock_type l(m_mutex); + f(); + } + + template + inline return_type critical(F && f) const noexcept + { + lock_type l(m_mutex); + return f(); + } + }; +#else + struct primitives + { + public: // member types + using id_type = int; + + class token + { + private: // members + id_type m_id; + + friend primitives; + + token(id_type id) noexcept + : m_id(id) + {} + + public: // ctors + token(const token&) = delete; + token(token&&) = default; + + public: // member functions + id_type id() const noexcept { return m_id; } + }; + + using mutex_type = atomic_mutex; + using lock_type = lock_guard; + + private: // members + + public: // ctors + primitives(int=1) noexcept + {} + + primitives(const primitives&) = delete; + primitives(primitives&&) = delete; + + public: // public member functions + inline token get_token() noexcept { return {0}; } + + inline void barrier(token& t) noexcept {} + + template + inline void single(token& t, F && f) const noexcept { f(); } + + template + inline void master(token& t, F && f) const noexcept { f(); } + + template + inline void_return_type critical(F && f) const noexcept { f(); } + + template + inline return_type critical(F && f) const noexcept { return f(); } + }; +#endif + } // namespace atomic + } // namespace threads + } // namespace ghex +} // namespace gridtools + +#endif /* INCLUDED_GHEX_THREADS_ATOMIC_PRIMITIVES_HPP */ + diff --git a/include/ghex/transport_layer/config.hpp b/include/ghex/transport_layer/config.hpp new file mode 100644 index 0000000..8e304a7 --- /dev/null +++ b/include/ghex/transport_layer/config.hpp @@ -0,0 +1,17 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_TL_CONFIG_HPP +#define INCLUDED_TL_CONFIG_HPP + + + +#endif /* INCLUDED_TL_CONFIG_HPP */ + diff --git a/include/ghex/transport_layer/context.hpp b/include/ghex/transport_layer/context.hpp new file mode 100644 index 0000000..4d01e1a --- /dev/null +++ b/include/ghex/transport_layer/context.hpp @@ -0,0 +1,171 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_GHEX_TL_CONTEXT_HPP +#define INCLUDED_GHEX_TL_CONTEXT_HPP + +#include +#include +//#include "../common/moved_bit.hpp" +#include "./config.hpp" + +namespace gridtools { + namespace ghex { + namespace tl { + + template + class transport_context; + + class mpi_world + { + private: + MPI_Comm m_comm; + int m_rank; + int m_size; + bool m_owning = false; + + mpi_world(int& argc, char**& argv) + { +#if defined(GHEX_THREAD_SINGLE) + MPI_Init(&argc, &argv); +#elif defined(GHEX_MPI_USE_GHEX_LOCKS) + int provided; + int res = MPI_Init_thread(&argc, &argv, MPI_THREAD_SERIALIZED, &provided); + if (res == MPI_ERR_OTHER) + throw std::runtime_error("MPI init failed"); + if (provided < MPI_THREAD_SERIALIZED) + throw std::runtime_error("MPI does not support required threading level"); +#else + int provided; + int res = MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided); + if (res == MPI_ERR_OTHER) + throw std::runtime_error("MPI init failed"); + if (provided < MPI_THREAD_MULTIPLE) + throw std::runtime_error("MPI does not support required threading level"); +#endif + m_comm = MPI_COMM_WORLD; + MPI_Comm_rank(m_comm, &m_rank); + MPI_Comm_size(m_comm, &m_size); + m_owning = true; + } + + mpi_world(MPI_Comm comm) + : m_comm(comm) + , m_rank{ [comm]() { int r; MPI_Comm_rank(comm, &r); return r; }() } + , m_size{ [comm]() { int r; MPI_Comm_size(comm, &r); return r; }() } + {} + + mpi_world(const mpi_world&) = delete; + mpi_world(mpi_world&&) = delete; + + // forward declaration + template + friend class parallel_context; + + public: + ~mpi_world() + { + if (m_owning) MPI_Finalize(); + } + + public: + inline int rank() const noexcept { return m_rank; } + inline int size() const noexcept { return m_size; } + operator MPI_Comm() const noexcept { return m_comm; } + }; + + template// = ::gridtools::ghex::threads::atomic::primitives> + class parallel_context + { + public: // members + using thread_primitives_type = ThreadPrimitives; + using thread_token = typename thread_primitives_type::token; + + private: + mpi_world m_world; + thread_primitives_type m_thread_primitives; + + private: + template + parallel_context(int num_threads, int& argc, char**& argv, Args&&...) noexcept + : m_world(argc,argv) + , m_thread_primitives(num_threads) + {} + + template + parallel_context(int num_threads, MPI_Comm comm, Args&&...) noexcept + : m_world(comm) + , m_thread_primitives(num_threads) + {} + + parallel_context(const parallel_context&) = delete; + parallel_context(parallel_context&&) = delete; + + // forward declaration + template + friend class context; + + public: + // thread-safe + const mpi_world& world() const { return m_world; } + // thread-safe + const thread_primitives_type& thread_primitives() const { return m_thread_primitives; } + // thread-safe + void barrier(thread_token& t) const + { + m_thread_primitives.barrier(t); + m_thread_primitives.single(t, [this]() { MPI_Barrier(m_world.m_comm); } ); + m_thread_primitives.barrier(t); + } + }; + + template + class context + { + public: // member types + using tag = TransportTag; + using transport_context_type = transport_context; + using communicator_type = typename transport_context_type::communicator_type; + using parallel_context_type = parallel_context; + using thread_token = typename parallel_context_type::thread_token; + + private: + std::unique_ptr m_parallel_context; + transport_context_type m_transport_context; + + public: + template + context(int num_threads, Args&&... args) + : m_parallel_context{new parallel_context_type{num_threads, std::forward(args)...}} + , m_transport_context{*m_parallel_context, std::forward(args)...} + {} + + public: + // thread-safe + communicator_type get_communicator(const thread_token& t) + { + return m_parallel_context->m_thread_primitives.critical( + [this,&t]() mutable { return m_transport_context.get_communicator(t.id()); } + ); + } + + // thread-safe + thread_token get_token() noexcept + { + return m_parallel_context->m_thread_primitives.get_token(); + } + }; + + } // namespace tl + } // namespace ghex +} // namespace gridtools + +#endif /* INCLUDED_CONTEXT_HPP */ + diff --git a/include/ghex/transport_layer/mpi/context.hpp b/include/ghex/transport_layer/mpi/context.hpp new file mode 100644 index 0000000..f511e06 --- /dev/null +++ b/include/ghex/transport_layer/mpi/context.hpp @@ -0,0 +1,48 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_TL_MPI_CONTEXT_HPP +#define INCLUDED_TL_MPI_CONTEXT_HPP + +#include "../context.hpp" +#include "./communicator.hpp" + +namespace gridtools { + namespace ghex { + namespace tl { + + template + struct transport_context + { + using communicator_type = communicator; + + parallel_context& m_parallel_context; + + template + transport_context(parallel_context& pc, Args&&...) + : m_parallel_context(pc) + {} + + communicator_type get_communicator(int) const + { + return {(MPI_Comm)(m_parallel_context.world())}; + } + + }; + + //using mpi_context = context; + + } + } +} + +#endif /* INCLUDED_TL_MPI_CONTEXT_HPP */ + + diff --git a/tests/transport/test_low_level.cpp b/tests/transport/test_low_level.cpp index 5ddda43..95e5a06 100644 --- a/tests/transport/test_low_level.cpp +++ b/tests/transport/test_low_level.cpp @@ -1,5 +1,7 @@ #include #include +#include +#include #include #include @@ -16,7 +18,13 @@ int rank; */ void test1() { - gridtools::ghex::tl::communicator sr; + //gridtools::ghex::tl::communicator sr; + + gridtools::ghex::tl::context context(1,MPI_COMM_WORLD); + + auto token = context.get_token(); + EXPECT_TRUE(token.id() == 0); + auto sr = context.get_communicator(token); std::vector smsg = {1,2,3,4,5,6,7,8,9,10}; std::vector rmsg(10);