From 47523f0c61f0880a6dba6658ae9b862cc4d4d57f Mon Sep 17 00:00:00 2001 From: katlun-lgtm Date: Sat, 30 May 2026 15:52:12 -0400 Subject: [PATCH 1/3] feat: add reflect and symmetric padding modes to mx.pad Implements numpy.pad-compatible "reflect" and "symmetric" modes for mx.pad, matching numpy semantics for arbitrary pad sizes (the reflection repeats when the pad width exceeds the axis length). - mlx/ops.cpp: reflect_pad helper builds a per-axis triangle-wave index map and gathers with take; one take per padded axis. reflect uses period 2(n-1) and skips the edge; symmetric uses period 2n and repeats the edge. n==1 maps to 0. - python/src/ops.cpp: extend the pad mode Literal and docstring. - python/tests/test_ops.py: test_pad_reflect_symmetric covers in-bounds, multi-reflect, asymmetric per-axis, zero-width sides, and degenerate axes (n==1, n==2), checked against numpy.pad. - tests/ops_tests.cpp: reflect/symmetric CHECK cases incl. multi-reflect. --- mlx/ops.cpp | 51 ++++++++++++++++++++++++++++++++++++++++ python/src/ops.cpp | 4 +++- python/tests/test_ops.py | 32 +++++++++++++++++++++++++ tests/ops_tests.cpp | 35 +++++++++++++++++++++++++++ 4 files changed, 121 insertions(+), 1 deletion(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e4ce3d750f..cba20b847b 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1362,6 +1362,53 @@ array tile( return reshape(x, std::move(final_shape), s); } +array reflect_pad( + const array& a, + const std::vector& axes, + const Shape& low_pad_size, + const Shape& high_pad_size, + bool include_edge, + StreamOrDevice s /* = {} */) { + // Reflect (include_edge=false) or symmetric (include_edge=true) padding. + // Matches numpy.pad for arbitrary pad sizes (the reflection repeats as needed). + // For an out-of-range coordinate r (relative to the original axis [0, n)), + // map it back into [0, n) by reflection: + // reflect -> period 2(n-1), edge NOT repeated + // symmetric -> period 2n, edge repeated + auto reflect_coord = [](int r, int n, bool include_edge) -> int { + if (n == 1) { + return 0; + } + if (include_edge) { + int period = 2 * n; + int m = ((r % period) + period) % period; + return m < n ? m : (2 * n - 1 - m); + } else { + int period = 2 * (n - 1); + int m = ((r % period) + period) % period; + return m < n ? m : (period - m); + } + }; + array out = a; + for (size_t i = 0; i < axes.size(); i++) { + int ax = axes[i]; + int L = low_pad_size[i]; + int H = high_pad_size[i]; + if (L == 0 && H == 0) { + continue; + } + int n = out.shape(ax); + int total = L + n + H; + std::vector idx_vec(total); + for (int p = 0; p < total; p++) { + idx_vec[p] = reflect_coord(p - L, n, include_edge); + } + array idx = array(idx_vec.begin(), {total}, int32); + out = take(out, idx, ax, s); + } + return out; +} + array edge_pad( const array& a, const std::vector& axes, @@ -1454,6 +1501,10 @@ array pad( {a, astype(pad_value, a.dtype(), s)}); } else if (mode == "edge") { return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); + } else if (mode == "reflect") { + return reflect_pad(a, axes, low_pad_size, high_pad_size, false, s); + } else if (mode == "symmetric") { + return reflect_pad(a, axes, low_pad_size, high_pad_size, true, s); } else { std::ostringstream msg; msg << "Invalid padding mode (" << mode << ") passed to pad"; diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 198a4861d9..33a8cdcb13 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3271,7 +3271,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge', 'reflect', 'symmetric'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Pad an array with a constant value @@ -3286,6 +3286,8 @@ void init_ops(nb::module_& m) { mode: Padding mode. One of the following strings: "constant" (default): Pads with a constant value. "edge": Pads with the edge values of array. + "reflect": Pads with the reflection of the array, without repeating the edge values. + "symmetric": Pads with the reflection of the array, repeating the edge values. constant_value (array or scalar, optional): Optional constant value to pad the edges of the array with. diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 484ae47c3a..3079ee8e0e 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2068,6 +2068,38 @@ def test_nan_to_num(self): out_mx = mx.nan_to_num(a, nan=0.0, posinf=1000, neginf=-1000) self.assertTrue(np.allclose(out_mx, out_np)) + def test_pad_reflect_symmetric(self): + # mx.pad reflect/symmetric must match numpy.pad exactly (it is a gather). + # Covers in-bounds, multi-reflect (pad larger than the axis), asymmetric + # per-axis widths, zero-width sides, and degenerate axes (n == 1, n == 2). + cases = [ + ((8,), [(2, 3)]), + ((8,), [(0, 4)]), + ((8,), [(3, 0)]), + ((8,), [(7, 8)]), + ((4,), [(10, 7)]), # multi-reflect + ((4,), [(20, 20)]), # multi-reflect, both sides + ((3,), [(9, 1)]), # multi-reflect + ((1,), [(3, 2)]), # degenerate axis + ((2,), [(5, 6)]), # smallest non-trivial, multi-reflect + ((5, 6), [(2, 3), (1, 2)]), + ((5, 6), [(9, 9), (11, 0)]), # both axes multi-reflect + ((3, 4, 5), [(1, 1), (0, 0), (2, 2)]), + ((3, 4, 5), [(4, 4), (0, 0), (7, 3)]), + ] + for mode in ("reflect", "symmetric"): + for shape, pw in cases: + a_npy = np.random.randn(*shape).astype(np.float32) + a_mlx = mx.array(a_npy) + b_npy = np.pad(a_npy, pw, mode=mode) + b_mlx = mx.pad(a_mlx, pw, mode=mode) + self.assertEqual(b_mlx.shape, tuple(b_npy.shape)) + self.assertTrue( + np.array_equal(np.array(b_mlx), b_npy), + msg=f"mismatch mode={mode} shape={shape} pad={pw}", + ) + self.assertEqual(b_mlx.dtype, mx.float32) + def test_as_strided(self): x_npy = np.random.randn(128).astype(np.float32) x_mlx = mx.array(x_npy) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c2230f77cf..cfa3a901eb 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2868,6 +2868,41 @@ TEST_CASE("test pad") { 0.0f}, {4, 4}); CHECK(array_equal(padded_x, expected).item()); + + // reflect padding (mirror without repeating the edge value) + x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {5}); + CHECK(array_equal( + pad(x, {{2, 2}}, array(0.0f), "reflect"), + array( + {3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 4.0f, 3.0f}, + {9})) + .item()); + CHECK(array_equal( + pad(x, {{0, 3}}, array(0.0f), "reflect"), + array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 4.0f, 3.0f, 2.0f}, {8})) + .item()); + + // symmetric padding (mirror repeating the edge value) + CHECK(array_equal( + pad(x, {{2, 2}}, array(0.0f), "symmetric"), + array( + {2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 5.0f, 4.0f}, + {9})) + .item()); + CHECK(array_equal( + pad(x, {{3, 0}}, array(0.0f), "symmetric"), + array({3.0f, 2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {8})) + .item()); + + // multi-reflect: pad larger than the axis repeats the reflection (numpy parity) + x = array({1.0f, 2.0f, 3.0f}, {3}); + CHECK(array_equal( + pad(x, {{5, 5}}, array(0.0f), "reflect"), + array( + {2.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f, + 2.0f, 3.0f, 2.0f}, + {13})) + .item()); } TEST_CASE("test power") { From 579bfd36eae6f9ecad2396ccb0803b2428c7e1c9 Mon Sep 17 00:00:00 2001 From: Cheng Date: Sat, 13 Jun 2026 09:11:38 +0900 Subject: [PATCH 2/3] Fix lint --- mlx/ops.cpp | 6 +++--- python/tests/test_ops.py | 12 ++++++------ tests/ops_tests.cpp | 26 +++++++++++++++++--------- 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index cba20b847b..ed4db4d3c3 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1370,9 +1370,9 @@ array reflect_pad( bool include_edge, StreamOrDevice s /* = {} */) { // Reflect (include_edge=false) or symmetric (include_edge=true) padding. - // Matches numpy.pad for arbitrary pad sizes (the reflection repeats as needed). - // For an out-of-range coordinate r (relative to the original axis [0, n)), - // map it back into [0, n) by reflection: + // Matches numpy.pad for arbitrary pad sizes (the reflection repeats as + // needed). For an out-of-range coordinate r (relative to the original axis + // [0, n)), map it back into [0, n) by reflection: // reflect -> period 2(n-1), edge NOT repeated // symmetric -> period 2n, edge repeated auto reflect_coord = [](int r, int n, bool include_edge) -> int { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 3079ee8e0e..91a60034fd 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2077,13 +2077,13 @@ def test_pad_reflect_symmetric(self): ((8,), [(0, 4)]), ((8,), [(3, 0)]), ((8,), [(7, 8)]), - ((4,), [(10, 7)]), # multi-reflect - ((4,), [(20, 20)]), # multi-reflect, both sides - ((3,), [(9, 1)]), # multi-reflect - ((1,), [(3, 2)]), # degenerate axis - ((2,), [(5, 6)]), # smallest non-trivial, multi-reflect + ((4,), [(10, 7)]), # multi-reflect + ((4,), [(20, 20)]), # multi-reflect, both sides + ((3,), [(9, 1)]), # multi-reflect + ((1,), [(3, 2)]), # degenerate axis + ((2,), [(5, 6)]), # smallest non-trivial, multi-reflect ((5, 6), [(2, 3), (1, 2)]), - ((5, 6), [(9, 9), (11, 0)]), # both axes multi-reflect + ((5, 6), [(9, 9), (11, 0)]), # both axes multi-reflect ((3, 4, 5), [(1, 1), (0, 0), (2, 2)]), ((3, 4, 5), [(4, 4), (0, 0), (7, 3)]), ] diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index cfa3a901eb..7f0421a037 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2873,9 +2873,7 @@ TEST_CASE("test pad") { x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {5}); CHECK(array_equal( pad(x, {{2, 2}}, array(0.0f), "reflect"), - array( - {3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 4.0f, 3.0f}, - {9})) + array({3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 4.0f, 3.0f}, {9})) .item()); CHECK(array_equal( pad(x, {{0, 3}}, array(0.0f), "reflect"), @@ -2885,22 +2883,32 @@ TEST_CASE("test pad") { // symmetric padding (mirror repeating the edge value) CHECK(array_equal( pad(x, {{2, 2}}, array(0.0f), "symmetric"), - array( - {2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 5.0f, 4.0f}, - {9})) + array({2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 5.0f, 4.0f}, {9})) .item()); CHECK(array_equal( pad(x, {{3, 0}}, array(0.0f), "symmetric"), array({3.0f, 2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {8})) .item()); - // multi-reflect: pad larger than the axis repeats the reflection (numpy parity) + // multi-reflect: pad larger than the axis repeats the reflection (numpy + // parity) x = array({1.0f, 2.0f, 3.0f}, {3}); CHECK(array_equal( pad(x, {{5, 5}}, array(0.0f), "reflect"), array( - {2.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f, - 2.0f, 3.0f, 2.0f}, + {2.0f, + 1.0f, + 2.0f, + 3.0f, + 2.0f, + 1.0f, + 2.0f, + 3.0f, + 2.0f, + 1.0f, + 2.0f, + 3.0f, + 2.0f}, {13})) .item()); } From 71a68c65bc74a2f99d11d358ad3049c200547f1e Mon Sep 17 00:00:00 2001 From: katlun-lgtm Date: Fri, 12 Jun 2026 23:01:00 -0400 Subject: [PATCH 3/3] docs: add ACKNOWLEDGMENTS entry for reflect/symmetric pad modes --- ACKNOWLEDGMENTS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 186908f09c..d83832be8f 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -20,6 +20,7 @@ MLX was developed with contributions from the following individuals: - Paul Paczuski: Improved stability of BCE loss calculation - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function. +- katlun-lgtm: Added `reflect` and `symmetric` padding modes.