invert check
This commit is contained in:
@@ -837,7 +837,7 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|||||||
The losses tensor contains the DPO loss for each example in the batch.
|
The losses tensor contains the DPO loss for each example in the batch.
|
||||||
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
||||||
"""
|
"""
|
||||||
if self.loss_type not in ["sigmoid", "hinge", "ipo", "kto_pair"]:
|
if self.loss_type in ["sigmoid", "hinge", "ipo", "kto_pair"]:
|
||||||
return super().dpo_loss(
|
return super().dpo_loss(
|
||||||
policy_chosen_logps,
|
policy_chosen_logps,
|
||||||
policy_rejected_logps,
|
policy_rejected_logps,
|
||||||
|
|||||||
Reference in New Issue
Block a user