From fe2d7ad0a6f1d66670aef5cb2a799628e3bc5511 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20B=C3=B6sch?= <48126478+boeschf@users.noreply.github.com> Date: Wed, 4 Dec 2019 09:48:36 +0100 Subject: [PATCH 1/3] contexts: parallel, transport and overall --- include/ghex/threads/context.hpp | 197 +++++++++++++++++++ include/ghex/threads/mutex.hpp | 61 ++++++ include/ghex/transport_layer/config.hpp | 17 ++ include/ghex/transport_layer/context.hpp | 171 ++++++++++++++++ include/ghex/transport_layer/mpi/context.hpp | 46 +++++ tests/transport/test_low_level.cpp | 8 +- 6 files changed, 499 insertions(+), 1 deletion(-) create mode 100644 include/ghex/threads/context.hpp create mode 100644 include/ghex/threads/mutex.hpp create mode 100644 include/ghex/transport_layer/config.hpp create mode 100644 include/ghex/transport_layer/context.hpp create mode 100644 include/ghex/transport_layer/mpi/context.hpp diff --git a/include/ghex/threads/context.hpp b/include/ghex/threads/context.hpp new file mode 100644 index 0000000..d066b57 --- /dev/null +++ b/include/ghex/threads/context.hpp @@ -0,0 +1,197 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_GHEX_THREADS_CONTEXT_CONTEXT_HPP +#define INCLUDED_GHEX_THREADS_CONTEXT_CONTEXT_HPP + +#include +#include +#include "./mutex.hpp" + +namespace gridtools { + namespace ghex { + namespace threads { + + template + using void_return_type = typename std::enable_if< + std::is_same,void>::value, + void>::type; + + template + using return_type = typename std::enable_if< + !std::is_same,void>::value, + boost::callable_traits::return_type_t>::type; + +#ifndef GHEX_THREAD_SINGLE + struct context + { + public: // member types + using id_type = int; + + class token + { + private: // members + id_type m_id; + int m_epoch = 0; + bool m_selected = false; + + friend context; + + token(id_type id, int epoch) noexcept + : m_id(id), m_epoch(epoch), m_selected(id==0?true:false) + {} + + public: // ctors + token(const token&) = delete; + token(token&&) = default; + + public: // member functions + id_type id() const noexcept { return m_id; } + }; + + using mutex_type = atomic_mutex; + using lock_type = lock_guard; + + private: // members + const int m_num_threads; + std::atomic m_ids; + mutable volatile int m_epoch; + mutable std::atomic b_count; + mutable mutex_type m_mutex; + + public: // ctors + context(int num_threads) noexcept + : m_num_threads(num_threads) + , m_ids(0) + , m_epoch(0) + , b_count(0) + {} + + context(const context&) = delete; + context(context&&) = delete; + + public: // public member functions + inline token get_token() noexcept + { + return {(int)m_ids++,0}; + } + + inline void barrier(token& t) const noexcept + { + int expected = b_count; + while (!b_count.compare_exchange_weak(expected, expected+1, std::memory_order_relaxed)) + expected = b_count; + t.m_epoch ^= 1; + t.m_selected = (expected?false:true); + if (expected == m_num_threads-1) + { + b_count.store(0); + m_epoch ^= 1; + } + while(t.m_epoch != m_epoch) {} + } + + template + inline void single(token& t, F && f) const noexcept + { + if (t.m_selected) { + f(); + } + } + + template + inline void master(token& t, F && f) const noexcept + { + if (t.m_id == 0) { + f(); + } + } + + template + /*inline + typename std::enable_if< + std::is_same,void>::value, + void>::type*/ + inline void_return_type critical(F && f) const noexcept + { + lock_type l(m_mutex); + f(); + } + + template + inline + typename std::enable_if< + !std::is_same,void>::value, + boost::callable_traits::return_type_t>::type + critical(F && f) const noexcept + { + lock_type l(m_mutex); + return f(); + } + }; +#else + struct context + { + public: // member types + using id_type = int; + + class token + { + private: // members + id_type m_id; + + friend context; + + token(id_type id) noexcept + : m_id(id) + {} + + public: // ctors + token(const token&) = delete; + token(token&&) = default; + + public: // member functions + id_type id() const noexcept { return m_id; } + }; + + using mutex_type = atomic_mutex; + using lock_type = lock_guard; + + private: // members + + public: // ctors + context(int=1) noexcept + {} + + context(const context&) = delete; + context(context&&) = delete; + + public: // public member functions + inline token get_token() noexcept { return {0}; } + + inline void barrier(token& t) noexcept {} + + template + inline void single(token& t, F && f) const noexcept { f(); } + + template + inline void master(token& t, F && f) const noexcept { f(); } + + template + inline void critical(F && f) const noexcept { f(); } + }; +#endif + + } // namespace threads + } // namespace ghex +} // namespace gridtools + +#endif /* INCLUDED_GHEX_THREADS_CONTEXT_CONTEXT_HPP */ + diff --git a/include/ghex/threads/mutex.hpp b/include/ghex/threads/mutex.hpp new file mode 100644 index 0000000..2d9dec8 --- /dev/null +++ b/include/ghex/threads/mutex.hpp @@ -0,0 +1,61 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_GHEX_THREADS_MUTEX_HPP +#define INCLUDED_GHEX_THREADS_MUTEX_HPP + +#include +#include + +namespace gridtools { + namespace ghex { + namespace threads { + + class atomic_mutex + { + private: // members + std::atomic m_flag; + public: + atomic_mutex() noexcept : m_flag(0) {} + atomic_mutex(const atomic_mutex&) = delete; + atomic_mutex(atomic_mutex&&) = delete; + + inline bool try_lock() noexcept + { + bool expected = false; + return m_flag.compare_exchange_weak(expected, true, std::memory_order_relaxed); + } + + inline bool try_unlock() noexcept + { + bool expected = true; + return m_flag.compare_exchange_weak(expected, false, std::memory_order_relaxed); + } + + inline void lock() noexcept + { + while (!try_lock()) {} + } + + inline void unlock() noexcept + { + while (!try_unlock()) {} + } + }; + + template + using lock_guard = std::lock_guard; + + } // namespace threads + } // namespace ghex +} // namespace gridtools + +#endif /* INCLUDED_GHEX_THREADS_MUTEX_HPP */ + diff --git a/include/ghex/transport_layer/config.hpp b/include/ghex/transport_layer/config.hpp new file mode 100644 index 0000000..8e304a7 --- /dev/null +++ b/include/ghex/transport_layer/config.hpp @@ -0,0 +1,17 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_TL_CONFIG_HPP +#define INCLUDED_TL_CONFIG_HPP + + + +#endif /* INCLUDED_TL_CONFIG_HPP */ + diff --git a/include/ghex/transport_layer/context.hpp b/include/ghex/transport_layer/context.hpp new file mode 100644 index 0000000..deef272 --- /dev/null +++ b/include/ghex/transport_layer/context.hpp @@ -0,0 +1,171 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_GHEX_TL_CONTEXT_HPP +#define INCLUDED_GHEX_TL_CONTEXT_HPP + +#include +#include +#include "../threads/context.hpp" +//#include "../common/moved_bit.hpp" +#include "./config.hpp" + +namespace gridtools { + namespace ghex { + namespace tl { + + template + class transport_context; + + template + class context; + + class parallel_context; + + class mpi_world + { + private: + MPI_Comm m_comm; + int m_rank; + int m_size; + bool m_owning = false; + + mpi_world(int& argc, char**& argv) + { +#if defined(GHEX_THREAD_SINGLE) + MPI_Init(&argc, &argv); +#elif defined(GHEX_MPI_USE_GHEX_LOCKS) + int provided; + MPI_Init_thread(&argc, &argv, MPI_THREAD_SERIALIZED, &provided); +#else + int provided; + MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided); +#endif + m_comm = MPI_COMM_WORLD; + MPI_Comm_rank(m_comm, &m_rank); + MPI_Comm_size(m_comm, &m_size); + m_owning = true; + } + + mpi_world(MPI_Comm comm) + : m_comm(comm) + , m_rank{ [comm]() { int r; MPI_Comm_rank(comm, &r); return r; }() } + , m_size{ [comm]() { int r; MPI_Comm_size(comm, &r); return r; }() } + {} + + mpi_world(const mpi_world&) = delete; + mpi_world(mpi_world&&) = delete; + + friend class parallel_context; + + public: + ~mpi_world() + { + if (m_owning) MPI_Finalize(); + } + + public: + inline int rank() const noexcept { return m_rank; } + inline int size() const noexcept { return m_size; } + operator MPI_Comm() const noexcept { return m_comm; } + }; + + class parallel_context + { + public: // members + using thread_context_type = ::gridtools::ghex::threads::context; + using thread_token = typename thread_context_type::token; + + private: + mpi_world m_world; + public: + thread_context_type m_thread_context; + + public: + template + parallel_context(int num_threads, int& argc, char**& argv, Args&&...) noexcept + : m_world(argc,argv) + , m_thread_context(num_threads) + {} + + template + parallel_context(int num_threads, MPI_Comm comm, Args&&...) noexcept + : m_world(comm) + , m_thread_context(num_threads) + {} + + parallel_context(const parallel_context&) = delete; + parallel_context(parallel_context&&) = delete; + + //template + //friend class context; + + public: + // thread-safe + const mpi_world& world() const { return m_world; } + // thread-safe + const thread_context_type& thread_context() const { return m_thread_context; } + // thread-safe + void barrier(thread_token& t) const + { + m_thread_context.barrier(t); + m_thread_context.single(t, [this]() { MPI_Barrier(m_world.m_comm); } ); + m_thread_context.barrier(t); + } + }; + + template + class context + { + public: // member types + using tag = TransportTag; + using transport_context_type = transport_context; + using communicator_type = typename transport_context_type::communicator_type; + using thread_token = parallel_context::thread_token; + + private: + std::unique_ptr m_pc; + transport_context_type m_tc; + + public: + template + context(int num_threads, Args&&... args) + : m_pc{std::make_unique(num_threads, std::forward(args)...)} + , m_tc{*m_pc, std::forward(args)...} + {} + + /*template + context(int& argc, char**& argv, MPI_Comm comm, int num_threads=1, M, Args&&... args) + : m_pc{std::make_unique(comm, num_threads)} + , m_tc{*m_pc, argc, argv, std::forward(args)...} + {}*/ + + public: + // thread-safe + communicator_type get_communicator(const thread_token& t) + { + return m_pc->m_thread_context.critical( + [this,&t]() mutable { return m_tc.get_communicator(t.id()); } + ); + } + + // thread-safe + thread_token get_token() noexcept + { + return m_pc->m_thread_context.get_token(); + } + }; + + } // namespace tl + } // namespace ghex +} // namespace gridtools + +#endif /* INCLUDED_CONTEXT_HPP */ + diff --git a/include/ghex/transport_layer/mpi/context.hpp b/include/ghex/transport_layer/mpi/context.hpp new file mode 100644 index 0000000..520f8d9 --- /dev/null +++ b/include/ghex/transport_layer/mpi/context.hpp @@ -0,0 +1,46 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_TL_MPI_CONTEXT_HPP +#define INCLUDED_TL_MPI_CONTEXT_HPP + +#include "../context.hpp" +#include "./communicator.hpp" + +namespace gridtools { + namespace ghex { + namespace tl { + + template<> + struct transport_context + { + using communicator_type = communicator; + + parallel_context& m_pc; + + template + transport_context(parallel_context& pc, Args&&...) + : m_pc(pc) + {} + + communicator_type get_communicator(int) const + { + return {(MPI_Comm)(m_pc.world())}; + } + + }; + + } + } +} + +#endif /* INCLUDED_TL_MPI_CONTEXT_HPP */ + + diff --git a/tests/transport/test_low_level.cpp b/tests/transport/test_low_level.cpp index 5ddda43..291e1de 100644 --- a/tests/transport/test_low_level.cpp +++ b/tests/transport/test_low_level.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -16,7 +17,12 @@ int rank; */ void test1() { - gridtools::ghex::tl::communicator sr; + //gridtools::ghex::tl::communicator sr; + + gridtools::ghex::tl::context context(1,MPI_COMM_WORLD); + + auto token = context.get_token(); + auto sr = context.get_communicator(token); std::vector smsg = {1,2,3,4,5,6,7,8,9,10}; std::vector rmsg(10); From 886ca3487bdd18813bfcf5c9dc74e6d44d6ceef8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20B=C3=B6sch?= <48126478+boeschf@users.noreply.github.com> Date: Wed, 4 Dec 2019 11:27:56 +0100 Subject: [PATCH 2/3] friend classes --- include/ghex/threads/context.hpp | 10 +---- include/ghex/transport_layer/context.hpp | 46 ++++++++++---------- include/ghex/transport_layer/mpi/context.hpp | 2 + 3 files changed, 26 insertions(+), 32 deletions(-) diff --git a/include/ghex/threads/context.hpp b/include/ghex/threads/context.hpp index d066b57..15cad73 100644 --- a/include/ghex/threads/context.hpp +++ b/include/ghex/threads/context.hpp @@ -115,10 +115,6 @@ namespace gridtools { } template - /*inline - typename std::enable_if< - std::is_same,void>::value, - void>::type*/ inline void_return_type critical(F && f) const noexcept { lock_type l(m_mutex); @@ -126,11 +122,7 @@ namespace gridtools { } template - inline - typename std::enable_if< - !std::is_same,void>::value, - boost::callable_traits::return_type_t>::type - critical(F && f) const noexcept + inline return_type critical(F && f) const noexcept { lock_type l(m_mutex); return f(); diff --git a/include/ghex/transport_layer/context.hpp b/include/ghex/transport_layer/context.hpp index deef272..a61d70c 100644 --- a/include/ghex/transport_layer/context.hpp +++ b/include/ghex/transport_layer/context.hpp @@ -24,9 +24,6 @@ namespace gridtools { template class transport_context; - template - class context; - class parallel_context; class mpi_world @@ -43,10 +40,18 @@ namespace gridtools { MPI_Init(&argc, &argv); #elif defined(GHEX_MPI_USE_GHEX_LOCKS) int provided; - MPI_Init_thread(&argc, &argv, MPI_THREAD_SERIALIZED, &provided); + int res = MPI_Init_thread(&argc, &argv, MPI_THREAD_SERIALIZED, &provided); + if (res == MPI_ERR_OTHER) + throw std::runtime_error("MPI init failed"); + if (provided < MPI_THREAD_SERIALIZED) + throw std::runtime_error("MPI does not support required threading level"); #else int provided; - MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided); + int res = MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided); + if (res == MPI_ERR_OTHER) + throw std::runtime_error("MPI init failed"); + if (provided < MPI_THREAD_MULTIPLE) + throw std::runtime_error("MPI does not support required threading level"); #endif m_comm = MPI_COMM_WORLD; MPI_Comm_rank(m_comm, &m_rank); @@ -77,6 +82,7 @@ namespace gridtools { operator MPI_Comm() const noexcept { return m_comm; } }; + class parallel_context { public: // members @@ -85,10 +91,9 @@ namespace gridtools { private: mpi_world m_world; - public: thread_context_type m_thread_context; - public: + private: template parallel_context(int num_threads, int& argc, char**& argv, Args&&...) noexcept : m_world(argc,argv) @@ -104,8 +109,9 @@ namespace gridtools { parallel_context(const parallel_context&) = delete; parallel_context(parallel_context&&) = delete; - //template - //friend class context; + // forward declaration + template + friend class context; public: // thread-safe @@ -131,37 +137,31 @@ namespace gridtools { using thread_token = parallel_context::thread_token; private: - std::unique_ptr m_pc; - transport_context_type m_tc; + std::unique_ptr m_parallel_context; + transport_context_type m_transport_context; public: template context(int num_threads, Args&&... args) - : m_pc{std::make_unique(num_threads, std::forward(args)...)} - , m_tc{*m_pc, std::forward(args)...} + : m_parallel_context{new parallel_context{num_threads, std::forward(args)...}} + , m_transport_context{*m_parallel_context, std::forward(args)...} {} - /*template - context(int& argc, char**& argv, MPI_Comm comm, int num_threads=1, M, Args&&... args) - : m_pc{std::make_unique(comm, num_threads)} - , m_tc{*m_pc, argc, argv, std::forward(args)...} - {}*/ - public: // thread-safe communicator_type get_communicator(const thread_token& t) { - return m_pc->m_thread_context.critical( - [this,&t]() mutable { return m_tc.get_communicator(t.id()); } + return m_parallel_context->m_thread_context.critical( + [this,&t]() mutable { return m_transport_context.get_communicator(t.id()); } ); } // thread-safe thread_token get_token() noexcept { - return m_pc->m_thread_context.get_token(); + return m_parallel_context->m_thread_context.get_token(); } - }; + }; } // namespace tl } // namespace ghex diff --git a/include/ghex/transport_layer/mpi/context.hpp b/include/ghex/transport_layer/mpi/context.hpp index 520f8d9..1700542 100644 --- a/include/ghex/transport_layer/mpi/context.hpp +++ b/include/ghex/transport_layer/mpi/context.hpp @@ -37,6 +37,8 @@ namespace gridtools { }; + using mpi_context = context; + } } } From cd38b12ee9d33b3c04130b5f6c349648fccfbff0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20B=C3=B6sch?= <48126478+boeschf@users.noreply.github.com> Date: Wed, 4 Dec 2019 14:46:14 +0100 Subject: [PATCH 3/3] renamed to primitives --- include/ghex/threads/atomic/mutex.hpp | 62 ++++++ include/ghex/threads/atomic/primitives.hpp | 193 +++++++++++++++++++ include/ghex/threads/context.hpp | 189 ------------------ include/ghex/threads/mutex.hpp | 61 ------ include/ghex/transport_layer/context.hpp | 48 ++--- include/ghex/transport_layer/mpi/context.hpp | 14 +- tests/transport/test_low_level.cpp | 4 +- 7 files changed, 289 insertions(+), 282 deletions(-) create mode 100644 include/ghex/threads/atomic/mutex.hpp create mode 100644 include/ghex/threads/atomic/primitives.hpp delete mode 100644 include/ghex/threads/context.hpp delete mode 100644 include/ghex/threads/mutex.hpp diff --git a/include/ghex/threads/atomic/mutex.hpp b/include/ghex/threads/atomic/mutex.hpp new file mode 100644 index 0000000..4eee98d --- /dev/null +++ b/include/ghex/threads/atomic/mutex.hpp @@ -0,0 +1,62 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_GHEX_THREADS_MUTEX_HPP +#define INCLUDED_GHEX_THREADS_MUTEX_HPP + +#include +#include + +namespace gridtools { + namespace ghex { + namespace threads { + namespace atomic { + + class atomic_mutex + { + private: // members + std::atomic m_flag; + public: + atomic_mutex() noexcept : m_flag(0) {} + atomic_mutex(const atomic_mutex&) = delete; + atomic_mutex(atomic_mutex&&) = delete; + + inline bool try_lock() noexcept + { + bool expected = false; + return m_flag.compare_exchange_weak(expected, true, std::memory_order_relaxed); + } + + inline bool try_unlock() noexcept + { + bool expected = true; + return m_flag.compare_exchange_weak(expected, false, std::memory_order_relaxed); + } + + inline void lock() noexcept + { + while (!try_lock()) {} + } + + inline void unlock() noexcept + { + while (!try_unlock()) {} + } + }; + + template + using lock_guard = std::lock_guard; + } // namespace atomic + } // namespace threads + } // namespace ghex +} // namespace gridtools + +#endif /* INCLUDED_GHEX_THREADS_MUTEX_HPP */ + diff --git a/include/ghex/threads/atomic/primitives.hpp b/include/ghex/threads/atomic/primitives.hpp new file mode 100644 index 0000000..abda84b --- /dev/null +++ b/include/ghex/threads/atomic/primitives.hpp @@ -0,0 +1,193 @@ +/* + * GridTools + * + * Copyright (c) 2014-2019, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + * + */ +#ifndef INCLUDED_GHEX_THREADS_ATOMIC_PRIMITIVES_HPP +#define INCLUDED_GHEX_THREADS_ATOMIC_PRIMITIVES_HPP + +#include +#include +#include "./mutex.hpp" + +namespace gridtools { + namespace ghex { + namespace threads { + namespace atomic { + + template + using void_return_type = typename std::enable_if< + std::is_same,void>::value, + void>::type; + + template + using return_type = typename std::enable_if< + !std::is_same,void>::value, + boost::callable_traits::return_type_t>::type; + +#ifndef GHEX_THREAD_SINGLE + struct primitives + { + public: // member types + using id_type = int; + + class token + { + private: // members + id_type m_id; + int m_epoch = 0; + bool m_selected = false; + + friend primitives; + + token(id_type id, int epoch) noexcept + : m_id(id), m_epoch(epoch), m_selected(id==0?true:false) + {} + + public: // ctors + token(const token&) = delete; + token(token&&) = default; + + public: // member functions + id_type id() const noexcept { return m_id; } + }; + + using mutex_type = atomic_mutex; + using lock_type = lock_guard; + + private: // members + const int m_num_threads; + std::atomic m_ids; + mutable volatile int m_epoch; + mutable std::atomic b_count; + mutable mutex_type m_mutex; + + public: // ctors + primitives(int num_threads) noexcept + : m_num_threads(num_threads) + , m_ids(0) + , m_epoch(0) + , b_count(0) + {} + + primitives(const primitives&) = delete; + primitives(primitives&&) = delete; + + public: // public member functions + inline token get_token() noexcept + { + return {(int)m_ids++,0}; + } + + inline void barrier(token& t) const noexcept + { + int expected = b_count; + while (!b_count.compare_exchange_weak(expected, expected+1, std::memory_order_relaxed)) + expected = b_count; + t.m_epoch ^= 1; + t.m_selected = (expected?false:true); + if (expected == m_num_threads-1) + { + b_count.store(0); + m_epoch ^= 1; + } + while(t.m_epoch != m_epoch) {} + } + + template + inline void single(token& t, F && f) const noexcept + { + if (t.m_selected) { + f(); + } + } + + template + inline void master(token& t, F && f) const noexcept + { + if (t.m_id == 0) { + f(); + } + } + + template + inline void_return_type critical(F && f) const noexcept + { + lock_type l(m_mutex); + f(); + } + + template + inline return_type critical(F && f) const noexcept + { + lock_type l(m_mutex); + return f(); + } + }; +#else + struct primitives + { + public: // member types + using id_type = int; + + class token + { + private: // members + id_type m_id; + + friend primitives; + + token(id_type id) noexcept + : m_id(id) + {} + + public: // ctors + token(const token&) = delete; + token(token&&) = default; + + public: // member functions + id_type id() const noexcept { return m_id; } + }; + + using mutex_type = atomic_mutex; + using lock_type = lock_guard; + + private: // members + + public: // ctors + primitives(int=1) noexcept + {} + + primitives(const primitives&) = delete; + primitives(primitives&&) = delete; + + public: // public member functions + inline token get_token() noexcept { return {0}; } + + inline void barrier(token& t) noexcept {} + + template + inline void single(token& t, F && f) const noexcept { f(); } + + template + inline void master(token& t, F && f) const noexcept { f(); } + + template + inline void_return_type critical(F && f) const noexcept { f(); } + + template + inline return_type critical(F && f) const noexcept { return f(); } + }; +#endif + } // namespace atomic + } // namespace threads + } // namespace ghex +} // namespace gridtools + +#endif /* INCLUDED_GHEX_THREADS_ATOMIC_PRIMITIVES_HPP */ + diff --git a/include/ghex/threads/context.hpp b/include/ghex/threads/context.hpp deleted file mode 100644 index 15cad73..0000000 --- a/include/ghex/threads/context.hpp +++ /dev/null @@ -1,189 +0,0 @@ -/* - * GridTools - * - * Copyright (c) 2014-2019, ETH Zurich - * All rights reserved. - * - * Please, refer to the LICENSE file in the root directory. - * SPDX-License-Identifier: BSD-3-Clause - * - */ -#ifndef INCLUDED_GHEX_THREADS_CONTEXT_CONTEXT_HPP -#define INCLUDED_GHEX_THREADS_CONTEXT_CONTEXT_HPP - -#include -#include -#include "./mutex.hpp" - -namespace gridtools { - namespace ghex { - namespace threads { - - template - using void_return_type = typename std::enable_if< - std::is_same,void>::value, - void>::type; - - template - using return_type = typename std::enable_if< - !std::is_same,void>::value, - boost::callable_traits::return_type_t>::type; - -#ifndef GHEX_THREAD_SINGLE - struct context - { - public: // member types - using id_type = int; - - class token - { - private: // members - id_type m_id; - int m_epoch = 0; - bool m_selected = false; - - friend context; - - token(id_type id, int epoch) noexcept - : m_id(id), m_epoch(epoch), m_selected(id==0?true:false) - {} - - public: // ctors - token(const token&) = delete; - token(token&&) = default; - - public: // member functions - id_type id() const noexcept { return m_id; } - }; - - using mutex_type = atomic_mutex; - using lock_type = lock_guard; - - private: // members - const int m_num_threads; - std::atomic m_ids; - mutable volatile int m_epoch; - mutable std::atomic b_count; - mutable mutex_type m_mutex; - - public: // ctors - context(int num_threads) noexcept - : m_num_threads(num_threads) - , m_ids(0) - , m_epoch(0) - , b_count(0) - {} - - context(const context&) = delete; - context(context&&) = delete; - - public: // public member functions - inline token get_token() noexcept - { - return {(int)m_ids++,0}; - } - - inline void barrier(token& t) const noexcept - { - int expected = b_count; - while (!b_count.compare_exchange_weak(expected, expected+1, std::memory_order_relaxed)) - expected = b_count; - t.m_epoch ^= 1; - t.m_selected = (expected?false:true); - if (expected == m_num_threads-1) - { - b_count.store(0); - m_epoch ^= 1; - } - while(t.m_epoch != m_epoch) {} - } - - template - inline void single(token& t, F && f) const noexcept - { - if (t.m_selected) { - f(); - } - } - - template - inline void master(token& t, F && f) const noexcept - { - if (t.m_id == 0) { - f(); - } - } - - template - inline void_return_type critical(F && f) const noexcept - { - lock_type l(m_mutex); - f(); - } - - template - inline return_type critical(F && f) const noexcept - { - lock_type l(m_mutex); - return f(); - } - }; -#else - struct context - { - public: // member types - using id_type = int; - - class token - { - private: // members - id_type m_id; - - friend context; - - token(id_type id) noexcept - : m_id(id) - {} - - public: // ctors - token(const token&) = delete; - token(token&&) = default; - - public: // member functions - id_type id() const noexcept { return m_id; } - }; - - using mutex_type = atomic_mutex; - using lock_type = lock_guard; - - private: // members - - public: // ctors - context(int=1) noexcept - {} - - context(const context&) = delete; - context(context&&) = delete; - - public: // public member functions - inline token get_token() noexcept { return {0}; } - - inline void barrier(token& t) noexcept {} - - template - inline void single(token& t, F && f) const noexcept { f(); } - - template - inline void master(token& t, F && f) const noexcept { f(); } - - template - inline void critical(F && f) const noexcept { f(); } - }; -#endif - - } // namespace threads - } // namespace ghex -} // namespace gridtools - -#endif /* INCLUDED_GHEX_THREADS_CONTEXT_CONTEXT_HPP */ - diff --git a/include/ghex/threads/mutex.hpp b/include/ghex/threads/mutex.hpp deleted file mode 100644 index 2d9dec8..0000000 --- a/include/ghex/threads/mutex.hpp +++ /dev/null @@ -1,61 +0,0 @@ -/* - * GridTools - * - * Copyright (c) 2014-2019, ETH Zurich - * All rights reserved. - * - * Please, refer to the LICENSE file in the root directory. - * SPDX-License-Identifier: BSD-3-Clause - * - */ -#ifndef INCLUDED_GHEX_THREADS_MUTEX_HPP -#define INCLUDED_GHEX_THREADS_MUTEX_HPP - -#include -#include - -namespace gridtools { - namespace ghex { - namespace threads { - - class atomic_mutex - { - private: // members - std::atomic m_flag; - public: - atomic_mutex() noexcept : m_flag(0) {} - atomic_mutex(const atomic_mutex&) = delete; - atomic_mutex(atomic_mutex&&) = delete; - - inline bool try_lock() noexcept - { - bool expected = false; - return m_flag.compare_exchange_weak(expected, true, std::memory_order_relaxed); - } - - inline bool try_unlock() noexcept - { - bool expected = true; - return m_flag.compare_exchange_weak(expected, false, std::memory_order_relaxed); - } - - inline void lock() noexcept - { - while (!try_lock()) {} - } - - inline void unlock() noexcept - { - while (!try_unlock()) {} - } - }; - - template - using lock_guard = std::lock_guard; - - } // namespace threads - } // namespace ghex -} // namespace gridtools - -#endif /* INCLUDED_GHEX_THREADS_MUTEX_HPP */ - diff --git a/include/ghex/transport_layer/context.hpp b/include/ghex/transport_layer/context.hpp index a61d70c..4d01e1a 100644 --- a/include/ghex/transport_layer/context.hpp +++ b/include/ghex/transport_layer/context.hpp @@ -13,7 +13,6 @@ #include #include -#include "../threads/context.hpp" //#include "../common/moved_bit.hpp" #include "./config.hpp" @@ -21,11 +20,9 @@ namespace gridtools { namespace ghex { namespace tl { - template + template class transport_context; - class parallel_context; - class mpi_world { private: @@ -68,6 +65,8 @@ namespace gridtools { mpi_world(const mpi_world&) = delete; mpi_world(mpi_world&&) = delete; + // forward declaration + template friend class parallel_context; public: @@ -82,68 +81,69 @@ namespace gridtools { operator MPI_Comm() const noexcept { return m_comm; } }; - + template// = ::gridtools::ghex::threads::atomic::primitives> class parallel_context { public: // members - using thread_context_type = ::gridtools::ghex::threads::context; - using thread_token = typename thread_context_type::token; + using thread_primitives_type = ThreadPrimitives; + using thread_token = typename thread_primitives_type::token; private: - mpi_world m_world; - thread_context_type m_thread_context; + mpi_world m_world; + thread_primitives_type m_thread_primitives; private: template parallel_context(int num_threads, int& argc, char**& argv, Args&&...) noexcept : m_world(argc,argv) - , m_thread_context(num_threads) + , m_thread_primitives(num_threads) {} template parallel_context(int num_threads, MPI_Comm comm, Args&&...) noexcept : m_world(comm) - , m_thread_context(num_threads) + , m_thread_primitives(num_threads) {} parallel_context(const parallel_context&) = delete; parallel_context(parallel_context&&) = delete; // forward declaration - template + template friend class context; public: // thread-safe const mpi_world& world() const { return m_world; } // thread-safe - const thread_context_type& thread_context() const { return m_thread_context; } + const thread_primitives_type& thread_primitives() const { return m_thread_primitives; } // thread-safe void barrier(thread_token& t) const { - m_thread_context.barrier(t); - m_thread_context.single(t, [this]() { MPI_Barrier(m_world.m_comm); } ); - m_thread_context.barrier(t); + m_thread_primitives.barrier(t); + m_thread_primitives.single(t, [this]() { MPI_Barrier(m_world.m_comm); } ); + m_thread_primitives.barrier(t); } }; - template + template class context { public: // member types using tag = TransportTag; - using transport_context_type = transport_context; + using transport_context_type = transport_context; using communicator_type = typename transport_context_type::communicator_type; - using thread_token = parallel_context::thread_token; + using parallel_context_type = parallel_context; + using thread_token = typename parallel_context_type::thread_token; private: - std::unique_ptr m_parallel_context; - transport_context_type m_transport_context; + std::unique_ptr m_parallel_context; + transport_context_type m_transport_context; public: template context(int num_threads, Args&&... args) - : m_parallel_context{new parallel_context{num_threads, std::forward(args)...}} + : m_parallel_context{new parallel_context_type{num_threads, std::forward(args)...}} , m_transport_context{*m_parallel_context, std::forward(args)...} {} @@ -151,7 +151,7 @@ namespace gridtools { // thread-safe communicator_type get_communicator(const thread_token& t) { - return m_parallel_context->m_thread_context.critical( + return m_parallel_context->m_thread_primitives.critical( [this,&t]() mutable { return m_transport_context.get_communicator(t.id()); } ); } @@ -159,7 +159,7 @@ namespace gridtools { // thread-safe thread_token get_token() noexcept { - return m_parallel_context->m_thread_context.get_token(); + return m_parallel_context->m_thread_primitives.get_token(); } }; diff --git a/include/ghex/transport_layer/mpi/context.hpp b/include/ghex/transport_layer/mpi/context.hpp index 1700542..f511e06 100644 --- a/include/ghex/transport_layer/mpi/context.hpp +++ b/include/ghex/transport_layer/mpi/context.hpp @@ -18,26 +18,26 @@ namespace gridtools { namespace ghex { namespace tl { - template<> - struct transport_context + template + struct transport_context { using communicator_type = communicator; - parallel_context& m_pc; + parallel_context& m_parallel_context; template - transport_context(parallel_context& pc, Args&&...) - : m_pc(pc) + transport_context(parallel_context& pc, Args&&...) + : m_parallel_context(pc) {} communicator_type get_communicator(int) const { - return {(MPI_Comm)(m_pc.world())}; + return {(MPI_Comm)(m_parallel_context.world())}; } }; - using mpi_context = context; + //using mpi_context = context; } } diff --git a/tests/transport/test_low_level.cpp b/tests/transport/test_low_level.cpp index 291e1de..95e5a06 100644 --- a/tests/transport/test_low_level.cpp +++ b/tests/transport/test_low_level.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -19,9 +20,10 @@ int rank; void test1() { //gridtools::ghex::tl::communicator sr; - gridtools::ghex::tl::context context(1,MPI_COMM_WORLD); + gridtools::ghex::tl::context context(1,MPI_COMM_WORLD); auto token = context.get_token(); + EXPECT_TRUE(token.id() == 0); auto sr = context.get_communicator(token); std::vector smsg = {1,2,3,4,5,6,7,8,9,10};