Skip to content

Assemble decoder-only model: forward + next-token CE loss (MULTI-1383)#6

Merged
RobbieMcKinstry merged 1 commit into
trunkfrom
robbie/multi-1383
Jun 29, 2026
Merged

Assemble decoder-only model: forward + next-token CE loss (MULTI-1383)#6
RobbieMcKinstry merged 1 commit into
trunkfrom
robbie/multi-1383

Conversation

@RobbieMcKinstry

Copy link
Copy Markdown
Contributor

Summary

  • Assemble the decoder-only model as a Burn Module: token embedding → sinusoidal PositionalEncoding → causal-masked TransformerEncoder (honoring norm_first) → final LayerNorm → untied LM head. Uses the encoder (not TransformerDecoder) because decoder-only GPTs have no cross-attention or encoder memory; the encoder + a causal mask is the correct primitive.
  • Add next_token_cross_entropy(logits, tokens): the standard left-shift + flatten to Burn's CrossEntropyLoss.
  • Tests cover the four MULTI-1383 acceptance criteria.

Test plan

  • cargo fmt --all --check
  • cargo clippy --all-targets --workspace --locked -- -D warnings
  • cargo nextest run --workspace --locked — 59 tests pass, including the four new ones:
    • gpt2_small_lands_in_the_100m_param_class — actual num_params() falls in 80M–140M with the locked 16k vocab.
    • forward_returns_batch_seq_vocab_with_no_nan_or_inf — output shape [B, T, V], all finite.
    • loss_at_init_is_near_ln_vocab — CE within ±2 nats of ln(vocab).
    • all_parameters_receive_non_zero_gradients — backward on TrainBackend produces a non-zero abs-sum gradient for every visited float param.

Closes MULTI-1383.

🤖 Generated with Claude Code

…-1383)

Build a Burn `Module` that turns raw tokens into a training loss: token
embedding → sinusoidal positional encoding → causal-masked
`TransformerEncoder` → final `LayerNorm` → untied LM head. Use the
encoder (not `TransformerDecoder`) because a decoder-only GPT has no
cross-attention or encoder memory; the encoder + a causal mask is the
correct primitive. The `norm_first` toggle threads through from
`ModelConfig`.

`next_token_cross_entropy` does the standard left-shift and flattens to
`CrossEntropyLoss`, so the last logit slot has no target and the first
token has no prediction.

Tests cover the four MULTI-1383 acceptance criteria:

  * `gpt2_small` instantiates in the ~100M-parameter class (80M–140M),
    catching slipped width / depth / vocab against the locked tokenizer
    from MULTI-1379.
  * Forward output is `[batch × seq × vocab]` with no NaN/Inf at init.
  * Loss at init lands within ±2 nats of `ln(vocab)` — the uniform
    predictor's entropy — catching missing log-softmax or wrong vocab
    dim.
  * Every float parameter receives a non-zero gradient after backward
    on `TrainBackend`, proving the autograd graph reaches the whole
    module.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@RobbieMcKinstry RobbieMcKinstry added this pull request to the merge queue Jun 29, 2026
Merged via the queue into trunk with commit 6885f6a Jun 29, 2026
3 checks passed
@RobbieMcKinstry RobbieMcKinstry deleted the robbie/multi-1383 branch June 29, 2026 19:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant