diff --git a/examples/spmm/spmm.cc b/examples/spmm/spmm.cc index f29ae9088e..e4778ab3ed 100644 --- a/examples/spmm/spmm.cc +++ b/examples/spmm/spmm.cc @@ -575,7 +575,7 @@ class SpMM25D { /// 3-D process grid only template class MultiplyAdd : public TT, std::tuple, Blk>, Out, Blk>>, MultiplyAdd, - ttg::typelist> { + ttg::typelist, Space_> { static constexpr const bool is_device_space = (Space_ != ttg::ExecutionSpace::Host); using task_return_type = std::conditional_t; diff --git a/ttg/ttg/madness/fwd.h b/ttg/ttg/madness/fwd.h index b164a6f50b..600e4f16af 100644 --- a/ttg/ttg/madness/fwd.h +++ b/ttg/ttg/madness/fwd.h @@ -9,7 +9,9 @@ namespace ttg_madness { - template > + template , + ttg::ExecutionSpace Space = ttg::ExecutionSpace::Host> class TT; /// \internal the OG name diff --git a/ttg/ttg/madness/ttg.h b/ttg/ttg/madness/ttg.h index 69ceafb49b..7997939c47 100644 --- a/ttg/ttg/madness/ttg.h +++ b/ttg/ttg/madness/ttg.h @@ -190,8 +190,9 @@ namespace ttg_madness { /// values /// flowing into this TT; a const type indicates nonmutating (read-only) use, nonconst type /// indicates mutating use (e.g. the corresponding input can be used as scratch, moved-from, etc.) - template - class TT : public ttg::TTBase, public ::madness::WorldObject> { + template + class TT : public ttg::TTBase, public ::madness::WorldObject> { + static_assert(Space == ttg::ExecutionSpace::Host, "MADNESS backend only supports Host Execution Space"); static_assert(ttg::meta::is_typelist_v, "The fourth template for ttg::TT must be a ttg::typelist containing the input types"); using input_tuple_type = ttg::meta::typelist_to_tuple_t; diff --git a/ttg/ttg/make_tt.h b/ttg/ttg/make_tt.h index f46cbd7c8c..3cabea508a 100644 --- a/ttg/ttg/make_tt.h +++ b/ttg/ttg/make_tt.h @@ -17,7 +17,7 @@ class CallableWrapTT : public TT< keyT, output_terminalsT, CallableWrapTT, - ttg::typelist> { + ttg::typelist, space> { using baseT = typename CallableWrapTT::ttT; using input_values_tuple_type = typename baseT::input_values_tuple_type; @@ -44,11 +44,6 @@ class CallableWrapTT void; #endif // TTG_HAVE_COROUTINE -public: - static constexpr bool have_cuda_op = (space == ttg::ExecutionSpace::CUDA); - static constexpr bool have_hip_op = (space == ttg::ExecutionSpace::HIP); - static constexpr bool have_level_zero_op = (space == ttg::ExecutionSpace::L0); - protected: template diff --git a/ttg/ttg/parsec/fwd.h b/ttg/ttg/parsec/fwd.h index c7e0b5551f..6ac6618671 100644 --- a/ttg/ttg/parsec/fwd.h +++ b/ttg/ttg/parsec/fwd.h @@ -13,7 +13,9 @@ extern "C" struct parsec_context_s; namespace ttg_parsec { - template > + template , + ttg::ExecutionSpace Space = ttg::ExecutionSpace::Host> class TT; /// \internal the OG name diff --git a/ttg/ttg/parsec/task.h b/ttg/ttg/parsec/task.h index f29ca8ecb5..65712dec1c 100644 --- a/ttg/ttg/parsec/task.h +++ b/ttg/ttg/parsec/task.h @@ -252,9 +252,9 @@ namespace ttg_parsec { template parsec_hook_return_t invoke_op() { if constexpr (Space == ttg::ExecutionSpace::Host) { - return TT::template static_op(&this->parsec_task); + return TT::static_op(&this->parsec_task); } else { - return TT::template device_static_op(&this->parsec_task); + return TT::device_static_op(&this->parsec_task); } } @@ -263,7 +263,7 @@ namespace ttg_parsec { if constexpr (Space == ttg::ExecutionSpace::Host) { return PARSEC_HOOK_RETURN_DONE; } else { - return TT::template device_static_evaluate(&this->parsec_task); + return TT::device_static_evaluate(&this->parsec_task); } } @@ -310,9 +310,9 @@ namespace ttg_parsec { template parsec_hook_return_t invoke_op() { if constexpr (Space == ttg::ExecutionSpace::Host) { - return TT::template static_op(&this->parsec_task); + return TT::static_op(&this->parsec_task); } else { - return TT::template device_static_op(&this->parsec_task); + return TT::device_static_op(&this->parsec_task); } } @@ -321,7 +321,7 @@ namespace ttg_parsec { if constexpr (Space == ttg::ExecutionSpace::Host) { return PARSEC_HOOK_RETURN_DONE; } else { - return TT::template device_static_evaluate(&this->parsec_task); + return TT::device_static_evaluate(&this->parsec_task); } } diff --git a/ttg/ttg/parsec/ttg.h b/ttg/ttg/parsec/ttg.h index ae04194e0f..c21df816f3 100644 --- a/ttg/ttg/parsec/ttg.h +++ b/ttg/ttg/parsec/ttg.h @@ -512,8 +512,9 @@ namespace ttg_parsec { #endif // TTG_USE_USER_TERMDET } - template > - void register_tt_profiling(const TT *t) { + template , ttg::ExecutionSpace Space> + void register_tt_profiling(const TT *t) { #if defined(PARSEC_PROF_TRACE) std::stringstream ss; build_composite_name_rec(t->ttg_ptr(), ss); @@ -1181,7 +1182,7 @@ namespace ttg_parsec { } // namespace detail - template + template class TT : public ttg::TTBase, detail::ParsecTTBase { private: /// preconditions @@ -1218,29 +1219,17 @@ namespace ttg_parsec { public: /// @return true if derivedT::have_cuda_op exists and is defined to true static constexpr bool derived_has_cuda_op() { - if constexpr (ttg::meta::is_detected_v) { - return derivedT::have_cuda_op; - } else { - return false; - } + return Space == ttg::ExecutionSpace::CUDA; } /// @return true if derivedT::have_hip_op exists and is defined to true static constexpr bool derived_has_hip_op() { - if constexpr (ttg::meta::is_detected_v) { - return derivedT::have_hip_op; - } else { - return false; - } + return Space == ttg::ExecutionSpace::HIP; } /// @return true if derivedT::have_hip_op exists and is defined to true static constexpr bool derived_has_level_zero_op() { - if constexpr (ttg::meta::is_detected_v) { - return derivedT::have_level_zero_op; - } else { - return false; - } + return Space == ttg::ExecutionSpace::L0; } /// @return true if the TT supports device execution @@ -1355,18 +1344,17 @@ namespace ttg_parsec { /// dispatches a call to derivedT::op /// @return void if called a synchronous function, or ttg::coroutine_handle<> if called a coroutine (if non-null, /// points to the suspended coroutine) - template + template auto op(Args &&...args) { derivedT *derived = static_cast(this); - //if constexpr (Space == ttg::ExecutionSpace::Host) { - using return_type = decltype(derived->op(std::forward(args)...)); - if constexpr (std::is_same_v) { - derived->op(std::forward(args)...); - return; - } - else { - return derived->op(std::forward(args)...); - } + using return_type = decltype(derived->op(std::forward(args)...)); + if constexpr (std::is_same_v) { + derived->op(std::forward(args)...); + return; + } + else { + return derived->op(std::forward(args)...); + } } template @@ -1419,7 +1407,6 @@ namespace ttg_parsec { /** * Submit callback called by PaRSEC once all input transfers have completed. */ - template static int device_static_submit(parsec_device_gpu_module_t *gpu_device, parsec_gpu_task_t *gpu_task, parsec_gpu_exec_stream_t *gpu_stream) { @@ -1461,7 +1448,7 @@ namespace ttg_parsec { #endif // defined(PARSEC_HAVE_DEV_LEVEL_ZERO_SUPPORT) && defined(TTG_HAVE_LEVEL_ZERO) /* Here we call back into the coroutine again after the transfers have completed */ - static_op(&task->parsec_task); + static_op(&task->parsec_task); ttg::device::detail::reset_current(); @@ -1503,7 +1490,6 @@ namespace ttg_parsec { return rc; } - template static parsec_hook_return_t device_static_evaluate(parsec_task_t* parsec_task) { task_t *task = (task_t*)parsec_task; @@ -1518,7 +1504,7 @@ namespace ttg_parsec { gpu_task->task_type = 0; // user task gpu_task->last_data_check_epoch = std::numeric_limits::max(); // used internally gpu_task->pushout = 0; - gpu_task->submit = &TT::device_static_submit; + gpu_task->submit = &TT::device_static_submit; // one way to force the task device // currently this will probably break all of PaRSEC if this hint @@ -1536,7 +1522,7 @@ namespace ttg_parsec { task->dev_ptr->task_class = *task->parsec_task.task_class; // first invocation of the coroutine to get the coroutine handle - static_op(parsec_task); + static_op(parsec_task); /* when we come back here, the flows in gpu_task are set (see register_device_memory) */ @@ -1586,7 +1572,6 @@ namespace ttg_parsec { } - template static parsec_hook_return_t device_static_op(parsec_task_t* parsec_task) { static_assert(derived_has_device_op()); @@ -1658,7 +1643,6 @@ namespace ttg_parsec { } #endif // TTG_HAVE_DEVICE - template static parsec_hook_return_t static_op(parsec_task_t *parsec_task) { task_t *task = (task_t*)parsec_task; @@ -1684,14 +1668,14 @@ namespace ttg_parsec { if constexpr (!ttg::meta::is_void_v && !ttg::meta::is_empty_tuple_v) { auto input = make_tuple_of_ref_from_array(task, std::make_index_sequence{}); - TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op(task->key, std::move(input), obj->output_terminals)); + TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(task->key, std::move(input), obj->output_terminals)); } else if constexpr (!ttg::meta::is_void_v && ttg::meta::is_empty_tuple_v) { - TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op(task->key, obj->output_terminals)); + TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(task->key, obj->output_terminals)); } else if constexpr (ttg::meta::is_void_v && !ttg::meta::is_empty_tuple_v) { auto input = make_tuple_of_ref_from_array(task, std::make_index_sequence{}); - TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op(std::move(input), obj->output_terminals)); + TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(std::move(input), obj->output_terminals)); } else if constexpr (ttg::meta::is_void_v && ttg::meta::is_empty_tuple_v) { - TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op(obj->output_terminals)); + TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(obj->output_terminals)); } else { ttg::abort(); } @@ -1767,7 +1751,6 @@ namespace ttg_parsec { return PARSEC_HOOK_RETURN_DONE; } - template static parsec_hook_return_t static_op_noarg(parsec_task_t *parsec_task) { task_t *task = static_cast(parsec_task); @@ -1783,9 +1766,9 @@ namespace ttg_parsec { assert(detail::parsec_ttg_caller == NULL); detail::parsec_ttg_caller = task; if constexpr (!ttg::meta::is_void_v) { - TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op(task->key, obj->output_terminals)); + TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(task->key, obj->output_terminals)); } else if constexpr (ttg::meta::is_void_v) { - TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op(obj->output_terminals)); + TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(obj->output_terminals)); } else // unreachable ttg:: abort(); detail::parsec_ttg_caller = NULL; @@ -4385,6 +4368,7 @@ namespace ttg_parsec { return ttg::device::Device(dm(key), ttg::ExecutionSpace::L0); } else { throw std::runtime_error("Unknown device type!"); + return ttg::device::Device{}; } }; }