diff --git a/AGENTS.md b/AGENTS.md index e9b747ce3..43470d9f8 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -26,7 +26,7 @@ axolotl config-schema # Dump config JSON schema | Method | Config Key | When to Use | |--------|-----------|-------------| | SFT | *(default)* | Input-output pairs, instruction tuning | -| DPO/IPO | `rl: dpo` / `rl: ipo` | Paired preference data (chosen vs rejected) | +| DPO/IPO | `rl: dpo` / `rl: dpo, dpo_loss_type: ["ipo"]` | Paired preference data (chosen vs rejected) | | KTO | `rl: kto` | Unpaired binary preference labels | | ORPO | `rl: orpo` | Single-stage alignment, no ref model | | GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) | diff --git a/docs/agents/preference_tuning.md b/docs/agents/preference_tuning.md index bed973009..3414d22ce 100644 --- a/docs/agents/preference_tuning.md +++ b/docs/agents/preference_tuning.md @@ -38,7 +38,7 @@ No vLLM server needed (unlike GRPO). Offline RL with pre-collected preference da 1. Paired preference data (chosen + rejected)? - Default → `rl: dpo` - - Overfitting → `rl: ipo` + - Overfitting → `rl: dpo, dpo_loss_type: ["ipo"]` - VRAM-limited → `rl: orpo` (no ref model) - Length-sensitive → `rl: simpo` (no ref model) 2. Only binary labels (good/bad)? → `rl: kto` diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 514c1c034..75d20414c 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -320,8 +320,10 @@ The input format is a simple JSON input with customizable fields based on the ab As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO. ```yaml -rl: ipo +rl: dpo +dpo_loss_type: ["ipo"] ``` +*Note:* Passing `rl: ipo` directly is still supported, but will soon be deprecated. ### ORPO diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 7d979e6bf..a73ec3149 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -20,8 +20,16 @@ class DPOStrategy: @classmethod def set_training_args_kwargs(cls, cfg): training_args_kwargs = {} + if cfg.rl is RLType.DPO: + if cfg.dpo_loss_type is not None: + training_args_kwargs["loss_type"] = cfg.dpo_loss_type + + if cfg.dpo_loss_weights is not None: + training_args_kwargs["loss_weights"] = cfg.dpo_loss_weights + if cfg.rl is RLType.IPO: training_args_kwargs["loss_type"] = ["ipo"] + # Label smoothing is not compatible with IPO if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing: training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index c762d7f80..3211b7c36 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -309,6 +309,16 @@ class AxolotlInputConfig( dpo_padding_free: bool | None = None + dpo_loss_type: Annotated[list[str], MinLen(1)] | None = Field( + default=None, + json_schema_extra={"description": "List of DPO losses to use."}, + ) + + dpo_loss_weights: Annotated[list[float], MinLen(1)] | None = Field( + default=None, + json_schema_extra={"description": "Weights for each DPO loss."}, + ) + datasets: ( Annotated[ list[ diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 484a1fb47..9161765d0 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -760,6 +760,40 @@ class RLValidationMixin: ) return data + @model_validator(mode="before") + @classmethod + def check_dpo(cls, data): + dpo_loss_type = data.get("dpo_loss_type") + dpo_loss_weights = data.get("dpo_loss_weights") + rl = data.get("rl") + + if rl == "ipo": + LOG.warning( + "rl: ipo will soon be deprecated. Use `rl: dpo` with `dpo_loss_type: ['ipo']` instead." + ) + + if rl == "dpo": + if dpo_loss_weights is not None and dpo_loss_type is None: + raise ValueError( + "`dpo_loss_weights` requires `dpo_loss_type` to be set" + ) + if ( + dpo_loss_type is not None + and dpo_loss_weights is not None + and len(dpo_loss_type) != len(dpo_loss_weights) + ): + raise ValueError( + f"`dpo_loss_type` and `dpo_loss_weights` must be the same length, " + f"but got {len(dpo_loss_type)} losses and {len(dpo_loss_weights)} weights" + ) + elif dpo_loss_type is not None or dpo_loss_weights is not None: + raise ValueError( + f"`dpo_loss_type` and `dpo_loss_weights` are for DPO only," + f"but got {rl=}, {dpo_loss_type=} and {dpo_loss_weights=}" + ) + + return data + @model_validator(mode="before") @classmethod def check_grpo_batch_size_divisibility(cls, data): diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index 338d48171..0a4b2ad0b 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -96,6 +96,8 @@ def fixture_dpo_cfg(base_cfg): "dpo_use_weighting": True, "dpo_label_smoothing": 0.1, "beta": 0.1, # DPO beta + "dpo_loss_type": ["sigmoid", "sft"], + "dpo_loss_weights": [1.0, 0.5], } ) return cfg @@ -164,7 +166,8 @@ def fixture_ipo_cfg(base_cfg): cfg = base_cfg.copy() cfg.update( { - "rl": RLType.IPO, + "rl": RLType.DPO, + "dpo_loss_type": ["ipo"], "dpo_label_smoothing": 0, "beta": 0.1, } @@ -300,6 +303,8 @@ class TestHFRLTrainerBuilder: assert training_arguments.use_weighting is True assert training_arguments.label_smoothing == 0.1 assert training_arguments.precompute_ref_log_probs is True + assert training_arguments.loss_type == ["sigmoid", "sft"] + assert training_arguments.loss_weights == [1.0, 0.5] def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer): builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer) diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 0aca1807c..24784eb2c 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -116,6 +116,58 @@ class TestDPOLlamaLora(unittest.TestCase): train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) + @with_temp_dir + def test_rpo(self, temp_dir): + # For TRL >= 0.29, loss_type=["sigmoid", "sft"], loss_weights=[1, alpha] + # replaces loss_type="rpo", rpo_alpha=alpha. + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "tokenizer_type": "AutoTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "rl": "dpo", + "dpo_loss_type": ["sigmoid", "sft"], + "dpo_loss_weights": [1.0, 1.0], + "datasets": [ + { + "path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", + "type": "chatml.ultra", + "split": "train", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, + "save_first_step": False, + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) + @pytest.mark.skip("kto_pair no longer supported in trl") @with_temp_dir def test_kto_pair_lora(self, temp_dir): @@ -181,7 +233,8 @@ class TestDPOLlamaLora(unittest.TestCase): "special_tokens": { "pad_token": "<|endoftext|>", }, - "rl": "ipo", + "rl": "dpo", + "dpo_loss_type": ["ipo"], "datasets": [ { "path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",