diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index b4d8b4d47..447b64eb8 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -239,10 +239,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): and self.cfg.rl not in (RLType.GRPO, RLType.ORPO, RLType.EBFT) ): trainer_kwargs["peft_config"] = self.peft_config - if self.cfg.precompute_ref_log_probs is not None: - trainer_kwargs["precompute_ref_log_probs"] = ( - self.cfg.precompute_ref_log_probs - ) trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs) diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 6d5251de1..7d979e6bf 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -32,4 +32,8 @@ class DPOStrategy: training_args_kwargs["padding_free"] = cfg.dpo_padding_free if cfg.dpo_use_liger_kernel is not None: training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel + if cfg.precompute_ref_log_probs is not None: + training_args_kwargs["precompute_ref_log_probs"] = ( + cfg.precompute_ref_log_probs + ) return training_args_kwargs diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 5c3357d72..e45534640 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -294,6 +294,12 @@ class AxolotlInputConfig( }, ) dpo_label_smoothing: float | None = None + precompute_ref_log_probs: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Precompute reference model log probabilities for DPO" + }, + ) dpo_use_liger_kernel: bool | None = Field( default=None, diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index a241e8549..338d48171 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -289,6 +289,7 @@ class TestHFRLTrainerBuilder: # assert training_arguments.gradient_checkpointing is True def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer): + dpo_cfg["precompute_ref_log_probs"] = True builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer) training_arguments, _ = builder._build_training_arguments(100) @@ -298,6 +299,7 @@ class TestHFRLTrainerBuilder: assert hasattr(training_arguments, "use_weighting") assert training_arguments.use_weighting is True assert training_arguments.label_smoothing == 0.1 + assert training_arguments.precompute_ref_log_probs is True def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer): builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)