Skip to content

[codex] Remove Vulkan host sync fallbacks#57

Open
goniz wants to merge 9 commits into
feat/vulkanfrom
codex/vulkan-native-pipeline-fixes
Open

[codex] Remove Vulkan host sync fallbacks#57
goniz wants to merge 9 commits into
feat/vulkanfrom
codex/vulkan-native-pipeline-fixes

Conversation

@goniz

@goniz goniz commented Jun 19, 2026

Copy link
Copy Markdown
Owner

Summary

  • replace RandomBits host readback/generation with a Vulkan shader dispatch
  • remove gather/scatter host-loop fallbacks and route supported composed cases through GPU dispatch
  • keep scalar/singleton gather fast paths limited to CPU-resident constants
  • remove the low-precision arange CPU materialization path

Validation

  • ./dev.sh build
  • ./dev.sh test-cpp - 249 tests, 3314 assertions
  • ./dev.sh test-py - 699 passed, 13 skipped
  • ./dev.sh generate - coherent output
  • ./dev.sh model-report - all 10 models generated coherent output
  • ./dev.sh update-benchmark --skip-model-report

Benchmark delta from results.csv

Compared with the previous latest rows from 2026-06-18T15:23:39Z:

Model Prompt TPS Generation TPS Peak Mem
Qwen3-0.6B-bf16 2860.080 -> 2862.407 (+0.08%) 65.950 -> 65.486 (-0.70%) 2.614 -> 2.614
Qwen3-0.6B-8bit 1464.916 -> 1458.232 (-0.46%) 87.609 -> 85.857 (-2.00%) 2.056 -> 2.055
Qwen3.6-35B-A3B-8bit 120.009 -> 120.704 (+0.58%) 21.288 -> 20.996 (-1.37%) 40.350 -> 40.349

@goniz goniz marked this pull request as ready for review June 19, 2026 15:46
@goniz

goniz commented Jun 19, 2026

Copy link
Copy Markdown
Owner Author

@codex review

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 7835acef9b

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".


layout(binding = 0) readonly buffer Keys { uint keys[]; };
layout(binding = 1) writeonly buffer Out { uint out_data[]; };
layout(scalar, binding = 1) writeonly buffer Out { uint8_t out_data[]; };

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore the 8-bit storage extension for random bits

This shader now declares the output as uint8_t out_data[] in a storage buffer, but it only enables scalar block layout; storage-buffer use of uint8_t requires the 8-bit storage extension. Other Vulkan shaders in this tree that put uint8_t in buffers enable GL_EXT_shader_8bit_storage, so RandomBits will fail shader compilation/pipeline creation instead of producing bits for random::bits/uniform.

Useful? React with 👍 / 👎.

Comment on lines +688 to +690
(reduce_type == Scatter::None || reduce_type == Scatter::Sum ||
reduce_type == Scatter::Prod || reduce_type == Scatter::Max ||
reduce_type == Scatter::Min)) {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep bf16 and complex reductions off generic scatter

This newly routes scatter_prod/scatter_max/scatter_min through the generic shader for every dtype, but that shader operates on dtype_to_glsl_storage_type: bfloat16 is uint16_t, and complex64 prod is just vec2 *= vec2, so the reductions are performed on encoded bits/component-wise values rather than MLX dtype values. For example, a bfloat16 scatter_prod with an update of 2 multiplies the 0x3f80/0x4000 encodings instead of 1 * 2, producing corrupt results even without duplicate indices.

Useful? React with 👍 / 👎.

Comment on lines +203 to +204
} else if (use_float_atomic_cas) {
os << "\n atomic_reduce(dst_offset, read_update(linear_idx));\n";

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Make non-float scatter reductions atomic before enabling them

The CAS path added here only protects float32; the newly enabled integer/half Prod/Max/Min cases below still do a plain read-modify-write to out_data[dst_offset]. When two updates target the same destination, parallel invocations race and can drop one reduction, e.g. int32 scatter_prod with indices [0, 0] and updates [2, 3] can leave 2 or 3 instead of 6.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

vulkan::dispatch_arange_op(
out,
*shader_id,
command_buffer,
s,
start,
step);

P2 Badge Preserve sequential low-precision arange semantics

When out.dtype() is float16 or bfloat16 with integral start/step, this now always falls through to the shader path; arange.comp computes start_f32 + step_f32 * float(i) and then casts, while the previous Vulkan path intentionally advanced in the target dtype each step to match CPU semantics. For example, a bfloat16 arange with step 1 diverges once the running bfloat16 value reaches 256, so Vulkan will silently produce different values from CPU for larger low-precision ranges.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread mlx/backend/vulkan/random.cpp Outdated
Comment on lines +62 to +64
{static_cast<uint32_t>(num_keys),
static_cast<uint32_t>(half_size + odd),
1});

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Dispatch random bits by workgroups, not keys

The explicit grid passed here is used directly as vkCmdDispatch workgroup counts in dispatch_with_spec, but random_bits.comp declares local_size_x = 256. Passing num_keys therefore launches 256 * num_keys x-invocations; batched random generation wastes 255 extra lanes per key and can exceed the device's max x workgroup count even when ceil(num_keys / 256) would be valid.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c5875165ef

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread mlx/backend/vulkan/kernels/arange.comp Outdated
Comment on lines +30 to +31
for (uint j = 0; j < i; ++j) {
value = float16_t(value + step);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid quadratic work in low-precision arange

When float16/bfloat16 arange has integral start and step, KY is set and every invocation recomputes its element by looping from 0 to i (the bf16 branch does the same just below). Large ranges therefore do roughly n² additions in the shader; for example a 1M-element half arange requires about 5e11 additions and can easily hang or time out the GPU, replacing the previous linear host loop with a much worse fallback.

Useful? React with 👍 / 👎.

Comment thread mlx/backend/vulkan/scatter.cpp Outdated
reduce_type)) {
return true;
}
return false;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Let pair scatter-add fall through to static shaders

For two-axis scatter_add with float16 or bfloat16 values and composed updates, this branch is entered but try_dispatch_generic_scatter returns false because dynamic sum reductions only allow float32/int32/uint32. The new unconditional return then skips the existing scatter_sum_pair_shader_id path below, which has f16/bf16 pair shaders and alignment checks, so valid Vulkan pair scatter-adds now fail instead of using the supported GPU shader.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 9e4ecb7dd9

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread mlx/backend/vulkan/kernels.cpp Outdated
Comment on lines +832 to +833
num_elements <= kMaxSequentialLowPrecisionArange &&
start == std::trunc(start) && step == std::trunc(step);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve sequential fp16/bf16 arange past the cutoff

When float16/bfloat16 arange has integral start/step and more than 4096 elements, this leaves KY disabled and the shader falls back to start + step * i in fp32 before casting. CPU arange, and the previous Vulkan fallback, advance in the target dtype; for example arange(0, 4097, 1, float16) keeps rounding after 2048 while the fp32 formula writes values around 4096 at the tail. This makes larger low-precision ranges silently diverge from CPU/Vulkan's intended semantics rather than just avoiding the quadratic shader path.

Useful? React with 👍 / 👎.

if (out_slice_shape == slice_sizes) {
std::vector<array> generic_inputs = {src_input, idx};
std::vector<int> norm_axes = {axis};
if (try_dispatch_generic_gather(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep narrow indices out of generic gather

When a take-like gather uses int8/int16/uint8/uint16 indices, gather_shader_id is empty, so this new fallback dispatches the generic shader. That shader maps every non-int64/uint64/uint32 index dtype to int and does not enable 8/16-bit index storage, so an int16 index buffer is read as 32-bit entries; e.g. take(x, array([1, 2], int16), axis) can gather from packed pairs/out-of-bounds instead of elements 1 and 2. Either reject these index widths here or emit the correct index storage type.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: fb5d93b5ce

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread mlx/backend/vulkan/random.cpp Outdated
Comment on lines +62 to +64
{static_cast<uint32_t>((num_keys + 255) / 256),
static_cast<uint32_t>(half_size + odd),
1});

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Split random-bit work over valid dispatch dimensions

The x-count issue is fixed, but this still passes half_size + odd directly as vkCmdDispatch's Y workgroup count. For a single large random array, e.g. random.uniform({1000000}), half_size + odd is about 500k, which exceeds the common Vulkan maxComputeWorkGroupCount[1] limit of 65,535 and makes the command invalid instead of generating random bits; tile this dimension or spill into Z the way other large dispatches do.

Useful? React with 👍 / 👎.

Comment on lines +727 to +729
(reduce_type == Scatter::None || reduce_type == Scatter::Sum ||
reduce_type == Scatter::Prod || reduce_type == Scatter::Max ||
reduce_type == Scatter::Min)) {

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reject narrow indices before generic scatter reductions

By enabling scatter_prod/scatter_max/scatter_min to use the generic path here, narrow index arrays (int8/int16/uint8/uint16, which CPU scatter accepts) now produce wrong destinations: try_dispatch_generic_scatter does not reject those dtypes, and the generated shader maps every non-uint32/int64/uint64 index to int, so it reads packed 8/16-bit index buffers as 32-bit entries. Please keep these reductions off the generic shader unless the index storage type is emitted correctly.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a304c37f30

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

case Scatter::Prod:
case Scatter::Max:
case Scatter::Min:
return dtype == float32 || dtype == int32 || dtype == uint32;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve non-generic scatter reduction fallback

Fresh evidence in the current code is that this dtype guard now rejects float16/bfloat16/complex Prod/Max/Min, and the composed single/pair callers return false for non-sum reductions when generic dispatch fails. That means previously supported Vulkan cases that went through the SliceUpdate-based path, such as a one-axis bfloat16 scatter_prod with normal composed updates, now throw as unsupported instead of producing a result.

Useful? React with 👍 / 👎.

Comment thread mlx/backend/vulkan/random.cpp Outdated
command_buffer,
stream(),
push_constants,
{static_cast<uint32_t>((num_keys + 255) / 256),

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Check X dispatch limit for random bits

This dispatch now tiles the output words over Y/Z, but the X workgroup count is still passed directly from (num_keys + 255) / 256 without comparing it to limits.maxComputeWorkGroupCount[0]. When RandomBits is vmapped or otherwise receives a large batch of keys, the command can exceed the device's X dispatch limit and become invalid even though the Y/Z dimensions were guarded.

Useful? React with 👍 / 👎.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 9d129bb8d7

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread mlx/backend/vulkan/random.cpp Outdated
Comment on lines 34 to 37
if (!keys.flags().contiguous || keys.offset() != 0 ||
keys.strides().back() != 1) {
keys = contiguous_copy_gpu(keys, stream());
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Copy host random keys before dispatch

When the key is the usual host array from random::key/the default KeySequence, it is already contiguous with offset 0, so this condition skips contiguous_copy_gpu and the later dispatch_random_bits_op binds a non-Vulkan buffer. make_buffer_info then throws Missing Vulkan buffer for keys, so ordinary Vulkan random generation with a CPU key now fails instead of using the previous host key path; also check vulkan::is_vulkan_storage_array(keys) before deciding no copy is needed.

Useful? React with 👍 / 👎.

make_arange_push_constants_t<float16_t>(num_elements, start, step);
const bool use_sequential_low_precision =
start == std::trunc(start) && step == std::trunc(step);
push_constants.KY = use_sequential_low_precision ? 1 : 0;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve rounded step in low-precision arange

Enabling the sequential shader path here makes f16/bf16 arange use p.step_f32, but make_arange_push_constants_t derives that from round(start + step) - round(start) instead of the separately rounded step used by the removed CPU fallback. For integral starts that are not exactly representable, e.g. arange(2049, 2053, 1, float16), this sends a step of 2 and writes 2048, 2050, ... while the intended sequential low-precision update adds the rounded step value 1 each iteration and rounds after each addition.

Useful? React with 👍 / 👎.

@goniz

goniz commented Jun 21, 2026

Copy link
Copy Markdown
Owner Author

@codex review

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 9ac80fd6a2

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +5 to +6
#extension GL_EXT_scalar_block_layout : require
#extension GL_EXT_shader_8bit_storage : require

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Move random_bits extensions before the include

In this shader, types.glsl is included before these new #extension directives, and that include expands to declarations such as structs using typed fields before the directives are seen. GLSL/shaderc requires extension directives to appear before any non-preprocessor tokens, so regenerating or compiling random_bits.comp can reject GL_EXT_scalar_block_layout/GL_EXT_shader_8bit_storage, breaking Vulkan RandomBits dispatch. Put the new extensions before #include "types.glsl".

Useful? React with 👍 / 👎.

@goniz

goniz commented Jun 21, 2026

Copy link
Copy Markdown
Owner Author

@codex review

@chatgpt-codex-connector

Copy link
Copy Markdown

You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard.

@goniz

goniz commented Jun 21, 2026

Copy link
Copy Markdown
Owner Author

@codex review

@chatgpt-codex-connector

Copy link
Copy Markdown

You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant