Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,10 @@ def get_log_probs_and_entropy(
per-sample slicing) so backward traverses ``[T, V]`` only once, then
extracts per-sample response portions.

When ``entropy_coef == 0``, entropy is computed under ``torch.no_grad()``
to avoid retaining the computation graph and to skip cloning.
When ``entropy_coef == 0``, entropy is still returned as a metric but is computed under
``torch.no_grad()`` (and without the defensive logits clone), since it is multiplied out of
the loss and therefore needs no gradient. This avoids retaining the ``[T, V]`` entropy graph,
which is a dominant activation-memory cost for long multi-turn rollouts.
"""
assert non_loss_data
assert logits.dtype == torch.float32, f"{logits.dtype}"
Expand Down Expand Up @@ -522,13 +524,19 @@ def get_log_probs_and_entropy(
)

# --- compute on full [T,V] logits at once via calculate_log_probs_and_entropy ---
# Entropy only contributes a gradient when it is weighted into the loss (``entropy_coef != 0``).
# When the coefficient is 0 the entropy is logged as a metric but multiplied out of the loss, so
# we compute it under ``no_grad`` to avoid retaining the [T, V] entropy graph (a major activation-
# memory cost / OOM source for long rollouts).
need_entropy_grad = with_entropy and getattr(args, "entropy_coef", 0.0) != 0
log_prob_full, entropy_full = calculate_log_probs_and_entropy(
logits,
full_tokens,
tp_group,
with_entropy=with_entropy,
chunk_size=chunk_size,
log_prob_keep_mask=top_p_keep_mask,
need_entropy_grad=need_entropy_grad,
)
log_prob_full = log_prob_full.squeeze(-1) # [T, 1] -> [T]

Expand Down
36 changes: 28 additions & 8 deletions slime/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# and https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ppo_utils/experience_maker.py

from argparse import Namespace
from contextlib import nullcontext

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -690,10 +691,27 @@ def chunked_gae(


def calculate_log_probs_and_entropy(
logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1, log_prob_keep_mask=None
logits,
tokens,
tp_group,
with_entropy: bool = False,
chunk_size: int = -1,
log_prob_keep_mask=None,
need_entropy_grad: bool = True,
):
"""Compute per-token log-probs and (optionally) entropy from vocab-parallel logits.

When ``with_entropy`` is set but ``need_entropy_grad`` is ``False`` (i.e. the caller
multiplies entropy by ``entropy_coef == 0``, so it never contributes a gradient), the
entropy is computed under ``torch.no_grad()``. This avoids retaining the autograd graph
for the ``[num_tokens, vocab]`` entropy computation, which otherwise dominates activation
memory for long multi-turn rollouts and can OOM. Under ``no_grad`` we also skip the
defensive ``logits.clone()`` (only needed to keep the backward's in-place ops off the
shared logits tensor), saving a full logits-sized allocation per chunk.
"""
logits = logits.contiguous()
entropy = None
entropy_ctx = nullcontext() if need_entropy_grad else torch.no_grad()
if logits.size(0) != 0:
if chunk_size > 0:
num_chunks = (logits.size(0) - 1) // chunk_size + 1
Expand All @@ -704,11 +722,12 @@ def calculate_log_probs_and_entropy(
)

if with_entropy:
entropys = []
for logits_chunk in logits_chunks:
entropy_input = logits_chunk.clone()
entropys.append(compute_entropy_from_logits(entropy_input, tp_group))
entropy = torch.cat(entropys, dim=0)
with entropy_ctx:
entropys = []
for logits_chunk in logits_chunks:
entropy_input = logits_chunk.clone() if need_entropy_grad else logits_chunk
entropys.append(compute_entropy_from_logits(entropy_input, tp_group))
entropy = torch.cat(entropys, dim=0)

log_probs = []
for tokens_chunk, logits_chunk, mask_chunk in zip(tokens_chunks, logits_chunks, mask_chunks, strict=True):
Expand All @@ -717,8 +736,9 @@ def calculate_log_probs_and_entropy(
log_prob = torch.cat(log_probs, dim=0)
else:
if with_entropy:
entropy_input = logits.clone()
entropy = compute_entropy_from_logits(entropy_input, tp_group)
with entropy_ctx:
entropy_input = logits.clone() if need_entropy_grad else logits
entropy = compute_entropy_from_logits(entropy_input, tp_group)

log_prob = compute_log_probs(logits.clone(), tokens, tp_group, keep_mask=log_prob_keep_mask)
else:
Expand Down
148 changes: 148 additions & 0 deletions tests/test_entropy_grad_gating.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Unit tests for entropy gradient gating in ``calculate_log_probs_and_entropy``.

Background
----------
Entropy enters the policy loss as ``loss = pg_loss - args.entropy_coef * entropy_loss``
(see ``policy_loss_function`` in ``slime/backends/megatron_utils/loss.py``). When
``entropy_coef == 0`` the entropy term contributes no gradient, yet the entropy was still
computed *with* autograd enabled — retaining the ``[num_tokens, vocab]`` entropy graph and a
defensive ``logits.clone()`` per chunk. For long multi-turn rollouts that activation memory
dominates and can OOM.

The fix adds a ``need_entropy_grad`` flag: when ``False`` the entropy is computed under
``torch.no_grad()`` and the clone is skipped. These tests pin the contract:

1. The entropy *values* are identical whether or not grad is tracked.
2. With ``need_entropy_grad=False`` the returned entropy carries no autograd graph.
3. With ``need_entropy_grad=True`` the entropy is differentiable w.r.t. the logits.
4. The log-prob output is unaffected by the entropy flag.

These run on CPU with a ``world_size=1`` gloo group (the TP all-reduce in the entropy kernel
is a no-op for a single rank). ``compute_log_probs`` is monkeypatched to a cheap stub so the
test does not depend on Megatron's fused cross-entropy (absent from the CPU CI image); the
gating logic under test lives entirely in the entropy branch.
"""

import os

import pytest
import torch
import torch.distributed as dist

import slime.utils.ppo_utils as ppo_utils
from slime.utils.ppo_utils import calculate_log_probs_and_entropy


@pytest.fixture(scope="module")
def single_rank_group():
"""A real ``world_size=1`` gloo group for the entropy kernel's ``all_reduce`` calls."""
if dist.is_initialized():
yield dist.group.WORLD
return
os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
os.environ.setdefault("MASTER_PORT", "29555")
dist.init_process_group(backend="gloo", rank=0, world_size=1)
try:
yield dist.group.WORLD
finally:
dist.destroy_process_group()


@pytest.fixture(autouse=True)
def stub_compute_log_probs(monkeypatch):
"""Avoid the Megatron fused-CE dependency: log-probs are not what these tests exercise."""

def _fake_log_probs(logits, tokens, process_group, keep_mask=None):
# Shape/contract-compatible stub: one scalar log-prob per token, differentiable.
return logits.sum(dim=-1)

monkeypatch.setattr(ppo_utils, "compute_log_probs", _fake_log_probs)


@pytest.mark.unit
@pytest.mark.parametrize("chunk_size", [-1, 4])
def test_entropy_values_match_regardless_of_grad(single_rank_group, chunk_size):
"""no_grad entropy must equal grad-tracked entropy numerically (only the graph differs)."""
torch.manual_seed(0)
num_tokens, vocab = 11, 32
logits = torch.randn(num_tokens, vocab, dtype=torch.float32)
tokens = torch.randint(0, vocab, (num_tokens,))

_, entropy_grad = calculate_log_probs_and_entropy(
logits, tokens, single_rank_group, with_entropy=True, chunk_size=chunk_size, need_entropy_grad=True
)
_, entropy_nograd = calculate_log_probs_and_entropy(
logits, tokens, single_rank_group, with_entropy=True, chunk_size=chunk_size, need_entropy_grad=False
)

assert entropy_grad.shape == (num_tokens,)
assert entropy_nograd.shape == (num_tokens,)
torch.testing.assert_close(entropy_grad.detach(), entropy_nograd)


@pytest.mark.unit
@pytest.mark.parametrize("chunk_size", [-1, 4])
def test_need_entropy_grad_false_detaches_graph(single_rank_group, chunk_size):
"""need_entropy_grad=False -> entropy is a leaf with no autograd graph (memory is freed)."""
torch.manual_seed(1)
num_tokens, vocab = 9, 16
logits = torch.randn(num_tokens, vocab, dtype=torch.float32, requires_grad=True)
tokens = torch.randint(0, vocab, (num_tokens,))

_, entropy = calculate_log_probs_and_entropy(
logits, tokens, single_rank_group, with_entropy=True, chunk_size=chunk_size, need_entropy_grad=False
)

assert entropy.grad_fn is None
assert not entropy.requires_grad


@pytest.mark.unit
@pytest.mark.parametrize("chunk_size", [-1, 4])
def test_need_entropy_grad_true_is_differentiable(single_rank_group, chunk_size):
"""need_entropy_grad=True -> entropy backpropagates to the logits (the entropy-bonus path)."""
torch.manual_seed(2)
num_tokens, vocab = 9, 16
logits = torch.randn(num_tokens, vocab, dtype=torch.float32, requires_grad=True)
tokens = torch.randint(0, vocab, (num_tokens,))

_, entropy = calculate_log_probs_and_entropy(
logits, tokens, single_rank_group, with_entropy=True, chunk_size=chunk_size, need_entropy_grad=True
)

assert entropy.requires_grad
assert entropy.grad_fn is not None
entropy.sum().backward()
assert logits.grad is not None
assert torch.isfinite(logits.grad).all()
assert logits.grad.abs().sum() > 0


@pytest.mark.unit
@pytest.mark.parametrize("chunk_size", [-1, 4])
def test_log_probs_unaffected_by_entropy_flag(single_rank_group, chunk_size):
"""The log-prob output must not depend on need_entropy_grad."""
torch.manual_seed(3)
num_tokens, vocab = 11, 16
logits = torch.randn(num_tokens, vocab, dtype=torch.float32)
tokens = torch.randint(0, vocab, (num_tokens,))

log_prob_a, _ = calculate_log_probs_and_entropy(
logits, tokens, single_rank_group, with_entropy=True, chunk_size=chunk_size, need_entropy_grad=True
)
log_prob_b, _ = calculate_log_probs_and_entropy(
logits, tokens, single_rank_group, with_entropy=True, chunk_size=chunk_size, need_entropy_grad=False
)
torch.testing.assert_close(log_prob_a, log_prob_b)


@pytest.mark.unit
def test_empty_input_returns_empty_entropy(single_rank_group):
"""Zero-length response (all positions masked away upstream) is handled without compute."""
logits = torch.zeros(0, 16, dtype=torch.float32)
tokens = torch.zeros(0, dtype=torch.long)

_, entropy = calculate_log_probs_and_entropy(
logits, tokens, single_rank_group, with_entropy=True, need_entropy_grad=False
)
assert entropy.shape == (0,)
Loading