add helper to verify the correct model output file exists (#2245)
* add helper to verify the correct model output file exists * more checks using helper * chore: lint * fix import and relora model check * workaround for trl trainer saves * remove stray print
This commit is contained in:
@@ -13,7 +13,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_tensorboard, with_temp_dir
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -78,10 +78,10 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
|
||||
assert (
|
||||
Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors"
|
||||
).exists()
|
||||
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()
|
||||
Path(temp_dir) / "checkpoint-100/relora/model.safetensors"
|
||||
).exists(), "Relora model checkpoint not found"
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"
|
||||
|
||||
Reference in New Issue
Block a user