fix(test): use RLType directly to skip needing to validate

This commit is contained in:
NanoCode012
2025-05-14 16:17:34 +07:00
parent 06fae0d34e
commit c281c6e519

View File

@@ -11,6 +11,7 @@ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuil
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.schemas.enums import RLType
@pytest.fixture(name="base_cfg")
@@ -87,7 +88,7 @@ def fixture_dpo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": "dpo",
"rl": RLType.DPO,
"dpo_use_weighting": True,
"dpo_use_logits_to_keep": True,
"dpo_label_smoothing": 0.1,
@@ -102,7 +103,7 @@ def fixture_orpo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": "orpo",
"rl": RLType.ORPO,
"orpo_alpha": 0.1,
"max_prompt_len": 512,
}
@@ -115,7 +116,7 @@ def fixture_kto_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": "kto",
"rl": RLType.KTO,
"kto_desirable_weight": 1.0,
"kto_undesirable_weight": 1.0,
"max_prompt_len": 512,
@@ -129,7 +130,7 @@ def fixture_grpo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": "grpo",
"rl": RLType.GRPO,
"trl": DictDefault(
{
"beta": 0.001,
@@ -153,7 +154,7 @@ def fixture_ipo_cfg(base_cfg):
cfg = base_cfg.copy()
cfg.update(
{
"rl": "ipo",
"rl": RLType.IPO,
"dpo_label_smoothing": 0.1,
"beta": 0.1,
}