Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ class SpMM25D {
/// 3-D process grid only
template<ttg::ExecutionSpace Space_>
class MultiplyAdd : public TT<Key<3>, std::tuple<Out<Key<2>, Blk>, Out<Key<3>, Blk>>, MultiplyAdd<Space_>,
ttg::typelist<const Blk, const Blk, Blk>> {
ttg::typelist<const Blk, const Blk, Blk>, Space_> {
static constexpr const bool is_device_space = (Space_ != ttg::ExecutionSpace::Host);
using task_return_type = std::conditional_t<is_device_space, ttg::device::Task, void>;

Expand Down
4 changes: 3 additions & 1 deletion ttg/ttg/madness/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

namespace ttg_madness {

template <typename keyT, typename output_terminalsT, typename derivedT, typename input_valueTs = ttg::typelist<>>
template <typename keyT, typename output_terminalsT, typename derivedT,
typename input_valueTs = ttg::typelist<>,
ttg::ExecutionSpace Space = ttg::ExecutionSpace::Host>
class TT;

/// \internal the OG name
Expand Down
5 changes: 3 additions & 2 deletions ttg/ttg/madness/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ namespace ttg_madness {
/// values
/// flowing into this TT; a const type indicates nonmutating (read-only) use, nonconst type
/// indicates mutating use (e.g. the corresponding input can be used as scratch, moved-from, etc.)
template <typename keyT, typename output_terminalsT, typename derivedT, typename input_valueTs>
class TT : public ttg::TTBase, public ::madness::WorldObject<TT<keyT, output_terminalsT, derivedT, input_valueTs>> {
template <typename keyT, typename output_terminalsT, typename derivedT, typename input_valueTs, ttg::ExecutionSpace Space>
class TT : public ttg::TTBase, public ::madness::WorldObject<TT<keyT, output_terminalsT, derivedT, input_valueTs, Space>> {
static_assert(Space == ttg::ExecutionSpace::Host, "MADNESS backend only supports Host Execution Space");
static_assert(ttg::meta::is_typelist_v<input_valueTs>,
"The fourth template for ttg::TT must be a ttg::typelist containing the input types");
using input_tuple_type = ttg::meta::typelist_to_tuple_t<input_valueTs>;
Expand Down
7 changes: 1 addition & 6 deletions ttg/ttg/make_tt.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class CallableWrapTT
: public TT<
keyT, output_terminalsT,
CallableWrapTT<funcT, returnT, funcT_receives_input_tuple, funcT_receives_outterm_tuple, space, keyT, output_terminalsT, input_valuesT...>,
ttg::typelist<input_valuesT...>> {
ttg::typelist<input_valuesT...>, space> {
using baseT = typename CallableWrapTT::ttT;

using input_values_tuple_type = typename baseT::input_values_tuple_type;
Expand All @@ -44,11 +44,6 @@ class CallableWrapTT
void;
#endif // TTG_HAVE_COROUTINE

public:
static constexpr bool have_cuda_op = (space == ttg::ExecutionSpace::CUDA);
static constexpr bool have_hip_op = (space == ttg::ExecutionSpace::HIP);
static constexpr bool have_level_zero_op = (space == ttg::ExecutionSpace::L0);

protected:

template<typename ReturnT>
Expand Down
4 changes: 3 additions & 1 deletion ttg/ttg/parsec/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ extern "C" struct parsec_context_s;

namespace ttg_parsec {

template <typename keyT, typename output_terminalsT, typename derivedT, typename input_valueTs = ttg::typelist<>>
template <typename keyT, typename output_terminalsT, typename derivedT,
typename input_valueTs = ttg::typelist<>,
ttg::ExecutionSpace Space = ttg::ExecutionSpace::Host>
class TT;

/// \internal the OG name
Expand Down
12 changes: 6 additions & 6 deletions ttg/ttg/parsec/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ namespace ttg_parsec {
template<ttg::ExecutionSpace Space>
parsec_hook_return_t invoke_op() {
if constexpr (Space == ttg::ExecutionSpace::Host) {
return TT::template static_op<Space>(&this->parsec_task);
return TT::static_op(&this->parsec_task);
} else {
return TT::template device_static_op<Space>(&this->parsec_task);
return TT::device_static_op(&this->parsec_task);
}
}

Expand All @@ -263,7 +263,7 @@ namespace ttg_parsec {
if constexpr (Space == ttg::ExecutionSpace::Host) {
return PARSEC_HOOK_RETURN_DONE;
} else {
return TT::template device_static_evaluate<Space>(&this->parsec_task);
return TT::device_static_evaluate(&this->parsec_task);
}
}

Expand Down Expand Up @@ -310,9 +310,9 @@ namespace ttg_parsec {
template<ttg::ExecutionSpace Space>
parsec_hook_return_t invoke_op() {
if constexpr (Space == ttg::ExecutionSpace::Host) {
return TT::template static_op<Space>(&this->parsec_task);
return TT::static_op(&this->parsec_task);
} else {
return TT::template device_static_op<Space>(&this->parsec_task);
return TT::device_static_op(&this->parsec_task);
}
}

Expand All @@ -321,7 +321,7 @@ namespace ttg_parsec {
if constexpr (Space == ttg::ExecutionSpace::Host) {
return PARSEC_HOOK_RETURN_DONE;
} else {
return TT::template device_static_evaluate<Space>(&this->parsec_task);
return TT::device_static_evaluate(&this->parsec_task);
}
}

Expand Down
68 changes: 26 additions & 42 deletions ttg/ttg/parsec/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,9 @@ namespace ttg_parsec {
#endif // TTG_USE_USER_TERMDET
}

template <typename keyT, typename output_terminalsT, typename derivedT, typename input_valueTs = ttg::typelist<>>
void register_tt_profiling(const TT<keyT, output_terminalsT, derivedT, input_valueTs> *t) {
template <typename keyT, typename output_terminalsT, typename derivedT,
typename input_valueTs = ttg::typelist<>, ttg::ExecutionSpace Space>
void register_tt_profiling(const TT<keyT, output_terminalsT, derivedT, input_valueTs, Space> *t) {
#if defined(PARSEC_PROF_TRACE)
std::stringstream ss;
build_composite_name_rec(t->ttg_ptr(), ss);
Expand Down Expand Up @@ -1181,7 +1182,7 @@ namespace ttg_parsec {

} // namespace detail

template <typename keyT, typename output_terminalsT, typename derivedT, typename input_valueTs>
template <typename keyT, typename output_terminalsT, typename derivedT, typename input_valueTs, ttg::ExecutionSpace Space>
class TT : public ttg::TTBase, detail::ParsecTTBase {
private:
/// preconditions
Expand Down Expand Up @@ -1218,29 +1219,17 @@ namespace ttg_parsec {
public:
/// @return true if derivedT::have_cuda_op exists and is defined to true
static constexpr bool derived_has_cuda_op() {
if constexpr (ttg::meta::is_detected_v<have_cuda_op_non_type_t, derivedT>) {
return derivedT::have_cuda_op;
} else {
return false;
}
return Space == ttg::ExecutionSpace::CUDA;
}

/// @return true if derivedT::have_hip_op exists and is defined to true
static constexpr bool derived_has_hip_op() {
if constexpr (ttg::meta::is_detected_v<have_hip_op_non_type_t, derivedT>) {
return derivedT::have_hip_op;
} else {
return false;
}
return Space == ttg::ExecutionSpace::HIP;
}

/// @return true if derivedT::have_hip_op exists and is defined to true
static constexpr bool derived_has_level_zero_op() {
if constexpr (ttg::meta::is_detected_v<have_level_zero_op_non_type_t, derivedT>) {
return derivedT::have_level_zero_op;
} else {
return false;
}
return Space == ttg::ExecutionSpace::L0;
}

/// @return true if the TT supports device execution
Expand Down Expand Up @@ -1355,18 +1344,17 @@ namespace ttg_parsec {
/// dispatches a call to derivedT::op
/// @return void if called a synchronous function, or ttg::coroutine_handle<> if called a coroutine (if non-null,
/// points to the suspended coroutine)
template <ttg::ExecutionSpace Space, typename... Args>
template <typename... Args>
auto op(Args &&...args) {
derivedT *derived = static_cast<derivedT *>(this);
//if constexpr (Space == ttg::ExecutionSpace::Host) {
using return_type = decltype(derived->op(std::forward<Args>(args)...));
if constexpr (std::is_same_v<return_type,void>) {
derived->op(std::forward<Args>(args)...);
return;
}
else {
return derived->op(std::forward<Args>(args)...);
}
using return_type = decltype(derived->op(std::forward<Args>(args)...));
if constexpr (std::is_same_v<return_type,void>) {
derived->op(std::forward<Args>(args)...);
return;
}
else {
return derived->op(std::forward<Args>(args)...);
}
}

template <std::size_t i, typename terminalT, typename Key>
Expand Down Expand Up @@ -1419,7 +1407,6 @@ namespace ttg_parsec {
/**
* Submit callback called by PaRSEC once all input transfers have completed.
*/
template <ttg::ExecutionSpace Space>
static int device_static_submit(parsec_device_gpu_module_t *gpu_device,
parsec_gpu_task_t *gpu_task,
parsec_gpu_exec_stream_t *gpu_stream) {
Expand Down Expand Up @@ -1461,7 +1448,7 @@ namespace ttg_parsec {
#endif // defined(PARSEC_HAVE_DEV_LEVEL_ZERO_SUPPORT) && defined(TTG_HAVE_LEVEL_ZERO)

/* Here we call back into the coroutine again after the transfers have completed */
static_op<Space>(&task->parsec_task);
static_op(&task->parsec_task);

ttg::device::detail::reset_current();

Expand Down Expand Up @@ -1503,7 +1490,6 @@ namespace ttg_parsec {
return rc;
}

template <ttg::ExecutionSpace Space>
static parsec_hook_return_t device_static_evaluate(parsec_task_t* parsec_task) {

task_t *task = (task_t*)parsec_task;
Expand All @@ -1518,7 +1504,7 @@ namespace ttg_parsec {
gpu_task->task_type = 0; // user task
gpu_task->last_data_check_epoch = std::numeric_limits<uint64_t>::max(); // used internally
gpu_task->pushout = 0;
gpu_task->submit = &TT::device_static_submit<Space>;
gpu_task->submit = &TT::device_static_submit;

// one way to force the task device
// currently this will probably break all of PaRSEC if this hint
Expand All @@ -1536,7 +1522,7 @@ namespace ttg_parsec {
task->dev_ptr->task_class = *task->parsec_task.task_class;

// first invocation of the coroutine to get the coroutine handle
static_op<Space>(parsec_task);
static_op(parsec_task);

/* when we come back here, the flows in gpu_task are set (see register_device_memory) */

Expand Down Expand Up @@ -1586,7 +1572,6 @@ namespace ttg_parsec {

}

template <ttg::ExecutionSpace Space>
static parsec_hook_return_t device_static_op(parsec_task_t* parsec_task) {
static_assert(derived_has_device_op());

Expand Down Expand Up @@ -1658,7 +1643,6 @@ namespace ttg_parsec {
}
#endif // TTG_HAVE_DEVICE

template <ttg::ExecutionSpace Space>
static parsec_hook_return_t static_op(parsec_task_t *parsec_task) {

task_t *task = (task_t*)parsec_task;
Expand All @@ -1684,14 +1668,14 @@ namespace ttg_parsec {

if constexpr (!ttg::meta::is_void_v<keyT> && !ttg::meta::is_empty_tuple_v<input_values_tuple_type>) {
auto input = make_tuple_of_ref_from_array(task, std::make_index_sequence<numinvals>{});
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op<Space>(task->key, std::move(input), obj->output_terminals));
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(task->key, std::move(input), obj->output_terminals));
} else if constexpr (!ttg::meta::is_void_v<keyT> && ttg::meta::is_empty_tuple_v<input_values_tuple_type>) {
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op<Space>(task->key, obj->output_terminals));
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(task->key, obj->output_terminals));
} else if constexpr (ttg::meta::is_void_v<keyT> && !ttg::meta::is_empty_tuple_v<input_values_tuple_type>) {
auto input = make_tuple_of_ref_from_array(task, std::make_index_sequence<numinvals>{});
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op<Space>(std::move(input), obj->output_terminals));
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(std::move(input), obj->output_terminals));
} else if constexpr (ttg::meta::is_void_v<keyT> && ttg::meta::is_empty_tuple_v<input_values_tuple_type>) {
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op<Space>(obj->output_terminals));
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(obj->output_terminals));
} else {
ttg::abort();
}
Expand Down Expand Up @@ -1767,7 +1751,6 @@ namespace ttg_parsec {
return PARSEC_HOOK_RETURN_DONE;
}

template <ttg::ExecutionSpace Space>
static parsec_hook_return_t static_op_noarg(parsec_task_t *parsec_task) {
task_t *task = static_cast<task_t*>(parsec_task);

Expand All @@ -1783,9 +1766,9 @@ namespace ttg_parsec {
assert(detail::parsec_ttg_caller == NULL);
detail::parsec_ttg_caller = task;
if constexpr (!ttg::meta::is_void_v<keyT>) {
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op<Space>(task->key, obj->output_terminals));
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(task->key, obj->output_terminals));
} else if constexpr (ttg::meta::is_void_v<keyT>) {
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->template op<Space>(obj->output_terminals));
TTG_PROCESS_TT_OP_RETURN(suspended_task_address, task->coroutine_id, baseobj->op(obj->output_terminals));
} else // unreachable
ttg:: abort();
detail::parsec_ttg_caller = NULL;
Expand Down Expand Up @@ -4385,6 +4368,7 @@ namespace ttg_parsec {
return ttg::device::Device(dm(key), ttg::ExecutionSpace::L0);
} else {
throw std::runtime_error("Unknown device type!");
return ttg::device::Device{};
}
};
}
Expand Down