remove override

This commit is contained in:
Wing Lian
2024-05-02 11:01:20 -04:00
parent df645906eb
commit b301068098

View File

@@ -14,7 +14,7 @@ from collections import defaultdict
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple, Type, Union
from typing import Dict, List, Literal, Optional, Type, Union
import torch
import transformers
@@ -817,70 +817,6 @@ class AxolotlDPOTrainer(DPOTrainer):
res[key] = res[key][1:]
return res
def dpo_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the DPO loss for a batch of policy and reference model log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the DPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
if self.loss_type in ["sigmoid", "hinge", "ipo", "kto_pair"]:
return super().dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
reference_rejected_logps,
)
# The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative DPO loss.
if self.loss_type == "sppo":
# Calculate a and b
a = self.beta * ( # pylint: disable=invalid-name
policy_chosen_logps - reference_chosen_logps
)
b = self.beta * ( # pylint: disable=invalid-name
policy_rejected_logps - reference_rejected_logps
)
# Compute the SPPO loss
losses = (a - 0.5) ** 2 + (b + 0.5) ** 2
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'sppo']"
)
chosen_rewards = (
self.beta
* (
policy_chosen_logps.to(self.accelerator.device)
- reference_chosen_logps.to(self.accelerator.device)
).detach()
)
rejected_rewards = (
self.beta
* (
policy_rejected_logps.to(self.accelerator.device)
- reference_rejected_logps.to(self.accelerator.device)
).detach()
)
return losses, chosen_rewards, rejected_rewards
class AxolotlORPOTrainer(ORPOTrainer):
"""