diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b1554f0c0..55eecf839 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -14,7 +14,7 @@ from collections import defaultdict from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import Dict, List, Literal, Optional, Tuple, Type, Union +from typing import Dict, List, Literal, Optional, Type, Union import torch import transformers @@ -817,70 +817,6 @@ class AxolotlDPOTrainer(DPOTrainer): res[key] = res[key][1:] return res - def dpo_loss( - self, - policy_chosen_logps: torch.FloatTensor, - policy_rejected_logps: torch.FloatTensor, - reference_chosen_logps: torch.FloatTensor, - reference_rejected_logps: torch.FloatTensor, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """Compute the DPO loss for a batch of policy and reference model log probabilities. - - Args: - policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) - policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) - reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) - reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) - - Returns: - A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). - The losses tensor contains the DPO loss for each example in the batch. - The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. - """ - if self.loss_type in ["sigmoid", "hinge", "ipo", "kto_pair"]: - return super().dpo_loss( - policy_chosen_logps, - policy_rejected_logps, - reference_chosen_logps, - reference_rejected_logps, - ) - - # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. - # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and - # calculates a conservative DPO loss. - if self.loss_type == "sppo": - # Calculate a and b - a = self.beta * ( # pylint: disable=invalid-name - policy_chosen_logps - reference_chosen_logps - ) - b = self.beta * ( # pylint: disable=invalid-name - policy_rejected_logps - reference_rejected_logps - ) - - # Compute the SPPO loss - losses = (a - 0.5) ** 2 + (b + 0.5) ** 2 - else: - raise ValueError( - f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'sppo']" - ) - - chosen_rewards = ( - self.beta - * ( - policy_chosen_logps.to(self.accelerator.device) - - reference_chosen_logps.to(self.accelerator.device) - ).detach() - ) - rejected_rewards = ( - self.beta - * ( - policy_rejected_logps.to(self.accelerator.device) - - reference_rejected_logps.to(self.accelerator.device) - ).detach() - ) - - return losses, chosen_rewards, rejected_rewards - class AxolotlORPOTrainer(ORPOTrainer): """