DPO support loss types (#3566)
* Support loss_type/loss_weights DPO * Validate dpo loss type/weights only set for dpo * Tests: Update ipo tests to use new path * Docs: Update docs for new ipo path * PR fixes - typo/validation * PR nit - warning * chore: fix warnings arg --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -96,6 +96,8 @@ def fixture_dpo_cfg(base_cfg):
|
||||
"dpo_use_weighting": True,
|
||||
"dpo_label_smoothing": 0.1,
|
||||
"beta": 0.1, # DPO beta
|
||||
"dpo_loss_type": ["sigmoid", "sft"],
|
||||
"dpo_loss_weights": [1.0, 0.5],
|
||||
}
|
||||
)
|
||||
return cfg
|
||||
@@ -164,7 +166,8 @@ def fixture_ipo_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": RLType.IPO,
|
||||
"rl": RLType.DPO,
|
||||
"dpo_loss_type": ["ipo"],
|
||||
"dpo_label_smoothing": 0,
|
||||
"beta": 0.1,
|
||||
}
|
||||
@@ -300,6 +303,8 @@ class TestHFRLTrainerBuilder:
|
||||
assert training_arguments.use_weighting is True
|
||||
assert training_arguments.label_smoothing == 0.1
|
||||
assert training_arguments.precompute_ref_log_probs is True
|
||||
assert training_arguments.loss_type == ["sigmoid", "sft"]
|
||||
assert training_arguments.loss_weights == [1.0, 0.5]
|
||||
|
||||
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
|
||||
|
||||
Reference in New Issue
Block a user