fix(mooncake): use separate magic_recv buffer to prevent weight corruption#6813
fix(mooncake): use separate magic_recv buffer to prevent weight corruption#6813KunWuLuan wants to merge 2 commits into
Conversation
…ption Problem: MooncakeCheckpointEngine's daisy-chain weight sync produces degenerate inference output (e.g. '!!!!' repeated to max_response_length) when using multi-rank rollout on the same node. In Qwen3.6-27B, embed_tokens.weight first 4 bytes get overwritten with the magic completion marker [0xAB, 0xDC, 0xEF, 0x88], causing astronomical embedding values that saturate all downstream attention computations. Observation: After receiving a bucket, the receiver writes a 4-byte magic marker to the sender's DATA buffer as a completion signal via transfer_sync_write. Instrumentation shows the data buffer gets corrupted with magic bytes between the RDMA read and the subsequent usage, and the corruption propagates through the daisy chain to downstream ranks. The exact mechanism (whether transfer_sync_write has a local GPU side effect on intra-node RDMA, or another pathway) is still under investigation and not yet confirmed. Fix: - Add dedicated magic_recv buffer (8 bytes, one 4-byte slot per double-buffer) - Register magic_recv with TransferEngine for RDMA access - Send magic_ptr in info dict alongside data ptr - Write magic completion signal to magic_ptr instead of data buffer ptr - Regardless of the corruption mechanism, writing to a dedicated buffer ensures any side effects only modify magic_recv (harmless) instead of the data buffer (fatal) - Reset magic slot to 0 after detection for reuse Verification: - With fix, full training completes successfully - response_length/mean: 506.25 (vs 4096 with bug) - Agent trajectories: coherent English text - Checkpoint saved, no crashes
|
|
There was a problem hiding this comment.
Code Review
This pull request introduces a dedicated magic_recv buffer to handle magic completion signals separately from the data buffers, preventing potential data corruption during weight synchronization. However, a critical race condition was identified in wait_for_complete where the asynchronous zeroing of the magic buffer on the GPU can race with incoming RDMA writes from the next rank, potentially leading to a deadlock. Synchronizing the device after resetting the buffer is recommended to resolve this issue.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if torch.equal(buf[:4], magic): | ||
| buf[:4] = 0 # reset for next use | ||
| break |
There was a problem hiding this comment.
Critical Race Condition / Deadlock Risk
There is a critical race condition here that can lead to a permanent deadlock during weight synchronization.
Mechanism of the Bug:
buf[:4] = 0is a GPU operation that is queued asynchronously on the local GPU's default CUDA stream.- Immediately after queuing this operation,
wait_for_completereturns, and the receiver rank sends theinfometadata (containing the pointer to this magic slot) to the next rank viaself.store.send_obj(info, self.rank + 1). - The receiver rank then yields the tensors to the caller, suspending the coroutine. During this suspension, the next rank receives the
infometadata, reads the weights, and performs atransfer_sync_write(RDMA write) to write the magic bytes back to this rank's magic slot. - Because
transfer_sync_writeis an RDMA operation, it writes directly to the GPU memory via PCIe/NIC, completely bypassing the GPU's command processor and CUDA streams. - If the next rank's RDMA write completes before the local GPU has finished executing the queued
buf[:4] = 0kernel on the default stream (which is highly likely if the local GPU is busy or delayed), the local GPU's zeroing kernel will run after the RDMA write, overwriting the newly received magic bytes with0. - In the next iteration,
wait_for_completewill poll this slot forever because the magic bytes were overwritten and will never be sent again, causing a permanent hang.
Solution:
We must call get_torch_device().synchronize() immediately after zeroing the buffer to block the CPU until the zeroing operation is fully completed on the GPU, ensuring the slot is clean before the next rank is notified.
| if torch.equal(buf[:4], magic): | |
| buf[:4] = 0 # reset for next use | |
| break | |
| if torch.equal(buf[:4], magic): | |
| buf[:4] = 0 | |
| get_torch_device().synchronize() | |
| break |
…ve_weights
Add get_torch_device().synchronize() before writing magic completion
signal to prevent sender from reusing the buffer while consumer still
has pending GPU ops. Use info.get("magic_ptr") with fallback to data
ptr for backward compatibility with older senders. Add sglang import
fallback for StatelessProcessGroup.
What does this PR do?
Fix weight corruption in
MooncakeCheckpointEngine's daisy-chain weight sync where the magic completion marker[0xAB, 0xDC, 0xEF, 0x88]overwrites the first 4 bytes of the data buffer, causing degenerate inference output (e.g.!!!!repeated tomax_response_length). Introduces a dedicatedmagic_recvbuffer for completion signals to isolate them from the data path.Checklist Before Starting
[ckpt] fix: use separate magic_recv buffer to prevent weight corruption in Mooncake daisy-chain syncTest
This change cannot be tested by CI as it requires multi-GPU Mooncake RDMA environment. Validated by full training run:
response_length/meannum_turns/mean!!!!repeatedembed_tokensAPI and Usage Example
No API changes. The fix is internal to
MooncakeCheckpointEngine— all existing config and usage remains the same.Design & Code Changes
Problem: After receiving a bucket via RDMA, the receiver writes a 4-byte magic marker to the sender's data buffer as a completion signal. Instrumentation shows the data buffer gets corrupted with magic bytes, and this corruption propagates through the daisy chain to downstream ranks. In Qwen3.6-27B, this overwrites
embed_tokens.weight[0:2]with values like-3.85e+17, saturating all attention computations.Root cause mechanism (whether
transfer_sync_writehas a local GPU side effect on intra-node RDMA, or another pathway) is still under investigation.Fix: Introduce a separate RDMA-registered buffer (
magic_recv, 8 bytes) for completion signals:__init__: Addself.magic_recv = torch.zeros(8, ...)and register it withbatch_register_memorysend_weights: Includemagic_ptr(frommagic_slots) in info dict;wait_for_completechecksmagic_slots[idx]instead of data bufferreceive_weights: Extractmagic_ptrfrom received info; write magic tomagic_ptr(dedicated slot) instead ofptr(data buffer); forwardmagic_ptrto next rankwait_for_complete: Reset magic slot to 0 after detection for reuseRegardless of the corruption mechanism, writing to a dedicated buffer ensures any side effects only modify
magic_recv(harmless) instead of the data buffer (fatal).Checklist Before Submitting
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel.recipesubmodule. — N/A, no recipe changes.