refactor(3-pool): SUM-grad convention (supersedes #545; fixes 2nd stoch→V/U instance)#546
Draft
ocg-goodfire wants to merge 1 commit into
Draft
Conversation
…ssembly Replace the per-instance "pre-scale to survive a downstream reduce" patches (4 recurring bugs, latest PR #545's PPGD x n_ci) with a single convention: every data-parallel gradient reduction is SUM, and each producer emits a partial sum normalized only by the honest global count, carrying NO pool-size transport factor. SUM(partials) = total, so no producer needs any pool's size. Changes: - portals: all_reduce_ci_fn_grads and all_reduce_grads_in_block flip AVG -> SUM. - step_ppgd: V/U and CI collapse to one scale; the #545 x n_ci is deleted. - step_layerwise: stoch denom drops /n_ci (one scale now serves both CI leaves and V/U). faith + broadcast-PPGD V/U are "contribute once" (block leader only) so the SUM lands them exactly once instead of n_per_block x. - step_ci: imp-min uses the detached-global-residual trick instead of the autograd-aware all_reduce, so its backward is a local partial (no reliance on the old AVG cancelling an n_ci factor). - grad-clip n_replicas unchanged (counts distinct params for the global norm, independent of the reduce op). The replica count does not vanish for replicated contributions — it relocates from a numeric scale factor into a structural placement (contribute-once) or a graph decision (detach the global term). See SUM_GRAD_CONVENTION.md for the honest verdict on whether this is simpler than per-destination scale-splitting. Validated by a new distributed grad check at a non-square topology (n_ci=4 != n_per_block=2, 2 blocks, n_ppgd=2) with ALL loss terms enabled: fully-reduced CI-fn and V/U grads match a single-process full-batch reference (mean dist/ref = 1.000000, worst rel err 4e-07). Sensitivity confirmed: AVG on the CI reduce -> 1/n_ci on CI grads; faith on all ranks -> n_per_block x on V/U. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
A structural fix for the recurring 3-pool gradient-scaling bug class: replace the per-instance "pre-scale to survive a downstream reduction" pattern with a single convention — every gradient crossing a cross-rank reduction is a partial SUM, normalized only by the honest global count, carrying no pool-size transport factor; all data-parallel grad reductions are SUM.
Full rationale + the honest trade-off analysis in
param_decomp_lab/three_pool/SUM_GRAD_CONVENTION.md.Why
Four recurring bugs, all the same shape: a producer pre-scales its gradient by a pool-size factor (
n_ci,n_per_block) to survive an AVG-reduce it can't locally see; the factor is invisible atn_ci=1/square and wrong otherwise. Two are confirmed live by a real grad check:×n_ci, fixed manually in fix(3-pool): scale PPGD's CI grad by n_ci to survive the CI-pool AVG #545):1/n_citoo weak.n_ci/n_per_blockoff. Measured0.500at n_ci=2/n_per_block=4.On the live run
p-97bab993(n_ci=32, n_per_block=2) these are 16× (V/U) and 32× (CI) off — not subtle.How
all_reduce_ci_fn_gradsandall_reduce_grads_in_block: AVG → SUM.×n_ciand stoch's/n_ciare deleted (the factors evaporate for the data-parallel producers).Relationship to #545
This supersedes #545 by folding its fix into the convention (the manual
×n_ciis removed because the AVG it compensated for is now a SUM). Based onfix/ppgd-ci-grad-nci-scalingso the diff here is only the structural change. @danbraunai-goodfire — keen for your read on whether to merge #545 then this, or collapse them.Validation
test_three_pool_grad_check_distributed.py(10-rank gloo, non-square n_ci=4/n_per_block=2/n_ppgd=2, all four loss terms, RNG/masks pinned): fully-reduced CI-fn and V/U grads vs single-process reference — mean(dist/ref)=1.000000, worst rel err 4e-7. Sensitivity-checked: AVG→CI gives 1/n_ci, faith-on-all-ranks gives n_per_block× (the check bites).make checkclean; 42 three-pool/distributed tests pass.The fix changes effective gradient magnitudes (it corrects them), so training dynamics / the stoch:imp:ppgd balance change and LRs may need retuning. Pending a behavioural smoke at a real (non-square) topology before this comes out of draft. Correctness is grad-check-proven; dynamics are not yet observed at scale.
🤖 Generated with Claude Code