Skip to content

[FEAT] TP-invariant reductions for FSDP(TP=1) vs TP>1 rollout/training parity #102

Description

@inaniloquentee

This issue scopes the P0.3 roadmap item for tensor-parallel-invariant reductions across rollout and training engines.

Target identity:

same model + same sequence + same policy state
=> selected logprobs / loss reductions are invariant to tensor-parallel degree

This is a focused follow-up to the train-inference consistency roadmap (#83) and the batch-invariant consistency RFC (#101). #101 covers batch/cache/layout invariance broadly; this issue covers the distributed reduction contract needed when training uses FSDP with TP=1 but rollout or scoring uses TP>1.

Concrete Example

Consider one token position with vocab size 6 and target token id 4.

Full-vocab TP=1 logits:

id:      0     1     2     3     4     5
logit:  0.1  -0.2   1.7   0.3   1.2  -0.5

FSDP(TP=1) computes:

logp(id=4) = 1.2 - logsumexp([0.1, -0.2, 1.7, 0.3, 1.2, -0.5])
           = -1.3395806963

With TP=2, vocab is sharded:

rank0 ids 0..2: [0.1, -0.2, 1.7]
rank1 ids 3..5: [0.3,  1.2, -0.5]

The TP-invariant reduction is:

global_max = all_reduce_max(local_max) = 1.7
partial_sum_rank0 = sum(exp(rank0_logits - global_max))
partial_sum_rank1 = sum(exp(rank1_logits - global_max))
global_sum = all_reduce_sum(partial_sum_rank0 + partial_sum_rank1)
logp(id=4) = target_logit - (global_max + log(global_sum))
           = -1.3395806963

Bad reductions drift immediately:

owner-rank local logsumexp only:  1.2 - logsumexp([0.3, 1.2, -0.5]) = -0.4632642102
averaging per-rank logsumexp:     1.2 - mean(local_lse_rank0, local_lse_rank1) = -0.6322267505

That kind of drift is not a harmless implementation detail: GRPO/PPO consume logprob deltas through ratios and KL terms, so a TP-dependent selected-logprob can change the policy update even when the model weights and sampled sequence are identical.

Problem

Rollout/scoring stacks often run TP>1, while the training/reference scorer may run FSDP with TP=1. The same logical operation can then travel through different reduction paths:

  • full-vocab selected logprob in FSDP(TP=1) vs vocab-sharded selected logprob in TP>1;
  • lm_head output projection with full logits vs sharded vocab logits;
  • masked token/sentence loss denominators under different shard or micro-batch layouts;
  • group reward/advantage statistics when group members are partitioned differently;
  • dtype and accumulation differences such as fp16/bf16 inputs with fp32 reduction state;
  • reduction topology differences across CUDA / ROCm / Triton / torch.distributed paths.

Without an explicit TP-invariant reduction contract, rollout-vs-training parity tests can fail even when weight sync, tokenization, causal shifting, and masks are correct.

Scope

Define and test canonical TP-invariant reductions for:

  • selected-token logprob / cross entropy over vocab-sharded logits;
  • masked token reductions used by logprob summaries and loss denominators;
  • GRPO/PPO-style policy-ratio, KL, and loss aggregation inputs where TP or micro-batch partitioning can alter denominator semantics;
  • optional simulated TP partitions on CPU/CUDA so the contract can be tested without launching a full distributed engine;
  • real distributed TP smoke tests when multi-GPU CI or nightly hardware is available.

The contract should state:

  • reduction order/topology where bitwise identity is expected;
  • dtype policy for local partials, global reductions, and final casts;
  • mask and padding denominator semantics;
  • behavior for uneven vocab shards and target tokens owned by different ranks;
  • tolerance policy when bitwise identity is not realistic across backends.

Non-Goals

Proposed Deliverables

  • A short RFC/design note describing TP-invariant reduction semantics.
  • A reference implementation or test helper that computes selected logprob from full logits and from simulated TP shards using the same contract.
  • A parity test matrix for TP=1 vs TP=2/4 simulated shards, covering target-token ownership on each rank, padding masks, variable completion lengths, and uneven vocab shard tails.
  • A drift report format with max abs/relative error, token position, target id, owner rank, TP size, dtype, backend, and reduction name.
  • A path to plug the same checks into the end-to-end rollout/training logprob cross-benchmark.

Acceptance Criteria

  • TP=1 full-vocab reference and simulated TP>1 sharded reference produce matching selected logprobs within the declared tolerance.
  • Tests include at least TP=2 and TP=4, with target tokens on every shard.
  • Tests include prompt/completion/padding masks and verify denominator semantics do not change with TP degree.
  • Tests include uneven vocab shard sizes.
  • The design explicitly names the dtype/accumulation policy, including fp16/bf16 input and fp32 reduction state if chosen.
  • Failure output is actionable enough to identify whether drift came from vocab logsumexp, selected-token ownership, mask denominator, group/loss aggregation, or backend dtype behavior.
  • The issue links cleanly into [RFC] Batch-Invariant RL Kernel Suite for Train-Inference Consistency #101 and the future end-to-end logprob cross-benchmark tool.

Metadata

Metadata

Labels

component: distributedTasks involving Ray actor management, cross-node scheduling, and communication synchronization.component: kernelsTasks involving the development of CUDA and Triton underlying operatorscomponent: testingAdd test cases and benchmark-related tasksfeaturepriority: highSevere congestion issues require the highest priority for resolution.type: designIssues requiring in-depth discussion of architecture design

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions