Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/include/raft/linalg/strided_reduction.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void stridedReduction(OutType* dots,
// other cases, because coalescedReduction supports arbitrary types.
if constexpr (std::is_same_v<OutType, float> || std::is_same_v<OutType, double> ||
std::is_same_v<OutType, int> || std::is_same_v<OutType, long long> ||
std::is_same_v<OutType, unsigned long long>) {
std::is_same_v<OutType, unsigned long long> || std::is_same_v<OutType, half>) {
detail::stridedReduction<InType, OutType, IdxType>(
dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op);
} else {
Expand Down
35 changes: 22 additions & 13 deletions cpp/include/raft/stats/detail/mean.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,33 @@
#include <raft/linalg/reduce.cuh>
#include <raft/util/cuda_utils.cuh>

#include <type_traits>

namespace raft {
namespace stats {
namespace detail {

template <bool rowMajor, typename Type, typename IdxType = int>
void mean(Type* mu, const Type* data, IdxType D, IdxType N, cudaStream_t stream)
template <bool rowMajor, typename OutType, typename InType, typename IdxType = int>
void mean(OutType* mu, const InType* data, IdxType D, IdxType N, cudaStream_t stream)
{
Type ratio = Type(1) / Type(N);
raft::linalg::reduce<rowMajor, false>(mu,
data,
D,
N,
Type(0),
stream,
false,
raft::identity_op(),
raft::add_op(),
raft::mul_const_op<Type>(ratio));
OutType ratio = OutType(1) / OutType(N);
auto main_op = [=]() {
if constexpr (std::is_same_v<InType, OutType>) {
return raft::identity_op();
} else {
return raft::cast_op<OutType>();
}
}();
raft::linalg::reduce<rowMajor, false, InType, OutType>(mu,
data,
D,
N,
OutType(0),
stream,
false,
main_op,
raft::add_op(),
raft::mul_const_op<OutType>(ratio));
}

template <bool rowMajor, typename Type, typename IdxType = int>
Expand Down
37 changes: 30 additions & 7 deletions cpp/include/raft/stats/mean.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <raft/core/resources.hpp>
#include <raft/stats/detail/mean.cuh>

#include <type_traits>

namespace raft {
namespace stats {

Expand Down Expand Up @@ -68,21 +70,22 @@ template <bool rowMajor, typename Type, typename IdxType = int>
*/

/**
* @brief Compute mean of the input matrix
* @brief Compute mean of the input matrix with different input and output data types.
*
* Mean operation is assumed to be performed on a given column.
*
* @tparam value_t the data type
* @tparam in_value_t the input data type
* @tparam out_value_t the output data type
* @tparam idx_t index type
* @tparam layout_t Layout type of the input matrix.
* @param[in] handle the raft handle
* @param[in] data: the input matrix
* @param[out] mu: the output mean vector
* @param[in] data the input matrix
* @param[out] mu the output mean vector
*/
template <typename value_t, typename idx_t, typename layout_t>
template <typename in_value_t, typename out_value_t, typename idx_t, typename layout_t>
void mean(raft::resources const& handle,
raft::device_matrix_view<const value_t, idx_t, layout_t> data,
raft::device_vector_view<value_t, idx_t> mu)
raft::device_matrix_view<const in_value_t, idx_t, layout_t> data,
raft::device_vector_view<out_value_t, idx_t> mu)
{
static_assert(
std::is_same_v<layout_t, raft::row_major> || std::is_same_v<layout_t, raft::col_major>,
Expand All @@ -97,6 +100,26 @@ void mean(raft::resources const& handle,
resource::get_cuda_stream(handle));
}

/**
* @brief Compute mean of the input matrix
*
* Mean operation is assumed to be performed on a given column.
*
* @tparam value_t the data type
* @tparam idx_t index type
* @tparam layout_t Layout type of the input matrix.
* @param[in] handle the raft handle
* @param[in] data: the input matrix
* @param[out] mu: the output mean vector
*/
template <typename value_t, typename idx_t, typename layout_t>
void mean(raft::resources const& handle,
raft::device_matrix_view<const value_t, idx_t, layout_t> data,
raft::device_vector_view<value_t, idx_t> mu)
{
mean<value_t, value_t, idx_t, layout_t>(handle, data, mu);
}

/**
* @brief Compute mean of the input matrix
*
Expand Down
125 changes: 101 additions & 24 deletions cpp/tests/stats/mean.cu
Original file line number Diff line number Diff line change
@@ -1,49 +1,74 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2018-2024, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include "../test_utils.cuh"

#include <raft/core/resource/cuda_stream.hpp>
#include <raft/linalg/unary_op.cuh>
#include <raft/random/rng.cuh>
#include <raft/stats/mean.cuh>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <cuda_fp16.h>

#include <gtest/gtest.h>
#include <stdio.h>
#include <stdlib.h>

#include <cstdint>
#include <type_traits>
#include <vector>

namespace raft {
namespace stats {

template <typename T>
float toFloat(T value)
{
return static_cast<float>(value);
}

template <>
inline float toFloat<half>(half value)
{
return __half2float(value);
}

struct float_to_half_op {
__device__ half operator()(float x) const { return __float2half(x); }
};

template <typename InputT, typename OutputT = InputT>
struct MeanInputs {
T tolerance, mean;
OutputT tolerance;
InputT mean;
int rows, cols;
bool rowMajor;
unsigned long long int seed;
T stddev = (T)1.0;
InputT stddev = (InputT)1.0;
};

template <typename T>
::std::ostream& operator<<(::std::ostream& os, const MeanInputs<T>& dims)
template <typename InputT, typename OutputT>
::std::ostream& operator<<(::std::ostream& os, const MeanInputs<InputT, OutputT>& dims)
{
return os << "{ " << dims.tolerance << ", " << dims.rows << ", " << dims.cols << ", "
<< ", " << dims.rowMajor << ", " << dims.stddev << "}" << std::endl;
return os << "{ tol=" << toFloat(dims.tolerance) << ", mean=" << toFloat(dims.mean)
<< ", rows=" << dims.rows << ", cols=" << dims.cols << ", rowMajor=" << dims.rowMajor
<< ", stddev=" << toFloat(dims.stddev) << "}" << std::endl;
}

template <typename T>
class MeanTest : public ::testing::TestWithParam<MeanInputs<T>> {
template <typename InputT, typename OutputT = InputT>
class MeanTest : public ::testing::TestWithParam<MeanInputs<InputT, OutputT>> {
public:
MeanTest()
: params(::testing::TestWithParam<MeanInputs<T>>::GetParam()),
: params(::testing::TestWithParam<MeanInputs<InputT, OutputT>>::GetParam()),
stream(resource::get_cuda_stream(handle)),
rows(params.rows),
cols(params.cols),
data(rows * cols, stream),
mean_act(cols, stream)
data(raft::make_device_matrix<InputT, int>(handle, rows, cols)),
mean_act(raft::make_device_vector<OutputT, int>(handle, cols))
{
}

Expand All @@ -52,33 +77,42 @@ class MeanTest : public ::testing::TestWithParam<MeanInputs<T>> {
{
raft::random::RngState r(params.seed);
int len = rows * cols;
normal(handle, r, data.data(), len, params.mean, params.stddev);
meanSGtest(data.data(), stream);
if constexpr (std::is_same_v<InputT, half>) {
rmm::device_uvector<float> data_float(len, stream);
normal(handle, r, data_float.data(), len, toFloat(params.mean), toFloat(params.stddev));
raft::linalg::unaryOp(data.data_handle(), data_float.data(), len, float_to_half_op{}, stream);
} else if constexpr (std::is_integral_v<InputT>) {
normalInt(handle, r, data.data_handle(), len, params.mean, params.stddev);
} else {
normal(handle, r, data.data_handle(), len, params.mean, params.stddev);
}
meanSGtest();
}

void meanSGtest(T* data, cudaStream_t stream)
void meanSGtest()
{
int rows = params.rows, cols = params.cols;
if (params.rowMajor) {
using layout = raft::row_major;
mean(handle,
raft::make_device_matrix_view<const T, int, layout>(data, rows, cols),
raft::make_device_vector_view<T, int>(mean_act.data(), cols));
raft::make_device_matrix_view<const InputT, int, layout>(data.data_handle(), rows, cols),
raft::make_device_vector_view<OutputT, int>(mean_act.data_handle(), cols));
} else {
using layout = raft::col_major;
mean(handle,
raft::make_device_matrix_view<const T, int, layout>(data, rows, cols),
raft::make_device_vector_view<T, int>(mean_act.data(), cols));
raft::make_device_matrix_view<const InputT, int, layout>(data.data_handle(), rows, cols),
raft::make_device_vector_view<OutputT, int>(mean_act.data_handle(), cols));
}
}

protected:
raft::resources handle;
cudaStream_t stream;

MeanInputs<T> params;
MeanInputs<InputT, OutputT> params;
int rows, cols;
rmm::device_uvector<T> data, mean_act;
raft::device_matrix<InputT, int> data;
raft::device_vector<OutputT, int> mean_act;
};

// Note: For 1024 samples, 256 experiments, a mean of 1.0 with stddev=1.0, the
Expand Down Expand Up @@ -131,23 +165,66 @@ const std::vector<MeanInputs<double>> inputsd = {{0.15, -1.0, 1024, 32, false, 1
{1e-8, 1e-1, 1 << 27, 2, false, 1234ULL, 0.0001},
{1e-8, 1e-1, 1 << 27, 2, true, 1234ULL, 0.0001}};

const std::vector<MeanInputs<half, float>> inputshf = {
{0.15f, -1.f, 1024, 32, false, 1234ULL},
{0.15f, -1.f, 1024, 64, false, 1234ULL},
{0.15f, -1.f, 1024, 128, false, 1234ULL},
{0.15f, -1.f, 1024, 256, false, 1234ULL},
{0.15f, -1.f, 1024, 32, true, 1234ULL},
{0.15f, -1.f, 1024, 64, true, 1234ULL},
{0.0001f, 0.1f, 1 << 27, 2, false, 1234ULL, 0.0001f}};

const std::vector<MeanInputs<int8_t, half>> inputsi8h = {{0.95f, -5, 8096, 32, false, 1234ULL, 1},
{0.5f, 1, 8096, 10, false, 1234ULL, 10},
{0.15f, 0, 60000, 128, false, 1234ULL, 6},
{0.5f, -1, 8096, 256, false, 1234ULL, 2},
{1.0f, 8, 2000, 32, true, 1234ULL, 1},
{0.50f, -1, 20000, 64, true, 1234ULL, 5},
{1.0f, 6, 10024, 2, false, 1234ULL, 10}};

typedef MeanTest<float> MeanTestF;
TEST_P(MeanTestF, Result)
{
ASSERT_TRUE(
devArrMatch(params.mean, mean_act.data(), params.cols, CompareApprox<float>(params.tolerance)));
ASSERT_TRUE(devArrMatch(
params.mean, mean_act.data_handle(), params.cols, CompareApprox<float>(params.tolerance)));
}

typedef MeanTest<double> MeanTestD;
TEST_P(MeanTestD, Result)
{
ASSERT_TRUE(devArrMatch(
params.mean, mean_act.data(), params.cols, CompareApprox<double>(params.tolerance)));
params.mean, mean_act.data_handle(), params.cols, CompareApprox<double>(params.tolerance)));
}

INSTANTIATE_TEST_SUITE_P(MeanTests, MeanTestF, ::testing::ValuesIn(inputsf));

INSTANTIATE_TEST_SUITE_P(MeanTests, MeanTestD, ::testing::ValuesIn(inputsd));

typedef MeanTest<half, float> MeanTestHF;
TEST_P(MeanTestHF, Result)
{
ASSERT_TRUE(devArrMatch(toFloat(params.mean),
mean_act.data_handle(),
params.cols,
CompareApprox<float>(params.tolerance)));
}

typedef MeanTest<int8_t, half> MeanTestI8H;
TEST_P(MeanTestI8H, Result)
{
std::vector<half> mean_act_h(params.cols);
raft::update_host(mean_act_h.data(), mean_act.data_handle(), params.cols, stream);
raft::resource::sync_stream(handle);

auto expected = toFloat(params.mean);
auto tolerance = toFloat(params.tolerance);
for (int i = 0; i < params.cols; ++i) {
ASSERT_NEAR(toFloat(mean_act_h[i]), expected, tolerance) << " @col=" << i;
}
}

INSTANTIATE_TEST_SUITE_P(MeanTests, MeanTestHF, ::testing::ValuesIn(inputshf));

INSTANTIATE_TEST_SUITE_P(MeanTests, MeanTestI8H, ::testing::ValuesIn(inputsi8h));
} // end namespace stats
} // end namespace raft
Loading