From c281c6e519926d724eec6f31ef9d65511d63cd85 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 14 May 2025 16:17:34 +0700 Subject: [PATCH] fix(test): use RLType directly to skip needing to validate --- tests/core/test_trainer_builder.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py index 42c66a608..824a7b4c7 100644 --- a/tests/core/test_trainer_builder.py +++ b/tests/core/test_trainer_builder.py @@ -11,6 +11,7 @@ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuil from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer +from axolotl.utils.schemas.enums import RLType @pytest.fixture(name="base_cfg") @@ -87,7 +88,7 @@ def fixture_dpo_cfg(base_cfg): cfg = base_cfg.copy() cfg.update( { - "rl": "dpo", + "rl": RLType.DPO, "dpo_use_weighting": True, "dpo_use_logits_to_keep": True, "dpo_label_smoothing": 0.1, @@ -102,7 +103,7 @@ def fixture_orpo_cfg(base_cfg): cfg = base_cfg.copy() cfg.update( { - "rl": "orpo", + "rl": RLType.ORPO, "orpo_alpha": 0.1, "max_prompt_len": 512, } @@ -115,7 +116,7 @@ def fixture_kto_cfg(base_cfg): cfg = base_cfg.copy() cfg.update( { - "rl": "kto", + "rl": RLType.KTO, "kto_desirable_weight": 1.0, "kto_undesirable_weight": 1.0, "max_prompt_len": 512, @@ -129,7 +130,7 @@ def fixture_grpo_cfg(base_cfg): cfg = base_cfg.copy() cfg.update( { - "rl": "grpo", + "rl": RLType.GRPO, "trl": DictDefault( { "beta": 0.001, @@ -153,7 +154,7 @@ def fixture_ipo_cfg(base_cfg): cfg = base_cfg.copy() cfg.update( { - "rl": "ipo", + "rl": RLType.IPO, "dpo_label_smoothing": 0.1, "beta": 0.1, }