DPO support loss types (#3566)
* Support loss_type/loss_weights DPO * Validate dpo loss type/weights only set for dpo * Tests: Update ipo tests to use new path * Docs: Update docs for new ipo path * PR fixes - typo/validation * PR nit - warning * chore: fix warnings arg --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -96,6 +96,8 @@ def fixture_dpo_cfg(base_cfg):
|
||||
"dpo_use_weighting": True,
|
||||
"dpo_label_smoothing": 0.1,
|
||||
"beta": 0.1, # DPO beta
|
||||
"dpo_loss_type": ["sigmoid", "sft"],
|
||||
"dpo_loss_weights": [1.0, 0.5],
|
||||
}
|
||||
)
|
||||
return cfg
|
||||
@@ -164,7 +166,8 @@ def fixture_ipo_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": RLType.IPO,
|
||||
"rl": RLType.DPO,
|
||||
"dpo_loss_type": ["ipo"],
|
||||
"dpo_label_smoothing": 0,
|
||||
"beta": 0.1,
|
||||
}
|
||||
@@ -300,6 +303,8 @@ class TestHFRLTrainerBuilder:
|
||||
assert training_arguments.use_weighting is True
|
||||
assert training_arguments.label_smoothing == 0.1
|
||||
assert training_arguments.precompute_ref_log_probs is True
|
||||
assert training_arguments.loss_type == ["sigmoid", "sft"]
|
||||
assert training_arguments.loss_weights == [1.0, 0.5]
|
||||
|
||||
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
|
||||
|
||||
@@ -116,6 +116,58 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_rpo(self, temp_dir):
|
||||
# For TRL >= 0.29, loss_type=["sigmoid", "sft"], loss_weights=[1, alpha]
|
||||
# replaces loss_type="rpo", rpo_alpha=alpha.
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 64,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_target_linear": True,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"rl": "dpo",
|
||||
"dpo_loss_type": ["sigmoid", "sft"],
|
||||
"dpo_loss_weights": [1.0, 1.0],
|
||||
"datasets": [
|
||||
{
|
||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||
"type": "chatml.ultra",
|
||||
"split": "train",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "paged_adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@pytest.mark.skip("kto_pair no longer supported in trl")
|
||||
@with_temp_dir
|
||||
def test_kto_pair_lora(self, temp_dir):
|
||||
@@ -181,7 +233,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"rl": "ipo",
|
||||
"rl": "dpo",
|
||||
"dpo_loss_type": ["ipo"],
|
||||
"datasets": [
|
||||
{
|
||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||
|
||||
Reference in New Issue
Block a user