From 1b1fc917bc4b106dec0482c1f2ad9a48422bd468 Mon Sep 17 00:00:00 2001 From: Joaquin Hui <132194176+joaquinhuigomez@users.noreply.github.com> Date: Wed, 1 Apr 2026 18:28:40 +0100 Subject: [PATCH] Add precompute_ref_log_probs to config schema (#3555) [skip ci] * Add precompute_ref_log_probs to config schema * chore: add description for config * Add test for precompute_ref_log_probs and move to training args * useing precompute logprobs as the default slows down CI as it has to precompute --------- Co-authored-by: NanoCode012 Co-authored-by: Wing Lian --- src/axolotl/core/builders/rl.py | 4 ---- src/axolotl/core/trainers/dpo/__init__.py | 4 ++++ src/axolotl/utils/schemas/config.py | 6 ++++++ tests/core/test_builders.py | 2 ++ 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index b4d8b4d47..447b64eb8 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -239,10 +239,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase): and self.cfg.rl not in (RLType.GRPO, RLType.ORPO, RLType.EBFT) ): trainer_kwargs["peft_config"] = self.peft_config - if self.cfg.precompute_ref_log_probs is not None: - trainer_kwargs["precompute_ref_log_probs"] = ( - self.cfg.precompute_ref_log_probs - ) trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs) diff --git a/src/axolotl/core/trainers/dpo/__init__.py b/src/axolotl/core/trainers/dpo/__init__.py index 6d5251de1..7d979e6bf 100644 --- a/src/axolotl/core/trainers/dpo/__init__.py +++ b/src/axolotl/core/trainers/dpo/__init__.py @@ -32,4 +32,8 @@ class DPOStrategy: training_args_kwargs["padding_free"] = cfg.dpo_padding_free if cfg.dpo_use_liger_kernel is not None: training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel + if cfg.precompute_ref_log_probs is not None: + training_args_kwargs["precompute_ref_log_probs"] = ( + cfg.precompute_ref_log_probs + ) return training_args_kwargs diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 5c3357d72..e45534640 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -294,6 +294,12 @@ class AxolotlInputConfig( }, ) dpo_label_smoothing: float | None = None + precompute_ref_log_probs: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Precompute reference model log probabilities for DPO" + }, + ) dpo_use_liger_kernel: bool | None = Field( default=None, diff --git a/tests/core/test_builders.py b/tests/core/test_builders.py index a241e8549..338d48171 100644 --- a/tests/core/test_builders.py +++ b/tests/core/test_builders.py @@ -289,6 +289,7 @@ class TestHFRLTrainerBuilder: # assert training_arguments.gradient_checkpointing is True def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer): + dpo_cfg["precompute_ref_log_probs"] = True builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer) training_arguments, _ = builder._build_training_arguments(100) @@ -298,6 +299,7 @@ class TestHFRLTrainerBuilder: assert hasattr(training_arguments, "use_weighting") assert training_arguments.use_weighting is True assert training_arguments.label_smoothing == 0.1 + assert training_arguments.precompute_ref_log_probs is True def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer): builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)