ORPO Trainer replacement (#1551)

* WIP use trl ORPOTrainer

* fixes to make orpo work with trl

* fix the chat template laoding

* make sure to handle the special tokens and add_generation for assistant turn too
This commit is contained in:
Wing Lian
2024-04-19 17:25:36 -04:00
committed by GitHub
parent 0e8f340945
commit 7d1d22f72f
10 changed files with 151 additions and 26 deletions

View File

@@ -4,7 +4,7 @@ unit tests for axolotl.core.trainer_builder
import pytest
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
from axolotl.core.trainer_builder import HFRLTrainerBuilder
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
@@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer):
return load_model(cfg, tokenizer)
class TestHFDPOTrainerBuilder:
class TestHFRLTrainerBuilder:
"""
TestCase class for DPO trainer builder
"""
def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9