From c0fe2d5e3eac3baf6ad18ec67ef896e49084c48b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BA=94=E9=A3=9E=E6=89=AC?= Date: Sat, 4 Jul 2026 19:20:07 +0800 Subject: [PATCH] Fix DPO completion log-prob loss --- easyeditor/models/dpo/dpo_main.py | 108 ++++++++++++++++++++++-------- 1 file changed, 81 insertions(+), 27 deletions(-) diff --git a/easyeditor/models/dpo/dpo_main.py b/easyeditor/models/dpo/dpo_main.py index ad111ecc..b060dd55 100644 --- a/easyeditor/models/dpo/dpo_main.py +++ b/easyeditor/models/dpo/dpo_main.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Tuple from peft import get_peft_model, AdaLoraConfig, TaskType, get_peft_model_state_dict, set_peft_model_state_dict, LoraConfig import torch +import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer from ...util.device import normalize_device @@ -103,19 +104,9 @@ def execute_dpo( mask_token = -100 opt.zero_grad() - # Build inputs for positive samples - full_prompt_pos = [f"{p} {l}" for p, l in zip(txt_batch, tgt_pos_batch)] - tokens_pos = tok(full_prompt_pos, return_tensors="pt", padding=True, truncation=True) - tokens_pos["labels"] = tokens_pos["input_ids"].clone() - tokens_pos["labels"][tokens_pos["input_ids"] == tok.pad_token_id] = mask_token - tokens_pos = tokens_pos.to(device) - - # Build inputs for negative samples - full_prompt_neg = [f"{p} {l}" for p, l in zip(txt_batch, tgt_neg_batch)] - tokens_neg = tok(full_prompt_neg, return_tensors="pt", padding=True, truncation=True) - tokens_neg["labels"] = tokens_neg["input_ids"].clone() - tokens_neg["labels"][tokens_neg["input_ids"] == tok.pad_token_id] = mask_token - tokens_neg = tokens_neg.to(device) + # Build inputs with labels only on completion tokens. + tokens_pos = build_completion_batch(tok, txt_batch, tgt_pos_batch, device, mask_token) + tokens_neg = build_completion_batch(tok, txt_batch, tgt_neg_batch, device, mask_token) # Compute outputs with LoRA modules (current model) outputs_pos = peft_model(**tokens_pos) @@ -125,28 +116,28 @@ def execute_dpo( peft_model.eval() # Switch to evaluation mode peft_model.disable_adapter_layers() # Disable LoRA layers - with torch.no_grad(): - ref_outputs_pos = peft_model(**tokens_pos) - ref_outputs_neg = peft_model(**tokens_neg) - - peft_model.train() # Switch back to training mode - peft_model.enable_adapter_layers() # Enable LoRA layers + try: + with torch.no_grad(): + ref_outputs_pos = peft_model(**tokens_pos) + ref_outputs_neg = peft_model(**tokens_neg) + finally: + peft_model.enable_adapter_layers() # Enable LoRA layers + peft_model.train() # Switch back to training mode # Compute losses lora_loss = outputs_pos.loss beta = hparams.beta - ref_log_probs_pos = ref_outputs_pos.logits.log_softmax(-1) - ref_log_probs_neg = ref_outputs_neg.logits.log_softmax(-1) - - log_probs_pos = outputs_pos.logits.log_softmax(-1) - log_probs_neg = outputs_neg.logits.log_softmax(-1) + policy_log_probs_pos = completion_log_probs(outputs_pos.logits, tokens_pos["labels"], mask_token) + policy_log_probs_neg = completion_log_probs(outputs_neg.logits, tokens_neg["labels"], mask_token) + ref_log_probs_pos = completion_log_probs(ref_outputs_pos.logits, tokens_pos["labels"], mask_token) + ref_log_probs_neg = completion_log_probs(ref_outputs_neg.logits, tokens_neg["labels"], mask_token) dpo_advantage = beta * ( - (log_probs_pos - ref_log_probs_pos).sum(-1) - - (log_probs_neg - ref_log_probs_neg).sum(-1) + (policy_log_probs_pos - ref_log_probs_pos) - + (policy_log_probs_neg - ref_log_probs_neg) ) - dpo_loss = -torch.mean(torch.log(torch.sigmoid(dpo_advantage))) + dpo_loss = -F.logsigmoid(dpo_advantage).mean() # Total loss loss = hparams.alpha * lora_loss + (1 - hparams.alpha) * dpo_loss @@ -162,6 +153,69 @@ def execute_dpo( return peft_model +def build_completion_batch(tok: AutoTokenizer, prompts: List[str], targets: List[str], device, mask_token: int): + """ + Tokenize prompt + completion pairs and mask labels outside completion spans. + DPO positive and negative completions may have different token lengths, so + the preference objective must operate on per-example sequence logprobs. + """ + if tok.pad_token_id is None: + tok.pad_token = tok.eos_token + + old_padding_side = tok.padding_side + tok.padding_side = "right" + try: + full_texts = [f"{prompt} {target}" for prompt, target in zip(prompts, targets)] + prefix_texts = [f"{prompt} " for prompt in prompts] + try: + tokens = tok( + full_texts, + return_tensors="pt", + padding=True, + truncation=True, + return_offsets_mapping=True, + ) + offset_mapping = tokens.pop("offset_mapping") + except (NotImplementedError, TypeError): + tokens = tok(full_texts, return_tensors="pt", padding=True, truncation=True) + offset_mapping = None + labels = torch.full_like(tokens["input_ids"], mask_token) + + for row, prefix_text in enumerate(prefix_texts): + seq_len = int(tokens["attention_mask"][row].sum().item()) + if offset_mapping is not None: + prefix_len = len(prefix_text) + for col in range(seq_len): + start, end = offset_mapping[row][col].tolist() + if end > prefix_len and end > start: + labels[row, col] = tokens["input_ids"][row, col] + else: + prefix_ids = tok(prefix_text, add_special_tokens=True, truncation=True)["input_ids"] + completion_start = min(len(prefix_ids), seq_len) + if completion_start >= seq_len: + prompt_ids = tok(prompts[row], add_special_tokens=True, truncation=True)["input_ids"] + completion_start = min(len(prompt_ids), max(seq_len - 1, 0)) + labels[row, completion_start:seq_len] = tokens["input_ids"][row, completion_start:seq_len] + + tokens["labels"] = labels + return tokens.to(device) + finally: + tok.padding_side = old_padding_side + + +def completion_log_probs(logits: torch.Tensor, labels: torch.Tensor, mask_token: int) -> torch.Tensor: + """Return summed log probability of the labeled completion tokens.""" + shift_logits = logits[:, :-1, :] + shift_labels = labels[:, 1:] + label_mask = shift_labels.ne(mask_token) + safe_labels = shift_labels.masked_fill(~label_mask, 0) + token_log_probs = shift_logits.log_softmax(-1).gather( + dim=-1, + index=safe_labels.unsqueeze(-1), + ).squeeze(-1) + return (token_log_probs * label_mask).sum(-1) + + class AverageMeter: """Computes and stores the average and current value"""