Skip to content

refactor(3-pool): SUM-grad convention (supersedes #545; fixes 2nd stoch→V/U instance)#546

Draft
ocg-goodfire wants to merge 1 commit into
fix/ppgd-ci-grad-nci-scalingfrom
refactor/three-pool-sum-grad-convention
Draft

refactor(3-pool): SUM-grad convention (supersedes #545; fixes 2nd stoch→V/U instance)#546
ocg-goodfire wants to merge 1 commit into
fix/ppgd-ci-grad-nci-scalingfrom
refactor/three-pool-sum-grad-convention

Conversation

@ocg-goodfire

Copy link
Copy Markdown
Collaborator

What

A structural fix for the recurring 3-pool gradient-scaling bug class: replace the per-instance "pre-scale to survive a downstream reduction" pattern with a single convention — every gradient crossing a cross-rank reduction is a partial SUM, normalized only by the honest global count, carrying no pool-size transport factor; all data-parallel grad reductions are SUM.

Full rationale + the honest trade-off analysis in param_decomp_lab/three_pool/SUM_GRAD_CONVENTION.md.

Why

Four recurring bugs, all the same shape: a producer pre-scales its gradient by a pool-size factor (n_ci, n_per_block) to survive an AVG-reduce it can't locally see; the factor is invisible at n_ci=1/square and wrong otherwise. Two are confirmed live by a real grad check:

On the live run p-97bab993 (n_ci=32, n_per_block=2) these are 16× (V/U) and 32× (CI) off — not subtle.

How

  • all_reduce_ci_fn_grads and all_reduce_grads_in_block: AVG → SUM.
  • Producers normalize by the honest global count; fix(3-pool): scale PPGD's CI grad by n_ci to survive the CI-pool AVG #545's ×n_ci and stoch's /n_ci are deleted (the factors evaporate for the data-parallel producers).
  • Replicated contributions handled structurally, not numerically: faith + broadcast-PPGD V/U become contribute-once (block leader only); imp-min uses a detached-global-residual so its backward is a local partial sum.

Relationship to #545

This supersedes #545 by folding its fix into the convention (the manual ×n_ci is removed because the AVG it compensated for is now a SUM). Based on fix/ppgd-ci-grad-nci-scaling so the diff here is only the structural change. @danbraunai-goodfire — keen for your read on whether to merge #545 then this, or collapse them.

Validation

  • New test_three_pool_grad_check_distributed.py (10-rank gloo, non-square n_ci=4/n_per_block=2/n_ppgd=2, all four loss terms, RNG/masks pinned): fully-reduced CI-fn and V/U grads vs single-process reference — mean(dist/ref)=1.000000, worst rel err 4e-7. Sensitivity-checked: AVG→CI gives 1/n_ci, faith-on-all-ranks gives n_per_block× (the check bites).
  • make check clean; 42 three-pool/distributed tests pass.

⚠️ Draft — not merge-ready

The fix changes effective gradient magnitudes (it corrects them), so training dynamics / the stoch:imp:ppgd balance change and LRs may need retuning. Pending a behavioural smoke at a real (non-square) topology before this comes out of draft. Correctness is grad-check-proven; dynamics are not yet observed at scale.

🤖 Generated with Claude Code

…ssembly

Replace the per-instance "pre-scale to survive a downstream reduce" patches
(4 recurring bugs, latest PR #545's PPGD x n_ci) with a single convention:
every data-parallel gradient reduction is SUM, and each producer emits a
partial sum normalized only by the honest global count, carrying NO pool-size
transport factor. SUM(partials) = total, so no producer needs any pool's size.

Changes:
- portals: all_reduce_ci_fn_grads and all_reduce_grads_in_block flip AVG -> SUM.
- step_ppgd: V/U and CI collapse to one scale; the #545 x n_ci is deleted.
- step_layerwise: stoch denom drops /n_ci (one scale now serves both CI leaves
  and V/U). faith + broadcast-PPGD V/U are "contribute once" (block leader only)
  so the SUM lands them exactly once instead of n_per_block x.
- step_ci: imp-min uses the detached-global-residual trick instead of the
  autograd-aware all_reduce, so its backward is a local partial (no reliance on
  the old AVG cancelling an n_ci factor).
- grad-clip n_replicas unchanged (counts distinct params for the global norm,
  independent of the reduce op).

The replica count does not vanish for replicated contributions — it relocates
from a numeric scale factor into a structural placement (contribute-once) or a
graph decision (detach the global term). See SUM_GRAD_CONVENTION.md for the
honest verdict on whether this is simpler than per-destination scale-splitting.

Validated by a new distributed grad check at a non-square topology
(n_ci=4 != n_per_block=2, 2 blocks, n_ppgd=2) with ALL loss terms enabled:
fully-reduced CI-fn and V/U grads match a single-process full-batch reference
(mean dist/ref = 1.000000, worst rel err 4e-07). Sensitivity confirmed: AVG on
the CI reduce -> 1/n_ci on CI grads; faith on all ranks -> n_per_block x on V/U.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant