Describe the bug
Under mx.compile, assigning to a negative-strided slice (a[::-1] = expr or a[::-1] += expr) produces wrong results when the right-hand side is a non-trivial elementwise expression of a negative-strided view. Only a single element ends up written; the rest stay zero. The same function is correct in eager mode, and is correct under compile when the slice is positive-strided or when the RHS is a bare view.
To Reproduce
import mlx.core as mx
def f(x):
base = mx.zeros_like(x)
base[::-1] += 2.0 * x[::-1] # negative-strided scatter; RHS is an elementwise expr
return base
x = mx.arange(6, dtype=mx.float32)
print("eager :", f(x).tolist()) # [0.0, 2.0, 4.0, 6.0, 8.0, 10.0] (correct)
print("compiled:", mx.compile(f)(x).tolist()) # [0.0, 0.0, 0.0, 0.0, 0.0, 10.0] (WRONG)
The set form fails identically:
def g(x):
base = mx.zeros_like(x)
base[::-1] = 2.0 * x[::-1]
return base
mx.compile(g)(x).tolist() # [0.0, 0.0, 0.0, 0.0, 0.0, 10.0] (WRONG; eager is [0,2,4,6,8,10])
Characterization (what does / doesn't trigger it)
All in mx.compile, x = mx.arange(6, dtype=mx.float32):
| Case |
Compiled result |
Correct? |
base[::-1] += 2.0 * x[::-1] |
[0,0,0,0,0,10] |
❌ |
base[::-1] = 2.0 * x[::-1] |
[0,0,0,0,0,10] |
❌ |
base[::-1] += x[::-1] + x[::-1] |
[0,0,0,0,0,10] |
❌ |
base[::-1] += x[::-1] (bare view RHS) |
[0,1,2,3,4,5] |
✅ |
base[::-1] += x (non-strided RHS) |
[5,4,3,2,1,0] |
✅ |
base[::2] += 2.0 * x[::2] (positive step) |
[0,0,4,0,8,0] |
✅ |
| any of the above in eager mode |
— |
✅ |
So the bug needs the combination of (a) a negative-strided destination slice, (b) an RHS that is an elementwise expression (not a bare view), and (c) mx.compile.
Expected behavior
mx.compile(f)(x) should equal f(x) (eager), i.e. [0, 2, 4, 6, 8, 10].
Desktop (please complete the following information):
- OS: macOS 26.5.1 (arm64, Apple Silicon)
- MLX version: 0.31.2
Describe the bug
Under
mx.compile, assigning to a negative-strided slice (a[::-1] = exprora[::-1] += expr) produces wrong results when the right-hand side is a non-trivial elementwise expression of a negative-strided view. Only a single element ends up written; the rest stay zero. The same function is correct in eager mode, and is correct under compile when the slice is positive-strided or when the RHS is a bare view.To Reproduce
The
setform fails identically:Characterization (what does / doesn't trigger it)
All in
mx.compile,x = mx.arange(6, dtype=mx.float32):base[::-1] += 2.0 * x[::-1][0,0,0,0,0,10]base[::-1] = 2.0 * x[::-1][0,0,0,0,0,10]base[::-1] += x[::-1] + x[::-1][0,0,0,0,0,10]base[::-1] += x[::-1](bare view RHS)[0,1,2,3,4,5]base[::-1] += x(non-strided RHS)[5,4,3,2,1,0]base[::2] += 2.0 * x[::2](positive step)[0,0,4,0,8,0]So the bug needs the combination of (a) a negative-strided destination slice, (b) an RHS that is an elementwise expression (not a bare view), and (c)
mx.compile.Expected behavior
mx.compile(f)(x)should equalf(x)(eager), i.e.[0, 2, 4, 6, 8, 10].Desktop (please complete the following information):