Implement a CUDA Batch-Invariant Deterministic LogProb kernel for token-level logprob computation. The same sequence must produce bitwise-identical logprobs regardless of batch size, batch position, padding layout, or other samples in the batch.
The goal is to eliminate train/inference mismatch caused by nondeterministic reduction order, preventing KL spikes and training collapse.
Requirements
- Do not use
atomicAdd or any nondeterministic accumulation path.
- Use a fixed CUDA reduction topology:
- fixed block/thread partitioning,
- fixed tree-reduction order,
- fixed reduction order for max, sum-exp, and final logprob computation.
- The output for the same sequence must be bitwise identical across:
- different batch sizes,
- different positions within the batch,
- different padding/mask layouts,
- mixed batches containing unrelated sequences.
- Accumulate in FP32 where needed. Inputs may support FP16 / BF16 / FP32.
- Do not rely on cuDNN/cuBLAS softmax implementations that may change behavior based on workspace, heuristics, or scheduling.
- Preserve existing logprob semantics, including masking, padding, and ignore-index behavior.
Acceptance Criteria
- Add a deterministic CUDA logprob kernel.
- Add batch-invariance tests:
- place the same sequence in batches of different sizes,
- move it to different batch positions,
- mix it with unrelated sequences,
- verify per-token and sequence logprobs are bitwise identical.
- Add repeatability tests showing identical output across repeated runs.
- Verify the kernel does not use
atomicAdd.
- Compare numerics against the existing logprob path within an acceptable tolerance, while prioritizing deterministic stability.
- Leave the test structure ready for future ROCm and Triton parity checks.
Implement a CUDA
Batch-Invariant Deterministic LogProbkernel for token-level logprob computation. The same sequence must produce bitwise-identical logprobs regardless of batch size, batch position, padding layout, or other samples in the batch.The goal is to eliminate train/inference mismatch caused by nondeterministic reduction order, preventing KL spikes and training collapse.
Requirements
atomicAddor any nondeterministic accumulation path.Acceptance Criteria
atomicAdd.