invert check

This commit is contained in:
Wing Lian
2024-05-02 09:31:43 -04:00
parent 7fea5822f0
commit df645906eb

View File

@@ -837,7 +837,7 @@ class AxolotlDPOTrainer(DPOTrainer):
The losses tensor contains the DPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
if self.loss_type not in ["sigmoid", "hinge", "ipo", "kto_pair"]:
if self.loss_type in ["sigmoid", "hinge", "ipo", "kto_pair"]:
return super().dpo_loss(
policy_chosen_logps,
policy_rejected_logps,