Implement batched matmul for large 1D dot products#3580
Conversation
|
@zcbenz would it possible for you to review this? |
zcbenz
left a comment
There was a problem hiding this comment.
This basically looks good to me, thanks!
|
@zcbenz I have made the changes could you review and merge. |
zcbenz
left a comment
There was a problem hiding this comment.
I would like another review from maintainers before merging.
Ok sure, is there someone specific I should tag to request a review? |
|
@angeloskath @jagrit06 it would be great if you could review these changes |
|
@zcbenz Its been several weeks since the changes and I believe that the changes are not very large in terms of number of lines, so would it be possible for this to be merged? |
|
Sorry WWDC had largely disrupted our schedule, there is no wrong with this PR it is just I need another view on this since I lack the background knowledge. WWDC is over now and there is a large backlog so please give us more time. |
|
No worries and thanks for the update. |
angeloskath
left a comment
There was a problem hiding this comment.
@Ved235 sorry for the super late reply, especially since it will be a negative response.
This is something we need to fix indeed but unfortunately this is not the way to fix it. Basically the ops should not really change based on the shape but the implementation should. The same way that we route to split-k kernel when the matrix K dimension is large.
I am gonna mark this as requested changes so we don't merge it by accident and make a new PR with a specialized kernel for this particular case. I am not sure how important it is but I think the kernel will end up being simple enough.
|
Should I make a PR for this specialised kernel? I have a solution ready |
|
@angeloskath I have implemented a specialised kernel. The performance benchmarks show the same improvement. Is this what you were expecting? |
angeloskath
left a comment
There was a problem hiding this comment.
It's nearly there.
I think the kernel needs a bit of restructuring but it is quite close to what is needed. I left comments inline.
|
Thanks for pushing this forward. I took a pass over the current specialized I could not benchmark this PR branch directly yet because my local source build Two things look worth checking before merge:
The new
The current route sends all contiguous On the current PyPI wheel, I see this baseline on GPU: So the large-vector problem is very real, but the small-K path probably needs A benchmark table over K and dtype would make the PR much easier to review:
The main thing to rule out is that broad routing trades the big-vector win for |
|
I have added a dtype checking for complex64. Apart from that here is a benchmark table for this implementation:
I tried comparing this implementation with the old GEMV path and according to me the results are highly variable. If required I think a condition around |
Proposed changes
Addresses issue #3533. Adds routing logic in
mlx/ops.cppso that it divides the large 1D dot product into chunks so gemv parallelizes.Benchmark
Using this benchmarking script the performance changes are:
to
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes