DPO support loss types (#3566)
* Support loss_type/loss_weights DPO * Validate dpo loss type/weights only set for dpo * Tests: Update ipo tests to use new path * Docs: Update docs for new ipo path * PR fixes - typo/validation * PR nit - warning * chore: fix warnings arg --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -116,6 +116,58 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_rpo(self, temp_dir):
|
||||
# For TRL >= 0.29, loss_type=["sigmoid", "sft"], loss_weights=[1, alpha]
|
||||
# replaces loss_type="rpo", rpo_alpha=alpha.
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 64,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_target_linear": True,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"rl": "dpo",
|
||||
"dpo_loss_type": ["sigmoid", "sft"],
|
||||
"dpo_loss_weights": [1.0, 1.0],
|
||||
"datasets": [
|
||||
{
|
||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||
"type": "chatml.ultra",
|
||||
"split": "train",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "paged_adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@pytest.mark.skip("kto_pair no longer supported in trl")
|
||||
@with_temp_dir
|
||||
def test_kto_pair_lora(self, temp_dir):
|
||||
@@ -181,7 +233,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"rl": "ipo",
|
||||
"rl": "dpo",
|
||||
"dpo_loss_type": ["ipo"],
|
||||
"datasets": [
|
||||
{
|
||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||
|
||||
Reference in New Issue
Block a user