Skip to content

Add more array API functions#3684

Closed
katlun-lgtm wants to merge 1 commit into
ml-explore:mainfrom
katlun-lgtm:array-api-batch
Closed

Add more array API functions#3684
katlun-lgtm wants to merge 1 commit into
ml-explore:mainfrom
katlun-lgtm:array-api-batch

Conversation

@katlun-lgtm

Copy link
Copy Markdown
Contributor

Proposed changes

Adds a batch of array API functions to mlx.core, all built on existing primitives (no core/Metal changes):

Elementwise / utility

  • positive(a), logical_xor(a, b), trunc(a)
  • count_nonzero(a, /, *, axis=None, keepdims=False)
  • diff(a, /, n=1, axis=-1, *, prepend=None, append=None)

Creation

  • full_like(a, vals, dtype=None)
  • empty(shape, dtype=...), empty_like(a, dtype=None) — these return zeros, since MLX does not expose uninitialized memory

Free-function wrappers

  • astype(a, dtype) and matrix_transpose(a) mirror the existing array method/property
  • cumulative_sum / cumulative_prod wrap cumsum / cumprod with the array API axis (flatten when None), dtype, and include_initial semantics

Inspection

  • __array_namespace_info__() returning an object with capabilities(), default_device(), default_dtypes(), devices(), and dtypes(kind=...)

All added to the ops docs and tested in test_ops.py / test_array.py. Part of #3484.

Checklist

  • I have read the CONTRIBUTING document
  • clang-format and black (the formatters configured in .pre-commit-config.yaml) report no changes on the modified files
  • Added tests (test_ops.py: test_array_api_elementwise, test_diff, test_array_api_creation, test_astype_and_matrix_transpose, test_cumulative_sum_prod; test_array.py: test_array_namespace_info)
  • Built from source and ran the tests locally: the new tests pass, and the full test_ops.py (144) and test_array.py (75) suites pass with no regressions

Adds array API functions toward ml-explore#3484, all built on existing primitives
(no core changes):

- Elementwise / utility: positive, logical_xor, trunc, count_nonzero, diff
- Creation: full_like, empty, empty_like (empty / empty_like return zeros
  since MLX does not expose uninitialized memory)
- Free functions: astype, matrix_transpose, cumulative_sum, cumulative_prod
- Inspection: __array_namespace_info__ (capabilities, default_device,
  default_dtypes, devices, dtypes)

Adds them to the ops docs and tests in test_ops.py / test_array.py.
@katlun-lgtm

Copy link
Copy Markdown
Contributor Author

Same question as #3683: these are currently composed in python/src/ops.cpp (all pure compositions of existing ops, no new kernels). I'm happy to move them into mlx/ops.{h,cpp} as core ops with C++ tests, the same way as #3683 — but it's a fair number of functions, so I wanted to check scope before reworking.

A few design calls I'd like your steer on first:

  • empty / empty_like: MLX has no uninitialized buffer, so these currently return zeros — happy to drop them if you'd rather not have them.
  • astype: a free-function form of the existing array.astype method.
  • positive: identity (array API completeness).
  • cumulative_sum / cumulative_prod: wrap cumsum/cumprod with the array-API axis=None flatten + include_initial semantics.

Want me to move the whole set to core, or would you prefer to split/trim it first?

@zcbenz

zcbenz commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

This PR is adding too much different things so it is hard to review. Generally speaking for pure aliases it should be done like #3678, otherwise the ops should have C++ versions and then exposed in python. I'm not sure about __array_namespace_info__ and I need to read about the docs later.

  • empty / empty_like: MLX has no uninitialized buffer, so these currently return zeros — happy to drop them if you'd rather not have them.

I'm good just making them alias of zeros/zeros_like.

  • astype: a free-function form of the existing array.astype method.

👍

  • positive: identity (array API completeness).

It can just be a shallow copy like how __copy__ is implemented.

  • cumulative_sum / cumulative_prod: wrap cumsum/cumprod with the array-API axis=None flatten + include_initial semantics.

I didn't check how cumulative_sum differences from cumsum yet but in principle we should not have 2 APIs having similar names but do different things.

@katlun-lgtm

Copy link
Copy Markdown
Contributor Author

Per @zcbenz's feedback I've split this PR into smaller, focused pieces. All three new PRs are open against main:

Held pending review:
Branch katlun-lgtm:array-api-namespace-info contains __array_namespace_info__ + dtype_matches_kind. Not opening as a PR until there is clarity on the approach. Closing this original PR in favour of the above.

@katlun-lgtm katlun-lgtm deleted the array-api-batch branch June 20, 2026 00:36
katlun-lgtm added a commit to katlun-lgtm/mlx that referenced this pull request Jun 20, 2026
empty/empty_like are pure aliases of zeros/zeros_like via
m.attr("empty") = m.attr("zeros"), matching the pattern from ml-explore#3678.
MLX does not expose uninitialized memory so zeros is the correct
semantic match.

astype exposes mx::astype as a free function (Array API §2.0).
matrix_transpose transposes the last two dimensions and validates
ndim >= 2.

Docs and tests included.

Part of the array API split from ml-explore#3684.
zcbenz pushed a commit to katlun-lgtm/mlx that referenced this pull request Jun 21, 2026
empty/empty_like are pure aliases of zeros/zeros_like via
m.attr("empty") = m.attr("zeros"), matching the pattern from ml-explore#3678.
MLX does not expose uninitialized memory so zeros is the correct
semantic match.

astype exposes mx::astype as a free function (Array API §2.0).
matrix_transpose transposes the last two dimensions and validates
ndim >= 2.

Docs and tests included.

Part of the array API split from ml-explore#3684.
katlun-lgtm added a commit to katlun-lgtm/mlx that referenced this pull request Jun 21, 2026
These are the Array API standard equivalents of cumsum/cumprod with
three key differences that justify the separate names:

1. axis=None (default) flattens the input first; cumsum/cumprod require
   an explicit axis.
2. include_initial=True prepends the identity element (0 for sum, 1 for
   prod) so the output length along axis is len+1.  This matches the
   Array API spec's include_initial parameter and has no equivalent in
   cumsum/cumprod.
3. dtype parameter casts the input before accumulating, matching NumPy
   2.0 / Array API behaviour.

Docs and tests included.

Part of the array API split from ml-explore#3684.
katlun-lgtm added a commit to katlun-lgtm/mlx that referenced this pull request Jun 21, 2026
These are the Array API standard equivalents of cumsum/cumprod with
three key differences that justify the separate names:

1. axis=None (default) flattens the input first; cumsum/cumprod require
   an explicit axis.
2. include_initial=True prepends the identity element (0 for sum, 1 for
   prod) so the output length along axis is len+1.  This matches the
   Array API spec's include_initial parameter and has no equivalent in
   cumsum/cumprod.
3. dtype parameter casts the input before accumulating, matching NumPy
   2.0 / Array API behaviour.

Docs and tests included.

Part of the array API split from ml-explore#3684.
katlun-lgtm added a commit to katlun-lgtm/mlx that referenced this pull request Jun 21, 2026
…f, full_like

New elementwise and utility ops required by the Python array API standard
(https://data-apis.org/array-api/latest/):

- positive: unary plus, returns a copy (mx::astype to same dtype)
- logical_xor: element-wise XOR via not_equal(bool(a), bool(b))
- trunc: truncate toward zero (where(a < 0, ceil, floor))
- count_nonzero: count non-zero elements; returns int32; supports axis/keepdims
- diff: n-th discrete difference along an axis, with optional prepend/append
- full_like: fill an array shaped like the input; optional dtype override

Docs and tests included.

Part of the array API split from ml-explore#3684.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants