From bae1de8968751124ba22bf5ec0e14670830ce013 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ey=C3=BCp=20Can=20Akman?= Date: Mon, 8 Jun 2026 01:29:24 +0300 Subject: [PATCH] Add stft and istft to mlx.core.fft --- docs/src/python/fft.rst | 2 + mlx/fft.cpp | 265 +++++++++++++++++++++++++++++++++++++++ mlx/fft.h | 28 +++++ python/src/fft.cpp | 120 ++++++++++++++++++ python/tests/test_fft.py | 115 +++++++++++++++++ tests/fft_tests.cpp | 66 ++++++++++ 6 files changed, 596 insertions(+) diff --git a/docs/src/python/fft.rst b/docs/src/python/fft.rst index 78bfe7f8fc..0df0b665df 100644 --- a/docs/src/python/fft.rst +++ b/docs/src/python/fft.rst @@ -20,6 +20,8 @@ FFT irfft2 rfftn irfftn + stft + istft fftfreq rfftfreq fftshift diff --git a/mlx/fft.cpp b/mlx/fft.cpp index 8ddc1aca46..6a160674a4 100644 --- a/mlx/fft.cpp +++ b/mlx/fft.cpp @@ -1,4 +1,5 @@ // Copyright © 2023 Apple Inc. +#include #include #include #include @@ -338,4 +339,268 @@ array rfftfreq(int n, double d /* = 1.0 */, StreamOrDevice s /* = {} */) { return multiply(freqs, scale, s); } +namespace { + +// Pad the last axis; mx::pad has no reflect mode, so build it from take. +array pad_last_axis( + const array& a, + int pad_width, + const std::string& mode, + StreamOrDevice s) { + if (pad_width <= 0) { + return a; + } + int ax = a.ndim() - 1; + if (mode == "constant") { + return pad( + a, {ax}, {pad_width}, {pad_width}, array(0, a.dtype()), "constant", s); + } + if (mode == "edge") { + int n = a.shape(ax); + auto left = take(a, full(Shape{pad_width}, array(0), int32, s), ax, s); + auto right = take(a, full(Shape{pad_width}, array(n - 1), int32, s), ax, s); + return concatenate({left, a, right}, ax, s); + } + if (mode == "reflect") { + int n = a.shape(ax); + if (pad_width >= n) { + std::ostringstream msg; + msg << "[stft] Reflect padding (" << pad_width + << ") requires an input longer than the padding along the last axis (" + << n << ")."; + throw std::invalid_argument(msg.str()); + } + auto left = take(a, arange(pad_width, 0, -1, s), ax, s); + auto right = take(a, arange(n - 2, n - 2 - pad_width, -1, s), ax, s); + return concatenate({left, a, right}, ax, s); + } + std::ostringstream msg; + msg << "[stft] Invalid pad_mode '" << mode + << "'. Expected one of {'constant', 'reflect', 'edge'}."; + throw std::invalid_argument(msg.str()); +} + +// Center a window of length <= n_fft inside an n_fft-length frame. +array prepare_window( + const std::optional& window, + int win_length, + int n_fft, + const std::string& op, + StreamOrDevice s) { + array win = + window.has_value() ? window.value() : ones({win_length}, float32, s); + if (win.ndim() != 1) { + throw std::invalid_argument( + "[" + op + "] window must be a one-dimensional array."); + } + int wl = win.shape(0); + if (wl > n_fft) { + throw std::invalid_argument( + "[" + op + "] window length must not exceed n_fft."); + } + if (wl < n_fft) { + int pad_left = (n_fft - wl) / 2; + int pad_right = n_fft - wl - pad_left; + win = + pad(win, + {0}, + {pad_left}, + {pad_right}, + array(0, win.dtype()), + "constant", + s); + } + return win; +} + +} // namespace + +array stft( + const array& x_in, + int n_fft /* = 2048 */, + const std::optional& hop_length_ /* = std::nullopt */, + const std::optional& win_length_ /* = std::nullopt */, + const std::optional& window /* = std::nullopt */, + bool center /* = true */, + const std::string& pad_mode /* = "reflect" */, + FFTNorm norm /* = FFTNorm::Backward */, + const std::optional& onesided_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + if (n_fft <= 0) { + throw std::invalid_argument("[stft] n_fft must be positive."); + } + if (x_in.ndim() < 1) { + throw std::invalid_argument( + "[stft] Input must have at least one dimension."); + } + int hop_length = hop_length_.value_or(n_fft / 4); + int win_length = win_length_.value_or(n_fft); + if (hop_length <= 0) { + throw std::invalid_argument("[stft] hop_length must be positive."); + } + bool onesided = onesided_.value_or(x_in.dtype() != complex64); + + array win = prepare_window(window, win_length, n_fft, "stft", s); + + // Collapse leading dims into a single batch axis. + Shape lead = x_in.shape(); + lead.pop_back(); + int t = x_in.shape(-1); + array x = reshape(x_in, {-1, t}, s); + + if (center) { + x = pad_last_axis(x, n_fft / 2, pad_mode, s); + } + int tp = x.shape(-1); + if (tp < n_fft) { + throw std::invalid_argument( + "[stft] Input is too short for the given n_fft."); + } + int n_frames = 1 + (tp - n_fft) / hop_length; + int b = x.shape(0); + + // Overlapping frames via as_strided (element strides: tp, hop, 1). + array frames = as_strided( + x, + {b, n_frames, n_fft}, + Strides{static_cast(tp), static_cast(hop_length), 1}, + 0, + s); + frames = multiply(frames, win, s); + + array spec = onesided ? rfftn(frames, Shape{n_fft}, {-1}, norm, s) + : fftn(frames, Shape{n_fft}, {-1}, norm, s); + spec = swapaxes(spec, -1, -2, s); + + Shape out_shape = lead; + out_shape.push_back(spec.shape(-2)); + out_shape.push_back(n_frames); + return reshape(spec, out_shape, s); +} + +array istft( + const array& stft_matrix, + const std::optional& n_fft_ /* = std::nullopt */, + const std::optional& hop_length_ /* = std::nullopt */, + const std::optional& win_length_ /* = std::nullopt */, + const std::optional& window /* = std::nullopt */, + bool center /* = true */, + FFTNorm norm /* = FFTNorm::Backward */, + const std::optional& onesided_ /* = std::nullopt */, + const std::optional& length_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + if (stft_matrix.ndim() < 2) { + throw std::invalid_argument( + "[istft] Input must have at least two dimensions (freq, frames)."); + } + bool onesided = onesided_.value_or(true); + int n_freq = stft_matrix.shape(-2); + int n_fft = n_fft_.value_or(onesided ? (n_freq - 1) * 2 : n_freq); + if (n_fft <= 0) { + throw std::invalid_argument("[istft] n_fft must be positive."); + } + int hop_length = hop_length_.value_or(n_fft / 4); + int win_length = win_length_.value_or(n_fft); + if (hop_length <= 0) { + throw std::invalid_argument("[istft] hop_length must be positive."); + } + + array win = prepare_window(window, win_length, n_fft, "istft", s); + + Shape lead = stft_matrix.shape(); + int n_frames = lead.back(); + lead.pop_back(); + lead.pop_back(); + array z = reshape(stft_matrix, {-1, n_freq, n_frames}, s); + int b = z.shape(0); + z = swapaxes(z, -1, -2, s); + + array frames = onesided ? irfftn(z, Shape{n_fft}, {-1}, norm, s) + : real(ifftn(z, Shape{n_fft}, {-1}, norm, s), s); + frames = multiply(frames, win, s); + + // Overlap-add: sub-block k of frame i lands on output block i + k, so + // summing seg = ceil(n_fft / hop) shifted copies is independent of n_frames. + int seg = (n_fft + hop_length - 1) / hop_length; + int wp = seg * hop_length; + if (wp > n_fft) { + frames = + pad(frames, + {2}, + {0}, + {wp - n_fft}, + array(0, frames.dtype()), + "constant", + s); + } + array c = reshape(frames, {b, n_frames, seg, hop_length}, s); + + array win_sq = square(win, s); + if (wp > n_fft) { + win_sq = + pad(win_sq, + {0}, + {0}, + {wp - n_fft}, + array(0, win_sq.dtype()), + "constant", + s); + } + array cw = broadcast_to( + reshape(win_sq, {1, 1, seg, hop_length}, s), + {1, n_frames, seg, hop_length}, + s); + + int out_blocks = n_frames + seg - 1; + int out_len = out_blocks * hop_length; + array signal = zeros({b, out_len}, frames.dtype(), s); + array envelope = zeros({1, out_len}, frames.dtype(), s); + for (int k = 0; k < seg; ++k) { + array ck = reshape( + slice(c, {0, 0, k, 0}, {b, n_frames, k + 1, hop_length}, s), + {b, n_frames, hop_length}, + s); + ck = pad(ck, {1}, {k}, {seg - 1 - k}, array(0, ck.dtype()), "constant", s); + signal = add(signal, reshape(ck, {b, out_len}, s), s); + + array wk = reshape( + slice(cw, {0, 0, k, 0}, {1, n_frames, k + 1, hop_length}, s), + {1, n_frames, hop_length}, + s); + wk = pad(wk, {1}, {k}, {seg - 1 - k}, array(0, wk.dtype()), "constant", s); + envelope = add(envelope, reshape(wk, {1, out_len}, s), s); + } + signal = + divide(signal, maximum(envelope, array(1e-8f, envelope.dtype()), s), s); + + // Drop the centering pad from the start. The end is bounded by `length` when + // given, otherwise by removing the centering pad from the end as well. The + // full reconstruction spans n_fft + (n_frames - 1) * hop before centering. + int sig_len = n_fft + (n_frames - 1) * hop_length; + int start = center ? n_fft / 2 : 0; + if (length_.has_value()) { + int length = length_.value(); + int stop = std::min(sig_len, start + length); + signal = slice(signal, {0, start}, {b, stop}, s); + int cur = signal.shape(-1); + if (cur < length) { + signal = + pad(signal, + {1}, + {0}, + {length - cur}, + array(0, signal.dtype()), + "constant", + s); + } + } else { + int stop = center ? sig_len - n_fft / 2 : sig_len; + signal = slice(signal, {0, start}, {b, stop}, s); + } + + Shape out_shape = lead; + out_shape.push_back(signal.shape(-1)); + return reshape(signal, out_shape, s); +} + } // namespace mlx::core::fft diff --git a/mlx/fft.h b/mlx/fft.h index f69b8aa2c4..f067ff49c5 100644 --- a/mlx/fft.h +++ b/mlx/fft.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include #include "array.h" @@ -212,6 +213,7 @@ inline array irfft2( StreamOrDevice s = {}) { return irfftn(a, axes, norm, s); } + /** Compute the discrete Fourier Transform sample frequencies. */ MLX_API array fftfreq(int n, double d = 1.0, StreamOrDevice s = {}); @@ -234,4 +236,30 @@ MLX_API array ifftshift(const array& a, StreamOrDevice s = {}); MLX_API array ifftshift(const array& a, const std::vector& axes, StreamOrDevice s = {}); +/** Compute the Short-Time Fourier Transform of the last axis of `x`. */ +MLX_API array stft( + const array& x, + int n_fft = 2048, + const std::optional& hop_length = std::nullopt, + const std::optional& win_length = std::nullopt, + const std::optional& window = std::nullopt, + bool center = true, + const std::string& pad_mode = "reflect", + FFTNorm norm = FFTNorm::Backward, + const std::optional& onesided = std::nullopt, + StreamOrDevice s = {}); + +/** Compute the inverse Short-Time Fourier Transform. */ +MLX_API array istft( + const array& stft_matrix, + const std::optional& n_fft = std::nullopt, + const std::optional& hop_length = std::nullopt, + const std::optional& win_length = std::nullopt, + const std::optional& window = std::nullopt, + bool center = true, + FFTNorm norm = FFTNorm::Backward, + const std::optional& onesided = std::nullopt, + const std::optional& length = std::nullopt, + StreamOrDevice s = {}); + } // namespace mlx::core::fft diff --git a/python/src/fft.cpp b/python/src/fft.cpp index eb6d531c92..6a385cc88d 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -633,4 +633,124 @@ void init_fft(nb::module_& parent_module) { Returns: array: The inverse-shifted array with the same shape as the input. )pbdoc"); + m.def( + "stft", + [](const mx::array& x, + int n_fft, + const std::optional& hop_length, + const std::optional& win_length, + const std::optional& window, + bool center, + const std::string& pad_mode, + const std::string& norm, + const std::optional& onesided, + mx::StreamOrDevice s) { + auto fft_norm = parse_norm(norm, "stft"); + return mx::fft::stft( + x, + n_fft, + hop_length, + win_length, + window, + center, + pad_mode, + fft_norm, + onesided, + s); + }, + "x"_a, + "n_fft"_a = 2048, + "hop_length"_a = nb::none(), + "win_length"_a = nb::none(), + "window"_a = nb::none(), + "center"_a = true, + "pad_mode"_a = "reflect", + "norm"_a = "backward", + "onesided"_a = nb::none(), + "stream"_a = nb::none(), + R"pbdoc( + Short-Time Fourier Transform. + + Splits the last axis of ``x`` into overlapping frames of length + ``n_fft``, applies ``window``, and takes the Fourier Transform of each + frame. The output is always complex with shape ``(..., n_freq, n_frames)``. + + Args: + x (array): The input array. + n_fft (int, optional): Size of the Fourier Transform. Default is ``2048``. + hop_length (int, optional): Samples between frames. Default is ``n_fft // 4``. + win_length (int, optional): Window size. Default is ``n_fft``. + window (array, optional): A one-dimensional window, centered and + zero-padded to ``n_fft``. Default is a rectangular window. + center (bool, optional): Pad both ends so frames are centered. Default is ``True``. + pad_mode (str, optional): One of ``"reflect"``, ``"constant"``, or + ``"edge"``, used when ``center`` is ``True``. Default is ``"reflect"``. + norm (str, optional): One of ``"backward"``, ``"ortho"``, or + ``"forward"``. Default is ``"backward"``. + onesided (bool, optional): Return only non-negative frequencies. + Default is ``True`` for real input and ``False`` for complex input. + + Returns: + array: The STFT of ``x``. + )pbdoc"); + m.def( + "istft", + [](const mx::array& stft_matrix, + const std::optional& n_fft, + const std::optional& hop_length, + const std::optional& win_length, + const std::optional& window, + bool center, + const std::string& norm, + const std::optional& onesided, + const std::optional& length, + mx::StreamOrDevice s) { + auto fft_norm = parse_norm(norm, "istft"); + return mx::fft::istft( + stft_matrix, + n_fft, + hop_length, + win_length, + window, + center, + fft_norm, + onesided, + length, + s); + }, + "stft_matrix"_a, + "n_fft"_a = nb::none(), + "hop_length"_a = nb::none(), + "win_length"_a = nb::none(), + "window"_a = nb::none(), + "center"_a = true, + "norm"_a = "backward", + "onesided"_a = nb::none(), + "length"_a = nb::none(), + "stream"_a = nb::none(), + R"pbdoc( + Inverse Short-Time Fourier Transform. + + Inverts :func:`stft` by overlap-adding the windowed inverse transform + of each frame, normalized by the squared-window envelope. + + Args: + stft_matrix (array): The STFT with shape ``(..., n_freq, n_frames)``. + n_fft (int, optional): Size of the Fourier Transform. Default is + inferred from ``n_freq`` and ``onesided``. The onesided inference + is even; pass ``n_fft`` for an odd transform size. + hop_length (int, optional): Samples between frames. Default is ``n_fft // 4``. + win_length (int, optional): Window size. Default is ``n_fft``. + window (array, optional): A one-dimensional window. Default is a + rectangular window. + center (bool, optional): Whether the STFT was centered. Default is ``True``. + norm (str, optional): One of ``"backward"``, ``"ortho"``, or + ``"forward"``. Default is ``"backward"``. + onesided (bool, optional): Whether the STFT has only non-negative + frequencies. Default is ``True``. + length (int, optional): Trim or zero-pad the output to this length. + + Returns: + array: The reconstructed signal. + )pbdoc"); } diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 9c03530b04..bbff46dba7 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -406,6 +406,121 @@ def g(x): dgdx = torch.func.grad(g)(torch.tensor(x)) self.assertLess((dfdx - dgdx).abs().max() / dgdx.abs().mean(), 1e-4) + @unittest.skipIf(not has_torch, "requires PyTorch") + def test_stft(self): + def check(x_np, win_np, n_fft, hop, win_len, center, pad_mode, norm, one): + win_mx = None if win_np is None else mx.array(win_np) + win_pt = None if win_np is None else torch.tensor(win_np) + z_mx = mx.fft.stft( + mx.array(x_np), + n_fft=n_fft, + hop_length=hop, + win_length=win_len, + window=win_mx, + center=center, + pad_mode=pad_mode, + norm=norm, + onesided=one, + ) + z_pt = torch.stft( + torch.tensor(x_np), + n_fft=n_fft, + hop_length=hop, + win_length=win_len, + window=win_pt, + center=center, + pad_mode="replicate" if pad_mode == "edge" else pad_mode, + normalized=(norm == "ortho"), + onesided=one, + return_complex=True, + ) + np.testing.assert_allclose( + np.array(z_mx), z_pt.numpy(), atol=1e-4, rtol=1e-4 + ) + + np.random.seed(0) + x = np.random.randn(16000).astype(np.float32) + han = np.hanning(400).astype(np.float32) + h256 = np.hanning(256).astype(np.float32) + check(x, han, 400, 160, 400, True, "reflect", "backward", True) + check(x, han, 512, 128, 400, False, "constant", "backward", True) + check(x, han, 400, 160, 400, True, "edge", "backward", True) + check(x, han, 400, 160, 400, True, "reflect", "ortho", True) + check(x, h256, 256, 64, 256, True, "reflect", "backward", False) + check(x, None, 256, 64, 256, True, "reflect", "backward", True) + xb = np.random.randn(2, 16000).astype(np.float32) + check(xb, h256, 256, 64, 256, True, "reflect", "backward", True) + + @unittest.skipIf(not has_torch, "requires PyTorch") + def test_istft(self): + def check(x_np, win_np, n_fft, hop, center, norm, one, length): + win_mx = None if win_np is None else mx.array(win_np) + win_pt = None if win_np is None else torch.tensor(win_np) + z = mx.fft.stft( + mx.array(x_np), + n_fft=n_fft, + hop_length=hop, + window=win_mx, + center=center, + norm=norm, + onesided=one, + ) + y_mx = mx.fft.istft( + z, + n_fft=n_fft, + hop_length=hop, + window=win_mx, + center=center, + norm=norm, + onesided=one, + length=length, + ) + y_pt = torch.istft( + torch.tensor(np.array(z)), + n_fft=n_fft, + hop_length=hop, + window=win_pt, + center=center, + normalized=(norm == "ortho"), + onesided=one, + length=length, + ) + np.testing.assert_allclose( + np.array(y_mx), y_pt.numpy(), atol=1e-4, rtol=1e-4 + ) + + np.random.seed(0) + x = np.random.randn(16000).astype(np.float32) + win = np.hanning(512).astype(np.float32) + check(x, win, 512, 128, True, "backward", True, 16000) + # n_fft not divisible by hop, no explicit length (length must still match) + check(x, win, 512, 160, True, "backward", True, None) + # center=False needs an edge-nonzero window for torch's NOLA check + check(x, None, 512, 160, False, "backward", True, None) + check(x, win, 512, 128, True, "ortho", True, 16000) + check(x, None, 512, 256, True, "backward", False, 16000) + # odd n_fft (passed explicitly, since onesided inference is even) + check( + x, + np.hanning(257).astype(np.float32), + 257, + 64, + True, + "backward", + True, + 16000, + ) + xb = np.random.randn(2, 16000).astype(np.float32) + check(xb, win, 512, 128, True, "backward", True, 16000) + + # norm="forward" has no torch.istft flag; check the round-trip directly + win_mx = mx.array(win) + z = mx.fft.stft( + mx.array(x), n_fft=512, hop_length=128, window=win_mx, norm="forward" + ) + y = mx.fft.istft(z, hop_length=128, window=win_mx, norm="forward", length=16000) + np.testing.assert_allclose(np.array(y)[1000:-1000], x[1000:-1000], atol=1e-4) + if __name__ == "__main__": mlx_tests.MLXTestRunner() diff --git a/tests/fft_tests.cpp b/tests/fft_tests.cpp index 7dc1b6e666..2039f937cb 100644 --- a/tests/fft_tests.cpp +++ b/tests/fft_tests.cpp @@ -393,3 +393,69 @@ TEST_CASE("test fftshift and ifftshift") { CHECK_THROWS_AS(fft::ifftshift(x, {3}), std::invalid_argument); CHECK_THROWS_AS(fft::ifftshift(x, {-5}), std::invalid_argument); } + +TEST_CASE("test stft and istft") { + auto win = ones({16}, float32); + + // Real, centered input -> (n_freq, n_frames). + auto x = astype(arange(64), float32); + auto z = fft::stft(x, 16, 8, 16, win, /* center = */ true); + CHECK_EQ(z.ndim(), 2); + CHECK_EQ(z.shape(0), 9); // n_fft / 2 + 1 + CHECK_EQ(z.shape(1), 9); // 1 + (64 + 16 - 16) / 8 + CHECK_EQ(z.dtype(), complex64); + + // Batched input -> (batch, n_freq, n_frames). + auto xb = reshape(astype(arange(2 * 64), float32), {2, 64}); + auto zb = fft::stft(xb, 16, 8, 16, win, true); + CHECK_EQ(zb.ndim(), 3); + CHECK_EQ(zb.shape(0), 2); + CHECK_EQ(zb.shape(1), 9); + + // Non-overlapping rectangular window reconstructs the signal exactly. + auto x2 = astype(arange(32), float32); + auto win2 = ones({8}, float32); + auto z2 = fft::stft(x2, 8, 8, 8, win2, /* center = */ false); + auto y2 = fft::istft( + z2, + /* n_fft = */ 8, + /* hop_length = */ 8, + /* win_length = */ 8, + win2, + /* center = */ false, + fft::FFTNorm::Backward, + /* onesided = */ true, + /* length = */ 32); + CHECK_EQ(y2.shape(0), 32); + CHECK(allclose(y2, x2, 1e-4, 1e-4).item()); + + // Overlapping window: round-trip recovers the fully covered interior. + auto x3 = astype(arange(64), float32); + auto z3 = fft::stft(x3, 16, 8, 16, win, /* center = */ false); + auto y3 = fft::istft(z3, 16, 8, 16, win, false, fft::FFTNorm::Backward, true); + CHECK_EQ(y3.shape(0), 64); // n_fft + (n_frames - 1) * hop = 16 + 6 * 8 + CHECK(allclose(slice(y3, {8}, {56}), slice(x3, {8}, {56}), 1e-4, 1e-4) + .item()); + + // n_fft not divisible by hop: length is n_fft + (n_frames - 1) * hop. + auto w7 = ones({7}, float32); + auto z4 = fft::stft(x3, 7, 3, 7, w7, /* center = */ false); + auto y4 = fft::istft(z4, 7, 3, 7, w7, false, fft::FFTNorm::Backward, true); + CHECK_EQ(y4.shape(0), 64); // n_frames = 20 -> 7 + 19 * 3 + + // Two-sided transform keeps all frequency bins. + auto zf = fft::stft( + x, 16, 8, 16, win, true, "reflect", fft::FFTNorm::Backward, false); + CHECK_EQ(zf.shape(0), 16); + + // Edge padding must not raise (regression: edge mode on the last axis). + auto ze = fft::stft(x, 16, 8, 16, win, true, "edge"); + CHECK_EQ(ze.shape(0), 9); + + // Error cases. + CHECK_THROWS_AS(fft::stft(x, 0), std::invalid_argument); + CHECK_THROWS_AS( + fft::stft(x, 16, 8, 16, ones({32}, float32)), std::invalid_argument); + CHECK_THROWS_AS( + fft::stft(x, 16, 8, 16, win, true, "bogus"), std::invalid_argument); +}