diff --git a/.cscs-ci/container/build.Containerfile b/.cscs-ci/container/build.Containerfile index 3128b739..3932328f 100644 --- a/.cscs-ci/container/build.Containerfile +++ b/.cscs-ci/container/build.Containerfile @@ -16,6 +16,7 @@ RUN spack -e ci build-env ghex -- \ -DGHEX_USE_BUNDLED_GTEST=ON \ -DGHEX_USE_BUNDLED_OOMPH=OFF \ -DGHEX_USE_BUNDLED_GRIDTOOLS=OFF \ + -DCMAKE_CUDA_ARCHITECTURES=90 \ -DGHEX_USE_GPU=ON \ -DGHEX_GPU_TYPE=NVIDIA \ -DGHEX_BUILD_PYTHON_BINDINGS=ON \ diff --git a/.cscs-ci/default.yaml b/.cscs-ci/default.yaml index 029f2e73..2899dc58 100644 --- a/.cscs-ci/default.yaml +++ b/.cscs-ci/default.yaml @@ -2,9 +2,9 @@ include: - remote: 'https://gitlab.com/cscs-ci/recipes/-/raw/master/templates/v2/.ci-ext.yml' variables: - BASE_IMAGE: jfrog.svc.cscs.ch/docker-group-csstaff/alps-images/ngc-pytorch:26.01-py3-alps4-dev + BASE_IMAGE: jfrog.svc.cscs.ch/docker-group-csstaff/alps-images/ngc-pytorch:26.01-py3-alps3 SPACK_SHA: v1.1.1 - SPACK_PACKAGES_SHA: a010c65289743f900bdbbfb840e4d1876c24e93f # develop on 2025-05-08 + SPACK_PACKAGES_SHA: 8ea120fe82c02737dddef32451edf88929f266ff # https://github.com/msimberg/spack-packages/tree/oomph-nccl FF_TIMESTAMPS: true .build_deps_template: @@ -25,6 +25,12 @@ variables: reports: dotenv: base-${BACKEND}.env +build_deps_nccl: + variables: + BACKEND: nccl + extends: + - .build_deps_template + build_deps_mpi: extends: .build_deps_template variables: @@ -53,6 +59,14 @@ build_deps_libfabric: reports: dotenv: build-${BACKEND}.env +build_nccl: + extends: .build_template + variables: + BACKEND: nccl + needs: + - job: build_deps_nccl + artifacts: true + build_mpi: extends: .build_template variables: @@ -81,7 +95,7 @@ build_libfabric: extends: .container-runner-clariden-gh200 variables: SLURM_GPUS_PER_TASK: 1 - SLURM_TIMELIMIT: '5:00' + SLURM_TIMELIMIT: '15:00' SLURM_PARTITION: normal SLURM_MPI_TYPE: pmix SLURM_NETWORK: disable_rdzv_get @@ -90,13 +104,15 @@ build_libfabric: PMIX_MCA_psec: native PMIX_MCA_gds: "^shmem2" USE_MPI: NO + NCCL_DEBUG: trace + NCCL_PXN_DISABLE: 1 .test_serial_template: extends: .test_template_base variables: SLURM_NTASKS: 1 script: - - spack -e ci build-env ghex -- ctest --test-dir /ghex/build -L "serial" --output-on-failure --timeout 60 --parallel 8 + - spack -e ci build-env ghex -- ctest --test-dir /ghex/build -L "serial" --verbose --output-on-failure --timeout 600 --parallel 8 .test_parallel_template: extends: .test_template_base @@ -105,12 +121,18 @@ build_libfabric: # writing inside the container. - if [[ "${SLURM_LOCALID}" == 0 ]]; then rm -rf /ghex/build/Testing; mkdir /tmp/Testing; ln -s /tmp/Testing /ghex/build/Testing; fi - until [[ -L /ghex/build/Testing ]]; do sleep 1; done - - spack -e ci build-env ghex -- ctest --test-dir /ghex/build -L "parallel-ranks-${SLURM_NTASKS}" --output-on-failure --timeout 60 + - spack -e ci build-env ghex -- ctest --test-dir /ghex/build -L "parallel-ranks-${SLURM_NTASKS}" --verbose --output-on-failure --timeout 600 .test_parallel_job: extends: .test_parallel_template image: $BUILD_IMAGE +.test_parallel_nccl: + extends: .test_parallel_job + needs: + - job: build_nccl + artifacts: true + .test_parallel_mpi: extends: .test_parallel_job needs: @@ -129,6 +151,28 @@ build_libfabric: # - job: build_libfabric # artifacts: true +test_serial_nccl: + extends: .test_serial_template + needs: + - job: build_nccl + artifacts: true + image: $BUILD_IMAGE + +test_parallel_2_nccl: + extends: .test_parallel_nccl + variables: + SLURM_NTASKS: 2 + +test_parallel_4_nccl: + extends: .test_parallel_nccl + variables: + SLURM_NTASKS: 4 + +test_parallel_6_nccl: + extends: .test_parallel_nccl + variables: + SLURM_NTASKS: 6 + test_serial_mpi: extends: .test_serial_template needs: diff --git a/.cscs-ci/spack/libfabric.yaml b/.cscs-ci/spack/libfabric.yaml index d2e71c17..adbe173e 100644 --- a/.cscs-ci/spack/libfabric.yaml +++ b/.cscs-ci/spack/libfabric.yaml @@ -4,3 +4,8 @@ spack: view: false concretizer: unify: true + packages: + oomph: + require: '@git.365fab9af9538694668fc8516750bfb0e96b6be2=main' + package_attributes: + git: "https://github.com/msimberg/oomph.git" diff --git a/.cscs-ci/spack/mpi.yaml b/.cscs-ci/spack/mpi.yaml index 552c95e5..7960e676 100644 --- a/.cscs-ci/spack/mpi.yaml +++ b/.cscs-ci/spack/mpi.yaml @@ -4,3 +4,8 @@ spack: view: false concretizer: unify: true + packages: + oomph: + require: '@git.365fab9af9538694668fc8516750bfb0e96b6be2=main' + package_attributes: + git: "https://github.com/msimberg/oomph.git" diff --git a/.cscs-ci/spack/nccl.yaml b/.cscs-ci/spack/nccl.yaml new file mode 100644 index 00000000..c6fb8206 --- /dev/null +++ b/.cscs-ci/spack/nccl.yaml @@ -0,0 +1,11 @@ +spack: + specs: + - ghex@master backend=nccl +cuda cuda_arch=90a +python + view: false + concretizer: + unify: true + packages: + oomph: + require: '@git.3d7f65888e10ba1689257184005b12e53c907738=main' + package_attributes: + git: "https://github.com/msimberg/oomph.git" diff --git a/.cscs-ci/spack/ucx.yaml b/.cscs-ci/spack/ucx.yaml index dc9220f4..b7bebf8d 100644 --- a/.cscs-ci/spack/ucx.yaml +++ b/.cscs-ci/spack/ucx.yaml @@ -4,3 +4,8 @@ spack: view: false concretizer: unify: true + packages: + oomph: + require: '@git.365fab9af9538694668fc8516750bfb0e96b6be2=main' + package_attributes: + git: "https://github.com/msimberg/oomph.git" diff --git a/.gitignore b/.gitignore index 68ca0053..3d4dbcec 100644 --- a/.gitignore +++ b/.gitignore @@ -48,3 +48,4 @@ doc_src/_build/ __pycache__ _build .venv*/ +_nix_build/ diff --git a/CMakeLists.txt b/CMakeLists.txt index f38bc54c..e5fdedce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -172,6 +172,8 @@ if(GHEX_USE_BUNDLED_OOMPH) set_target_properties(oomph_libfabric PROPERTIES INSTALL_RPATH "${rpath_origin}") elseif (GHEX_TRANSPORT_BACKEND STREQUAL "UCX") set_target_properties(oomph_ucx PROPERTIES INSTALL_RPATH "${rpath_origin}") + elseif (GHEX_TRANSPORT_BACKEND STREQUAL "NCCL") + set_target_properties(oomph_nccl PROPERTIES INSTALL_RPATH "${rpath_origin}") else() set_target_properties(oomph_mpi PROPERTIES INSTALL_RPATH "${rpath_origin}") endif() diff --git a/README.md b/README.md index c9e8e8cd..017fe6fa 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ make test | `GHEX_PYTHON_LIB_PATH=` | `` | `${CMAKE_INSTALL_PREFIX}/` | Installation directory for GHEX's Python package | `GHEX_WITH_TESTING=` | `{ON, OFF}` | `OFF` | Build unit tests | `GHEX_USE_XPMEM=` | `{ON, OFF}` | `OFF` | Use Xpmem -| `GHEX_TRANSPORT_BACKEND=` | `{MPI, UCX, LIBFABRIC}` | `MPI` | Choose transport backend +| `GHEX_TRANSPORT_BACKEND=` | `{MPI, UCX, LIBFABRIC, NCCL}` | `MPI` | Choose transport backend ### Pip Install @@ -75,7 +75,7 @@ python -m pip install ghex | `GHEX_USE_GPU=` | `{ON, OFF}` | `OFF` | Enable GPU | `GHEX_GPU_TYPE=` | `{AUTO, NVIDIA, AMD}` | `AUTO` | Choose GPU type | `GHEX_GPU_ARCH=` | list of archs | `"60;70;75;80"`/ `"gfx900;gfx906"` | GPU architecture -| `GHEX_TRANSPORT_BACKEND=` | `{MPI, UCX, LIBFABRIC}` | `MPI` | Choose transport backend +| `GHEX_TRANSPORT_BACKEND=` | `{MPI, UCX, LIBFABRIC, NCCL}` | `MPI` | Choose transport backend ## Contributing diff --git a/cmake/ghex_external_dependencies.cmake b/cmake/ghex_external_dependencies.cmake index 32c40fe4..3f1ed57e 100644 --- a/cmake/ghex_external_dependencies.cmake +++ b/cmake/ghex_external_dependencies.cmake @@ -43,8 +43,8 @@ endif() # --------------------------------------------------------------------- # oomph setup # --------------------------------------------------------------------- -set(GHEX_TRANSPORT_BACKEND "MPI" CACHE STRING "Choose the backend type: MPI | UCX | LIBFABRIC") -set_property(CACHE GHEX_TRANSPORT_BACKEND PROPERTY STRINGS "MPI" "UCX" "LIBFABRIC") +set(GHEX_TRANSPORT_BACKEND "MPI" CACHE STRING "Choose the backend type: MPI | UCX | LIBFABRIC | NCCL") +set_property(CACHE GHEX_TRANSPORT_BACKEND PROPERTY STRINGS "MPI" "UCX" "LIBFABRIC" "NCCL") cmake_dependent_option(GHEX_USE_BUNDLED_OOMPH "Use bundled oomph." ON "GHEX_USE_BUNDLED_LIBS" OFF) if(GHEX_USE_BUNDLED_OOMPH) set(OOMPH_GIT_SUBMODULE OFF CACHE BOOL "") @@ -53,6 +53,11 @@ if(GHEX_USE_BUNDLED_OOMPH) set(OOMPH_WITH_LIBFABRIC ON CACHE BOOL "Build with LIBFABRIC backend") elseif(GHEX_TRANSPORT_BACKEND STREQUAL "UCX") set(OOMPH_WITH_UCX ON CACHE BOOL "Build with UCX backend") + elseif(GHEX_TRANSPORT_BACKEND STREQUAL "NCCL") + set(OOMPH_WITH_NCCL ON CACHE BOOL "Build with NCCL backend") + if(NOT GHEX_USE_GPU) + message(FATAL_ERROR "GHEX_TRANSPORT_BACKEND=NCCL requires GHEX_USE_GPU=ON but GHEX_USE_GPU=OFF") + endif() endif() if(GHEX_USE_GPU) set(HWMALLOC_ENABLE_DEVICE ON CACHE BOOL "True if GPU support shall be enabled") @@ -70,6 +75,9 @@ if(GHEX_USE_BUNDLED_OOMPH) if(TARGET oomph_ucx) add_library(oomph::oomph_ucx ALIAS oomph_ucx) endif() + if(TARGET oomph_nccl) + add_library(oomph::oomph_nccl ALIAS oomph_nccl) + endif() if(TARGET oomph_libfabric) add_library(oomph::oomph_libfabric ALIAS oomph_libfabric) endif() @@ -82,6 +90,8 @@ function(ghex_link_to_oomph target) target_link_libraries(${target} PRIVATE oomph::oomph_libfabric) elseif (GHEX_TRANSPORT_BACKEND STREQUAL "UCX") target_link_libraries(${target} PRIVATE oomph::oomph_ucx) + elseif (GHEX_TRANSPORT_BACKEND STREQUAL "NCCL") + target_link_libraries(${target} PRIVATE oomph::oomph_nccl) else() target_link_libraries(${target} PRIVATE oomph::oomph_mpi) endif() diff --git a/ext/gridtools b/ext/gridtools index 1141a348..5fb48c4d 160000 --- a/ext/gridtools +++ b/ext/gridtools @@ -1 +1 @@ -Subproject commit 1141a3489346087821b90eeec805ffc0cd2c7676 +Subproject commit 5fb48c4dfa8db88ae84304ff18fd37eb0e5f5298 diff --git a/ext/oomph b/ext/oomph index 96963ecd..3d7f6588 160000 --- a/ext/oomph +++ b/ext/oomph @@ -1 +1 @@ -Subproject commit 96963ecde75d8b3b81a204cf9096c7dd7040576f +Subproject commit 3d7f65888e10ba1689257184005b12e53c907738 diff --git a/include/ghex/communication_object.hpp b/include/ghex/communication_object.hpp index a011b803..77efa2a5 100644 --- a/include/ghex/communication_object.hpp +++ b/include/ghex/communication_object.hpp @@ -1,7 +1,7 @@ /* * ghex-org * - * Copyright (c) 2014-2023, ETH Zurich + * Copyright (c) 2014-2026, ETH Zurich * All rights reserved. * * Please, refer to the LICENSE file in the root directory. @@ -24,7 +24,7 @@ #endif #include #include -#include +#include #include namespace ghex @@ -274,12 +274,28 @@ class communication_object { complete_schedule_exchange(); prepare_exchange_buffers(buffer_infos...); - post_recvs(); - pack_and_send(); + + // Use new code path for NCCL (requires start_group/end_group), old path for UCX/MPI + const char* backend_name = m_comm.get_transport_option("name"); + if (std::strcmp(backend_name, "nccl") == 0) + { + pack(); + m_comm.start_group(); + post_recvs(); + post_sends(); + m_comm.end_group(); + unpack(); + } + else + { + post_recvs(); + this->pack_and_send(); + } + return {this}; } -#if defined(GHEX_CUDACC) // TODO +#if defined(GHEX_CUDACC) /** @brief Start a synchronized exchange. * * This function is similar to `exchange()` but it has some important (semantic) @@ -303,19 +319,26 @@ class communication_object [[nodiscard]] handle_type schedule_exchange(cudaStream_t stream, buffer_info_type... buffer_infos) { - // make sure that the previous exchange has finished and free memory complete_schedule_exchange(); - - // allocate memory, probably for the receiving buffers prepare_exchange_buffers(buffer_infos...); + schedule_sync_pack(stream); - // set up the receives, and also install the call backs that will then do the unpacking - post_recvs(); - - // NOTE: The function will wait until the sends have been concluded, so it is not - // fully asynchronous. Changing that might be hard because this might lead - // to race conditions somewhere else, but it ensures that progress is made. - pack_and_send(stream); + // Use new code path for NCCL (requires start_group/end_group), old path for UCX/MPI + const char* backend_name = m_comm.get_transport_option("name"); + if (std::strcmp(backend_name, "nccl") == 0) + { + pack(); + m_comm.start_group(); + post_recvs(); + post_sends(); + m_comm.end_group(); + unpack(); + } + else + { + post_recvs(); + pack_and_send(stream); + } return {this}; } @@ -326,8 +349,24 @@ class communication_object { complete_schedule_exchange(); prepare_exchange_buffers(std::make_pair(std::move(first), std::move(last))); - post_recvs(); - pack_and_send(stream); + schedule_sync_pack(stream); + + // Use new code path for NCCL (requires start_group/end_group), old path for UCX/MPI + const char* backend_name = m_comm.get_transport_option("name"); + if (std::strcmp(backend_name, "nccl") == 0) + { + pack(); + m_comm.start_group(); + post_recvs(); + post_sends(); + m_comm.end_group(); + unpack(); + } + else + { + post_recvs(); + pack_and_send(stream); + } return {this}; } @@ -361,7 +400,7 @@ class communication_object Iterator0 last0, Iterator1 first1, Iterator1 last1, Iterators... iters) { static_assert(sizeof...(Iterators) % 2 == 0, - "need even number of iteratiors: (begin,end) pairs"); + "need even number of iterators: (begin, end) pairs"); // call helper function to turn iterators into pairs of iterators return exchange_make_pairs(std::make_index_sequence<2 + sizeof...(iters) / 2>(), first0, last0, first1, last1, iters...); @@ -384,8 +423,24 @@ class communication_object { complete_schedule_exchange(); prepare_exchange_buffers(iter_pairs...); - post_recvs(); - pack_and_send(); + + // Use new code path for NCCL (requires start_group/end_group), old path for UCX/MPI + const char* backend_name = m_comm.get_transport_option("name"); + if (std::strcmp(backend_name, "nccl") == 0) + { + pack(); + m_comm.start_group(); + post_recvs(); + post_sends(); + m_comm.end_group(); + unpack(); + } + else + { + post_recvs(); + this->pack_and_send(); + } + return {this}; } @@ -421,11 +476,18 @@ class communication_object handle_type> exchange_u(Iterator first, Iterator last) { + if (std::string(m_comm.get_transport_option("name")) == "nccl") + { + throw std::runtime_error( + "GHEX: optimized regular-grid GPU exchange is not supported with NCCL backend"); + } using gpu_mem_t = buffer_memory; using field_type = std::remove_reference_tget_field())>; using value_type = typename field_type::value_type; + complete_schedule_exchange(); prepare_exchange_buffers(std::make_pair(first, last)); + // post recvs auto& gpu_mem = std::get(m_mem); for (auto& p0 : gpu_mem.recv_memory) @@ -544,19 +606,13 @@ class communication_object }); } - /** \brief Non synchronizing version of `post_recvs()`. - * - * Create the receives requests and also _register_ the unpacker - * callbacks. The function will return after the receives calls - * have been posted. - */ - void post_recvs() + void pack() { for_each(m_mem, [this](std::size_t, auto& m) { using arch_type = typename std::remove_reference_t::arch_type; - for (auto& p0 : m.recv_memory) + for (auto& p0 : m.send_memory) { const auto device_id = p0.first; for (auto& p1 : p0.second) @@ -568,33 +624,19 @@ class communication_object || p1.second.buffer.device_id() != device_id #endif ) + { p1.second.buffer = arch_traits::make_message(m_comm, p1.second.size, device_id); - auto ptr = &p1.second; - // use callbacks for unpacking - // TODO: Reserve space in vector? - m_recv_reqs.push_back( - m_comm.recv(p1.second.buffer, p1.second.rank, p1.second.tag, - [ptr](context::message_type& m, context::rank_type, - context::tag_type) - { - device::guard g(m); - packer::unpack(*ptr, g.data()); - })); + } + + device::guard g(p1.second.buffer); + packer::pack(p1.second, g.data()); } } } }); } - /** \brief Non synchronizing variant of `pack_and_send()`. - * - * The function will collect copy the halos into a continuous buffers - * and send them to the destination. - * It is important that the function will start packing immediately - * and only return once the packing has been completed and the sending - * request has been posted. - */ void pack_and_send() { for_each(m_mem, @@ -607,15 +649,6 @@ class communication_object } #ifdef GHEX_CUDACC - /** \brief Synchronizing variant of `pack_and_send()`. - * - * As its non synchronizing version, the function packs the halos into - * continuous buffers and starts sending them. The main difference is - * that the function will not pack immediately, instead it will wait - * until all work, that has been submitted to `stream` has finished. - * However, the function will not return until the sending has been - * initiated (subject to change). - */ void pack_and_send(cudaStream_t sync_stream) { for_each(m_mem, @@ -649,6 +682,176 @@ class communication_object } #endif + void post_sends() + { + for_each(m_mem, + [this](std::size_t, auto& map) + { +#ifdef GHEX_CUDACC + // If a communicator isn't stream-aware and we're dealing with GPU memory, we wait + // for each packing kernel to finish and trigger the send as soon as possible. if a + // communicator is stream-aware or we're dealing with CPU memory we trigger sends + // immediately (for stream-aware GPU memory the packing has been scheduled on a + // stream and for CPU memory the packing is blocking and has already completed). + using arch_type = typename std::remove_reference_t::arch_type; + if (!m_comm.is_stream_aware() && std::is_same_v) + { + using send_buffer_type = + typename std::remove_reference_t::send_buffer_type; + using future_type = device::future; + std::vector stream_futures; + + for (auto& p0 : map.send_memory) + { + for (auto& p1 : p0.second) + { + if (p1.second.size > 0u) + { + stream_futures.push_back( + future_type{&(p1.second), p1.second.m_stream}); + } + } + } + + await_futures(stream_futures, + [this](send_buffer_type* b) + { + m_send_reqs.push_back(m_comm.send(b->buffer, b->rank, b->tag, + [](context::message_type&, context::rank_type, context::tag_type) { + })); + }); + } + else +#endif + { + for (auto& p0 : map.send_memory) + { + for (auto& p1 : p0.second) + { + if (p1.second.size > 0u) + { + auto& ptr = p1.second; + assert(ptr.buffer); + m_send_reqs.push_back(m_comm.send( + ptr.buffer, ptr.rank, ptr.tag, + [](context::message_type&, context::rank_type, + context::tag_type) {} +#ifdef GHEX_CUDACC + , + static_cast(p1.second.m_stream.get()) +#endif + )); + } + } + } + } + }); + } + + /** \brief Posts receives without blocking. + * + * Creates messages and posts receives for all memory types. Returns + * immediately after posting receives without waiting for receives to + * complete. + */ + void post_recvs() + { + for_each(m_mem, + [this](std::size_t, auto& m) + { + using arch_type = typename std::remove_reference_t::arch_type; + for (auto& p0 : m.recv_memory) + { + const auto device_id = p0.first; + for (auto& p1 : p0.second) + { + if (p1.second.size > 0u) + { + if (!p1.second.buffer || p1.second.buffer.size() != p1.second.size +#if defined(GHEX_USE_GPU) || defined(GHEX_GPU_MODE_EMULATE) + || p1.second.buffer.device_id() != device_id +#endif + ) + { + p1.second.buffer = arch_traits::make_message(m_comm, + p1.second.size, device_id); + } + + auto ptr = &p1.second; + + // If a communicator is stream-aware and we're dealing with GPU memory + // unpacking will be triggered separately by scheduling it on the same + // stream as the receive. If a communicator isn't stream-aware or we're + // dealing with CPU memory (for which unpacking doesn't happen on a + // stream) we do unpacking in a callback so that it can be triggered as + // soon as possible instead of having to wait for all receives to + // complete before starting any unpacking. + if (m_comm.is_stream_aware() && std::is_same_v) + { + m_recv_reqs.push_back(m_comm.recv( + ptr->buffer, ptr->rank, ptr->tag, + [](context::message_type&, context::rank_type, + context::tag_type) {} +#if defined(GHEX_CUDACC) + , + static_cast(p1.second.m_stream.get()) +#endif + )); + } + else + { + m_recv_reqs.push_back(m_comm.recv( + ptr->buffer, ptr->rank, ptr->tag, + [ptr](context::message_type& m, context::rank_type, + context::tag_type) + { + device::guard g(m); + packer::unpack(*ptr, g.data()); + } +#if defined(GHEX_CUDACC) + , + static_cast(p1.second.m_stream.get()) +#endif + )); + } + } + } + } + }); + } + + /** \brief Trigger unpacking. + * + * In cases where unpacking can be done without callbacks (stream-aware communicator, GPU + * memory) trigger unpacking. In other cases this is a no-op. + */ + void unpack() + { + for_each(m_mem, + [this](std::size_t, auto& m) + { + using arch_type = typename std::remove_reference_t::arch_type; + // If a communicator is stream-aware and we're dealing with GPU memory we can + // schedule the unpacking without waiting for receives. In all other cases unpacking + // is added as callbacks to the receives (see post_recvs()). + if (m_comm.is_stream_aware() && std::is_same_v) + { + for (auto& p0 : m.recv_memory) + { + for (auto& p1 : p0.second) + { + if (p1.second.size > 0u) + { + auto ptr = &p1.second; + device::guard g(ptr->buffer); + packer::unpack(*ptr, g.data()); + } + } + } + } + }); + } + private: // wait functions void progress() { @@ -683,11 +886,11 @@ class communication_object void wait() { - // TODO: This function has a big overlap with `is_read()` should it be implemented + // TODO: This function has a big overlap with `is_ready()` should it be implemented // in terms of it, i.e. something like `while(!is_read()) {};`? if (!m_valid) return; - // wait for data to arrive (unpack callback will be invoked) + m_comm.wait_all(); #ifdef GHEX_CUDACC if (has_scheduled_exchange()) @@ -716,11 +919,23 @@ class communication_object { if (!m_valid) return; - // Wait for data to arrive, needed to make progress. - m_comm.wait_all(); - - // Schedule a wait. - schedule_sync_streams(stream); + // If communicator isn't stream-aware we need to explicitly wait for requests to make sure + // callbacks for unpacking are triggered. If we have CPU memory with a stream-aware + // communicator we also need wait for requests to make sure the blocking unpacking callback + // is called for the CPU communication. + // + // The additional synchronization when CPU memory is involved is a pessimization that could + // theoretically be avoided by separately tracking CPU and GPU memory communication, and + // only waiting for the CPU requests. However, in practice e.g. with NCCL, the communication + // with CPU and GPU memory happens in one NCCL group so waiting for a CPU request means + // waiting for all communication anyway. CPU memory communication with NCCL also only works + // on unified memory architectures. One should avoid communicating CPU and GPU + // memory with the same communicator. + using cpu_mem_t = buffer_memory; + auto& m = std::get(m_mem); + if (!m_comm.is_stream_aware() || !m.recv_memory.empty()) { m_comm.wait_all(); } + + schedule_sync_unpack(stream); // NOTE: We do not call `clear()` here, because the memory might still be // in use. Instead we call `clear()` in the next `schedule_exchange()` call. @@ -747,9 +962,40 @@ class communication_object } } - // Actual implementation of the scheduled wait, for more information, - // see description of the `communication_handle::schedule_wait()`. - void schedule_sync_streams(cudaStream_t stream) + // Add a dependency on the given stream streams such that packing happens + // after work on the given stream has completed, without blocking. + void schedule_sync_pack(cudaStream_t stream) + { + for_each(m_mem, + [&, this](std::size_t, auto& m) + { + using arch_type = typename std::remove_reference_t::arch_type; + if constexpr (std::is_same_v) + { + auto& e = m_event_pool.get_event(); + e.record(stream); + + for (auto& p0 : m.send_memory) + { + for (auto& p1 : p0.second) + { + if (p1.second.size > 0u) + { + // Make sure stream used for packing synchronizes with the + // given stream. + GHEX_CHECK_CUDA_RESULT( + cudaStreamWaitEvent(p1.second.m_stream.get(), e.get(), 0)); + } + } + } + } + }); + } + + // Add a dependency on the unpacking streams such that any work that happens + // on the given stream happens after unpacking has completed, without + // blocking. + void schedule_sync_unpack(cudaStream_t stream) { // NOTE: We only iterate over the receive buffers because `pack_and_send()` will // wait until the sending has been completed. Thus if we are here, the sending @@ -762,13 +1008,14 @@ class communication_object { if (p1.second.size > 0u) { - // Instead of doing a blocking wait, create events on each unpacking - // stream and make `stream` wait on that event. This ensures that - // nothing that will be submitted to `stream` after this function - // starts before the unpacking has finished. - cudaEvent_t& e = m_event_pool.get_event().get(); - GHEX_CHECK_CUDA_RESULT(cudaEventRecord(e, p1.second.m_stream.get())); - GHEX_CHECK_CUDA_RESULT(cudaStreamWaitEvent(stream, e, 0)); + // Instead of doing a blocking wait, create events on each + // unpacking stream and make `stream` wait on that event. + // This ensures that nothing that will be submitted to + // `stream` after this function starts before the unpacking + // has finished. + auto& e = m_event_pool.get_event(); + e.record(p1.second.m_stream.get()); + GHEX_CHECK_CUDA_RESULT(cudaStreamWaitEvent(stream, e.get(), 0)); } } } @@ -780,7 +1027,7 @@ class communication_object // last event function. // TODO: Find out what happens to the event if `stream` is destroyed. assert(m_active_scheduled_exchange == nullptr); - GHEX_CHECK_CUDA_RESULT(cudaEventRecord(m_last_scheduled_exchange.get(), stream)); + m_last_scheduled_exchange.record(stream); m_active_scheduled_exchange = &m_last_scheduled_exchange; } #endif @@ -818,9 +1065,6 @@ class communication_object // important: does not deallocate the memory void clear() { -#ifdef GHEX_CUDACC - assert(!has_scheduled_exchange()); -#endif m_valid = false; m_send_reqs.clear(); m_recv_reqs.clear(); diff --git a/include/ghex/device/cuda/event.hpp b/include/ghex/device/cuda/event.hpp index b35c8ee8..44e533bf 100644 --- a/include/ghex/device/cuda/event.hpp +++ b/include/ghex/device/cuda/event.hpp @@ -26,6 +26,7 @@ struct cuda_event { cudaEvent_t m_event; ghex::util::moved_bit m_moved; + bool m_recorded; cuda_event() : cuda_event(cudaEventDisableTiming) @@ -46,6 +47,29 @@ struct cuda_event operator bool() const noexcept { return !m_moved; } + //! Records an event. + void record(cudaStream_t stream) + { + assert(!m_moved); + GHEX_CHECK_CUDA_RESULT(cudaEventRecord(m_event, stream)); + m_recorded = true; + } + + //! Returns `true` if an event has been recorded and the event is ready. + bool is_ready() const + { + if (m_moved || !m_recorded) { return false; } + + cudaError_t res = cudaEventQuery(m_event); + if (res == cudaSuccess) { return true; } + else if (res == cudaErrorNotReady) { return false; } + else + { + GHEX_CHECK_CUDA_RESULT(res); + return false; + } + } + cudaEvent_t& get() noexcept { assert(!m_moved); diff --git a/include/ghex/device/cuda/runtime.hpp b/include/ghex/device/cuda/runtime.hpp index 4cc1aed2..7189e243 100644 --- a/include/ghex/device/cuda/runtime.hpp +++ b/include/ghex/device/cuda/runtime.hpp @@ -17,41 +17,44 @@ #include /* GridTools cuda -> hip translations */ -#define cudaDeviceProp hipDeviceProp_t -#define cudaDeviceSynchronize hipDeviceSynchronize -#define cudaErrorInvalidValue hipErrorInvalidValue -#define cudaError_t hipError_t -#define cudaEventCreate hipEventCreate -#define cudaEventDestroy hipEventDestroy -#define cudaEventElapsedTime hipEventElapsedTime -#define cudaEventRecord hipEventRecord -#define cudaEventSynchronize hipEventSynchronize -#define cudaEvent_t hipEvent_t -#define cudaFree hipFree -#define cudaFreeHost hipFreeHost -#define cudaGetDevice hipGetDevice -#define cudaGetDeviceCount hipGetDeviceCount -#define cudaGetDeviceProperties hipGetDeviceProperties -#define cudaGetErrorName hipGetErrorName -#define cudaGetErrorString hipGetErrorString -#define cudaGetLastError hipGetLastError -#define cudaMalloc hipMalloc -#define cudaMallocHost hipMallocHost -#define cudaMallocManaged hipMallocManaged -#define cudaMemAttachGlobal hipMemAttachGlobal -#define cudaMemcpy hipMemcpy -#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost -#define cudaMemcpyHostToDevice hipMemcpyHostToDevice -#define cudaMemoryTypeDevice hipMemoryTypeDevice -#define cudaPointerAttributes hipPointerAttribute_t -#define cudaPointerGetAttributes hipPointerGetAttributes -#define cudaSetDevice hipSetDevice -#define cudaStreamCreate hipStreamCreate -#define cudaStreamDestroy hipStreamDestroy -#define cudaStreamSynchronize hipStreamSynchronize -#define cudaStreamWaitEvent hipStreamWaitEvent -#define cudaStream_t hipStream_t -#define cudaSuccess hipSuccess +#define cudaDeviceGetStreamPriorityRange hipDeviceGetStreamPriorityRange +#define cudaDeviceProp hipDeviceProp_t +#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaErrorInvalidValue hipErrorInvalidValue +#define cudaErrorNotReady hipErrorNotReady +#define cudaError_t hipError_t +#define cudaEventCreate hipEventCreate +#define cudaEventDestroy hipEventDestroy +#define cudaEventElapsedTime hipEventElapsedTime +#define cudaEventRecord hipEventRecord +#define cudaEventSynchronize hipEventSynchronize +#define cudaEvent_t hipEvent_t +#define cudaFree hipFree +#define cudaFreeHost hipFreeHost +#define cudaGetDevice hipGetDevice +#define cudaGetDeviceCount hipGetDeviceCount +#define cudaGetDeviceProperties hipGetDeviceProperties +#define cudaGetErrorName hipGetErrorName +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaMalloc hipMalloc +#define cudaMallocHost hipMallocHost +#define cudaMallocManaged hipMallocManaged +#define cudaMemAttachGlobal hipMemAttachGlobal +#define cudaMemcpy hipMemcpy +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemoryTypeDevice hipMemoryTypeDevice +#define cudaPointerAttributes hipPointerAttribute_t +#define cudaPointerGetAttributes hipPointerGetAttributes +#define cudaSetDevice hipSetDevice +#define cudaStreamCreate hipStreamCreate +#define cudaStreamCreateWithPriority hipStreamCreateWithPriority +#define cudaStreamDestroy hipStreamDestroy +#define cudaStreamSynchronize hipStreamSynchronize +#define cudaStreamWaitEvent hipStreamWaitEvent +#define cudaStream_t hipStream_t +#define cudaSuccess hipSuccess /* additional cuda -> hip translations */ #define cudaEventCreateWithFlags hipEventCreateWithFlags diff --git a/include/ghex/device/cuda/stream.hpp b/include/ghex/device/cuda/stream.hpp index 7ade4771..9bdda487 100644 --- a/include/ghex/device/cuda/stream.hpp +++ b/include/ghex/device/cuda/stream.hpp @@ -1,7 +1,7 @@ /* * ghex-org * - * Copyright (c) 2014-2023, ETH Zurich + * Copyright (c) 2014-2026, ETH Zurich * All rights reserved. * * Please, refer to the LICENSE file in the root directory. @@ -27,7 +27,14 @@ struct stream cudaStream_t m_stream; ghex::util::moved_bit m_moved; - stream(){GHEX_CHECK_CUDA_RESULT(cudaStreamCreateWithFlags(&m_stream, cudaStreamNonBlocking))} + stream() + { + int least_priority, greatest_priority; + GHEX_CHECK_CUDA_RESULT( + cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority)) + GHEX_CHECK_CUDA_RESULT( + cudaStreamCreateWithPriority(&m_stream, cudaStreamNonBlocking, greatest_priority)) + } stream(const stream&) = delete; stream& operator=(const stream&) = delete; @@ -60,7 +67,6 @@ struct stream void sync() { - // busy wait here assert(!m_moved); GHEX_CHECK_CUDA_RESULT(cudaStreamSynchronize(m_stream)) } diff --git a/include/ghex/device/event.hpp b/include/ghex/device/event.hpp index ecd4ae1c..f1239296 100644 --- a/include/ghex/device/event.hpp +++ b/include/ghex/device/event.hpp @@ -27,8 +27,7 @@ struct cuda_event cuda_event& operator=(cuda_event&&) noexcept = default; ~cuda_event() noexcept = default; - // By returning `true` we emulate the behaviour of a - // CUDA `stream` that has been moved. + // By returning `true` we emulate a valid (non-moved) CUDA event. constexpr operator bool() const noexcept { return true; } }; diff --git a/include/ghex/packer.hpp b/include/ghex/packer.hpp index 81a15c88..2e3abc53 100644 --- a/include/ghex/packer.hpp +++ b/include/ghex/packer.hpp @@ -1,7 +1,7 @@ /* * ghex-org * - * Copyright (c) 2014-2023, ETH Zurich + * Copyright (c) 2014-2026, ETH Zurich * All rights reserved. * * Please, refer to the LICENSE file in the root directory. @@ -28,6 +28,13 @@ namespace ghex template struct packer { + template + static void pack(Buffer& buffer, unsigned char* data) + { + for (const auto& fb : buffer.field_infos) + fb.call_back(data + fb.offset, *fb.index_container, nullptr); + } + template static void pack(Map& map, Requests& send_reqs, Communicator& comm) { @@ -117,6 +124,14 @@ pack_kernel_u(device::kernel_argument args) template<> struct packer { + template + static void pack(Buffer& buffer, unsigned char* data) + { + auto& stream = buffer.m_stream; + for (const auto& fb : buffer.field_infos) + fb.call_back(data + fb.offset, *fb.index_container, static_cast(&stream.get())); + } + template static void pack(Map& map, Requests& send_reqs, Communicator& comm) { @@ -171,7 +186,7 @@ struct packer { auto& stream = buffer.m_stream; for (const auto& fb : buffer.field_infos) - fb.call_back(data + fb.offset, *fb.index_container, (void*)(&stream.get())); + fb.call_back(data + fb.offset, *fb.index_container, static_cast(&stream.get())); } template diff --git a/include/ghex/unstructured/communication_object_ipr.hpp b/include/ghex/unstructured/communication_object_ipr.hpp index 67d77e6c..056944f1 100644 --- a/include/ghex/unstructured/communication_object_ipr.hpp +++ b/include/ghex/unstructured/communication_object_ipr.hpp @@ -150,8 +150,10 @@ class communication_object_ipr handle exchange() { + m_comm.start_group(); post_recvs(); pack(); + m_comm.end_group(); //while (m_status->num_completed < m_status->num_total) { m_status->comm.progress(); } return {m_status.get()}; } diff --git a/shell.nix b/shell.nix new file mode 100644 index 00000000..ab113878 --- /dev/null +++ b/shell.nix @@ -0,0 +1,42 @@ +{ pkgs ? import { } }: +let + python = pkgs.python3.withPackages (ps: [ + ps.nanobind + ps.mpi4py + ps.numpy + ps.pytest + ps.pytest-mpi + ]); +in +pkgs.mkShell { + name = "ghex-dev"; + + nativeBuildInputs = [ + pkgs.cmake + pkgs.ninja + pkgs.gcc + pkgs.git + python + ]; + + buildInputs = [ + pkgs.openmpi + pkgs.boost + pkgs.numactl + ]; + + # cmake's Python test infrastructure creates a venv and pip-installs + # mpi4py into it. That venv needs to find the Nix-provided MPI libs. + shellHook = '' + export LD_LIBRARY_PATH=${pkgs.openmpi}/lib:${pkgs.numactl}/lib''${LD_LIBRARY_PATH:+:}$LD_LIBRARY_PATH + + echo "ghex dev shell" + echo " cmake -B build -S . -G Ninja \\" + echo " -DGHEX_USE_BUNDLED_LIBS=ON \\" + echo " -DGHEX_GIT_SUBMODULE=OFF \\" + echo " -DGHEX_BUILD_PYTHON_BINDINGS=ON \\" + echo " -DGHEX_WITH_TESTING=ON \\" + echo " -DGHEX_USE_GPU=OFF \\" + echo " -DGHEX_TRANSPORT_BACKEND=MPI" + ''; +} diff --git a/test/bindings/python/CMakeLists.txt b/test/bindings/python/CMakeLists.txt index 8678bd73..ace1e0a3 100644 --- a/test/bindings/python/CMakeLists.txt +++ b/test/bindings/python/CMakeLists.txt @@ -87,7 +87,7 @@ function(ghex_reg_parallel_pytest t n) ${venv_dir}/bin/python -m pytest -s --with-mpi ${CMAKE_CURRENT_BINARY_DIR}/test_${t}.py WORKING_DIRECTORY ${pyghex_test_workdir}) endif() - set_tests_properties(py_${t}_parallel PROPERTIES RUN_SERIAL ON LABELS "parallel-ranks-${n}") + set_tests_properties(py_${t}_parallel PROPERTIES RUN_SERIAL ON LABELS "parallel-ranks-${n}" TIMEOUT 1200) endfunction() ghex_reg_pytest(context) diff --git a/test/bindings/python/fixtures/context.py b/test/bindings/python/fixtures/context.py index 0f268b0f..77d28d71 100644 --- a/test/bindings/python/fixtures/context.py +++ b/test/bindings/python/fixtures/context.py @@ -11,6 +11,7 @@ from mpi4py import MPI import pytest +import ghex from ghex.context import make_context @@ -20,9 +21,24 @@ mpi4py.rc.thread_level = "multiple" +@pytest.fixture(params=[True, False], ids=["thread_safe", "not_thread_safe"]) +def thread_safe(request): + return request.param + + @pytest.fixture -def context(): - return make_context(MPI.COMM_WORLD, True) +def context(thread_safe): + if ghex.__config__["transport"] == "NCCL" and thread_safe: + pytest.skip("NCCL not supported with thread_safe = true") + ctx = make_context(MPI.COMM_WORLD, thread_safe) + yield ctx + # Explicit cleanup to ensure UCX/MPI resources are released. + # Necessary to prevent state accumulation between tests that can cause + # subsequent parallel tests to hang (especially with UCX backend). + del ctx + import gc + + gc.collect() @pytest.fixture @@ -30,9 +46,23 @@ def mpi_cart_comm(): mpi_comm = MPI.COMM_WORLD dims = MPI.Compute_dims(mpi_comm.Get_size(), [0, 0, 0]) mpi_cart_comm = mpi_comm.Create_cart(dims=dims, periods=[False, False, False]) - return mpi_cart_comm + yield mpi_cart_comm + # Explicitly free the communicator to clean up UCX/MPI state. + # Without this, subsequent parallel tests may hang due to leftover state + # from previous tests (particularly when mixing structured and unstructured tests). + mpi_cart_comm.Free() @pytest.fixture -def cart_context(mpi_cart_comm): - return make_context(mpi_cart_comm, True) +def cart_context(mpi_cart_comm, thread_safe): + if ghex.__config__["transport"] == "NCCL" and thread_safe: + pytest.skip("NCCL not supported with thread_safe = true") + ctx = make_context(mpi_cart_comm, thread_safe) + yield ctx + # Explicit cleanup to ensure UCX/MPI resources are released. + # Necessary to prevent state accumulation between tests that can cause + # subsequent parallel tests to hang (especially with UCX backend). + del ctx + import gc + + gc.collect() diff --git a/test/bindings/python/test_context.py b/test/bindings/python/test_context.py index e0b878ab..99d19f90 100644 --- a/test/bindings/python/test_context.py +++ b/test/bindings/python/test_context.py @@ -27,11 +27,18 @@ def test_mpi_comm(): comm = ghex.mpi_comm("invalid") +@pytest.mark.parametrize("thread_safe", [True, False], ids=["thread_safe", "not_thread_safe"]) @pytest.mark.mpi_skip -def test_context_mpi4py(): - ctx = make_context(MPI.COMM_WORLD, True) - assert ctx.size() == 1 - assert ctx.rank() == 0 +def test_context_mpi4py(thread_safe): + try: + ctx = make_context(MPI.COMM_WORLD, thread_safe) + assert ctx.size() == 1 + assert ctx.rank() == 0 + except RuntimeError as e: + if ghex.__config__["transport"] == "NCCL" and thread_safe: + assert str(e) == "NCCL not supported with thread_safe = true" + else: + raise @pytest.mark.mpi diff --git a/test/bindings/python/test_structured_domain_descriptor.py b/test/bindings/python/test_structured_domain_descriptor.py index 22600ce3..074cb2b5 100644 --- a/test/bindings/python/test_structured_domain_descriptor.py +++ b/test/bindings/python/test_structured_domain_descriptor.py @@ -9,7 +9,6 @@ # import pytest -from ghex.context import make_context from ghex.structured.cartesian_sets import IndexSpace, UnitRange from ghex.structured.regular import DomainDescriptor, HaloGenerator @@ -29,8 +28,8 @@ @pytest.mark.mpi -def test_domain_descriptor(capsys, mpi_cart_comm): - ctx = make_context(mpi_cart_comm, True) +def test_domain_descriptor(capsys, mpi_cart_comm, cart_context): + ctx = cart_context coords = mpi_cart_comm.Get_coords(mpi_cart_comm.Get_rank()) coords2 = mpi_cart_comm.Get_coords(ctx.rank()) @@ -61,17 +60,15 @@ def test_halo_gen_construction(capsys, mpi_cart_comm, halos): print(halos) dims = mpi_cart_comm.dims glob_domain_indices = ( - UnitRange(0, dims[0] * Nx) - * UnitRange(0, dims[1] * Ny) - * UnitRange(0, dims[2] * Nz) + UnitRange(0, dims[0] * Nx) * UnitRange(0, dims[1] * Ny) * UnitRange(0, dims[2] * Nz) ) halo_gen = HaloGenerator(glob_domain_indices, halos, (False, False, False)) @pytest.mark.parametrize("halos", haloss) @pytest.mark.mpi -def test_halo_gen_call(mpi_cart_comm, halos): - ctx = make_context(mpi_cart_comm, True) +def test_halo_gen_call(mpi_cart_comm, cart_context, halos): + ctx = cart_context periodicity = (False, False, False) p_coord = tuple(mpi_cart_comm.Get_coords(mpi_cart_comm.Get_rank())) @@ -100,8 +97,8 @@ def test_halo_gen_call(mpi_cart_comm, halos): @pytest.mark.parametrize("halos", haloss) @pytest.mark.mpi -def test_domain_descriptor_grid(mpi_cart_comm, halos): - ctx = make_context(mpi_cart_comm, True) +def test_domain_descriptor_grid(mpi_cart_comm, cart_context, halos): + ctx = cart_context p_coord = tuple(mpi_cart_comm.Get_coords(mpi_cart_comm.Get_rank())) diff --git a/test/bindings/python/test_unstructured_domain_descriptor.py b/test/bindings/python/test_unstructured_domain_descriptor.py index 844f8851..c0c13f81 100644 --- a/test/bindings/python/test_unstructured_domain_descriptor.py +++ b/test/bindings/python/test_unstructured_domain_descriptor.py @@ -28,7 +28,6 @@ def __cuda_stream__(self): STREAM_TYPES_TO_TEST = [None] # Must be at least one element. import ghex -from ghex.context import make_context from ghex.unstructured import make_communication_object from ghex.unstructured import DomainDescriptor from ghex.unstructured import HaloGenerator @@ -225,16 +224,17 @@ def __cuda_stream__(self): LEVELS = 2 + @pytest.mark.parametrize("dtype", [np.float64, np.float32, np.int32, np.int64]) @pytest.mark.parametrize("on_gpu", [True, False]) @pytest.mark.mpi -def test_domain_descriptor(on_gpu, capsys, mpi_cart_comm, dtype): +def test_domain_descriptor(on_gpu, capsys, mpi_cart_comm, cart_context, dtype): # Does not uses streams. if on_gpu and cp is None: pytest.skip(reason="`CuPy` is not installed.") - ctx = make_context(mpi_cart_comm, True) + ctx = cart_context assert ctx.size() == 4 domain_desc = DomainDescriptor( @@ -247,9 +247,7 @@ def test_domain_descriptor(on_gpu, capsys, mpi_cart_comm, dtype): def make_field(order): # Creation is always on host. - data = np.zeros( - [len(domains[ctx.rank()]["all"]), LEVELS], dtype=dtype, order=order - ) + data = np.zeros([len(domains[ctx.rank()]["all"]), LEVELS], dtype=dtype, order=order) inner_set = set(domains[ctx.rank()]["inner"]) all_list = domains[ctx.rank()]["all"] for x in range(len(all_list)): @@ -278,13 +276,11 @@ def check_field(data, order): if gid in inner_set: assert data[x, l] == ctx.rank() * 1000 + 10 * gid + l else: - assert ( - data[x, l] - 1000 * int((data[x, l]) / 1000) - ) == 10 * gid + l + assert (data[x, l] - 1000 * int((data[x, l]) / 1000)) == 10 * gid + l # TODO: Find out if there is a side effect that makes it important to keep them. - #field = make_field_descriptor(domain_desc, data) - #return data, field + # field = make_field_descriptor(domain_desc, data) + # return data, field halo_gen = HaloGenerator.from_gids(domains[ctx.rank()]["outer"]) pattern = make_pattern(ctx, halo_gen, [domain_desc]) @@ -304,7 +300,7 @@ def check_field(data, order): @pytest.mark.parametrize("on_gpu", [True, False]) @pytest.mark.parametrize("stream_type", STREAM_TYPES_TO_TEST) @pytest.mark.mpi -def test_domain_descriptor_async(on_gpu, stream_type, capsys, mpi_cart_comm, dtype): +def test_domain_descriptor_async(on_gpu, stream_type, capsys, mpi_cart_comm, cart_context, dtype): if on_gpu: if cp is None: @@ -312,9 +308,15 @@ def test_domain_descriptor_async(on_gpu, stream_type, capsys, mpi_cart_comm, dty if not cp.is_available(): pytest.skip(reason="`CuPy` is installed but no GPU could be found.") if not ghex.__config__["gpu"]: - pytest.skip(reason="Skipping `schedule_exchange()` tests because `GHEX` was not compiled with GPU support") + pytest.skip( + reason="Skipping `schedule_exchange()` tests because `GHEX` was not compiled with GPU support" + ) - ctx = make_context(mpi_cart_comm, True) + ctx = cart_context + print( + f"[RANK {ctx.rank()}] test_domain_descriptor_async: dtype={dtype}, on_gpu={on_gpu}, stream_type={stream_type}", + flush=True, + ) assert ctx.size() == 4 domain_desc = DomainDescriptor( @@ -326,9 +328,7 @@ def test_domain_descriptor_async(on_gpu, stream_type, capsys, mpi_cart_comm, dty assert domain_desc.inner_size() == len(domains[ctx.rank()]["inner"]) def make_field(order): - data = np.zeros( - [len(domains[ctx.rank()]["all"]), LEVELS], dtype=dtype, order=order - ) + data = np.zeros([len(domains[ctx.rank()]["all"]), LEVELS], dtype=dtype, order=order) inner_set = set(domains[ctx.rank()]["inner"]) all_list = domains[ctx.rank()]["all"] for x in range(len(all_list)): @@ -357,9 +357,7 @@ def check_field(data, order, stream): if gid in inner_set: assert data[x, l] == ctx.rank() * 1000 + 10 * gid + l else: - assert ( - data[x, l] - 1000 * int((data[x, l]) / 1000) - ) == 10 * gid + l + assert (data[x, l] - 1000 * int((data[x, l]) / 1000)) == 10 * gid + l halo_gen = HaloGenerator.from_gids(domains[ctx.rank()]["outer"]) pattern = make_pattern(ctx, halo_gen, [domain_desc]) diff --git a/test/structured/cubed_sphere/test_cubed_sphere_exchange.cpp b/test/structured/cubed_sphere/test_cubed_sphere_exchange.cpp index 88a38989..e69505fc 100644 --- a/test/structured/cubed_sphere/test_cubed_sphere_exchange.cpp +++ b/test/structured/cubed_sphere/test_cubed_sphere_exchange.cpp @@ -21,6 +21,8 @@ #include #include "../../util/memory.hpp" +#include "../../util/nccl_test_helpers.hpp" +#include #include #include @@ -929,116 +931,131 @@ check_field(const Field& field, int halo, int n) TEST_F(mpi_test_fixture, cubed_sphere) { + // NCCL PXN (PCIe x NVLink) warns about host buffers. This is expected for CPU memory + // with NCCL and is handled by setting NCCL_PXN_DISABLE=1 in the test environment. using namespace ghex::structured::cubed_sphere; EXPECT_TRUE(world_size == 6); - // create context - ghex::context ctxt(world, thread_safe); + try + { + // create context + ghex::context ctxt(world, thread_safe); - // halo generator with 2 halo lines in x and y dimensions (on both sides) - halo_generator halo_gen(2); + // halo generator with 2 halo lines in x and y dimensions (on both sides) + halo_generator halo_gen(2); - // cube with size 10 and 6 levels - cube c{10, 6}; + // cube with size 10 and 6 levels + cube c{10, 6}; - // define 4 local domains - domain_descriptor domain0(c, ctxt.rank(), 0, 4, 0, 4); - domain_descriptor domain1(c, ctxt.rank(), 5, 9, 0, 4); - domain_descriptor domain2(c, ctxt.rank(), 0, 4, 5, 9); - domain_descriptor domain3(c, ctxt.rank(), 5, 9, 5, 9); - std::vector local_domains{domain0, domain1, domain2, domain3}; + // define 4 local domains + domain_descriptor domain0(c, ctxt.rank(), 0, 4, 0, 4); + domain_descriptor domain1(c, ctxt.rank(), 5, 9, 0, 4); + domain_descriptor domain2(c, ctxt.rank(), 0, 4, 5, 9); + domain_descriptor domain3(c, ctxt.rank(), 5, 9, 5, 9); + std::vector local_domains{domain0, domain1, domain2, domain3}; - // allocate large enough memory for fields, sufficient for 3 halo lines - // use 8 components per field and 6 z-levels - const int halo = 3; - ghex::test::util::memory data_dom_0((2 * halo + 5) * (2 * halo + 5) * 6 * 8, - -1); // fields - ghex::test::util::memory data_dom_1((2 * halo + 5) * (2 * halo + 5) * 6 * 8, - -1); // fields - ghex::test::util::memory data_dom_2((2 * halo + 5) * (2 * halo + 5) * 6 * 8, - -1); // fields - ghex::test::util::memory data_dom_3((2 * halo + 5) * (2 * halo + 5) * 6 * 8, - -1); // fields + // allocate large enough memory for fields, sufficient for 3 halo lines + // use 8 components per field and 6 z-levels + const int halo = 3; + ghex::test::util::memory data_dom_0((2 * halo + 5) * (2 * halo + 5) * 6 * 8, + -1); // fields + ghex::test::util::memory data_dom_1((2 * halo + 5) * (2 * halo + 5) * 6 * 8, + -1); // fields + ghex::test::util::memory data_dom_2((2 * halo + 5) * (2 * halo + 5) * 6 * 8, + -1); // fields + ghex::test::util::memory data_dom_3((2 * halo + 5) * (2 * halo + 5) * 6 * 8, + -1); // fields - // initialize physical domain (leave halos as they are) - for (int comp = 0; comp < 8; ++comp) - for (int z = 0; z < 6; ++z) - for (int y = 0; y < 5; ++y) - for (int x = 0; x < 5; ++x) - { - const auto idx = (x + halo) + (y + halo) * (2 * halo + 5) + - z * (2 * halo + 5) * (2 * halo + 5) + - comp * (2 * halo + 5) * (2 * halo + 5) * 6; - data_dom_0[idx] = 100000 * (domain0.domain_id().tile + 1) + - 10000 * id_to_int(domain0.domain_id().id) + 1000 * comp + - 100 * x + 10 * y + 1 * z; - data_dom_1[idx] = 100000 * (domain1.domain_id().tile + 1) + - 10000 * id_to_int(domain1.domain_id().id) + 1000 * comp + - 100 * x + 10 * y + 1 * z; - data_dom_2[idx] = 100000 * (domain2.domain_id().tile + 1) + - 10000 * id_to_int(domain2.domain_id().id) + 1000 * comp + - 100 * x + 10 * y + 1 * z; - data_dom_3[idx] = 100000 * (domain3.domain_id().tile + 1) + - 10000 * id_to_int(domain3.domain_id().id) + 1000 * comp + - 100 * x + 10 * y + 1 * z; - } + // initialize physical domain (leave halos as they are) + for (int comp = 0; comp < 8; ++comp) + for (int z = 0; z < 6; ++z) + for (int y = 0; y < 5; ++y) + for (int x = 0; x < 5; ++x) + { + const auto idx = (x + halo) + (y + halo) * (2 * halo + 5) + + z * (2 * halo + 5) * (2 * halo + 5) + + comp * (2 * halo + 5) * (2 * halo + 5) * 6; + data_dom_0[idx] = 100000 * (domain0.domain_id().tile + 1) + + 10000 * id_to_int(domain0.domain_id().id) + 1000 * comp + + 100 * x + 10 * y + 1 * z; + data_dom_1[idx] = 100000 * (domain1.domain_id().tile + 1) + + 10000 * id_to_int(domain1.domain_id().id) + 1000 * comp + + 100 * x + 10 * y + 1 * z; + data_dom_2[idx] = 100000 * (domain2.domain_id().tile + 1) + + 10000 * id_to_int(domain2.domain_id().id) + 1000 * comp + + 100 * x + 10 * y + 1 * z; + data_dom_3[idx] = 100000 * (domain3.domain_id().tile + 1) + + 10000 * id_to_int(domain3.domain_id().id) + 1000 * comp + + 100 * x + 10 * y + 1 * z; + } #if defined(GHEX_USE_GPU) || defined(GHEX_GPU_MODE_EMULATE) - using arch_t = ghex::gpu; - float* data_ptr_0 = data_dom_0.device_data(); - float* data_ptr_1 = data_dom_1.device_data(); - float* data_ptr_2 = data_dom_2.device_data(); - float* data_ptr_3 = data_dom_3.device_data(); - data_dom_0.clone_to_device(); - data_dom_1.clone_to_device(); - data_dom_2.clone_to_device(); - data_dom_3.clone_to_device(); + using arch_t = ghex::gpu; + float* data_ptr_0 = data_dom_0.device_data(); + float* data_ptr_1 = data_dom_1.device_data(); + float* data_ptr_2 = data_dom_2.device_data(); + float* data_ptr_3 = data_dom_3.device_data(); + data_dom_0.clone_to_device(); + data_dom_1.clone_to_device(); + data_dom_2.clone_to_device(); + data_dom_3.clone_to_device(); #else - using arch_t = ghex::cpu; - float* data_ptr_0 = data_dom_0.host_data(); - float* data_ptr_1 = data_dom_1.host_data(); - float* data_ptr_2 = data_dom_2.host_data(); - float* data_ptr_3 = data_dom_3.host_data(); + using arch_t = ghex::cpu; + float* data_ptr_0 = data_dom_0.host_data(); + float* data_ptr_1 = data_dom_1.host_data(); + float* data_ptr_2 = data_dom_2.host_data(); + float* data_ptr_3 = data_dom_3.host_data(); #endif - // wrap field memory in a field_descriptor - field_descriptor field_dom_0(domain0, data_ptr_0, - std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 6}, 8); - field_descriptor field_dom_1(domain1, data_ptr_1, - std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 6}, 8); - field_descriptor field_dom_2(domain2, data_ptr_2, - std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 6}, 8); - field_descriptor field_dom_3(domain3, data_ptr_3, - std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 6}, 8); + // wrap field memory in a field_descriptor + field_descriptor field_dom_0(domain0, data_ptr_0, + std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 6}, + 8); + field_descriptor field_dom_1(domain1, data_ptr_1, + std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 6}, + 8); + field_descriptor field_dom_2(domain2, data_ptr_2, + std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 6}, + 8); + field_descriptor field_dom_3(domain3, data_ptr_3, + std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 6}, + 8); - // create a structured pattern - auto pattern1 = ghex::make_pattern(ctxt, halo_gen, local_domains); + // create a structured pattern + auto pattern1 = ghex::make_pattern(ctxt, halo_gen, local_domains); - // make a communication object - using pattern_type = decltype(pattern1); - auto co = ghex::make_communication_object(ctxt); + // make a communication object + using pattern_type = decltype(pattern1); + auto co = ghex::make_communication_object(ctxt); - // exchange halo data - co.exchange(pattern1(field_dom_0), pattern1(field_dom_1), pattern1(field_dom_2), - pattern1(field_dom_3)) - .wait(); + // exchange halo data + co.exchange(pattern1(field_dom_0), pattern1(field_dom_1), pattern1(field_dom_2), + pattern1(field_dom_3)) + .wait(); #if defined(GHEX_USE_GPU) || defined(GHEX_GPU_MODE_EMULATE) - data_dom_0.clone_to_host(); - data_dom_1.clone_to_host(); - data_dom_2.clone_to_host(); - data_dom_3.clone_to_host(); - field_dom_0.set_data(data_dom_0.host_data()); - field_dom_1.set_data(data_dom_1.host_data()); - field_dom_2.set_data(data_dom_2.host_data()); - field_dom_3.set_data(data_dom_3.host_data()); + data_dom_0.clone_to_host(); + data_dom_1.clone_to_host(); + data_dom_2.clone_to_host(); + data_dom_3.clone_to_host(); + field_dom_0.set_data(data_dom_0.host_data()); + field_dom_1.set_data(data_dom_1.host_data()); + field_dom_2.set_data(data_dom_2.host_data()); + field_dom_3.set_data(data_dom_3.host_data()); #endif - // check results - check_field(field_dom_0, 2, 5); - check_field(field_dom_1, 2, 5); - check_field(field_dom_2, 2, 5); - check_field(field_dom_3, 2, 5); + // check results + check_field(field_dom_0, 2, 5); + check_field(field_dom_1, 2, 5); + check_field(field_dom_2, 2, 5); + check_field(field_dom_3, 2, 5); + } + catch (std::runtime_error const& e) + { + if (thread_safe) ghex::test::handle_nccl_thread_safe_exception(world, e); + else + ghex::test::handle_nccl_self_comm_exception(world, e); + } } TEST_F(mpi_test_fixture, cubed_sphere_vector) @@ -1046,115 +1063,124 @@ TEST_F(mpi_test_fixture, cubed_sphere_vector) using namespace ghex::structured::cubed_sphere; EXPECT_TRUE(world_size == 6); - // create context - ghex::context ctxt(world, thread_safe); + try + { + // create context + ghex::context ctxt(world, thread_safe); - // halo generator with 2 halo lines in x and y dimensions (on both sides) - halo_generator halo_gen(2); + // halo generator with 2 halo lines in x and y dimensions (on both sides) + halo_generator halo_gen(2); - // cube with size 10 and 7 levels - cube c{10, 7}; + // cube with size 10 and 7 levels + cube c{10, 7}; - // define 4 local domains - domain_descriptor domain0(c, ctxt.rank(), 0, 4, 0, 4); - domain_descriptor domain1(c, ctxt.rank(), 5, 9, 0, 4); - domain_descriptor domain2(c, ctxt.rank(), 0, 4, 5, 9); - domain_descriptor domain3(c, ctxt.rank(), 5, 9, 5, 9); - std::vector local_domains{domain0, domain1, domain2, domain3}; + // define 4 local domains + domain_descriptor domain0(c, ctxt.rank(), 0, 4, 0, 4); + domain_descriptor domain1(c, ctxt.rank(), 5, 9, 0, 4); + domain_descriptor domain2(c, ctxt.rank(), 0, 4, 5, 9); + domain_descriptor domain3(c, ctxt.rank(), 5, 9, 5, 9); + std::vector local_domains{domain0, domain1, domain2, domain3}; - // allocate large enough memory for fields, sufficient for 3 halo lines - // use 8 components per field and 6 z-levels - const int halo = 3; - ghex::test::util::memory data_dom_0((2 * halo + 5) * (2 * halo + 5) * 3 * 7, - -1); // fields - ghex::test::util::memory data_dom_1((2 * halo + 5) * (2 * halo + 5) * 3 * 7, - -1); // fields - ghex::test::util::memory data_dom_2((2 * halo + 5) * (2 * halo + 5) * 3 * 7, - -1); // fields - ghex::test::util::memory data_dom_3((2 * halo + 5) * (2 * halo + 5) * 3 * 7, - -1); // fields + // allocate large enough memory for fields, sufficient for 3 halo lines + // use 8 components per field and 6 z-levels + const int halo = 3; + ghex::test::util::memory data_dom_0((2 * halo + 5) * (2 * halo + 5) * 3 * 7, + -1); // fields + ghex::test::util::memory data_dom_1((2 * halo + 5) * (2 * halo + 5) * 3 * 7, + -1); // fields + ghex::test::util::memory data_dom_2((2 * halo + 5) * (2 * halo + 5) * 3 * 7, + -1); // fields + ghex::test::util::memory data_dom_3((2 * halo + 5) * (2 * halo + 5) * 3 * 7, + -1); // fields - // initialize physical domain (leave halos as they are) - for (int comp = 0; comp < 3; ++comp) - for (int z = 0; z < 7; ++z) - for (int y = 0; y < 5; ++y) - for (int x = 0; x < 5; ++x) - { - const auto idx = (x + halo) + (y + halo) * (2 * halo + 5) + - z * (2 * halo + 5) * (2 * halo + 5) + - comp * (2 * halo + 5) * (2 * halo + 5) * 7; - data_dom_0[idx] = 100000 * (domain0.domain_id().tile + 1) + - 10000 * id_to_int(domain0.domain_id().id) + 1000 * comp + - 100 * x + 10 * y + 1 * z; - data_dom_1[idx] = 100000 * (domain1.domain_id().tile + 1) + - 10000 * id_to_int(domain1.domain_id().id) + 1000 * comp + - 100 * x + 10 * y + 1 * z; - data_dom_2[idx] = 100000 * (domain2.domain_id().tile + 1) + - 10000 * id_to_int(domain2.domain_id().id) + 1000 * comp + - 100 * x + 10 * y + 1 * z; - data_dom_3[idx] = 100000 * (domain3.domain_id().tile + 1) + - 10000 * id_to_int(domain3.domain_id().id) + 1000 * comp + - 100 * x + 10 * y + 1 * z; - } + // initialize physical domain (leave halos as they are) + for (int comp = 0; comp < 3; ++comp) + for (int z = 0; z < 7; ++z) + for (int y = 0; y < 5; ++y) + for (int x = 0; x < 5; ++x) + { + const auto idx = (x + halo) + (y + halo) * (2 * halo + 5) + + z * (2 * halo + 5) * (2 * halo + 5) + + comp * (2 * halo + 5) * (2 * halo + 5) * 7; + data_dom_0[idx] = 100000 * (domain0.domain_id().tile + 1) + + 10000 * id_to_int(domain0.domain_id().id) + 1000 * comp + + 100 * x + 10 * y + 1 * z; + data_dom_1[idx] = 100000 * (domain1.domain_id().tile + 1) + + 10000 * id_to_int(domain1.domain_id().id) + 1000 * comp + + 100 * x + 10 * y + 1 * z; + data_dom_2[idx] = 100000 * (domain2.domain_id().tile + 1) + + 10000 * id_to_int(domain2.domain_id().id) + 1000 * comp + + 100 * x + 10 * y + 1 * z; + data_dom_3[idx] = 100000 * (domain3.domain_id().tile + 1) + + 10000 * id_to_int(domain3.domain_id().id) + 1000 * comp + + 100 * x + 10 * y + 1 * z; + } #if defined(GHEX_USE_GPU) || defined(GHEX_GPU_MODE_EMULATE) - using arch_t = ghex::gpu; - float* data_ptr_0 = data_dom_0.device_data(); - float* data_ptr_1 = data_dom_1.device_data(); - float* data_ptr_2 = data_dom_2.device_data(); - float* data_ptr_3 = data_dom_3.device_data(); - data_dom_0.clone_to_device(); - data_dom_1.clone_to_device(); - data_dom_2.clone_to_device(); - data_dom_3.clone_to_device(); + using arch_t = ghex::gpu; + float* data_ptr_0 = data_dom_0.device_data(); + float* data_ptr_1 = data_dom_1.device_data(); + float* data_ptr_2 = data_dom_2.device_data(); + float* data_ptr_3 = data_dom_3.device_data(); + data_dom_0.clone_to_device(); + data_dom_1.clone_to_device(); + data_dom_2.clone_to_device(); + data_dom_3.clone_to_device(); #else - using arch_t = ghex::cpu; - float* data_ptr_0 = data_dom_0.host_data(); - float* data_ptr_1 = data_dom_1.host_data(); - float* data_ptr_2 = data_dom_2.host_data(); - float* data_ptr_3 = data_dom_3.host_data(); + using arch_t = ghex::cpu; + float* data_ptr_0 = data_dom_0.host_data(); + float* data_ptr_1 = data_dom_1.host_data(); + float* data_ptr_2 = data_dom_2.host_data(); + float* data_ptr_3 = data_dom_3.host_data(); #endif - // wrap field memory in a field_descriptor - field_descriptor field_dom_0(domain0, data_ptr_0, - std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 7}, 3, - true); - field_descriptor field_dom_1(domain1, data_ptr_1, - std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 7}, 3, - true); - field_descriptor field_dom_2(domain2, data_ptr_2, - std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 7}, 3, - true); - field_descriptor field_dom_3(domain3, data_ptr_3, - std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 7}, 3, - true); + // wrap field memory in a field_descriptor + field_descriptor field_dom_0(domain0, data_ptr_0, + std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 7}, 3, + true); + field_descriptor field_dom_1(domain1, data_ptr_1, + std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 7}, 3, + true); + field_descriptor field_dom_2(domain2, data_ptr_2, + std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 7}, 3, + true); + field_descriptor field_dom_3(domain3, data_ptr_3, + std::array{halo, halo, 0}, std::array{2 * halo + 5, 2 * halo + 5, 7}, 3, + true); - // create a structured pattern - auto pattern1 = ghex::make_pattern(ctxt, halo_gen, local_domains); + // create a structured pattern + auto pattern1 = ghex::make_pattern(ctxt, halo_gen, local_domains); - // make a communication object - using pattern_type = decltype(pattern1); - auto co = ghex::make_communication_object(ctxt); + // make a communication object + using pattern_type = decltype(pattern1); + auto co = ghex::make_communication_object(ctxt); - // exchange halo data - co.exchange(pattern1(field_dom_0), pattern1(field_dom_1), pattern1(field_dom_2), - pattern1(field_dom_3)) - .wait(); + // exchange halo data + co.exchange(pattern1(field_dom_0), pattern1(field_dom_1), pattern1(field_dom_2), + pattern1(field_dom_3)) + .wait(); #if defined(GHEX_USE_GPU) || defined(GHEX_GPU_MODE_EMULATE) - data_dom_0.clone_to_host(); - data_dom_1.clone_to_host(); - data_dom_2.clone_to_host(); - data_dom_3.clone_to_host(); - field_dom_0.set_data(data_dom_0.host_data()); - field_dom_1.set_data(data_dom_1.host_data()); - field_dom_2.set_data(data_dom_2.host_data()); - field_dom_3.set_data(data_dom_3.host_data()); + data_dom_0.clone_to_host(); + data_dom_1.clone_to_host(); + data_dom_2.clone_to_host(); + data_dom_3.clone_to_host(); + field_dom_0.set_data(data_dom_0.host_data()); + field_dom_1.set_data(data_dom_1.host_data()); + field_dom_2.set_data(data_dom_2.host_data()); + field_dom_3.set_data(data_dom_3.host_data()); #endif - // check results - check_field(field_dom_0, 2, 5); - check_field(field_dom_1, 2, 5); - check_field(field_dom_2, 2, 5); - check_field(field_dom_3, 2, 5); + // check results + check_field(field_dom_0, 2, 5); + check_field(field_dom_1, 2, 5); + check_field(field_dom_2, 2, 5); + check_field(field_dom_3, 2, 5); + } + catch (std::runtime_error const& e) + { + if (thread_safe) ghex::test::handle_nccl_thread_safe_exception(world, e); + else + ghex::test::handle_nccl_self_comm_exception(world, e); + } } diff --git a/test/structured/regular/test_local_rma.cpp b/test/structured/regular/test_local_rma.cpp index c264770d..695cf495 100644 --- a/test/structured/regular/test_local_rma.cpp +++ b/test/structured/regular/test_local_rma.cpp @@ -366,9 +366,30 @@ struct simulation_1 TEST_F(mpi_test_fixture, rma_exchange) { - simulation_1 sim(thread_safe); - sim.exchange(); - sim.exchange(); - sim.exchange(); - EXPECT_TRUE(sim.check()); + try + { + simulation_1 sim(thread_safe); + sim.exchange(); + sim.exchange(); + sim.exchange(); + EXPECT_TRUE(sim.check()); + } + catch (std::runtime_error const& e) + { + if (ghex::context(world, false).transport_context()->get_transport_option("name") == + std::string("nccl")) + { + if (thread_safe) + { + EXPECT_STREQ(e.what(), "NCCL not supported with thread_safe = true"); + } + else + { + EXPECT_STREQ(e.what(), + "oomph NCCL backend: self-send/recv requires an active NCCL group. " + "Use start_group()/end_group() around self-send/recv operations."); + } + } + else { throw; } + } } diff --git a/test/structured/regular/test_regular_domain.cpp b/test/structured/regular/test_regular_domain.cpp index 0137b88d..26b09537 100644 --- a/test/structured/regular/test_regular_domain.cpp +++ b/test/structured/regular/test_regular_domain.cpp @@ -1,7 +1,7 @@ /* * ghex-org * - * Copyright (c) 2014-2023, ETH Zurich + * Copyright (c) 2014-2026, ETH Zurich * All rights reserved. * * Please, refer to the LICENSE file in the root directory. @@ -24,6 +24,7 @@ #include #include "../../util/memory.hpp" +#include "../../util/nccl_test_helpers.hpp" #include #include #include @@ -438,19 +439,34 @@ TEST_F(mpi_test_fixture, exchange_host_host) { using namespace ghex; EXPECT_TRUE((world_size == 1) || (world_size % 2 == 0)); - context ctxt(world, thread_safe); - - if (!thread_safe) + try { - test_exchange::run(ctxt); - test_exchange::run_split(ctxt); + context ctxt(world, thread_safe); + + if (!thread_safe) + { + test_exchange::run(ctxt); + if (!ghex::test::is_nccl_backend(world)) + { + test_exchange::run_split(ctxt); + } + } + else + { + if (!ghex::test::is_nccl_backend(world)) + { + test_exchange::run_mt(ctxt); + test_exchange::run_mt_async(ctxt); + test_exchange::run_mt_async_ret(ctxt); + test_exchange::run_mt_deferred_ret(ctxt); + } + } } - else + catch (std::runtime_error const& e) { - test_exchange::run_mt(ctxt); - test_exchange::run_mt_async(ctxt); - test_exchange::run_mt_async_ret(ctxt); - test_exchange::run_mt_deferred_ret(ctxt); + if (thread_safe) ghex::test::handle_nccl_thread_safe_exception(world, e); + else + ghex::test::handle_nccl_self_comm_exception(world, e); } } @@ -458,19 +474,35 @@ TEST_F(mpi_test_fixture, exchange_host_host_vector) { using namespace ghex; EXPECT_TRUE((world_size == 1) || (world_size % 2 == 0)); - context ctxt(world, thread_safe); - - if (!thread_safe) + try { - test_exchange::run(ctxt); - test_exchange::run_split(ctxt); + context ctxt(world, thread_safe); + + if (!thread_safe) + { + test_exchange::run(ctxt); + if (!ghex::test::is_nccl_backend(world)) + { + test_exchange::run_split(ctxt); + } + } + else + { + if (!ghex::test::is_nccl_backend(world)) + { + test_exchange::run_mt(ctxt); + test_exchange::run_mt_async(ctxt); + test_exchange::run_mt_async_ret(ctxt); + test_exchange::run_mt_deferred_ret( + ctxt); + } + } } - else + catch (std::runtime_error const& e) { - test_exchange::run_mt(ctxt); - test_exchange::run_mt_async(ctxt); - test_exchange::run_mt_async_ret(ctxt); - test_exchange::run_mt_deferred_ret(ctxt); + if (thread_safe) ghex::test::handle_nccl_thread_safe_exception(world, e); + else + ghex::test::handle_nccl_self_comm_exception(world, e); } } @@ -479,19 +511,34 @@ TEST_F(mpi_test_fixture, exchange_device_device) { using namespace ghex; EXPECT_TRUE((world_size == 1) || (world_size % 2 == 0)); - context ctxt(world, thread_safe); - - if (!thread_safe) + try { - test_exchange::run(ctxt); - test_exchange::run_split(ctxt); + context ctxt(world, thread_safe); + + if (!thread_safe) + { + test_exchange::run(ctxt); + if (!ghex::test::is_nccl_backend(world)) + { + test_exchange::run_split(ctxt); + } + } + else + { + if (!ghex::test::is_nccl_backend(world)) + { + test_exchange::run_mt(ctxt); + test_exchange::run_mt_async(ctxt); + test_exchange::run_mt_async_ret(ctxt); + test_exchange::run_mt_deferred_ret(ctxt); + } + } } - else + catch (std::runtime_error const& e) { - test_exchange::run_mt(ctxt); - test_exchange::run_mt_async(ctxt); - test_exchange::run_mt_async_ret(ctxt); - test_exchange::run_mt_deferred_ret(ctxt); + if (thread_safe) ghex::test::handle_nccl_thread_safe_exception(world, e); + else + ghex::test::handle_nccl_self_comm_exception(world, e); } } @@ -499,19 +546,35 @@ TEST_F(mpi_test_fixture, exchange_device_device_vector) { using namespace ghex; EXPECT_TRUE((world_size == 1) || (world_size % 2 == 0)); - context ctxt(world, thread_safe); - - if (!thread_safe) + try { - test_exchange::run(ctxt); - test_exchange::run_split(ctxt); + context ctxt(world, thread_safe); + + if (!thread_safe) + { + test_exchange::run(ctxt); + if (!ghex::test::is_nccl_backend(world)) + { + test_exchange::run_split(ctxt); + } + } + else + { + if (!ghex::test::is_nccl_backend(world)) + { + test_exchange::run_mt(ctxt); + test_exchange::run_mt_async(ctxt); + test_exchange::run_mt_async_ret(ctxt); + test_exchange::run_mt_deferred_ret( + ctxt); + } + } } - else + catch (std::runtime_error const& e) { - test_exchange::run_mt(ctxt); - test_exchange::run_mt_async(ctxt); - test_exchange::run_mt_async_ret(ctxt); - test_exchange::run_mt_deferred_ret(ctxt); + if (thread_safe) ghex::test::handle_nccl_thread_safe_exception(world, e); + else + ghex::test::handle_nccl_self_comm_exception(world, e); } } @@ -519,19 +582,38 @@ TEST_F(mpi_test_fixture, exchange_host_device) { using namespace ghex; EXPECT_TRUE((world_size == 1) || (world_size % 2 == 0)); - context ctxt(world, thread_safe); - - if (!thread_safe) + if (ghex::test::is_nccl_backend(world)) + { + GTEST_SKIP() << "mixed-architecture exchanges not supported with NCCL backend"; + } + try { - test_exchange::run(ctxt); - test_exchange::run_split(ctxt); + context ctxt(world, thread_safe); + + if (!thread_safe) + { + test_exchange::run(ctxt); + if (!ghex::test::is_nccl_backend(world)) + { + test_exchange::run_split(ctxt); + } + } + else + { + if (!ghex::test::is_nccl_backend(world)) + { + test_exchange::run_mt(ctxt); + test_exchange::run_mt_async(ctxt); + test_exchange::run_mt_async_ret(ctxt); + test_exchange::run_mt_deferred_ret(ctxt); + } + } } - else + catch (std::runtime_error const& e) { - test_exchange::run_mt(ctxt); - test_exchange::run_mt_async(ctxt); - test_exchange::run_mt_async_ret(ctxt); - test_exchange::run_mt_deferred_ret(ctxt); + if (thread_safe) ghex::test::handle_nccl_thread_safe_exception(world, e); + else + ghex::test::handle_nccl_self_comm_exception(world, e); } } @@ -539,19 +621,39 @@ TEST_F(mpi_test_fixture, exchange_host_device_vector) { using namespace ghex; EXPECT_TRUE((world_size == 1) || (world_size % 2 == 0)); - context ctxt(world, thread_safe); - - if (!thread_safe) + if (ghex::test::is_nccl_backend(world)) + { + GTEST_SKIP() << "mixed-architecture exchanges not supported with NCCL backend"; + } + try { - test_exchange::run(ctxt); - test_exchange::run_split(ctxt); + context ctxt(world, thread_safe); + + if (!thread_safe) + { + test_exchange::run(ctxt); + if (!ghex::test::is_nccl_backend(world)) + { + test_exchange::run_split(ctxt); + } + } + else + { + if (!ghex::test::is_nccl_backend(world)) + { + test_exchange::run_mt(ctxt); + test_exchange::run_mt_async(ctxt); + test_exchange::run_mt_async_ret(ctxt); + test_exchange::run_mt_deferred_ret( + ctxt); + } + } } - else + catch (std::runtime_error const& e) { - test_exchange::run_mt(ctxt); - test_exchange::run_mt_async(ctxt); - test_exchange::run_mt_async_ret(ctxt); - test_exchange::run_mt_deferred_ret(ctxt); + if (thread_safe) ghex::test::handle_nccl_thread_safe_exception(world, e); + else + ghex::test::handle_nccl_self_comm_exception(world, e); } } #endif @@ -648,14 +750,12 @@ parameters::check_values(ghex::test::util::memory #include #include @@ -474,41 +475,50 @@ run(context& ctxt, const Pattern& pattern, const SPattern& spattern, const Domai void sim(bool multi_threaded) { - context ctxt(MPI_COMM_WORLD, multi_threaded); - // 2D domain decomposition - arr dims{0, 0}, coords{0, 0}; - MPI_Dims_create(ctxt.size(), 2, dims.data()); - coords[1] = ctxt.rank() / dims[0]; - coords[0] = ctxt.rank() - coords[1] * dims[0]; - // make 2 domains per rank - std::vector domains{make_domain(ctxt.rank(), 0, coords), - make_domain(ctxt.rank(), 1, coords)}; - // neighbor lookup - domain_lu d_lu{dims}; - - auto staged_pattern = structured::regular::make_staged_pattern(ctxt, domains, d_lu, arr{0, 0}, - arr{dims[0] * DIM - 1, dims[1] * DIM - 1}, halos, periodic); - - // make halo generator - halo_gen gen{arr{0, 0}, arr{dims[0] * DIM - 1, dims[1] * DIM - 1}, halos, periodic}; - // create a pattern for communication - auto pattern = make_pattern(ctxt, gen, domains); - // run - bool res = true; - if (multi_threaded) + try + { + context ctxt(MPI_COMM_WORLD, multi_threaded); + // 2D domain decomposition + arr dims{0, 0}, coords{0, 0}; + MPI_Dims_create(ctxt.size(), 2, dims.data()); + coords[1] = ctxt.rank() / dims[0]; + coords[0] = ctxt.rank() - coords[1] * dims[0]; + // make 2 domains per rank + std::vector domains{make_domain(ctxt.rank(), 0, coords), + make_domain(ctxt.rank(), 1, coords)}; + // neighbor lookup + domain_lu d_lu{dims}; + + auto staged_pattern = structured::regular::make_staged_pattern(ctxt, domains, d_lu, + arr{0, 0}, arr{dims[0] * DIM - 1, dims[1] * DIM - 1}, halos, periodic); + + // make halo generator + halo_gen gen{arr{0, 0}, arr{dims[0] * DIM - 1, dims[1] * DIM - 1}, halos, periodic}; + // create a pattern for communication + auto pattern = make_pattern(ctxt, gen, domains); + // run + bool res = true; + if (multi_threaded) + { + auto run_fct = [&ctxt, &pattern, &staged_pattern, &domains, &dims](int id) + { return run(ctxt, pattern, staged_pattern, domains, dims, id); }; + auto f1 = std::async(std::launch::async, run_fct, 0); + auto f2 = std::async(std::launch::async, run_fct, 1); + res = res && f1.get(); + res = res && f2.get(); + } + else { res = res && run(ctxt, pattern, staged_pattern, domains, dims); } + // reduce res + bool all_res = false; + MPI_Reduce(&res, &all_res, 1, MPI_C_BOOL, MPI_LAND, 0, MPI_COMM_WORLD); + if (ctxt.rank() == 0) { EXPECT_TRUE(all_res); } + } + catch (std::runtime_error const& e) { - auto run_fct = [&ctxt, &pattern, &staged_pattern, &domains, &dims](int id) - { return run(ctxt, pattern, staged_pattern, domains, dims, id); }; - auto f1 = std::async(std::launch::async, run_fct, 0); - auto f2 = std::async(std::launch::async, run_fct, 1); - res = res && f1.get(); - res = res && f2.get(); + if (multi_threaded) ghex::test::handle_nccl_thread_safe_exception(MPI_COMM_WORLD, e); + else + ghex::test::handle_nccl_self_comm_exception(MPI_COMM_WORLD, e); } - else { res = res && run(ctxt, pattern, staged_pattern, domains, dims); } - // reduce res - bool all_res = false; - MPI_Reduce(&res, &all_res, 1, MPI_C_BOOL, MPI_LAND, 0, MPI_COMM_WORLD); - if (ctxt.rank() == 0) { EXPECT_TRUE(all_res); } } TEST_F(mpi_test_fixture, simple_exchange) { sim(thread_safe); } diff --git a/test/test_context.cpp b/test/test_context.cpp index 72c899b4..964a32fb 100644 --- a/test/test_context.cpp +++ b/test/test_context.cpp @@ -1,7 +1,7 @@ /* * ghex-org * - * Copyright (c) 2014-2023, ETH Zurich + * Copyright (c) 2014-2026, ETH Zurich * All rights reserved. * * Please, refer to the LICENSE file in the root directory. @@ -11,6 +11,7 @@ #include #include #include "./mpi_runner/mpi_test_fixture.hpp" +#include "./util/nccl_test_helpers.hpp" #include #include #include @@ -19,7 +20,14 @@ TEST_F(mpi_test_fixture, context) { using namespace ghex; - context ctxt(world, thread_safe); + try + { + context ctxt(world, thread_safe); + } + catch (std::runtime_error const& e) + { + ghex::test::handle_nccl_thread_safe_exception(world, e); + } } #if OOMPH_ENABLE_BARRIER @@ -27,27 +35,34 @@ TEST_F(mpi_test_fixture, barrier) { using namespace ghex; - context ctxt(world, thread_safe); - - if (thread_safe) + try { - barrier b(ctxt, 1); - b.rank_barrier(); - } - else - { - barrier b(ctxt, 4); + context ctxt(world, thread_safe); + + if (thread_safe) + { + barrier b(ctxt, 1); + b.rank_barrier(); + } + else + { + barrier b(ctxt, 4); - auto use_barrier = [&]() { b(); }; + auto use_barrier = [&]() { b(); }; - auto use_thread_barrier = [&]() { b.thread_barrier(); }; + auto use_thread_barrier = [&]() { b.thread_barrier(); }; - std::vector threads; - for (int i = 0; i < 4; ++i) threads.push_back(std::thread{use_thread_barrier}); - for (int i = 0; i < 4; ++i) threads[i].join(); - threads.clear(); - for (int i = 0; i < 4; ++i) threads.push_back(std::thread{use_barrier}); - for (int i = 0; i < 4; ++i) threads[i].join(); + std::vector threads; + for (int i = 0; i < 4; ++i) threads.push_back(std::thread{use_thread_barrier}); + for (int i = 0; i < 4; ++i) threads[i].join(); + threads.clear(); + for (int i = 0; i < 4; ++i) threads.push_back(std::thread{use_barrier}); + for (int i = 0; i < 4; ++i) threads[i].join(); + } + } + catch (std::runtime_error const& e) + { + ghex::test::handle_nccl_thread_safe_exception(world, e); } } #endif diff --git a/test/unstructured/test_user_concepts.cpp b/test/unstructured/test_user_concepts.cpp index 938081a0..3da57a82 100644 --- a/test/unstructured/test_user_concepts.cpp +++ b/test/unstructured/test_user_concepts.cpp @@ -1,7 +1,7 @@ /* * ghex-org * - * Copyright (c) 2014-2023, ETH Zurich + * Copyright (c) 2014-2026, ETH Zurich * All rights reserved. * * Please, refer to the LICENSE file in the root directory. @@ -18,6 +18,7 @@ #include #include "./unstructured_test_case.hpp" #include "../util/memory.hpp" +#include "../util/nccl_test_helpers.hpp" #include #include @@ -47,50 +48,86 @@ void test_in_place_receive_threads(ghex::context& ctxt); TEST_F(mpi_test_fixture, domain_descriptor) { - ghex::context ctxt{MPI_COMM_WORLD, thread_safe}; + try + { + ghex::context ctxt{MPI_COMM_WORLD, thread_safe}; - if (world_size == 4) { test_domain_descriptor_and_halos(ctxt); } + if (world_size == 4) { test_domain_descriptor_and_halos(ctxt); } + } + catch (std::runtime_error const& e) + { + if (thread_safe) ghex::test::handle_nccl_thread_safe_exception(world, e); + else + ghex::test::handle_nccl_self_comm_exception(world, e); + } } TEST_F(mpi_test_fixture, pattern_setup) { - ghex::context ctxt{MPI_COMM_WORLD, thread_safe}; - if (world_size == 4) { test_pattern_setup(ctxt); } - else if (world_size == 2) + try { - test_pattern_setup_oversubscribe(ctxt); - test_pattern_setup_oversubscribe_asymm(ctxt); + ghex::context ctxt{MPI_COMM_WORLD, thread_safe}; + if (world_size == 4) { test_pattern_setup(ctxt); } + else if (world_size == 2) + { + test_pattern_setup_oversubscribe(ctxt); + test_pattern_setup_oversubscribe_asymm(ctxt); + } + } + catch (std::runtime_error const& e) + { + if (thread_safe) ghex::test::handle_nccl_thread_safe_exception(world, e); + else + ghex::test::handle_nccl_self_comm_exception(world, e); } } TEST_F(mpi_test_fixture, data_descriptor) { - ghex::context ctxt{MPI_COMM_WORLD, thread_safe}; - - if (world_size == 4) + try { - test_data_descriptor(ctxt, 1, true); - test_data_descriptor(ctxt, 3, true); - test_data_descriptor(ctxt, 1, false); - test_data_descriptor(ctxt, 3, false); + ghex::context ctxt{MPI_COMM_WORLD, thread_safe}; + + if (world_size == 4) + { + test_data_descriptor(ctxt, 1, true); + test_data_descriptor(ctxt, 3, true); + test_data_descriptor(ctxt, 1, false); + test_data_descriptor(ctxt, 3, false); + } + else if (world_size == 2) + { + test_data_descriptor_oversubscribe(ctxt); + if (thread_safe) test_data_descriptor_threads(ctxt); + } } - else if (world_size == 2) + catch (std::runtime_error const& e) { - test_data_descriptor_oversubscribe(ctxt); - if (thread_safe) test_data_descriptor_threads(ctxt); + if (thread_safe) ghex::test::handle_nccl_thread_safe_exception(world, e); + else + ghex::test::handle_nccl_self_comm_exception(world, e); } } TEST_F(mpi_test_fixture, data_descriptor_async) { - ghex::context ctxt{MPI_COMM_WORLD, thread_safe}; + try + { + ghex::context ctxt{MPI_COMM_WORLD, thread_safe}; - if (world_size == 4) + if (world_size == 4) + { + test_data_descriptor_async(ctxt, 1, true); + test_data_descriptor_async(ctxt, 3, true); + test_data_descriptor_async(ctxt, 1, false); + test_data_descriptor_async(ctxt, 3, false); + } + } + catch (std::runtime_error const& e) { - test_data_descriptor_async(ctxt, 1, true); - test_data_descriptor_async(ctxt, 3, true); - test_data_descriptor_async(ctxt, 1, false); - test_data_descriptor_async(ctxt, 3, false); + if (thread_safe) ghex::test::handle_nccl_thread_safe_exception(world, e); + else + ghex::test::handle_nccl_self_comm_exception(world, e); } } @@ -320,11 +357,12 @@ test_data_descriptor(ghex::context& ctxt, std::size_t levels, bool levels_first) /** @brief Test data descriptor concept*/ void -test_data_descriptor_async(ghex::context& ctxt, std::size_t levels, bool levels_first) +test_data_descriptor_async([[maybe_unused]] ghex::context& ctxt, + [[maybe_unused]] std::size_t levels, [[maybe_unused]] bool levels_first) { #ifdef GHEX_CUDACC // NOTE: Async exchange is only implemented for the GPU, however, we also - // test it for CPU memory, although it is kind of botherline. + // test it for CPU memory, although it is kind of borderline. // domain std::vector local_domains{make_domain(ctxt.rank())}; diff --git a/test/util/nccl_test_helpers.hpp b/test/util/nccl_test_helpers.hpp new file mode 100644 index 00000000..25b3387c --- /dev/null +++ b/test/util/nccl_test_helpers.hpp @@ -0,0 +1,48 @@ +/* + * ghex-org + * + * Copyright (c) 2014-2026, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include +#include +#include + +#include + +namespace ghex::test +{ +inline bool +is_nccl_backend(MPI_Comm world) +{ + return ghex::context(world, false).transport_context()->get_transport_option("name") == + std::string("nccl"); +} + +inline void +handle_nccl_thread_safe_exception(MPI_Comm world, std::runtime_error const& e) +{ + if (is_nccl_backend(world)) + { + EXPECT_STREQ(e.what(), "NCCL not supported with thread_safe = true"); + } + else { throw; } +} + +inline void +handle_nccl_self_comm_exception(MPI_Comm world, std::runtime_error const& e) +{ + if (is_nccl_backend(world)) + { + EXPECT_STREQ(e.what(), "oomph NCCL backend: self-send/recv requires an active NCCL group. " + "Use start_group()/end_group() around self-send/recv operations."); + } + else { throw; } +} +} // namespace ghex::test