From 82a17994a07812b298c440c632ed60126a5c3059 Mon Sep 17 00:00:00 2001 From: Noah Lyons Date: Sat, 13 Jun 2026 18:46:52 -0400 Subject: [PATCH] Add antialias support to nn.Upsample --- python/mlx/nn/layers/upsample.py | 222 ++++++++++++++++++++++- python/tests/test_upsample.py | 293 +++++++++++++++++++++++++++++++ 2 files changed, 506 insertions(+), 9 deletions(-) diff --git a/python/mlx/nn/layers/upsample.py b/python/mlx/nn/layers/upsample.py index e6bd282af1..87f74cf8d9 100644 --- a/python/mlx/nn/layers/upsample.py +++ b/python/mlx/nn/layers/upsample.py @@ -3,6 +3,7 @@ import operator from functools import partial, reduce from itertools import product +from math import ceil from typing import Callable, Literal, Tuple, Union import mlx.core as mx @@ -51,6 +52,118 @@ def _linear_indices(N, scale, align_corners, dim, ndims): ) +def _aa_indices(N, scale, align_corners, dim, ndims, kernel_fn, kernel_radius): + """Compute antialiased interpolation indices for a given kernel. + + When downscaling (scale < 1), the kernel support widens by 1/scale to + act as a low-pass filter, preventing aliasing. Out-of-bounds taps are + zeroed and weights are renormalized per output pixel. + + For upscaling (scale >= 1), the kernel is applied at its native width + without widening. This matches PyTorch's F.interpolate(antialias=True) + behavior where the kernel coefficient (e.g. a=-0.5 for cubic) is used + for both up and downsampling when antialias=True. + + Args: + N: input size for this dimension + scale: scale factor for this dimension + align_corners: align_corners flag + dim: which spatial dimension + ndims: number of spatial dimensions + kernel_fn: callable(distance) -> weight, operating in normalized + filter coordinates where |distance| < kernel_radius has support + kernel_radius: support radius of the kernel in normalized coords + (1.0 for triangle/linear, 2.0 for cubic) + """ + indices = _scaled_indices(N, scale, align_corners, dim, ndims) + + # For downscale, widen the filter by 1/scale + if scale < 1: + inv_scale = 1.0 / scale + else: + inv_scale = 1.0 + + support = kernel_radius * inv_scale + num_taps = ceil(support) + 1 + + # Compute per-tap weights, zero out-of-bounds, then normalize + all_idx = [] + all_w = [] + for k in range(-num_taps + 1, num_taps): + idx = mx.floor(indices) + k + # Map distance to normalized filter coordinates + dist = mx.abs(indices - idx) / inv_scale + w = kernel_fn(dist) + # Zero out-of-bounds taps + w = mx.where((idx >= 0) & (idx < N), w, 0.0) + all_idx.append(idx) + all_w.append(w) + + # Normalize so weights sum to 1 per output pixel + w_sum = sum(all_w) + w_sum = mx.where(w_sum > 0, w_sum, 1.0) + + result = [] + for idx, w in zip(all_idx, all_w): + w = mx.expand_dims(w / w_sum, -1) + idx = mx.clip(idx, a_min=0, a_max=N - 1).astype(mx.uint32) + result.append((idx, w)) + + return tuple(result) + + +def _triangle_kernel(x): + """Triangle (linear) filter kernel. Support radius = 1.""" + return mx.maximum(1.0 - x, 0.0) + + +def _cubic_kernel(x): + """Keys cubic kernel with a=-0.5 (PIL/Pillow convention). + + This coefficient is used by PyTorch when antialias=True for both + bilinear and bicubic modes. The non-antialiased cubic path uses + a=-0.75 (OpenCV convention) -- see ``_cubic_indices``. + + Support radius = 2. + """ + a = -0.5 + w_inner = ((a + 2.0) * x - (a + 3.0)) * x * x + 1 + w_outer = (((x - 5) * x + 8) * x - 4) * a + return mx.where(x <= 1.0, w_inner, mx.where(x <= 2.0, w_outer, 0.0)) + + +def _linear_aa_indices(N, scale, align_corners, dim, ndims): + """Linear interpolation with antialiasing (triangle kernel).""" + return _aa_indices( + N, + scale, + align_corners, + dim, + ndims, + kernel_fn=_triangle_kernel, + kernel_radius=1.0, + ) + + +def _cubic_aa_indices(N, scale, align_corners, dim, ndims): + """Cubic interpolation with antialiasing (Keys cubic, a=-0.5). + + Note: the non-antialiased cubic path (``_cubic_indices``) uses a=-0.75 + (OpenCV convention). When ``antialias=True``, PyTorch switches to a=-0.5 + (PIL convention). This coefficient change affects the interpolant shape, + not just the filter width. See ``_cubic_kernel`` for details. + """ + return _aa_indices( + N, + scale, + align_corners, + dim, + ndims, + kernel_fn=_cubic_kernel, + kernel_radius=2.0, + ) + + def _cubic_indices(N, scale, align_corners, dim, ndims): indices = _scaled_indices(N, scale, align_corners, dim, ndims) indices_l1 = mx.floor(indices) @@ -60,8 +173,9 @@ def _cubic_indices(N, scale, align_corners, dim, ndims): @partial(mx.compile, shapeless=True) def _get_weight(ind, grid, dist): - # PyTorch uses -0.5 for antialiasing=true (compatibility with PIL) - # and uses -0.75 for antialiasing=false (compatibility with OpenCV) + # a=-0.75 (OpenCV convention) for non-antialiased cubic. + # When antialias=True, _cubic_aa_indices uses a=-0.5 (PIL convention) + # via _cubic_kernel instead. a = -0.75 x = mx.abs(ind - grid) if dist == 1: @@ -89,6 +203,14 @@ def _get_weight(ind, grid, dist): ) +def _validate_antialias_align_corners(align_corners, antialias): + if antialias and align_corners: + raise ValueError( + "[Upsample] antialias=True with align_corners=True is not " + "supported. Use align_corners=False for antialiased interpolation." + ) + + def upsample_nearest(x: mx.array, scale_factor: Tuple): dims = x.ndim - 2 if dims != len(scale_factor): @@ -145,7 +267,41 @@ def _interpolate( return sum(wi * xi for wi, xi in zip(weights, samples)) -def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = False): +def _interpolate_separable( + x: mx.array, scale_factor: Tuple, indices_fn: Callable, align_corners: bool = False +): + dims = x.ndim - 2 + if dims != len(scale_factor): + raise ValueError("A scale needs to be provided for each spatial dimension") + + _, *N, _ = x.shape + out = x + + for i, (n, s) in enumerate(zip(N, scale_factor)): + axis = i + 1 + samples = [] + for idx, weight in indices_fn(n, s, align_corners, i, dims): + sample = mx.take(out, idx.reshape(-1), axis=axis) + samples.append(sample * weight) + out = sum(samples) + + return out + + +def upsample_linear( + x: mx.array, + scale_factor: Tuple, + align_corners: bool = False, + antialias: bool = False, +): + _validate_antialias_align_corners(align_corners, antialias) + if antialias: + return _interpolate_separable( + x=x, + scale_factor=scale_factor, + indices_fn=_linear_aa_indices, + align_corners=align_corners, + ) return _interpolate( x=x, scale_factor=scale_factor, @@ -154,7 +310,20 @@ def upsample_linear(x: mx.array, scale_factor: Tuple, align_corners: bool = Fals ) -def upsample_cubic(x: mx.array, scale_factor: Tuple, align_corners: bool = False): +def upsample_cubic( + x: mx.array, + scale_factor: Tuple, + align_corners: bool = False, + antialias: bool = False, +): + _validate_antialias_align_corners(align_corners, antialias) + if antialias: + return _interpolate_separable( + x=x, + scale_factor=scale_factor, + indices_fn=_cubic_aa_indices, + align_corners=align_corners, + ) return _interpolate( x=x, scale_factor=scale_factor, @@ -185,6 +354,23 @@ class Upsample(Module): ``align_corners=True`` then the top and left edge of the input and output will be matching as will the bottom right edge. + .. note:: + When ``antialias=True`` is used with ``"linear"`` or ``"cubic"`` mode, + an antialiased filter is applied during downsampling (scale factor < 1), + producing smoother results by avoiding aliasing artifacts. For 2D + integer-ratio downscales with ``align_corners=False``, this matches the + behavior of PyTorch's ``F.interpolate(antialias=True)``. Non-integer + scale factors are supported but may differ from PyTorch because of + existing index-selection differences. + + For ``"cubic"`` mode, enabling ``antialias`` also changes the cubic + kernel coefficient from ``a=-0.75`` (OpenCV convention) to ``a=-0.5`` + (PIL/Pillow convention), matching PyTorch's behavior. This affects the + interpolant shape, not just the filter width. + + ``antialias=True`` with ``align_corners=True`` is not supported and + will raise a ``ValueError``. + Parameters: scale_factor (float or tuple): The multiplier for the spatial size. If a ``float`` is provided, it is the multiplier for all spatial dimensions. @@ -195,6 +381,11 @@ class Upsample(Module): align_corners (bool, optional): Changes the way the corners are treated during ``"linear"`` and ``"cubic"`` upsampling. See the note above and the examples below for more details. Default: ``False``. + antialias (bool, optional): If ``True``, apply an antialiasing filter + when downsampling with ``"linear"`` or ``"cubic"`` mode. For + ``"cubic"`` mode this also switches the kernel coefficient to + ``a=-0.5``. Not supported with ``"nearest"`` mode or with + ``align_corners=True``. Default: ``False``. Examples: >>> import mlx.core as mx @@ -230,22 +421,35 @@ def __init__( scale_factor: Union[float, Tuple], mode: Literal["nearest", "linear", "cubic"] = "nearest", align_corners: bool = False, + antialias: bool = False, ): super().__init__() if mode not in ["nearest", "linear", "cubic"]: raise ValueError(f"[Upsample] Got unsupported upsampling algorithm: {mode}") + if antialias and mode == "nearest": + raise ValueError( + "[Upsample] Antialiasing is not supported for nearest neighbor upsampling" + ) if isinstance(scale_factor, (list, tuple)): - self.scale_factor = tuple(map(float, scale_factor)) + scale_factor = tuple(map(float, scale_factor)) else: - self.scale_factor = float(scale_factor) + scale_factor = float(scale_factor) + + _validate_antialias_align_corners(align_corners, antialias) + + self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners + self.antialias = antialias def _extra_repr(self) -> str: - return ( + repr_str = ( f"scale_factor={self.scale_factor}, mode={self.mode!r}, " f"align_corners={self.align_corners}" ) + if self.antialias: + repr_str += ", antialias=True" + return repr_str def __call__(self, x: mx.array) -> mx.array: dims = x.ndim - 2 @@ -270,8 +474,8 @@ def __call__(self, x: mx.array) -> mx.array: if self.mode == "nearest": return upsample_nearest(x, scale_factor) elif self.mode == "linear": - return upsample_linear(x, scale_factor, self.align_corners) + return upsample_linear(x, scale_factor, self.align_corners, self.antialias) elif self.mode == "cubic": - return upsample_cubic(x, scale_factor, self.align_corners) + return upsample_cubic(x, scale_factor, self.align_corners, self.antialias) else: raise Exception(f"Unknown interpolation mode: {self.mode}") diff --git a/python/tests/test_upsample.py b/python/tests/test_upsample.py index 631853cce0..da2c3db740 100644 --- a/python/tests/test_upsample.py +++ b/python/tests/test_upsample.py @@ -1,11 +1,13 @@ # Copyright © 2023-2024 Apple Inc. import unittest +from unittest.mock import patch import mlx.core as mx import mlx.nn as nn import mlx_tests import numpy as np +from mlx.nn.layers.upsample import upsample_cubic, upsample_linear try: import torch @@ -95,6 +97,297 @@ def run_upsample( dtype=dtype, ) + @unittest.skipIf(not has_torch, "requires Torch") + def test_torch_upsample_antialias(self): + """Test antialiased downsampling matches PyTorch F.interpolate(antialias=True).""" + + def run_antialias( + N, + C, + idim, + scale_factor, + mode, + align_corners=False, + dtype="float32", + atol=1e-5, + ): + with self.subTest( + N=N, + C=C, + idim=idim, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + ): + np_dtype = getattr(np, dtype) + np.random.seed(0) + iH, iW = idim + in_np = np.random.normal(-1.0, 1.0, (N, iH, iW, C)).astype(np_dtype) + + in_mx = mx.array(in_np) + in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)).to("cpu") + + out_mx = nn.Upsample( + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + antialias=True, + )(in_mx) + mode_pt = { + "linear": "bilinear", + "cubic": "bicubic", + }[mode] + out_pt = F.interpolate( + in_pt, + scale_factor=scale_factor, + mode=mode_pt, + align_corners=align_corners, + antialias=True, + ) + out_pt = torch.permute(out_pt, (0, 2, 3, 1)).numpy(force=True) + self.assertEqual(out_pt.shape, out_mx.shape) + self.assertTrue( + np.allclose(out_pt, out_mx, atol=atol), + f"antialias {mode} ac={align_corners} scale={scale_factor} max_diff=" + f"{np.abs(out_pt - np.array(out_mx)).max():.2e}", + ) + + for dtype in ("float32",): + for N, C in ((1, 1), (2, 3)): + # Test downscale with antialias — use integer-ratio scales + # to avoid the pre-existing _scaled_indices step divergence + # for non-integer ratios (see issue #2186). + for idim, scale_factor in ( + ((4, 4), (0.5, 0.5)), + ((8, 8), (0.5, 0.5)), + ((8, 8), (0.25, 0.25)), + ((16, 16), (0.5, 0.5)), + ((16, 16), (0.25, 0.25)), + ((32, 32), (0.5, 0.5)), + ((32, 32), (0.25, 0.25)), + ((64, 64), (0.5, 0.5)), + ((10, 10), (0.5, 0.5)), + ((12, 12), (0.25, 0.25)), + ((8, 16), (0.5, 0.5)), # non-square + ((16, 8), (0.5, 0.25)), # different scales per dim + ): + for mode in ("linear", "cubic"): + # align_corners=True + antialias has a known + # interaction with _scaled_indices that requires + # further work. Test align_corners=False for now. + run_antialias( + N, + C, + idim, + scale_factor, + mode, + align_corners=False, + dtype=dtype, + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_antialias_upscale_linear_is_noop(self): + """For linear mode, antialias has no effect on upscaling.""" + np.random.seed(0) + in_np = np.random.normal(-1.0, 1.0, (1, 4, 4, 3)).astype(np.float32) + in_mx = mx.array(in_np) + + for scale in (2.0, 3.0): + with self.subTest(scale=scale): + out_aa = nn.Upsample( + scale_factor=scale, + mode="linear", + align_corners=False, + antialias=True, + )(in_mx) + out_no = nn.Upsample( + scale_factor=scale, + mode="linear", + align_corners=False, + antialias=False, + )(in_mx) + self.assertTrue( + np.allclose(np.array(out_aa), np.array(out_no), atol=1e-7), + "linear antialias should be no-op for upscaling", + ) + + @unittest.skipIf(not has_torch, "requires Torch") + def test_antialias_upscale_cubic_matches_pytorch(self): + """For cubic mode, antialias changes a from -0.75 to -0.5 even on upscale.""" + np.random.seed(0) + in_np = np.random.normal(-1.0, 1.0, (1, 8, 8, 3)).astype(np.float32) + in_mx = mx.array(in_np) + in_pt = torch.from_numpy(in_np.transpose(0, 3, 1, 2)) + + for scale in (2.0, 3.0): + with self.subTest(scale=scale): + out_mx = nn.Upsample( + scale_factor=scale, + mode="cubic", + align_corners=False, + antialias=True, + )(in_mx) + out_pt = F.interpolate( + in_pt, + scale_factor=scale, + mode="bicubic", + align_corners=False, + antialias=True, + ) + out_pt_np = out_pt.permute(0, 2, 3, 1).numpy(force=True) + self.assertTrue( + np.allclose(np.array(out_mx), out_pt_np, atol=1e-5), + f"cubic antialias upscale {scale}x max_diff=" + f"{np.abs(np.array(out_mx) - out_pt_np).max():.2e}", + ) + + def test_antialias_non_integer_scale_smoke(self): + """Smoke test for non-integer scale factors (no PyTorch comparison).""" + np.random.seed(42) + in_np = np.random.normal(0, 1, (1, 32, 32, 3)).astype(np.float32) + in_mx = mx.array(in_np) + + for scale in (0.3, 0.7, 0.6): + for mode in ("linear", "cubic"): + with self.subTest(scale=scale, mode=mode): + out = nn.Upsample( + scale_factor=scale, + mode=mode, + align_corners=False, + antialias=True, + )(in_mx) + mx.eval(out) + out_np = np.array(out) + + # Correct shape + expected = int(32 * scale) + self.assertEqual(out_np.shape, (1, expected, expected, 3)) + + self.assertTrue(np.all(np.isfinite(out_np))) + if mode == "linear": + # Linear AA uses non-negative triangle weights, so it is + # bounded by the input range. Cubic interpolation can + # overshoot because the Keys kernel has negative lobes. + self.assertLessEqual( + out_np.max(), + in_np.max() + 0.01, + "linear AA output should not exceed input range", + ) + self.assertGreaterEqual( + out_np.min(), + in_np.min() - 0.01, + "linear AA output should not go below input range", + ) + + def test_antialias_1d_smoke(self): + """Test that antialias works on 1D spatial input (3D tensor). + + PyTorch does not support antialias on 1D tensors, so this is a + smoke test only (correct shape, non-trivial, smoother than non-AA). + """ + np.random.seed(0) + for length, scale in ((16, 0.5), (32, 0.25)): + for mode in ("linear", "cubic"): + with self.subTest(length=length, scale=scale, mode=mode): + in_np = np.random.normal(0, 1, (1, length, 3)).astype(np.float32) + in_mx = mx.array(in_np) + + out_aa = nn.Upsample( + scale_factor=scale, + mode=mode, + align_corners=False, + antialias=True, + )(in_mx) + out_no = nn.Upsample( + scale_factor=scale, + mode=mode, + align_corners=False, + antialias=False, + )(in_mx) + mx.eval(out_aa, out_no) + + expected_len = int(length * scale) + self.assertEqual(out_aa.shape, (1, expected_len, 3)) + # AA should differ from non-AA + self.assertGreater( + float(mx.abs(out_aa - out_no).max()), + 1e-6, + ) + + def test_antialias_uses_separable_path(self): + """AA interpolation should avoid cartesian product gather expansion.""" + in_mx = mx.zeros((1, 16, 16, 1)) + for mode in ("linear", "cubic"): + with self.subTest(mode=mode): + with patch( + "mlx.nn.layers.upsample.product", + side_effect=AssertionError("cartesian interpolation path used"), + ): + out = nn.Upsample( + scale_factor=0.25, + mode=mode, + align_corners=False, + antialias=True, + )(in_mx) + mx.eval(out) + self.assertEqual(out.shape, (1, 4, 4, 1)) + + def test_antialias_nearest_raises(self): + """Antialias + nearest should raise ValueError.""" + with self.assertRaises(ValueError): + nn.Upsample(scale_factor=0.5, mode="nearest", antialias=True) + + def test_antialias_align_corners_raises(self): + """Antialias + align_corners is unsupported and should raise.""" + for mode in ("linear", "cubic"): + for scale in (0.5, 2.0): + with self.subTest(mode=mode, scale=scale): + with self.assertRaises(ValueError): + nn.Upsample( + scale_factor=scale, + mode=mode, + align_corners=True, + antialias=True, + ) + + def test_antialias_align_corners_direct_functions_raise(self): + """Direct interpolation helpers should enforce the same AA contract.""" + x = mx.zeros((1, 4, 4, 1)) + for fn in (upsample_linear, upsample_cubic): + for scale in ((0.5, 0.5), (2.0, 2.0)): + with self.subTest(fn=fn.__name__, scale=scale): + with self.assertRaises(ValueError): + fn(x, scale, align_corners=True, antialias=True) + + def test_antialias_differs_from_non_antialias(self): + """Antialiased downscale should produce different output than non-AA.""" + np.random.seed(42) + in_np = np.random.normal(0, 1, (1, 32, 32, 3)).astype(np.float32) + in_mx = mx.array(in_np) + + out_no = nn.Upsample( + scale_factor=0.5, mode="linear", align_corners=False, antialias=False + )(in_mx) + out_aa = nn.Upsample( + scale_factor=0.5, mode="linear", align_corners=False, antialias=True + )(in_mx) + + # Outputs should differ (AA applies a wider filter) + diff = float(mx.abs(out_aa - out_no).max()) + self.assertGreater( + diff, + 1e-6, + "AA and non-AA downscale should produce different results", + ) + # AA output should have lower variance (wider filter = more averaging) + std_no = float(mx.std(out_no)) + std_aa = float(mx.std(out_aa)) + self.assertLess( + std_aa, + std_no, + f"AA output should have lower variance: std_aa={std_aa:.4f} >= std_no={std_no:.4f}", + ) + if __name__ == "__main__": mlx_tests.MLXTestRunner()