Skip to content

[WS1][kernels] Batch-invariant embedding + LM head projection #151

Description

@Flink-ddd

Part of WS1 — Full Batch-Invariant Forward Chain (epic: #)

Why

The input embedding lookup and the final vocab projection bracket the network. The LM head is a large matmul sitting directly upstream of logprob, so any batch-dependent reduction there lands in the logprobs immediately. The embedding lookup is simpler but must be confirmed not to branch on batch shape. Both must be on the batch-invariant path or the chain is not actually closed.

Scope

Confirm the input embedding and the LM-head projection run batch-invariantly.

  • Confirm the embedding lookup (gather) produces identical vectors for a token regardless of batch size, position, or padding (no shape-dependent path).
  • Route the LM-head vocab projection through the batch-invariant GEMM (the matmul issue), with a fixed K-accumulation order over the hidden dimension.
  • Confirm tied-embedding weight sharing (if used by the model) does not change the reduction path between the two.
  • Validate both against the [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 harness across the standard sweep, with particular attention to the LM-head -> logprob handoff.

Out of scope

  • The GEMM kernel itself (covered by the matmul issue; this issue consumes it).
  • The logprob reduction (covered by the logprob issue; this issue feeds it).
  • Vocab-parallel distributed LM head; weight-tying policy changes.
  • FP8.

Acceptance criteria

Notes

Planned PRs

  • Embedding-lookup invariance tests (fixed token/position -> identical vector)
  • Route the LM-head vocab projection through the deterministic GEMM
  • Tied-weight sharing consistency check (if the model uses tied embeddings)
  • LM-head -> logprob handoff test
  • Embedding / LM-head backward invariance via the shared gradient check; wire through [WS1] Ground-truth harness + numerical contract for batch-invariant ops #108 harness

Metadata

Metadata

Labels

component: kernelsTasks involving the development of CUDA and Triton underlying operatorsfeatureplatform: cudaSpecific optimizations or bugs in NVIDIA graphics cards (such as FlashInfer, TMA optimizations)priority: highSevere congestion issues require the highest priority for resolution.sprint-0615

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions