continued cleanup and documentation
This commit is contained in:
@@ -4,8 +4,8 @@ Simple end-to-end test for Cut Cross Entropy integration
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.datasets import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import get_pytorch_version
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
@@ -64,9 +64,9 @@ class TestCutCrossEntropyIntegration:
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
with pytest.raises(ImportError):
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
else:
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -92,7 +92,7 @@ class TestCutCrossEntropyIntegration:
|
||||
major, minor, _ = get_pytorch_version()
|
||||
if (major, minor) < (2, 4):
|
||||
with pytest.raises(ImportError):
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
else:
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,6 +6,7 @@ from e2e.utils import require_torch_2_4_1
|
||||
|
||||
from axolotl.cli.datasets import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -60,7 +61,7 @@ class LigerIntegrationTestCase:
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@require_torch_2_4_1
|
||||
@@ -105,5 +106,5 @@ class LigerIntegrationTestCase:
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
Reference in New Issue
Block a user