add support for SPPO
This commit is contained in:
@@ -138,7 +138,7 @@ test_datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- /workspace/data/eval.jsonl
|
- /workspace/data/eval.jsonl
|
||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto_pair'
|
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo'
|
||||||
rl:
|
rl:
|
||||||
|
|
||||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from collections import defaultdict
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Literal, Optional, Type, Union
|
from typing import Dict, List, Literal, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@@ -817,6 +817,70 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|||||||
res[key] = res[key][1:]
|
res[key] = res[key][1:]
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def dpo_loss(
|
||||||
|
self,
|
||||||
|
policy_chosen_logps: torch.FloatTensor,
|
||||||
|
policy_rejected_logps: torch.FloatTensor,
|
||||||
|
reference_chosen_logps: torch.FloatTensor,
|
||||||
|
reference_rejected_logps: torch.FloatTensor,
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
"""Compute the DPO loss for a batch of policy and reference model log probabilities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
||||||
|
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
||||||
|
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
|
||||||
|
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
||||||
|
The losses tensor contains the DPO loss for each example in the batch.
|
||||||
|
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
||||||
|
"""
|
||||||
|
if self.loss_type not in ["sigmoid", "hinge", "ipo", "kto_pair"]:
|
||||||
|
return super().dpo_loss(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
reference_chosen_logps,
|
||||||
|
reference_rejected_logps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
|
||||||
|
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
|
||||||
|
# calculates a conservative DPO loss.
|
||||||
|
if self.loss_type == "sppo":
|
||||||
|
# Calculate a and b
|
||||||
|
a = self.beta * ( # pylint: disable=invalid-name
|
||||||
|
policy_chosen_logps - reference_chosen_logps
|
||||||
|
)
|
||||||
|
b = self.beta * ( # pylint: disable=invalid-name
|
||||||
|
policy_rejected_logps - reference_rejected_logps
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute the SPPO loss
|
||||||
|
losses = (a - 0.5) ** 2 + (b + 0.5) ** 2
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'sppo']"
|
||||||
|
)
|
||||||
|
|
||||||
|
chosen_rewards = (
|
||||||
|
self.beta
|
||||||
|
* (
|
||||||
|
policy_chosen_logps.to(self.accelerator.device)
|
||||||
|
- reference_chosen_logps.to(self.accelerator.device)
|
||||||
|
).detach()
|
||||||
|
)
|
||||||
|
rejected_rewards = (
|
||||||
|
self.beta
|
||||||
|
* (
|
||||||
|
policy_rejected_logps.to(self.accelerator.device)
|
||||||
|
- reference_rejected_logps.to(self.accelerator.device)
|
||||||
|
).detach()
|
||||||
|
)
|
||||||
|
|
||||||
|
return losses, chosen_rewards, rejected_rewards
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(ORPOTrainer):
|
class AxolotlORPOTrainer(ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1552,6 +1616,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
elif self.cfg.rl == "kto_pair":
|
elif self.cfg.rl == "kto_pair":
|
||||||
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
||||||
|
elif self.cfg.rl == "sppo":
|
||||||
|
dpo_trainer_kwargs["loss_type"] = "sppo"
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
@@ -1560,7 +1626,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs[
|
||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|||||||
@@ -133,6 +133,7 @@ class RLType(str, Enum):
|
|||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
|
sppo = "sppo" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplate(str, Enum):
|
class ChatTemplate(str, Enum):
|
||||||
|
|||||||
@@ -789,7 +789,11 @@ def load_model(
|
|||||||
if not reference_model or cfg.lora_model_dir:
|
if not reference_model or cfg.lora_model_dir:
|
||||||
# if we're not loading the reference model, then we're loading the model for training
|
# if we're not loading the reference model, then we're loading the model for training
|
||||||
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
||||||
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
|
if (
|
||||||
|
cfg.adapter
|
||||||
|
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]
|
||||||
|
and not cfg.merge_lora
|
||||||
|
):
|
||||||
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
||||||
else:
|
else:
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|||||||
@@ -438,7 +438,7 @@ def prepare_optim_env(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
|
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
|
|||||||
Reference in New Issue
Block a user