Skip to content

Implement Batch-Invariant Deterministic LogProb CUDA Kernel to Eliminate Batch-Size Drift #96

Description

@inaniloquentee

Implement a CUDA Batch-Invariant Deterministic LogProb kernel for token-level logprob computation. The same sequence must produce bitwise-identical logprobs regardless of batch size, batch position, padding layout, or other samples in the batch.

The goal is to eliminate train/inference mismatch caused by nondeterministic reduction order, preventing KL spikes and training collapse.

Requirements

  • Do not use atomicAdd or any nondeterministic accumulation path.
  • Use a fixed CUDA reduction topology:
    • fixed block/thread partitioning,
    • fixed tree-reduction order,
    • fixed reduction order for max, sum-exp, and final logprob computation.
  • The output for the same sequence must be bitwise identical across:
    • different batch sizes,
    • different positions within the batch,
    • different padding/mask layouts,
    • mixed batches containing unrelated sequences.
  • Accumulate in FP32 where needed. Inputs may support FP16 / BF16 / FP32.
  • Do not rely on cuDNN/cuBLAS softmax implementations that may change behavior based on workspace, heuristics, or scheduling.
  • Preserve existing logprob semantics, including masking, padding, and ignore-index behavior.

Acceptance Criteria

  • Add a deterministic CUDA logprob kernel.
  • Add batch-invariance tests:
    • place the same sequence in batches of different sizes,
    • move it to different batch positions,
    • mix it with unrelated sequences,
    • verify per-token and sequence logprobs are bitwise identical.
  • Add repeatability tests showing identical output across repeated runs.
  • Verify the kernel does not use atomicAdd.
  • Compare numerics against the existing logprob path within an acceptable tolerance, while prioritizing deterministic stability.
  • Leave the test structure ready for future ROCm and Triton parity checks.

Metadata

Metadata

Labels

component: kernelsTasks involving the development of CUDA and Triton underlying operatorsplatform: cudaSpecific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations)priority: highSevere congestion issues require the highest priority for resolution.

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions