diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 61f03e7ad..77e698468 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -576,7 +576,7 @@ def prepare_opinionated_env(cfg): def setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps ): - if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"): + if cfg.rl in ("dpo", "grpo", "ipo", "orpo", "kto", "simpo"): trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2]