add helper to verify the correct model output file exists (#2245)

* add helper to verify the correct model output file exists

* more checks using helper

* chore: lint

* fix import and relora model check

* workaround for trl trainer saves

* remove stray print
This commit is contained in:
Wing Lian
2025-01-13 10:43:29 -05:00
committed by GitHub
parent d8b4027200
commit dd26cc3c0f
29 changed files with 116 additions and 111 deletions

View File

@@ -14,6 +14,8 @@ import torch
from packaging import version
from tbparse import SummaryReader
from axolotl.utils.dict import DictDefault
def with_temp_dir(test_func):
@wraps(test_func)
@@ -93,3 +95,27 @@ def check_tensorboard(
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name
assert df.value.values[-1] < lt_val, assertion_err
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
"""
helper function to check if a model output file exists after training
checks based on adapter or not and if safetensors saves are enabled or not
"""
if cfg.save_safetensors:
if not cfg.adapter:
assert (Path(temp_dir) / "model.safetensors").exists()
else:
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
else:
# check for both, b/c in trl, it often defaults to saving safetensors
if not cfg.adapter:
assert (Path(temp_dir) / "pytorch_model.bin").exists() or (
Path(temp_dir) / "model.safetensors"
).exists()
else:
assert (Path(temp_dir) / "adapter_model.bin").exists() or (
Path(temp_dir) / "adapter_model.safetensors"
).exists()