add helper to verify the correct model output file exists (#2245)
* add helper to verify the correct model output file exists * more checks using helper * chore: lint * fix import and relora model check * workaround for trl trainer saves * remove stray print
This commit is contained in:
@@ -4,7 +4,8 @@ E2E tests for llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from e2e.utils import check_model_output_exists
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -60,7 +61,7 @@ class TestLlama:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_fix_untrained_tokens(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -103,7 +104,7 @@ class TestLlama:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_batch_flattening(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -142,4 +143,4 @@ class TestLlama:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
Reference in New Issue
Block a user