lint and additional train metric checks for kd
This commit is contained in:
@@ -90,6 +90,12 @@ class TestKnowledgeDistillation:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"load_in_8bit",
|
||||
@@ -121,3 +127,9 @@ class TestKnowledgeDistillation:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
|
||||
)
|
||||
|
||||
@@ -27,6 +27,7 @@ def test_kl_loss_gradient():
|
||||
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
|
||||
|
||||
# Generate random target token IDs, ensuring they're valid indices
|
||||
# pylint: disable=duplicate-code
|
||||
target_token_ids = torch.randint(
|
||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user