diff --git a/cpp/include/raft/linalg/detail/pca.cuh b/cpp/include/raft/linalg/detail/pca.cuh index 58cec63151..0c56b2a42f 100644 --- a/cpp/include/raft/linalg/detail/pca.cuh +++ b/cpp/include/raft/linalg/detail/pca.cuh @@ -29,16 +29,22 @@ namespace raft { namespace linalg::detail { -template +template void trunc_comp_exp_vars(raft::resources const& handle, const paramsTSVD& prms, - raft::device_matrix_view in, - raft::device_matrix_view components, + raft::device_matrix_view in, + raft::device_matrix_view components, raft::device_vector_view explained_var, raft::device_vector_view explained_var_ratio, raft::device_scalar_view noise_vars, std::size_t n_rows) { + static_assert( + std::is_same_v || std::is_same_v, + "trunc_comp_exp_vars: layout must be raft::row_major or raft::col_major"); + + constexpr bool is_row_major = std::is_same_v; + auto stream = resource::get_cuda_stream(handle); auto n_cols = in.extent(0); @@ -49,19 +55,26 @@ void trunc_comp_exp_vars(raft::resources const& handle, rmm::device_uvector explained_var_all(static_cast(n_cols), stream); rmm::device_uvector explained_var_ratio_all(static_cast(n_cols), stream); - detail::cal_eig( + detail::cal_eig( handle, prms, in, - raft::make_device_matrix_view( + raft::make_device_matrix_view( components_all.data(), n_cols, n_cols), raft::make_device_vector_view(explained_var_all.data(), n_cols)); - raft::matrix::trunc_zero_origin( - handle, - raft::make_device_matrix_view( - components_all.data(), n_cols, n_cols), - raft::make_device_matrix_view( - components.data_handle(), n_components, n_cols)); + if constexpr (is_row_major) { + raft::copy(components.data_handle(), + components_all.data(), + static_cast(n_components) * static_cast(n_cols), + stream); + } else { + raft::matrix::trunc_zero_origin( + handle, + raft::make_device_matrix_view( + components_all.data(), n_cols, n_cols), + raft::make_device_matrix_view( + components.data_handle(), n_components, n_cols)); + } raft::matrix::ratio(handle, raft::make_device_matrix_view( explained_var_all.data(), n_cols, idx_t(1)), @@ -98,10 +111,13 @@ void trunc_comp_exp_vars(raft::resources const& handle, /** * @brief perform fit operation for PCA. + * @tparam math_t element type + * @tparam idx_t index type + * @tparam LayoutPolicy layout of the input matrix (raft::row_major or raft::col_major) * @param[in] handle: raft::resources * @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.) - * @param[inout] input: the data is fitted to PCA. Size n_rows x n_cols (col-major). - * @param[out] components: the principal components. Size n_components x n_cols (col-major). + * @param[inout] input: the data is fitted to PCA. Size n_rows x n_cols. + * @param[out] components: the principal components. Size n_components x n_cols. * @param[out] explained_var: explained variances. Size n_components. * @param[out] explained_var_ratio: ratio of explained to total variance. Size n_components. * @param[out] singular_vals: singular values. Size n_components. @@ -109,11 +125,11 @@ void trunc_comp_exp_vars(raft::resources const& handle, * @param[out] noise_vars: noise variance scalar. * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) */ -template +template void pca_fit(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view components, raft::device_vector_view explained_var, raft::device_vector_view explained_var_ratio, raft::device_vector_view singular_vals, @@ -121,6 +137,11 @@ void pca_fit(raft::resources const& handle, raft::device_scalar_view noise_vars, bool flip_signs_based_on_U = false) { + static_assert( + std::is_same_v || std::is_same_v, + "pca_fit: input layout must be raft::row_major or raft::col_major"); + constexpr bool input_row_major = std::is_same_v; + auto stream = resource::get_cuda_stream(handle); auto cublas_handle = raft::resource::get_cublas_handle(handle); @@ -134,18 +155,19 @@ void pca_fit(raft::resources const& handle, ASSERT(n_components > 0, "Parameter n_components: number of components cannot be less than one"); ASSERT(n_components <= n_cols, "n_components cannot exceed n_cols"); - raft::stats::mean(mu.data_handle(), input.data_handle(), n_cols, n_rows, false, stream); + raft::stats::mean( + mu.data_handle(), input.data_handle(), n_cols, n_rows, false, stream); auto len = static_cast(n_cols * n_cols); rmm::device_uvector cov(len, stream); - raft::stats::cov( + raft::stats::cov( handle, cov.data(), input.data_handle(), mu.data_handle(), n_cols, n_rows, true, true, stream); - detail::trunc_comp_exp_vars( + detail::trunc_comp_exp_vars( handle, prms, - raft::make_device_matrix_view(cov.data(), n_cols, n_cols), + raft::make_device_matrix_view(cov.data(), n_cols, n_cols), components, explained_var, explained_var_ratio, @@ -161,7 +183,7 @@ void pca_fit(raft::resources const& handle, raft::make_host_scalar_view(&scalar), true); - raft::stats::meanAdd( + raft::stats::meanAdd( input.data_handle(), input.data_handle(), mu.data_handle(), n_cols, n_rows, stream); detail::sign_flip_components(handle, input, components, true, flip_signs_based_on_U); @@ -169,23 +191,31 @@ void pca_fit(raft::resources const& handle, /** * @brief performs transform operation for PCA. Transforms the data to eigenspace. + * @tparam math_t element type + * @tparam idx_t index type + * @tparam LayoutPolicy layout (raft::row_major or raft::col_major) * @param[in] handle: raft::resources * @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.) - * @param[inout] input: the data to transform. Size n_rows x n_cols (col-major). - * @param[in] components: principal components. Size n_components x n_cols (col-major). + * @param[inout] input: the data to transform. Size n_rows x n_cols. + * @param[in] components: principal components. Size n_components x n_cols. * @param[in] singular_vals: singular values. Size n_components. * @param[in] mu: mean of features. Size n_cols. - * @param[out] trans_input: the transformed data. Size n_rows x n_components (col-major). + * @param[out] trans_input: the transformed data. Size n_rows x n_components. */ -template +template void pca_transform(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view components, raft::device_vector_view singular_vals, raft::device_vector_view mu, - raft::device_matrix_view trans_input) + raft::device_matrix_view trans_input) { + static_assert( + std::is_same_v || std::is_same_v, + "pca_transform: layout must be raft::row_major or raft::col_major"); + constexpr bool input_row_major = std::is_same_v; + auto stream = resource::get_cuda_stream(handle); auto n_rows = input.extent(0); @@ -204,45 +234,56 @@ void pca_transform(raft::resources const& handle, math_t scalar = math_t(sqrt(n_rows - 1)); raft::linalg::scalarMultiply( components_copy.data(), components_copy.data(), scalar, components_len, stream); - raft::linalg::binary_div_skip_zero( + raft::linalg::binary_div_skip_zero( handle, - raft::make_device_matrix_view( - components_copy.data(), n_cols, n_components), + raft::make_device_matrix_view( + components_copy.data(), n_components, n_cols), raft::make_device_vector_view(singular_vals.data_handle(), n_components)); } - raft::stats::meanCenter( + raft::stats::meanCenter( input.data_handle(), input.data_handle(), mu.data_handle(), n_cols, n_rows, stream); - detail::tsvd_transform(handle, - prms, - input, - raft::make_device_matrix_view( - components_copy.data(), n_components, n_cols), - trans_input); - raft::stats::meanAdd( + + detail::tsvd_transform( + handle, + prms, + input, + raft::make_device_matrix_view( + components_copy.data(), n_components, n_cols), + trans_input); + + raft::stats::meanAdd( input.data_handle(), input.data_handle(), mu.data_handle(), n_cols, n_rows, stream); } /** * @brief performs inverse transform operation for PCA. + * @tparam math_t element type + * @tparam idx_t index type + * @tparam LayoutPolicy layout (raft::row_major or raft::col_major) * @param[in] handle: raft::resources * @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.) - * @param[in] trans_input: the transformed data. Size n_rows x n_components (col-major). - * @param[in] components: principal components. Size n_components x n_cols (col-major). + * @param[in] trans_input: the transformed data. Size n_rows x n_components. + * @param[in] components: principal components. Size n_components x n_cols. * @param[in] singular_vals: singular values. Size n_components. * @param[in] mu: mean of features. Size n_cols. - * @param[out] output: the reconstructed data. Size n_rows x n_cols (col-major). + * @param[out] output: the reconstructed data. Size n_rows x n_cols. */ -template +template void pca_inverse_transform(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, raft::device_vector_view singular_vals, raft::device_vector_view mu, - raft::device_matrix_view output) + raft::device_matrix_view output) { + static_assert( + std::is_same_v || std::is_same_v, + "pca_inverse_transform: layout must be raft::row_major or raft::col_major"); + constexpr bool input_row_major = std::is_same_v; + auto stream = resource::get_cuda_stream(handle); auto n_rows = output.extent(0); @@ -262,31 +303,36 @@ void pca_inverse_transform(raft::resources const& handle, math_t scalar = n_rows - 1 > 0 ? math_t(1 / sqrt_n_samples) : 0; raft::linalg::scalarMultiply( components_copy.data(), components_copy.data(), scalar, components_len, stream); - raft::linalg::binary_mult_skip_zero( + raft::linalg::binary_mult_skip_zero( handle, - raft::make_device_matrix_view( - components_copy.data(), n_cols, n_components), + raft::make_device_matrix_view( + components_copy.data(), n_components, n_cols), raft::make_device_vector_view(singular_vals.data_handle(), n_components)); } - detail::tsvd_inverse_transform(handle, - prms, - trans_input, - raft::make_device_matrix_view( - components_copy.data(), n_components, n_cols), - output); - raft::stats::meanAdd( + detail::tsvd_inverse_transform( + handle, + prms, + trans_input, + raft::make_device_matrix_view( + components_copy.data(), n_components, n_cols), + output); + + raft::stats::meanAdd( output.data_handle(), output.data_handle(), mu.data_handle(), n_cols, n_rows, stream); } /** * @brief perform fit and transform operations for PCA. + * @tparam math_t element type + * @tparam idx_t index type + * @tparam LayoutPolicy layout (raft::row_major or raft::col_major) * @param[in] handle: raft::resources * @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.) - * @param[inout] input: the data is fitted to PCA. Size n_rows x n_cols (col-major). - * @param[out] trans_input: the transformed data. Size n_rows x n_components (col-major). - * @param[out] components: the principal components. Size n_components x n_cols (col-major). + * @param[inout] input: the data is fitted to PCA. Size n_rows x n_cols. + * @param[out] trans_input: the transformed data. Size n_rows x n_components. + * @param[out] components: the principal components. Size n_components x n_cols. * @param[out] explained_var: explained variances. Size n_components. * @param[out] explained_var_ratio: ratio of explained to total variance. Size n_components. * @param[out] singular_vals: singular values. Size n_components. @@ -294,12 +340,12 @@ void pca_inverse_transform(raft::resources const& handle, * @param[out] noise_vars: noise variance scalar. * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) */ -template +template void pca_fit_transform(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view input, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, raft::device_vector_view explained_var, raft::device_vector_view explained_var_ratio, raft::device_vector_view singular_vals, diff --git a/cpp/include/raft/linalg/detail/tsvd.cuh b/cpp/include/raft/linalg/detail/tsvd.cuh index 68ab68747f..32ce9c9275 100644 --- a/cpp/include/raft/linalg/detail/tsvd.cuh +++ b/cpp/include/raft/linalg/detail/tsvd.cuh @@ -15,7 +15,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -104,13 +106,19 @@ void cal_comp_exp_vars_svd(raft::resources const& handle, handle, explained_vars.data_handle(), explained_var_ratio.data_handle(), n_components, stream); } -template +template void cal_eig(raft::resources const& handle, const paramsTSVD& prms, - raft::device_matrix_view in, - raft::device_matrix_view components, + raft::device_matrix_view in, + raft::device_matrix_view components, raft::device_vector_view explained_var) { + static_assert( + std::is_same_v || std::is_same_v, + "cal_eig: layout must be raft::row_major or raft::col_major"); + + constexpr bool is_row_major = std::is_same_v; + auto stream = resource::get_cuda_stream(handle); auto cusolver_handle = raft::resource::get_cusolver_dn_handle(handle); @@ -141,7 +149,9 @@ void cal_eig(raft::resources const& handle, raft::matrix::col_reverse(handle_stream_zero, raft::make_device_matrix_view( components.data_handle(), n_cols, n_cols)); - raft::linalg::transpose(components.data_handle(), n_cols, stream); + if constexpr (!is_row_major) { + raft::linalg::transpose(components.data_handle(), n_cols, stream); + } raft::matrix::row_reverse(handle_stream_zero, raft::make_device_matrix_view( @@ -149,53 +159,58 @@ void cal_eig(raft::resources const& handle, } /** - * @brief sign flip for PCA and tSVD. Stabilizes the sign of column major eigenvectors. + * @brief sign flip for PCA and tSVD. Stabilizes the sign of the eigenvectors. + * @tparam math_t element type + * @tparam idx_t index type + * @tparam LayoutPolicy layout of the input and components matrices * @param handle: raft::resources - * @param input: input data [n_samples x n_features] (col-major) - * @param components: components matrix [n_components x n_features] (col-major) + * @param input: input data [n_samples x n_features] + * @param components: components matrix [n_components x n_features] * @param center whether to mean-center input before computing signs * @param flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) */ -template +template void sign_flip_components(raft::resources const& handle, - raft::device_matrix_view input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view components, bool center, bool flip_signs_based_on_U = false) { + static_assert( + std::is_same_v || std::is_same_v, + "sign_flip_components: layout must be raft::row_major or raft::col_major"); + constexpr bool is_row_major = std::is_same_v; + auto stream = resource::get_cuda_stream(handle); auto n_samples = input.extent(0); auto n_features = input.extent(1); auto n_components = components.extent(0); rmm::device_uvector max_vals(static_cast(n_components), stream); - auto components_view = raft::make_device_matrix_view( + auto components_view = raft::make_device_matrix_view( components.data_handle(), n_components, n_features); auto max_vals_view = raft::make_device_vector_view(max_vals.data(), n_components); if (flip_signs_based_on_U) { if (center) { rmm::device_uvector col_means(static_cast(n_features), stream); - raft::stats::mean( + raft::stats::mean( col_means.data(), input.data_handle(), n_features, n_samples, stream); - raft::stats::meanCenter( + raft::stats::meanCenter( input.data_handle(), input.data_handle(), col_means.data(), n_features, n_samples, stream); } rmm::device_uvector US(static_cast(n_samples * n_components), stream); - raft::linalg::gemm(handle, - input.data_handle(), - n_samples, - n_features, - components.data_handle(), - US.data(), - n_samples, - n_components, - CUBLAS_OP_N, - CUBLAS_OP_T, - math_t(1), - math_t(0), - stream); - raft::linalg::reduce( + raft::linalg::gemm(handle, + input, + raft::make_device_matrix_view< + math_t, + idx_t, + std::conditional_t>( + components.data_handle(), n_features, n_components), + raft::make_device_matrix_view( + US.data(), n_samples, n_components)); + + raft::linalg::reduce( max_vals.data(), US.data(), n_components, @@ -211,7 +226,7 @@ void sign_flip_components(raft::resources const& handle, }, raft::identity_op()); } else { - raft::linalg::reduce( + raft::linalg::reduce( max_vals.data(), components.data_handle(), n_features, @@ -232,8 +247,15 @@ void sign_flip_components(raft::resources const& handle, handle, components_view, [components_view, max_vals_view, n_components, n_features] __device__(auto idx) { - auto row = idx % n_components; - auto column = idx / n_components; + idx_t row; + idx_t column; + if constexpr (is_row_major) { + row = idx / n_features; + column = idx % n_features; + } else { + row = idx % n_components; + column = idx / n_components; + } return (max_vals_view(row) < math_t(0)) ? (-components_view(row, column)) : components_view(row, column); }); @@ -376,18 +398,21 @@ void tsvd_fit(raft::resources const& handle, * @brief performs transform operation for the tsvd. Transforms the data to eigenspace. * @param[in] handle raft::resources * @param[in] prms: data structure that includes all the parameters from input size to algorithm. - * @param[in] input: the data to transform. Size n_rows x n_cols (col-major). - * @param[in] components: principal components. Size n_components x n_cols (col-major). - * @param[out] trans_input: transformed output. Size n_rows x n_components (col-major). + * @param[in] input: the data to transform. Size n_rows x n_cols. + * @param[in] components: principal components. Size n_components x n_cols. + * @param[out] trans_input: transformed output. Size n_rows x n_components. */ -template +template void tsvd_transform(raft::resources const& handle, const paramsTSVD& prms, - raft::device_matrix_view input, - raft::device_matrix_view components, - raft::device_matrix_view trans_input) + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_matrix_view trans_input) { - auto stream = resource::get_cuda_stream(handle); + static_assert( + std::is_same_v || std::is_same_v, + "tsvd_transform: layout must be raft::row_major or raft::col_major"); + constexpr bool is_row_major = std::is_same_v; auto n_rows = input.extent(0); auto n_cols = input.extent(1); @@ -397,39 +422,34 @@ void tsvd_transform(raft::resources const& handle, ASSERT(n_rows > 0, "Parameter n_rows: number of rows cannot be less than one"); ASSERT(n_components > 0, "Parameter n_components: number of components cannot be less than one"); - math_t alpha = math_t(1); - math_t beta = math_t(0); raft::linalg::gemm(handle, - input.data_handle(), - n_rows, - n_cols, - components.data_handle(), - trans_input.data_handle(), - n_rows, - n_components, - CUBLAS_OP_N, - CUBLAS_OP_T, - alpha, - beta, - stream); + input, + raft::make_device_matrix_view< + math_t, + idx_t, + std::conditional_t>( + components.data_handle(), n_cols, n_components), + trans_input); } /** * @brief performs inverse transform operation for the tsvd. * @param[in] handle raft::resources * @param[in] prms: data structure that includes all the parameters from input size to algorithm. - * @param[in] trans_input: the transformed data. Size n_rows x n_components (col-major). - * @param[in] components: principal components. Size n_components x n_cols (col-major). - * @param[out] output: reconstructed output. Size n_rows x n_cols (col-major). + * @param[in] trans_input: the transformed data. Size n_rows x n_components. + * @param[in] components: principal components. Size n_components x n_cols. + * @param[out] output: reconstructed output. Size n_rows x n_cols. */ -template +template void tsvd_inverse_transform(raft::resources const& handle, const paramsTSVD& prms, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, - raft::device_matrix_view output) + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_matrix_view output) { - auto stream = resource::get_cuda_stream(handle); + static_assert( + std::is_same_v || std::is_same_v, + "tsvd_inverse_transform: layout must be raft::row_major or raft::col_major"); auto n_rows = output.extent(0); auto n_cols = output.extent(1); @@ -439,22 +459,7 @@ void tsvd_inverse_transform(raft::resources const& handle, ASSERT(n_rows > 0, "Parameter n_rows: number of rows cannot be less than one"); ASSERT(n_components > 0, "Parameter n_components: number of components cannot be less than one"); - math_t alpha = math_t(1); - math_t beta = math_t(0); - - raft::linalg::gemm(handle, - trans_input.data_handle(), - n_rows, - n_components, - components.data_handle(), - output.data_handle(), - n_rows, - n_cols, - CUBLAS_OP_N, - CUBLAS_OP_N, - alpha, - beta, - stream); + raft::linalg::gemm(handle, trans_input, components, output); } /** diff --git a/cpp/include/raft/linalg/pca.cuh b/cpp/include/raft/linalg/pca.cuh index 4022c82989..93cabb5c4f 100644 --- a/cpp/include/raft/linalg/pca.cuh +++ b/cpp/include/raft/linalg/pca.cuh @@ -22,12 +22,13 @@ namespace linalg { * @brief perform fit operation for PCA. Generates eigenvectors, explained vars, singular vals, etc. * @tparam math_t data-type upon which the math operation will be performed * @tparam idx_t integer type used for indexing + * @tparam LayoutPolicy layout of the input/components matrices (raft::row_major or + * raft::col_major) * @param[in] handle: raft::resources * @param[in] prms PCA parameters (n_components, algorithm, whiten, etc.) - * @param[inout] input the data is fitted to PCA. Size n_rows x n_cols (col-major). Modified + * @param[inout] input the data is fitted to PCA. Size n_rows x n_cols. Modified * temporarily during computation. - * @param[out] components the principal components of the input data. Size n_components x n_cols - * (col-major). + * @param[out] components the principal components of the input data. Size n_components x n_cols. * @param[out] explained_var explained variances (eigenvalues) of the principal components. Size * n_components. * @param[out] explained_var_ratio the ratio of the explained variance and total variance. Size @@ -37,11 +38,11 @@ namespace linalg { * @param[out] noise_vars variance of the noise. Scalar. * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) */ -template +template void pca_fit(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view components, raft::device_vector_view explained_var, raft::device_vector_view explained_var_ratio, raft::device_vector_view singular_vals, @@ -66,13 +67,13 @@ void pca_fit(raft::resources const& handle, * eigenvectors, explained vars, singular vals, etc. * @tparam math_t data-type upon which the math operation will be performed * @tparam idx_t integer type used for indexing + * @tparam LayoutPolicy layout of the input/output matrices (raft::row_major or raft::col_major) * @param[in] handle raft::resources * @param[in] prms PCA parameters (n_components, algorithm, whiten, etc.) - * @param[inout] input the data is fitted to PCA. Size n_rows x n_cols (col-major). Modified + * @param[inout] input the data is fitted to PCA. Size n_rows x n_cols. Modified * temporarily during computation. - * @param[out] trans_input the transformed data. Size n_rows x n_components (col-major). - * @param[out] components the principal components of the input data. Size n_components x n_cols - * (col-major). + * @param[out] trans_input the transformed data. Size n_rows x n_components. + * @param[out] components the principal components of the input data. Size n_components x n_cols. * @param[out] explained_var explained variances (eigenvalues) of the principal components. Size * n_components. * @param[out] explained_var_ratio the ratio of the explained variance and total variance. Size @@ -82,12 +83,12 @@ void pca_fit(raft::resources const& handle, * @param[out] noise_vars variance of the noise. Scalar. * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) */ -template +template void pca_fit_transform(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view input, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, raft::device_vector_view explained_var, raft::device_vector_view explained_var_ratio, raft::device_vector_view singular_vals, @@ -113,23 +114,23 @@ void pca_fit_transform(raft::resources const& handle, * original data. * @tparam math_t data-type upon which the math operation will be performed * @tparam idx_t integer type used for indexing + * @tparam LayoutPolicy layout of the input/output matrices (raft::row_major or raft::col_major) * @param[in] handle raft::resources * @param[in] prms PCA parameters (n_components, algorithm, whiten, etc.) - * @param[in] trans_input the transformed data. Size n_rows x n_components (col-major). - * @param[in] components the principal components of the input data. Size n_components x n_cols - * (col-major). + * @param[in] trans_input the transformed data. Size n_rows x n_components. + * @param[in] components the principal components of the input data. Size n_components x n_cols. * @param[in] singular_vals singular values of the data. Size n_components. * @param[in] mu mean of features (every column). Size n_cols. - * @param[out] output the reconstructed data. Size n_rows x n_cols (col-major). + * @param[out] output the reconstructed data. Size n_rows x n_cols. */ -template +template void pca_inverse_transform(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, raft::device_vector_view singular_vals, raft::device_vector_view mu, - raft::device_matrix_view output) + raft::device_matrix_view output) { detail::pca_inverse_transform(handle, prms, trans_input, components, singular_vals, mu, output); } @@ -138,24 +139,24 @@ void pca_inverse_transform(raft::resources const& handle, * @brief performs transform operation for PCA. Transforms the data to eigenspace. * @tparam math_t data-type upon which the math operation will be performed * @tparam idx_t integer type used for indexing + * @tparam LayoutPolicy layout of the input/output matrices (raft::row_major or raft::col_major) * @param[in] handle raft::resources * @param[in] prms PCA parameters (n_components, algorithm, whiten, etc.) - * @param[inout] input the data to be transformed. Size n_rows x n_cols (col-major). Modified + * @param[inout] input the data to be transformed. Size n_rows x n_cols. Modified * temporarily during computation (mean-centered then restored). - * @param[in] components principal components of the input data. Size n_components x n_cols - * (col-major). + * @param[in] components principal components of the input data. Size n_components x n_cols. * @param[in] singular_vals singular values of the data. Size n_components. * @param[in] mu mean value of the input data. Size n_cols. - * @param[out] trans_input the transformed data. Size n_rows x n_components (col-major). + * @param[out] trans_input the transformed data. Size n_rows x n_components. */ -template +template void pca_transform(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view components, raft::device_vector_view singular_vals, raft::device_vector_view mu, - raft::device_matrix_view trans_input) + raft::device_matrix_view trans_input) { detail::pca_transform(handle, prms, input, components, singular_vals, mu, trans_input); } diff --git a/cpp/tests/linalg/pca.cu b/cpp/tests/linalg/pca.cu index a1d80953e2..6d95a342d8 100644 --- a/cpp/tests/linalg/pca.cu +++ b/cpp/tests/linalg/pca.cu @@ -51,12 +51,9 @@ class PcaTest : public ::testing::TestWithParam> { trans_data(params.len, stream), trans_data_ref(params.len, stream), data(params.len, stream), - data_back(params.len, stream), - data2(params.len2, stream), - data2_back(params.len2, stream) + data_back(params.len, stream) { basicTest(); - advancedTest(); } protected: @@ -135,6 +132,7 @@ class PcaTest : public ::testing::TestWithParam> { handle, prms, trans_data_view, components_view, singular_vals_view, mu_view, data_back_view); } + template void advancedTest() { raft::random::Rng r(params.seed, raft::random::GenPC); @@ -151,6 +149,8 @@ class PcaTest : public ::testing::TestWithParam> { else if (params.algo == 1) prms.algorithm = solver::COV_EIG_JACOBI; + rmm::device_uvector data2(len, stream); + rmm::device_uvector data2_back(len, stream); r.uniform(data2.data(), len, T(-1.0), T(1.0), stream); rmm::device_uvector data2_trans(n_rows * n_components, stream); @@ -163,10 +163,10 @@ class PcaTest : public ::testing::TestWithParam> { rmm::device_uvector noise_vars2(1, stream); auto input_view = - raft::make_device_matrix_view(data2.data(), n_rows, n_cols); - auto trans_view = raft::make_device_matrix_view( + raft::make_device_matrix_view(data2.data(), n_rows, n_cols); + auto trans_view = raft::make_device_matrix_view( data2_trans.data(), n_rows, n_components); - auto comp_view = raft::make_device_matrix_view( + auto comp_view = raft::make_device_matrix_view( components2.data(), n_components, n_cols); auto ev_view = raft::make_device_vector_view(explained_vars2.data(), n_components); @@ -188,10 +188,13 @@ class PcaTest : public ::testing::TestWithParam> { mu_view, noise_view); - auto data2_back_view = raft::make_device_matrix_view( + auto data2_back_view = raft::make_device_matrix_view( data2_back.data(), n_rows, n_cols); pca_inverse_transform(handle, prms, trans_view, comp_view, sv_view, mu_view, data2_back_view); + + ASSERT_TRUE(devArrMatch( + data2.data(), data2_back.data(), len, raft::CompareApprox(params.tolerance), stream)); } protected: @@ -201,7 +204,7 @@ class PcaTest : public ::testing::TestWithParam> { PcaInputs params; rmm::device_uvector explained_vars, explained_vars_ref, components, components_ref, trans_data, - trans_data_ref, data, data_back, data2, data2_back; + trans_data_ref, data, data_back; }; const std::vector> inputsf2 = { @@ -295,21 +298,15 @@ TEST_P(PcaTestDataVecSmallD, Result) typedef PcaTest PcaTestDataVecF; TEST_P(PcaTestDataVecF, Result) { - ASSERT_TRUE(devArrMatch(data2.data(), - data2_back.data(), - (params.n_col2 * params.n_col2), - raft::CompareApprox(params.tolerance), - resource::get_cuda_stream(handle))); + this->template advancedTest(); + this->template advancedTest(); } typedef PcaTest PcaTestDataVecD; TEST_P(PcaTestDataVecD, Result) { - ASSERT_TRUE(devArrMatch(data2.data(), - data2_back.data(), - (params.n_col2 * params.n_col2), - raft::CompareApprox(params.tolerance), - resource::get_cuda_stream(handle))); + this->template advancedTest(); + this->template advancedTest(); } INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestValF, ::testing::ValuesIn(inputsf2));