diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 186908f09c..d83832be8f 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -20,6 +20,7 @@ MLX was developed with contributions from the following individuals: - Paul Paczuski: Improved stability of BCE loss calculation - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function. +- katlun-lgtm: Added `reflect` and `symmetric` padding modes. diff --git a/mlx/ops.cpp b/mlx/ops.cpp index e4ce3d750f..ed4db4d3c3 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1362,6 +1362,53 @@ array tile( return reshape(x, std::move(final_shape), s); } +array reflect_pad( + const array& a, + const std::vector& axes, + const Shape& low_pad_size, + const Shape& high_pad_size, + bool include_edge, + StreamOrDevice s /* = {} */) { + // Reflect (include_edge=false) or symmetric (include_edge=true) padding. + // Matches numpy.pad for arbitrary pad sizes (the reflection repeats as + // needed). For an out-of-range coordinate r (relative to the original axis + // [0, n)), map it back into [0, n) by reflection: + // reflect -> period 2(n-1), edge NOT repeated + // symmetric -> period 2n, edge repeated + auto reflect_coord = [](int r, int n, bool include_edge) -> int { + if (n == 1) { + return 0; + } + if (include_edge) { + int period = 2 * n; + int m = ((r % period) + period) % period; + return m < n ? m : (2 * n - 1 - m); + } else { + int period = 2 * (n - 1); + int m = ((r % period) + period) % period; + return m < n ? m : (period - m); + } + }; + array out = a; + for (size_t i = 0; i < axes.size(); i++) { + int ax = axes[i]; + int L = low_pad_size[i]; + int H = high_pad_size[i]; + if (L == 0 && H == 0) { + continue; + } + int n = out.shape(ax); + int total = L + n + H; + std::vector idx_vec(total); + for (int p = 0; p < total; p++) { + idx_vec[p] = reflect_coord(p - L, n, include_edge); + } + array idx = array(idx_vec.begin(), {total}, int32); + out = take(out, idx, ax, s); + } + return out; +} + array edge_pad( const array& a, const std::vector& axes, @@ -1454,6 +1501,10 @@ array pad( {a, astype(pad_value, a.dtype(), s)}); } else if (mode == "edge") { return edge_pad(a, axes, low_pad_size, high_pad_size, out_shape, s); + } else if (mode == "reflect") { + return reflect_pad(a, axes, low_pad_size, high_pad_size, false, s); + } else if (mode == "symmetric") { + return reflect_pad(a, axes, low_pad_size, high_pad_size, true, s); } else { std::ostringstream msg; msg << "Invalid padding mode (" << mode << ") passed to pad"; diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 198a4861d9..33a8cdcb13 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3271,7 +3271,7 @@ void init_ops(nb::module_& m) { nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + "def pad(a: array, pad_width: Union[int, tuple[int], tuple[int, int], list[tuple[int, int]]], mode: Literal['constant', 'edge', 'reflect', 'symmetric'] = 'constant', constant_values: Union[scalar, array] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Pad an array with a constant value @@ -3286,6 +3286,8 @@ void init_ops(nb::module_& m) { mode: Padding mode. One of the following strings: "constant" (default): Pads with a constant value. "edge": Pads with the edge values of array. + "reflect": Pads with the reflection of the array, without repeating the edge values. + "symmetric": Pads with the reflection of the array, repeating the edge values. constant_value (array or scalar, optional): Optional constant value to pad the edges of the array with. diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 484ae47c3a..91a60034fd 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2068,6 +2068,38 @@ def test_nan_to_num(self): out_mx = mx.nan_to_num(a, nan=0.0, posinf=1000, neginf=-1000) self.assertTrue(np.allclose(out_mx, out_np)) + def test_pad_reflect_symmetric(self): + # mx.pad reflect/symmetric must match numpy.pad exactly (it is a gather). + # Covers in-bounds, multi-reflect (pad larger than the axis), asymmetric + # per-axis widths, zero-width sides, and degenerate axes (n == 1, n == 2). + cases = [ + ((8,), [(2, 3)]), + ((8,), [(0, 4)]), + ((8,), [(3, 0)]), + ((8,), [(7, 8)]), + ((4,), [(10, 7)]), # multi-reflect + ((4,), [(20, 20)]), # multi-reflect, both sides + ((3,), [(9, 1)]), # multi-reflect + ((1,), [(3, 2)]), # degenerate axis + ((2,), [(5, 6)]), # smallest non-trivial, multi-reflect + ((5, 6), [(2, 3), (1, 2)]), + ((5, 6), [(9, 9), (11, 0)]), # both axes multi-reflect + ((3, 4, 5), [(1, 1), (0, 0), (2, 2)]), + ((3, 4, 5), [(4, 4), (0, 0), (7, 3)]), + ] + for mode in ("reflect", "symmetric"): + for shape, pw in cases: + a_npy = np.random.randn(*shape).astype(np.float32) + a_mlx = mx.array(a_npy) + b_npy = np.pad(a_npy, pw, mode=mode) + b_mlx = mx.pad(a_mlx, pw, mode=mode) + self.assertEqual(b_mlx.shape, tuple(b_npy.shape)) + self.assertTrue( + np.array_equal(np.array(b_mlx), b_npy), + msg=f"mismatch mode={mode} shape={shape} pad={pw}", + ) + self.assertEqual(b_mlx.dtype, mx.float32) + def test_as_strided(self): x_npy = np.random.randn(128).astype(np.float32) x_mlx = mx.array(x_npy) diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index c2230f77cf..7f0421a037 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2868,6 +2868,49 @@ TEST_CASE("test pad") { 0.0f}, {4, 4}); CHECK(array_equal(padded_x, expected).item()); + + // reflect padding (mirror without repeating the edge value) + x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {5}); + CHECK(array_equal( + pad(x, {{2, 2}}, array(0.0f), "reflect"), + array({3.0f, 2.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 4.0f, 3.0f}, {9})) + .item()); + CHECK(array_equal( + pad(x, {{0, 3}}, array(0.0f), "reflect"), + array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 4.0f, 3.0f, 2.0f}, {8})) + .item()); + + // symmetric padding (mirror repeating the edge value) + CHECK(array_equal( + pad(x, {{2, 2}}, array(0.0f), "symmetric"), + array({2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 5.0f, 4.0f}, {9})) + .item()); + CHECK(array_equal( + pad(x, {{3, 0}}, array(0.0f), "symmetric"), + array({3.0f, 2.0f, 1.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {8})) + .item()); + + // multi-reflect: pad larger than the axis repeats the reflection (numpy + // parity) + x = array({1.0f, 2.0f, 3.0f}, {3}); + CHECK(array_equal( + pad(x, {{5, 5}}, array(0.0f), "reflect"), + array( + {2.0f, + 1.0f, + 2.0f, + 3.0f, + 2.0f, + 1.0f, + 2.0f, + 3.0f, + 2.0f, + 1.0f, + 2.0f, + 3.0f, + 2.0f}, + {13})) + .item()); } TEST_CASE("test power") {