Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 109 additions & 63 deletions cpp/include/raft/linalg/detail/pca.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,22 @@
namespace raft {
namespace linalg::detail {

template <typename math_t, typename idx_t>
template <typename math_t, typename idx_t, typename LayoutPolicy>
void trunc_comp_exp_vars(raft::resources const& handle,
const paramsTSVD& prms,
raft::device_matrix_view<math_t, idx_t, raft::col_major> in,
raft::device_matrix_view<math_t, idx_t, raft::col_major> components,
raft::device_matrix_view<math_t, idx_t, LayoutPolicy> in,
raft::device_matrix_view<math_t, idx_t, LayoutPolicy> components,
raft::device_vector_view<math_t, idx_t> explained_var,
raft::device_vector_view<math_t, idx_t> explained_var_ratio,
raft::device_scalar_view<math_t, idx_t> noise_vars,
std::size_t n_rows)
{
static_assert(
std::is_same_v<LayoutPolicy, raft::row_major> || std::is_same_v<LayoutPolicy, raft::col_major>,
"trunc_comp_exp_vars: layout must be raft::row_major or raft::col_major");

constexpr bool is_row_major = std::is_same_v<LayoutPolicy, raft::row_major>;

auto stream = resource::get_cuda_stream(handle);

auto n_cols = in.extent(0);
Expand All @@ -49,19 +55,26 @@ void trunc_comp_exp_vars(raft::resources const& handle,
rmm::device_uvector<math_t> explained_var_all(static_cast<std::size_t>(n_cols), stream);
rmm::device_uvector<math_t> explained_var_ratio_all(static_cast<std::size_t>(n_cols), stream);

detail::cal_eig<math_t, idx_t>(
detail::cal_eig<math_t, idx_t, LayoutPolicy>(
handle,
prms,
in,
raft::make_device_matrix_view<math_t, idx_t, raft::col_major>(
raft::make_device_matrix_view<math_t, idx_t, LayoutPolicy>(
components_all.data(), n_cols, n_cols),
raft::make_device_vector_view<math_t, idx_t>(explained_var_all.data(), n_cols));
raft::matrix::trunc_zero_origin(
handle,
raft::make_device_matrix_view<const math_t, idx_t, raft::col_major>(
components_all.data(), n_cols, n_cols),
raft::make_device_matrix_view<math_t, idx_t, raft::col_major>(
components.data_handle(), n_components, n_cols));
if constexpr (is_row_major) {
raft::copy(components.data_handle(),
components_all.data(),
static_cast<std::size_t>(n_components) * static_cast<std::size_t>(n_cols),
stream);
} else {
raft::matrix::trunc_zero_origin(
handle,
raft::make_device_matrix_view<const math_t, idx_t, raft::col_major>(
components_all.data(), n_cols, n_cols),
raft::make_device_matrix_view<math_t, idx_t, raft::col_major>(
components.data_handle(), n_components, n_cols));
}
raft::matrix::ratio(handle,
raft::make_device_matrix_view<const math_t, idx_t, raft::col_major>(
explained_var_all.data(), n_cols, idx_t(1)),
Expand Down Expand Up @@ -98,29 +111,37 @@ void trunc_comp_exp_vars(raft::resources const& handle,

/**
* @brief perform fit operation for PCA.
* @tparam math_t element type
* @tparam idx_t index type
* @tparam LayoutPolicy layout of the input matrix (raft::row_major or raft::col_major)
* @param[in] handle: raft::resources
* @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.)
* @param[inout] input: the data is fitted to PCA. Size n_rows x n_cols (col-major).
* @param[out] components: the principal components. Size n_components x n_cols (col-major).
* @param[inout] input: the data is fitted to PCA. Size n_rows x n_cols.
* @param[out] components: the principal components. Size n_components x n_cols.
* @param[out] explained_var: explained variances. Size n_components.
* @param[out] explained_var_ratio: ratio of explained to total variance. Size n_components.
* @param[out] singular_vals: singular values. Size n_components.
* @param[out] mu: mean of all features. Size n_cols.
* @param[out] noise_vars: noise variance scalar.
* @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false)
*/
template <typename math_t, typename idx_t>
template <typename math_t, typename idx_t, typename LayoutPolicy>
void pca_fit(raft::resources const& handle,
const paramsPCA& prms,
raft::device_matrix_view<math_t, idx_t, raft::col_major> input,
raft::device_matrix_view<math_t, idx_t, raft::col_major> components,
raft::device_matrix_view<math_t, idx_t, LayoutPolicy> input,
raft::device_matrix_view<math_t, idx_t, LayoutPolicy> components,
raft::device_vector_view<math_t, idx_t> explained_var,
raft::device_vector_view<math_t, idx_t> explained_var_ratio,
raft::device_vector_view<math_t, idx_t> singular_vals,
raft::device_vector_view<math_t, idx_t> mu,
raft::device_scalar_view<math_t, idx_t> noise_vars,
bool flip_signs_based_on_U = false)
{
static_assert(
std::is_same_v<LayoutPolicy, raft::row_major> || std::is_same_v<LayoutPolicy, raft::col_major>,
"pca_fit: input layout must be raft::row_major or raft::col_major");
constexpr bool input_row_major = std::is_same_v<LayoutPolicy, raft::row_major>;

auto stream = resource::get_cuda_stream(handle);
auto cublas_handle = raft::resource::get_cublas_handle(handle);

Expand All @@ -134,18 +155,19 @@ void pca_fit(raft::resources const& handle,
ASSERT(n_components > 0, "Parameter n_components: number of components cannot be less than one");
ASSERT(n_components <= n_cols, "n_components cannot exceed n_cols");

raft::stats::mean<false>(mu.data_handle(), input.data_handle(), n_cols, n_rows, false, stream);
raft::stats::mean<input_row_major>(
mu.data_handle(), input.data_handle(), n_cols, n_rows, false, stream);

auto len = static_cast<std::size_t>(n_cols * n_cols);
rmm::device_uvector<math_t> cov(len, stream);

raft::stats::cov<false>(
raft::stats::cov<input_row_major>(
handle, cov.data(), input.data_handle(), mu.data_handle(), n_cols, n_rows, true, true, stream);

detail::trunc_comp_exp_vars(
detail::trunc_comp_exp_vars<math_t, idx_t, LayoutPolicy>(
handle,
prms,
raft::make_device_matrix_view<math_t, idx_t, raft::col_major>(cov.data(), n_cols, n_cols),
raft::make_device_matrix_view<math_t, idx_t, LayoutPolicy>(cov.data(), n_cols, n_cols),
components,
explained_var,
explained_var_ratio,
Expand All @@ -161,31 +183,39 @@ void pca_fit(raft::resources const& handle,
raft::make_host_scalar_view(&scalar),
true);

raft::stats::meanAdd<false, true>(
raft::stats::meanAdd<input_row_major, true>(
input.data_handle(), input.data_handle(), mu.data_handle(), n_cols, n_rows, stream);

detail::sign_flip_components(handle, input, components, true, flip_signs_based_on_U);
}

/**
* @brief performs transform operation for PCA. Transforms the data to eigenspace.
* @tparam math_t element type
* @tparam idx_t index type
* @tparam LayoutPolicy layout (raft::row_major or raft::col_major)
* @param[in] handle: raft::resources
* @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.)
* @param[inout] input: the data to transform. Size n_rows x n_cols (col-major).
* @param[in] components: principal components. Size n_components x n_cols (col-major).
* @param[inout] input: the data to transform. Size n_rows x n_cols.
* @param[in] components: principal components. Size n_components x n_cols.
* @param[in] singular_vals: singular values. Size n_components.
* @param[in] mu: mean of features. Size n_cols.
* @param[out] trans_input: the transformed data. Size n_rows x n_components (col-major).
* @param[out] trans_input: the transformed data. Size n_rows x n_components.
*/
template <typename math_t, typename idx_t>
template <typename math_t, typename idx_t, typename LayoutPolicy>
void pca_transform(raft::resources const& handle,
const paramsPCA& prms,
raft::device_matrix_view<math_t, idx_t, raft::col_major> input,
raft::device_matrix_view<math_t, idx_t, raft::col_major> components,
raft::device_matrix_view<math_t, idx_t, LayoutPolicy> input,
raft::device_matrix_view<math_t, idx_t, LayoutPolicy> components,
raft::device_vector_view<math_t, idx_t> singular_vals,
raft::device_vector_view<math_t, idx_t> mu,
raft::device_matrix_view<math_t, idx_t, raft::col_major> trans_input)
raft::device_matrix_view<math_t, idx_t, LayoutPolicy> trans_input)
{
static_assert(
std::is_same_v<LayoutPolicy, raft::row_major> || std::is_same_v<LayoutPolicy, raft::col_major>,
"pca_transform: layout must be raft::row_major or raft::col_major");
constexpr bool input_row_major = std::is_same_v<LayoutPolicy, raft::row_major>;

auto stream = resource::get_cuda_stream(handle);

auto n_rows = input.extent(0);
Expand All @@ -204,45 +234,56 @@ void pca_transform(raft::resources const& handle,
math_t scalar = math_t(sqrt(n_rows - 1));
raft::linalg::scalarMultiply(
components_copy.data(), components_copy.data(), scalar, components_len, stream);
raft::linalg::binary_div_skip_zero<raft::Apply::ALONG_ROWS>(
raft::linalg::binary_div_skip_zero<raft::Apply::ALONG_COLUMNS>(
handle,
raft::make_device_matrix_view<math_t, idx_t, raft::row_major>(
components_copy.data(), n_cols, n_components),
raft::make_device_matrix_view<math_t, idx_t, LayoutPolicy>(
components_copy.data(), n_components, n_cols),
raft::make_device_vector_view<const math_t, idx_t>(singular_vals.data_handle(),
n_components));
}

raft::stats::meanCenter<false, true>(
raft::stats::meanCenter<input_row_major, true>(
input.data_handle(), input.data_handle(), mu.data_handle(), n_cols, n_rows, stream);
detail::tsvd_transform(handle,
prms,
input,
raft::make_device_matrix_view<math_t, idx_t, raft::col_major>(
components_copy.data(), n_components, n_cols),
trans_input);
raft::stats::meanAdd<false, true>(

detail::tsvd_transform<math_t, idx_t, LayoutPolicy>(
handle,
prms,
input,
raft::make_device_matrix_view<math_t, idx_t, LayoutPolicy>(
components_copy.data(), n_components, n_cols),
trans_input);

raft::stats::meanAdd<input_row_major, true>(
input.data_handle(), input.data_handle(), mu.data_handle(), n_cols, n_rows, stream);
}

/**
* @brief performs inverse transform operation for PCA.
* @tparam math_t element type
* @tparam idx_t index type
* @tparam LayoutPolicy layout (raft::row_major or raft::col_major)
* @param[in] handle: raft::resources
* @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.)
* @param[in] trans_input: the transformed data. Size n_rows x n_components (col-major).
* @param[in] components: principal components. Size n_components x n_cols (col-major).
* @param[in] trans_input: the transformed data. Size n_rows x n_components.
* @param[in] components: principal components. Size n_components x n_cols.
* @param[in] singular_vals: singular values. Size n_components.
* @param[in] mu: mean of features. Size n_cols.
* @param[out] output: the reconstructed data. Size n_rows x n_cols (col-major).
* @param[out] output: the reconstructed data. Size n_rows x n_cols.
*/
template <typename math_t, typename idx_t>
template <typename math_t, typename idx_t, typename LayoutPolicy>
void pca_inverse_transform(raft::resources const& handle,
const paramsPCA& prms,
raft::device_matrix_view<math_t, idx_t, raft::col_major> trans_input,
raft::device_matrix_view<math_t, idx_t, raft::col_major> components,
raft::device_matrix_view<math_t, idx_t, LayoutPolicy> trans_input,
raft::device_matrix_view<math_t, idx_t, LayoutPolicy> components,
raft::device_vector_view<math_t, idx_t> singular_vals,
raft::device_vector_view<math_t, idx_t> mu,
raft::device_matrix_view<math_t, idx_t, raft::col_major> output)
raft::device_matrix_view<math_t, idx_t, LayoutPolicy> output)
{
static_assert(
std::is_same_v<LayoutPolicy, raft::row_major> || std::is_same_v<LayoutPolicy, raft::col_major>,
"pca_inverse_transform: layout must be raft::row_major or raft::col_major");
constexpr bool input_row_major = std::is_same_v<LayoutPolicy, raft::row_major>;

auto stream = resource::get_cuda_stream(handle);

auto n_rows = output.extent(0);
Expand All @@ -262,44 +303,49 @@ void pca_inverse_transform(raft::resources const& handle,
math_t scalar = n_rows - 1 > 0 ? math_t(1 / sqrt_n_samples) : 0;
raft::linalg::scalarMultiply(
components_copy.data(), components_copy.data(), scalar, components_len, stream);
raft::linalg::binary_mult_skip_zero<raft::Apply::ALONG_ROWS>(
raft::linalg::binary_mult_skip_zero<raft::Apply::ALONG_COLUMNS>(
handle,
raft::make_device_matrix_view<math_t, idx_t, raft::row_major>(
components_copy.data(), n_cols, n_components),
raft::make_device_matrix_view<math_t, idx_t, LayoutPolicy>(
components_copy.data(), n_components, n_cols),
raft::make_device_vector_view<const math_t, idx_t>(singular_vals.data_handle(),
n_components));
}

detail::tsvd_inverse_transform(handle,
prms,
trans_input,
raft::make_device_matrix_view<math_t, idx_t, raft::col_major>(
components_copy.data(), n_components, n_cols),
output);
raft::stats::meanAdd<false, true>(
detail::tsvd_inverse_transform<math_t, idx_t, LayoutPolicy>(
handle,
prms,
trans_input,
raft::make_device_matrix_view<math_t, idx_t, LayoutPolicy>(
components_copy.data(), n_components, n_cols),
output);

raft::stats::meanAdd<input_row_major, true>(
output.data_handle(), output.data_handle(), mu.data_handle(), n_cols, n_rows, stream);
}

/**
* @brief perform fit and transform operations for PCA.
* @tparam math_t element type
* @tparam idx_t index type
* @tparam LayoutPolicy layout (raft::row_major or raft::col_major)
* @param[in] handle: raft::resources
* @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.)
* @param[inout] input: the data is fitted to PCA. Size n_rows x n_cols (col-major).
* @param[out] trans_input: the transformed data. Size n_rows x n_components (col-major).
* @param[out] components: the principal components. Size n_components x n_cols (col-major).
* @param[inout] input: the data is fitted to PCA. Size n_rows x n_cols.
* @param[out] trans_input: the transformed data. Size n_rows x n_components.
* @param[out] components: the principal components. Size n_components x n_cols.
* @param[out] explained_var: explained variances. Size n_components.
* @param[out] explained_var_ratio: ratio of explained to total variance. Size n_components.
* @param[out] singular_vals: singular values. Size n_components.
* @param[out] mu: mean of all features. Size n_cols.
* @param[out] noise_vars: noise variance scalar.
* @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false)
*/
template <typename math_t, typename idx_t>
template <typename math_t, typename idx_t, typename LayoutPolicy>
void pca_fit_transform(raft::resources const& handle,
const paramsPCA& prms,
raft::device_matrix_view<math_t, idx_t, raft::col_major> input,
raft::device_matrix_view<math_t, idx_t, raft::col_major> trans_input,
raft::device_matrix_view<math_t, idx_t, raft::col_major> components,
raft::device_matrix_view<math_t, idx_t, LayoutPolicy> input,
raft::device_matrix_view<math_t, idx_t, LayoutPolicy> trans_input,
raft::device_matrix_view<math_t, idx_t, LayoutPolicy> components,
raft::device_vector_view<math_t, idx_t> explained_var,
raft::device_vector_view<math_t, idx_t> explained_var_ratio,
raft::device_vector_view<math_t, idx_t> singular_vals,
Expand Down
Loading
Loading