ORPO Trainer replacement (#1551)
* WIP use trl ORPOTrainer * fixes to make orpo work with trl * fix the chat template laoding * make sure to handle the special tokens and add_generation for assistant turn too
This commit is contained in:
@@ -4,7 +4,7 @@ unit tests for axolotl.core.trainer_builder
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.core.trainer_builder import HFDPOTrainerBuilder
|
||||
from axolotl.core.trainer_builder import HFRLTrainerBuilder
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
@@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer):
|
||||
return load_model(cfg, tokenizer)
|
||||
|
||||
|
||||
class TestHFDPOTrainerBuilder:
|
||||
class TestHFRLTrainerBuilder:
|
||||
"""
|
||||
TestCase class for DPO trainer builder
|
||||
"""
|
||||
|
||||
def test_build_training_arguments(self, cfg, model, tokenizer):
|
||||
builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
|
||||
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||
training_arguments = builder.build_training_arguments(100)
|
||||
assert training_arguments.adam_beta1 == 0.998
|
||||
assert training_arguments.adam_beta2 == 0.9
|
||||
|
||||
Reference in New Issue
Block a user