add helper to verify the correct model output file exists (#2245)

* add helper to verify the correct model output file exists

* more checks using helper

* chore: lint

* fix import and relora model check

* workaround for trl trainer saves

* remove stray print
This commit is contained in:
Wing Lian
2025-01-13 10:43:29 -05:00
committed by GitHub
parent d8b4027200
commit dd26cc3c0f
29 changed files with 116 additions and 111 deletions

View File

@@ -13,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_tensorboard, with_temp_dir
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -78,10 +78,10 @@ class TestReLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
assert (
Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors"
).exists()
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()
Path(temp_dir) / "checkpoint-100/relora/model.safetensors"
).exists(), "Relora model checkpoint not found"
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"