fix adapter model check

This commit is contained in:
Wing Lian
2025-01-08 20:15:42 -05:00
parent 0da0cd02e5
commit c45ab03487

View File

@@ -114,7 +114,7 @@ class TestKnowledgeDistillation:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)