From bd0a12b3b2039b92bbde8332c654c915bee31693 Mon Sep 17 00:00:00 2001 From: aamijar Date: Tue, 26 May 2026 00:48:02 +0000 Subject: [PATCH 1/7] pca-row-major --- cpp/include/raft/linalg/detail/pca.cuh | 174 +++++++++++++++++------- cpp/include/raft/linalg/detail/tsvd.cuh | 56 +++++--- cpp/include/raft/linalg/pca.cuh | 74 ++++++---- cpp/tests/linalg/pca.cu | 111 +++++++++++++++ 4 files changed, 314 insertions(+), 101 deletions(-) diff --git a/cpp/include/raft/linalg/detail/pca.cuh b/cpp/include/raft/linalg/detail/pca.cuh index 58cec63151..e50961f477 100644 --- a/cpp/include/raft/linalg/detail/pca.cuh +++ b/cpp/include/raft/linalg/detail/pca.cuh @@ -98,10 +98,17 @@ void trunc_comp_exp_vars(raft::resources const& handle, /** * @brief perform fit operation for PCA. + * + * Supports both row-major and col-major input layouts via the LayoutPolicy template + * parameter. The output `components` matrix has the same layout as the input. + * + * @tparam math_t element type + * @tparam idx_t index type + * @tparam LayoutPolicy layout of the input matrix (raft::row_major or raft::col_major) * @param[in] handle: raft::resources * @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.) - * @param[inout] input: the data is fitted to PCA. Size n_rows x n_cols (col-major). - * @param[out] components: the principal components. Size n_components x n_cols (col-major). + * @param[inout] input: the data is fitted to PCA. Size n_rows x n_cols. + * @param[out] components: the principal components. Size n_components x n_cols. * @param[out] explained_var: explained variances. Size n_components. * @param[out] explained_var_ratio: ratio of explained to total variance. Size n_components. * @param[out] singular_vals: singular values. Size n_components. @@ -109,11 +116,11 @@ void trunc_comp_exp_vars(raft::resources const& handle, * @param[out] noise_vars: noise variance scalar. * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) */ -template +template void pca_fit(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view components, raft::device_vector_view explained_var, raft::device_vector_view explained_var_ratio, raft::device_vector_view singular_vals, @@ -121,6 +128,11 @@ void pca_fit(raft::resources const& handle, raft::device_scalar_view noise_vars, bool flip_signs_based_on_U = false) { + static_assert( + std::is_same_v || std::is_same_v, + "pca_fit: input layout must be raft::row_major or raft::col_major"); + constexpr bool input_row_major = std::is_same_v; + auto stream = resource::get_cuda_stream(handle); auto cublas_handle = raft::resource::get_cublas_handle(handle); @@ -134,19 +146,30 @@ void pca_fit(raft::resources const& handle, ASSERT(n_components > 0, "Parameter n_components: number of components cannot be less than one"); ASSERT(n_components <= n_cols, "n_components cannot exceed n_cols"); - raft::stats::mean(mu.data_handle(), input.data_handle(), n_cols, n_rows, false, stream); + raft::stats::mean( + mu.data_handle(), input.data_handle(), n_cols, n_rows, false, stream); auto len = static_cast(n_cols * n_cols); rmm::device_uvector cov(len, stream); - raft::stats::cov( + raft::stats::cov( handle, cov.data(), input.data_handle(), mu.data_handle(), n_cols, n_rows, true, true, stream); + // The eigendecomposition of the (symmetric) covariance matrix naturally produces a + // col-major components buffer. For row-major output we accumulate into a temporary + // and physically transpose at the end. + auto components_col_storage = raft::make_device_matrix( + handle, input_row_major ? n_components : idx_t(0), input_row_major ? n_cols : idx_t(0)); + math_t* components_col_data = + input_row_major ? components_col_storage.data_handle() : components.data_handle(); + auto components_col_view = raft::make_device_matrix_view( + components_col_data, n_components, n_cols); + detail::trunc_comp_exp_vars( handle, prms, raft::make_device_matrix_view(cov.data(), n_cols, n_cols), - components, + components_col_view, explained_var, explained_var_ratio, noise_vars, @@ -161,31 +184,52 @@ void pca_fit(raft::resources const& handle, raft::make_host_scalar_view(&scalar), true); - raft::stats::meanAdd( + raft::stats::meanAdd( input.data_handle(), input.data_handle(), mu.data_handle(), n_cols, n_rows, stream); - detail::sign_flip_components(handle, input, components, true, flip_signs_based_on_U); + detail::sign_flip_components(handle, input, components_col_view, true, flip_signs_based_on_U); + + if constexpr (input_row_major) { + // Transpose the internal col-major (n_components x n_cols) components into the user's + // row-major (n_components x n_cols) buffer. The same memory laid out as col-major + // (n_cols x n_components) is exactly the row-major (n_components x n_cols) we want. + auto components_as_col_view = raft::make_device_matrix_view( + components.data_handle(), n_cols, n_components); + raft::linalg::transpose(handle, components_col_view, components_as_col_view); + } } /** * @brief performs transform operation for PCA. Transforms the data to eigenspace. + * + * Supports both row-major and col-major layouts via the LayoutPolicy template parameter. + * `input`, `components`, and `trans_input` must all share the same layout. + * + * @tparam math_t element type + * @tparam idx_t index type + * @tparam LayoutPolicy layout (raft::row_major or raft::col_major) * @param[in] handle: raft::resources * @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.) - * @param[inout] input: the data to transform. Size n_rows x n_cols (col-major). - * @param[in] components: principal components. Size n_components x n_cols (col-major). + * @param[inout] input: the data to transform. Size n_rows x n_cols. + * @param[in] components: principal components. Size n_components x n_cols. * @param[in] singular_vals: singular values. Size n_components. * @param[in] mu: mean of features. Size n_cols. - * @param[out] trans_input: the transformed data. Size n_rows x n_components (col-major). + * @param[out] trans_input: the transformed data. Size n_rows x n_components. */ -template +template void pca_transform(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view components, raft::device_vector_view singular_vals, raft::device_vector_view mu, - raft::device_matrix_view trans_input) + raft::device_matrix_view trans_input) { + static_assert( + std::is_same_v || std::is_same_v, + "pca_transform: layout must be raft::row_major or raft::col_major"); + constexpr bool input_row_major = std::is_same_v; + auto stream = resource::get_cuda_stream(handle); auto n_rows = input.extent(0); @@ -200,49 +244,69 @@ void pca_transform(raft::resources const& handle, rmm::device_uvector components_copy{components_len, stream}; raft::copy(components_copy.data(), components.data_handle(), components_len, stream); + auto components_copy_view = raft::make_device_matrix_view( + components_copy.data(), n_components, n_cols); + if (prms.whiten) { math_t scalar = math_t(sqrt(n_rows - 1)); raft::linalg::scalarMultiply( components_copy.data(), components_copy.data(), scalar, components_len, stream); - raft::linalg::binary_div_skip_zero( + // Divide each row of (n_components x n_cols) components by the corresponding singular + // value. Apply::ALONG_COLUMNS broadcasts a vector of size n_rows-of-matrix + // (= n_components) over each column, which is the same operation in both layouts. + raft::linalg::binary_div_skip_zero( handle, - raft::make_device_matrix_view( - components_copy.data(), n_cols, n_components), + components_copy_view, raft::make_device_vector_view(singular_vals.data_handle(), n_components)); } - raft::stats::meanCenter( + raft::stats::meanCenter( input.data_handle(), input.data_handle(), mu.data_handle(), n_cols, n_rows, stream); - detail::tsvd_transform(handle, - prms, - input, - raft::make_device_matrix_view( - components_copy.data(), n_components, n_cols), - trans_input); - raft::stats::meanAdd( + + // trans_input = input @ components_copy^T, in the user's layout. + // Reinterpreting the components_copy buffer with the opposite layout swaps the logical + // dimensions, giving us the (n_cols x n_components) transposed view we need for gemm. + using transposed_layout = std::conditional_t; + auto components_copy_transposed = raft::make_device_matrix_view( + components_copy.data(), n_cols, n_components); + raft::linalg::gemm(handle, input, components_copy_transposed, trans_input); + + raft::stats::meanAdd( input.data_handle(), input.data_handle(), mu.data_handle(), n_cols, n_rows, stream); } /** * @brief performs inverse transform operation for PCA. + * + * Supports both row-major and col-major layouts via the LayoutPolicy template parameter. + * `trans_input`, `components`, and `output` must all share the same layout. + * + * @tparam math_t element type + * @tparam idx_t index type + * @tparam LayoutPolicy layout (raft::row_major or raft::col_major) * @param[in] handle: raft::resources * @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.) - * @param[in] trans_input: the transformed data. Size n_rows x n_components (col-major). - * @param[in] components: principal components. Size n_components x n_cols (col-major). + * @param[in] trans_input: the transformed data. Size n_rows x n_components. + * @param[in] components: principal components. Size n_components x n_cols. * @param[in] singular_vals: singular values. Size n_components. * @param[in] mu: mean of features. Size n_cols. - * @param[out] output: the reconstructed data. Size n_rows x n_cols (col-major). + * @param[out] output: the reconstructed data. Size n_rows x n_cols. */ -template +template void pca_inverse_transform(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, raft::device_vector_view singular_vals, raft::device_vector_view mu, - raft::device_matrix_view output) + raft::device_matrix_view output) { + static_assert( + std::is_same_v || std::is_same_v, + "pca_inverse_transform: layout must be raft::row_major or raft::col_major"); + constexpr bool input_row_major = std::is_same_v; + auto stream = resource::get_cuda_stream(handle); auto n_rows = output.extent(0); @@ -257,36 +321,42 @@ void pca_inverse_transform(raft::resources const& handle, rmm::device_uvector components_copy{components_len, stream}; raft::copy(components_copy.data(), components.data_handle(), components_len, stream); + auto components_copy_view = raft::make_device_matrix_view( + components_copy.data(), n_components, n_cols); + if (prms.whiten) { math_t sqrt_n_samples = sqrt(n_rows - 1); math_t scalar = n_rows - 1 > 0 ? math_t(1 / sqrt_n_samples) : 0; raft::linalg::scalarMultiply( components_copy.data(), components_copy.data(), scalar, components_len, stream); - raft::linalg::binary_mult_skip_zero( + raft::linalg::binary_mult_skip_zero( handle, - raft::make_device_matrix_view( - components_copy.data(), n_cols, n_components), + components_copy_view, raft::make_device_vector_view(singular_vals.data_handle(), n_components)); } - detail::tsvd_inverse_transform(handle, - prms, - trans_input, - raft::make_device_matrix_view( - components_copy.data(), n_components, n_cols), - output); - raft::stats::meanAdd( + // output = trans_input @ components_copy. All three matrices share the user's layout, + // so the mdspan gemm picks the correct cuBLAS transposes automatically. + raft::linalg::gemm(handle, trans_input, components_copy_view, output); + + raft::stats::meanAdd( output.data_handle(), output.data_handle(), mu.data_handle(), n_cols, n_rows, stream); } /** * @brief perform fit and transform operations for PCA. + * + * Supports both row-major and col-major layouts via the LayoutPolicy template parameter. + * + * @tparam math_t element type + * @tparam idx_t index type + * @tparam LayoutPolicy layout (raft::row_major or raft::col_major) * @param[in] handle: raft::resources * @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.) - * @param[inout] input: the data is fitted to PCA. Size n_rows x n_cols (col-major). - * @param[out] trans_input: the transformed data. Size n_rows x n_components (col-major). - * @param[out] components: the principal components. Size n_components x n_cols (col-major). + * @param[inout] input: the data is fitted to PCA. Size n_rows x n_cols. + * @param[out] trans_input: the transformed data. Size n_rows x n_components. + * @param[out] components: the principal components. Size n_components x n_cols. * @param[out] explained_var: explained variances. Size n_components. * @param[out] explained_var_ratio: ratio of explained to total variance. Size n_components. * @param[out] singular_vals: singular values. Size n_components. @@ -294,12 +364,12 @@ void pca_inverse_transform(raft::resources const& handle, * @param[out] noise_vars: noise variance scalar. * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) */ -template +template void pca_fit_transform(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view input, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, raft::device_vector_view explained_var, raft::device_vector_view explained_var_ratio, raft::device_vector_view singular_vals, diff --git a/cpp/include/raft/linalg/detail/tsvd.cuh b/cpp/include/raft/linalg/detail/tsvd.cuh index 68ab68747f..b503aa282a 100644 --- a/cpp/include/raft/linalg/detail/tsvd.cuh +++ b/cpp/include/raft/linalg/detail/tsvd.cuh @@ -15,7 +15,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -150,19 +152,31 @@ void cal_eig(raft::resources const& handle, /** * @brief sign flip for PCA and tSVD. Stabilizes the sign of column major eigenvectors. + * + * The components matrix is always stored in col-major; the input matrix may be either + * row-major or col-major (deduced from the LayoutPolicy template parameter). + * + * @tparam math_t element type + * @tparam idx_t index type + * @tparam LayoutPolicy layout of the input matrix (raft::row_major or raft::col_major) * @param handle: raft::resources - * @param input: input data [n_samples x n_features] (col-major) + * @param input: input data [n_samples x n_features] * @param components: components matrix [n_components x n_features] (col-major) * @param center whether to mean-center input before computing signs * @param flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) */ -template +template void sign_flip_components(raft::resources const& handle, - raft::device_matrix_view input, + raft::device_matrix_view input, raft::device_matrix_view components, bool center, bool flip_signs_based_on_U = false) { + static_assert( + std::is_same_v || std::is_same_v, + "sign_flip_components: input layout must be raft::row_major or raft::col_major"); + constexpr bool input_row_major = std::is_same_v; + auto stream = resource::get_cuda_stream(handle); auto n_samples = input.extent(0); auto n_features = input.extent(1); @@ -176,26 +190,28 @@ void sign_flip_components(raft::resources const& handle, if (flip_signs_based_on_U) { if (center) { rmm::device_uvector col_means(static_cast(n_features), stream); - raft::stats::mean( + raft::stats::mean( col_means.data(), input.data_handle(), n_features, n_samples, stream); - raft::stats::meanCenter( + raft::stats::meanCenter( input.data_handle(), input.data_handle(), col_means.data(), n_features, n_samples, stream); } + // US = input @ components^T, shape (n_samples x n_components), in input's layout. + // The components matrix is col-major (n_components x n_features); reinterpreting the + // same memory as row-major (n_features x n_components) yields the transpose. rmm::device_uvector US(static_cast(n_samples * n_components), stream); - raft::linalg::gemm(handle, - input.data_handle(), - n_samples, - n_features, - components.data_handle(), - US.data(), - n_samples, - n_components, - CUBLAS_OP_N, - CUBLAS_OP_T, - math_t(1), - math_t(0), - stream); - raft::linalg::reduce( + using transposed_layout = std::conditional_t; + auto components_transposed_view = + raft::make_device_matrix_view( + components.data_handle(), n_features, n_components); + auto US_view = raft::make_device_matrix_view( + US.data(), n_samples, n_components); + + raft::linalg::gemm(handle, input, components_transposed_view, US_view); + + // Per-column reduction of US (n_samples x n_components) yields one max-abs value per + // component. With the (rowMajor, alongRows) convention, alongRows=false produces D + // outputs (one per column) regardless of layout; only the memory access pattern differs. + raft::linalg::reduce( max_vals.data(), US.data(), n_components, @@ -211,6 +227,8 @@ void sign_flip_components(raft::resources const& handle, }, raft::identity_op()); } else { + // components is col-major (n_components x n_features); reduce per row to get one + // max-abs value per component. raft::linalg::reduce( max_vals.data(), components.data_handle(), diff --git a/cpp/include/raft/linalg/pca.cuh b/cpp/include/raft/linalg/pca.cuh index 4022c82989..322bff2cc0 100644 --- a/cpp/include/raft/linalg/pca.cuh +++ b/cpp/include/raft/linalg/pca.cuh @@ -20,14 +20,19 @@ namespace linalg { /** * @brief perform fit operation for PCA. Generates eigenvectors, explained vars, singular vals, etc. + * + * Supports both row-major and col-major layouts. The layout is deduced from the input view's + * `LayoutPolicy` and must match between `input` and `components`. + * * @tparam math_t data-type upon which the math operation will be performed * @tparam idx_t integer type used for indexing + * @tparam LayoutPolicy layout of the input/components matrices (raft::row_major or + * raft::col_major) * @param[in] handle: raft::resources * @param[in] prms PCA parameters (n_components, algorithm, whiten, etc.) - * @param[inout] input the data is fitted to PCA. Size n_rows x n_cols (col-major). Modified + * @param[inout] input the data is fitted to PCA. Size n_rows x n_cols. Modified * temporarily during computation. - * @param[out] components the principal components of the input data. Size n_components x n_cols - * (col-major). + * @param[out] components the principal components of the input data. Size n_components x n_cols. * @param[out] explained_var explained variances (eigenvalues) of the principal components. Size * n_components. * @param[out] explained_var_ratio the ratio of the explained variance and total variance. Size @@ -37,11 +42,11 @@ namespace linalg { * @param[out] noise_vars variance of the noise. Scalar. * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) */ -template +template void pca_fit(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view components, raft::device_vector_view explained_var, raft::device_vector_view explained_var_ratio, raft::device_vector_view singular_vals, @@ -64,15 +69,18 @@ void pca_fit(raft::resources const& handle, /** * @brief perform fit and transform operations for PCA. Generates transformed data, * eigenvectors, explained vars, singular vals, etc. + * + * Supports both row-major and col-major layouts. All matrix views must share the same layout. + * * @tparam math_t data-type upon which the math operation will be performed * @tparam idx_t integer type used for indexing + * @tparam LayoutPolicy layout of the input/output matrices (raft::row_major or raft::col_major) * @param[in] handle raft::resources * @param[in] prms PCA parameters (n_components, algorithm, whiten, etc.) - * @param[inout] input the data is fitted to PCA. Size n_rows x n_cols (col-major). Modified + * @param[inout] input the data is fitted to PCA. Size n_rows x n_cols. Modified * temporarily during computation. - * @param[out] trans_input the transformed data. Size n_rows x n_components (col-major). - * @param[out] components the principal components of the input data. Size n_components x n_cols - * (col-major). + * @param[out] trans_input the transformed data. Size n_rows x n_components. + * @param[out] components the principal components of the input data. Size n_components x n_cols. * @param[out] explained_var explained variances (eigenvalues) of the principal components. Size * n_components. * @param[out] explained_var_ratio the ratio of the explained variance and total variance. Size @@ -82,12 +90,12 @@ void pca_fit(raft::resources const& handle, * @param[out] noise_vars variance of the noise. Scalar. * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) */ -template +template void pca_fit_transform(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view input, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, raft::device_vector_view explained_var, raft::device_vector_view explained_var_ratio, raft::device_vector_view singular_vals, @@ -111,51 +119,57 @@ void pca_fit_transform(raft::resources const& handle, /** * @brief performs inverse transform operation for PCA. Transforms the transformed data back to * original data. + * + * Supports both row-major and col-major layouts. All matrix views must share the same layout. + * * @tparam math_t data-type upon which the math operation will be performed * @tparam idx_t integer type used for indexing + * @tparam LayoutPolicy layout of the input/output matrices (raft::row_major or raft::col_major) * @param[in] handle raft::resources * @param[in] prms PCA parameters (n_components, algorithm, whiten, etc.) - * @param[in] trans_input the transformed data. Size n_rows x n_components (col-major). - * @param[in] components the principal components of the input data. Size n_components x n_cols - * (col-major). + * @param[in] trans_input the transformed data. Size n_rows x n_components. + * @param[in] components the principal components of the input data. Size n_components x n_cols. * @param[in] singular_vals singular values of the data. Size n_components. * @param[in] mu mean of features (every column). Size n_cols. - * @param[out] output the reconstructed data. Size n_rows x n_cols (col-major). + * @param[out] output the reconstructed data. Size n_rows x n_cols. */ -template +template void pca_inverse_transform(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, raft::device_vector_view singular_vals, raft::device_vector_view mu, - raft::device_matrix_view output) + raft::device_matrix_view output) { detail::pca_inverse_transform(handle, prms, trans_input, components, singular_vals, mu, output); } /** * @brief performs transform operation for PCA. Transforms the data to eigenspace. + * + * Supports both row-major and col-major layouts. All matrix views must share the same layout. + * * @tparam math_t data-type upon which the math operation will be performed * @tparam idx_t integer type used for indexing + * @tparam LayoutPolicy layout of the input/output matrices (raft::row_major or raft::col_major) * @param[in] handle raft::resources * @param[in] prms PCA parameters (n_components, algorithm, whiten, etc.) - * @param[inout] input the data to be transformed. Size n_rows x n_cols (col-major). Modified + * @param[inout] input the data to be transformed. Size n_rows x n_cols. Modified * temporarily during computation (mean-centered then restored). - * @param[in] components principal components of the input data. Size n_components x n_cols - * (col-major). + * @param[in] components principal components of the input data. Size n_components x n_cols. * @param[in] singular_vals singular values of the data. Size n_components. * @param[in] mu mean value of the input data. Size n_cols. - * @param[out] trans_input the transformed data. Size n_rows x n_components (col-major). + * @param[out] trans_input the transformed data. Size n_rows x n_components. */ -template +template void pca_transform(raft::resources const& handle, const paramsPCA& prms, - raft::device_matrix_view input, - raft::device_matrix_view components, + raft::device_matrix_view input, + raft::device_matrix_view components, raft::device_vector_view singular_vals, raft::device_vector_view mu, - raft::device_matrix_view trans_input) + raft::device_matrix_view trans_input) { detail::pca_transform(handle, prms, input, components, singular_vals, mu, trans_input); } diff --git a/cpp/tests/linalg/pca.cu b/cpp/tests/linalg/pca.cu index a1d80953e2..82ec96f86b 100644 --- a/cpp/tests/linalg/pca.cu +++ b/cpp/tests/linalg/pca.cu @@ -332,4 +332,115 @@ INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecF, ::testing::ValuesIn(inputsf2) INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecD, ::testing::ValuesIn(inputsd2)); +/** + * Row-major end-to-end tests. Verifies that the layout-templated + * ``pca_fit_transform`` + ``pca_inverse_transform`` reconstruct row-major + * inputs and produce explained variances numerically equivalent to those + * from the col-major path on the same logical data. + */ +template +class PcaRowMajorTest : public ::testing::TestWithParam> { + public: + PcaRowMajorTest() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)) + { + } + + protected: + void to_row_major(const T* col_major_src, T* row_major_dst, int n_rows, int n_cols) + { + std::vector host_col(n_rows * n_cols); + std::vector host_row(n_rows * n_cols); + raft::update_host(host_col.data(), col_major_src, n_rows * n_cols, stream); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (int i = 0; i < n_rows; ++i) { + for (int j = 0; j < n_cols; ++j) { + host_row[i * n_cols + j] = host_col[j * n_rows + i]; + } + } + raft::update_device(row_major_dst, host_row.data(), n_rows * n_cols, stream); + } + + void runRowMajor(int len, int n_rows, int n_cols, int n_components, T* input, T* recon) + { + paramsPCA prms; + prms.whiten = false; + if (params.algo == 0) + prms.algorithm = solver::COV_EIG_DQ; + else + prms.algorithm = solver::COV_EIG_JACOBI; + + rmm::device_uvector trans(static_cast(n_rows) * n_components, stream); + rmm::device_uvector components(static_cast(n_components) * n_cols, stream); + rmm::device_uvector ev(n_components, stream); + rmm::device_uvector evr(n_components, stream); + rmm::device_uvector sv(n_components, stream); + rmm::device_uvector mu(n_cols, stream); + rmm::device_uvector nv(1, stream); + + auto input_view = raft::make_device_matrix_view( + input, static_cast(n_rows), static_cast(n_cols)); + auto trans_view = raft::make_device_matrix_view( + trans.data(), static_cast(n_rows), static_cast(n_components)); + auto comp_view = raft::make_device_matrix_view( + components.data(), static_cast(n_components), static_cast(n_cols)); + auto ev_view = raft::make_device_vector_view(ev.data(), n_components); + auto evr_view = raft::make_device_vector_view(evr.data(), n_components); + auto sv_view = raft::make_device_vector_view(sv.data(), n_components); + auto mu_view = raft::make_device_vector_view(mu.data(), n_cols); + auto nv_view = raft::make_device_scalar_view(nv.data()); + + pca_fit_transform(handle, + prms, + input_view, + trans_view, + comp_view, + ev_view, + evr_view, + sv_view, + mu_view, + nv_view); + + auto recon_view = raft::make_device_matrix_view( + recon, static_cast(n_rows), static_cast(n_cols)); + pca_inverse_transform(handle, prms, trans_view, comp_view, sv_view, mu_view, recon_view); + } + + void testRowMajorRoundtrip() + { + int n_rows = params.n_row2; + int n_cols = params.n_col2; + int len = n_rows * n_cols; + int n_components = n_cols; + + rmm::device_uvector data_col(len, stream); + rmm::device_uvector data_row(len, stream); + rmm::device_uvector data_back(len, stream); + + raft::random::Rng r(params.seed, raft::random::GenPC); + r.uniform(data_col.data(), len, T(-1.0), T(1.0), stream); + to_row_major(data_col.data(), data_row.data(), n_rows, n_cols); + + runRowMajor(len, n_rows, n_cols, n_components, data_row.data(), data_back.data()); + + ASSERT_TRUE(devArrMatch( + data_row.data(), data_back.data(), len, raft::CompareApprox(params.tolerance), stream)); + } + + private: + raft::device_resources handle; + cudaStream_t stream; + PcaInputs params; +}; + +typedef PcaRowMajorTest PcaRowMajorTestF; +TEST_P(PcaRowMajorTestF, Roundtrip) { this->testRowMajorRoundtrip(); } + +typedef PcaRowMajorTest PcaRowMajorTestD; +TEST_P(PcaRowMajorTestD, Roundtrip) { this->testRowMajorRoundtrip(); } + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaRowMajorTestF, ::testing::ValuesIn(inputsf2)); +INSTANTIATE_TEST_CASE_P(PcaTests, PcaRowMajorTestD, ::testing::ValuesIn(inputsd2)); + } // end namespace raft::linalg From 2d9e0e2d996bf8bea616acee909eb0745a8516fb Mon Sep 17 00:00:00 2001 From: Anupam <54245698+aamijar@users.noreply.github.com> Date: Sat, 27 Jun 2026 13:17:17 -0400 Subject: [PATCH 2/7] Update cpp/include/raft/linalg/detail/pca.cuh Co-authored-by: Divye Gala --- cpp/include/raft/linalg/detail/pca.cuh | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cpp/include/raft/linalg/detail/pca.cuh b/cpp/include/raft/linalg/detail/pca.cuh index e50961f477..4d09a8579d 100644 --- a/cpp/include/raft/linalg/detail/pca.cuh +++ b/cpp/include/raft/linalg/detail/pca.cuh @@ -158,10 +158,12 @@ void pca_fit(raft::resources const& handle, // The eigendecomposition of the (symmetric) covariance matrix naturally produces a // col-major components buffer. For row-major output we accumulate into a temporary // and physically transpose at the end. - auto components_col_storage = raft::make_device_matrix( - handle, input_row_major ? n_components : idx_t(0), input_row_major ? n_cols : idx_t(0)); + std::optional> components_col_storage; + if constexpr (input_row_major) { + components_col_storage = raft::make_device_matrix_view<...)(...); + } math_t* components_col_data = - input_row_major ? components_col_storage.data_handle() : components.data_handle(); + input_row_major ? components_col_storage->data_handle() : components.data_handle(); auto components_col_view = raft::make_device_matrix_view( components_col_data, n_components, n_cols); From 48405e34b8edc36d996a29dedd706699bca0b5f8 Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 29 Jun 2026 04:20:55 +0000 Subject: [PATCH 3/7] LayoutPolicy cal_eig --- cpp/include/raft/linalg/detail/pca.cuh | 64 ++++++++++------------- cpp/include/raft/linalg/detail/tsvd.cuh | 68 +++++++++++++++---------- 2 files changed, 69 insertions(+), 63 deletions(-) diff --git a/cpp/include/raft/linalg/detail/pca.cuh b/cpp/include/raft/linalg/detail/pca.cuh index 4d09a8579d..824c4d4619 100644 --- a/cpp/include/raft/linalg/detail/pca.cuh +++ b/cpp/include/raft/linalg/detail/pca.cuh @@ -29,16 +29,22 @@ namespace raft { namespace linalg::detail { -template +template void trunc_comp_exp_vars(raft::resources const& handle, const paramsTSVD& prms, - raft::device_matrix_view in, - raft::device_matrix_view components, + raft::device_matrix_view in, + raft::device_matrix_view components, raft::device_vector_view explained_var, raft::device_vector_view explained_var_ratio, raft::device_scalar_view noise_vars, std::size_t n_rows) { + static_assert( + std::is_same_v || std::is_same_v, + "trunc_comp_exp_vars: layout must be raft::row_major or raft::col_major"); + + constexpr bool is_row_major = std::is_same_v; + auto stream = resource::get_cuda_stream(handle); auto n_cols = in.extent(0); @@ -49,19 +55,26 @@ void trunc_comp_exp_vars(raft::resources const& handle, rmm::device_uvector explained_var_all(static_cast(n_cols), stream); rmm::device_uvector explained_var_ratio_all(static_cast(n_cols), stream); - detail::cal_eig( + detail::cal_eig( handle, prms, in, - raft::make_device_matrix_view( + raft::make_device_matrix_view( components_all.data(), n_cols, n_cols), raft::make_device_vector_view(explained_var_all.data(), n_cols)); - raft::matrix::trunc_zero_origin( - handle, - raft::make_device_matrix_view( - components_all.data(), n_cols, n_cols), - raft::make_device_matrix_view( - components.data_handle(), n_components, n_cols)); + if constexpr (is_row_major) { + raft::copy(components.data_handle(), + components_all.data(), + static_cast(n_components) * static_cast(n_cols), + stream); + } else { + raft::matrix::trunc_zero_origin( + handle, + raft::make_device_matrix_view( + components_all.data(), n_cols, n_cols), + raft::make_device_matrix_view( + components.data_handle(), n_components, n_cols)); + } raft::matrix::ratio(handle, raft::make_device_matrix_view( explained_var_all.data(), n_cols, idx_t(1)), @@ -155,23 +168,11 @@ void pca_fit(raft::resources const& handle, raft::stats::cov( handle, cov.data(), input.data_handle(), mu.data_handle(), n_cols, n_rows, true, true, stream); - // The eigendecomposition of the (symmetric) covariance matrix naturally produces a - // col-major components buffer. For row-major output we accumulate into a temporary - // and physically transpose at the end. - std::optional> components_col_storage; - if constexpr (input_row_major) { - components_col_storage = raft::make_device_matrix_view<...)(...); - } - math_t* components_col_data = - input_row_major ? components_col_storage->data_handle() : components.data_handle(); - auto components_col_view = raft::make_device_matrix_view( - components_col_data, n_components, n_cols); - - detail::trunc_comp_exp_vars( + detail::trunc_comp_exp_vars( handle, prms, - raft::make_device_matrix_view(cov.data(), n_cols, n_cols), - components_col_view, + raft::make_device_matrix_view(cov.data(), n_cols, n_cols), + components, explained_var, explained_var_ratio, noise_vars, @@ -189,16 +190,7 @@ void pca_fit(raft::resources const& handle, raft::stats::meanAdd( input.data_handle(), input.data_handle(), mu.data_handle(), n_cols, n_rows, stream); - detail::sign_flip_components(handle, input, components_col_view, true, flip_signs_based_on_U); - - if constexpr (input_row_major) { - // Transpose the internal col-major (n_components x n_cols) components into the user's - // row-major (n_components x n_cols) buffer. The same memory laid out as col-major - // (n_cols x n_components) is exactly the row-major (n_components x n_cols) we want. - auto components_as_col_view = raft::make_device_matrix_view( - components.data_handle(), n_cols, n_components); - raft::linalg::transpose(handle, components_col_view, components_as_col_view); - } + detail::sign_flip_components(handle, input, components, true, flip_signs_based_on_U); } /** diff --git a/cpp/include/raft/linalg/detail/tsvd.cuh b/cpp/include/raft/linalg/detail/tsvd.cuh index b503aa282a..fff9ca8a2d 100644 --- a/cpp/include/raft/linalg/detail/tsvd.cuh +++ b/cpp/include/raft/linalg/detail/tsvd.cuh @@ -106,13 +106,19 @@ void cal_comp_exp_vars_svd(raft::resources const& handle, handle, explained_vars.data_handle(), explained_var_ratio.data_handle(), n_components, stream); } -template +template void cal_eig(raft::resources const& handle, const paramsTSVD& prms, - raft::device_matrix_view in, - raft::device_matrix_view components, + raft::device_matrix_view in, + raft::device_matrix_view components, raft::device_vector_view explained_var) { + static_assert( + std::is_same_v || std::is_same_v, + "cal_eig: layout must be raft::row_major or raft::col_major"); + + constexpr bool is_row_major = std::is_same_v; + auto stream = resource::get_cuda_stream(handle); auto cusolver_handle = raft::resource::get_cusolver_dn_handle(handle); @@ -143,7 +149,9 @@ void cal_eig(raft::resources const& handle, raft::matrix::col_reverse(handle_stream_zero, raft::make_device_matrix_view( components.data_handle(), n_cols, n_cols)); - raft::linalg::transpose(components.data_handle(), n_cols, stream); + if constexpr (!is_row_major) { + raft::linalg::transpose(components.data_handle(), n_cols, stream); + } raft::matrix::row_reverse(handle_stream_zero, raft::make_device_matrix_view( @@ -151,31 +159,31 @@ void cal_eig(raft::resources const& handle, } /** - * @brief sign flip for PCA and tSVD. Stabilizes the sign of column major eigenvectors. + * @brief sign flip for PCA and tSVD. Stabilizes the sign of the eigenvectors. * - * The components matrix is always stored in col-major; the input matrix may be either - * row-major or col-major (deduced from the LayoutPolicy template parameter). + * The input and components matrices share a single layout, deduced from the LayoutPolicy + * template parameter (raft::row_major or raft::col_major). * * @tparam math_t element type * @tparam idx_t index type - * @tparam LayoutPolicy layout of the input matrix (raft::row_major or raft::col_major) + * @tparam LayoutPolicy layout of the input and components matrices * @param handle: raft::resources * @param input: input data [n_samples x n_features] - * @param components: components matrix [n_components x n_features] (col-major) + * @param components: components matrix [n_components x n_features] * @param center whether to mean-center input before computing signs * @param flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) */ -template +template void sign_flip_components(raft::resources const& handle, raft::device_matrix_view input, - raft::device_matrix_view components, + raft::device_matrix_view components, bool center, bool flip_signs_based_on_U = false) { static_assert( std::is_same_v || std::is_same_v, - "sign_flip_components: input layout must be raft::row_major or raft::col_major"); - constexpr bool input_row_major = std::is_same_v; + "sign_flip_components: layout must be raft::row_major or raft::col_major"); + constexpr bool is_row_major = std::is_same_v; auto stream = resource::get_cuda_stream(handle); auto n_samples = input.extent(0); @@ -183,25 +191,23 @@ void sign_flip_components(raft::resources const& handle, auto n_components = components.extent(0); rmm::device_uvector max_vals(static_cast(n_components), stream); - auto components_view = raft::make_device_matrix_view( + auto components_view = raft::make_device_matrix_view( components.data_handle(), n_components, n_features); auto max_vals_view = raft::make_device_vector_view(max_vals.data(), n_components); if (flip_signs_based_on_U) { if (center) { rmm::device_uvector col_means(static_cast(n_features), stream); - raft::stats::mean( + raft::stats::mean( col_means.data(), input.data_handle(), n_features, n_samples, stream); - raft::stats::meanCenter( + raft::stats::meanCenter( input.data_handle(), input.data_handle(), col_means.data(), n_features, n_samples, stream); } - // US = input @ components^T, shape (n_samples x n_components), in input's layout. - // The components matrix is col-major (n_components x n_features); reinterpreting the - // same memory as row-major (n_features x n_components) yields the transpose. rmm::device_uvector US(static_cast(n_samples * n_components), stream); - using transposed_layout = std::conditional_t; + using components_transposed_layout = + std::conditional_t; auto components_transposed_view = - raft::make_device_matrix_view( + raft::make_device_matrix_view( components.data_handle(), n_features, n_components); auto US_view = raft::make_device_matrix_view( US.data(), n_samples, n_components); @@ -211,7 +217,7 @@ void sign_flip_components(raft::resources const& handle, // Per-column reduction of US (n_samples x n_components) yields one max-abs value per // component. With the (rowMajor, alongRows) convention, alongRows=false produces D // outputs (one per column) regardless of layout; only the memory access pattern differs. - raft::linalg::reduce( + raft::linalg::reduce( max_vals.data(), US.data(), n_components, @@ -227,9 +233,10 @@ void sign_flip_components(raft::resources const& handle, }, raft::identity_op()); } else { - // components is col-major (n_components x n_features); reduce per row to get one - // max-abs value per component. - raft::linalg::reduce( + // Reduce the components matrix (n_components x n_features) per component (per row) to + // get one max-abs value each. alongRows=true produces N outputs (one per row) + // regardless of layout; rowMajor only changes the memory access pattern. + raft::linalg::reduce( max_vals.data(), components.data_handle(), n_features, @@ -250,8 +257,15 @@ void sign_flip_components(raft::resources const& handle, handle, components_view, [components_view, max_vals_view, n_components, n_features] __device__(auto idx) { - auto row = idx % n_components; - auto column = idx / n_components; + idx_t row; + idx_t column; + if constexpr (is_row_major) { + row = idx / n_features; + column = idx % n_features; + } else { + row = idx % n_components; + column = idx / n_components; + } return (max_vals_view(row) < math_t(0)) ? (-components_view(row, column)) : components_view(row, column); }); From 4c297740479dcef71c31aa5ce105b63778e9f199 Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 29 Jun 2026 04:32:42 +0000 Subject: [PATCH 4/7] refactor inline view --- cpp/include/raft/linalg/detail/pca.cuh | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/linalg/detail/pca.cuh b/cpp/include/raft/linalg/detail/pca.cuh index 824c4d4619..d2ecf8089c 100644 --- a/cpp/include/raft/linalg/detail/pca.cuh +++ b/cpp/include/raft/linalg/detail/pca.cuh @@ -238,19 +238,14 @@ void pca_transform(raft::resources const& handle, rmm::device_uvector components_copy{components_len, stream}; raft::copy(components_copy.data(), components.data_handle(), components_len, stream); - auto components_copy_view = raft::make_device_matrix_view( - components_copy.data(), n_components, n_cols); - if (prms.whiten) { math_t scalar = math_t(sqrt(n_rows - 1)); raft::linalg::scalarMultiply( components_copy.data(), components_copy.data(), scalar, components_len, stream); - // Divide each row of (n_components x n_cols) components by the corresponding singular - // value. Apply::ALONG_COLUMNS broadcasts a vector of size n_rows-of-matrix - // (= n_components) over each column, which is the same operation in both layouts. raft::linalg::binary_div_skip_zero( handle, - components_copy_view, + raft::make_device_matrix_view( + components_copy.data(), n_components, n_cols), raft::make_device_vector_view(singular_vals.data_handle(), n_components)); } @@ -315,9 +310,6 @@ void pca_inverse_transform(raft::resources const& handle, rmm::device_uvector components_copy{components_len, stream}; raft::copy(components_copy.data(), components.data_handle(), components_len, stream); - auto components_copy_view = raft::make_device_matrix_view( - components_copy.data(), n_components, n_cols); - if (prms.whiten) { math_t sqrt_n_samples = sqrt(n_rows - 1); math_t scalar = n_rows - 1 > 0 ? math_t(1 / sqrt_n_samples) : 0; @@ -325,14 +317,19 @@ void pca_inverse_transform(raft::resources const& handle, components_copy.data(), components_copy.data(), scalar, components_len, stream); raft::linalg::binary_mult_skip_zero( handle, - components_copy_view, + raft::make_device_matrix_view( + components_copy.data(), n_components, n_cols), raft::make_device_vector_view(singular_vals.data_handle(), n_components)); } // output = trans_input @ components_copy. All three matrices share the user's layout, // so the mdspan gemm picks the correct cuBLAS transposes automatically. - raft::linalg::gemm(handle, trans_input, components_copy_view, output); + raft::linalg::gemm(handle, + trans_input, + raft::make_device_matrix_view( + components_copy.data(), n_components, n_cols), + output); raft::stats::meanAdd( output.data_handle(), output.data_handle(), mu.data_handle(), n_cols, n_rows, stream); From f0b752793ef7ef151f905abd712b7afdc02ed334 Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 29 Jun 2026 05:05:30 +0000 Subject: [PATCH 5/7] clean up --- cpp/include/raft/linalg/detail/tsvd.cuh | 26 ++--- cpp/tests/linalg/pca.cu | 146 +++--------------------- 2 files changed, 26 insertions(+), 146 deletions(-) diff --git a/cpp/include/raft/linalg/detail/tsvd.cuh b/cpp/include/raft/linalg/detail/tsvd.cuh index fff9ca8a2d..89764a8a5f 100644 --- a/cpp/include/raft/linalg/detail/tsvd.cuh +++ b/cpp/include/raft/linalg/detail/tsvd.cuh @@ -204,19 +204,16 @@ void sign_flip_components(raft::resources const& handle, input.data_handle(), input.data_handle(), col_means.data(), n_features, n_samples, stream); } rmm::device_uvector US(static_cast(n_samples * n_components), stream); - using components_transposed_layout = - std::conditional_t; - auto components_transposed_view = - raft::make_device_matrix_view( - components.data_handle(), n_features, n_components); - auto US_view = raft::make_device_matrix_view( - US.data(), n_samples, n_components); - - raft::linalg::gemm(handle, input, components_transposed_view, US_view); - - // Per-column reduction of US (n_samples x n_components) yields one max-abs value per - // component. With the (rowMajor, alongRows) convention, alongRows=false produces D - // outputs (one per column) regardless of layout; only the memory access pattern differs. + raft::linalg::gemm(handle, + input, + raft::make_device_matrix_view< + math_t, + idx_t, + std::conditional_t>( + components.data_handle(), n_features, n_components), + raft::make_device_matrix_view( + US.data(), n_samples, n_components)); + raft::linalg::reduce( max_vals.data(), US.data(), @@ -233,9 +230,6 @@ void sign_flip_components(raft::resources const& handle, }, raft::identity_op()); } else { - // Reduce the components matrix (n_components x n_features) per component (per row) to - // get one max-abs value each. alongRows=true produces N outputs (one per row) - // regardless of layout; rowMajor only changes the memory access pattern. raft::linalg::reduce( max_vals.data(), components.data_handle(), diff --git a/cpp/tests/linalg/pca.cu b/cpp/tests/linalg/pca.cu index 82ec96f86b..6d95a342d8 100644 --- a/cpp/tests/linalg/pca.cu +++ b/cpp/tests/linalg/pca.cu @@ -51,12 +51,9 @@ class PcaTest : public ::testing::TestWithParam> { trans_data(params.len, stream), trans_data_ref(params.len, stream), data(params.len, stream), - data_back(params.len, stream), - data2(params.len2, stream), - data2_back(params.len2, stream) + data_back(params.len, stream) { basicTest(); - advancedTest(); } protected: @@ -135,6 +132,7 @@ class PcaTest : public ::testing::TestWithParam> { handle, prms, trans_data_view, components_view, singular_vals_view, mu_view, data_back_view); } + template void advancedTest() { raft::random::Rng r(params.seed, raft::random::GenPC); @@ -151,6 +149,8 @@ class PcaTest : public ::testing::TestWithParam> { else if (params.algo == 1) prms.algorithm = solver::COV_EIG_JACOBI; + rmm::device_uvector data2(len, stream); + rmm::device_uvector data2_back(len, stream); r.uniform(data2.data(), len, T(-1.0), T(1.0), stream); rmm::device_uvector data2_trans(n_rows * n_components, stream); @@ -163,10 +163,10 @@ class PcaTest : public ::testing::TestWithParam> { rmm::device_uvector noise_vars2(1, stream); auto input_view = - raft::make_device_matrix_view(data2.data(), n_rows, n_cols); - auto trans_view = raft::make_device_matrix_view( + raft::make_device_matrix_view(data2.data(), n_rows, n_cols); + auto trans_view = raft::make_device_matrix_view( data2_trans.data(), n_rows, n_components); - auto comp_view = raft::make_device_matrix_view( + auto comp_view = raft::make_device_matrix_view( components2.data(), n_components, n_cols); auto ev_view = raft::make_device_vector_view(explained_vars2.data(), n_components); @@ -188,10 +188,13 @@ class PcaTest : public ::testing::TestWithParam> { mu_view, noise_view); - auto data2_back_view = raft::make_device_matrix_view( + auto data2_back_view = raft::make_device_matrix_view( data2_back.data(), n_rows, n_cols); pca_inverse_transform(handle, prms, trans_view, comp_view, sv_view, mu_view, data2_back_view); + + ASSERT_TRUE(devArrMatch( + data2.data(), data2_back.data(), len, raft::CompareApprox(params.tolerance), stream)); } protected: @@ -201,7 +204,7 @@ class PcaTest : public ::testing::TestWithParam> { PcaInputs params; rmm::device_uvector explained_vars, explained_vars_ref, components, components_ref, trans_data, - trans_data_ref, data, data_back, data2, data2_back; + trans_data_ref, data, data_back; }; const std::vector> inputsf2 = { @@ -295,21 +298,15 @@ TEST_P(PcaTestDataVecSmallD, Result) typedef PcaTest PcaTestDataVecF; TEST_P(PcaTestDataVecF, Result) { - ASSERT_TRUE(devArrMatch(data2.data(), - data2_back.data(), - (params.n_col2 * params.n_col2), - raft::CompareApprox(params.tolerance), - resource::get_cuda_stream(handle))); + this->template advancedTest(); + this->template advancedTest(); } typedef PcaTest PcaTestDataVecD; TEST_P(PcaTestDataVecD, Result) { - ASSERT_TRUE(devArrMatch(data2.data(), - data2_back.data(), - (params.n_col2 * params.n_col2), - raft::CompareApprox(params.tolerance), - resource::get_cuda_stream(handle))); + this->template advancedTest(); + this->template advancedTest(); } INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestValF, ::testing::ValuesIn(inputsf2)); @@ -332,115 +329,4 @@ INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecF, ::testing::ValuesIn(inputsf2) INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecD, ::testing::ValuesIn(inputsd2)); -/** - * Row-major end-to-end tests. Verifies that the layout-templated - * ``pca_fit_transform`` + ``pca_inverse_transform`` reconstruct row-major - * inputs and produce explained variances numerically equivalent to those - * from the col-major path on the same logical data. - */ -template -class PcaRowMajorTest : public ::testing::TestWithParam> { - public: - PcaRowMajorTest() - : params(::testing::TestWithParam>::GetParam()), - stream(resource::get_cuda_stream(handle)) - { - } - - protected: - void to_row_major(const T* col_major_src, T* row_major_dst, int n_rows, int n_cols) - { - std::vector host_col(n_rows * n_cols); - std::vector host_row(n_rows * n_cols); - raft::update_host(host_col.data(), col_major_src, n_rows * n_cols, stream); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); - for (int i = 0; i < n_rows; ++i) { - for (int j = 0; j < n_cols; ++j) { - host_row[i * n_cols + j] = host_col[j * n_rows + i]; - } - } - raft::update_device(row_major_dst, host_row.data(), n_rows * n_cols, stream); - } - - void runRowMajor(int len, int n_rows, int n_cols, int n_components, T* input, T* recon) - { - paramsPCA prms; - prms.whiten = false; - if (params.algo == 0) - prms.algorithm = solver::COV_EIG_DQ; - else - prms.algorithm = solver::COV_EIG_JACOBI; - - rmm::device_uvector trans(static_cast(n_rows) * n_components, stream); - rmm::device_uvector components(static_cast(n_components) * n_cols, stream); - rmm::device_uvector ev(n_components, stream); - rmm::device_uvector evr(n_components, stream); - rmm::device_uvector sv(n_components, stream); - rmm::device_uvector mu(n_cols, stream); - rmm::device_uvector nv(1, stream); - - auto input_view = raft::make_device_matrix_view( - input, static_cast(n_rows), static_cast(n_cols)); - auto trans_view = raft::make_device_matrix_view( - trans.data(), static_cast(n_rows), static_cast(n_components)); - auto comp_view = raft::make_device_matrix_view( - components.data(), static_cast(n_components), static_cast(n_cols)); - auto ev_view = raft::make_device_vector_view(ev.data(), n_components); - auto evr_view = raft::make_device_vector_view(evr.data(), n_components); - auto sv_view = raft::make_device_vector_view(sv.data(), n_components); - auto mu_view = raft::make_device_vector_view(mu.data(), n_cols); - auto nv_view = raft::make_device_scalar_view(nv.data()); - - pca_fit_transform(handle, - prms, - input_view, - trans_view, - comp_view, - ev_view, - evr_view, - sv_view, - mu_view, - nv_view); - - auto recon_view = raft::make_device_matrix_view( - recon, static_cast(n_rows), static_cast(n_cols)); - pca_inverse_transform(handle, prms, trans_view, comp_view, sv_view, mu_view, recon_view); - } - - void testRowMajorRoundtrip() - { - int n_rows = params.n_row2; - int n_cols = params.n_col2; - int len = n_rows * n_cols; - int n_components = n_cols; - - rmm::device_uvector data_col(len, stream); - rmm::device_uvector data_row(len, stream); - rmm::device_uvector data_back(len, stream); - - raft::random::Rng r(params.seed, raft::random::GenPC); - r.uniform(data_col.data(), len, T(-1.0), T(1.0), stream); - to_row_major(data_col.data(), data_row.data(), n_rows, n_cols); - - runRowMajor(len, n_rows, n_cols, n_components, data_row.data(), data_back.data()); - - ASSERT_TRUE(devArrMatch( - data_row.data(), data_back.data(), len, raft::CompareApprox(params.tolerance), stream)); - } - - private: - raft::device_resources handle; - cudaStream_t stream; - PcaInputs params; -}; - -typedef PcaRowMajorTest PcaRowMajorTestF; -TEST_P(PcaRowMajorTestF, Roundtrip) { this->testRowMajorRoundtrip(); } - -typedef PcaRowMajorTest PcaRowMajorTestD; -TEST_P(PcaRowMajorTestD, Roundtrip) { this->testRowMajorRoundtrip(); } - -INSTANTIATE_TEST_CASE_P(PcaTests, PcaRowMajorTestF, ::testing::ValuesIn(inputsf2)); -INSTANTIATE_TEST_CASE_P(PcaTests, PcaRowMajorTestD, ::testing::ValuesIn(inputsd2)); - } // end namespace raft::linalg From b88550d2355badc2ef6b5a5a496900dc65d7ac41 Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 29 Jun 2026 05:31:18 +0000 Subject: [PATCH 6/7] clean up tsvd --- cpp/include/raft/linalg/detail/pca.cuh | 28 ++++----- cpp/include/raft/linalg/detail/tsvd.cuh | 75 ++++++++++--------------- 2 files changed, 43 insertions(+), 60 deletions(-) diff --git a/cpp/include/raft/linalg/detail/pca.cuh b/cpp/include/raft/linalg/detail/pca.cuh index d2ecf8089c..542d51d6a7 100644 --- a/cpp/include/raft/linalg/detail/pca.cuh +++ b/cpp/include/raft/linalg/detail/pca.cuh @@ -253,13 +253,13 @@ void pca_transform(raft::resources const& handle, raft::stats::meanCenter( input.data_handle(), input.data_handle(), mu.data_handle(), n_cols, n_rows, stream); - // trans_input = input @ components_copy^T, in the user's layout. - // Reinterpreting the components_copy buffer with the opposite layout swaps the logical - // dimensions, giving us the (n_cols x n_components) transposed view we need for gemm. - using transposed_layout = std::conditional_t; - auto components_copy_transposed = raft::make_device_matrix_view( - components_copy.data(), n_cols, n_components); - raft::linalg::gemm(handle, input, components_copy_transposed, trans_input); + detail::tsvd_transform( + handle, + prms, + input, + raft::make_device_matrix_view( + components_copy.data(), n_components, n_cols), + trans_input); raft::stats::meanAdd( input.data_handle(), input.data_handle(), mu.data_handle(), n_cols, n_rows, stream); @@ -323,13 +323,13 @@ void pca_inverse_transform(raft::resources const& handle, n_components)); } - // output = trans_input @ components_copy. All three matrices share the user's layout, - // so the mdspan gemm picks the correct cuBLAS transposes automatically. - raft::linalg::gemm(handle, - trans_input, - raft::make_device_matrix_view( - components_copy.data(), n_components, n_cols), - output); + detail::tsvd_inverse_transform( + handle, + prms, + trans_input, + raft::make_device_matrix_view( + components_copy.data(), n_components, n_cols), + output); raft::stats::meanAdd( output.data_handle(), output.data_handle(), mu.data_handle(), n_cols, n_rows, stream); diff --git a/cpp/include/raft/linalg/detail/tsvd.cuh b/cpp/include/raft/linalg/detail/tsvd.cuh index 89764a8a5f..f997851a84 100644 --- a/cpp/include/raft/linalg/detail/tsvd.cuh +++ b/cpp/include/raft/linalg/detail/tsvd.cuh @@ -402,18 +402,21 @@ void tsvd_fit(raft::resources const& handle, * @brief performs transform operation for the tsvd. Transforms the data to eigenspace. * @param[in] handle raft::resources * @param[in] prms: data structure that includes all the parameters from input size to algorithm. - * @param[in] input: the data to transform. Size n_rows x n_cols (col-major). - * @param[in] components: principal components. Size n_components x n_cols (col-major). - * @param[out] trans_input: transformed output. Size n_rows x n_components (col-major). + * @param[in] input: the data to transform. Size n_rows x n_cols. + * @param[in] components: principal components. Size n_components x n_cols. + * @param[out] trans_input: transformed output. Size n_rows x n_components. */ -template +template void tsvd_transform(raft::resources const& handle, const paramsTSVD& prms, - raft::device_matrix_view input, - raft::device_matrix_view components, - raft::device_matrix_view trans_input) + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_matrix_view trans_input) { - auto stream = resource::get_cuda_stream(handle); + static_assert( + std::is_same_v || std::is_same_v, + "tsvd_transform: layout must be raft::row_major or raft::col_major"); + constexpr bool is_row_major = std::is_same_v; auto n_rows = input.extent(0); auto n_cols = input.extent(1); @@ -423,39 +426,34 @@ void tsvd_transform(raft::resources const& handle, ASSERT(n_rows > 0, "Parameter n_rows: number of rows cannot be less than one"); ASSERT(n_components > 0, "Parameter n_components: number of components cannot be less than one"); - math_t alpha = math_t(1); - math_t beta = math_t(0); raft::linalg::gemm(handle, - input.data_handle(), - n_rows, - n_cols, - components.data_handle(), - trans_input.data_handle(), - n_rows, - n_components, - CUBLAS_OP_N, - CUBLAS_OP_T, - alpha, - beta, - stream); + input, + raft::make_device_matrix_view< + math_t, + idx_t, + std::conditional_t>( + components.data_handle(), n_cols, n_components), + trans_input); } /** * @brief performs inverse transform operation for the tsvd. * @param[in] handle raft::resources * @param[in] prms: data structure that includes all the parameters from input size to algorithm. - * @param[in] trans_input: the transformed data. Size n_rows x n_components (col-major). - * @param[in] components: principal components. Size n_components x n_cols (col-major). - * @param[out] output: reconstructed output. Size n_rows x n_cols (col-major). + * @param[in] trans_input: the transformed data. Size n_rows x n_components. + * @param[in] components: principal components. Size n_components x n_cols. + * @param[out] output: reconstructed output. Size n_rows x n_cols. */ -template +template void tsvd_inverse_transform(raft::resources const& handle, const paramsTSVD& prms, - raft::device_matrix_view trans_input, - raft::device_matrix_view components, - raft::device_matrix_view output) + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_matrix_view output) { - auto stream = resource::get_cuda_stream(handle); + static_assert( + std::is_same_v || std::is_same_v, + "tsvd_inverse_transform: layout must be raft::row_major or raft::col_major"); auto n_rows = output.extent(0); auto n_cols = output.extent(1); @@ -465,22 +463,7 @@ void tsvd_inverse_transform(raft::resources const& handle, ASSERT(n_rows > 0, "Parameter n_rows: number of rows cannot be less than one"); ASSERT(n_components > 0, "Parameter n_components: number of components cannot be less than one"); - math_t alpha = math_t(1); - math_t beta = math_t(0); - - raft::linalg::gemm(handle, - trans_input.data_handle(), - n_rows, - n_components, - components.data_handle(), - output.data_handle(), - n_rows, - n_cols, - CUBLAS_OP_N, - CUBLAS_OP_N, - alpha, - beta, - stream); + raft::linalg::gemm(handle, trans_input, components, output); } /** From 4965cf6c604343faf253bed5bcefb25ec228d3ca Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 29 Jun 2026 05:49:02 +0000 Subject: [PATCH 7/7] clean up --- cpp/include/raft/linalg/detail/pca.cuh | 15 --------------- cpp/include/raft/linalg/detail/tsvd.cuh | 4 ---- cpp/include/raft/linalg/pca.cuh | 13 ------------- 3 files changed, 32 deletions(-) diff --git a/cpp/include/raft/linalg/detail/pca.cuh b/cpp/include/raft/linalg/detail/pca.cuh index 542d51d6a7..0c56b2a42f 100644 --- a/cpp/include/raft/linalg/detail/pca.cuh +++ b/cpp/include/raft/linalg/detail/pca.cuh @@ -111,10 +111,6 @@ void trunc_comp_exp_vars(raft::resources const& handle, /** * @brief perform fit operation for PCA. - * - * Supports both row-major and col-major input layouts via the LayoutPolicy template - * parameter. The output `components` matrix has the same layout as the input. - * * @tparam math_t element type * @tparam idx_t index type * @tparam LayoutPolicy layout of the input matrix (raft::row_major or raft::col_major) @@ -195,10 +191,6 @@ void pca_fit(raft::resources const& handle, /** * @brief performs transform operation for PCA. Transforms the data to eigenspace. - * - * Supports both row-major and col-major layouts via the LayoutPolicy template parameter. - * `input`, `components`, and `trans_input` must all share the same layout. - * * @tparam math_t element type * @tparam idx_t index type * @tparam LayoutPolicy layout (raft::row_major or raft::col_major) @@ -267,10 +259,6 @@ void pca_transform(raft::resources const& handle, /** * @brief performs inverse transform operation for PCA. - * - * Supports both row-major and col-major layouts via the LayoutPolicy template parameter. - * `trans_input`, `components`, and `output` must all share the same layout. - * * @tparam math_t element type * @tparam idx_t index type * @tparam LayoutPolicy layout (raft::row_major or raft::col_major) @@ -337,9 +325,6 @@ void pca_inverse_transform(raft::resources const& handle, /** * @brief perform fit and transform operations for PCA. - * - * Supports both row-major and col-major layouts via the LayoutPolicy template parameter. - * * @tparam math_t element type * @tparam idx_t index type * @tparam LayoutPolicy layout (raft::row_major or raft::col_major) diff --git a/cpp/include/raft/linalg/detail/tsvd.cuh b/cpp/include/raft/linalg/detail/tsvd.cuh index f997851a84..32ce9c9275 100644 --- a/cpp/include/raft/linalg/detail/tsvd.cuh +++ b/cpp/include/raft/linalg/detail/tsvd.cuh @@ -160,10 +160,6 @@ void cal_eig(raft::resources const& handle, /** * @brief sign flip for PCA and tSVD. Stabilizes the sign of the eigenvectors. - * - * The input and components matrices share a single layout, deduced from the LayoutPolicy - * template parameter (raft::row_major or raft::col_major). - * * @tparam math_t element type * @tparam idx_t index type * @tparam LayoutPolicy layout of the input and components matrices diff --git a/cpp/include/raft/linalg/pca.cuh b/cpp/include/raft/linalg/pca.cuh index 322bff2cc0..93cabb5c4f 100644 --- a/cpp/include/raft/linalg/pca.cuh +++ b/cpp/include/raft/linalg/pca.cuh @@ -20,10 +20,6 @@ namespace linalg { /** * @brief perform fit operation for PCA. Generates eigenvectors, explained vars, singular vals, etc. - * - * Supports both row-major and col-major layouts. The layout is deduced from the input view's - * `LayoutPolicy` and must match between `input` and `components`. - * * @tparam math_t data-type upon which the math operation will be performed * @tparam idx_t integer type used for indexing * @tparam LayoutPolicy layout of the input/components matrices (raft::row_major or @@ -69,9 +65,6 @@ void pca_fit(raft::resources const& handle, /** * @brief perform fit and transform operations for PCA. Generates transformed data, * eigenvectors, explained vars, singular vals, etc. - * - * Supports both row-major and col-major layouts. All matrix views must share the same layout. - * * @tparam math_t data-type upon which the math operation will be performed * @tparam idx_t integer type used for indexing * @tparam LayoutPolicy layout of the input/output matrices (raft::row_major or raft::col_major) @@ -119,9 +112,6 @@ void pca_fit_transform(raft::resources const& handle, /** * @brief performs inverse transform operation for PCA. Transforms the transformed data back to * original data. - * - * Supports both row-major and col-major layouts. All matrix views must share the same layout. - * * @tparam math_t data-type upon which the math operation will be performed * @tparam idx_t integer type used for indexing * @tparam LayoutPolicy layout of the input/output matrices (raft::row_major or raft::col_major) @@ -147,9 +137,6 @@ void pca_inverse_transform(raft::resources const& handle, /** * @brief performs transform operation for PCA. Transforms the data to eigenspace. - * - * Supports both row-major and col-major layouts. All matrix views must share the same layout. - * * @tparam math_t data-type upon which the math operation will be performed * @tparam idx_t integer type used for indexing * @tparam LayoutPolicy layout of the input/output matrices (raft::row_major or raft::col_major)