From d7057ccd36b62fb44417a239c594498664f99191 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 Jan 2024 13:30:45 -0500 Subject: [PATCH] paired kto support (#1069) --- README.md | 3 +++ requirements.txt | 2 +- src/axolotl/core/trainer_builder.py | 2 ++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 12fad21b1..afb354385 100644 --- a/README.md +++ b/README.md @@ -595,6 +595,9 @@ datasets: # For `completion` datsets only, uses the provided field instead of `text` column field: +# use RL training: dpo, ipo, kto_pair +rl: + # Saves the desired chat template to the tokenizer_config.json for easier inferencing # Currently supports chatml and inst (mistral/mixtral) chat_template: chatml diff --git a/requirements.txt b/requirements.txt index 66a09fee7..6ec751ee9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,4 +40,4 @@ s3fs gcsfs # adlfs -trl @ git+https://github.com/huggingface/trl.git@main +trl>=0.7.9 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4c30fe517..7798ca455 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -927,6 +927,8 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs["loss_type"] = "ipo" if self.cfg.dpo_label_smoothing: dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing + elif self.cfg.rl == "kto_pair": + dpo_trainer_kwargs["loss_type"] = "kto_pair" dpo_trainer = DPOTrainer( self.model,