Add precompute_ref_log_probs to config schema (#3555) [skip ci]
* Add precompute_ref_log_probs to config schema * chore: add description for config * Add test for precompute_ref_log_probs and move to training args * useing precompute logprobs as the default slows down CI as it has to precompute --------- Co-authored-by: NanoCode012 <nano@axolotl.ai> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -289,6 +289,7 @@ class TestHFRLTrainerBuilder:
|
||||
# assert training_arguments.gradient_checkpointing is True
|
||||
|
||||
def test_dpo_training_arguments(self, dpo_cfg, model, tokenizer):
|
||||
dpo_cfg["precompute_ref_log_probs"] = True
|
||||
builder = HFRLTrainerBuilder(dpo_cfg, model, tokenizer)
|
||||
training_arguments, _ = builder._build_training_arguments(100)
|
||||
|
||||
@@ -298,6 +299,7 @@ class TestHFRLTrainerBuilder:
|
||||
assert hasattr(training_arguments, "use_weighting")
|
||||
assert training_arguments.use_weighting is True
|
||||
assert training_arguments.label_smoothing == 0.1
|
||||
assert training_arguments.precompute_ref_log_probs is True
|
||||
|
||||
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
|
||||
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
|
||||
|
||||
Reference in New Issue
Block a user