Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions slime/backends/megatron_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,13 +522,17 @@ def get_log_probs_and_entropy(
)

# --- compute on full [T,V] logits at once via calculate_log_probs_and_entropy ---
# When entropy_coef == 0, entropy is only used for logging, so compute it without
# building the autograd graph to reduce memory pressure.
entropy_no_grad = with_entropy and args.entropy_coef == 0
log_prob_full, entropy_full = calculate_log_probs_and_entropy(
logits,
full_tokens,
tp_group,
with_entropy=with_entropy,
chunk_size=chunk_size,
log_prob_keep_mask=top_p_keep_mask,
entropy_no_grad=entropy_no_grad,
)
log_prob_full = log_prob_full.squeeze(-1) # [T, 1] -> [T]

Expand Down
29 changes: 21 additions & 8 deletions slime/utils/ppo_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Adapt from https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/models/utils.py
# and https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ppo_utils/experience_maker.py

import contextlib
from argparse import Namespace

import torch
Expand Down Expand Up @@ -690,8 +691,18 @@ def chunked_gae(


def calculate_log_probs_and_entropy(
logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1, log_prob_keep_mask=None
logits,
tokens,
tp_group,
with_entropy: bool = False,
chunk_size: int = -1,
log_prob_keep_mask=None,
entropy_no_grad: bool = False,
):
# When entropy is only needed for logging (e.g. ``entropy_coef == 0``), compute it under
# ``torch.no_grad()`` so the autograd graph is not retained, reducing memory pressure.
entropy_ctx = torch.no_grad() if entropy_no_grad else contextlib.nullcontext()

logits = logits.contiguous()
entropy = None
if logits.size(0) != 0:
Expand All @@ -704,11 +715,12 @@ def calculate_log_probs_and_entropy(
)

if with_entropy:
entropys = []
for logits_chunk in logits_chunks:
entropy_input = logits_chunk.clone()
entropys.append(compute_entropy_from_logits(entropy_input, tp_group))
entropy = torch.cat(entropys, dim=0)
with entropy_ctx:
entropys = []
for logits_chunk in logits_chunks:
entropy_input = logits_chunk.clone()
entropys.append(compute_entropy_from_logits(entropy_input, tp_group))
entropy = torch.cat(entropys, dim=0)

log_probs = []
for tokens_chunk, logits_chunk, mask_chunk in zip(tokens_chunks, logits_chunks, mask_chunks, strict=True):
Expand All @@ -717,8 +729,9 @@ def calculate_log_probs_and_entropy(
log_prob = torch.cat(log_probs, dim=0)
else:
if with_entropy:
entropy_input = logits.clone()
entropy = compute_entropy_from_logits(entropy_input, tp_group)
with entropy_ctx:
entropy_input = logits.clone()
entropy = compute_entropy_from_logits(entropy_input, tp_group)

log_prob = compute_log_probs(logits.clone(), tokens, tp_group, keep_mask=log_prob_keep_mask)
else:
Expand Down
Loading