Skip to content

Widen RunBatch/ReconstructionLoss for structured outputs#525

Open
ocg-goodfire wants to merge 1 commit into
mainfrom
refactor/widen-runbatch-reconloss
Open

Widen RunBatch/ReconstructionLoss for structured outputs#525
ocg-goodfire wants to merge 1 commit into
mainfrom
refactor/widen-runbatch-reconloss

Conversation

@ocg-goodfire

Copy link
Copy Markdown
Collaborator

Description

Two small protocol changes in param_decomp/batch_and_loss_fns.py:

  • RunBatch: (model, batch) -> Tensor-> Any
  • ReconstructionLoss: (pred: Tensor, target: Tensor) -> (sum, n)(output: Any, target_output: Any) -> (sum, n)

Plus the matching type widenings to keep callers honest: OutputWithCache.output, MetricContext.target_out, the target_out: Tensor annotation in 9 metric modules, and a kwarg-rename at one PPGD callsite. Lab adapters (recon_loss_mse, recon_loss_kl) and two tests get the same arg-rename.

(sum, n) return shape kept — earns its keep for exact micro-averaging in eval-time DDP all_reduce when batch weights vary.

Also drops a NOTE in make_components flagging that storage-tied weights (tie_word_embeddings=True on Llama / ESM / GPT-2 / BERT) are not detected today. Deferred — we don't decompose embeddings currently.

Motivation and Context

Surfaced while stress-testing the abstractions against three bio models (ESM2, Carbon-500M, GPN-MSA). Real bio data needs per-batch context (padding masks for variable-length protein sequences, MLM-masked positions, MSA aux features) routed into the recon loss. Today the only way to do that is to smuggle it through pred's tensor shape, which breaks the pred.shape == target.shape invariant the existing impls assume.

After this change, an experiment can ship its own Output dataclass from RunBatch carrying whatever the recon needs:

@dataclass
class ESM2Output:
    logits: Tensor
    attention_mask: Tensor

def run_batch_esm2(model, batch) -> ESM2Output:
    out = model(batch.input_ids, attention_mask=batch.attention_mask)
    return ESM2Output(logits=out.logits, attention_mask=batch.attention_mask)

def recon_loss_masked_kl(output: ESM2Output, target_output: ESM2Output):
    # use output.attention_mask to zero out padding contributions
    ...

Nothing in-tree exercises the new shape yet; this PR is purely the protocol widening.

How Has This Been Tested?

  • make check clean (0 errors, 0 warnings)
  • make test: 411 passed, 5 skipped
  • Bio scaffolds on feature/bio-demo-model confirmed they continue to work on top of this (they use the simple RunBatch -> Tensor path which is still a valid Any)

Does this PR introduce a breaking change?

Yes — but only for out-of-tree callers that implement RunBatch / ReconstructionLoss against the old types:

  • RunBatch impls typed as returning Tensor are still valid (Tensor is Any).
  • ReconstructionLoss impls using positional args (def f(pred, target)) need to rename to (output, target_output) to satisfy the Protocol name-match in pyright. Functionally identical otherwise.
  • Lab adapters (recon_loss_mse, recon_loss_kl) called by kwarg need pred=output= and target=target_output=.

🤖 Generated with Claude Code

@ocg-goodfire ocg-goodfire force-pushed the refactor/widen-runbatch-reconloss branch from 0599cca to 7e3dff9 Compare May 27, 2026 12:55
Lets experiments package per-batch context (padding masks, labels, MSA
aux features) into output dataclasses instead of smuggling them through
tensor shapes. Surfaced while stress-testing the abstractions against
ESM2, Carbon, and GPN-MSA bio models.

- RunBatch: (model, batch) -> Tensor → -> Any
- ReconstructionLoss args: (pred, target) → (output, target_output); types Any
- OutputWithCache.output, MetricContext.target_out: Tensor → Any
- (sum, n) return shape kept — earns its keep for variable-mask eval
- Notes the tied-embedding gap in make_components (deferred)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@ocg-goodfire ocg-goodfire force-pushed the refactor/widen-runbatch-reconloss branch from 7e3dff9 to 1da9e13 Compare May 27, 2026 13:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant