diff --git a/README.md b/README.md index 12fad21b1..afb354385 100644 --- a/README.md +++ b/README.md @@ -595,6 +595,9 @@ datasets: # For `completion` datsets only, uses the provided field instead of `text` column field: +# use RL training: dpo, ipo, kto_pair +rl: + # Saves the desired chat template to the tokenizer_config.json for easier inferencing # Currently supports chatml and inst (mistral/mixtral) chat_template: chatml diff --git a/requirements.txt b/requirements.txt index 66a09fee7..6ec751ee9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,4 +40,4 @@ s3fs gcsfs # adlfs -trl @ git+https://github.com/huggingface/trl.git@main +trl>=0.7.9 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4c30fe517..7798ca455 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -927,6 +927,8 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs["loss_type"] = "ipo" if self.cfg.dpo_label_smoothing: dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing + elif self.cfg.rl == "kto_pair": + dpo_trainer_kwargs["loss_type"] = "kto_pair" dpo_trainer = DPOTrainer( self.model,