Widen RunBatch/ReconstructionLoss for structured outputs#525
Open
ocg-goodfire wants to merge 1 commit into
Open
Widen RunBatch/ReconstructionLoss for structured outputs#525ocg-goodfire wants to merge 1 commit into
ocg-goodfire wants to merge 1 commit into
Conversation
0599cca to
7e3dff9
Compare
Lets experiments package per-batch context (padding masks, labels, MSA aux features) into output dataclasses instead of smuggling them through tensor shapes. Surfaced while stress-testing the abstractions against ESM2, Carbon, and GPN-MSA bio models. - RunBatch: (model, batch) -> Tensor → -> Any - ReconstructionLoss args: (pred, target) → (output, target_output); types Any - OutputWithCache.output, MetricContext.target_out: Tensor → Any - (sum, n) return shape kept — earns its keep for variable-mask eval - Notes the tied-embedding gap in make_components (deferred) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
7e3dff9 to
1da9e13
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Two small protocol changes in
param_decomp/batch_and_loss_fns.py:RunBatch: (model, batch) -> Tensor→-> AnyReconstructionLoss: (pred: Tensor, target: Tensor) -> (sum, n)→(output: Any, target_output: Any) -> (sum, n)Plus the matching type widenings to keep callers honest:
OutputWithCache.output,MetricContext.target_out, thetarget_out: Tensorannotation in 9 metric modules, and a kwarg-rename at one PPGD callsite. Lab adapters (recon_loss_mse,recon_loss_kl) and two tests get the same arg-rename.(sum, n)return shape kept — earns its keep for exact micro-averaging in eval-time DDP all_reduce when batch weights vary.Also drops a
NOTEinmake_componentsflagging that storage-tied weights (tie_word_embeddings=Trueon Llama / ESM / GPT-2 / BERT) are not detected today. Deferred — we don't decompose embeddings currently.Motivation and Context
Surfaced while stress-testing the abstractions against three bio models (ESM2, Carbon-500M, GPN-MSA). Real bio data needs per-batch context (padding masks for variable-length protein sequences, MLM-masked positions, MSA aux features) routed into the recon loss. Today the only way to do that is to smuggle it through
pred's tensor shape, which breaks thepred.shape == target.shapeinvariant the existing impls assume.After this change, an experiment can ship its own
Outputdataclass fromRunBatchcarrying whatever the recon needs:Nothing in-tree exercises the new shape yet; this PR is purely the protocol widening.
How Has This Been Tested?
make checkclean (0 errors, 0 warnings)make test: 411 passed, 5 skippedfeature/bio-demo-modelconfirmed they continue to work on top of this (they use the simpleRunBatch -> Tensorpath which is still a validAny)Does this PR introduce a breaking change?
Yes — but only for out-of-tree callers that implement
RunBatch/ReconstructionLossagainst the old types:RunBatchimpls typed as returningTensorare still valid (Tensor is Any).ReconstructionLossimpls using positional args (def f(pred, target)) need to rename to(output, target_output)to satisfy the Protocol name-match in pyright. Functionally identical otherwise.recon_loss_mse,recon_loss_kl) called by kwarg needpred=→output=andtarget=→target_output=.🤖 Generated with Claude Code