[codex] Remove Vulkan host sync fallbacks#57
Conversation
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7835acef9b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
|
||
| layout(binding = 0) readonly buffer Keys { uint keys[]; }; | ||
| layout(binding = 1) writeonly buffer Out { uint out_data[]; }; | ||
| layout(scalar, binding = 1) writeonly buffer Out { uint8_t out_data[]; }; |
There was a problem hiding this comment.
Restore the 8-bit storage extension for random bits
This shader now declares the output as uint8_t out_data[] in a storage buffer, but it only enables scalar block layout; storage-buffer use of uint8_t requires the 8-bit storage extension. Other Vulkan shaders in this tree that put uint8_t in buffers enable GL_EXT_shader_8bit_storage, so RandomBits will fail shader compilation/pipeline creation instead of producing bits for random::bits/uniform.
Useful? React with 👍 / 👎.
| (reduce_type == Scatter::None || reduce_type == Scatter::Sum || | ||
| reduce_type == Scatter::Prod || reduce_type == Scatter::Max || | ||
| reduce_type == Scatter::Min)) { |
There was a problem hiding this comment.
Keep bf16 and complex reductions off generic scatter
This newly routes scatter_prod/scatter_max/scatter_min through the generic shader for every dtype, but that shader operates on dtype_to_glsl_storage_type: bfloat16 is uint16_t, and complex64 prod is just vec2 *= vec2, so the reductions are performed on encoded bits/component-wise values rather than MLX dtype values. For example, a bfloat16 scatter_prod with an update of 2 multiplies the 0x3f80/0x4000 encodings instead of 1 * 2, producing corrupt results even without duplicate indices.
Useful? React with 👍 / 👎.
| } else if (use_float_atomic_cas) { | ||
| os << "\n atomic_reduce(dst_offset, read_update(linear_idx));\n"; |
There was a problem hiding this comment.
Make non-float scatter reductions atomic before enabling them
The CAS path added here only protects float32; the newly enabled integer/half Prod/Max/Min cases below still do a plain read-modify-write to out_data[dst_offset]. When two updates target the same destination, parallel invocations race and can drop one reduction, e.g. int32 scatter_prod with indices [0, 0] and updates [2, 3] can leave 2 or 3 instead of 6.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
mlx/mlx/backend/vulkan/arange.cpp
Lines 60 to 66 in 7a95237
When out.dtype() is float16 or bfloat16 with integral start/step, this now always falls through to the shader path; arange.comp computes start_f32 + step_f32 * float(i) and then casts, while the previous Vulkan path intentionally advanced in the target dtype each step to match CPU semantics. For example, a bfloat16 arange with step 1 diverges once the running bfloat16 value reaches 256, so Vulkan will silently produce different values from CPU for larger low-precision ranges.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| {static_cast<uint32_t>(num_keys), | ||
| static_cast<uint32_t>(half_size + odd), | ||
| 1}); |
There was a problem hiding this comment.
Dispatch random bits by workgroups, not keys
The explicit grid passed here is used directly as vkCmdDispatch workgroup counts in dispatch_with_spec, but random_bits.comp declares local_size_x = 256. Passing num_keys therefore launches 256 * num_keys x-invocations; batched random generation wastes 255 extra lanes per key and can exceed the device's max x workgroup count even when ceil(num_keys / 256) would be valid.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c5875165ef
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| for (uint j = 0; j < i; ++j) { | ||
| value = float16_t(value + step); |
There was a problem hiding this comment.
Avoid quadratic work in low-precision arange
When float16/bfloat16 arange has integral start and step, KY is set and every invocation recomputes its element by looping from 0 to i (the bf16 branch does the same just below). Large ranges therefore do roughly n² additions in the shader; for example a 1M-element half arange requires about 5e11 additions and can easily hang or time out the GPU, replacing the previous linear host loop with a much worse fallback.
Useful? React with 👍 / 👎.
| reduce_type)) { | ||
| return true; | ||
| } | ||
| return false; |
There was a problem hiding this comment.
Let pair scatter-add fall through to static shaders
For two-axis scatter_add with float16 or bfloat16 values and composed updates, this branch is entered but try_dispatch_generic_scatter returns false because dynamic sum reductions only allow float32/int32/uint32. The new unconditional return then skips the existing scatter_sum_pair_shader_id path below, which has f16/bf16 pair shaders and alignment checks, so valid Vulkan pair scatter-adds now fail instead of using the supported GPU shader.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 9e4ecb7dd9
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| num_elements <= kMaxSequentialLowPrecisionArange && | ||
| start == std::trunc(start) && step == std::trunc(step); |
There was a problem hiding this comment.
Preserve sequential fp16/bf16 arange past the cutoff
When float16/bfloat16 arange has integral start/step and more than 4096 elements, this leaves KY disabled and the shader falls back to start + step * i in fp32 before casting. CPU arange, and the previous Vulkan fallback, advance in the target dtype; for example arange(0, 4097, 1, float16) keeps rounding after 2048 while the fp32 formula writes values around 4096 at the tail. This makes larger low-precision ranges silently diverge from CPU/Vulkan's intended semantics rather than just avoiding the quadratic shader path.
Useful? React with 👍 / 👎.
| if (out_slice_shape == slice_sizes) { | ||
| std::vector<array> generic_inputs = {src_input, idx}; | ||
| std::vector<int> norm_axes = {axis}; | ||
| if (try_dispatch_generic_gather( |
There was a problem hiding this comment.
Keep narrow indices out of generic gather
When a take-like gather uses int8/int16/uint8/uint16 indices, gather_shader_id is empty, so this new fallback dispatches the generic shader. That shader maps every non-int64/uint64/uint32 index dtype to int and does not enable 8/16-bit index storage, so an int16 index buffer is read as 32-bit entries; e.g. take(x, array([1, 2], int16), axis) can gather from packed pairs/out-of-bounds instead of elements 1 and 2. Either reject these index widths here or emit the correct index storage type.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: fb5d93b5ce
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| {static_cast<uint32_t>((num_keys + 255) / 256), | ||
| static_cast<uint32_t>(half_size + odd), | ||
| 1}); |
There was a problem hiding this comment.
Split random-bit work over valid dispatch dimensions
The x-count issue is fixed, but this still passes half_size + odd directly as vkCmdDispatch's Y workgroup count. For a single large random array, e.g. random.uniform({1000000}), half_size + odd is about 500k, which exceeds the common Vulkan maxComputeWorkGroupCount[1] limit of 65,535 and makes the command invalid instead of generating random bits; tile this dimension or spill into Z the way other large dispatches do.
Useful? React with 👍 / 👎.
| (reduce_type == Scatter::None || reduce_type == Scatter::Sum || | ||
| reduce_type == Scatter::Prod || reduce_type == Scatter::Max || | ||
| reduce_type == Scatter::Min)) { |
There was a problem hiding this comment.
Reject narrow indices before generic scatter reductions
By enabling scatter_prod/scatter_max/scatter_min to use the generic path here, narrow index arrays (int8/int16/uint8/uint16, which CPU scatter accepts) now produce wrong destinations: try_dispatch_generic_scatter does not reject those dtypes, and the generated shader maps every non-uint32/int64/uint64 index to int, so it reads packed 8/16-bit index buffers as 32-bit entries. Please keep these reductions off the generic shader unless the index storage type is emitted correctly.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a304c37f30
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| case Scatter::Prod: | ||
| case Scatter::Max: | ||
| case Scatter::Min: | ||
| return dtype == float32 || dtype == int32 || dtype == uint32; |
There was a problem hiding this comment.
Preserve non-generic scatter reduction fallback
Fresh evidence in the current code is that this dtype guard now rejects float16/bfloat16/complex Prod/Max/Min, and the composed single/pair callers return false for non-sum reductions when generic dispatch fails. That means previously supported Vulkan cases that went through the SliceUpdate-based path, such as a one-axis bfloat16 scatter_prod with normal composed updates, now throw as unsupported instead of producing a result.
Useful? React with 👍 / 👎.
| command_buffer, | ||
| stream(), | ||
| push_constants, | ||
| {static_cast<uint32_t>((num_keys + 255) / 256), |
There was a problem hiding this comment.
Check X dispatch limit for random bits
This dispatch now tiles the output words over Y/Z, but the X workgroup count is still passed directly from (num_keys + 255) / 256 without comparing it to limits.maxComputeWorkGroupCount[0]. When RandomBits is vmapped or otherwise receives a large batch of keys, the command can exceed the device's X dispatch limit and become invalid even though the Y/Z dimensions were guarded.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 9d129bb8d7
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if (!keys.flags().contiguous || keys.offset() != 0 || | ||
| keys.strides().back() != 1) { | ||
| keys = contiguous_copy_gpu(keys, stream()); | ||
| } |
There was a problem hiding this comment.
Copy host random keys before dispatch
When the key is the usual host array from random::key/the default KeySequence, it is already contiguous with offset 0, so this condition skips contiguous_copy_gpu and the later dispatch_random_bits_op binds a non-Vulkan buffer. make_buffer_info then throws Missing Vulkan buffer for keys, so ordinary Vulkan random generation with a CPU key now fails instead of using the previous host key path; also check vulkan::is_vulkan_storage_array(keys) before deciding no copy is needed.
Useful? React with 👍 / 👎.
| make_arange_push_constants_t<float16_t>(num_elements, start, step); | ||
| const bool use_sequential_low_precision = | ||
| start == std::trunc(start) && step == std::trunc(step); | ||
| push_constants.KY = use_sequential_low_precision ? 1 : 0; |
There was a problem hiding this comment.
Preserve rounded step in low-precision arange
Enabling the sequential shader path here makes f16/bf16 arange use p.step_f32, but make_arange_push_constants_t derives that from round(start + step) - round(start) instead of the separately rounded step used by the removed CPU fallback. For integral starts that are not exactly representable, e.g. arange(2049, 2053, 1, float16), this sends a step of 2 and writes 2048, 2050, ... while the intended sequential low-precision update adds the rounded step value 1 each iteration and rounds after each addition.
Useful? React with 👍 / 👎.
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 9ac80fd6a2
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| #extension GL_EXT_scalar_block_layout : require | ||
| #extension GL_EXT_shader_8bit_storage : require |
There was a problem hiding this comment.
Move random_bits extensions before the include
In this shader, types.glsl is included before these new #extension directives, and that include expands to declarations such as structs using typed fields before the directives are seen. GLSL/shaderc requires extension directives to appear before any non-preprocessor tokens, so regenerating or compiling random_bits.comp can reject GL_EXT_scalar_block_layout/GL_EXT_shader_8bit_storage, breaking Vulkan RandomBits dispatch. Put the new extensions before #include "types.glsl".
Useful? React with 👍 / 👎.
|
@codex review |
|
You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard. |
|
@codex review |
|
You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard. |
Summary
Validation
./dev.sh build./dev.sh test-cpp- 249 tests, 3314 assertions./dev.sh test-py- 699 passed, 13 skipped./dev.sh generate- coherent output./dev.sh model-report- all 10 models generated coherent output./dev.sh update-benchmark --skip-model-reportBenchmark delta from results.csv
Compared with the previous latest rows from 2026-06-18T15:23:39Z: