lint and additional train metric checks for kd

This commit is contained in:
Wing Lian
2025-02-26 03:19:42 -05:00
parent afbb44f08b
commit 23f029a89c
5 changed files with 35 additions and 4 deletions

View File

@@ -90,6 +90,12 @@ class TestKnowledgeDistillation:
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
)
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
)
@pytest.mark.parametrize(
"load_in_8bit",
@@ -121,3 +127,9 @@ class TestKnowledgeDistillation:
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)
check_tensorboard(
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
)
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
)