Compare commits
1 Commits
jagged-res
...
feat/pref_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8428b3f2c7 |
@@ -14,17 +14,22 @@ import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from liger_kernel.chunked_loss.fused_linear_preference import (
|
||||
LigerFusedLinearPreferenceBase,
|
||||
)
|
||||
from packaging import version
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from torch import amp, nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import (
|
||||
@@ -1077,6 +1082,15 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
self.dataset_tags = dataset_tags
|
||||
self.optimizer = None
|
||||
|
||||
from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss
|
||||
|
||||
self.liger_loss = LigerFusedLinearDPOLoss(
|
||||
ignore_index=self.label_pad_token_id,
|
||||
beta=self.beta,
|
||||
compute_nll_loss=True, # not same as rpo_alpha hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None,
|
||||
use_ref_model=not self.reference_free,
|
||||
)
|
||||
|
||||
def create_optimizer(self):
|
||||
if self.args.loraplus_lr_ratio is None:
|
||||
return super().create_optimizer()
|
||||
@@ -1180,6 +1194,309 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
# transformers<=4.46
|
||||
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""Compute the DPO loss and other metrics using Liger kernel."""
|
||||
# return super().get_batch_loss_metrics(model, batch, train_eval)
|
||||
if not self.liger_loss:
|
||||
raise ValueError("Liger loss not initialized")
|
||||
|
||||
metrics = {}
|
||||
|
||||
model_output = self.concatenated_forward(model, batch)
|
||||
|
||||
# Get the lm_head weights and bias
|
||||
lin_weight = model.lm_head.weight
|
||||
lin_bias = getattr(model.lm_head, "bias", None)
|
||||
|
||||
hidden_states = model_output["hidden_states"]
|
||||
labels = model_output["labels"]
|
||||
|
||||
if not self.reference_free:
|
||||
# Adapted from DPO's compute_ref_log_probs
|
||||
compte_ref_context_manager = (
|
||||
amp.autocast("cuda")
|
||||
if self._peft_has_been_casted_to_bf16
|
||||
else nullcontext()
|
||||
)
|
||||
with torch.no_grad(), compte_ref_context_manager: # type: ignore
|
||||
if self.ref_model is None:
|
||||
with self.null_ref_context():
|
||||
ref_model_output = self.concatenated_forward(self.model, batch)
|
||||
ref_weight = self.model.lm_head.weight
|
||||
ref_bias = getattr(self.model.lm_head, "bias", None)
|
||||
|
||||
ref_hidden_states = ref_model_output["hidden_states"]
|
||||
|
||||
else:
|
||||
ref_model_output = self.concatenated_forward(self.ref_model, batch)
|
||||
ref_weight = self.ref_model.lm_head.weight
|
||||
ref_bias = getattr(self.ref_model.lm_head, "bias", None)
|
||||
|
||||
ref_hidden_states = ref_model_output["hidden_states"]
|
||||
(
|
||||
ref_chosen_logps,
|
||||
ref_rejected_logps,
|
||||
_ref_chosen_logits,
|
||||
_ref_rejected_logits,
|
||||
_ref_chosen_nll_loss,
|
||||
) = LigerFusedLinearPreferenceBase.chunk_forward(
|
||||
input_chunk=ref_hidden_states,
|
||||
weight=ref_weight,
|
||||
target_chunk=labels,
|
||||
bias=ref_bias,
|
||||
# ignore_index=ignore_index,
|
||||
compute_nll_loss=False,
|
||||
)
|
||||
|
||||
else:
|
||||
ref_hidden_states = None
|
||||
ref_weight = None
|
||||
ref_bias = None
|
||||
|
||||
# Compute loss using Liger kernel
|
||||
loss, return_vars = self.liger_loss(
|
||||
lin_weight=lin_weight,
|
||||
_input=hidden_states,
|
||||
target=labels,
|
||||
bias=lin_bias, # TODO: check whether to pass bias as FCLE doesn't
|
||||
ref_input=ref_hidden_states,
|
||||
ref_weight=ref_weight,
|
||||
ref_bias=ref_bias,
|
||||
)
|
||||
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits_mean,
|
||||
policy_rejected_logits_mean,
|
||||
policy_nll_loss,
|
||||
) = return_vars
|
||||
|
||||
# Calculate rewards
|
||||
if not self.reference_free:
|
||||
chosen_rewards = (
|
||||
self.beta * (policy_chosen_logps - (ref_chosen_logps)).detach()
|
||||
)
|
||||
rejected_rewards = (
|
||||
self.beta * (policy_rejected_logps - (ref_rejected_logps)).detach()
|
||||
)
|
||||
|
||||
else:
|
||||
chosen_rewards = self.beta * policy_chosen_logps
|
||||
rejected_rewards = self.beta * policy_rejected_logps
|
||||
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics.update(
|
||||
{
|
||||
f"{prefix}rewards/chosen": chosen_rewards.mean().cpu(),
|
||||
f"{prefix}rewards/rejected": rejected_rewards.mean().cpu(),
|
||||
f"{prefix}rewards/accuracies": reward_accuracies.mean().cpu(),
|
||||
f"{prefix}rewards/margins": (chosen_rewards - rejected_rewards)
|
||||
.mean()
|
||||
.cpu(),
|
||||
f"{prefix}logps/chosen": policy_chosen_logps.mean().cpu(),
|
||||
f"{prefix}logps/rejected": policy_rejected_logps.mean().cpu(),
|
||||
f"{prefix}logits/chosen": policy_chosen_logits_mean.cpu(),
|
||||
f"{prefix}logits/rejected": policy_rejected_logits_mean.cpu(),
|
||||
}
|
||||
)
|
||||
|
||||
if hasattr(self.args, "rpo_alpha") and self.args.rpo_alpha is not None:
|
||||
metrics[f"{prefix}nll_loss"] = policy_nll_loss.cpu()
|
||||
|
||||
# TODO: Handle use_weighting, aux_loss_enabled as in upstream
|
||||
|
||||
return loss, metrics
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
||||
):
|
||||
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
||||
|
||||
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
||||
|
||||
Overridden base function to return the hidden states and labels for the loss calculation.
|
||||
"""
|
||||
num_examples = batch["prompt_input_ids"].shape[0] # type: ignore
|
||||
|
||||
concatenated_batch = self.concatenated_inputs(
|
||||
batch, padding_value=self.padding_value
|
||||
)
|
||||
|
||||
model_kwargs = {}
|
||||
if self.aux_loss_enabled:
|
||||
model_kwargs["output_router_logits"] = True
|
||||
|
||||
# Add to get the hidden states for the loss
|
||||
model_kwargs["output_hidden_states"] = True
|
||||
|
||||
# Add the pixel values and attention masks for vision models
|
||||
if "pixel_values" in concatenated_batch:
|
||||
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
|
||||
if "pixel_attention_mask" in concatenated_batch:
|
||||
model_kwargs["pixel_attention_mask"] = concatenated_batch[
|
||||
"pixel_attention_mask"
|
||||
]
|
||||
if "image_sizes" in concatenated_batch:
|
||||
model_kwargs["image_sizes"] = concatenated_batch["image_sizes"]
|
||||
|
||||
prompt_input_ids = concatenated_batch["prompt_input_ids"]
|
||||
prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
|
||||
completion_input_ids = concatenated_batch["completion_input_ids"]
|
||||
completion_attention_mask = concatenated_batch["completion_attention_mask"]
|
||||
if self.is_encoder_decoder:
|
||||
labels = completion_input_ids
|
||||
labels[completion_attention_mask == 0] = self.label_pad_token_id
|
||||
outputs = model(
|
||||
input_ids=prompt_input_ids,
|
||||
attention_mask=prompt_attention_mask,
|
||||
labels=labels, # we need the labels for the logits to be returned
|
||||
**model_kwargs,
|
||||
)
|
||||
logits = outputs.logits
|
||||
hidden_states = outputs.decoder_hidden_states[-1]
|
||||
loss_mask = completion_attention_mask.bool()
|
||||
else:
|
||||
# Concatenate the prompt and completion inputs
|
||||
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
|
||||
attention_mask = torch.cat(
|
||||
(prompt_attention_mask, completion_attention_mask), dim=1
|
||||
)
|
||||
# Mask the prompt but not the completion for the loss
|
||||
loss_mask = torch.cat(
|
||||
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Flush left to reduce the memory usage
|
||||
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
|
||||
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
|
||||
for i in range(attention_mask.size(0)):
|
||||
first_one_idx = torch.nonzero(attention_mask[i])[0].item()
|
||||
input_ids[i] = torch.roll(input_ids[i], shifts=-first_one_idx) # type: ignore
|
||||
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx) # type: ignore
|
||||
loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx) # type: ignore
|
||||
|
||||
# Get the first column idx that is all zeros and remove every column after that
|
||||
empty_cols = torch.sum(attention_mask, dim=0) == 0
|
||||
first_empty_col = (
|
||||
torch.nonzero(empty_cols)[0].item()
|
||||
if empty_cols.any()
|
||||
else attention_mask.size(1)
|
||||
)
|
||||
input_ids = input_ids[:, :first_empty_col] # type: ignore
|
||||
attention_mask = attention_mask[:, :first_empty_col] # type: ignore
|
||||
loss_mask = loss_mask[:, :first_empty_col] # type: ignore
|
||||
|
||||
# Truncate right
|
||||
if self.args.max_length is not None:
|
||||
input_ids = input_ids[:, : self.args.max_length]
|
||||
attention_mask = attention_mask[:, : self.args.max_length]
|
||||
loss_mask = loss_mask[:, : self.args.max_length]
|
||||
|
||||
# if self.use_num_logits_to_keep:
|
||||
# # Compute num_logits_to_keep based on loss_mask pattern:
|
||||
# # [[0, 0, 0, x, x, x, x],
|
||||
# # [0, 0, 0, x, x, x, 0]]
|
||||
# # ^ start computing logits from here ([:, -(7-3+1):])
|
||||
# first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
|
||||
# num_logits_to_keep = loss_mask.shape[1] - first_compute_index
|
||||
# model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label
|
||||
|
||||
outputs = model(
|
||||
input_ids=input_ids, attention_mask=attention_mask, **model_kwargs
|
||||
)
|
||||
|
||||
# Offset the logits by one to align with the labels
|
||||
logits = outputs.logits[:, :-1, :]
|
||||
hidden_states = outputs.hidden_states[-1][:, :-1, :]
|
||||
labels = input_ids[:, 1:].clone()
|
||||
loss_mask = loss_mask[:, 1:].bool()
|
||||
|
||||
# if self.use_num_logits_to_keep:
|
||||
# # Align labels with logits
|
||||
# # logits: -, -, [x2, x3, x4, x5, x6]
|
||||
# # ^ --------- ^ after logits[:, :-1, :]
|
||||
# # labels: [y0, y1, y2, y3, y4, y5, y6]
|
||||
# # ^ --------- ^ with num_logits_to_keep=4, [:, -4:]
|
||||
# # loss_mask: [0, 0, 0, 1, 1, 1, 1]
|
||||
# labels = labels[:, -num_logits_to_keep:]
|
||||
# loss_mask = loss_mask[:, -num_logits_to_keep:]
|
||||
# hidden_states = hidden_states[:, -num_logits_to_keep:, :]
|
||||
|
||||
if logits.shape[:2] != labels.shape[:2]:
|
||||
# for llava, the returned logits include the image tokens (placed before the text tokens)
|
||||
seq_len = labels.shape[1]
|
||||
logits = logits[:, -seq_len:]
|
||||
hidden_states = hidden_states[:, -seq_len:]
|
||||
|
||||
# Compute the log probabilities of the labels
|
||||
labels[
|
||||
~loss_mask
|
||||
] = 0 # dummy token; we'll ignore the losses on these tokens later
|
||||
per_token_logps = torch.gather(
|
||||
logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)
|
||||
).squeeze(2)
|
||||
per_token_logps[~loss_mask] = 0
|
||||
all_logps = per_token_logps.sum(-1)
|
||||
|
||||
output = {}
|
||||
|
||||
if self.use_weighting:
|
||||
with torch.no_grad():
|
||||
# Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827
|
||||
logprobs = F.log_softmax(logits, dim=-1)
|
||||
weights_adjustment_factor = torch.logsumexp(
|
||||
2 * logprobs, dim=-1
|
||||
) # same as sum(probs**2) in log space
|
||||
per_token_logps_adjusted = per_token_logps - weights_adjustment_factor
|
||||
all_weights = (per_token_logps_adjusted * loss_mask).sum(
|
||||
-1
|
||||
) / loss_mask.sum(-1)
|
||||
chosen_weights = all_weights[:num_examples]
|
||||
rejected_weights = all_weights[num_examples:]
|
||||
output["policy_weights"] = torch.clamp(
|
||||
torch.exp(chosen_weights + rejected_weights), max=1
|
||||
)
|
||||
|
||||
if self.args.rpo_alpha is not None:
|
||||
# Only use the chosen logits for the RPO loss
|
||||
chosen_logits = logits[:num_examples]
|
||||
chosen_labels = labels[:num_examples]
|
||||
|
||||
# Compute the log probabilities of the labels
|
||||
output["nll_loss"] = F.cross_entropy(
|
||||
torch.flatten(chosen_logits, end_dim=1),
|
||||
torch.flatten(chosen_labels, end_dim=1),
|
||||
ignore_index=0,
|
||||
)
|
||||
|
||||
if self.loss_type == "ipo":
|
||||
all_logps = all_logps / loss_mask.sum(-1)
|
||||
|
||||
output["chosen_logps"] = all_logps[:num_examples]
|
||||
output["rejected_logps"] = all_logps[num_examples:]
|
||||
output["mean_chosen_logits"] = logits[:num_examples][
|
||||
loss_mask[:num_examples]
|
||||
].mean()
|
||||
output["mean_rejected_logits"] = logits[num_examples:][
|
||||
loss_mask[num_examples:]
|
||||
].mean()
|
||||
output["hidden_states"] = hidden_states
|
||||
output["labels"] = labels
|
||||
|
||||
if self.aux_loss_enabled:
|
||||
output["aux_loss"] = outputs.aux_loss
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
@@ -2163,6 +2480,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
|
||||
report_to = []
|
||||
if self.cfg.use_wandb:
|
||||
report_to.append("wandb")
|
||||
if self.cfg.wandb_name:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
|
||||
training_args_kwargs["report_to"] = report_to
|
||||
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
output_dir=self.cfg.output_dir,
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
|
||||
Reference in New Issue
Block a user