add support for SPPO

This commit is contained in:
Wing Lian
2024-05-02 08:56:15 -04:00
parent 3367fca732
commit 7fea5822f0
5 changed files with 76 additions and 5 deletions

View File

@@ -138,7 +138,7 @@ test_datasets:
data_files:
- /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto_pair'
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo'
rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing

View File

@@ -14,7 +14,7 @@ from collections import defaultdict
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import Dict, List, Literal, Optional, Type, Union
from typing import Dict, List, Literal, Optional, Tuple, Type, Union
import torch
import transformers
@@ -817,6 +817,70 @@ class AxolotlDPOTrainer(DPOTrainer):
res[key] = res[key][1:]
return res
def dpo_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the DPO loss for a batch of policy and reference model log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the DPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
if self.loss_type not in ["sigmoid", "hinge", "ipo", "kto_pair"]:
return super().dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
# The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative DPO loss.
if self.loss_type == "sppo":
# Calculate a and b
a = self.beta * ( # pylint: disable=invalid-name
policy_chosen_logps - reference_chosen_logps
)
b = self.beta * ( # pylint: disable=invalid-name
policy_rejected_logps - reference_rejected_logps
)
# Compute the SPPO loss
losses = (a - 0.5) ** 2 + (b + 0.5) ** 2
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'sppo']"
)
chosen_rewards = (
self.beta
* (
policy_chosen_logps.to(self.accelerator.device)
- reference_chosen_logps.to(self.accelerator.device)
).detach()
)
rejected_rewards = (
self.beta
* (
policy_rejected_logps.to(self.accelerator.device)
- reference_rejected_logps.to(self.accelerator.device)
).detach()
)
return losses, chosen_rewards, rejected_rewards
class AxolotlORPOTrainer(ORPOTrainer):
"""
@@ -1552,6 +1616,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
elif self.cfg.rl == "kto_pair":
dpo_trainer_kwargs["loss_type"] = "kto_pair"
elif self.cfg.rl == "sppo":
dpo_trainer_kwargs["loss_type"] = "sppo"
if self.eval_dataset:
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config:
@@ -1560,7 +1626,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]:
trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
trainer_cls_args = [self.model, self.model_ref]

View File

@@ -133,6 +133,7 @@ class RLType(str, Enum):
ipo = "ipo" # pylint: disable=invalid-name
kto_pair = "kto_pair" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
sppo = "sppo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):

View File

@@ -789,7 +789,11 @@ def load_model(
if not reference_model or cfg.lora_model_dir:
# if we're not loading the reference model, then we're loading the model for training
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
if (
cfg.adapter
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]
and not cfg.merge_lora
):
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
else:
model, lora_config = load_adapter(model, cfg, cfg.adapter)

View File

@@ -438,7 +438,7 @@ def prepare_optim_env(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]