diff --git a/cpp/include/raft/matrix/detail/gather_inplace.cuh b/cpp/include/raft/matrix/detail/gather_inplace.cuh index 1cfd7664ec..a30a3430c9 100644 --- a/cpp/include/raft/matrix/detail/gather_inplace.cuh +++ b/cpp/include/raft/matrix/detail/gather_inplace.cuh @@ -54,7 +54,7 @@ void gatherInplaceImpl(raft::resources const& handle, transform_op, batch_offset, map_length, - cols_per_batch = raft::util::FastIntDiv(cols_per_batch), + cols_per_batch = raft::util::FastIntDiv(cols_per_batch), n, ld] __device__(auto idx) { IndexT row = idx / cols_per_batch; @@ -74,7 +74,7 @@ void gatherInplaceImpl(raft::resources const& handle, scratch_space = scratch_space.data_handle(), batch_offset, map_length, - cols_per_batch = raft::util::FastIntDiv(cols_per_batch), + cols_per_batch = raft::util::FastIntDiv(cols_per_batch), n, ld] __device__(auto idx) { IndexT row = idx / cols_per_batch; diff --git a/cpp/include/raft/matrix/detail/scatter_inplace.cuh b/cpp/include/raft/matrix/detail/scatter_inplace.cuh index 2c735e3fda..392928143a 100644 --- a/cpp/include/raft/matrix/detail/scatter_inplace.cuh +++ b/cpp/include/raft/matrix/detail/scatter_inplace.cuh @@ -77,7 +77,7 @@ void scatterInplaceImpl( auto copy_op = [inout = inout.data_handle(), map = map.data_handle(), batch_offset, - cols_per_batch = raft::util::FastIntDiv(cols_per_batch), + cols_per_batch = raft::util::FastIntDiv(cols_per_batch), n] __device__(auto idx) { IndexT row = idx / cols_per_batch; IndexT col = idx % cols_per_batch; @@ -92,7 +92,7 @@ void scatterInplaceImpl( map = map.data_handle(), scratch_space = scratch_space.data_handle(), batch_offset, - cols_per_batch = raft::util::FastIntDiv(cols_per_batch), + cols_per_batch = raft::util::FastIntDiv(cols_per_batch), n] __device__(auto idx) { IndexT row = idx / cols_per_batch; IndexT col = idx % cols_per_batch; diff --git a/cpp/include/raft/util/fast_int_div.cuh b/cpp/include/raft/util/fast_int_div.cuh index 527fd46e14..d7f385a5d4 100644 --- a/cpp/include/raft/util/fast_int_div.cuh +++ b/cpp/include/raft/util/fast_int_div.cuh @@ -10,6 +10,8 @@ #include +#include + namespace raft { namespace util { @@ -17,18 +19,25 @@ namespace util { * @brief Perform fast integer division and modulo using a known divisor * From Hacker's Delight, Second Edition, Chapter 10 * - * @note This currently only supports 32b signed integers + * @note 32b signed integer is supported. + * @note 64b signed integers is supported for an input data up to 2^31 + * because gpu-non-native int128 is avoided for performance. * @todo Extend support for signed divisors */ +template struct FastIntDiv { + static_assert(std::is_same_v || std::is_same_v, + "FastIntDiv: IntT must be int32_t or int64_t"); + using UIntT = std::make_unsigned_t; + /** * @defgroup HostMethods Ctor's that are accessible only from host * @{ * @brief Host-only ctor's * @param _d the divisor */ - FastIntDiv(int _d) : d(_d) { computeScalars(); } - FastIntDiv& operator=(int _d) + FastIntDiv(IntT _d) : d(_d) { computeScalars(); } + FastIntDiv& operator=(IntT _d) { d = _d; computeScalars(); @@ -53,9 +62,9 @@ struct FastIntDiv { /** @} */ /** divisor */ - int d; + IntT d; /** the term 'm' as found in the reference chapter */ - unsigned m; + UIntT m; /** the term 'p' as found in the reference chapter */ int p; @@ -90,10 +99,11 @@ struct FastIntDiv { * @param divisor the denominator * @return the quotient */ -HDI int operator/(int n, const FastIntDiv& divisor) +template +HDI IntT operator/(IntT n, const FastIntDiv& divisor) { if (divisor.d == 1) return n; - int ret = (int64_t(divisor.m) * int64_t(n)) >> divisor.p; + IntT ret = (int64_t(divisor.m) * int64_t(n)) >> divisor.p; if (n < 0) ++ret; return ret; } @@ -105,10 +115,11 @@ HDI int operator/(int n, const FastIntDiv& divisor) * @param divisor the denominator * @return the remainder */ -HDI int operator%(int n, const FastIntDiv& divisor) +template +HDI IntT operator%(IntT n, const FastIntDiv& divisor) { - int quotient = n / divisor; - int remainder = n - quotient * divisor.d; + IntT quotient = n / divisor; + IntT remainder = n - quotient * divisor.d; return remainder; } diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 00bfe1f32a..7127312b7b 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -317,6 +317,7 @@ if(BUILD_TESTS) util/bitonic_sort.cu util/cudart_utils.cpp util/device_atomics.cu + util/fast_int_div.cu util/integer_utils.cpp util/integer_utils.cu util/memory_type_dispatcher.cu diff --git a/cpp/tests/util/fast_int_div.cu b/cpp/tests/util/fast_int_div.cu new file mode 100644 index 0000000000..eb391c1dbb --- /dev/null +++ b/cpp/tests/util/fast_int_div.cu @@ -0,0 +1,54 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include + +#include +#include +#include + +namespace raft::util { + +constexpr int64_t kInt32Max = std::numeric_limits::max(); + +template +class FastIntDivTest : public ::testing::Test { + protected: + void CompareWithNativeDivision() + { + std::vector magnitudes{0, 1, 2, 3, 7, 13, 255, 12345, (1 << 20), kInt32Max}; + std::vector divisors{1, 2, 4, 7, 16, 31, 63, 128, 1000, (1 << 15), kInt32Max}; + + for (IntT d : divisors) { + FastIntDiv fid(d); + for (IntT mag : magnitudes) { + for (IntT n : {mag, -mag}) { + ASSERT_EQ(n / fid, n / d) << "operator/ mismatch for numerator=" << n << " divisor=" << d; + ASSERT_EQ(n % fid, n % d) << "operator% mismatch for numerator=" << n << " divisor=" << d; + } + } + } + } +}; + +using FastIntDivTypes = ::testing::Types; +TYPED_TEST_CASE(FastIntDivTest, FastIntDivTypes); + +TYPED_TEST(FastIntDivTest, CompareWithNativeDivision) { this->CompareWithNativeDivision(); } + +TEST(FastIntDiv, Int64NumeratorPastInt32Boundary) +{ + for (int64_t d : {129, 772, 1000}) { + FastIntDiv fid(d); + for (int64_t n : {1LL << 31, 2147704000LL, 3LL << 30}) { // in [2^31, 2^32) + ASSERT_EQ(n / fid, n / d) << "operator/ mismatch for numerator=" << n << " divisor=" << d; + ASSERT_EQ(n % fid, n % d) << "operator% mismatch for numerator=" << n << " divisor=" << d; + } + } +} + +} // namespace raft::util