From 23f029a89c6734da68efa5ad3058f865049fcceb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 26 Feb 2025 03:19:42 -0500 Subject: [PATCH] lint and additional train metric checks for kd --- .../integrations/kd/topk_logprob/bench_kl.py | 3 +++ .../kd/topk_logprob/forward_kl_triton.py | 1 + tests/e2e/integrations/test_kd.py | 12 ++++++++++ tests/e2e/integrations/test_kl_loss.py | 1 + tests/e2e/utils.py | 22 +++++++++++++++---- 5 files changed, 35 insertions(+), 4 deletions(-) diff --git a/src/axolotl/integrations/kd/topk_logprob/bench_kl.py b/src/axolotl/integrations/kd/topk_logprob/bench_kl.py index d67320bb7..01fa223fd 100644 --- a/src/axolotl/integrations/kd/topk_logprob/bench_kl.py +++ b/src/axolotl/integrations/kd/topk_logprob/bench_kl.py @@ -32,6 +32,7 @@ def benchmark_kl_div_loss_with_backward(): student_logits = torch.randn( batch_size, seq_len, vocab_size, device="cuda", requires_grad=True ) + # pylint: disable=duplicate-code target_token_ids = torch.randint( 0, vocab_size, (batch_size, seq_len, top_k), device="cuda" ) @@ -56,6 +57,7 @@ def benchmark_kl_div_loss_with_backward(): def run_triton(): # Forward pass + # pylint: disable=duplicate-code loss_triton = triton_loss( student_logits_triton, target_token_ids, @@ -132,6 +134,7 @@ def benchmark_forward_backward_separately(): student_logits = torch.randn( batch_size, seq_len, vocab_size, device="cuda", requires_grad=True ) + # pylint: disable=duplicate-code target_token_ids = torch.randint( 0, vocab_size, (batch_size, seq_len, top_k), device="cuda" ) diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py index f921f80bb..e79d799df 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py @@ -422,6 +422,7 @@ class TopKKLDivergence(torch.autograd.Function): kd_loss = token_losses.sum() # Apply temperature scaling + # pylint: disable=duplicate-code if kd_temperature != 1.0: kd_loss = kd_loss * (kd_temperature**2) diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index a90b48d67..84da1aa38 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -90,6 +90,12 @@ class TestKnowledgeDistillation: check_tensorboard( temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" ) + check_tensorboard( + temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False + ) + check_tensorboard( + temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high" + ) @pytest.mark.parametrize( "load_in_8bit", @@ -121,3 +127,9 @@ class TestKnowledgeDistillation: check_tensorboard( temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" ) + check_tensorboard( + temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False + ) + check_tensorboard( + temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high" + ) diff --git a/tests/e2e/integrations/test_kl_loss.py b/tests/e2e/integrations/test_kl_loss.py index 52a789166..0a1edd899 100644 --- a/tests/e2e/integrations/test_kl_loss.py +++ b/tests/e2e/integrations/test_kl_loss.py @@ -27,6 +27,7 @@ def test_kl_loss_gradient(): student_logits_triton = student_logits.detach().clone().requires_grad_(True) # Generate random target token IDs, ensuring they're valid indices + # pylint: disable=duplicate-code target_token_ids = torch.randint( 0, vocab_size, (batch_size, seq_len, top_k), device="cuda" ) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index ff96f1f58..3ca1ea479 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -102,7 +102,11 @@ def is_hopper(): def check_tensorboard( - temp_run_dir: str, tag: str, lt_val: float, assertion_err: str + temp_run_dir: str, + tag: str, + comparison_val: float, + assertion_err: str, + lt: bool = True, ) -> None: """ helper function to parse and check tensorboard logs @@ -112,10 +116,20 @@ def check_tensorboard( reader = SummaryReader(event_file) df = reader.scalars # pylint: disable=invalid-name df = df[(df.tag == tag)] # pylint: disable=invalid-name - if "%s" in assertion_err: - assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1] + if lt: + if "%s" in assertion_err: + assert df.value.values[-1] < comparison_val, ( + assertion_err % df.value.values[-1] + ) + else: + assert df.value.values[-1] < comparison_val, assertion_err else: - assert df.value.values[-1] < lt_val, assertion_err + if "%s" in assertion_err: + assert df.value.values[-1] > comparison_val, ( + assertion_err % df.value.values[-1] + ) + else: + assert df.value.values[-1] > comparison_val, assertion_err def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None: