lint and additional train metric checks for kd

This commit is contained in:
Wing Lian
2025-02-26 03:19:42 -05:00
parent afbb44f08b
commit 23f029a89c
5 changed files with 35 additions and 4 deletions

View File

@@ -32,6 +32,7 @@ def benchmark_kl_div_loss_with_backward():
student_logits = torch.randn( student_logits = torch.randn(
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
) )
# pylint: disable=duplicate-code
target_token_ids = torch.randint( target_token_ids = torch.randint(
0, vocab_size, (batch_size, seq_len, top_k), device="cuda" 0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
) )
@@ -56,6 +57,7 @@ def benchmark_kl_div_loss_with_backward():
def run_triton(): def run_triton():
# Forward pass # Forward pass
# pylint: disable=duplicate-code
loss_triton = triton_loss( loss_triton = triton_loss(
student_logits_triton, student_logits_triton,
target_token_ids, target_token_ids,
@@ -132,6 +134,7 @@ def benchmark_forward_backward_separately():
student_logits = torch.randn( student_logits = torch.randn(
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
) )
# pylint: disable=duplicate-code
target_token_ids = torch.randint( target_token_ids = torch.randint(
0, vocab_size, (batch_size, seq_len, top_k), device="cuda" 0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
) )

View File

@@ -422,6 +422,7 @@ class TopKKLDivergence(torch.autograd.Function):
kd_loss = token_losses.sum() kd_loss = token_losses.sum()
# Apply temperature scaling # Apply temperature scaling
# pylint: disable=duplicate-code
if kd_temperature != 1.0: if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2) kd_loss = kd_loss * (kd_temperature**2)

View File

@@ -90,6 +90,12 @@ class TestKnowledgeDistillation:
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
) )
check_tensorboard(
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
)
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"load_in_8bit", "load_in_8bit",
@@ -121,3 +127,9 @@ class TestKnowledgeDistillation:
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
) )
check_tensorboard(
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
)
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
)

View File

@@ -27,6 +27,7 @@ def test_kl_loss_gradient():
student_logits_triton = student_logits.detach().clone().requires_grad_(True) student_logits_triton = student_logits.detach().clone().requires_grad_(True)
# Generate random target token IDs, ensuring they're valid indices # Generate random target token IDs, ensuring they're valid indices
# pylint: disable=duplicate-code
target_token_ids = torch.randint( target_token_ids = torch.randint(
0, vocab_size, (batch_size, seq_len, top_k), device="cuda" 0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
) )

View File

@@ -102,7 +102,11 @@ def is_hopper():
def check_tensorboard( def check_tensorboard(
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str temp_run_dir: str,
tag: str,
comparison_val: float,
assertion_err: str,
lt: bool = True,
) -> None: ) -> None:
""" """
helper function to parse and check tensorboard logs helper function to parse and check tensorboard logs
@@ -112,10 +116,20 @@ def check_tensorboard(
reader = SummaryReader(event_file) reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name df = df[(df.tag == tag)] # pylint: disable=invalid-name
if "%s" in assertion_err: if lt:
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1] if "%s" in assertion_err:
assert df.value.values[-1] < comparison_val, (
assertion_err % df.value.values[-1]
)
else:
assert df.value.values[-1] < comparison_val, assertion_err
else: else:
assert df.value.values[-1] < lt_val, assertion_err if "%s" in assertion_err:
assert df.value.values[-1] > comparison_val, (
assertion_err % df.value.values[-1]
)
else:
assert df.value.values[-1] > comparison_val, assertion_err
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None: def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None: