Add KTO support (#1640)

* add kto support

* test cleanup

* fix outdated comment

* fix llama3 ultra

* chore: lint

* update to use rl_beta instead of dpo_beta

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Ben Redmond
2024-05-20 16:05:16 -04:00
committed by GitHub
parent ba45531802
commit 22ae21a6c2
11 changed files with 434 additions and 17 deletions

View File

@@ -803,7 +803,11 @@ def load_model(
if not reference_model or cfg.lora_model_dir:
# if we're not loading the reference model, then we're loading the model for training
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
if (
cfg.adapter
and cfg.rl in ["dpo", "ipo", "kto_pair", "kto"]
and not cfg.merge_lora
):
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
else:
model, lora_config = load_adapter(model, cfg, cfg.adapter)