diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index a7e9ee494..de6b1dcac 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -114,7 +114,7 @@ class TestKnowledgeDistillation: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) - assert (Path(temp_dir) / "model.safetensors").exists() + assert (Path(temp_dir) / "adapter_model.safetensors").exists() check_tensorboard( temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" )