fix(test): use RLType directly to skip needing to validate
This commit is contained in:
@@ -11,6 +11,7 @@ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuil
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
|
from axolotl.utils.schemas.enums import RLType
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="base_cfg")
|
@pytest.fixture(name="base_cfg")
|
||||||
@@ -87,7 +88,7 @@ def fixture_dpo_cfg(base_cfg):
|
|||||||
cfg = base_cfg.copy()
|
cfg = base_cfg.copy()
|
||||||
cfg.update(
|
cfg.update(
|
||||||
{
|
{
|
||||||
"rl": "dpo",
|
"rl": RLType.DPO,
|
||||||
"dpo_use_weighting": True,
|
"dpo_use_weighting": True,
|
||||||
"dpo_use_logits_to_keep": True,
|
"dpo_use_logits_to_keep": True,
|
||||||
"dpo_label_smoothing": 0.1,
|
"dpo_label_smoothing": 0.1,
|
||||||
@@ -102,7 +103,7 @@ def fixture_orpo_cfg(base_cfg):
|
|||||||
cfg = base_cfg.copy()
|
cfg = base_cfg.copy()
|
||||||
cfg.update(
|
cfg.update(
|
||||||
{
|
{
|
||||||
"rl": "orpo",
|
"rl": RLType.ORPO,
|
||||||
"orpo_alpha": 0.1,
|
"orpo_alpha": 0.1,
|
||||||
"max_prompt_len": 512,
|
"max_prompt_len": 512,
|
||||||
}
|
}
|
||||||
@@ -115,7 +116,7 @@ def fixture_kto_cfg(base_cfg):
|
|||||||
cfg = base_cfg.copy()
|
cfg = base_cfg.copy()
|
||||||
cfg.update(
|
cfg.update(
|
||||||
{
|
{
|
||||||
"rl": "kto",
|
"rl": RLType.KTO,
|
||||||
"kto_desirable_weight": 1.0,
|
"kto_desirable_weight": 1.0,
|
||||||
"kto_undesirable_weight": 1.0,
|
"kto_undesirable_weight": 1.0,
|
||||||
"max_prompt_len": 512,
|
"max_prompt_len": 512,
|
||||||
@@ -129,7 +130,7 @@ def fixture_grpo_cfg(base_cfg):
|
|||||||
cfg = base_cfg.copy()
|
cfg = base_cfg.copy()
|
||||||
cfg.update(
|
cfg.update(
|
||||||
{
|
{
|
||||||
"rl": "grpo",
|
"rl": RLType.GRPO,
|
||||||
"trl": DictDefault(
|
"trl": DictDefault(
|
||||||
{
|
{
|
||||||
"beta": 0.001,
|
"beta": 0.001,
|
||||||
@@ -153,7 +154,7 @@ def fixture_ipo_cfg(base_cfg):
|
|||||||
cfg = base_cfg.copy()
|
cfg = base_cfg.copy()
|
||||||
cfg.update(
|
cfg.update(
|
||||||
{
|
{
|
||||||
"rl": "ipo",
|
"rl": RLType.IPO,
|
||||||
"dpo_label_smoothing": 0.1,
|
"dpo_label_smoothing": 0.1,
|
||||||
"beta": 0.1,
|
"beta": 0.1,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user