From 1094c254134c6501a8cbc42022742ace509ca833 Mon Sep 17 00:00:00 2001 From: none0663 Date: Tue, 23 Jun 2026 19:31:04 +0800 Subject: [PATCH] Reduce entropy logging memory when entropy coef is zero --- slime/backends/megatron_utils/loss.py | 4 ++++ slime/utils/ppo_utils.py | 29 +++++++++++++++++++-------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 72afdfa66c..d1e70edb7e 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -522,6 +522,9 @@ def get_log_probs_and_entropy( ) # --- compute on full [T,V] logits at once via calculate_log_probs_and_entropy --- + # When entropy_coef == 0, entropy is only used for logging, so compute it without + # building the autograd graph to reduce memory pressure. + entropy_no_grad = with_entropy and args.entropy_coef == 0 log_prob_full, entropy_full = calculate_log_probs_and_entropy( logits, full_tokens, @@ -529,6 +532,7 @@ def get_log_probs_and_entropy( with_entropy=with_entropy, chunk_size=chunk_size, log_prob_keep_mask=top_p_keep_mask, + entropy_no_grad=entropy_no_grad, ) log_prob_full = log_prob_full.squeeze(-1) # [T, 1] -> [T] diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index a4dd0c0181..f1ded4cfe9 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -1,6 +1,7 @@ # Adapt from https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/models/utils.py # and https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ppo_utils/experience_maker.py +import contextlib from argparse import Namespace import torch @@ -690,8 +691,18 @@ def chunked_gae( def calculate_log_probs_and_entropy( - logits, tokens, tp_group, with_entropy: bool = False, chunk_size: int = -1, log_prob_keep_mask=None + logits, + tokens, + tp_group, + with_entropy: bool = False, + chunk_size: int = -1, + log_prob_keep_mask=None, + entropy_no_grad: bool = False, ): + # When entropy is only needed for logging (e.g. ``entropy_coef == 0``), compute it under + # ``torch.no_grad()`` so the autograd graph is not retained, reducing memory pressure. + entropy_ctx = torch.no_grad() if entropy_no_grad else contextlib.nullcontext() + logits = logits.contiguous() entropy = None if logits.size(0) != 0: @@ -704,11 +715,12 @@ def calculate_log_probs_and_entropy( ) if with_entropy: - entropys = [] - for logits_chunk in logits_chunks: - entropy_input = logits_chunk.clone() - entropys.append(compute_entropy_from_logits(entropy_input, tp_group)) - entropy = torch.cat(entropys, dim=0) + with entropy_ctx: + entropys = [] + for logits_chunk in logits_chunks: + entropy_input = logits_chunk.clone() + entropys.append(compute_entropy_from_logits(entropy_input, tp_group)) + entropy = torch.cat(entropys, dim=0) log_probs = [] for tokens_chunk, logits_chunk, mask_chunk in zip(tokens_chunks, logits_chunks, mask_chunks, strict=True): @@ -717,8 +729,9 @@ def calculate_log_probs_and_entropy( log_prob = torch.cat(log_probs, dim=0) else: if with_entropy: - entropy_input = logits.clone() - entropy = compute_entropy_from_logits(entropy_input, tp_group) + with entropy_ctx: + entropy_input = logits.clone() + entropy = compute_entropy_from_logits(entropy_input, tp_group) log_prob = compute_log_probs(logits.clone(), tokens, tp_group, keep_mask=log_prob_keep_mask) else: