DPO support loss types (#3566)
* Support loss_type/loss_weights DPO * Validate dpo loss type/weights only set for dpo * Tests: Update ipo tests to use new path * Docs: Update docs for new ipo path * PR fixes - typo/validation * PR nit - warning * chore: fix warnings arg --------- Co-authored-by: NanoCode012 <nano@axolotl.ai>
This commit is contained in:
@@ -26,7 +26,7 @@ axolotl config-schema # Dump config JSON schema
|
||||
| Method | Config Key | When to Use |
|
||||
|--------|-----------|-------------|
|
||||
| SFT | *(default)* | Input-output pairs, instruction tuning |
|
||||
| DPO/IPO | `rl: dpo` / `rl: ipo` | Paired preference data (chosen vs rejected) |
|
||||
| DPO/IPO | `rl: dpo` / `rl: dpo, dpo_loss_type: ["ipo"]` | Paired preference data (chosen vs rejected) |
|
||||
| KTO | `rl: kto` | Unpaired binary preference labels |
|
||||
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
|
||||
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |
|
||||
|
||||
@@ -38,7 +38,7 @@ No vLLM server needed (unlike GRPO). Offline RL with pre-collected preference da
|
||||
|
||||
1. Paired preference data (chosen + rejected)?
|
||||
- Default → `rl: dpo`
|
||||
- Overfitting → `rl: ipo`
|
||||
- Overfitting → `rl: dpo, dpo_loss_type: ["ipo"]`
|
||||
- VRAM-limited → `rl: orpo` (no ref model)
|
||||
- Length-sensitive → `rl: simpo` (no ref model)
|
||||
2. Only binary labels (good/bad)? → `rl: kto`
|
||||
|
||||
@@ -320,8 +320,10 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
||||
|
||||
```yaml
|
||||
rl: ipo
|
||||
rl: dpo
|
||||
dpo_loss_type: ["ipo"]
|
||||
```
|
||||
*Note:* Passing `rl: ipo` directly is still supported, but will soon be deprecated.
|
||||
|
||||
### ORPO
|
||||
|
||||
|
||||
@@ -20,8 +20,16 @@ class DPOStrategy:
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
training_args_kwargs = {}
|
||||
if cfg.rl is RLType.DPO:
|
||||
if cfg.dpo_loss_type is not None:
|
||||
training_args_kwargs["loss_type"] = cfg.dpo_loss_type
|
||||
|
||||
if cfg.dpo_loss_weights is not None:
|
||||
training_args_kwargs["loss_weights"] = cfg.dpo_loss_weights
|
||||
|
||||
if cfg.rl is RLType.IPO:
|
||||
training_args_kwargs["loss_type"] = ["ipo"]
|
||||
|
||||
# Label smoothing is not compatible with IPO
|
||||
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||
|
||||
@@ -309,6 +309,16 @@ class AxolotlInputConfig(
|
||||
|
||||
dpo_padding_free: bool | None = None
|
||||
|
||||
dpo_loss_type: Annotated[list[str], MinLen(1)] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "List of DPO losses to use."},
|
||||
)
|
||||
|
||||
dpo_loss_weights: Annotated[list[float], MinLen(1)] | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={"description": "Weights for each DPO loss."},
|
||||
)
|
||||
|
||||
datasets: (
|
||||
Annotated[
|
||||
list[
|
||||
|
||||
@@ -760,6 +760,40 @@ class RLValidationMixin:
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_dpo(cls, data):
|
||||
dpo_loss_type = data.get("dpo_loss_type")
|
||||
dpo_loss_weights = data.get("dpo_loss_weights")
|
||||
rl = data.get("rl")
|
||||
|
||||
if rl == "ipo":
|
||||
LOG.warning(
|
||||
"rl: ipo will soon be deprecated. Use `rl: dpo` with `dpo_loss_type: ['ipo']` instead."
|
||||
)
|
||||
|
||||
if rl == "dpo":
|
||||
if dpo_loss_weights is not None and dpo_loss_type is None:
|
||||
raise ValueError(
|
||||
"`dpo_loss_weights` requires `dpo_loss_type` to be set"
|
||||
)
|
||||
if (
|
||||
dpo_loss_type is not None
|
||||
and dpo_loss_weights is not None
|
||||
and len(dpo_loss_type) != len(dpo_loss_weights)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`dpo_loss_type` and `dpo_loss_weights` must be the same length, "
|
||||
f"but got {len(dpo_loss_type)} losses and {len(dpo_loss_weights)} weights"
|
||||
)
|
||||
elif dpo_loss_type is not None or dpo_loss_weights is not None:
|
||||
raise ValueError(
|
||||
f"`dpo_loss_type` and `dpo_loss_weights` are for DPO only,"
|
||||
f"but got {rl=}, {dpo_loss_type=} and {dpo_loss_weights=}"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_grpo_batch_size_divisibility(cls, data):
|
||||
|
||||
@@ -96,6 +96,8 @@ def fixture_dpo_cfg(base_cfg):
|
||||
"dpo_use_weighting": True,
|
||||
"dpo_label_smoothing": 0.1,
|
||||
"beta": 0.1, # DPO beta
|
||||
"dpo_loss_type": ["sigmoid", "sft"],
|
||||
"dpo_loss_weights": [1.0, 0.5],
|
||||
}
|
||||
)
|
||||
return cfg
|
||||
@@ -164,7 +166,8 @@ def fixture_ipo_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": RLType.IPO,
|
||||
"rl": RLType.DPO,
|
||||
"dpo_loss_type": ["ipo"],
|
||||
"dpo_label_smoothing": 0,
|
||||
"beta": 0.1,
|
||||
}
|
||||
@@ -300,6 +303,8 @@ class TestHFRLTrainerBuilder:
|
||||
assert training_arguments.use_weighting is True
|
||||
assert training_arguments.label_smoothing == 0.1
|
||||
assert training_arguments.precompute_ref_log_probs is True
|
||||
assert training_arguments.loss_type == ["sigmoid", "sft"]
|
||||
assert training_arguments.loss_weights == [1.0, 0.5]
|
||||
|
||||
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
|
||||
|
||||
@@ -116,6 +116,58 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_rpo(self, temp_dir):
|
||||
# For TRL >= 0.29, loss_type=["sigmoid", "sft"], loss_weights=[1, alpha]
|
||||
# replaces loss_type="rpo", rpo_alpha=alpha.
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 64,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_target_linear": True,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"rl": "dpo",
|
||||
"dpo_loss_type": ["sigmoid", "sft"],
|
||||
"dpo_loss_weights": [1.0, 1.0],
|
||||
"datasets": [
|
||||
{
|
||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||
"type": "chatml.ultra",
|
||||
"split": "train",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "paged_adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@pytest.mark.skip("kto_pair no longer supported in trl")
|
||||
@with_temp_dir
|
||||
def test_kto_pair_lora(self, temp_dir):
|
||||
@@ -181,7 +233,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"rl": "ipo",
|
||||
"rl": "dpo",
|
||||
"dpo_loss_type": ["ipo"],
|
||||
"datasets": [
|
||||
{
|
||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||
|
||||
Reference in New Issue
Block a user