Skip to content

Skip entropy gradient computation when entropy_coef == 0#2130

Open
CSUN1997 wants to merge 1 commit into
THUDM:mainfrom
CSUN1997:fix/entropy-grad-when-coef-zero
Open

Skip entropy gradient computation when entropy_coef == 0#2130
CSUN1997 wants to merge 1 commit into
THUDM:mainfrom
CSUN1997:fix/entropy-grad-when-coef-zero

Conversation

@CSUN1997

Copy link
Copy Markdown

Skip entropy gradient computation when entropy_coef == 0

Summary

When entropy_coef == 0, slime still computes the policy entropy with autograd
enabled
, retaining the full [num_tokens, vocab] entropy computation graph (and a
per-chunk logits.clone()) even though the entropy term is multiplied out of the loss
and contributes no gradient. This wasted activation memory dominates for long
multi-turn / agentic rollouts and is a frequent OOM source.

This PR gates the entropy autograd graph on whether it is actually needed.

Root cause

policy_loss_function requests entropy unconditionally:

# slime/backends/megatron_utils/loss.py
_, log_probs_and_entropy = get_log_probs_and_entropy(
    logits, ..., with_entropy=True, ...
)
...
entropy = log_probs_and_entropy["entropy"]
entropy_loss = sum_of_sample_mean(entropy)
loss = pg_loss - args.entropy_coef * entropy_loss   # ← × 0 when coef == 0

calculate_log_probs_and_entropy then computes entropy with grad tracking on:

entropy_input = logits.clone()
entropy = compute_entropy_from_logits(entropy_input, tp_group)

compute_entropy_from_logits is a torch.autograd.Function over the vocab-parallel
logits, so with grad enabled it saves [num_tokens, vocab] activations for backward.
When entropy_coef == 0 that graph is built and held for nothing.

Notably, the get_log_probs_and_entropy docstring already claims this is handled:

"When entropy_coef == 0, entropy is computed under torch.no_grad() to avoid
retaining the computation graph and to skip cloning."

…but the implementation never did it. A prior attempt (#1185) added the gate with an
inverted condition — with torch.no_grad() if args.entropy_coef else nullcontext()
— which disabled the gradient exactly when coef != 0 (breaking the entropy-bonus
path) and kept it when coef == 0. It was reverted the next day in #1189. This PR
implements the gate with the correct condition and locks the behavior with tests.

Change

Add a need_entropy_grad parameter to calculate_log_probs_and_entropy
(default True, so external callers are unaffected). The caller computes:

need_entropy_grad = with_entropy and getattr(args, "entropy_coef", 0.0) != 0

When need_entropy_grad is False, entropy is computed under torch.no_grad() and
the defensive logits.clone() is skipped (the clone only exists to keep the backward's
in-place ops — _VocabParallelEntropy.backward mutates the saved logits via sub_/mul_
— off the shared logits tensor; with no backward there is nothing to protect).

The entropy values are identical; only the autograd graph is dropped. The logged
entropy_loss metric is unchanged.

Files

  • slime/utils/ppo_utils.pycalculate_log_probs_and_entropy: add need_entropy_grad,
    wrap entropy compute in no_grad and skip the clone when grad isn't needed.
  • slime/backends/megatron_utils/loss.pyget_log_probs_and_entropy: derive
    need_entropy_grad from args.entropy_coef and pass it through; fix the stale docstring.
  • tests/test_entropy_grad_gating.py — new unit tests.

Why this is safe

  • Default preserves old behavior. need_entropy_grad defaults to True; the only
    in-repo caller opts into gating based on entropy_coef. Any external caller that does
    not pass the flag gets the previous (grad-on) behavior.
  • Numerically identical. Entropy is the same tensor either way; no_grad only changes
    whether the graph is retained (test 1).
  • Entropy-bonus training is unaffected. When entropy_coef != 0, need_entropy_grad
    is True, so entropy stays differentiable and still backpropagates to the logits
    (test 3). This is the exact correctness bug that sank Don't calculate entropy grad when coef is 0 #1185.

Testing

tests/test_entropy_grad_gating.py (CPU, world_size=1 gloo group; compute_log_probs
stubbed to avoid the Megatron fused-CE dependency), parametrized over chunked and
non-chunked paths:

  1. test_entropy_values_match_regardless_of_grad — entropy values equal with/without grad.
  2. test_need_entropy_grad_false_detaches_graphgrad_fn is None, requires_grad False.
  3. test_need_entropy_grad_true_is_differentiable — entropy backprops to logits, finite grad.
  4. test_log_probs_unaffected_by_entropy_flag — log-prob output independent of the flag.
  5. test_empty_input_returns_empty_entropy — zero-length response edge case.
$ pytest tests/test_entropy_grad_gating.py -o addopts="" -v
========================= 9 passed in 7.59s =========================

ruff check passes on all changed files; existing tests/test_chunked_gae.py still passes.

Notes / alternatives

This is the minimal, targeted fix (drop the graph when the gradient is provably unused).
A complementary, more general optimization is to filter logits to loss_mask == 1
positions before the vocab-parallel softmax (cf. #1905), which shrinks
both the log-prob and entropy compute for agentic rollouts where most response tokens are
masked tool-result tokens. The two are orthogonal and can be combined; this PR addresses
only the entropy_coef == 0 wasted-gradient case and restores the behavior the docstring
already promises.

`policy_loss_function` always requests entropy with `with_entropy=True`, and
`calculate_log_probs_and_entropy` computed it with autograd enabled
unconditionally. But entropy enters the loss as
`loss = pg_loss - args.entropy_coef * entropy_loss`, so when
`entropy_coef == 0` the entropy term contributes no gradient — yet the full
`[num_tokens, vocab]` entropy autograd graph (plus a defensive `logits.clone()`
per chunk) was still retained. For long multi-turn rollouts this activation
memory dominates and OOMs.

Add a `need_entropy_grad` flag to `calculate_log_probs_and_entropy`. The caller
sets it to `with_entropy and args.entropy_coef != 0`. When false, entropy is
computed under `torch.no_grad()` and the clone is skipped (the clone only exists
to keep the backward's in-place ops off the shared logits tensor; there is no
backward under `no_grad`). Entropy values are unchanged — only the graph is
dropped — so the logged `entropy_loss` metric is identical.

This makes the code match the existing `get_log_probs_and_entropy` docstring,
which already claimed this behavior but was never implemented (the prior
attempt, THUDM#1185, inverted the `no_grad` condition and was reverted in THUDM#1189).

Add tests/test_entropy_grad_gating.py covering: entropy values match with/without
grad; need_entropy_grad=False detaches the graph; need_entropy_grad=True remains
differentiable; log_probs are unaffected; empty-input handling.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants