From be4798f727bd502a393dfbf2ceb23effc030f9df Mon Sep 17 00:00:00 2001 From: Chuanneng Sun Date: Thu, 25 Jun 2026 21:29:49 +0000 Subject: [PATCH] Skip entropy gradient when entropy_coef == 0 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `policy_loss_function` always requests entropy with `with_entropy=True`, and `calculate_log_probs_and_entropy` computed it with autograd enabled unconditionally. But entropy enters the loss as `loss = pg_loss - args.entropy_coef * entropy_loss`, so when `entropy_coef == 0` the entropy term contributes no gradient — yet the full `[num_tokens, vocab]` entropy autograd graph (plus a defensive `logits.clone()` per chunk) was still retained. For long multi-turn rollouts this activation memory dominates and OOMs. Add a `need_entropy_grad` flag to `calculate_log_probs_and_entropy`. The caller sets it to `with_entropy and args.entropy_coef != 0`. When false, entropy is computed under `torch.no_grad()` and the clone is skipped (the clone only exists to keep the backward's in-place ops off the shared logits tensor; there is no backward under `no_grad`). Entropy values are unchanged — only the graph is dropped — so the logged `entropy_loss` metric is identical. This makes the code match the existing `get_log_probs_and_entropy` docstring, which already claimed this behavior but was never implemented (the prior attempt, #1185, inverted the `no_grad` condition and was reverted in #1189). Add tests/test_entropy_grad_gating.py covering: entropy values match with/without grad; need_entropy_grad=False detaches the graph; need_entropy_grad=True remains differentiable; log_probs are unaffected; empty-input handling. --- slime/backends/megatron_utils/loss.py | 12 ++- slime/utils/ppo_utils.py | 36 +++++-- tests/test_entropy_grad_gating.py | 148 ++++++++++++++++++++++++++ 3 files changed, 186 insertions(+), 10 deletions(-) create mode 100644 tests/test_entropy_grad_gating.py diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 72afdfa66c..b63d41bbe9 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -485,8 +485,10 @@ def get_log_probs_and_entropy( per-sample slicing) so backward traverses ``[T, V]`` only once, then extracts per-sample response portions. - When ``entropy_coef == 0``, entropy is computed under ``torch.no_grad()`` - to avoid retaining the computation graph and to skip cloning. + When ``entropy_coef == 0``, entropy is still returned as a metric but is computed under + ``torch.no_grad()`` (and without the defensive logits clone), since it is multiplied out of + the loss and therefore needs no gradient. This avoids retaining the ``[T, V]`` entropy graph, + which is a dominant activation-memory cost for long multi-turn rollouts. """ assert non_loss_data assert logits.dtype == torch.float32, f"{logits.dtype}" @@ -522,6 +524,11 @@ def get_log_probs_and_entropy( ) # --- compute on full [T,V] logits at once via calculate_log_probs_and_entropy --- + # Entropy only contributes a gradient when it is weighted into the loss (``entropy_coef != 0``). + # When the coefficient is 0 the entropy is logged as a metric but multiplied out of the loss, so + # we compute it under ``no_grad`` to avoid retaining the [T, V] entropy graph (a major activation- + # memory cost / OOM source for long rollouts). + need_entropy_grad = with_entropy and getattr(args, "entropy_coef", 0.0) != 0 log_prob_full, entropy_full = calculate_log_probs_and_entropy( logits, full_tokens, @@ -529,6 +536,7 @@ def get_log_probs_and_entropy( with_entropy=with_entropy, chunk_size=chunk_size, log_prob_keep_mask=top_p_keep_mask, + need_entropy_grad=need_entropy_grad, ) log_prob_full = log_prob_full.squeeze(-1) # [T, 1] -> [T] diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index a4dd0c0181..228ca653d9 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -2,6 +2,7 @@ # and https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ppo_utils/experience_maker.py from argparse import Namespace +from contextlib import nullcontext import torch import torch.distributed as dist @@ -690,10 +691,27 @@ def chunked_gae( def calculate_log_probs_and_entropy( - logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1, log_prob_keep_mask=None + logits, + tokens, + tp_group, + with_entropy: bool = False, + chunk_size: int = -1, + log_prob_keep_mask=None, + need_entropy_grad: bool = True, ): + """Compute per-token log-probs and (optionally) entropy from vocab-parallel logits. + + When ``with_entropy`` is set but ``need_entropy_grad`` is ``False`` (i.e. the caller + multiplies entropy by ``entropy_coef == 0``, so it never contributes a gradient), the + entropy is computed under ``torch.no_grad()``. This avoids retaining the autograd graph + for the ``[num_tokens, vocab]`` entropy computation, which otherwise dominates activation + memory for long multi-turn rollouts and can OOM. Under ``no_grad`` we also skip the + defensive ``logits.clone()`` (only needed to keep the backward's in-place ops off the + shared logits tensor), saving a full logits-sized allocation per chunk. + """ logits = logits.contiguous() entropy = None + entropy_ctx = nullcontext() if need_entropy_grad else torch.no_grad() if logits.size(0) != 0: if chunk_size > 0: num_chunks = (logits.size(0) - 1) // chunk_size + 1 @@ -704,11 +722,12 @@ def calculate_log_probs_and_entropy( ) if with_entropy: - entropys = [] - for logits_chunk in logits_chunks: - entropy_input = logits_chunk.clone() - entropys.append(compute_entropy_from_logits(entropy_input, tp_group)) - entropy = torch.cat(entropys, dim=0) + with entropy_ctx: + entropys = [] + for logits_chunk in logits_chunks: + entropy_input = logits_chunk.clone() if need_entropy_grad else logits_chunk + entropys.append(compute_entropy_from_logits(entropy_input, tp_group)) + entropy = torch.cat(entropys, dim=0) log_probs = [] for tokens_chunk, logits_chunk, mask_chunk in zip(tokens_chunks, logits_chunks, mask_chunks, strict=True): @@ -717,8 +736,9 @@ def calculate_log_probs_and_entropy( log_prob = torch.cat(log_probs, dim=0) else: if with_entropy: - entropy_input = logits.clone() - entropy = compute_entropy_from_logits(entropy_input, tp_group) + with entropy_ctx: + entropy_input = logits.clone() if need_entropy_grad else logits + entropy = compute_entropy_from_logits(entropy_input, tp_group) log_prob = compute_log_probs(logits.clone(), tokens, tp_group, keep_mask=log_prob_keep_mask) else: diff --git a/tests/test_entropy_grad_gating.py b/tests/test_entropy_grad_gating.py new file mode 100644 index 0000000000..dd93f3c2f5 --- /dev/null +++ b/tests/test_entropy_grad_gating.py @@ -0,0 +1,148 @@ +"""Unit tests for entropy gradient gating in ``calculate_log_probs_and_entropy``. + +Background +---------- +Entropy enters the policy loss as ``loss = pg_loss - args.entropy_coef * entropy_loss`` +(see ``policy_loss_function`` in ``slime/backends/megatron_utils/loss.py``). When +``entropy_coef == 0`` the entropy term contributes no gradient, yet the entropy was still +computed *with* autograd enabled — retaining the ``[num_tokens, vocab]`` entropy graph and a +defensive ``logits.clone()`` per chunk. For long multi-turn rollouts that activation memory +dominates and can OOM. + +The fix adds a ``need_entropy_grad`` flag: when ``False`` the entropy is computed under +``torch.no_grad()`` and the clone is skipped. These tests pin the contract: + +1. The entropy *values* are identical whether or not grad is tracked. +2. With ``need_entropy_grad=False`` the returned entropy carries no autograd graph. +3. With ``need_entropy_grad=True`` the entropy is differentiable w.r.t. the logits. +4. The log-prob output is unaffected by the entropy flag. + +These run on CPU with a ``world_size=1`` gloo group (the TP all-reduce in the entropy kernel +is a no-op for a single rank). ``compute_log_probs`` is monkeypatched to a cheap stub so the +test does not depend on Megatron's fused cross-entropy (absent from the CPU CI image); the +gating logic under test lives entirely in the entropy branch. +""" + +import os + +import pytest +import torch +import torch.distributed as dist + +import slime.utils.ppo_utils as ppo_utils +from slime.utils.ppo_utils import calculate_log_probs_and_entropy + + +@pytest.fixture(scope="module") +def single_rank_group(): + """A real ``world_size=1`` gloo group for the entropy kernel's ``all_reduce`` calls.""" + if dist.is_initialized(): + yield dist.group.WORLD + return + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29555") + dist.init_process_group(backend="gloo", rank=0, world_size=1) + try: + yield dist.group.WORLD + finally: + dist.destroy_process_group() + + +@pytest.fixture(autouse=True) +def stub_compute_log_probs(monkeypatch): + """Avoid the Megatron fused-CE dependency: log-probs are not what these tests exercise.""" + + def _fake_log_probs(logits, tokens, process_group, keep_mask=None): + # Shape/contract-compatible stub: one scalar log-prob per token, differentiable. + return logits.sum(dim=-1) + + monkeypatch.setattr(ppo_utils, "compute_log_probs", _fake_log_probs) + + +@pytest.mark.unit +@pytest.mark.parametrize("chunk_size", [-1, 4]) +def test_entropy_values_match_regardless_of_grad(single_rank_group, chunk_size): + """no_grad entropy must equal grad-tracked entropy numerically (only the graph differs).""" + torch.manual_seed(0) + num_tokens, vocab = 11, 32 + logits = torch.randn(num_tokens, vocab, dtype=torch.float32) + tokens = torch.randint(0, vocab, (num_tokens,)) + + _, entropy_grad = calculate_log_probs_and_entropy( + logits, tokens, single_rank_group, with_entropy=True, chunk_size=chunk_size, need_entropy_grad=True + ) + _, entropy_nograd = calculate_log_probs_and_entropy( + logits, tokens, single_rank_group, with_entropy=True, chunk_size=chunk_size, need_entropy_grad=False + ) + + assert entropy_grad.shape == (num_tokens,) + assert entropy_nograd.shape == (num_tokens,) + torch.testing.assert_close(entropy_grad.detach(), entropy_nograd) + + +@pytest.mark.unit +@pytest.mark.parametrize("chunk_size", [-1, 4]) +def test_need_entropy_grad_false_detaches_graph(single_rank_group, chunk_size): + """need_entropy_grad=False -> entropy is a leaf with no autograd graph (memory is freed).""" + torch.manual_seed(1) + num_tokens, vocab = 9, 16 + logits = torch.randn(num_tokens, vocab, dtype=torch.float32, requires_grad=True) + tokens = torch.randint(0, vocab, (num_tokens,)) + + _, entropy = calculate_log_probs_and_entropy( + logits, tokens, single_rank_group, with_entropy=True, chunk_size=chunk_size, need_entropy_grad=False + ) + + assert entropy.grad_fn is None + assert not entropy.requires_grad + + +@pytest.mark.unit +@pytest.mark.parametrize("chunk_size", [-1, 4]) +def test_need_entropy_grad_true_is_differentiable(single_rank_group, chunk_size): + """need_entropy_grad=True -> entropy backpropagates to the logits (the entropy-bonus path).""" + torch.manual_seed(2) + num_tokens, vocab = 9, 16 + logits = torch.randn(num_tokens, vocab, dtype=torch.float32, requires_grad=True) + tokens = torch.randint(0, vocab, (num_tokens,)) + + _, entropy = calculate_log_probs_and_entropy( + logits, tokens, single_rank_group, with_entropy=True, chunk_size=chunk_size, need_entropy_grad=True + ) + + assert entropy.requires_grad + assert entropy.grad_fn is not None + entropy.sum().backward() + assert logits.grad is not None + assert torch.isfinite(logits.grad).all() + assert logits.grad.abs().sum() > 0 + + +@pytest.mark.unit +@pytest.mark.parametrize("chunk_size", [-1, 4]) +def test_log_probs_unaffected_by_entropy_flag(single_rank_group, chunk_size): + """The log-prob output must not depend on need_entropy_grad.""" + torch.manual_seed(3) + num_tokens, vocab = 11, 16 + logits = torch.randn(num_tokens, vocab, dtype=torch.float32) + tokens = torch.randint(0, vocab, (num_tokens,)) + + log_prob_a, _ = calculate_log_probs_and_entropy( + logits, tokens, single_rank_group, with_entropy=True, chunk_size=chunk_size, need_entropy_grad=True + ) + log_prob_b, _ = calculate_log_probs_and_entropy( + logits, tokens, single_rank_group, with_entropy=True, chunk_size=chunk_size, need_entropy_grad=False + ) + torch.testing.assert_close(log_prob_a, log_prob_b) + + +@pytest.mark.unit +def test_empty_input_returns_empty_entropy(single_rank_group): + """Zero-length response (all positions masked away upstream) is handled without compute.""" + logits = torch.zeros(0, 16, dtype=torch.float32) + tokens = torch.zeros(0, dtype=torch.long) + + _, entropy = calculate_log_probs_and_entropy( + logits, tokens, single_rank_group, with_entropy=True, need_entropy_grad=False + ) + assert entropy.shape == (0,)