Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 213 additions & 9 deletions python/mlx/nn/layers/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import operator
from functools import partial, reduce
from itertools import product
from math import ceil
from typing import Callable, Literal, Tuple, Union

import mlx.core as mx
Expand Down Expand Up @@ -51,6 +52,118 @@ def _linear_indices(N, scale, align_corners, dim, ndims):
)


def _aa_indices(N, scale, align_corners, dim, ndims, kernel_fn, kernel_radius):
"""Compute antialiased interpolation indices for a given kernel.

When downscaling (scale < 1), the kernel support widens by 1/scale to
act as a low-pass filter, preventing aliasing. Out-of-bounds taps are
zeroed and weights are renormalized per output pixel.

For upscaling (scale >= 1), the kernel is applied at its native width
without widening. This matches PyTorch's F.interpolate(antialias=True)
behavior where the kernel coefficient (e.g. a=-0.5 for cubic) is used
for both up and downsampling when antialias=True.

Args:
N: input size for this dimension
scale: scale factor for this dimension
align_corners: align_corners flag
dim: which spatial dimension
ndims: number of spatial dimensions
kernel_fn: callable(distance) -> weight, operating in normalized
filter coordinates where |distance| < kernel_radius has support
kernel_radius: support radius of the kernel in normalized coords
(1.0 for triangle/linear, 2.0 for cubic)
"""
indices = _scaled_indices(N, scale, align_corners, dim, ndims)

# For downscale, widen the filter by 1/scale
if scale < 1:
inv_scale = 1.0 / scale
else:
inv_scale = 1.0

support = kernel_radius * inv_scale
num_taps = ceil(support) + 1

# Compute per-tap weights, zero out-of-bounds, then normalize
all_idx = []
all_w = []
for k in range(-num_taps + 1, num_taps):
idx = mx.floor(indices) + k
# Map distance to normalized filter coordinates
dist = mx.abs(indices - idx) / inv_scale
w = kernel_fn(dist)
# Zero out-of-bounds taps
w = mx.where((idx >= 0) & (idx < N), w, 0.0)
all_idx.append(idx)
all_w.append(w)

# Normalize so weights sum to 1 per output pixel
w_sum = sum(all_w)
w_sum = mx.where(w_sum > 0, w_sum, 1.0)

result = []
for idx, w in zip(all_idx, all_w):
w = mx.expand_dims(w / w_sum, -1)
idx = mx.clip(idx, a_min=0, a_max=N - 1).astype(mx.uint32)
result.append((idx, w))

return tuple(result)


def _triangle_kernel(x):
"""Triangle (linear) filter kernel. Support radius = 1."""
return mx.maximum(1.0 - x, 0.0)


def _cubic_kernel(x):
"""Keys cubic kernel with a=-0.5 (PIL/Pillow convention).

This coefficient is used by PyTorch when antialias=True for both
bilinear and bicubic modes. The non-antialiased cubic path uses
a=-0.75 (OpenCV convention) -- see ``_cubic_indices``.

Support radius = 2.
"""
a = -0.5
w_inner = ((a + 2.0) * x - (a + 3.0)) * x * x + 1
w_outer = (((x - 5) * x + 8) * x - 4) * a
return mx.where(x <= 1.0, w_inner, mx.where(x <= 2.0, w_outer, 0.0))


def _linear_aa_indices(N, scale, align_corners, dim, ndims):
"""Linear interpolation with antialiasing (triangle kernel)."""
return _aa_indices(
N,
scale,
align_corners,
dim,
ndims,
kernel_fn=_triangle_kernel,
kernel_radius=1.0,
)


def _cubic_aa_indices(N, scale, align_corners, dim, ndims):
"""Cubic interpolation with antialiasing (Keys cubic, a=-0.5).

Note: the non-antialiased cubic path (``_cubic_indices``) uses a=-0.75
(OpenCV convention). When ``antialias=True``, PyTorch switches to a=-0.5
(PIL convention). This coefficient change affects the interpolant shape,
not just the filter width. See ``_cubic_kernel`` for details.
"""
return _aa_indices(
N,
scale,
align_corners,
dim,
ndims,
kernel_fn=_cubic_kernel,
kernel_radius=2.0,
)


def _cubic_indices(N, scale, align_corners, dim, ndims):
indices = _scaled_indices(N, scale, align_corners, dim, ndims)
indices_l1 = mx.floor(indices)
Expand All @@ -60,8 +173,9 @@ def _cubic_indices(N, scale, align_corners, dim, ndims):

@partial(mx.compile, shapeless=True)
def _get_weight(ind, grid, dist):
# PyTorch uses -0.5 for antialiasing=true (compatibility with PIL)
# and uses -0.75 for antialiasing=false (compatibility with OpenCV)
# a=-0.75 (OpenCV convention) for non-antialiased cubic.
# When antialias=True, _cubic_aa_indices uses a=-0.5 (PIL convention)
# via _cubic_kernel instead.
a = -0.75
x = mx.abs(ind - grid)
if dist == 1:
Expand Down Expand Up @@ -89,6 +203,14 @@ def _get_weight(ind, grid, dist):
)


def _validate_antialias_align_corners(align_corners, antialias):
if antialias and align_corners:
raise ValueError(
"[Upsample] antialias=True with align_corners=True is not "
"supported. Use align_corners=False for antialiased interpolation."
)


def upsample_nearest(x: mx.array, scale_factor: Tuple):
dims = x.ndim - 2
if dims != len(scale_factor):
Expand Down Expand Up @@ -145,7 +267,41 @@ def _interpolate(
return sum(wi * xi for wi, xi in zip(weights, samples))


def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
def _interpolate_separable(
x: mx.array, scale_factor: Tuple, indices_fn: Callable, align_corners: bool = False
):
dims = x.ndim - 2
if dims != len(scale_factor):
raise ValueError("A scale needs to be provided for each spatial dimension")

_, *N, _ = x.shape
out = x

for i, (n, s) in enumerate(zip(N, scale_factor)):
axis = i + 1
samples = []
for idx, weight in indices_fn(n, s, align_corners, i, dims):
sample = mx.take(out, idx.reshape(-1), axis=axis)
samples.append(sample * weight)
out = sum(samples)

return out


def upsample_linear(
x: mx.array,
scale_factor: Tuple,
align_corners: bool = False,
antialias: bool = False,
):
_validate_antialias_align_corners(align_corners, antialias)
if antialias:
return _interpolate_separable(
x=x,
scale_factor=scale_factor,
indices_fn=_linear_aa_indices,
align_corners=align_corners,
)
return _interpolate(
x=x,
scale_factor=scale_factor,
Expand All @@ -154,7 +310,20 @@ def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = Fals
)


def upsample_cubic(x: mx.array, scale_factor: Tuple, align_corners: bool = False):
def upsample_cubic(
x: mx.array,
scale_factor: Tuple,
align_corners: bool = False,
antialias: bool = False,
):
_validate_antialias_align_corners(align_corners, antialias)
if antialias:
return _interpolate_separable(
x=x,
scale_factor=scale_factor,
indices_fn=_cubic_aa_indices,
align_corners=align_corners,
)
return _interpolate(
x=x,
scale_factor=scale_factor,
Expand Down Expand Up @@ -185,6 +354,23 @@ class Upsample(Module):
``align_corners=True`` then the top and left edge of the input and
output will be matching as will the bottom right edge.

.. note::
When ``antialias=True`` is used with ``"linear"`` or ``"cubic"`` mode,
an antialiased filter is applied during downsampling (scale factor < 1),
producing smoother results by avoiding aliasing artifacts. For 2D
integer-ratio downscales with ``align_corners=False``, this matches the
behavior of PyTorch's ``F.interpolate(antialias=True)``. Non-integer
scale factors are supported but may differ from PyTorch because of
existing index-selection differences.

For ``"cubic"`` mode, enabling ``antialias`` also changes the cubic
kernel coefficient from ``a=-0.75`` (OpenCV convention) to ``a=-0.5``
(PIL/Pillow convention), matching PyTorch's behavior. This affects the
interpolant shape, not just the filter width.

``antialias=True`` with ``align_corners=True`` is not supported and
will raise a ``ValueError``.

Parameters:
scale_factor (float or tuple): The multiplier for the spatial size.
If a ``float`` is provided, it is the multiplier for all spatial dimensions.
Expand All @@ -195,6 +381,11 @@ class Upsample(Module):
align_corners (bool, optional): Changes the way the corners are treated
during ``"linear"`` and ``"cubic"`` upsampling. See the note above and the
examples below for more details. Default: ``False``.
antialias (bool, optional): If ``True``, apply an antialiasing filter
when downsampling with ``"linear"`` or ``"cubic"`` mode. For
``"cubic"`` mode this also switches the kernel coefficient to
``a=-0.5``. Not supported with ``"nearest"`` mode or with
``align_corners=True``. Default: ``False``.

Examples:
>>> import mlx.core as mx
Expand Down Expand Up @@ -230,22 +421,35 @@ def __init__(
scale_factor: Union[float, Tuple],
mode: Literal["nearest", "linear", "cubic"] = "nearest",
align_corners: bool = False,
antialias: bool = False,
):
super().__init__()
if mode not in ["nearest", "linear", "cubic"]:
raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}")
if antialias and mode == "nearest":
raise ValueError(
"[Upsample] Antialiasing is not supported for nearest neighbor upsampling"
)
if isinstance(scale_factor, (list, tuple)):
self.scale_factor = tuple(map(float, scale_factor))
scale_factor = tuple(map(float, scale_factor))
else:
self.scale_factor = float(scale_factor)
scale_factor = float(scale_factor)

_validate_antialias_align_corners(align_corners, antialias)

self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
self.antialias = antialias

def _extra_repr(self) -> str:
return (
repr_str = (
f"scale_factor={self.scale_factor}, mode={self.mode!r}, "
f"align_corners={self.align_corners}"
)
if self.antialias:
repr_str += ", antialias=True"
return repr_str

def __call__(self, x: mx.array) -> mx.array:
dims = x.ndim - 2
Expand All @@ -270,8 +474,8 @@ def __call__(self, x: mx.array) -> mx.array:
if self.mode == "nearest":
return upsample_nearest(x, scale_factor)
elif self.mode == "linear":
return upsample_linear(x, scale_factor, self.align_corners)
return upsample_linear(x, scale_factor, self.align_corners, self.antialias)
elif self.mode == "cubic":
return upsample_cubic(x, scale_factor, self.align_corners)
return upsample_cubic(x, scale_factor, self.align_corners, self.antialias)
else:
raise Exception(f"Unknown interpolation mode: {self.mode}")
Loading
Loading