add helper to verify the correct model output file exists (#2245)
* add helper to verify the correct model output file exists * more checks using helper * chore: lint * fix import and relora model check * workaround for trl trainer saves * remove stray print
This commit is contained in:
@@ -14,6 +14,8 @@ import torch
|
||||
from packaging import version
|
||||
from tbparse import SummaryReader
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def with_temp_dir(test_func):
|
||||
@wraps(test_func)
|
||||
@@ -93,3 +95,27 @@ def check_tensorboard(
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == tag)] # pylint: disable=invalid-name
|
||||
assert df.value.values[-1] < lt_val, assertion_err
|
||||
|
||||
|
||||
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
|
||||
"""
|
||||
helper function to check if a model output file exists after training
|
||||
|
||||
checks based on adapter or not and if safetensors saves are enabled or not
|
||||
"""
|
||||
|
||||
if cfg.save_safetensors:
|
||||
if not cfg.adapter:
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
else:
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
else:
|
||||
# check for both, b/c in trl, it often defaults to saving safetensors
|
||||
if not cfg.adapter:
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists() or (
|
||||
Path(temp_dir) / "model.safetensors"
|
||||
).exists()
|
||||
else:
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists() or (
|
||||
Path(temp_dir) / "adapter_model.safetensors"
|
||||
).exists()
|
||||
|
||||
Reference in New Issue
Block a user