diff --git a/mlx/distributed/jaccl/lib/jaccl/ring_impl.h b/mlx/distributed/jaccl/lib/jaccl/ring_impl.h index 3ee91b3762..766f55386d 100644 --- a/mlx/distributed/jaccl/lib/jaccl/ring_impl.h +++ b/mlx/distributed/jaccl/lib/jaccl/ring_impl.h @@ -496,7 +496,7 @@ class RingImpl { // Prefill the pipeline for (int lw = 0; lw < n_wires; lw++) { int buff = 0; - while (N * buff < limits[lw] && buff < PIPELINE) { + while (N * buff < limits[lw] - write_offset[lw] && buff < PIPELINE) { recv_from(sz, buff, dir, lw); buff++;