From fbc4c0fce5ab45f751fc70bf021ec2925b8c5ed6 Mon Sep 17 00:00:00 2001 From: Markus Mayer Date: Sun, 31 May 2026 20:35:10 +0200 Subject: [PATCH 1/4] feat!: redesign as complete STFT/ISTFT toolkit (0.4) Ground-up rewrite turning the single-file 0.3 streaming STFT into a full transform toolkit. Removes the old STFT type, WindowType enum, FromF64 trait and log10_positive function. Correctness fixes (were part of the 0.3 contract): - n_freqs is now fft_size/2 + 1, including the Nyquist bin (was fft_size/2, which dropped it). - bin frequencies are now k*fs/fft_size (were k*fs/(2*(n_freqs-1))). Added: - Stft/StftBuilder: streaming (append/ready/process_into/step/columns) and one-shot spectrogram() batch with centered framing (reflect/edge/zero pad) and optional rayon parallelism. - Istft/IstftBuilder: weighted overlap-add inverse for perfect reconstruction; Stft::inverse() mirrors a forward transform. - Window/WindowFunction/Symmetry: 14 window families in periodic and symmetric variants, with cached sum and sum-of-squares. - Scaling (none/magnitude/density) and PadMode. - spectrum helpers: magnitude, power, phase, dB conversions. - mel feature: mel filterbank, mel scales and an orthonormal DCT-II for MFCCs (librosa-compatible defaults), no extra dependencies. - Optional ndarray (Array2 output), rayon, serde integrations. - no_std support (with alloc) for the window, spectrum and mel math. Changed: - FFT backend switched to realfft (real-only): ~2x faster, half the memory on real input. - #![forbid(unsafe_code)] across the crate; safety-dance badge. Tooling: - New CI (fmt, clippy, MSRV, no_std build, docs, cargo-deny, coverage), rustfmt.toml, deny.toml, CHANGELOG.md, dependabot. Removed dead .travis.yml. README rewritten with examples and a 0.3 migration table. Tests: analytic correctness tests plus proptest linearity and STFT/ISTFT round-trip properties; runnable examples. --- .github/dependabot.yml | 11 + .github/workflows/ci.yml | 95 ++++++++ .travis.yml | 30 --- CHANGELOG.md | 48 ++++ Cargo.toml | 56 ++++- README.md | 134 ++++++++++- Taskfile.dist.yaml | 122 ++++++++++ benches/lib.rs | 199 +++++++--------- deny.toml | 29 +++ examples/mfcc.rs | 44 ++++ examples/roundtrip.rs | 42 ++++ examples/spectrogram.rs | 52 +++++ rustfmt.toml | 1 + src/batch.rs | 204 ++++++++++++++++ src/config.rs | 42 ++++ src/error.rs | 79 +++++++ src/istft.rs | 271 +++++++++++++++++++++ src/lib.rs | 493 +++++++++------------------------------ src/mel.rs | 252 ++++++++++++++++++++ src/sample.rs | 27 +++ src/spectrum.rs | 106 +++++++++ src/stft.rs | 352 ++++++++++++++++++++++++++++ src/window/functions.rs | 202 ++++++++++++++++ src/window/mod.rs | 237 +++++++++++++++++++ tests/lib.rs | 350 ++++++++++++++++++--------- tests/proptests.rs | 63 +++++ 26 files changed, 2885 insertions(+), 656 deletions(-) create mode 100644 .github/dependabot.yml create mode 100644 .github/workflows/ci.yml delete mode 100644 .travis.yml create mode 100644 CHANGELOG.md create mode 100644 Taskfile.dist.yaml create mode 100644 deny.toml create mode 100644 examples/mfcc.rs create mode 100644 examples/roundtrip.rs create mode 100644 examples/spectrogram.rs create mode 100644 rustfmt.toml create mode 100644 src/batch.rs create mode 100644 src/config.rs create mode 100644 src/error.rs create mode 100644 src/istft.rs create mode 100644 src/mel.rs create mode 100644 src/sample.rs create mode 100644 src/spectrum.rs create mode 100644 src/stft.rs create mode 100644 src/window/functions.rs create mode 100644 src/window/mod.rs create mode 100644 tests/proptests.rs diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..421d7c8 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +version: 2 +updates: + - package-ecosystem: cargo + directory: "/" + schedule: + interval: weekly + open-pull-requests-limit: 5 + - package-ecosystem: github-actions + directory: "/" + schedule: + interval: weekly diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..d9c1dd0 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,95 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +env: + CARGO_TERM_COLOR: always + RUSTFLAGS: "-D warnings" + +jobs: + fmt: + name: rustfmt + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + - run: cargo fmt --all --check + + clippy: + name: clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: clippy + - uses: Swatinem/rust-cache@v2 + - run: cargo clippy --all-targets --all-features + # The FFT processors (and thus examples/tests/benches) need `std`, so + # lint only the library for the no_std subsets. + - run: cargo clippy --lib --no-default-features + - run: cargo clippy --lib --no-default-features --features mel,serde + + test: + name: test (${{ matrix.rust }}) + runs-on: ubuntu-latest + strategy: + matrix: + rust: [stable, "1.75"] + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + - uses: Swatinem/rust-cache@v2 + - run: cargo test --all-features + - run: cargo test # default features + + no_std: + name: no_std build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + targets: thumbv7em-none-eabihf + - run: cargo build --no-default-features --features mel,serde --target thumbv7em-none-eabihf + + docs: + name: doc + runs-on: ubuntu-latest + env: + RUSTDOCFLAGS: "-D warnings --cfg docsrs" + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@nightly + - run: cargo doc --all-features --no-deps + + deny: + name: cargo-deny + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: EmbarkStudios/cargo-deny-action@v2 + + coverage: + name: coverage + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: llvm-tools-preview + - uses: taiki-e/install-action@cargo-llvm-cov + - uses: Swatinem/rust-cache@v2 + - run: cargo llvm-cov --all-features --lcov --output-path lcov.info + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + fail_ci_if_error: false diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 6b33f90..0000000 --- a/.travis.yml +++ /dev/null @@ -1,30 +0,0 @@ -sudo: true -language: rust -rust: - - stable - - beta - - nightly -matrix: - allow_failures: - - rust: nightly -script: - # travis default rust test script - - cargo build --verbose - - cargo test --verbose - # benches use the test feature which is only available on nightly - # run benches once per build to detect performance regressions - - if [ "$TRAVIS_RUST_VERSION" = "nightly" ] ; then - cargo bench --verbose; - fi -env: - global: - - secure: "eRqOvqFJkpV2uhp27Va2S1Tq6FUXdkoKRYrLnvcyO4eEry0Mjohc0T3M2syWNCJNupyUIECunYGUx3k6f2YtpdgQg0UndS67YZqUusFjnUNYUTR9KDwHE9MBS/df4x7dbGLqMX9PHi+Oo1wQXZhRb56s3K7CuV7GqqSXpt+FQ2qswugYfYpYzzYnlj0ak2YjgP65dZaFiR/p5awg6Vs1iWd1hPkfeyiDP9xXtPSrwgYoivv9RVkRIk0EZ7ak9e3710pQCMIafmZPuKTmSF/5BLUAHnu9EShMGaYnf5tSXmbFiqW6Jp+72EnbaEEAIoqbxuBRYJA/hqvnJY/DADT9Y4njwCqHJ6GWwe0+RzG0YKOx1Y19xXiercwuxllc7d+H4RU84NHH67JgcE91Btoj2bIrkLSoCiOcMEOITKbye299m7H1PQbf9D6AB0N8hCQYaQq0qKogwrBUx16KYcnvSSybHAJMCGPHcNiHyHMtiiyvgwOmB9aDH3BIlUlXDq18PrV+OreQ7TZL3D+DJ9s6F8P+DB+7ZFckVwxSCND5auDSMa4EyagQIjVNy56K4Tcwn2szF/bzPKpXwFMM7jo25F1nud9tox9k8ohjMStlgfcb6jS25xVDbeogteqKp6g/wjNSJBCOdxcVVNn1AYSgwptG8e3PmHRPrQ1VaeJj6t8=" -# after success build rust documentation and push it to gh-pages branch -# only do this for the stable branch (once per build) -after_success: | - if [ "$TRAVIS_RUST_VERSION" = "stable" ] ; then - cargo doc --no-deps && \ - sudo pip install ghp-import && \ - ghp-import -n target/doc && \ - git push -qf https://${GITHUB_TOKEN}@github.com/${TRAVIS_REPO_SLUG}.git gh-pages - fi diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..92d9b1a --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,48 @@ +# Changelog + +All notable changes to this project are documented here. The format is based +on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project +adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.4.0] + +This is a ground-up redesign with a clean, encapsulated API. The old `STFT` +type, `WindowType` enum, `FromF64` trait and `log10_positive` free function +have been removed. + +### Added + +- `Istft` / `IstftBuilder`: inverse STFT with weighted overlap-add for perfect + reconstruction, plus `Stft::inverse()` to mirror a forward transform. +- One-shot batch processing: `Stft::spectrogram()` returning a `Spectrogram`, + with optional centered framing (`reflect`/`edge`/`zero` padding) and optional + `rayon` parallelism. +- A full window library (`Window`, `WindowFunction`, `Symmetry`): rectangular, + Hann, Hamming, Blackman, Blackman-Harris, Nuttall, flat-top, Bartlett, + triangular, Welch, cosine, Tukey, Kaiser and Gaussian, in periodic and + symmetric variants. +- Coefficient scaling modes (`Scaling`: none, magnitude, density). +- `spectrum` helpers: magnitude, power, phase, and decibel conversions. +- `mel` feature: mel filterbank, mel scale conversions, and an orthonormal + DCT-II for MFCCs (librosa-compatible defaults). +- Optional integrations: `ndarray` (`Array2` output), `rayon` (parallel batch), + `serde` (config (de)serialization). +- `no_std` support (with `alloc`) for the window, spectrum and mel math. +- `#![forbid(unsafe_code)]` across the crate. + +### Changed + +- Switched the FFT backend to `realfft`, roughly halving time and memory for + real-valued input. + +### Fixed + +- The number of frequency bins is now `fft_size / 2 + 1`, correctly including + the Nyquist bin (previously `fft_size / 2`, which dropped it). +- Bin center frequencies are now `k · fs / fft_size` (previously + `k · fs / (2·(n_freqs − 1))`, which was off). + +[Unreleased]: https://github.com/sunsided/stft/compare/v0.4.0...HEAD +[0.4.0]: https://github.com/sunsided/stft/compare/v0.3.1...v0.4.0 diff --git a/Cargo.toml b/Cargo.toml index 3a6a939..dfcc046 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,26 +1,62 @@ [package] -authors = ["Maximilian Krüger ", "Markus Mayer "] -description = "computes the short-time fourier transform on streaming data" +authors = [ + "Maximilian Krüger ", + "Markus Mayer ", +] +categories = ["science", "multimedia::audio", "mathematics", "no-std"] +description = "Short-time Fourier transform (STFT) and inverse STFT: streaming and batch spectrograms, a rich window library, mel spectrograms and MFCCs." documentation = "https://docs.rs/ruststft" +edition = "2021" homepage = "https://github.com/sunsided/stft" -keywords = ["dsp", "fft", "stream", "data", "fourier"] +keywords = ["dsp", "fft", "stft", "spectrogram", "mfcc"] license = "MIT OR Apache-2.0" name = "ruststft" readme = "README.md" repository = "https://github.com/sunsided/stft.git" -version = "0.3.1" -edition = "2021" +rust-version = "1.75" +version = "0.4.0" + +[features] +default = ["std"] +# Enables the FFT-backed processors (forward/inverse STFT, batch spectrograms). +# Required because the `realfft`/`rustfft` backend depends on `std`. +std = ["dep:realfft", "num-traits/std", "num-complex/std"] +# Mel filterbank, (log-)mel spectrogram and MFCC. Pure math, `no_std`-capable. +mel = [] +# `Array2` input/output for batch spectrograms. Pulls in `std`. +ndarray = ["dep:ndarray", "std"] +# Parallel per-frame batch STFT. Pulls in `std`. +rayon = ["dep:rayon", "std"] +# `serde` derives on the configuration and window-specification types. +serde = ["dep:serde"] [dependencies] -apodize = "1.0.0" -num = "0.4.0" -rustfft = "6.1.0" -strider = "0.1.3" +num-complex = { version = "0.4.6", default-features = false, features = ["libm"] } +num-traits = { version = "0.2.19", default-features = false, features = ["libm"] } +realfft = { version = "3.4.0", optional = true } +ndarray = { version = "0.16.1", optional = true, default-features = false, features = ["std"] } +rayon = { version = "1.10.0", optional = true } +serde = { version = "1.0", optional = true, default-features = false, features = ["derive", "alloc"] } [dev-dependencies] -criterion = "0.5.1" approx = "0.5.1" +criterion = "0.5.1" +proptest = "1.6.0" [[bench]] name = "lib" harness = false + +[[example]] +name = "mfcc" +required-features = ["mel"] + +[lints.rust] +unsafe_code = "forbid" + +[lints.clippy] +all = "deny" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] diff --git a/README.md b/README.md index 141503a..cf1a7c9 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,131 @@ -# STFT +# ruststft -*status: working. missing some tests. api still in flux.* +[![CI](https://github.com/sunsided/stft/actions/workflows/ci.yml/badge.svg)](https://github.com/sunsided/stft/actions/workflows/ci.yml) +[![crates.io](https://img.shields.io/crates/v/ruststft.svg)](https://crates.io/crates/ruststft) +[![docs.rs](https://img.shields.io/docsrs/ruststft)](https://docs.rs/ruststft) +[![unsafe forbidden](https://img.shields.io/badge/unsafe-forbidden-success.svg)](https://github.com/rust-secure-code/safety-dance/) +![license](https://img.shields.io/crates/l/ruststft.svg) -![Build Status](https://github.com/sunsided/stft/actions/workflows/rust.yml/badge.svg) +A complete [short-time Fourier transform](https://en.wikipedia.org/wiki/Short-time_Fourier_transform) +toolkit for Rust: forward **and** inverse STFT, a rich window library, batch and +streaming APIs, and optional mel spectrograms / MFCCs. -**computes the [short-time fourier transform](https://en.wikipedia.org/wiki/Short-time_Fourier_transform) -on streaming data. written in [rust](https://www.rust-lang.org/).** +- **Forward STFT** over real signals, streaming or batch, backed by + [`realfft`](https://crates.io/crates/realfft) (≈2× faster and half the memory + of a full complex FFT on real input). +- **Inverse STFT** with weighted overlap-add (WOLA) for perfect reconstruction. +- **Windows**: rectangular, Hann, Hamming, Blackman, Blackman-Harris, Nuttall, + flat-top, Bartlett, triangular, Welch, cosine, Tukey, Kaiser, Gaussian - + periodic (spectral-analysis) or symmetric (filter-design). +- **Spectrum helpers**: magnitude, power, phase, and decibel conversions. +- **Mel & MFCC** (`mel` feature): librosa-compatible filterbank, mel scales and + an orthonormal DCT-II. +- **`#![forbid(unsafe_code)]`** - 100% safe Rust. +- **`no_std`** (with `alloc`) for the window, spectrum and mel math. -to use add `ruststft = "*"` -to the `[dependencies]` section of your `Cargo.toml` and `use ruststft;` in your code. +## Install -## [read the documentation for an example and more !](https://docs.rs/ruststft/latest/ruststft/) +```toml +[dependencies] +ruststft = "0.4" +``` -### [contributing](contributing.md) +## Batch spectrogram -### licensed under either of [apache-2.0](LICENSE-APACHE) ([tl;dr](https://tldrlegal.com/license/apache-license-2.0-(apache-2.0))) or [MIT](LICENSE-MIT) ([tl;dr](https://tldrlegal.com/license/mit-license)) at your option +```rust +use ruststft::{Stft, Window}; + +let fs = 8_000.0; +let signal: Vec = (0..8_000) + .map(|n| (2.0 * std::f64::consts::PI * 1_000.0 * n as f64 / fs).sin()) + .collect(); + +let mut stft = Stft::builder() + .window(Window::::hann(1024)) + .hop_size(256) + .center(true) + .build() + .unwrap(); + +let spec = stft.spectrogram(&signal); +assert_eq!(spec.n_freqs(), 1024 / 2 + 1); // includes the Nyquist bin +``` + +## Perfect reconstruction (STFT → ISTFT) + +```rust +use ruststft::{Stft, Window}; + +let signal: Vec = (0..8_000).map(|n| (n as f64 * 0.01).sin()).collect(); + +let mut stft = Stft::builder() + .window(Window::::hann(1024)) + .hop_size(256) // 75% overlap: Hann is COLA-compliant + .center(true) + .build() + .unwrap(); + +let spec = stft.spectrogram(&signal); +let recon = stft.inverse().unwrap().reconstruct(&spec).unwrap(); +// recon matches `signal` in the interior to ~machine precision. +``` + +## Streaming + +```rust +use ruststft::{Complex, Stft, Window}; + +let mut stft = Stft::builder() + .window(Window::::hann(1024)) + .hop_size(512) + .build() + .unwrap(); + +let mut column = vec![Complex::new(0.0f32, 0.0); stft.n_freqs()]; +let chunk: Vec = (0..3000).map(|x| x as f32).collect(); + +stft.append(&chunk); +while stft.ready() { + stft.process_into(&mut column).unwrap(); + // ... use `column` ... + stft.step(); +} +``` + +## Feature flags + +| Feature | Default | Description | +|-----------|:-------:|----------------------------------------------------------| +| `std` | yes | FFT-backed processors (`Stft`, `Istft`, batch). Required for the transforms. | +| `mel` | no | Mel filterbank, mel scales, and DCT-II for MFCCs. | +| `ndarray` | no | `Spectrogram::to_array2` (`[n_freqs, n_frames]`). | +| `rayon` | no | Parallel per-frame batch spectrograms. | +| `serde` | no | (De)serialize configuration and window descriptions. | + +Without the default `std` feature the crate builds as `no_std` (with `alloc`), +exposing the window library, the [`spectrum`](https://docs.rs/ruststft/latest/ruststft/spectrum/) +helpers and the [`mel`](https://docs.rs/ruststft/latest/ruststft/mel/) math. The +FFT processors require `std` because the FFT backend does. + +## Migrating from 0.3 + +`0.4` is a breaking redesign. Rough mapping: + +| 0.3 | 0.4 | +|---------------------------------------|------------------------------------------------| +| `STFT::new(WindowType::Hanning, w, s)`| `Stft::builder().window(Window::hann(w)).hop_size(s).build()?` | +| `WindowType::Hanning` (etc.) | `Window::hann(len)` / `WindowFunction::Hann` | +| `stft.append_samples(x)` | `stft.append(x)` | +| `stft.contains_enough_to_compute()` | `stft.ready()` | +| `stft.compute_complex_column(&mut c)` | `stft.process_into(&mut c)?` | +| `stft.move_to_next_column()` | `stft.step()` | +| `stft.output_size()` (= `fft/2`) | `stft.n_freqs()` (= `fft/2 + 1`, fixes Nyquist)| +| `compute_magnitude_column` / `compute_column` | `spectrum::magnitude` / `power_to_db` on a column | +| (no inverse) | `stft.inverse()?.reconstruct(&spec)?` | + +See [`CHANGELOG.md`](CHANGELOG.md) for details. + +## [Contributing](contributing.md) + +Licensed under either of [Apache-2.0](LICENSE-APACHE) or [MIT](LICENSE-MIT) at +your option. diff --git a/Taskfile.dist.yaml b/Taskfile.dist.yaml new file mode 100644 index 0000000..8659d4c --- /dev/null +++ b/Taskfile.dist.yaml @@ -0,0 +1,122 @@ +# ruststft task runner +# https://taskfile.dev +version: '3' + +vars: + MSRV: '1.75' + +tasks: + default: + desc: List available tasks + cmds: + - task --list --sort=none + silent: true + + fmt: + desc: Format all Rust sources + aliases: [format] + preconditions: + - sh: command -v cargo + msg: "cargo not found — install Rust from https://rustup.rs/" + cmds: + - cargo fmt --all + + fmt:check: + desc: Check formatting without modifying files + cmds: + - cargo fmt --all -- --check + + lint: + desc: Run clippy with all features (warnings as errors) + aliases: [lint:check] + cmds: + - cargo clippy --all-targets --all-features -- -D warnings + + lint:fix: + desc: Auto-fix clippy lints, then format + aliases: [fix] + cmds: + - cargo clippy --all-targets --all-features --fix --allow-dirty --allow-staged + - task: fmt + + check: + desc: Type-check (all features) and verify formatting + cmds: + - cargo check --all-targets --all-features + - task: fmt:check + + check:no-std: + desc: Verify the crate builds without std (no default features) + cmds: + - cargo check --no-default-features + - cargo check --no-default-features --features mel + + build: + desc: Build with default features (debug) + cmds: + - cargo build --all-targets + + build:release: + desc: Build with all features (release) + cmds: + - cargo build --release --all-features + + test: + desc: Run the test suite with all features + cmds: + - cargo test --all-targets --all-features {{.CLI_ARGS}} + + test:doc: + desc: Run doctests + cmds: + - cargo test --doc --all-features + + test:no-std: + desc: Run tests without default features + cmds: + - cargo test --no-default-features --features mel {{.CLI_ARGS}} + + bench: + desc: Run the criterion benchmarks + cmds: + - cargo bench --all-features {{.CLI_ARGS}} + + doc: + desc: Build rustdoc with all features (docs.rs config) + aliases: [docs] + env: + RUSTDOCFLAGS: '--cfg docsrs' + cmds: + - cargo doc --no-deps --all-features + + doc:open: + desc: Build and open rustdoc in the browser + aliases: [docs:open] + env: + RUSTDOCFLAGS: '--cfg docsrs' + cmds: + - cargo doc --no-deps --all-features --open + + msrv: + desc: Verify the crate builds on the declared MSRV ({{.MSRV}}) + preconditions: + - sh: rustup toolchain list | grep -q '{{.MSRV}}' + msg: "Rust {{.MSRV}} toolchain required — rustup toolchain install {{.MSRV}}" + cmds: + - cargo +{{.MSRV}} build --all-features + + clean: + desc: Remove build artifacts + status: + - test ! -d target + cmds: + - cargo clean + + ci: + desc: Run the full CI sequence (format check, lint, no-std, tests, doctests) + cmds: + - task: fmt:check + - task: lint + - task: check:no-std + - task: test + - task: test:doc diff --git a/benches/lib.rs b/benches/lib.rs index 37237c4..624b654 100644 --- a/benches/lib.rs +++ b/benches/lib.rs @@ -1,134 +1,91 @@ -use criterion::{criterion_group, criterion_main, Criterion}; -use num::complex::Complex; -use rustfft::{FftDirection, FftPlanner}; -use ruststft::{WindowType, STFT}; +//! Criterion benchmarks for the forward/inverse STFT. -macro_rules! bench_fft_process { - ($c:expr, $window_size:expr, $float:ty) => {{ - let mut planner = FftPlanner::new(); - let fft = planner.plan_fft($window_size, FftDirection::Forward); - // input is processed in-place - let mut output = std::iter::repeat(Complex::new(0., 0.)) - .take($window_size) - .collect::>>(); - $c.bench_function( - concat!( - "bench_fft_process_", - stringify!($window_size), - "_", - stringify!($float) - ), - |b| b.iter(|| fft.process(&mut output[..])), - ); - }}; -} +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use ruststft::{Complex, Stft, Window}; -fn bench_fft_process_1024_f32(c: &mut Criterion) { - bench_fft_process!(c, 1024, f32); +fn signal_f32(seconds: usize) -> Vec { + let fs = 44_100usize; + (0..fs * seconds) + .map(|n| { + let t = n as f32 / fs as f32; + 0.5 * (2.0 * std::f32::consts::PI * 440.0 * t).sin() + }) + .collect() } -fn bench_fft_process_1024_f64(c: &mut Criterion) { - bench_fft_process!(c, 1024, f64); +fn signal_f64(seconds: usize) -> Vec { + let fs = 44_100usize; + (0..fs * seconds) + .map(|n| { + let t = n as f64 / fs as f64; + 0.5 * (2.0 * std::f64::consts::PI * 440.0 * t).sin() + }) + .collect() } -criterion_group!( - benches_fft_process, - bench_fft_process_1024_f32, - bench_fft_process_1024_f64 -); +fn bench_forward(c: &mut Criterion) { + let mut group = c.benchmark_group("forward_spectrogram"); + let s32 = signal_f32(10); + let s64 = signal_f64(10); + group.throughput(Throughput::Elements(s32.len() as u64)); -macro_rules! bench_stft_compute { - ($c:expr, $window_size:expr, $float:ty) => {{ - let step_size: usize = 512; - let mut stft = STFT::<$float>::new(WindowType::Hanning, $window_size, step_size); - let input = std::iter::repeat(1.) - .take($window_size) - .collect::>(); - let mut output = std::iter::repeat(0.) - .take(stft.output_size()) - .collect::>(); - stft.append_samples(&input[..]); - $c.bench_function( - concat!( - "bench_stft_compute_", - stringify!($window_size), - "_", - stringify!($float) - ), - |b| b.iter(|| stft.compute_column(&mut output[..])), - ); - }}; + for &win in &[1024usize, 4096] { + group.bench_with_input(BenchmarkId::new("f32", win), &win, |b, &win| { + let mut stft = Stft::builder() + .window(Window::::hann(win)) + .hop_size(win / 4) + .build() + .unwrap(); + b.iter(|| stft.spectrogram(&s32)); + }); + group.bench_with_input(BenchmarkId::new("f64", win), &win, |b, &win| { + let mut stft = Stft::builder() + .window(Window::::hann(win)) + .hop_size(win / 4) + .build() + .unwrap(); + b.iter(|| stft.spectrogram(&s64)); + }); + } + group.finish(); } -fn bench_stft_compute_1024_f32(c: &mut Criterion) { - bench_stft_compute!(c, 1024, f32); +fn bench_streaming(c: &mut Criterion) { + let signal = signal_f32(10); + c.bench_function("streaming_columns_1024_f32", |b| { + b.iter(|| { + let mut stft = Stft::builder() + .window(Window::::hann(1024)) + .hop_size(512) + .build() + .unwrap(); + let mut column = vec![Complex::new(0.0f32, 0.0); stft.n_freqs()]; + for chunk in signal.chunks(4096) { + stft.append(chunk); + while stft.ready() { + stft.process_into(&mut column).unwrap(); + stft.step(); + } + } + }); + }); } -fn bench_stft_compute_1024_f64(c: &mut Criterion) { - bench_stft_compute!(c, 1024, f64); +fn bench_round_trip(c: &mut Criterion) { + let signal = signal_f64(5); + c.bench_function("round_trip_1024_f64", |b| { + let mut stft = Stft::builder() + .window(Window::::hann(1024)) + .hop_size(256) + .build() + .unwrap(); + b.iter(|| { + let spec = stft.spectrogram(&signal); + let istft = stft.inverse().unwrap(); + istft.reconstruct(&spec).unwrap() + }); + }); } -criterion_group!( - benches_stft_compute, - bench_stft_compute_1024_f32, - bench_stft_compute_1024_f64 -); - -macro_rules! bench_stft_audio { - ($c:expr, $seconds:expr, $float:ty) => {{ - // let's generate some fake audio - let sample_rate: usize = 44100; - let seconds: usize = $seconds; - let sample_count = sample_rate * seconds; - let all_samples = (0..sample_count) - .map(|x| x as $float) - .collect::>(); - $c.bench_function( - concat!( - "bench_stft_audio_", - stringify!($windowsize), - "_", - stringify!($float) - ), - |b| { - b.iter(|| { - // let's initialize our short-time fourier transform - let window_type: WindowType = WindowType::Hanning; - let window_size: usize = 1024; - let step_size: usize = 512; - let mut stft = STFT::<$float>::new(window_type, window_size, step_size); - // we need a buffer to hold a computed column of the spectrogram - let mut spectrogram_column: Vec<$float> = - std::iter::repeat(0.).take(stft.output_size()).collect(); - for some_samples in (&all_samples[..]).chunks(3000) { - stft.append_samples(some_samples); - while stft.contains_enough_to_compute() { - stft.compute_column(&mut spectrogram_column[..]); - stft.move_to_next_column(); - } - } - }) - }, - ); - }}; -} - -fn bench_stft_10_seconds_audio_f32(c: &mut Criterion) { - bench_stft_audio!(c, 10, f32); -} - -fn bench_stft_10_seconds_audio_f64(c: &mut Criterion) { - bench_stft_audio!(c, 10, f64); -} - -criterion_group!( - benches_stft_audio, - bench_stft_10_seconds_audio_f32, - bench_stft_10_seconds_audio_f64 -); - -criterion_main!( - benches_fft_process, - benches_stft_compute, - benches_stft_audio -); +criterion_group!(benches, bench_forward, bench_streaming, bench_round_trip); +criterion_main!(benches); diff --git a/deny.toml b/deny.toml new file mode 100644 index 0000000..e66698e --- /dev/null +++ b/deny.toml @@ -0,0 +1,29 @@ +# Configuration for cargo-deny (https://embarkstudios.github.io/cargo-deny/). + +[advisories] +version = 2 +yanked = "deny" +ignore = [] + +[licenses] +version = 2 +allow = [ + "MIT", + "Apache-2.0", + "Apache-2.0 WITH LLVM-exception", + "BSD-2-Clause", + "BSD-3-Clause", + "ISC", + "Zlib", + "Unicode-3.0", + "Unicode-DFS-2016", +] +confidence-threshold = 0.9 + +[bans] +multiple-versions = "warn" +wildcards = "deny" + +[sources] +unknown-registry = "deny" +unknown-git = "deny" diff --git a/examples/mfcc.rs b/examples/mfcc.rs new file mode 100644 index 0000000..e39b655 --- /dev/null +++ b/examples/mfcc.rs @@ -0,0 +1,44 @@ +//! Compute MFCCs from a tone, mirroring librosa's pipeline: +//! power spectrogram -> mel filterbank -> dB -> DCT-II. +//! +//! Run with: `cargo run --example mfcc --features mel` + +use ruststft::mel::{DctII, MelFilterBank, MelScale}; +use ruststft::spectrum::{power, power_to_db}; +use ruststft::{Stft, Window}; + +fn main() { + let fs = 16_000.0; + let n_fft = 1024usize; + let n_mels = 40usize; + let n_mfcc = 13usize; + + let signal: Vec = (0..fs as usize) + .map(|i| { + let t = i as f64 / fs; + (2.0 * std::f64::consts::PI * 440.0 * t).sin() + }) + .collect(); + + let mut stft = Stft::builder() + .window(Window::::hann(n_fft)) + .hop_size(n_fft / 4) + .center(true) + .build() + .expect("valid configuration"); + let spec = stft.spectrogram(&signal); + + let bank = MelFilterBank::::new(n_mels, n_fft, fs, 0.0, fs / 2.0, MelScale::Slaney); + let dct = DctII::::new(n_mels, n_mfcc); + + // Transform the middle frame. + let frame = spec.n_frames() / 2; + let mut mel = bank.transform(&power(spec.column(frame))); + power_to_db(&mut mel, 1.0, None); + let mfcc = dct.transform(&mel); + + println!("MFCCs (frame {frame}):"); + for (i, c) in mfcc.iter().enumerate() { + println!(" c{i:<2} = {c:+.4}"); + } +} diff --git a/examples/roundtrip.rs b/examples/roundtrip.rs new file mode 100644 index 0000000..a1de7f7 --- /dev/null +++ b/examples/roundtrip.rs @@ -0,0 +1,42 @@ +//! Demonstrate perfect reconstruction: signal -> STFT -> ISTFT -> signal. +//! +//! Run with: `cargo run --example roundtrip` + +use ruststft::{Stft, Window}; + +fn main() { + let fs = 8_000.0; + let n = 8_000usize; + let signal: Vec = (0..n) + .map(|i| { + let t = i as f64 / fs; + (2.0 * std::f64::consts::PI * 220.0 * t).sin() + + 0.3 * (2.0 * std::f64::consts::PI * 880.0 * t).sin() + }) + .collect(); + + let mut stft = Stft::builder() + .window(Window::::hann(1024)) + .hop_size(256) // 75% overlap -> Hann is COLA compliant + .center(true) + .build() + .expect("valid configuration"); + + let spec = stft.spectrogram(&signal); + let recon = stft + .inverse() + .expect("invertible") + .reconstruct(&spec) + .expect("reconstruction"); + + // Maximum reconstruction error over the interior (edges taper off). + let frame = 1024; + let max_err = (frame..n - frame) + .map(|i| (recon[i] - signal[i]).abs()) + .fold(0.0_f64, f64::max); + + println!("frames: {}", spec.n_frames()); + println!("interior max reconstruction error: {max_err:.3e}"); + assert!(max_err < 1e-6, "reconstruction should be near-perfect"); + println!("reconstruction is near-perfect ✔"); +} diff --git a/examples/spectrogram.rs b/examples/spectrogram.rs new file mode 100644 index 0000000..7d70455 --- /dev/null +++ b/examples/spectrogram.rs @@ -0,0 +1,52 @@ +//! Compute a magnitude spectrogram of a linear chirp and print it in decibels. +//! +//! Run with: `cargo run --example spectrogram` + +use ruststft::spectrum::{magnitude, power_to_db}; +use ruststft::{Stft, Window}; + +fn main() { + let fs = 16_000.0; + let n = fs as usize * 2; // 2 seconds + + // A linear chirp sweeping from 200 Hz to 4 kHz. + let signal: Vec = (0..n) + .map(|i| { + let t = i as f32 / fs; + let f = 200.0 + (4000.0 - 200.0) * (t / 2.0); + (std::f32::consts::PI * f * t).sin() + }) + .collect(); + + let mut stft = Stft::builder() + .window(Window::::hann(1024)) + .hop_size(256) + .center(true) + .build() + .expect("valid configuration"); + + let spec = stft.spectrogram(&signal); + println!( + "spectrogram: {} frames x {} freqs", + spec.n_frames(), + spec.n_freqs() + ); + + let freqs = stft.freqs(fs as f64); + // Report the dominant frequency in a handful of frames. + for frame in (0..spec.n_frames()).step_by(spec.n_frames().max(1) / 8 + 1) { + let mags = magnitude(spec.column(frame)); + let mut powers: Vec = mags.iter().map(|m| m * m).collect(); + power_to_db(&mut powers, 1.0, Some(80.0)); + let peak = mags + .iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) + .map(|(k, _)| k) + .unwrap_or(0); + println!( + "frame {frame:>4}: peak ~ {:>6.0} Hz ({:>5.1} dB)", + freqs[peak], powers[peak] + ); + } +} diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..3a26366 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +edition = "2021" diff --git a/src/batch.rs b/src/batch.rs new file mode 100644 index 0000000..0631d43 --- /dev/null +++ b/src/batch.rs @@ -0,0 +1,204 @@ +//! One-shot batch spectrograms and the [`Spectrogram`] container. + +use crate::config::PadMode; +use crate::sample::Sample; +use crate::stft::Stft; +use alloc::vec; +use alloc::vec::Vec; +use num_complex::Complex; +use realfft::FftNum; + +#[cfg(feature = "rayon")] +use rayon::prelude::*; + +/// A dense spectrogram stored frame-major: `n_frames` columns of `n_freqs` +/// complex bins each. +#[derive(Debug, Clone, PartialEq)] +pub struct Spectrogram { + data: Vec>, + n_frames: usize, + n_freqs: usize, +} + +impl Spectrogram { + /// Build a spectrogram from a frame-major flat buffer. + /// + /// # Panics + /// Panics if `data.len() != n_frames * n_freqs`. + #[must_use] + pub fn from_flat(data: Vec>, n_frames: usize, n_freqs: usize) -> Self { + assert_eq!(data.len(), n_frames * n_freqs, "spectrogram shape mismatch"); + Self { + data, + n_frames, + n_freqs, + } + } + + /// Number of frames (columns). + #[must_use] + pub fn n_frames(&self) -> usize { + self.n_frames + } + + /// Number of frequency bins per frame. + #[must_use] + pub fn n_freqs(&self) -> usize { + self.n_freqs + } + + /// Whether the spectrogram has no frames. + #[must_use] + pub fn is_empty(&self) -> bool { + self.n_frames == 0 + } + + /// Borrow frame `index` as a slice of `n_freqs` bins. + #[must_use] + pub fn column(&self, index: usize) -> &[Complex] { + let start = index * self.n_freqs; + &self.data[start..start + self.n_freqs] + } + + /// Iterate over the frames (columns). + pub fn columns(&self) -> impl Iterator]> { + self.data.chunks_exact(self.n_freqs) + } + + /// The underlying frame-major buffer. + #[must_use] + pub fn as_flat(&self) -> &[Complex] { + &self.data + } + + /// Consume into the frame-major buffer. + #[must_use] + pub fn into_flat(self) -> Vec> { + self.data + } + + /// Convert to a `[n_freqs, n_frames]` [`ndarray::Array2`] (librosa layout). + #[cfg(feature = "ndarray")] + #[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] + #[must_use] + pub fn to_array2(&self) -> ndarray::Array2> { + ndarray::Array2::from_shape_fn((self.n_freqs, self.n_frames), |(bin, frame)| { + self.data[frame * self.n_freqs + bin] + }) + } +} + +/// Reflect an index `p` into `0..len` using mirror (no edge repeat) semantics, +/// matching NumPy's `reflect` padding. +fn reflect_index(p: isize, len: isize) -> usize { + if len == 1 { + return 0; + } + let period = 2 * (len - 1); + let mut m = p.rem_euclid(period); + if m >= len { + m = period - m; + } + m as usize +} + +/// Pad `signal` by `pad` samples on each side according to `mode`. +fn pad_signal(signal: &[T], pad: usize, mode: PadMode) -> Vec { + let len = signal.len(); + if pad == 0 { + return signal.to_vec(); + } + let mut out = vec![T::zero(); len + 2 * pad]; + let len_i = len as isize; + for (i, slot) in out.iter_mut().enumerate() { + let p = i as isize - pad as isize; + *slot = if p >= 0 && p < len_i { + signal[p as usize] + } else if len == 0 { + T::zero() + } else { + match mode { + PadMode::Zero => T::zero(), + PadMode::Edge => signal[p.clamp(0, len_i - 1) as usize], + PadMode::Reflect => signal[reflect_index(p, len_i)], + } + }; + } + out +} + +impl Stft { + /// Number of full frames produced for a signal of `signal_len` samples, + /// accounting for centered padding. + fn frame_count(&self, signal_len: usize) -> (usize, usize) { + let pad = if self.center { self.frame_len() / 2 } else { 0 }; + let padded_len = signal_len + 2 * pad; + let n_frames = if padded_len >= self.frame_len() { + 1 + (padded_len - self.frame_len()) / self.hop() + } else { + 0 + }; + (pad, n_frames) + } + + /// Compute the full spectrogram of `signal` in one call. + /// + /// Resets the internal streaming buffer. With + /// [`center`](crate::StftBuilder::center) enabled the signal is padded by + /// `frame_len / 2` on each side using the configured [`PadMode`]. With the + /// `rayon` feature the frames are computed in parallel. + #[must_use] + pub fn spectrogram(&mut self, signal: &[T]) -> Spectrogram { + self.reset(); + let (pad, n_frames) = self.frame_count(signal.len()); + let n_freqs = self.n_freqs(); + let frame_len = self.frame_len(); + let hop = self.hop(); + + let padded = pad_signal(signal, pad, self.pad_mode); + let zero = Complex::new(T::zero(), T::zero()); + let mut data = vec![zero; n_frames * n_freqs]; + + #[cfg(feature = "rayon")] + { + let fft = self.fft_handle(); + let win = self.window().coefficients(); + let scale = self.scale(); + let one = T::one(); + data.par_chunks_mut(n_freqs).enumerate().for_each_init( + || (fft.make_input_vec(), fft.make_scratch_vec()), + |(input, scratch), (frame_idx, out_col)| { + let start = frame_idx * hop; + let frame = &padded[start..start + frame_len]; + let (head, tail) = input.split_at_mut(frame_len); + for ((dst, &w), &s) in head.iter_mut().zip(win).zip(frame) { + *dst = s * w; + } + for dst in tail { + *dst = T::zero(); + } + fft.process_with_scratch(input, out_col, scratch) + .expect("realfft forward"); + if scale != one { + for bin in out_col.iter_mut() { + *bin = *bin * scale; + } + } + }, + ); + } + + #[cfg(not(feature = "rayon"))] + { + for frame_idx in 0..n_frames { + let start = frame_idx * hop; + let frame = &padded[start..start + frame_len]; + let spectrum = self.compute_frame(frame).expect("realfft forward"); + let out_start = frame_idx * n_freqs; + data[out_start..out_start + n_freqs].copy_from_slice(spectrum); + } + } + + Spectrogram::from_flat(data, n_frames, n_freqs) + } +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..7c88468 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,42 @@ +//! Configuration enums shared by the forward and inverse transforms. + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// How the complex STFT coefficients are scaled. +/// +/// The variants mirror the scaling modes of SciPy's `ShortTimeFFT`: +/// +/// - [`Scaling::None`] leaves the raw FFT output untouched. +/// - [`Scaling::Magnitude`] divides every bin by the sum of the window +/// coefficients, so a sinusoid's bin reflects its amplitude. +/// - [`Scaling::Density`] divides by `sqrt(fs * Σ wᵢ²)`, so that `|S|²` +/// approximates a power spectral density. Requires a configured sample rate. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Scaling { + /// No scaling: raw FFT coefficients. + #[default] + None, + /// Amplitude (magnitude) scaling: divide by the window sum. + Magnitude, + /// Power-spectral-density scaling: divide by `sqrt(fs * Σ wᵢ²)`. + Density, +} + +/// How a signal is padded when centered framing is enabled in batch mode. +/// +/// With [`center`](crate::StftBuilder::center) enabled the signal is padded by +/// `fft_size / 2` samples on each side so that frame `t` is centered on sample +/// `t * hop`, matching librosa's convention. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum PadMode { + /// Pad with zeros. + #[default] + Zero, + /// Mirror the signal at the boundary (without repeating the edge sample). + Reflect, + /// Repeat the edge sample. + Edge, +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..622ac76 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,79 @@ +//! Error type returned by the fallible parts of the crate. + +use core::fmt; + +/// Errors produced while configuring or running an STFT/ISTFT. +#[derive(Debug, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum StftError { + /// No window was supplied to a builder, or a zero-length window was given. + MissingWindow, + /// The frame length (window length) is zero. + InvalidFrameLength, + /// The hop size is invalid (must be `1..=frame_len`). + InvalidHopSize { + /// The requested hop size. + hop: usize, + /// The frame length it was checked against. + frame_len: usize, + }, + /// The FFT size is smaller than the frame length. + InvalidFftSize { + /// The requested FFT size. + fft_size: usize, + /// The frame length it must be at least as large as. + frame_len: usize, + }, + /// A supplied buffer did not have the expected length. + LengthMismatch { + /// The length that was expected. + expected: usize, + /// The length that was supplied. + got: usize, + }, + /// [`Scaling::Density`](crate::Scaling) was requested without a sample rate. + MissingSampleRate, + /// A processing call needed more buffered samples than were available. + NotEnoughData { + /// Samples required to compute a frame. + needed: usize, + /// Samples currently available. + available: usize, + }, + /// The underlying FFT backend reported an error. + Fft, +} + +impl fmt::Display for StftError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::MissingWindow => f.write_str("no window was supplied"), + Self::InvalidFrameLength => f.write_str("frame length must be greater than zero"), + Self::InvalidHopSize { hop, frame_len } => write!( + f, + "hop size {hop} is invalid for frame length {frame_len} (expected 1..={frame_len})" + ), + Self::InvalidFftSize { + fft_size, + frame_len, + } => write!( + f, + "fft size {fft_size} must be at least the frame length {frame_len}" + ), + Self::LengthMismatch { expected, got } => { + write!(f, "buffer length mismatch: expected {expected}, got {got}") + } + Self::MissingSampleRate => { + f.write_str("density scaling requires a sample rate to be configured") + } + Self::NotEnoughData { needed, available } => write!( + f, + "not enough buffered samples: needed {needed}, have {available}" + ), + Self::Fft => f.write_str("the FFT backend reported an error"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for StftError {} diff --git a/src/istft.rs b/src/istft.rs new file mode 100644 index 0000000..0ef4631 --- /dev/null +++ b/src/istft.rs @@ -0,0 +1,271 @@ +//! Inverse short-time Fourier transform via weighted overlap-add (WOLA). + +use crate::batch::Spectrogram; +use crate::error::StftError; +use crate::sample::{cast, Sample}; +use crate::stft::Stft; +use crate::window::Window; +use alloc::sync::Arc; +use alloc::vec::Vec; +use num_complex::Complex; +use realfft::{ComplexToReal, FftNum, RealFftPlanner}; + +/// Builder for [`Istft`]. +#[must_use] +pub struct IstftBuilder { + window: Option>, + hop: Option, + fft_size: Option, + forward_scale: T, + center: bool, +} + +impl Default for IstftBuilder { + fn default() -> Self { + Self { + window: None, + hop: None, + fft_size: None, + forward_scale: T::one(), + center: false, + } + } +} + +impl IstftBuilder { + /// Set the synthesis window. For perfect reconstruction this must be the + /// same window used by the forward transform. + pub fn window(mut self, window: Window) -> Self { + self.window = Some(window); + self + } + + /// Set the hop size. Must equal the forward hop size. + pub fn hop_size(mut self, hop: usize) -> Self { + self.hop = Some(hop); + self + } + + /// Set the FFT size. Must equal the forward FFT size. + pub fn fft_size(mut self, fft_size: usize) -> Self { + self.fft_size = Some(fft_size); + self + } + + /// Set the multiplicative scaling factor that the forward transform + /// applied, so it can be undone. Defaults to `1`. + pub fn forward_scale(mut self, scale: T) -> Self { + self.forward_scale = scale; + self + } + + /// Indicate that the forward transform used centered framing, so that + /// [`Istft::finish`] trims the `frame_len / 2` padding from each end. + pub fn center(mut self, center: bool) -> Self { + self.center = center; + self + } + + /// Validate the configuration and build the [`Istft`]. + /// + /// # Errors + /// Returns [`StftError`] for a missing/empty window, out-of-range hop, or + /// an FFT size smaller than the frame length. + pub fn build(self) -> Result, StftError> { + let window = self.window.ok_or(StftError::MissingWindow)?; + let frame_len = window.len(); + if frame_len == 0 { + return Err(StftError::InvalidFrameLength); + } + let hop = self.hop.unwrap_or((frame_len / 4).max(1)); + if hop == 0 || hop > frame_len { + return Err(StftError::InvalidHopSize { hop, frame_len }); + } + let fft_size = self.fft_size.unwrap_or(frame_len); + if fft_size < frame_len { + return Err(StftError::InvalidFftSize { + fft_size, + frame_len, + }); + } + + let fft = RealFftPlanner::::new().plan_fft_inverse(fft_size); + let spec_in = fft.make_input_vec(); + let frame_out = fft.make_output_vec(); + let scratch = fft.make_scratch_vec(); + let n_freqs = spec_in.len(); + + // Undo the unnormalized inverse FFT (factor `fft_size`) and the forward + // scaling factor in one division. + let inv_scale = T::one() / (self.forward_scale * cast::(fft_size as f64)); + + Ok(Istft { + window, + frame_len, + hop, + fft_size, + n_freqs, + inv_scale, + center: self.center, + fft, + spec_in, + frame_out, + scratch, + output: Vec::new(), + norm: Vec::new(), + pos: 0, + frames: 0, + }) + } +} + +/// An inverse short-time Fourier transform that reconstructs a real signal +/// from spectrogram columns using weighted overlap-add. +/// +/// Feed columns with [`process_column`](Istft::process_column) (or a whole +/// [`Spectrogram`] via [`reconstruct`](Istft::reconstruct)) and obtain the +/// signal with [`finish`](Istft::finish). +pub struct Istft { + window: Window, + frame_len: usize, + hop: usize, + fft_size: usize, + n_freqs: usize, + inv_scale: T, + center: bool, + fft: Arc>, + spec_in: Vec>, + frame_out: Vec, + scratch: Vec>, + output: Vec, + norm: Vec, + pos: usize, + frames: usize, +} + +impl Istft { + /// Start building an [`Istft`]. + pub fn builder() -> IstftBuilder { + IstftBuilder::default() + } + + /// Number of frequency bins expected per column (`fft_size / 2 + 1`). + #[must_use] + pub fn n_freqs(&self) -> usize { + self.n_freqs + } + + /// Number of columns processed so far. + #[must_use] + pub fn frames(&self) -> usize { + self.frames + } + + /// Overlap-add one spectrogram column. + /// + /// # Errors + /// Returns [`StftError::LengthMismatch`] if `column.len() != n_freqs`, or + /// [`StftError::Fft`] if the backend fails. + pub fn process_column(&mut self, column: &[Complex]) -> Result<(), StftError> { + if column.len() != self.n_freqs { + return Err(StftError::LengthMismatch { + expected: self.n_freqs, + got: column.len(), + }); + } + + self.spec_in.copy_from_slice(column); + // The inverse real FFT requires the DC bin (and the Nyquist bin for an + // even-length transform) to be purely real; force it to avoid backend + // errors from round-off. + self.spec_in[0].im = T::zero(); + if self.fft_size % 2 == 0 { + let last = self.n_freqs - 1; + self.spec_in[last].im = T::zero(); + } + + self.fft + .process_with_scratch(&mut self.spec_in, &mut self.frame_out, &mut self.scratch) + .map_err(|_| StftError::Fft)?; + + let end = self.pos + self.frame_len; + if self.output.len() < end { + self.output.resize(end, T::zero()); + self.norm.resize(end, T::zero()); + } + + let inv = self.inv_scale; + let frame_len = self.frame_len; + let pos = self.pos; + let win = self.window.coefficients(); + let out_seg = &mut self.output[pos..pos + frame_len]; + let norm_seg = &mut self.norm[pos..pos + frame_len]; + for (((o, n), &w), &fo) in out_seg + .iter_mut() + .zip(norm_seg.iter_mut()) + .zip(win) + .zip(&self.frame_out[..frame_len]) + { + let recon = fo * inv; + *o = *o + w * recon; + *n = *n + w * w; + } + + self.pos += self.hop; + self.frames += 1; + Ok(()) + } + + /// Overlap-add an entire [`Spectrogram`] and return the reconstructed + /// signal. + /// + /// # Errors + /// Propagates errors from [`process_column`](Istft::process_column). + pub fn reconstruct(mut self, spectrogram: &Spectrogram) -> Result, StftError> { + for column in spectrogram.columns() { + self.process_column(column)?; + } + Ok(self.finish()) + } + + /// Finish reconstruction: normalize by the accumulated window energy and + /// return the signal, trimming centered padding if configured. + #[must_use] + pub fn finish(self) -> Vec { + let mut output = self.output; + let eps = cast::(1e-12); + for (o, n) in output.iter_mut().zip(&self.norm) { + if *n > eps { + *o = *o / *n; + } else { + *o = T::zero(); + } + } + + if self.center { + let pad = self.frame_len / 2; + if output.len() >= 2 * pad { + output.truncate(output.len() - pad); + output.drain(..pad); + } + } + output + } +} + +impl Stft { + /// Build an [`Istft`] that exactly inverts this forward transform + /// (same window, hop, FFT size, scaling and centering). + /// + /// # Errors + /// Returns [`StftError`] if the mirrored configuration is invalid. + pub fn inverse(&self) -> Result, StftError> { + IstftBuilder::default() + .window(self.window().clone()) + .hop_size(self.hop()) + .fft_size(self.fft_size()) + .forward_scale(self.scale()) + .center(self.center) + .build() + } +} diff --git a/src/lib.rs b/src/lib.rs index 1c4c418..cacfa32 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,377 +1,116 @@ -/*! - -**computes the [short-time fourier transform](https://en.wikipedia.org/wiki/Short-time_Fourier_transform) -on streaming data.** - -to use add `ruststft = "*"` -to the `[dependencies]` section of your `Cargo.toml` and `use ruststft;` in your code. - -## example - -``` -use ruststft::{STFT, WindowType}; - -// let's generate ten seconds of fake audio -let sample_rate: usize = 44100; -let seconds: usize = 10; -let sample_count = sample_rate * seconds; -let all_samples = (0..sample_count).map(|x| x as f64).collect::>(); - -// let's initialize our short-time fourier transform -let window_type: WindowType = WindowType::Hanning; -let window_size: usize = 1024; -let step_size: usize = 512; -let mut stft = STFT::new(window_type, window_size, step_size); - -// we need a buffer to hold a computed column of the spectrogram -let mut spectrogram_column: Vec = - std::iter::repeat(0.).take(stft.output_size()).collect(); - -// iterate over all the samples in chunks of 3000 samples. -// in a real program you would probably read from something instead. -for some_samples in (&all_samples[..]).chunks(3000) { - // append the samples to the internal ringbuffer of the stft - stft.append_samples(some_samples); - - // as long as there remain window_size samples in the internal - // ringbuffer of the stft - while stft.contains_enough_to_compute() { - // compute one column of the stft by - // taking the first window_size samples of the internal ringbuffer, - // multiplying them with the window, - // computing the fast fourier transform, - // taking half of the symetric complex outputs, - // computing the norm of the complex outputs and - // taking the log10 - stft.compute_column(&mut spectrogram_column[..]); - - // here's where you would do something with the - // spectrogram_column... - - // drop step_size samples from the internal ringbuffer of the stft - // making a step of size step_size - stft.move_to_next_column(); - } -} -``` -*/ - -use num::complex::Complex; -use num::traits::{Float, Signed, Zero}; -use rustfft::{Fft, FftDirection, FftNum, FftPlanner}; -use std::str::FromStr; -use std::sync::Arc; -use strider::{SliceRing, SliceRingImpl}; - -/// returns `0` if `log10(value).is_negative()`. -/// otherwise returns `log10(value)`. -/// `log10` turns values in domain `0..1` into values -/// in range `-inf..0`. -/// `log10_positive` turns values in domain `0..1` into `0`. -/// this sets very small values to zero which may not be -/// what you want depending on your application. -#[inline] -pub fn log10_positive(value: T) -> T { - // Float.log10 - // Signed.is_negative - // Zero.zero - let log = value.log10(); - if log.is_negative() { - T::zero() - } else { - log - } -} - -/// the type of apodization window to use -#[derive(Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Debug, Hash)] -pub enum WindowType { - Hanning, - Hamming, - Blackman, - Nuttall, - None, -} - -impl FromStr for WindowType { - type Err = &'static str; - - fn from_str(s: &str) -> Result { - let lower = s.to_lowercase(); - match lower.as_str() { - "hanning" => Ok(WindowType::Hanning), - "hann" => Ok(WindowType::Hanning), - "hamming" => Ok(WindowType::Hamming), - "blackman" => Ok(WindowType::Blackman), - "nuttall" => Ok(WindowType::Nuttall), - "none" => Ok(WindowType::None), - _ => Err("no match"), - } - } -} - -// this also implements ToString::to_string -impl std::fmt::Display for WindowType { - fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(formatter, "{:?}", self) - } -} - -// TODO write a macro that does this automatically for any enum -static WINDOW_TYPES: [WindowType; 5] = [ - WindowType::Hanning, - WindowType::Hamming, - WindowType::Blackman, - WindowType::Nuttall, - WindowType::None, -]; - -impl WindowType { - pub fn values() -> [WindowType; 5] { - WINDOW_TYPES - } -} - -pub struct STFT -where - T: FftNum + FromF64 + num::Float, -{ - pub window_size: usize, - pub fft_size: usize, - pub step_size: usize, - pub fft: Arc>, - pub window: Option>, - /// internal ringbuffer used to store samples - pub sample_ring: SliceRingImpl, - pub real_input: Vec, - pub complex_input_output: Vec>, - fft_scratch: Vec>, -} - -impl STFT -where - T: FftNum + FromF64 + num::Float, -{ - pub fn window_type_to_window_vec( - window_type: WindowType, - window_size: usize, - ) -> Option> { - match window_type { - WindowType::Hanning => Some( - apodize::hanning_iter(window_size) - .map(FromF64::from_f64) - .collect(), - ), - WindowType::Hamming => Some( - apodize::hamming_iter(window_size) - .map(FromF64::from_f64) - .collect(), - ), - WindowType::Blackman => Some( - apodize::blackman_iter(window_size) - .map(FromF64::from_f64) - .collect(), - ), - WindowType::Nuttall => Some( - apodize::nuttall_iter(window_size) - .map(FromF64::from_f64) - .collect(), - ), - WindowType::None => None, - } - } - - pub fn new(window_type: WindowType, window_size: usize, step_size: usize) -> Self { - let window = Self::window_type_to_window_vec(window_type, window_size); - Self::new_with_window_vec(window, window_size, step_size) - } - - pub fn new_with_zero_padding( - window_type: WindowType, - window_size: usize, - fft_size: usize, - step_size: usize, - ) -> Self { - let window = Self::window_type_to_window_vec(window_type, window_size); - Self::new_with_window_vec_and_zero_padding(window, window_size, fft_size, step_size) - } - - // TODO this should ideally take an iterator and not a vec - pub fn new_with_window_vec_and_zero_padding( - window: Option>, - window_size: usize, - fft_size: usize, - step_size: usize, - ) -> Self { - assert!(step_size > 0 && step_size < window_size); - let fft = FftPlanner::new().plan_fft(fft_size, FftDirection::Forward); - - // allocate a scratch buffer for the FFT - let scratch_len = fft.get_inplace_scratch_len(); - let fft_scratch = vec![Complex::::zero(); scratch_len]; - - STFT { - window_size, - fft_size, - step_size, - fft, - fft_scratch, - sample_ring: SliceRingImpl::new(), - window, - real_input: std::iter::repeat(T::zero()).take(window_size).collect(), - // zero-padded complex_input, so the size is fft_size, not window_size - complex_input_output: std::iter::repeat(Complex::::zero()) - .take(fft_size) - .collect(), - // same size as complex_output - } - } - - pub fn new_with_window_vec( - window: Option>, - window_size: usize, - step_size: usize, - ) -> Self { - Self::new_with_window_vec_and_zero_padding(window, window_size, window_size, step_size) - } - - #[inline] - pub fn output_size(&self) -> usize { - self.fft_size / 2 - } - - #[inline] - pub fn len(&self) -> usize { - self.sample_ring.len() - } - - #[inline] - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - pub fn append_samples(&mut self, input: &[T]) { - self.sample_ring.push_many_back(input); - } - - #[inline] - pub fn contains_enough_to_compute(&self) -> bool { - self.window_size <= self.sample_ring.len() - } - - pub fn compute_into_complex_output(&mut self) { - assert!(self.contains_enough_to_compute()); - - // read into real_input - self.sample_ring.read_many_front(&mut self.real_input); - - // multiply real_input with window - if let Some(ref window) = self.window { - for (dst, src) in self.real_input.iter_mut().zip(window.iter()) { - *dst = *dst * *src; - } - } - - // copy windowed real_input as real parts into complex_input - // only copy `window_size` size, leave the rest in `complex_input` be zero - for (src, dst) in self - .real_input - .iter() - .zip(self.complex_input_output.iter_mut()) - { - dst.re = *src; - dst.im = T::zero(); - } - - // ensure the buffer is indeed zero-padded when needed. - if self.window_size < self.fft_size { - for dst in self.complex_input_output.iter_mut().skip(self.window_size) { - dst.re = T::zero(); - dst.im = T::zero(); - } - } - - // compute fft - self.fft - .process_with_scratch(&mut self.complex_input_output, &mut self.fft_scratch) - } - - /// # Panics - /// panics unless `self.output_size() == output.len()` - pub fn compute_complex_column(&mut self, output: &mut [Complex]) { - assert_eq!(self.output_size(), output.len()); - - self.compute_into_complex_output(); - - for (dst, src) in output.iter_mut().zip(self.complex_input_output.iter()) { - *dst = *src; - } - } - - /// # Panics - /// panics unless `self.output_size() == output.len()` - pub fn compute_magnitude_column(&mut self, output: &mut [T]) { - assert_eq!(self.output_size(), output.len()); - - self.compute_into_complex_output(); - - for (dst, src) in output.iter_mut().zip(self.complex_input_output.iter()) { - *dst = src.norm(); - } - } - - /// computes a column of the spectrogram - /// # Panics - /// panics unless `self.output_size() == output.len()` - pub fn compute_column(&mut self, output: &mut [T]) { - assert_eq!(self.output_size(), output.len()); - - self.compute_into_complex_output(); - - for (dst, src) in output.iter_mut().zip(self.complex_input_output.iter()) { - *dst = log10_positive(src.norm()); - } - } - - /// make a step - /// drops `self.step_size` samples from the internal buffer `self.sample_ring`. - pub fn move_to_next_column(&mut self) { - self.sample_ring.drop_many_front(self.step_size); - } - - /// corresponding frequencies of a column of the spectrogram - /// # Arguments - /// `fs`: sampling frequency. - pub fn freqs(&self, fs: f64) -> Vec { - let n_freqs = self.output_size(); - (0..n_freqs) - .map(|f| (f as f64) / ((n_freqs - 1) as f64) * (fs / 2.)) - .collect() - } - - /// corresponding time of first columns of the spectrogram - pub fn first_time(&self, fs: f64) -> f64 { - (self.window_size as f64) / (fs * 2.) - } - - /// time interval between two adjacent columns of the spectrogram - pub fn time_interval(&self, fs: f64) -> f64 { - (self.step_size as f64) / fs - } -} - -pub trait FromF64 { - fn from_f64(n: f64) -> Self; -} - -impl FromF64 for f64 { - fn from_f64(n: f64) -> Self { - n - } -} - -impl FromF64 for f32 { - fn from_f64(n: f64) -> Self { - n as f32 - } -} +//! # ruststft +//! +//! A complete short-time Fourier transform (STFT) toolkit for Rust: +//! +//! - **Forward STFT** over real-valued signals, in both *streaming* and *batch* +//! modes, backed by [`realfft`] (≈2× faster and half the memory of a full +//! complex FFT on real input). +//! - **Inverse STFT** with weighted overlap-add (WOLA) for perfect +//! reconstruction. +//! - A **rich window library** (Hann, Hamming, Blackman/-Harris, Nuttall, +//! Bartlett, triangular, Welch, cosine, Tukey, Kaiser, Gaussian, flat-top) +//! with both periodic (spectral-analysis) and symmetric (filter-design) +//! variants. +//! - **Spectrum helpers**: magnitude, power, phase and decibel conversions. +//! - Optional **mel** spectrograms and **MFCC**s (`mel` feature). +//! - Optional `ndarray` I/O (`ndarray`), parallel batch processing +//! (`rayon`) and configuration (de)serialization (`serde`). +//! +//! ## Quick start (batch) +//! +//! ``` +//! # #[cfg(feature = "std")] { +//! use ruststft::{Stft, Window}; +//! +//! // A 1 kHz tone sampled at 8 kHz. +//! let fs = 8_000.0; +//! let signal: Vec = (0..8_000) +//! .map(|n| (2.0 * std::f64::consts::PI * 1_000.0 * n as f64 / fs).sin()) +//! .collect(); +//! +//! let mut stft = Stft::builder() +//! .window(Window::::hann(1024)) +//! .hop_size(256) +//! .build() +//! .unwrap(); +//! +//! let spec = stft.spectrogram(&signal); +//! assert_eq!(spec.n_freqs(), 1024 / 2 + 1); // includes the Nyquist bin +//! # } +//! ``` +//! +//! ## Streaming +//! +//! ``` +//! # #[cfg(feature = "std")] { +//! use ruststft::{Stft, Window}; +//! +//! let mut stft = Stft::builder() +//! .window(Window::::hann(1024)) +//! .hop_size(512) +//! .build() +//! .unwrap(); +//! +//! let mut column = vec![num_complex::Complex::new(0.0f32, 0.0); stft.n_freqs()]; +//! let chunk: Vec = (0..3000).map(|x| x as f32).collect(); +//! +//! stft.append(&chunk); +//! while stft.ready() { +//! stft.process_into(&mut column).unwrap(); +//! // ... use `column` ... +//! stft.step(); +//! } +//! # } +//! ``` +//! +//! ## `no_std` +//! +//! The crate is `#![no_std]` (with `alloc`). The FFT-backed processors +//! ([`Stft`], [`Istft`], batch spectrograms) require the default `std` +//! feature because the underlying FFT backend needs `std`. The window +//! library, [`crate::mel`] filterbank/MFCC math and the +//! [`crate::spectrum`] helpers build without `std`. + +#![cfg_attr(not(feature = "std"), no_std)] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![forbid(unsafe_code)] +#![deny(missing_docs)] + +extern crate alloc; + +pub mod error; +pub mod sample; +pub mod spectrum; +pub mod window; + +mod config; + +#[cfg(feature = "mel")] +#[cfg_attr(docsrs, doc(cfg(feature = "mel")))] +pub mod mel; + +#[cfg(feature = "std")] +mod batch; +#[cfg(feature = "std")] +mod istft; +#[cfg(feature = "std")] +mod stft; + +pub use config::{PadMode, Scaling}; +pub use error::StftError; +pub use sample::Sample; +pub use window::{Symmetry, Window, WindowFunction}; + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +pub use batch::Spectrogram; +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +pub use istft::{Istft, IstftBuilder}; +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +pub use stft::{Stft, StftBuilder}; + +// Re-export the complex number type so downstream users do not need to track +// the exact `num-complex` version themselves. +pub use num_complex::Complex; diff --git a/src/mel.rs b/src/mel.rs new file mode 100644 index 0000000..6b47371 --- /dev/null +++ b/src/mel.rs @@ -0,0 +1,252 @@ +//! Mel filterbank, (log-)mel spectrograms and MFCCs. +//! +//! This module is pure math and builds without `std`. It operates on power +//! spectra (`|S|²`), which you obtain from a [`Spectrogram`](crate::Spectrogram) +//! column via [`spectrum::power`](crate::spectrum::power). +//! +//! Conventions follow librosa's defaults: the Slaney mel scale with +//! area-normalized triangular filters, and an orthonormal DCT-II for MFCCs. + +use crate::sample::{cast, Sample}; +use alloc::vec; +use alloc::vec::Vec; +use core::f64::consts::PI; + +#[cfg(not(feature = "std"))] +use num_traits::Float; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// The mel frequency scale convention. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum MelScale { + /// Slaney (Auditory Toolbox) scale, librosa's default. + #[default] + Slaney, + /// HTK scale: `2595·log₁₀(1 + f/700)`. + Htk, +} + +const SLANEY_F_SP: f64 = 200.0 / 3.0; +const SLANEY_MIN_LOG_HZ: f64 = 1000.0; + +/// Convert a frequency in Hz to mels. +#[must_use] +pub fn hz_to_mel(hz: f64, scale: MelScale) -> f64 { + match scale { + MelScale::Htk => 2595.0 * (1.0 + hz / 700.0).log10(), + MelScale::Slaney => { + let min_log_mel = SLANEY_MIN_LOG_HZ / SLANEY_F_SP; + let logstep = 6.4f64.ln() / 27.0; + if hz >= SLANEY_MIN_LOG_HZ { + min_log_mel + (hz / SLANEY_MIN_LOG_HZ).ln() / logstep + } else { + hz / SLANEY_F_SP + } + } + } +} + +/// Convert mels to a frequency in Hz. +#[must_use] +pub fn mel_to_hz(mel: f64, scale: MelScale) -> f64 { + match scale { + MelScale::Htk => 700.0 * (10.0f64.powf(mel / 2595.0) - 1.0), + MelScale::Slaney => { + let min_log_mel = SLANEY_MIN_LOG_HZ / SLANEY_F_SP; + let logstep = 6.4f64.ln() / 27.0; + if mel >= min_log_mel { + SLANEY_MIN_LOG_HZ * (logstep * (mel - min_log_mel)).exp() + } else { + mel * SLANEY_F_SP + } + } + } +} + +/// A bank of triangular mel filters mapping `n_freqs` linear bins to `n_mels` +/// mel bands. +#[derive(Debug, Clone, PartialEq)] +pub struct MelFilterBank { + weights: Vec, + n_mels: usize, + n_freqs: usize, +} + +impl MelFilterBank { + /// Construct a mel filterbank (librosa `mel` with `norm='slaney'`). + /// + /// - `n_fft` is the FFT size (the bank has `n_fft / 2 + 1` linear bins). + /// - `fmin`/`fmax` bound the mel band edges in Hz. + #[must_use] + pub fn new( + n_mels: usize, + n_fft: usize, + sample_rate: f64, + fmin: f64, + fmax: f64, + scale: MelScale, + ) -> Self { + let n_freqs = n_fft / 2 + 1; + let mut weights = vec![T::zero(); n_mels * n_freqs]; + + // Linear FFT bin frequencies. + let fft_freqs: Vec = (0..n_freqs) + .map(|k| k as f64 * sample_rate / n_fft as f64) + .collect(); + + // `n_mels + 2` mel band edges, converted back to Hz. + let mel_min = hz_to_mel(fmin, scale); + let mel_max = hz_to_mel(fmax, scale); + let hz_points: Vec = (0..n_mels + 2) + .map(|i| { + let mel = mel_min + (mel_max - mel_min) * i as f64 / (n_mels + 1) as f64; + mel_to_hz(mel, scale) + }) + .collect(); + + for m in 0..n_mels { + let lower_edge = hz_points[m]; + let center = hz_points[m + 1]; + let upper_edge = hz_points[m + 2]; + let lower_width = center - lower_edge; + let upper_width = upper_edge - center; + // Slaney area normalization. + let enorm = 2.0 / (upper_edge - lower_edge); + for (k, &f) in fft_freqs.iter().enumerate() { + let lower = if lower_width > 0.0 { + (f - lower_edge) / lower_width + } else { + 0.0 + }; + let upper = if upper_width > 0.0 { + (upper_edge - f) / upper_width + } else { + 0.0 + }; + let w = lower.min(upper).max(0.0) * enorm; + weights[m * n_freqs + k] = cast(w); + } + } + + Self { + weights, + n_mels, + n_freqs, + } + } + + /// Number of mel bands. + #[must_use] + pub fn n_mels(&self) -> usize { + self.n_mels + } + + /// Number of linear frequency bins expected as input. + #[must_use] + pub fn n_freqs(&self) -> usize { + self.n_freqs + } + + /// Filter weights, row-major `[n_mels, n_freqs]`. + #[must_use] + pub fn weights(&self) -> &[T] { + &self.weights + } + + /// Apply the filterbank to one power-spectrum column into `out`. + /// + /// # Panics + /// Panics if `power.len() != n_freqs` or `out.len() != n_mels`. + pub fn transform_into(&self, power: &[T], out: &mut [T]) { + assert_eq!(power.len(), self.n_freqs, "mel input length mismatch"); + assert_eq!(out.len(), self.n_mels, "mel output length mismatch"); + for (m, slot) in out.iter_mut().enumerate() { + let row = &self.weights[m * self.n_freqs..(m + 1) * self.n_freqs]; + let mut acc = T::zero(); + for (&w, &p) in row.iter().zip(power) { + acc = acc + w * p; + } + *slot = acc; + } + } + + /// Apply the filterbank to one power-spectrum column, allocating the result. + #[must_use] + pub fn transform(&self, power: &[T]) -> Vec { + let mut out = vec![T::zero(); self.n_mels]; + self.transform_into(power, &mut out); + out + } +} + +/// An orthonormal type-II DCT, precomputed as a basis matrix. +/// +/// Matches `scipy.fftpack.dct(type=2, norm='ortho')`, which is what librosa +/// uses to turn a log-mel spectrum into MFCCs. +#[derive(Debug, Clone, PartialEq)] +pub struct DctII { + basis: Vec, + n_in: usize, + n_out: usize, +} + +impl DctII { + /// Build a DCT-II that maps `n_in` inputs to the first `n_out` coefficients. + #[must_use] + pub fn new(n_in: usize, n_out: usize) -> Self { + let mut basis = vec![T::zero(); n_out * n_in]; + let n = n_in as f64; + for k in 0..n_out { + let f = if k == 0 { + (1.0 / (4.0 * n)).sqrt() + } else { + (1.0 / (2.0 * n)).sqrt() + }; + for m in 0..n_in { + let v = 2.0 * f * (PI * k as f64 * (2.0 * m as f64 + 1.0) / (2.0 * n)).cos(); + basis[k * n_in + m] = cast(v); + } + } + Self { basis, n_in, n_out } + } + + /// Number of input samples. + #[must_use] + pub fn n_in(&self) -> usize { + self.n_in + } + + /// Number of output coefficients. + #[must_use] + pub fn n_out(&self) -> usize { + self.n_out + } + + /// Transform one input column into `out`. + /// + /// # Panics + /// Panics if `input.len() != n_in` or `out.len() != n_out`. + pub fn transform_into(&self, input: &[T], out: &mut [T]) { + assert_eq!(input.len(), self.n_in, "DCT input length mismatch"); + assert_eq!(out.len(), self.n_out, "DCT output length mismatch"); + for (k, slot) in out.iter_mut().enumerate() { + let row = &self.basis[k * self.n_in..(k + 1) * self.n_in]; + let mut acc = T::zero(); + for (&b, &x) in row.iter().zip(input) { + acc = acc + b * x; + } + *slot = acc; + } + } + + /// Transform one input column, allocating the result. + #[must_use] + pub fn transform(&self, input: &[T]) -> Vec { + let mut out = vec![T::zero(); self.n_out]; + self.transform_into(input, &mut out); + out + } +} diff --git a/src/sample.rs b/src/sample.rs new file mode 100644 index 0000000..4399b90 --- /dev/null +++ b/src/sample.rs @@ -0,0 +1,27 @@ +//! Numeric scalar abstraction shared by the whole crate. + +use num_traits::{Float, FromPrimitive}; + +/// Scalar sample type supported by the crate. +/// +/// This is a convenience trait alias implemented for every type that is both a +/// floating-point number ([`num_traits::Float`]) and constructible from +/// primitives ([`num_traits::FromPrimitive`]). In practice that means +/// [`f32`] and [`f64`]. +/// +/// The FFT-backed processors ([`Stft`](crate::Stft), [`Istft`](crate::Istft)) +/// additionally require the backend's `FftNum` bound, which is also satisfied +/// by `f32`/`f64`. +pub trait Sample: Float + FromPrimitive + 'static {} + +impl Sample for T {} + +/// Convert an `f64` constant into the sample type `T`. +/// +/// Window coefficients, mel scale conversions and decibel references are all +/// computed in `f64` and cast once. For `f32`/`f64` this conversion is +/// infallible, hence the `expect`. +#[inline] +pub(crate) fn cast(value: f64) -> T { + T::from_f64(value).expect("f32/f64 are always constructible from f64") +} diff --git a/src/spectrum.rs b/src/spectrum.rs new file mode 100644 index 0000000..0c885be --- /dev/null +++ b/src/spectrum.rs @@ -0,0 +1,106 @@ +//! Helpers for turning complex STFT coefficients into magnitudes, powers, +//! phases and decibels. +//! +//! All functions operate on plain slices so they are usable on a single +//! spectrogram column, a whole flattened spectrogram, or any other layout. + +use crate::sample::{cast, Sample}; +use alloc::vec::Vec; +use num_complex::Complex; + +/// Compute the magnitude (`|z|`) of each complex bin into `out`. +/// +/// # Panics +/// Panics if `spectrum.len() != out.len()`. +pub fn magnitude_into(spectrum: &[Complex], out: &mut [T]) { + assert_eq!(spectrum.len(), out.len(), "magnitude_into length mismatch"); + for (dst, z) in out.iter_mut().zip(spectrum) { + *dst = z.norm(); + } +} + +/// Allocate and return the magnitudes of `spectrum`. +#[must_use] +pub fn magnitude(spectrum: &[Complex]) -> Vec { + spectrum.iter().map(|z| z.norm()).collect() +} + +/// Compute the power (`|z|²`) of each complex bin into `out`. +/// +/// # Panics +/// Panics if `spectrum.len() != out.len()`. +pub fn power_into(spectrum: &[Complex], out: &mut [T]) { + assert_eq!(spectrum.len(), out.len(), "power_into length mismatch"); + for (dst, z) in out.iter_mut().zip(spectrum) { + *dst = z.norm_sqr(); + } +} + +/// Allocate and return the power of `spectrum`. +#[must_use] +pub fn power(spectrum: &[Complex]) -> Vec { + spectrum.iter().map(|z| z.norm_sqr()).collect() +} + +/// Compute the phase angle (in radians, `-π..=π`) of each bin into `out`. +/// +/// # Panics +/// Panics if `spectrum.len() != out.len()`. +pub fn phase_into(spectrum: &[Complex], out: &mut [T]) { + assert_eq!(spectrum.len(), out.len(), "phase_into length mismatch"); + for (dst, z) in out.iter_mut().zip(spectrum) { + *dst = z.arg(); + } +} + +/// Allocate and return the phase angles of `spectrum`. +#[must_use] +pub fn phase(spectrum: &[Complex]) -> Vec { + spectrum.iter().map(|z| z.arg()).collect() +} + +/// Smallest value clamped to before taking a logarithm, avoiding `-inf`. +fn amin() -> T { + cast(1e-10) +} + +/// Convert amplitudes to decibels in place: `20·log₁₀(max(|a|, amin) / reference)`. +/// +/// If `top_db` is `Some(d)`, values are floored at `max - d`, matching +/// librosa's `amplitude_to_db`. +pub fn amplitude_to_db(amplitudes: &mut [T], reference: T, top_db: Option) { + let amin = amin::(); + let twenty = cast::(20.0); + let log_ref = reference.abs().max(amin).log10(); + convert_to_db(amplitudes, twenty, log_ref, amin, top_db); +} + +/// Convert powers to decibels in place: `10·log₁₀(max(p, amin) / reference)`. +/// +/// If `top_db` is `Some(d)`, values are floored at `max - d`, matching +/// librosa's `power_to_db`. +pub fn power_to_db(powers: &mut [T], reference: T, top_db: Option) { + let amin = amin::(); + let ten = cast::(10.0); + let log_ref = reference.abs().max(amin).log10(); + convert_to_db(powers, ten, log_ref, amin, top_db); +} + +fn convert_to_db(values: &mut [T], factor: T, log_ref: T, amin: T, top_db: Option) { + let mut max_db = T::neg_infinity(); + for v in values.iter_mut() { + let db = factor * (v.max(amin).log10() - log_ref); + *v = db; + if db > max_db { + max_db = db; + } + } + if let Some(top) = top_db { + let floor = max_db - top.abs(); + for v in values.iter_mut() { + if *v < floor { + *v = floor; + } + } + } +} diff --git a/src/stft.rs b/src/stft.rs new file mode 100644 index 0000000..e40f40c --- /dev/null +++ b/src/stft.rs @@ -0,0 +1,352 @@ +//! Forward short-time Fourier transform: streaming processor and builder. + +use crate::config::{PadMode, Scaling}; +use crate::error::StftError; +use crate::sample::{cast, Sample}; +use crate::window::Window; +use alloc::collections::VecDeque; +use alloc::sync::Arc; +use alloc::vec::Vec; +use num_complex::Complex; +use realfft::{FftNum, RealFftPlanner, RealToComplex}; + +/// Builder for [`Stft`]. +/// +/// A window is mandatory; everything else has a sensible default +/// (hop = `frame_len / 4`, fft size = `frame_len`, no scaling, no centering). +#[must_use] +pub struct StftBuilder { + window: Option>, + hop: Option, + fft_size: Option, + scaling: Scaling, + center: bool, + pad_mode: PadMode, + sample_rate: Option, +} + +impl Default for StftBuilder { + fn default() -> Self { + Self { + window: None, + hop: None, + fft_size: None, + scaling: Scaling::None, + center: false, + pad_mode: PadMode::Zero, + sample_rate: None, + } + } +} + +impl StftBuilder { + /// Set the analysis window. Its length becomes the frame length. + pub fn window(mut self, window: Window) -> Self { + self.window = Some(window); + self + } + + /// Set the hop size (samples advanced between frames). Defaults to + /// `frame_len / 4`. + pub fn hop_size(mut self, hop: usize) -> Self { + self.hop = Some(hop); + self + } + + /// Set the FFT size; values larger than the frame length zero-pad each + /// frame. Defaults to the frame length. + pub fn fft_size(mut self, fft_size: usize) -> Self { + self.fft_size = Some(fft_size); + self + } + + /// Set the coefficient [`Scaling`] mode. + pub fn scaling(mut self, scaling: Scaling) -> Self { + self.scaling = scaling; + self + } + + /// Enable centered framing for batch [`spectrogram`](Stft::spectrogram): + /// the signal is padded by `fft_size / 2` on each side. + pub fn center(mut self, center: bool) -> Self { + self.center = center; + self + } + + /// Set how the signal is padded when centered framing is enabled. + pub fn pad_mode(mut self, pad_mode: PadMode) -> Self { + self.pad_mode = pad_mode; + self + } + + /// Set the sample rate, required for [`Scaling::Density`]. + pub fn sample_rate(mut self, fs: f64) -> Self { + self.sample_rate = Some(fs); + self + } + + /// Validate the configuration and build the [`Stft`]. + /// + /// # Errors + /// Returns [`StftError`] if the window is missing/empty, the hop size is + /// out of range, the FFT size is smaller than the frame length, or density + /// scaling is requested without a sample rate. + pub fn build(self) -> Result, StftError> { + let window = self.window.ok_or(StftError::MissingWindow)?; + let frame_len = window.len(); + if frame_len == 0 { + return Err(StftError::InvalidFrameLength); + } + + let hop = self.hop.unwrap_or((frame_len / 4).max(1)); + if hop == 0 || hop > frame_len { + return Err(StftError::InvalidHopSize { hop, frame_len }); + } + + let fft_size = self.fft_size.unwrap_or(frame_len); + if fft_size < frame_len { + return Err(StftError::InvalidFftSize { + fft_size, + frame_len, + }); + } + + let scale = match self.scaling { + Scaling::None => T::one(), + Scaling::Magnitude => T::one() / window.sum(), + Scaling::Density => { + let fs = self.sample_rate.ok_or(StftError::MissingSampleRate)?; + T::one() / (cast::(fs) * window.sum_squared()).sqrt() + } + }; + + let fft = RealFftPlanner::::new().plan_fft_forward(fft_size); + let input = fft.make_input_vec(); + let spectrum = fft.make_output_vec(); + let scratch = fft.make_scratch_vec(); + let n_freqs = spectrum.len(); + + Ok(Stft { + window, + frame_len, + hop, + fft_size, + n_freqs, + scale, + center: self.center, + pad_mode: self.pad_mode, + fft, + input, + spectrum, + scratch, + ring: VecDeque::new(), + }) + } +} + +/// A streaming forward short-time Fourier transform over real samples. +/// +/// Feed samples with [`append`](Stft::append); whenever [`ready`](Stft::ready) +/// is true, compute a column with [`process_into`](Stft::process_into) (or use +/// the [`columns`](Stft::columns) iterator) and advance with +/// [`step`](Stft::step). For one-shot processing of a whole signal use +/// [`spectrogram`](Stft::spectrogram). +pub struct Stft { + window: Window, + frame_len: usize, + hop: usize, + fft_size: usize, + n_freqs: usize, + scale: T, + pub(crate) center: bool, + pub(crate) pad_mode: PadMode, + fft: Arc>, + input: Vec, + spectrum: Vec>, + scratch: Vec>, + ring: VecDeque, +} + +impl Stft { + /// Start building an [`Stft`]. + pub fn builder() -> StftBuilder { + StftBuilder::default() + } + + /// Number of frequency bins per column: `fft_size / 2 + 1` (DC … Nyquist). + #[must_use] + pub fn n_freqs(&self) -> usize { + self.n_freqs + } + + /// The frame (window) length. + #[must_use] + pub fn frame_len(&self) -> usize { + self.frame_len + } + + /// The hop size. + #[must_use] + pub fn hop(&self) -> usize { + self.hop + } + + /// The FFT size. + #[must_use] + pub fn fft_size(&self) -> usize { + self.fft_size + } + + /// The multiplicative scaling factor applied to each coefficient. + #[must_use] + pub fn scale(&self) -> T { + self.scale + } + + /// The analysis window. + #[must_use] + pub fn window(&self) -> &Window { + &self.window + } + + /// A cloned handle to the shared forward-FFT plan (used by batch workers). + #[cfg(feature = "rayon")] + pub(crate) fn fft_handle(&self) -> Arc> { + self.fft.clone() + } + + /// The bin center frequencies for a sample rate `fs`: `freqs[k] = k·fs / fft_size`. + #[must_use] + pub fn freqs(&self, fs: f64) -> Vec { + let fft_size = self.fft_size as f64; + (0..self.n_freqs) + .map(|k| cast(k as f64 * fs / fft_size)) + .collect() + } + + /// Append samples to the internal ring buffer. + pub fn append(&mut self, samples: &[T]) { + self.ring.extend(samples.iter().copied()); + } + + /// Number of buffered samples awaiting processing. + #[must_use] + pub fn buffered(&self) -> usize { + self.ring.len() + } + + /// Whether a full frame is available to process. + #[must_use] + pub fn ready(&self) -> bool { + self.ring.len() >= self.frame_len + } + + /// Clear the internal ring buffer. + pub fn reset(&mut self) { + self.ring.clear(); + } + + /// Compute the current column into `out` without advancing. + /// + /// # Errors + /// Returns [`StftError::LengthMismatch`] if `out.len() != n_freqs`, or + /// [`StftError::NotEnoughData`] if fewer than `frame_len` samples are + /// buffered. + pub fn process_into(&mut self, out: &mut [Complex]) -> Result<(), StftError> { + if out.len() != self.n_freqs { + return Err(StftError::LengthMismatch { + expected: self.n_freqs, + got: out.len(), + }); + } + if self.ring.len() < self.frame_len { + return Err(StftError::NotEnoughData { + needed: self.frame_len, + available: self.ring.len(), + }); + } + self.compute_from_ring()?; + out.copy_from_slice(&self.spectrum); + Ok(()) + } + + /// Drop `hop` samples from the front of the ring buffer. + pub fn step(&mut self) { + let drop = self.hop.min(self.ring.len()); + self.ring.drain(..drop); + } + + /// Iterate over spectrogram columns, advancing by `hop` after each, until + /// fewer than `frame_len` samples remain buffered. + pub fn columns(&mut self) -> Columns<'_, T> { + Columns { stft: self } + } + + /// Fill `self.input` from the front `frame_len` samples of the ring, + /// applying the window and zero-padding, then run the FFT and scaling. + fn compute_from_ring(&mut self) -> Result<(), StftError> { + let frame_len = self.frame_len; + let win = self.window.coefficients(); + let (head, tail) = self.input.split_at_mut(frame_len); + for ((dst, &w), &s) in head.iter_mut().zip(win).zip(self.ring.iter()) { + *dst = s * w; + } + for dst in tail { + *dst = T::zero(); + } + self.run_fft() + } + + /// Run the forward FFT on `self.input`, writing to `self.spectrum` and + /// applying the scaling factor. + fn run_fft(&mut self) -> Result<(), StftError> { + self.fft + .process_with_scratch(&mut self.input, &mut self.spectrum, &mut self.scratch) + .map_err(|_| StftError::Fft)?; + if self.scale != T::one() { + let scale = self.scale; + for bin in &mut self.spectrum { + *bin = *bin * scale; + } + } + Ok(()) + } + + /// Fill `self.input` from an arbitrary `frame` slice (length `frame_len`) + /// and compute its spectrum. Used by serial batch processing. + #[cfg(not(feature = "rayon"))] + pub(crate) fn compute_frame(&mut self, frame: &[T]) -> Result<&[Complex], StftError> { + debug_assert_eq!(frame.len(), self.frame_len); + let frame_len = self.frame_len; + let win = self.window.coefficients(); + let (head, tail) = self.input.split_at_mut(frame_len); + for ((dst, &w), &s) in head.iter_mut().zip(win).zip(frame) { + *dst = s * w; + } + for dst in tail { + *dst = T::zero(); + } + self.run_fft()?; + Ok(&self.spectrum) + } +} + +/// Iterator over spectrogram columns produced by [`Stft::columns`]. +pub struct Columns<'a, T: Sample + FftNum> { + stft: &'a mut Stft, +} + +impl Iterator for Columns<'_, T> { + type Item = Vec>; + + fn next(&mut self) -> Option { + if !self.stft.ready() { + return None; + } + // `ready()` guarantees a full frame, so `compute_from_ring` succeeds. + self.stft.compute_from_ring().ok()?; + let column = self.stft.spectrum.clone(); + self.stft.step(); + Some(column) + } +} diff --git a/src/window/functions.rs b/src/window/functions.rs new file mode 100644 index 0000000..2baa645 --- /dev/null +++ b/src/window/functions.rs @@ -0,0 +1,202 @@ +//! Raw window-coefficient generators, all computed in `f64`. +//! +//! Every generator builds a *symmetric* window of the requested length using +//! the conventional `denom = M - 1` normalization. Periodic ("DFT-even") +//! windows are obtained by [`super::mod`] generating a symmetric window one +//! sample longer and truncating, which matches NumPy/SciPy `fftbins=True` and +//! librosa. + +use alloc::vec; +use alloc::vec::Vec; +use core::f64::consts::PI; +// In `no_std` builds `f64` has no inherent transcendental methods, so the +// `Float` trait (backed by `libm`) supplies them. Under `std` the inherent +// methods are used and importing the trait would be flagged as unused. +#[cfg(not(feature = "std"))] +use num_traits::Float; + +/// Sum-of-cosines window of length `m` with the given coefficients +/// `[a0, a1, a2, ...]`: `w[i] = a0 - a1·cos(θ) + a2·cos(2θ) - a3·cos(3θ) + ...` +/// where `θ = 2π·i/(m-1)`. +pub(super) fn cosine_sum(m: usize, coeffs: &[f64]) -> Vec { + if m == 1 { + return vec![1.0]; + } + let denom = (m - 1) as f64; + (0..m) + .map(|i| { + let theta = 2.0 * PI * (i as f64) / denom; + let mut acc = 0.0; + for (k, &a) in coeffs.iter().enumerate() { + let sign = if k % 2 == 0 { 1.0 } else { -1.0 }; + acc += sign * a * (k as f64 * theta).cos(); + } + acc + }) + .collect() +} + +pub(super) fn rectangular(m: usize) -> Vec { + vec![1.0; m] +} + +pub(super) fn hann(m: usize) -> Vec { + cosine_sum(m, &[0.5, 0.5]) +} + +pub(super) fn hamming(m: usize) -> Vec { + cosine_sum(m, &[0.54, 0.46]) +} + +pub(super) fn blackman(m: usize) -> Vec { + cosine_sum(m, &[0.42, 0.5, 0.08]) +} + +pub(super) fn blackman_harris(m: usize) -> Vec { + cosine_sum(m, &[0.358_75, 0.488_29, 0.141_28, 0.011_68]) +} + +pub(super) fn nuttall(m: usize) -> Vec { + cosine_sum(m, &[0.363_581_9, 0.489_177_5, 0.136_599_5, 0.010_641_1]) +} + +pub(super) fn flat_top(m: usize) -> Vec { + cosine_sum( + m, + &[ + 0.215_578_95, + 0.416_631_58, + 0.277_263_158, + 0.083_578_947, + 0.006_947_368, + ], + ) +} + +/// Sine (a.k.a. cosine) window: `w[i] = sin(π·i/(m-1))`. +pub(super) fn cosine(m: usize) -> Vec { + if m == 1 { + return vec![1.0]; + } + let denom = (m - 1) as f64; + (0..m).map(|i| (PI * (i as f64) / denom).sin()).collect() +} + +/// Bartlett (triangular with zero endpoints), matching `numpy.bartlett`. +pub(super) fn bartlett(m: usize) -> Vec { + if m == 1 { + return vec![1.0]; + } + let half = (m - 1) as f64 / 2.0; + (0..m) + .map(|i| 1.0 - ((i as f64 - half) / half).abs()) + .collect() +} + +/// Triangular window without zero endpoints, matching `scipy.signal.windows.triang`. +pub(super) fn triangular(m: usize) -> Vec { + if m == 1 { + return vec![1.0]; + } + let mf = m as f64; + (0..m) + .map(|i| { + // distance from the center, in samples + let d = (i as f64 - (m - 1) as f64 / 2.0).abs(); + if m % 2 == 0 { + 1.0 - (2.0 * d) / mf + } else { + 1.0 - (2.0 * d) / (mf + 1.0) + } + }) + .collect() +} + +/// Welch (parabolic) window. +pub(super) fn welch(m: usize) -> Vec { + if m == 1 { + return vec![1.0]; + } + let half = (m - 1) as f64 / 2.0; + (0..m) + .map(|i| { + let x = (i as f64 - half) / half; + 1.0 - x * x + }) + .collect() +} + +/// Tukey (tapered cosine) window, matching `scipy.signal.windows.tukey`. +pub(super) fn tukey(m: usize, alpha: f64) -> Vec { + if m == 1 { + return vec![1.0]; + } + let alpha = alpha.clamp(0.0, 1.0); + if alpha == 0.0 { + return rectangular(m); + } + if alpha >= 1.0 { + return hann(m); + } + let denom = (m - 1) as f64; + let width = alpha * denom / 2.0; + (0..m) + .map(|i| { + let n = i as f64; + if n < width { + 0.5 * (1.0 + (PI * (-1.0 + 2.0 * n / (alpha * denom))).cos()) + } else if n <= denom - width { + 1.0 + } else { + 0.5 * (1.0 + (PI * (-2.0 / alpha + 1.0 + 2.0 * n / (alpha * denom))).cos()) + } + }) + .collect() +} + +/// Modified Bessel function of the first kind, order zero (series expansion). +fn bessel_i0(x: f64) -> f64 { + let mut sum = 1.0; + let mut term = 1.0; + let mut k = 1.0; + loop { + let r = x / (2.0 * k); + term *= r * r; + sum += term; + if term <= 1e-16 * sum || k > 1_000.0 { + break; + } + k += 1.0; + } + sum +} + +/// Kaiser window with shape parameter `beta`. +pub(super) fn kaiser(m: usize, beta: f64) -> Vec { + if m == 1 { + return vec![1.0]; + } + let denom = (m - 1) as f64; + let i0_beta = bessel_i0(beta); + (0..m) + .map(|i| { + let r = 2.0 * (i as f64) / denom - 1.0; + bessel_i0(beta * (1.0 - r * r).max(0.0).sqrt()) / i0_beta + }) + .collect() +} + +/// Gaussian window with standard deviation `std` (in samples), matching +/// `scipy.signal.windows.gaussian`. +pub(super) fn gaussian(m: usize, std: f64) -> Vec { + if m == 1 { + return vec![1.0]; + } + let center = (m - 1) as f64 / 2.0; + (0..m) + .map(|i| { + let n = (i as f64 - center) / std; + (-0.5 * n * n).exp() + }) + .collect() +} diff --git a/src/window/mod.rs b/src/window/mod.rs new file mode 100644 index 0000000..ab81a50 --- /dev/null +++ b/src/window/mod.rs @@ -0,0 +1,237 @@ +//! Analysis/synthesis windows. +//! +//! [`Window`] holds the (already evaluated) coefficients together with their +//! sum and sum-of-squares, which the scaling modes and overlap-add +//! normalization need. [`WindowFunction`] is a serializable description of a +//! parametric window family that can be evaluated into a [`Window`]. + +mod functions; + +use crate::sample::{cast, Sample}; +use alloc::vec::Vec; + +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + +/// Whether a window is *symmetric* (for filter design) or *periodic* (a.k.a. +/// "DFT-even", for spectral analysis). +/// +/// A periodic window of length `N` equals the symmetric window of length +/// `N + 1` with its last sample removed, matching NumPy/SciPy `fftbins=True` +/// and librosa. [`Symmetry::Periodic`] is the default because it is the right +/// choice for STFT spectral analysis. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum Symmetry { + /// DFT-even window, the correct choice for spectral analysis. + #[default] + Periodic, + /// Symmetric window, the correct choice for FIR filter design. + Symmetric, +} + +/// A parametric window family that can be evaluated into a [`Window`]. +#[derive(Debug, Clone, Copy, PartialEq)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[non_exhaustive] +pub enum WindowFunction { + /// Rectangular (boxcar) window: all ones. + Rectangular, + /// Hann window. + Hann, + /// Hamming window. + Hamming, + /// Blackman window. + Blackman, + /// 4-term Blackman-Harris window. + BlackmanHarris, + /// 4-term Nuttall window. + Nuttall, + /// 5-term flat-top window (excellent amplitude accuracy). + FlatTop, + /// Bartlett (triangular, zero endpoints) window. + Bartlett, + /// Triangular window without zero endpoints. + Triangular, + /// Welch (parabolic) window. + Welch, + /// Cosine (sine) window. + Cosine, + /// Tukey (tapered cosine) window; `alpha` is the taper fraction in `0..=1`. + Tukey { + /// Fraction of the window inside the cosine tapers (`0` = rectangular, + /// `1` = Hann). + alpha: f64, + }, + /// Kaiser window with shape parameter `beta`. + Kaiser { + /// Shape parameter; larger values trade main-lobe width for side-lobe + /// attenuation. + beta: f64, + }, + /// Gaussian window with standard deviation `std` (in samples). + Gaussian { + /// Standard deviation in samples. + std: f64, + }, +} + +impl WindowFunction { + /// Evaluate this window family into raw `f64` coefficients of length `len` + /// with the given symmetry. + fn coefficients(self, len: usize, symmetry: Symmetry) -> Vec { + if len == 0 { + return Vec::new(); + } + if len == 1 { + return alloc::vec![1.0]; + } + // Build a symmetric window of length `m` (= len for symmetric, len + 1 + // for periodic) and truncate to `len`. + let m = match symmetry { + Symmetry::Symmetric => len, + Symmetry::Periodic => len + 1, + }; + let mut coeffs = match self { + Self::Rectangular => functions::rectangular(m), + Self::Hann => functions::hann(m), + Self::Hamming => functions::hamming(m), + Self::Blackman => functions::blackman(m), + Self::BlackmanHarris => functions::blackman_harris(m), + Self::Nuttall => functions::nuttall(m), + Self::FlatTop => functions::flat_top(m), + Self::Bartlett => functions::bartlett(m), + Self::Triangular => functions::triangular(m), + Self::Welch => functions::welch(m), + Self::Cosine => functions::cosine(m), + Self::Tukey { alpha } => functions::tukey(m, alpha), + Self::Kaiser { beta } => functions::kaiser(m, beta), + Self::Gaussian { std } => functions::gaussian(m, std), + }; + coeffs.truncate(len); + coeffs + } + + /// Evaluate this window family into a typed [`Window`]. + #[must_use] + pub fn generate(self, len: usize, symmetry: Symmetry) -> Window { + let coeffs = self + .coefficients(len, symmetry) + .into_iter() + .map(cast) + .collect(); + Window::from_coefficients(coeffs) + } +} + +/// A window: its coefficients plus cached `sum` and `sum_of_squares`. +#[derive(Debug, Clone, PartialEq)] +pub struct Window { + coeffs: Vec, + sum: T, + sum_sq: T, +} + +impl Window { + /// Build a window from explicit coefficients, computing the cached sums. + #[must_use] + pub fn from_coefficients(coeffs: Vec) -> Self { + let mut sum = T::zero(); + let mut sum_sq = T::zero(); + for &c in &coeffs { + sum = sum + c; + sum_sq = sum_sq + c * c; + } + Self { + coeffs, + sum, + sum_sq, + } + } + + /// Build a window from a [`WindowFunction`] family. + #[must_use] + pub fn new(function: WindowFunction, len: usize, symmetry: Symmetry) -> Self { + function.generate(len, symmetry) + } + + /// The window coefficients. + #[must_use] + pub fn coefficients(&self) -> &[T] { + &self.coeffs + } + + /// Number of coefficients (the frame length). + #[must_use] + pub fn len(&self) -> usize { + self.coeffs.len() + } + + /// Whether the window is empty. + #[must_use] + pub fn is_empty(&self) -> bool { + self.coeffs.is_empty() + } + + /// Sum of the coefficients (`Σ wᵢ`). Used by magnitude scaling. + #[must_use] + pub fn sum(&self) -> T { + self.sum + } + + /// Sum of the squared coefficients (`Σ wᵢ²`). Used by density scaling and + /// overlap-add normalization. + #[must_use] + pub fn sum_squared(&self) -> T { + self.sum_sq + } +} + +/// Generate a shorthand constructor for a parameter-free window family. +macro_rules! window_ctor { + ($name:ident, $variant:ident, $doc:literal) => { + #[doc = $doc] + #[doc = ""] + #[doc = "Uses [`Symmetry::Periodic`]; use [`Window::new`] for symmetric windows."] + #[must_use] + pub fn $name(len: usize) -> Self { + WindowFunction::$variant.generate(len, Symmetry::Periodic) + } + }; +} + +impl Window { + window_ctor!(rectangular, Rectangular, "A rectangular (boxcar) window."); + window_ctor!(hann, Hann, "A periodic Hann window."); + window_ctor!(hamming, Hamming, "A periodic Hamming window."); + window_ctor!(blackman, Blackman, "A periodic Blackman window."); + window_ctor!( + blackman_harris, + BlackmanHarris, + "A periodic Blackman-Harris window." + ); + window_ctor!(nuttall, Nuttall, "A periodic Nuttall window."); + window_ctor!(flat_top, FlatTop, "A periodic flat-top window."); + window_ctor!(bartlett, Bartlett, "A Bartlett window (zero endpoints)."); + window_ctor!(triangular, Triangular, "A triangular window."); + window_ctor!(welch, Welch, "A Welch (parabolic) window."); + window_ctor!(cosine, Cosine, "A cosine (sine) window."); + + /// A periodic Tukey (tapered cosine) window with taper fraction `alpha`. + #[must_use] + pub fn tukey(len: usize, alpha: f64) -> Self { + WindowFunction::Tukey { alpha }.generate(len, Symmetry::Periodic) + } + + /// A periodic Kaiser window with shape parameter `beta`. + #[must_use] + pub fn kaiser(len: usize, beta: f64) -> Self { + WindowFunction::Kaiser { beta }.generate(len, Symmetry::Periodic) + } + + /// A periodic Gaussian window with standard deviation `std` (in samples). + #[must_use] + pub fn gaussian(len: usize, std: f64) -> Self { + WindowFunction::Gaussian { std }.generate(len, Symmetry::Periodic) + } +} diff --git a/tests/lib.rs b/tests/lib.rs index 466c10f..6c921c7 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -1,126 +1,260 @@ -use approx::assert_ulps_eq; -use std::str::FromStr; -use ruststft::{WindowType, STFT}; +//! Integration tests exercising the public API against analytic ground truth. + +use approx::assert_abs_diff_eq; +use core::f64::consts::PI; +#[cfg(feature = "mel")] +use ruststft::mel::{hz_to_mel, mel_to_hz, DctII, MelFilterBank, MelScale}; +use ruststft::spectrum::{amplitude_to_db, magnitude}; +use ruststft::{Complex, Scaling, Stft, Symmetry, Window, WindowFunction}; + +// --------------------------------------------------------------------------- +// Windows +// --------------------------------------------------------------------------- + +#[test] +fn periodic_hann_matches_truncated_symmetric() { + let w = Window::::hann(4); + // Periodic Hann(4) == symmetric Hann(5) without its last sample. + let expected = [0.0, 0.5, 1.0, 0.5]; + for (got, want) in w.coefficients().iter().zip(expected) { + assert_abs_diff_eq!(*got, want, epsilon = 1e-12); + } +} + +#[test] +fn symmetric_window_is_symmetric() { + let w = Window::::new(WindowFunction::Hann, 9, Symmetry::Symmetric); + let c = w.coefficients(); + for i in 0..c.len() { + assert_abs_diff_eq!(c[i], c[c.len() - 1 - i], epsilon = 1e-12); + } + assert_abs_diff_eq!(c[0], 0.0, epsilon = 1e-12); +} + +#[test] +fn rectangular_window_sums() { + let w = Window::::rectangular(16); + assert_abs_diff_eq!(w.sum(), 16.0, epsilon = 1e-12); + assert_abs_diff_eq!(w.sum_squared(), 16.0, epsilon = 1e-12); +} + +// --------------------------------------------------------------------------- +// Forward STFT correctness +// --------------------------------------------------------------------------- + +#[test] +fn nyquist_bin_is_included() { + let stft = Stft::builder() + .window(Window::::hann(1024)) + .hop_size(256) + .build() + .unwrap(); + assert_eq!(stft.n_freqs(), 1024 / 2 + 1); +} + +#[test] +fn frequencies_are_correct() { + let fft_size = 1024usize; + let fs = 8_000.0; + let stft = Stft::builder() + .window(Window::::rectangular(fft_size)) + .hop_size(256) + .build() + .unwrap(); + let freqs = stft.freqs(fs); + assert_eq!(freqs.len(), fft_size / 2 + 1); + assert_abs_diff_eq!(freqs[0], 0.0, epsilon = 1e-9); + // Last bin is exactly Nyquist. + assert_abs_diff_eq!(*freqs.last().unwrap(), fs / 2.0, epsilon = 1e-9); + // Arbitrary bin k -> k * fs / fft_size. + assert_abs_diff_eq!(freqs[10], 10.0 * fs / fft_size as f64, epsilon = 1e-9); +} + +#[test] +fn pure_tone_peaks_at_expected_bin() { + let n = 1024usize; + let fs = 1024.0; + let k0 = 64usize; // exact bin + let signal: Vec = (0..n) + .map(|i| (2.0 * PI * k0 as f64 * i as f64 / n as f64).cos()) + .collect(); + + let mut stft = Stft::builder() + .window(Window::::rectangular(n)) + .hop_size(n) + .build() + .unwrap(); + let spec = stft.spectrogram(&signal); + assert_eq!(spec.n_frames(), 1); + + let mags = magnitude(spec.column(0)); + let argmax = mags + .iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) + .unwrap() + .0; + assert_eq!(argmax, k0); + assert_eq!(spec.column(0).len(), fs as usize / 2 + 1); +} #[test] -fn test_window_type_from_string() { - assert_eq!( - WindowType::from_str("Hanning").unwrap(), - WindowType::Hanning - ); - assert_eq!( - WindowType::from_str("hanning").unwrap(), - WindowType::Hanning - ); - assert_eq!(WindowType::from_str("hann").unwrap(), WindowType::Hanning); - assert_eq!( - WindowType::from_str("blackman").unwrap(), - WindowType::Blackman - ); +fn magnitude_scaling_recovers_sine_amplitude() { + let n = 2048usize; + let amplitude = 0.7; + let k0 = 100usize; + let signal: Vec = (0..n) + .map(|i| amplitude * (2.0 * PI * k0 as f64 * i as f64 / n as f64).cos()) + .collect(); + + let mut stft = Stft::builder() + .window(Window::::rectangular(n)) + .hop_size(n) + .scaling(Scaling::Magnitude) + .build() + .unwrap(); + let spec = stft.spectrogram(&signal); + let mags = magnitude(spec.column(0)); + // One-sided magnitude of a real cosine at an interior bin is amplitude/2. + assert_abs_diff_eq!(mags[k0], amplitude / 2.0, epsilon = 1e-6); } #[test] -fn test_window_type_to_string() { - assert_eq!(WindowType::Hanning.to_string(), "Hanning"); +fn streaming_and_batch_agree() { + let n = 4096usize; + let signal: Vec = (0..n).map(|i| (i as f64 * 0.01).sin()).collect(); + + let mut batch = Stft::builder() + .window(Window::::hann(256)) + .hop_size(64) + .build() + .unwrap(); + let spec = batch.spectrogram(&signal); + + let mut stream = Stft::builder() + .window(Window::::hann(256)) + .hop_size(64) + .build() + .unwrap(); + stream.append(&signal); + let columns: Vec>> = stream.columns().collect(); + + assert_eq!(columns.len(), spec.n_frames()); + for (frame, batch_col) in columns.iter().zip(spec.columns()) { + for (a, b) in frame.iter().zip(batch_col) { + assert_abs_diff_eq!(a.re, b.re, epsilon = 1e-9); + assert_abs_diff_eq!(a.im, b.im, epsilon = 1e-9); + } + } } +// --------------------------------------------------------------------------- +// Round-trip reconstruction (STFT -> ISTFT) +// --------------------------------------------------------------------------- + #[test] -fn test_window_types_to_strings() { - assert_eq!( - vec!["Hanning", "Hamming", "Blackman", "Nuttall", "None"], - WindowType::values() - .iter() - .map(|x| x.to_string()) - .collect::>() - ); +fn round_trip_reconstructs_signal() { + let n = 8000usize; + let fs = 8000.0; + let signal: Vec = (0..n) + .map(|i| { + (2.0 * PI * 220.0 * i as f64 / fs).sin() + + 0.3 * (2.0 * PI * 600.0 * i as f64 / fs).sin() + }) + .collect(); + + let mut stft = Stft::builder() + .window(Window::::hann(1024)) + .hop_size(256) // 75% overlap: Hann is COLA-compliant + .center(true) + .build() + .unwrap(); + let spec = stft.spectrogram(&signal); + + let istft = stft.inverse().unwrap(); + let recon = istft.reconstruct(&spec).unwrap(); + + // Compare the interior, away from edge-taper artifacts. + let lo = 1024; + let hi = n - 1024; + for i in lo..hi { + assert_abs_diff_eq!(recon[i], signal[i], epsilon = 1e-6); + } } #[test] -fn test_log10_positive() { - assert!(ruststft::log10_positive(-1. as f64).is_nan()); - assert_eq!(ruststft::log10_positive(0.), 0.); - assert_eq!(ruststft::log10_positive(1.), 0.); - assert_eq!(ruststft::log10_positive(10.), 1.); - assert_eq!(ruststft::log10_positive(100.), 2.); - assert_eq!(ruststft::log10_positive(1000.), 3.); +fn round_trip_rectangular_no_overlap_is_exact_interior() { + let n = 4096usize; + let signal: Vec = (0..n).map(|i| ((i * 7 % 13) as f64) - 6.0).collect(); + + let mut stft = Stft::builder() + .window(Window::::rectangular(512)) + .hop_size(512) // contiguous, non-overlapping frames + .build() + .unwrap(); + let spec = stft.spectrogram(&signal); + let istft = stft.inverse().unwrap(); + let recon = istft.reconstruct(&spec).unwrap(); + + for i in 0..recon.len() { + assert_abs_diff_eq!(recon[i], signal[i], epsilon = 1e-9); + } +} + +// --------------------------------------------------------------------------- +// Spectrum helpers +// --------------------------------------------------------------------------- + +#[test] +fn amplitude_to_db_floor_and_reference() { + let mut v = vec![1.0f64, 0.1, 0.01, 0.0]; + amplitude_to_db(&mut v, 1.0, Some(80.0)); + assert_abs_diff_eq!(v[0], 0.0, epsilon = 1e-9); // 20*log10(1) = 0 + assert_abs_diff_eq!(v[1], -20.0, epsilon = 1e-9); // 20*log10(0.1) + // Floored at max - 80 = -80. + assert_abs_diff_eq!(v[3], -80.0, epsilon = 1e-9); } +// --------------------------------------------------------------------------- +// Mel + MFCC +// --------------------------------------------------------------------------- + +#[cfg(feature = "mel")] #[test] -fn test_stft() { - let mut stft = STFT::new(WindowType::Hanning, 8, 4); - assert!(!stft.contains_enough_to_compute()); - assert_eq!(stft.output_size(), 4); - assert_eq!(stft.len(), 0); - stft.append_samples(&[500., 0., 100.]); - assert_eq!(stft.len(), 3); - assert!(!stft.contains_enough_to_compute()); - stft.append_samples(&[500., 0., 100., 0.]); - assert_eq!(stft.len(), 7); - assert!(!stft.contains_enough_to_compute()); - - stft.append_samples(&[500.]); - assert!(stft.contains_enough_to_compute()); - - let mut output: Vec = vec![0.; 4]; - stft.compute_column(&mut output); - println!("{:?}", output); - - let expected = vec![ - 2.7763337740785166, - 2.7149781042402594, - 2.6218024907053796, - 2.647816050270838, - ]; - assert_ulps_eq!(output.as_slice(), expected.as_slice(), max_ulps = 10); - - // repeat the calculation to ensure results are independent of the internal buffer - let mut output2: Vec = vec![0.; 4]; - stft.compute_column(&mut output2); - assert_ulps_eq!(output.as_slice(), output2.as_slice(), max_ulps = 10); +fn mel_scale_round_trips_and_is_monotonic() { + for scale in [MelScale::Slaney, MelScale::Htk] { + assert_abs_diff_eq!(hz_to_mel(0.0, scale), 0.0, epsilon = 1e-9); + for f in [100.0, 440.0, 1000.0, 4000.0, 8000.0] { + assert_abs_diff_eq!(mel_to_hz(hz_to_mel(f, scale), scale), f, epsilon = 1e-6); + } + assert!(hz_to_mel(2000.0, scale) > hz_to_mel(1000.0, scale)); + } +} + +#[cfg(feature = "mel")] +#[test] +fn mel_filterbank_shape_and_nonnegativity() { + let bank = MelFilterBank::::new(40, 1024, 16_000.0, 0.0, 8_000.0, MelScale::Slaney); + assert_eq!(bank.n_mels(), 40); + assert_eq!(bank.n_freqs(), 1024 / 2 + 1); + assert!(bank.weights().iter().all(|&w| w >= 0.0)); + + // Applying to a flat power spectrum yields positive energy in every band. + let power = vec![1.0f64; bank.n_freqs()]; + let mel = bank.transform(&power); + assert_eq!(mel.len(), 40); + assert!(mel.iter().all(|&m| m > 0.0)); } +#[cfg(feature = "mel")] #[test] -fn test_stft_padded() { - let mut stft = STFT::new_with_zero_padding(WindowType::Hanning, 8, 32, 4); - assert!(!stft.contains_enough_to_compute()); - assert_eq!(stft.output_size(), 16); - assert_eq!(stft.len(), 0); - stft.append_samples(&[500., 0., 100.]); - assert_eq!(stft.len(), 3); - assert!(!stft.contains_enough_to_compute()); - stft.append_samples(&[500., 0., 100., 0.]); - assert_eq!(stft.len(), 7); - assert!(!stft.contains_enough_to_compute()); - - stft.append_samples(&[500.]); - assert!(stft.contains_enough_to_compute()); - - let mut output: Vec = vec![0.; 16]; - stft.compute_column(&mut output); - println!("{:?}", output); - - let expected = vec![ - 2.7763337740785166, - 2.772158781619449, - 2.7598791705720664, - 2.740299218211912, - 2.7149781042402594, - 2.686495897766628, - 2.6585877421915676, - 2.635728083951981, - 2.6218024907053796, - 2.6183544930578027, - 2.6238833073831658, - 2.634925941918913, - 2.647816050270838, - 2.65977332745612, - 2.6691025866822033, - 2.6749381613735683, - ]; - assert_ulps_eq!(output.as_slice(), expected.as_slice(), max_ulps = 10); - - // repeat the calculation to ensure results are independent of the internal buffer - let mut output2: Vec = vec![0.; 16]; - stft.compute_column(&mut output2); - assert_ulps_eq!(output.as_slice(), output2.as_slice(), max_ulps = 10); +fn dct2_of_constant_is_a_single_coefficient() { + let n = 32usize; + let dct = DctII::::new(n, n); + let x = vec![1.0f64; n]; + let y = dct.transform(&x); + assert_abs_diff_eq!(y[0], (n as f64).sqrt(), epsilon = 1e-9); + for &c in &y[1..] { + assert_abs_diff_eq!(c, 0.0, epsilon = 1e-9); + } } diff --git a/tests/proptests.rs b/tests/proptests.rs new file mode 100644 index 0000000..8d2f808 --- /dev/null +++ b/tests/proptests.rs @@ -0,0 +1,63 @@ +//! Property-based tests: STFT linearity and STFT/ISTFT round-trip fidelity. + +use proptest::prelude::*; +use ruststft::{Stft, Window}; + +proptest! { + #![proptest_config(ProptestConfig::with_cases(24))] + + /// The STFT is linear: `STFT(a·x + b·y) == a·STFT(x) + b·STFT(y)`. + #[test] + fn stft_is_linear( + x in prop::collection::vec(-1.0f64..1.0, 1024..2048), + a in -3.0f64..3.0, + b in -3.0f64..3.0, + ) { + let n = x.len(); + // A deterministic second signal correlated with the first. + let y: Vec = (0..n).map(|i| (i as f64 * 0.013).sin()).collect(); + let z: Vec = (0..n).map(|i| a * x[i] + b * y[i]).collect(); + + let build = || Stft::builder() + .window(Window::::hann(256)) + .hop_size(128) + .build() + .unwrap(); + + let sx = build().spectrogram(&x); + let sy = build().spectrogram(&y); + let sz = build().spectrogram(&z); + + prop_assert_eq!(sx.n_frames(), sz.n_frames()); + for f in 0..sz.n_frames() { + for (k, zc) in sz.column(f).iter().enumerate() { + let lhs = *zc; + let rhs = sx.column(f)[k] * a + sy.column(f)[k] * b; + prop_assert!((lhs.re - rhs.re).abs() < 1e-6); + prop_assert!((lhs.im - rhs.im).abs() < 1e-6); + } + } + } + + /// A Hann window at 75% overlap is COLA-compliant, so STFT->ISTFT + /// reconstructs the interior of the signal. + #[test] + fn round_trip_reconstructs_interior( + signal in prop::collection::vec(-1.0f64..1.0, 4096..6000), + ) { + let mut stft = Stft::builder() + .window(Window::::hann(512)) + .hop_size(128) + .center(true) + .build() + .unwrap(); + let spec = stft.spectrogram(&signal); + let recon = stft.inverse().unwrap().reconstruct(&spec).unwrap(); + + let frame = 512; + let hi = signal.len().min(recon.len()); + for i in frame..(hi - frame) { + prop_assert!((recon[i] - signal[i]).abs() < 1e-6); + } + } +} From 14ca7d4a1f767eea782d6a4e406214f331f07066 Mon Sep 17 00:00:00 2001 From: Markus Mayer Date: Sun, 31 May 2026 20:46:30 +0200 Subject: [PATCH 2/4] build: add opt-in wasm_simd feature for WASM SIMD FFT kernels Adds a `wasm_simd` feature that forwards to `realfft/wasm_simd` (and thus `rustfft/wasm_simd`). It implies `std` because the FFT backend does, so it cannot half-activate the optional realfft dependency. No effect off wasm32; on wasm32 it enables the simd128 kernels when built with `-C target-feature=+simd128`. --- Cargo.toml | 3 +++ README.md | 1 + 2 files changed, 4 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index dfcc046..5c7046a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,9 @@ ndarray = ["dep:ndarray", "std"] rayon = ["dep:rayon", "std"] # `serde` derives on the configuration and window-specification types. serde = ["dep:serde"] +# Enable the WASM SIMD (`simd128`) FFT kernels. Implies `std`. Only takes +# effect on `wasm32` targets built with `-C target-feature=+simd128`. +wasm_simd = ["std", "realfft/wasm_simd"] [dependencies] num-complex = { version = "0.4.6", default-features = false, features = ["libm"] } diff --git a/README.md b/README.md index cf1a7c9..52c23b1 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ while stft.ready() { | `ndarray` | no | `Spectrogram::to_array2` (`[n_freqs, n_frames]`). | | `rayon` | no | Parallel per-frame batch spectrograms. | | `serde` | no | (De)serialize configuration and window descriptions. | +| `wasm_simd` | no | WASM `simd128` FFT kernels (implies `std`; build with `-C target-feature=+simd128`). | Without the default `std` feature the crate builds as `no_std` (with `alloc`), exposing the window library, the [`spectrum`](https://docs.rs/ruststft/latest/ruststft/spectrum/) From c7073d6f35053d30162a55cfe6f6177744065fa5 Mon Sep 17 00:00:00 2001 From: Markus Mayer Date: Sun, 31 May 2026 20:56:19 +0200 Subject: [PATCH 3/4] fix: address PR review feedback - Taskfile `test:no-std`: build the library only instead of running the integration tests, which require `std` (they import `Stft`). - Remove the obsolete `.github/workflows/rust.yml`; the new `ci.yml` replaces it and the duplicate would double-run on every push/PR. - Correct the centered-padding docs on `StftBuilder::center` and `PadMode` to say `frame_len / 2` (the value the implementation uses), not `fft_size / 2`. - Fix the chirp in examples/spectrogram.rs to integrate the linear frequency sweep, so it actually starts at 200 Hz (a bare sin(pi*f*t) starts at f0/2 = 100 Hz). --- .github/workflows/rust.yml | 24 ------------------------ Taskfile.dist.yaml | 6 ++++-- examples/spectrogram.rs | 12 +++++++++--- src/config.rs | 4 ++-- src/stft.rs | 2 +- 5 files changed, 16 insertions(+), 32 deletions(-) delete mode 100644 .github/workflows/rust.yml diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml deleted file mode 100644 index c1fc968..0000000 --- a/.github/workflows/rust.yml +++ /dev/null @@ -1,24 +0,0 @@ -name: Rust - -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] - -env: - CARGO_TERM_COLOR: always - -jobs: - build: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - name: Build - run: cargo build --verbose - - name: Run tests - run: cargo test --tests --verbose - - name: Run doctests - run: cargo test --doc --verbose diff --git a/Taskfile.dist.yaml b/Taskfile.dist.yaml index 8659d4c..cd9bee7 100644 --- a/Taskfile.dist.yaml +++ b/Taskfile.dist.yaml @@ -72,9 +72,11 @@ tasks: - cargo test --doc --all-features test:no-std: - desc: Run tests without default features + desc: Build the library without default features (the no_std subset; the + FFT processors and integration tests require std) cmds: - - cargo test --no-default-features --features mel {{.CLI_ARGS}} + - cargo build --lib --no-default-features {{.CLI_ARGS}} + - cargo build --lib --no-default-features --features mel,serde bench: desc: Run the criterion benchmarks diff --git a/examples/spectrogram.rs b/examples/spectrogram.rs index 7d70455..e297ee5 100644 --- a/examples/spectrogram.rs +++ b/examples/spectrogram.rs @@ -9,12 +9,18 @@ fn main() { let fs = 16_000.0; let n = fs as usize * 2; // 2 seconds - // A linear chirp sweeping from 200 Hz to 4 kHz. + // A linear chirp sweeping from 200 Hz to 4 kHz over 2 seconds. The phase is + // the integral of the instantaneous frequency f(t) = f0 + k*t, so the + // sweep actually starts at f0 (a bare `sin(pi*f*t)` would start at f0/2). + let f0 = 200.0f32; + let f1 = 4000.0f32; + let duration = n as f32 / fs; + let k = (f1 - f0) / duration; // Hz per second let signal: Vec = (0..n) .map(|i| { let t = i as f32 / fs; - let f = 200.0 + (4000.0 - 200.0) * (t / 2.0); - (std::f32::consts::PI * f * t).sin() + let phase = 2.0 * std::f32::consts::PI * (f0 * t + 0.5 * k * t * t); + phase.sin() }) .collect(); diff --git a/src/config.rs b/src/config.rs index 7c88468..96fcb51 100644 --- a/src/config.rs +++ b/src/config.rs @@ -27,8 +27,8 @@ pub enum Scaling { /// How a signal is padded when centered framing is enabled in batch mode. /// /// With [`center`](crate::StftBuilder::center) enabled the signal is padded by -/// `fft_size / 2` samples on each side so that frame `t` is centered on sample -/// `t * hop`, matching librosa's convention. +/// `frame_len / 2` samples on each side so that frame `t` is centered on sample +/// `t * hop`. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum PadMode { diff --git a/src/stft.rs b/src/stft.rs index e40f40c..c0a64d1 100644 --- a/src/stft.rs +++ b/src/stft.rs @@ -67,7 +67,7 @@ impl StftBuilder { } /// Enable centered framing for batch [`spectrogram`](Stft::spectrogram): - /// the signal is padded by `fft_size / 2` on each side. + /// the signal is padded by `frame_len / 2` on each side. pub fn center(mut self, center: bool) -> Self { self.center = center; self From 6523ee105c9b6dc8994a39eab5ae2bd7b38a5313 Mon Sep 17 00:00:00 2001 From: Markus Mayer Date: Sun, 31 May 2026 21:06:29 +0200 Subject: [PATCH 4/4] Bump MSRV to 1.85 --- .github/workflows/ci.yml | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d9c1dd0..0482bf7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,7 +41,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - rust: [stable, "1.75"] + rust: [stable, "1.85"] steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master diff --git a/Cargo.toml b/Cargo.toml index 5c7046a..9873b98 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ license = "MIT OR Apache-2.0" name = "ruststft" readme = "README.md" repository = "https://github.com/sunsided/stft.git" -rust-version = "1.75" +rust-version = "1.85" version = "0.4.0" [features]