review comments

This commit is contained in:
Dan Saunders
2025-01-10 17:27:03 +00:00
parent 2b7b37413d
commit 5ff1322f32
16 changed files with 130 additions and 158 deletions

View File

@@ -9,8 +9,8 @@ from pathlib import Path
import pytest
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_rl_datasets
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_dpo_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
@@ -65,7 +65,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -110,7 +110,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -155,7 +155,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -200,7 +200,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -244,7 +244,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -291,7 +291,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
@@ -355,7 +355,7 @@ class TestDPOLlamaLora(unittest.TestCase):
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
dataset_meta = load_dpo_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)