Add precompute_ref_log_probs to config schema (#3555) [skip ci]

* Add precompute_ref_log_probs to config schema

* chore: add description for config

* Add test for precompute_ref_log_probs and move to training args

* useing precompute logprobs as the default slows down CI as it has to precompute

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
Joaquin Hui
2026-04-01 18:28:40 +01:00
committed by GitHub
parent 96ae8bdd1d
commit 1b1fc917bc
4 changed files with 12 additions and 4 deletions

View File

@@ -239,10 +239,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO, RLType.EBFT)
):
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
)
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)

View File

@@ -32,4 +32,8 @@ class DPOStrategy:
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
if cfg.dpo_use_liger_kernel is not None:
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
if cfg.precompute_ref_log_probs is not None:
training_args_kwargs["precompute_ref_log_probs"] = (
cfg.precompute_ref_log_probs
)
return training_args_kwargs

View File

@@ -294,6 +294,12 @@ class AxolotlInputConfig(
},
)
dpo_label_smoothing: float | None = None
precompute_ref_log_probs: bool | None = Field(
default=None,
json_schema_extra={
"description": "Precompute reference model log probabilities for DPO"
},
)
dpo_use_liger_kernel: bool | None = Field(
default=None,

View File

@@ -289,6 +289,7 @@ class TestHFRLTrainerBuilder:
# assert training_arguments.gradient_checkpointing is True
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
dpo_cfg["precompute_ref_log_probs"] = True
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)
training_arguments, _ = builder._build_training_arguments(100)
@@ -298,6 +299,7 @@ class TestHFRLTrainerBuilder:
assert hasattr(training_arguments, "use_weighting")
assert training_arguments.use_weighting is True
assert training_arguments.label_smoothing == 0.1
assert training_arguments.precompute_ref_log_probs is True
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)