add helper to verify the correct model output file exists (#2245)

* add helper to verify the correct model output file exists

* more checks using helper

* chore: lint

* fix import and relora model check

* workaround for trl trainer saves

* remove stray print
This commit is contained in:
Wing Lian
2025-01-13 10:43:29 -05:00
committed by GitHub
parent d8b4027200
commit dd26cc3c0f
29 changed files with 116 additions and 111 deletions

View File

@@ -2,8 +2,6 @@
Simple end-to-end test for Cut Cross Entropy integration
"""
from pathlib import Path
import pytest
from axolotl.cli import load_datasets
@@ -13,6 +11,8 @@ from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
# pylint: disable=duplicate-code
@@ -67,7 +67,7 @@ class TestCutCrossEntropyIntegration:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)
@pytest.mark.parametrize(
"attention_type",
@@ -95,4 +95,4 @@ class TestCutCrossEntropyIntegration:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_model_output_exists(temp_dir, cfg)