diff --git a/docs/config.qmd b/docs/config.qmd index 570a173f9..800293535 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -138,7 +138,7 @@ test_datasets: data_files: - /workspace/data/eval.jsonl -# use RL training: 'dpo', 'ipo', 'kto_pair' +# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo' rl: # Saves the desired chat template to the tokenizer_config.json for easier inferencing diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 742a88633..cc53fb79b 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -14,7 +14,7 @@ from collections import defaultdict from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import Dict, List, Literal, Optional, Type, Union +from typing import Dict, List, Literal, Optional, Tuple, Type, Union import torch import transformers @@ -817,6 +817,70 @@ class AxolotlDPOTrainer(DPOTrainer): res[key] = res[key][1:] return res + def dpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the DPO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) + reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the DPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + """ + if self.loss_type not in ["sigmoid", "hinge", "ipo", "kto_pair"]: + return super().dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + + # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. + # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and + # calculates a conservative DPO loss. + if self.loss_type == "sppo": + # Calculate a and b + a = self.beta * ( # pylint: disable=invalid-name + policy_chosen_logps - reference_chosen_logps + ) + b = self.beta * ( # pylint: disable=invalid-name + policy_rejected_logps - reference_rejected_logps + ) + + # Compute the SPPO loss + losses = (a - 0.5) ** 2 + (b + 0.5) ** 2 + else: + raise ValueError( + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'sppo']" + ) + + chosen_rewards = ( + self.beta + * ( + policy_chosen_logps.to(self.accelerator.device) + - reference_chosen_logps.to(self.accelerator.device) + ).detach() + ) + rejected_rewards = ( + self.beta + * ( + policy_rejected_logps.to(self.accelerator.device) + - reference_rejected_logps.to(self.accelerator.device) + ).detach() + ) + + return losses, chosen_rewards, rejected_rewards + class AxolotlORPOTrainer(ORPOTrainer): """ @@ -1552,6 +1616,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing elif self.cfg.rl == "kto_pair": dpo_trainer_kwargs["loss_type"] = "kto_pair" + elif self.cfg.rl == "sppo": + dpo_trainer_kwargs["loss_type"] = "sppo" if self.eval_dataset: dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset if self.cfg.adapter and self.peft_config: @@ -1560,7 +1626,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs[ "precompute_ref_log_probs" ] = self.cfg.precompute_ref_log_probs - if self.cfg.rl in ["dpo", "ipo", "kto_pair"]: + if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]: trainer_cls = AxolotlDPOTrainer dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 trainer_cls_args = [self.model, self.model_ref] diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 419deee58..53d60e76c 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -133,6 +133,7 @@ class RLType(str, Enum): ipo = "ipo" # pylint: disable=invalid-name kto_pair = "kto_pair" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name + sppo = "sppo" # pylint: disable=invalid-name class ChatTemplate(str, Enum): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 8537b7e75..f0ae55a73 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -789,7 +789,11 @@ def load_model( if not reference_model or cfg.lora_model_dir: # if we're not loading the reference model, then we're loading the model for training # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config - if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora: + if ( + cfg.adapter + and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"] + and not cfg.merge_lora + ): _, lora_config = load_lora(model, cfg, inference=False, config_only=True) else: model, lora_config = load_adapter(model, cfg, cfg.adapter) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 2e3728cc8..fe1f6e0bd 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -438,7 +438,7 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]: + if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo"]: trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2]