Add KTO support (#1640)

* add kto support

* test cleanup

* fix outdated comment

* fix llama3 ultra

* chore: lint

* update to use rl_beta instead of dpo_beta

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Ben Redmond
2024-05-20 16:05:16 -04:00
committed by GitHub
parent ba45531802
commit 22ae21a6c2
11 changed files with 434 additions and 17 deletions

View File

@@ -205,3 +205,66 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir
def test_kto_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {},
"rl": "kto",
"rl_beta": 0.5,
"kto_desirable_weight": 1.0,
"kto_undesirable_weight": 1.0,
"remove_unused_columns": False,
"datasets": [
# {
# "path": "argilla/kto-mix-15k",
# "type": "chatml.argilla_chat",
# "split": "train",
# },
{
"path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto",
"type": "chatml.ultra",
"split": "train",
},
# {
# "path": "argilla/kto-mix-15k",
# "type": "llama3.argilla_chat",
# "split": "train",
# },
{
"path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto",
"type": "llama3.ultra",
"split": "train",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()

View File

@@ -1117,6 +1117,15 @@ class TestValidation(BaseValidation):
validate_config(cfg)
assert len(self._caplog.records) == 0
def test_dpo_beta_deprecation(self, minimal_cfg):
cfg = DictDefault({"dpo_beta": 0.2}) | minimal_cfg
with self._caplog.at_level(logging.WARNING):
new_cfg = validate_config(cfg)
assert new_cfg["rl_beta"] == 0.2
assert new_cfg["dpo_beta"] is None
assert len(self._caplog.records) == 1
class TestValidationCheckModelConfig(BaseValidation):
"""