use correct builder

This commit is contained in:
Wing Lian
2025-02-03 00:25:24 -05:00
parent 626db6cb84
commit d155849e2c

View File

@@ -576,7 +576,7 @@ def prepare_opinionated_env(cfg):
def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"):
if cfg.rl in ("dpo", "grpo", "ipo", "orpo", "kto", "simpo"):
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]