Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 81 additions & 27 deletions easyeditor/models/dpo/dpo_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, List, Tuple
from peft import get_peft_model, AdaLoraConfig, TaskType, get_peft_model_state_dict, set_peft_model_state_dict, LoraConfig
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

from ...util.device import normalize_device
Expand Down Expand Up @@ -103,19 +104,9 @@ def execute_dpo(
mask_token = -100
opt.zero_grad()

# Build inputs for positive samples
full_prompt_pos = [f"{p} {l}" for p, l in zip(txt_batch, tgt_pos_batch)]
tokens_pos = tok(full_prompt_pos, return_tensors="pt", padding=True, truncation=True)
tokens_pos["labels"] = tokens_pos["input_ids"].clone()
tokens_pos["labels"][tokens_pos["input_ids"] == tok.pad_token_id] = mask_token
tokens_pos = tokens_pos.to(device)

# Build inputs for negative samples
full_prompt_neg = [f"{p} {l}" for p, l in zip(txt_batch, tgt_neg_batch)]
tokens_neg = tok(full_prompt_neg, return_tensors="pt", padding=True, truncation=True)
tokens_neg["labels"] = tokens_neg["input_ids"].clone()
tokens_neg["labels"][tokens_neg["input_ids"] == tok.pad_token_id] = mask_token
tokens_neg = tokens_neg.to(device)
# Build inputs with labels only on completion tokens.
tokens_pos = build_completion_batch(tok, txt_batch, tgt_pos_batch, device, mask_token)
tokens_neg = build_completion_batch(tok, txt_batch, tgt_neg_batch, device, mask_token)

# Compute outputs with LoRA modules (current model)
outputs_pos = peft_model(**tokens_pos)
Expand All @@ -125,28 +116,28 @@ def execute_dpo(
peft_model.eval() # Switch to evaluation mode
peft_model.disable_adapter_layers() # Disable LoRA layers

with torch.no_grad():
ref_outputs_pos = peft_model(**tokens_pos)
ref_outputs_neg = peft_model(**tokens_neg)

peft_model.train() # Switch back to training mode
peft_model.enable_adapter_layers() # Enable LoRA layers
try:
with torch.no_grad():
ref_outputs_pos = peft_model(**tokens_pos)
ref_outputs_neg = peft_model(**tokens_neg)
finally:
peft_model.enable_adapter_layers() # Enable LoRA layers
peft_model.train() # Switch back to training mode

# Compute losses
lora_loss = outputs_pos.loss
beta = hparams.beta

ref_log_probs_pos = ref_outputs_pos.logits.log_softmax(-1)
ref_log_probs_neg = ref_outputs_neg.logits.log_softmax(-1)

log_probs_pos = outputs_pos.logits.log_softmax(-1)
log_probs_neg = outputs_neg.logits.log_softmax(-1)
policy_log_probs_pos = completion_log_probs(outputs_pos.logits, tokens_pos["labels"], mask_token)
policy_log_probs_neg = completion_log_probs(outputs_neg.logits, tokens_neg["labels"], mask_token)
ref_log_probs_pos = completion_log_probs(ref_outputs_pos.logits, tokens_pos["labels"], mask_token)
ref_log_probs_neg = completion_log_probs(ref_outputs_neg.logits, tokens_neg["labels"], mask_token)

dpo_advantage = beta * (
(log_probs_pos - ref_log_probs_pos).sum(-1) -
(log_probs_neg - ref_log_probs_neg).sum(-1)
(policy_log_probs_pos - ref_log_probs_pos) -
(policy_log_probs_neg - ref_log_probs_neg)
)
dpo_loss = -torch.mean(torch.log(torch.sigmoid(dpo_advantage)))
dpo_loss = -F.logsigmoid(dpo_advantage).mean()

# Total loss
loss = hparams.alpha * lora_loss + (1 - hparams.alpha) * dpo_loss
Expand All @@ -162,6 +153,69 @@ def execute_dpo(
return peft_model


def build_completion_batch(tok: AutoTokenizer, prompts: List[str], targets: List[str], device, mask_token: int):
"""
Tokenize prompt + completion pairs and mask labels outside completion spans.
DPO positive and negative completions may have different token lengths, so
the preference objective must operate on per-example sequence logprobs.
"""
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token

old_padding_side = tok.padding_side
tok.padding_side = "right"
try:
full_texts = [f"{prompt} {target}" for prompt, target in zip(prompts, targets)]
prefix_texts = [f"{prompt} " for prompt in prompts]
try:
tokens = tok(
full_texts,
return_tensors="pt",
padding=True,
truncation=True,
return_offsets_mapping=True,
)
offset_mapping = tokens.pop("offset_mapping")
except (NotImplementedError, TypeError):
tokens = tok(full_texts, return_tensors="pt", padding=True, truncation=True)
offset_mapping = None
labels = torch.full_like(tokens["input_ids"], mask_token)

for row, prefix_text in enumerate(prefix_texts):
seq_len = int(tokens["attention_mask"][row].sum().item())
if offset_mapping is not None:
prefix_len = len(prefix_text)
for col in range(seq_len):
start, end = offset_mapping[row][col].tolist()
if end > prefix_len and end > start:
labels[row, col] = tokens["input_ids"][row, col]
else:
prefix_ids = tok(prefix_text, add_special_tokens=True, truncation=True)["input_ids"]
completion_start = min(len(prefix_ids), seq_len)
if completion_start >= seq_len:
prompt_ids = tok(prompts[row], add_special_tokens=True, truncation=True)["input_ids"]
completion_start = min(len(prompt_ids), max(seq_len - 1, 0))
labels[row, completion_start:seq_len] = tokens["input_ids"][row, completion_start:seq_len]

tokens["labels"] = labels
return tokens.to(device)
finally:
tok.padding_side = old_padding_side


def completion_log_probs(logits: torch.Tensor, labels: torch.Tensor, mask_token: int) -> torch.Tensor:
"""Return summed log probability of the labeled completion tokens."""
shift_logits = logits[:, :-1, :]
shift_labels = labels[:, 1:]
label_mask = shift_labels.ne(mask_token)
safe_labels = shift_labels.masked_fill(~label_mask, 0)
token_log_probs = shift_logits.log_softmax(-1).gather(
dim=-1,
index=safe_labels.unsqueeze(-1),
).squeeze(-1)
return (token_log_probs * label_mask).sum(-1)


class AverageMeter:
"""Computes and stores the average and current value"""

Expand Down