remove override
This commit is contained in:
@@ -14,7 +14,7 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
from typing import Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@@ -817,70 +817,6 @@ class AxolotlDPOTrainer(DPOTrainer):
|
||||
res[key] = res[key][1:]
|
||||
return res
|
||||
|
||||
def dpo_loss(
|
||||
self,
|
||||
policy_chosen_logps: torch.FloatTensor,
|
||||
policy_rejected_logps: torch.FloatTensor,
|
||||
reference_chosen_logps: torch.FloatTensor,
|
||||
reference_rejected_logps: torch.FloatTensor,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Compute the DPO loss for a batch of policy and reference model log probabilities.
|
||||
|
||||
Args:
|
||||
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
||||
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
||||
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
|
||||
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
|
||||
|
||||
Returns:
|
||||
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
||||
The losses tensor contains the DPO loss for each example in the batch.
|
||||
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
||||
"""
|
||||
if self.loss_type in ["sigmoid", "hinge", "ipo", "kto_pair"]:
|
||||
return super().dpo_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
)
|
||||
|
||||
# The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
|
||||
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
|
||||
# calculates a conservative DPO loss.
|
||||
if self.loss_type == "sppo":
|
||||
# Calculate a and b
|
||||
a = self.beta * ( # pylint: disable=invalid-name
|
||||
policy_chosen_logps - reference_chosen_logps
|
||||
)
|
||||
b = self.beta * ( # pylint: disable=invalid-name
|
||||
policy_rejected_logps - reference_rejected_logps
|
||||
)
|
||||
|
||||
# Compute the SPPO loss
|
||||
losses = (a - 0.5) ** 2 + (b + 0.5) ** 2
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'sppo']"
|
||||
)
|
||||
|
||||
chosen_rewards = (
|
||||
self.beta
|
||||
* (
|
||||
policy_chosen_logps.to(self.accelerator.device)
|
||||
- reference_chosen_logps.to(self.accelerator.device)
|
||||
).detach()
|
||||
)
|
||||
rejected_rewards = (
|
||||
self.beta
|
||||
* (
|
||||
policy_rejected_logps.to(self.accelerator.device)
|
||||
- reference_rejected_logps.to(self.accelerator.device)
|
||||
).detach()
|
||||
)
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(ORPOTrainer):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user