Add precompute_ref_log_probs to config schema (#3555) [skip ci]
* Add precompute_ref_log_probs to config schema * chore: add description for config * Add test for precompute_ref_log_probs and move to training args * useing precompute logprobs as the default slows down CI as it has to precompute --------- Co-authored-by: NanoCode012 <nano@axolotl.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -239,10 +239,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO, RLType.EBFT)
|
||||
):
|
||||
trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||
self.cfg.precompute_ref_log_probs
|
||||
)
|
||||
|
||||
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
|
||||
|
||||
|
||||
@@ -32,4 +32,8 @@ class DPOStrategy:
|
||||
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
|
||||
if cfg.dpo_use_liger_kernel is not None:
|
||||
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
|
||||
if cfg.precompute_ref_log_probs is not None:
|
||||
training_args_kwargs["precompute_ref_log_probs"] = (
|
||||
cfg.precompute_ref_log_probs
|
||||
)
|
||||
return training_args_kwargs
|
||||
|
||||
@@ -294,6 +294,12 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
dpo_label_smoothing: float | None = None
|
||||
precompute_ref_log_probs: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Precompute reference model log probabilities for DPO"
|
||||
},
|
||||
)
|
||||
|
||||
dpo_use_liger_kernel: bool | None = Field(
|
||||
default=None,
|
||||
|
||||
@@ -289,6 +289,7 @@ class TestHFRLTrainerBuilder:
|
||||
# assert training_arguments.gradient_checkpointing is True
|
||||
|
||||
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
|
||||
dpo_cfg["precompute_ref_log_probs"] = True
|
||||
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)
|
||||
training_arguments, _ = builder._build_training_arguments(100)
|
||||
|
||||
@@ -298,6 +299,7 @@ class TestHFRLTrainerBuilder:
|
||||
assert hasattr(training_arguments, "use_weighting")
|
||||
assert training_arguments.use_weighting is True
|
||||
assert training_arguments.label_smoothing == 0.1
|
||||
assert training_arguments.precompute_ref_log_probs is True
|
||||
|
||||
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
|
||||
|
||||
Reference in New Issue
Block a user