fix(test): use RLType directly to skip needing to validate
This commit is contained in:
@@ -11,6 +11,7 @@ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuil
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
|
||||
@pytest.fixture(name="base_cfg")
|
||||
@@ -87,7 +88,7 @@ def fixture_dpo_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": "dpo",
|
||||
"rl": RLType.DPO,
|
||||
"dpo_use_weighting": True,
|
||||
"dpo_use_logits_to_keep": True,
|
||||
"dpo_label_smoothing": 0.1,
|
||||
@@ -102,7 +103,7 @@ def fixture_orpo_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": "orpo",
|
||||
"rl": RLType.ORPO,
|
||||
"orpo_alpha": 0.1,
|
||||
"max_prompt_len": 512,
|
||||
}
|
||||
@@ -115,7 +116,7 @@ def fixture_kto_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": "kto",
|
||||
"rl": RLType.KTO,
|
||||
"kto_desirable_weight": 1.0,
|
||||
"kto_undesirable_weight": 1.0,
|
||||
"max_prompt_len": 512,
|
||||
@@ -129,7 +130,7 @@ def fixture_grpo_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": "grpo",
|
||||
"rl": RLType.GRPO,
|
||||
"trl": DictDefault(
|
||||
{
|
||||
"beta": 0.001,
|
||||
@@ -153,7 +154,7 @@ def fixture_ipo_cfg(base_cfg):
|
||||
cfg = base_cfg.copy()
|
||||
cfg.update(
|
||||
{
|
||||
"rl": "ipo",
|
||||
"rl": RLType.IPO,
|
||||
"dpo_label_smoothing": 0.1,
|
||||
"beta": 0.1,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user