continued cleanup and documentation

This commit is contained in:
Dan Saunders
2025-01-08 19:15:03 +00:00
parent 929ee15cc3
commit 6e72baf287
42 changed files with 280 additions and 300 deletions

View File

@@ -4,8 +4,8 @@ Simple end-to-end test for Cut Cross Entropy integration
import pytest
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins
@@ -64,9 +64,9 @@ class TestCutCrossEntropyIntegration:
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.parametrize(
@@ -92,7 +92,7 @@ class TestCutCrossEntropyIntegration:
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,6 +6,7 @@ from e2e.utils import require_torch_2_4_1
from axolotl.cli.datasets import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
@@ -60,7 +61,7 @@ class LigerIntegrationTestCase:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@require_torch_2_4_1
@@ -105,5 +106,5 @@ class LigerIntegrationTestCase:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)