Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/python/fft.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ FFT
irfft2
rfftn
irfftn
stft
istft
fftfreq
rfftfreq
fftshift
Expand Down
265 changes: 265 additions & 0 deletions mlx/fft.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <algorithm>
#include <cmath>
#include <functional>
#include <numeric>
Expand Down Expand Up @@ -338,4 +339,268 @@ array rfftfreq(int n, double d /* = 1.0 */, StreamOrDevice s /* = {} */) {
return multiply(freqs, scale, s);
}

namespace {

// Pad the last axis; mx::pad has no reflect mode, so build it from take.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would #3608 help?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, its reflect mode matches what this helper does and also handles pad >= n, so I can drop the helper and call pad directly once it lands. Want me to rebase on it after it merges, or keep the local version for now?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for verifying, let's just keep the local version for now.

array pad_last_axis(
const array& a,
int pad_width,
const std::string& mode,
StreamOrDevice s) {
if (pad_width <= 0) {
return a;
}
int ax = a.ndim() - 1;
if (mode == "constant") {
return pad(
a, {ax}, {pad_width}, {pad_width}, array(0, a.dtype()), "constant", s);
}
if (mode == "edge") {
int n = a.shape(ax);
auto left = take(a, full(Shape{pad_width}, array(0), int32, s), ax, s);
auto right = take(a, full(Shape{pad_width}, array(n - 1), int32, s), ax, s);
return concatenate({left, a, right}, ax, s);
}
if (mode == "reflect") {
int n = a.shape(ax);
if (pad_width >= n) {
std::ostringstream msg;
msg << "[stft] Reflect padding (" << pad_width
<< ") requires an input longer than the padding along the last axis ("
<< n << ").";
throw std::invalid_argument(msg.str());
}
auto left = take(a, arange(pad_width, 0, -1, s), ax, s);
auto right = take(a, arange(n - 2, n - 2 - pad_width, -1, s), ax, s);
return concatenate({left, a, right}, ax, s);
}
std::ostringstream msg;
msg << "[stft] Invalid pad_mode '" << mode
<< "'. Expected one of {'constant', 'reflect', 'edge'}.";
throw std::invalid_argument(msg.str());
}

// Center a window of length <= n_fft inside an n_fft-length frame.
array prepare_window(
const std::optional<array>& window,
int win_length,
int n_fft,
const std::string& op,
StreamOrDevice s) {
array win =
window.has_value() ? window.value() : ones({win_length}, float32, s);
if (win.ndim() != 1) {
throw std::invalid_argument(
"[" + op + "] window must be a one-dimensional array.");
}
int wl = win.shape(0);
if (wl > n_fft) {
throw std::invalid_argument(
"[" + op + "] window length must not exceed n_fft.");
}
if (wl < n_fft) {
int pad_left = (n_fft - wl) / 2;
int pad_right = n_fft - wl - pad_left;
win =
pad(win,
{0},
{pad_left},
{pad_right},
array(0, win.dtype()),
"constant",
s);
}
return win;
}

} // namespace

array stft(
const array& x_in,
int n_fft /* = 2048 */,
const std::optional<int>& hop_length_ /* = std::nullopt */,
const std::optional<int>& win_length_ /* = std::nullopt */,
const std::optional<array>& window /* = std::nullopt */,
bool center /* = true */,
const std::string& pad_mode /* = "reflect" */,
FFTNorm norm /* = FFTNorm::Backward */,
const std::optional<bool>& onesided_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
if (n_fft <= 0) {
throw std::invalid_argument("[stft] n_fft must be positive.");
}
if (x_in.ndim() < 1) {
throw std::invalid_argument(
"[stft] Input must have at least one dimension.");
}
int hop_length = hop_length_.value_or(n_fft / 4);
int win_length = win_length_.value_or(n_fft);
if (hop_length <= 0) {
throw std::invalid_argument("[stft] hop_length must be positive.");
}
bool onesided = onesided_.value_or(x_in.dtype() != complex64);

array win = prepare_window(window, win_length, n_fft, "stft", s);

// Collapse leading dims into a single batch axis.
Shape lead = x_in.shape();
lead.pop_back();
int t = x_in.shape(-1);
array x = reshape(x_in, {-1, t}, s);

if (center) {
x = pad_last_axis(x, n_fft / 2, pad_mode, s);
}
int tp = x.shape(-1);
if (tp < n_fft) {
throw std::invalid_argument(
"[stft] Input is too short for the given n_fft.");
}
int n_frames = 1 + (tp - n_fft) / hop_length;
int b = x.shape(0);

// Overlapping frames via as_strided (element strides: tp, hop, 1).
array frames = as_strided(
x,
{b, n_frames, n_fft},
Strides{static_cast<int64_t>(tp), static_cast<int64_t>(hop_length), 1},
0,
s);
frames = multiply(frames, win, s);

array spec = onesided ? rfftn(frames, Shape{n_fft}, {-1}, norm, s)
: fftn(frames, Shape{n_fft}, {-1}, norm, s);
spec = swapaxes(spec, -1, -2, s);

Shape out_shape = lead;
out_shape.push_back(spec.shape(-2));
out_shape.push_back(n_frames);
return reshape(spec, out_shape, s);
}

array istft(
const array& stft_matrix,
const std::optional<int>& n_fft_ /* = std::nullopt */,
const std::optional<int>& hop_length_ /* = std::nullopt */,
const std::optional<int>& win_length_ /* = std::nullopt */,
const std::optional<array>& window /* = std::nullopt */,
bool center /* = true */,
FFTNorm norm /* = FFTNorm::Backward */,
const std::optional<bool>& onesided_ /* = std::nullopt */,
const std::optional<int>& length_ /* = std::nullopt */,
StreamOrDevice s /* = {} */) {
if (stft_matrix.ndim() < 2) {
throw std::invalid_argument(
"[istft] Input must have at least two dimensions (freq, frames).");
}
bool onesided = onesided_.value_or(true);
int n_freq = stft_matrix.shape(-2);
int n_fft = n_fft_.value_or(onesided ? (n_freq - 1) * 2 : n_freq);
if (n_fft <= 0) {
throw std::invalid_argument("[istft] n_fft must be positive.");
}
int hop_length = hop_length_.value_or(n_fft / 4);
int win_length = win_length_.value_or(n_fft);
if (hop_length <= 0) {
throw std::invalid_argument("[istft] hop_length must be positive.");
}

array win = prepare_window(window, win_length, n_fft, "istft", s);

Shape lead = stft_matrix.shape();
int n_frames = lead.back();
lead.pop_back();
lead.pop_back();
array z = reshape(stft_matrix, {-1, n_freq, n_frames}, s);
int b = z.shape(0);
z = swapaxes(z, -1, -2, s);

array frames = onesided ? irfftn(z, Shape{n_fft}, {-1}, norm, s)
: real(ifftn(z, Shape{n_fft}, {-1}, norm, s), s);
frames = multiply(frames, win, s);

// Overlap-add: sub-block k of frame i lands on output block i + k, so
// summing seg = ceil(n_fft / hop) shifted copies is independent of n_frames.
int seg = (n_fft + hop_length - 1) / hop_length;
int wp = seg * hop_length;
if (wp > n_fft) {
frames =
pad(frames,
{2},
{0},
{wp - n_fft},
array(0, frames.dtype()),
"constant",
s);
}
array c = reshape(frames, {b, n_frames, seg, hop_length}, s);

array win_sq = square(win, s);
if (wp > n_fft) {
win_sq =
pad(win_sq,
{0},
{0},
{wp - n_fft},
array(0, win_sq.dtype()),
"constant",
s);
}
array cw = broadcast_to(
reshape(win_sq, {1, 1, seg, hop_length}, s),
{1, n_frames, seg, hop_length},
s);

int out_blocks = n_frames + seg - 1;
int out_len = out_blocks * hop_length;
array signal = zeros({b, out_len}, frames.dtype(), s);
array envelope = zeros({1, out_len}, frames.dtype(), s);
for (int k = 0; k < seg; ++k) {
array ck = reshape(
slice(c, {0, 0, k, 0}, {b, n_frames, k + 1, hop_length}, s),
{b, n_frames, hop_length},
s);
ck = pad(ck, {1}, {k}, {seg - 1 - k}, array(0, ck.dtype()), "constant", s);
signal = add(signal, reshape(ck, {b, out_len}, s), s);

array wk = reshape(
slice(cw, {0, 0, k, 0}, {1, n_frames, k + 1, hop_length}, s),
{1, n_frames, hop_length},
s);
wk = pad(wk, {1}, {k}, {seg - 1 - k}, array(0, wk.dtype()), "constant", s);
envelope = add(envelope, reshape(wk, {1, out_len}, s), s);
}
signal =
divide(signal, maximum(envelope, array(1e-8f, envelope.dtype()), s), s);

// Drop the centering pad from the start. The end is bounded by `length` when
// given, otherwise by removing the centering pad from the end as well. The
// full reconstruction spans n_fft + (n_frames - 1) * hop before centering.
int sig_len = n_fft + (n_frames - 1) * hop_length;
int start = center ? n_fft / 2 : 0;
if (length_.has_value()) {
int length = length_.value();
int stop = std::min(sig_len, start + length);
signal = slice(signal, {0, start}, {b, stop}, s);
int cur = signal.shape(-1);
if (cur < length) {
signal =
pad(signal,
{1},
{0},
{length - cur},
array(0, signal.dtype()),
"constant",
s);
}
} else {
int stop = center ? sig_len - n_fft / 2 : sig_len;
signal = slice(signal, {0, start}, {b, stop}, s);
}

Shape out_shape = lead;
out_shape.push_back(signal.shape(-1));
return reshape(signal, out_shape, s);
}

} // namespace mlx::core::fft
28 changes: 28 additions & 0 deletions mlx/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#pragma once

#include <cstdint>
#include <optional>
#include <variant>

#include "array.h"
Expand Down Expand Up @@ -212,6 +213,7 @@ inline array irfft2(
StreamOrDevice s = {}) {
return irfftn(a, axes, norm, s);
}

/** Compute the discrete Fourier Transform sample frequencies. */
MLX_API array fftfreq(int n, double d = 1.0, StreamOrDevice s = {});

Expand All @@ -234,4 +236,30 @@ MLX_API array ifftshift(const array& a, StreamOrDevice s = {});
MLX_API array
ifftshift(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});

/** Compute the Short-Time Fourier Transform of the last axis of `x`. */
MLX_API array stft(
const array& x,
int n_fft = 2048,
const std::optional<int>& hop_length = std::nullopt,
const std::optional<int>& win_length = std::nullopt,
const std::optional<array>& window = std::nullopt,
bool center = true,
const std::string& pad_mode = "reflect",
FFTNorm norm = FFTNorm::Backward,
const std::optional<bool>& onesided = std::nullopt,
StreamOrDevice s = {});

/** Compute the inverse Short-Time Fourier Transform. */
MLX_API array istft(
Comment thread
zcbenz marked this conversation as resolved.
const array& stft_matrix,
const std::optional<int>& n_fft = std::nullopt,
const std::optional<int>& hop_length = std::nullopt,
const std::optional<int>& win_length = std::nullopt,
const std::optional<array>& window = std::nullopt,
bool center = true,
FFTNorm norm = FFTNorm::Backward,
const std::optional<bool>& onesided = std::nullopt,
const std::optional<int>& length = std::nullopt,
StreamOrDevice s = {});

} // namespace mlx::core::fft
Loading
Loading