diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index 3520aff10..20bcc8dd6 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -177,12 +177,8 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): # dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs if res["chosen_input_ids"][0] == processing_class.bos_token_id: res["chosen_input_ids"] = res["chosen_input_ids"][1:] - res["chosen_labels"] = res["chosen_labels"][1:] - res["chosen_attention_mask"] = res["chosen_attention_mask"][1:] if res["rejected_input_ids"][0] == processing_class.bos_token_id: res["rejected_input_ids"] = res["rejected_input_ids"][1:] - res["rejected_labels"] = res["rejected_labels"][1:] - res["rejected_attention_mask"] = res["rejected_attention_mask"][1:] return res