Skip to content

Implement batched matmul for large 1D dot products#3580

Open
Ved235 wants to merge 7 commits into
ml-explore:mainfrom
Ved235:main
Open

Implement batched matmul for large 1D dot products#3580
Ved235 wants to merge 7 commits into
ml-explore:mainfrom
Ved235:main

Conversation

@Ved235

@Ved235 Ved235 commented May 22, 2026

Copy link
Copy Markdown

Proposed changes

Addresses issue #3533. Adds routing logic in mlx/ops.cpp so that it divides the large 1D dot product into chunks so gemv parallelizes.

Benchmark

import mlx.core as mx
import numpy as np
import time

def bench(fn, rounds=100, label=""):
    for _ in range(3):
        r = fn()
        mx.eval(r)

    times = []
    for _ in range(rounds):
        mx.eval()  
        t0 = time.perf_counter()
        r = fn()
        mx.eval(r) 
        times.append(time.perf_counter() - t0)

    times.sort()
    median = times[len(times) // 2]
    best = times[0]
    worst = times[-1]
    print(f"{label}")
    print(f"median={median*1000:.3f}ms | min={best*1000:.3f}ms | max={worst*1000:.3f}ms")
    return r

a = mx.random.normal(shape=(50_000_000,), dtype=mx.float32)
b = mx.random.normal(shape=(50_000_000,), dtype=mx.float32)

a_np = np.array(a, copy=False)
b_np = np.array(b, copy=False)

ccc = bench(lambda: mx.inner(a, b), label="MLX native")

print(f"mx.inner : {float(ccc)}")

Using this benchmarking script the performance changes are:

median=15.393ms | min=15.323ms | max=15.769ms

to

median=1.741ms | min=1.682ms | max=1.835ms

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@Ved235

Ved235 commented May 23, 2026

Copy link
Copy Markdown
Author

@zcbenz would it possible for you to review this?

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This basically looks good to me, thanks!

Comment thread mlx/ops.cpp Outdated
@Ved235

Ved235 commented May 26, 2026

Copy link
Copy Markdown
Author

@zcbenz I have made the changes could you review and merge.

@zcbenz zcbenz left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like another review from maintainers before merging.

@Ved235

Ved235 commented May 28, 2026

Copy link
Copy Markdown
Author

I would like another review from maintainers before merging.

Ok sure, is there someone specific I should tag to request a review?

@zcbenz zcbenz requested review from angeloskath and jagrit06 May 28, 2026 22:49
@Ved235

Ved235 commented Jun 6, 2026

Copy link
Copy Markdown
Author

@angeloskath @jagrit06 it would be great if you could review these changes

@Ved235

Ved235 commented Jun 12, 2026

Copy link
Copy Markdown
Author

@zcbenz Its been several weeks since the changes and I believe that the changes are not very large in terms of number of lines, so would it be possible for this to be merged?

@zcbenz

zcbenz commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator

Sorry WWDC had largely disrupted our schedule, there is no wrong with this PR it is just I need another view on this since I lack the background knowledge. WWDC is over now and there is a large backlog so please give us more time.

@Ved235

Ved235 commented Jun 12, 2026

Copy link
Copy Markdown
Author

No worries and thanks for the update.

@angeloskath angeloskath left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ved235 sorry for the super late reply, especially since it will be a negative response.

This is something we need to fix indeed but unfortunately this is not the way to fix it. Basically the ops should not really change based on the shape but the implementation should. The same way that we route to split-k kernel when the matrix K dimension is large.

I am gonna mark this as requested changes so we don't merge it by accident and make a new PR with a specialized kernel for this particular case. I am not sure how important it is but I think the kernel will end up being simple enough.

@Ved235

Ved235 commented Jun 13, 2026

Copy link
Copy Markdown
Author

Should I make a PR for this specialised kernel? I have a solution ready

@Ved235

Ved235 commented Jun 16, 2026

Copy link
Copy Markdown
Author

@angeloskath I have implemented a specialised kernel. The performance benchmarks show the same improvement. Is this what you were expecting?

@Ved235 Ved235 requested a review from angeloskath June 16, 2026 16:33

@angeloskath angeloskath left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's nearly there.

I think the kernel needs a bit of restructuring but it is quite close to what is needed. I left comments inline.

Comment thread mlx/backend/metal/kernels/dot.metal Outdated
Comment thread mlx/backend/metal/kernels/dot.metal Outdated
Comment thread mlx/backend/metal/matmul.cpp Outdated
Comment thread mlx/backend/metal/matmul.cpp Outdated
@jeonghoon-ad

jeonghoon-ad commented Jun 20, 2026

Copy link
Copy Markdown

Thanks for pushing this forward. I took a pass over the current specialized
kernel direction and ran a small baseline sweep on the PyPI MLX wheel.

I could not benchmark this PR branch directly yet because my local source build
fails after the Metal library is built (dot.air / mlx.metallib are produced,
then the build fails later in the C++/SDK setup). So please treat the comments
below as review notes rather than measured PR results.

Two things look worth checking before merge:

  1. Dtype routing

The new Matmul::eval_gpu route does not appear to guard dtype, but
dot.metal only instantiates float32, float16, and bfloat16 kernels. Existing
Metal GEMV has complex64 support. If a contiguous complex64 1 x K @ K x 1
case reaches this path, it looks like it may try to load a missing
dot_product_complex64 kernel. A dtype guard that falls back to GEMV for
complex64, or a complex64 kernel plus test, would make this safer.

  1. Small-K regression risk

The current route sends all contiguous M == 1 && N == 1 && batch_size_out == 1
cases to the dot path. The implementation always launches dot_product and
then dot_reduce, even when blocks == 1, and also allocates a partials array.
That looks great for huge vectors, but it could regress small vectors where the
old GEMV path is already dominated by launch overhead.

On the current PyPI wheel, I see this baseline on GPU:

K=1          median 210 us
K=16         median 207 us
K=256        median 283 us
K=4,096      median 311 us
K=16,384     median 316 us
K=131,072    median 349 us
K=1,000,000  median 927 us
K=50,000,000 median 19.36 ms

So the large-vector problem is very real, but the small-K path probably needs
either a direct one-kernel blocks == 1 case or a measured threshold.

A benchmark table over K and dtype would make the PR much easier to review:

  • K: 1, 16, 256, 4096, 16384, 131072, 1M, 50M
  • dtype: float32, float16, bfloat16, plus complex64 fallback/guard coverage
  • shape cases: mx.inner(1D, 1D) and explicit (1, K) @ (K, 1)

The main thing to rule out is that broad routing trades the big-vector win for
small-vector or complex64 regressions.

@Ved235

Ved235 commented Jun 20, 2026

Copy link
Copy Markdown
Author

I have added a dtype checking for complex64. Apart from that here is a benchmark table for this implementation:

Size MLX native median (ms) NumPy median (ms)
1 0.163 0.000
16 0.146 0.000
256 0.143 0.000
256 0.153 0.000
4096 0.151 0.001
16384 0.156 0.001
131072 0.144 0.003
1000000 0.188 0.011
50000000 1.755 2.611

I tried comparing this implementation with the old GEMV path and according to me the results are highly variable. If required I think a condition around >= 500000 would be good enough to separate small vectors and big enough vectors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants