Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/include/raft/matrix/detail/gather_inplace.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void gatherInplaceImpl(raft::resources const& handle,
transform_op,
batch_offset,
map_length,
cols_per_batch = raft::util::FastIntDiv(cols_per_batch),
cols_per_batch = raft::util::FastIntDiv<IndexT>(cols_per_batch),
n,
ld] __device__(auto idx) {
IndexT row = idx / cols_per_batch;
Expand All @@ -74,7 +74,7 @@ void gatherInplaceImpl(raft::resources const& handle,
scratch_space = scratch_space.data_handle(),
batch_offset,
map_length,
cols_per_batch = raft::util::FastIntDiv(cols_per_batch),
cols_per_batch = raft::util::FastIntDiv<IndexT>(cols_per_batch),
n,
ld] __device__(auto idx) {
IndexT row = idx / cols_per_batch;
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/matrix/detail/scatter_inplace.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void scatterInplaceImpl(
auto copy_op = [inout = inout.data_handle(),
map = map.data_handle(),
batch_offset,
cols_per_batch = raft::util::FastIntDiv(cols_per_batch),
cols_per_batch = raft::util::FastIntDiv<IndexT>(cols_per_batch),
n] __device__(auto idx) {
IndexT row = idx / cols_per_batch;
IndexT col = idx % cols_per_batch;
Expand All @@ -92,7 +92,7 @@ void scatterInplaceImpl(
map = map.data_handle(),
scratch_space = scratch_space.data_handle(),
batch_offset,
cols_per_batch = raft::util::FastIntDiv(cols_per_batch),
cols_per_batch = raft::util::FastIntDiv<IndexT>(cols_per_batch),
n] __device__(auto idx) {
IndexT row = idx / cols_per_batch;
IndexT col = idx % cols_per_batch;
Expand Down
31 changes: 21 additions & 10 deletions cpp/include/raft/util/fast_int_div.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,34 @@

#include <stdint.h>

#include <type_traits>

namespace raft {
namespace util {

/**
* @brief Perform fast integer division and modulo using a known divisor
* From Hacker's Delight, Second Edition, Chapter 10
*
* @note This currently only supports 32b signed integers
* @note 32b signed integer is supported.
* @note 64b signed integers is supported for an input data up to 2^31
* because gpu-non-native int128 is avoided for performance.
* @todo Extend support for signed divisors
*/
template <typename IntT>
struct FastIntDiv {
static_assert(std::is_same_v<IntT, int32_t> || std::is_same_v<IntT, int64_t>,
"FastIntDiv: IntT must be int32_t or int64_t");
using UIntT = std::make_unsigned_t<IntT>;

/**
* @defgroup HostMethods Ctor's that are accessible only from host
* @{
* @brief Host-only ctor's
* @param _d the divisor
*/
FastIntDiv(int _d) : d(_d) { computeScalars(); }
FastIntDiv& operator=(int _d)
FastIntDiv(IntT _d) : d(_d) { computeScalars(); }
FastIntDiv& operator=(IntT _d)
{
d = _d;
computeScalars();
Expand All @@ -53,9 +62,9 @@ struct FastIntDiv {
/** @} */

/** divisor */
int d;
IntT d;
/** the term 'm' as found in the reference chapter */
unsigned m;
UIntT m;
/** the term 'p' as found in the reference chapter */
int p;

Expand Down Expand Up @@ -90,10 +99,11 @@ struct FastIntDiv {
* @param divisor the denominator
* @return the quotient
*/
HDI int operator/(int n, const FastIntDiv& divisor)
template <typename IntT>
HDI IntT operator/(IntT n, const FastIntDiv<IntT>& divisor)
{
if (divisor.d == 1) return n;
int ret = (int64_t(divisor.m) * int64_t(n)) >> divisor.p;
IntT ret = (int64_t(divisor.m) * int64_t(n)) >> divisor.p;
if (n < 0) ++ret;
return ret;
}
Expand All @@ -105,10 +115,11 @@ HDI int operator/(int n, const FastIntDiv& divisor)
* @param divisor the denominator
* @return the remainder
*/
HDI int operator%(int n, const FastIntDiv& divisor)
template <typename IntT>
HDI IntT operator%(IntT n, const FastIntDiv<IntT>& divisor)
{
int quotient = n / divisor;
int remainder = n - quotient * divisor.d;
IntT quotient = n / divisor;
IntT remainder = n - quotient * divisor.d;
return remainder;
}

Expand Down
1 change: 1 addition & 0 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ if(BUILD_TESTS)
util/bitonic_sort.cu
util/cudart_utils.cpp
util/device_atomics.cu
util/fast_int_div.cu
util/integer_utils.cpp
util/integer_utils.cu
util/memory_type_dispatcher.cu
Expand Down
54 changes: 54 additions & 0 deletions cpp/tests/util/fast_int_div.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include <raft/util/fast_int_div.cuh>

#include <gtest/gtest.h>

#include <cstdint>
#include <limits>
#include <vector>

namespace raft::util {

constexpr int64_t kInt32Max = std::numeric_limits<int32_t>::max();

template <typename IntT>
class FastIntDivTest : public ::testing::Test {
protected:
void CompareWithNativeDivision()
{
std::vector<IntT> magnitudes{0, 1, 2, 3, 7, 13, 255, 12345, (1 << 20), kInt32Max};
std::vector<IntT> divisors{1, 2, 4, 7, 16, 31, 63, 128, 1000, (1 << 15), kInt32Max};

for (IntT d : divisors) {
FastIntDiv fid(d);
for (IntT mag : magnitudes) {
for (IntT n : {mag, -mag}) {
ASSERT_EQ(n / fid, n / d) << "operator/ mismatch for numerator=" << n << " divisor=" << d;
ASSERT_EQ(n % fid, n % d) << "operator% mismatch for numerator=" << n << " divisor=" << d;
}
}
}
}
};

using FastIntDivTypes = ::testing::Types<int32_t, int64_t>;
TYPED_TEST_CASE(FastIntDivTest, FastIntDivTypes);

TYPED_TEST(FastIntDivTest, CompareWithNativeDivision) { this->CompareWithNativeDivision(); }

TEST(FastIntDiv, Int64NumeratorPastInt32Boundary)
{
for (int64_t d : {129, 772, 1000}) {
FastIntDiv fid(d);
for (int64_t n : {1LL << 31, 2147704000LL, 3LL << 30}) { // in [2^31, 2^32)
ASSERT_EQ(n / fid, n / d) << "operator/ mismatch for numerator=" << n << " divisor=" << d;
ASSERT_EQ(n % fid, n % d) << "operator% mismatch for numerator=" << n << " divisor=" << d;
}
}
}

} // namespace raft::util
Loading