diff --git a/CMakeLists.txt b/CMakeLists.txt index 93afd5cd..15fa262e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,6 +31,7 @@ find_package(Spt3g REQUIRED) find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) find_package(FLAC) find_package(GSL) +find_package(FFTW) find_package(Ceres) find_package(OpenMP) @@ -86,7 +87,15 @@ target_link_libraries(so3g PUBLIC spt3g::core) target_include_directories(so3g PRIVATE ${GSL_INCLUDE_DIR}) target_link_libraries(so3g PUBLIC ${GSL_LIBRARIES}) # Link Ceres -target_link_libraries(so3g PUBLIC Ceres::ceres Eigen3::Eigen) +target_link_libraries(so3g PUBLIC Ceres::Ceres Eigen3::Eigen) + +# Link FFTW +target_include_directories(so3g PRIVATE ${FFTW_INCLUDE_DIRS}) +target_link_libraries(so3g PUBLIC ${FFTW_LIBRARIES}) +target_link_libraries(so3g PUBLIC ${FFTW_OMP_LIBRARY}) +# Link FFTWF +target_link_libraries(so3g PUBLIC ${FFTWF_LIBRARIES}) +target_link_libraries(so3g PUBLIC ${FFTWF_OMP_LIBRARY}) # FLAC- library already comes from spt3g dependencies, but # we need to have the headers. diff --git a/Dockerfile b/Dockerfile index bf94ef60..df6722ee 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,6 +15,7 @@ RUN apt update && apt install -y \ libopenblas-openmp-dev \ libbz2-dev \ python-is-python3 \ + libfftw3-dev \ libgoogle-glog-dev \ libgflags-dev \ libmetis-dev \ diff --git a/cmake/FindFFTW.cmake b/cmake/FindFFTW.cmake new file mode 100644 index 00000000..d9dfb4e3 --- /dev/null +++ b/cmake/FindFFTW.cmake @@ -0,0 +1,39 @@ +# Find FFTW +# FFTW_INCLUDES - where to find fftw3.h +# FFTW_LIBRARIES - List of libraries when using FFTW +# FFTW_FOUND - True if FFTW is found + +if (FFTW_INCLUDES) + set (FFTW_FIND_QUIETLY TRUE) # Already in cache, be silent +endif () + +find_path (FFTW_INCLUDES fftw3.h HINTS ENV FFTW_INC) +find_library (FFTW_LIBRARIES NAMES fftw3 HINTS ENV FFTW_DIR) +find_library (FFTW_OMP_LIBRARY NAMES fftw3_omp HINTS ENV FFTW_DIR) # Find OMP implementation + +include (FindPackageHandleStandardArgs) +set(FPHSA_NAME_MISMATCHED TRUE) +find_package_handle_standard_args(FFTW_OMP DEFAULT_MSG FFTW_OMP_LIBRARY) +mark_as_advanced(FFTW_OMP_LIBRARY) + +find_package_handle_standard_args(FFTW DEFAULT_MSG FFTW_LIBRARIES FFTW_INCLUDES) +mark_as_advanced(FFTW_LIBRARIES FFTW_INCLUDES) + +# FFTWF_LIBRARIES - List of libraries when using FFTWF +# FFTWF_FOUND - True if FFTWF is found + +if (FFTW_INCLUDES) + set (FFTWF_FIND_QUIETLY TRUE) # Already in cache, be silent +endif () + +find_path (FFTW_INCLUDES fftw3f.h HINTS ENV FFTW_INC) +find_library (FFTWF_LIBRARIES NAMES fftw3f HINTS ENV FFTW_DIR) +find_library (FFTWF_OMP_LIBRARY NAMES fftw3f_omp HINTS ENV FFTW_DIR) # Find OMP implementation + +include (FindPackageHandleStandardArgs) +set(FPHSA_NAME_MISMATCHED TRUE) +find_package_handle_standard_args(FFTWF_OMP DEFAULT_MSG FFTWF_OMP_LIBRARY) +mark_as_advanced(FFTWF_OMP_LIBRARY) + +find_package_handle_standard_args(FFTWF DEFAULT_MSG FFTWF_LIBRARIES FFTW_INCLUDES) +mark_as_advanced(FFTWF_LIBRARIES FFTW_INCLUDES) diff --git a/docker/so3g-setup.sh b/docker/so3g-setup.sh index f8594493..dfa16ea2 100644 --- a/docker/so3g-setup.sh +++ b/docker/so3g-setup.sh @@ -6,6 +6,7 @@ cmake \ -DCMAKE_VERBOSE_MAKEFILE=ON \ -DCMAKE_BUILD_TYPE=Release \ -DPython_EXECUTABLE=$(which python3) \ + -DCMAKE_MODULE_PATH=$(pwd)/../cmake \ .. make -j 2 make install diff --git a/src/array_ops.cxx b/src/array_ops.cxx index 66bc53cd..2c72d717 100644 --- a/src/array_ops.cxx +++ b/src/array_ops.cxx @@ -23,6 +23,8 @@ extern "C" { #include #include +#include + #include #include "so3g_numpy.h" #include "numpy_assist.h" @@ -1260,6 +1262,260 @@ void detrend(bp::object & tod, const std::string & method, const int linear_ncou } } +template +void _hanning_window(T* window, const int n) +{ + for (int i = 0; i < n; ++i) { + window[i] = 0.5 * (1 - cos(2 * M_PI * i / (n - 1))); + } +} + +template +auto _allocate_fftw_output(const int n) { + if constexpr (std::is_same::value) { + return static_cast(fftwf_malloc(sizeof(fftwf_complex) * n)); + } + else if constexpr (std::is_same::value) { + return static_cast(fftw_malloc(sizeof(fftw_complex) * n)); + } +} + +template +auto _select_fftw_plan(T1* in, T2* out, const int ndets, + const int nperseg, const int npsd, + const int idist, const int odist) +{ + int rank = 1; // 1D FFT + int n[] = {static_cast(nperseg)}; // FFT size + int howmany = ndets; // Number of transforms to compute + int istride = 1; // Input stride + int ostride = 1; // Output stride + int *inembed = n; // Input array dimensions + int *onembed = n; // Output array dimensions + + if constexpr (std::is_same::value) { + return fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist, + out, onembed, ostride, odist, FFTW_ESTIMATE); + } + else if constexpr (std::is_same::value) { + return fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist, + out, onembed, ostride, odist, FFTW_ESTIMATE); + } +} + +template +void _execute_fftw_plan(T plan) +{ + if constexpr (std::is_same::value) { + fftw_execute(plan); + } + else if constexpr (std::is_same::value) { + fftwf_execute(plan); + } +} + +template +void _welch(const bp::object & signal, bp::object & psd, const double fs, + const int nperseg, int noverlap, const std::string & detrend_method, + const int detrend_linear_ncount, const std::string & scaling) +{ + BufferWrapper signal_buf ("signal", signal, false, std::vector{-1, -1}); + if (signal_buf->strides[1] != signal_buf->itemsize) + throw ValueError_exception("Argument 'signal' must be contiguous in last axis."); + const int ndets = signal_buf->shape[0]; + const int nsamps = signal_buf->shape[1]; + T* signal_data = (T*)signal_buf->buf; + + BufferWrapper psd_buf ("psd", psd, false, std::vector{-1, nperseg / 2 + 1}); + if (psd_buf->strides[1] != psd_buf->itemsize) + throw ValueError_exception("Argument 'psd' must be contiguous in last axis."); + const int npsd = psd_buf->shape[1]; + T* psd_data = (T*)psd_buf->buf; + + if (nperseg > nsamps) { + throw ValueError_exception("nperseg must be <= nsamps"); + } + if (noverlap >= nperseg) { + throw ValueError_exception("noverlap must be < nperseg"); + } + if (fs <= 0) { + throw ValueError_exception("fs must be > 0"); + } + + // Data strides + int signal_stride = signal_buf->strides[0] / sizeof(T); + int psd_stride = psd_buf->strides[0] / sizeof(T); + + // Get number of threads for fftw + int nthreads = 1; + #pragma omp parallel + { + #ifdef _OPENMP + if (omp_get_thread_num() == 0) + nthreads = omp_get_num_threads(); + #endif + } + + if constexpr (std::is_same::value) { + fftwf_init_threads(); + fftwf_plan_with_nthreads(nthreads); + } + else if constexpr (std::is_same::value) { + fftw_init_threads(); + fftw_plan_with_nthreads(nthreads); + } + + // Default noverlap + if (noverlap < 0) { + noverlap = nperseg / 2; + } + + int nstep = nperseg - noverlap; + + // Window array + T window[nperseg]; + _hanning_window(window, nperseg); + + T scale = 1.0; + + if (scaling == "density") { + T window_sum_sq = 0.0; + for (int i = 0; i < nperseg; ++i) { + window_sum_sq += window[i] * window[i]; + } + scale = 1.0 / (fs * window_sum_sq); + } + else if (scaling == "spectrum") { + T window_sum = 0.0; + for (int i = 0; i < nperseg; ++i) { + window_sum += window[i]; + } + scale = 1.0 / (window_sum * window_sum); + } + else { + throw ValueError_exception("Supported scaling options are 'density' " + "or 'spectrum'"); + } + + // Number of segments to average over + int nsegments = ((nsamps) - noverlap) / nstep; + + // Input array for segment + T* segment = (T*) malloc(ndets * nperseg * sizeof(T)); + + // Either fftw_complex* or fftwf_complex* + auto out = _allocate_fftw_output(ndets * npsd); + + // Plan creation is not thread-safe and creating many upfront + // is less efficient, so just reuse one sequentially and + // parallelize internally. + auto plan = _select_fftw_plan(segment, out, ndets, nperseg, npsd, + nperseg, psd_stride); + + // Loop over segments + for (int s = 0; s < nsegments; ++s) { + int start = s * nstep; + int end = start + nperseg; + + #pragma omp parallel for + for (int i = 0; i < ndets; ++i) { + int signal_ioff = i * signal_stride + start; + int segment_ioff = i * nperseg; + + T* signal_row = signal_data + signal_ioff; + T* segment_row = segment + segment_ioff; + + // Copy data due to detrending and windowing + for (int j = 0; j < nperseg; ++j) { + segment_row[j] = signal_row[j]; + } + + // More efficient to detrend each row in a parallel loop + // with data copying and windowing than to do each individually + // in parallel loops + if (detrend_method != "none") { + _detrend(segment_row, 1, nperseg, nperseg, detrend_method, + detrend_linear_ncount, 1); + } + // Apply window function + for (int j = 0; j < nperseg; ++j) { + segment_row[j] *= window[j]; + } + } + + // Execute either fftw or fftwf plan + _execute_fftw_plan(plan); + + // Add segement psd into total psd array + #pragma omp parallel for + for (int i = 0; i < ndets; ++i) { + int ioff = i * psd_stride; + for (int j = 0; j < npsd; ++j) { + T real = out[ioff + j][0]; + T imag = out[ioff + j][1]; + psd_data[ioff + j] += (real * real + imag * imag) * scale; + } + } + } + + // Get average psd over segments + #pragma omp parallel for + for (int i = 0; i < ndets; ++i) { + int ioff = i * psd_stride; + for (int j = 0; j < npsd; ++j) { + psd_data[ioff + j] /= nsegments; + } + } + + // Normalization of endpoints + int end_index = npsd; + + if (nperseg % 2) { + end_index -= 1; + } + + #pragma omp parallel for + for (int i = 0; i < ndets; ++i) { + int ioff = i * psd_stride; + for (int j = 1; j < end_index; ++j) { + psd_data[ioff + j] *= 2; + } + } + + if constexpr (std::is_same::value) { + fftw_destroy_plan(plan); + fftw_free(out); + fftw_cleanup_threads(); + } + else if constexpr (std::is_same::value) { + fftwf_destroy_plan(plan); + fftwf_free(out); + fftwf_cleanup_threads(); + } + free(segment); +} + +void welch(const bp::object & signal, bp::object & psd, const double fs, + const int nperseg, const int noverlap, const std::string & detrend_method, + const int detrend_linear_ncount, const std::string & scaling) +{ + // get data type + int dtype = get_dtype(signal); + + if (dtype == NPY_FLOAT) { + _welch(signal, psd, fs, nperseg, noverlap, detrend_method, + detrend_linear_ncount, scaling); + } + else if (dtype == NPY_DOUBLE) { + _welch(signal, psd, fs, nperseg, noverlap, detrend_method, + detrend_linear_ncount, scaling); + } + else { + throw TypeError_exception("Only float32 or float64 arrays are supported."); + } +} + + PYBINDINGS("so3g") { bp::def("nmat_detvecs_apply", nmat_detvecs_apply); @@ -1418,4 +1674,25 @@ PYBINDINGS("so3g") " linear_ncount: Number (int) of samples to use on each end, when measuring mean level for 'linear'" " detrend. Must be a positive integer or -1. If -1, nsamps / 2 will be used. Values " " larger than 1 suppress the influence of white noise.\n"); -} \ No newline at end of file + bp::def("welch", welch, + "welch(signal, psd, fs, nperseg, noverlap, detrend_method, detrend_linear_ncount, scaling)" + "\n" + "Calculate the PSD of each row a 2D data array (float32/float64) using Welch's method.\n" + "A Hanning window is applied. OMP is used to parallelize across dets (rows) and within FFTW.\n" + "Uses mean normalization for PSD segments. Based on the scipy implementation of Welch's method:\n" + "https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.welch.html" + "\n" + "Args:\n" + " signal: input array (float32/float64) buffer with shape (ndets, nsamps)\n" + " psd: output data buffer (float32/float64) to store computed PSDs with shape\n" + " (ndets,(nperseg // 2) + 1). It is modified in place.\n" + " fs: Sample rate in Hz (double)\n" + " nperseg: size (int) of each segment to be averaged over. nperseg must be <= nsamps.\n" + " noverlap: number (int) of samples to overlap for each segment. Set to nperseg / 2 if < 0\n" + " detrend_method: how to detrend each row (string). Options are 'mean', 'median', 'linear', and 'none'.\n" + " See docstring for detrend.\n" + " detrend_linear_ncount: 'linear_ncount' parameter for 'linear' detrending (int). See docstring for detrend.\n" + " scaling: how to normalize the averaged PSD (string). Options are 'density' or 'spectrum.'\n" + " density normalizes by 1.0 / (fs * sum(window^2)). spectrum normalizes by \n" + " 1.0 / (sum(window)^2\n"); +} diff --git a/test/test_array_ops.py b/test/test_array_ops.py index 20c6036a..97a76d2f 100644 --- a/test/test_array_ops.py +++ b/test/test_array_ops.py @@ -425,5 +425,67 @@ def test_02_linear_detrending(self): np.testing.assert_allclose(signal_copy, signal, rtol=rtol, atol=atol) +class TestWelch(unittest.TestCase): + """ + Test Welch PSD calculation. + """ + + def test_00_psd_float32(self): + nsamps = 1000 + ndets = 3 + dtype = "float32" + order = "C" + + x = np.linspace(0., 1., nsamps, dtype=dtype) + signal = np.array([(i + 1) * np.sin(2*np.pi*x + i) for i in range(ndets)], dtype=dtype, order=order) + + fs = 200. + nperseg = 256 + noverlap = -1 # noverlap = nperseg // 2 + detrend_method = "mean" # "constant" in scipy's welch + detrend_ncount = 0 + scaling = "density" + + npsd = (nperseg // 2) + 1 + + window = np.hanning(nperseg) + scipy_f, scipy_psd = welch(signal, fs, nperseg=nperseg, window=window, detrend="constant", scaling=scaling) + + so3g_psd = np.zeros((ndets, npsd), dtype=dtype, order=order) + so3g.welch(signal, so3g_psd, fs, nperseg, noverlap, detrend_method, detrend_ncount, scaling) + + rtol = 1e-4 + atol = 1e-8 + np.testing.assert_allclose(scipy_psd, so3g_psd, rtol=rtol, atol=atol) + + def test_00_psd_float64(self): + nsamps = 1000 + ndets = 3 + dtype = "float64" + order = "C" + + x = np.linspace(0., 1., nsamps, dtype=dtype) + signal = np.array([(i + 1) * np.sin(2*np.pi*x + i) for i in range(ndets)], dtype=dtype, order=order) + + fs = 200. + nperseg = 256 + noverlap = -1 # noverlap = nperseg // 2 + detrend_method = "mean" # "constant" in scipy's welch + detrend_ncount = 0 + scaling = "density" + + npsd = (nperseg // 2) + 1 + + window = np.hanning(nperseg) + scipy_f, scipy_psd = welch(signal, fs, nperseg=nperseg, window=window, detrend="constant", scaling=scaling) + + so3g_psd = np.zeros((ndets, npsd), dtype=dtype, order=order) + so3g.welch(signal, so3g_psd, fs, nperseg, noverlap, detrend_method, detrend_ncount, scaling) + + rtol = 1e-10 + atol = 1e-10 + np.testing.assert_allclose(scipy_psd, so3g_psd, rtol=rtol, atol=atol) + + if __name__ == "__main__": unittest.main()