lint and additional train metric checks for kd
This commit is contained in:
@@ -32,6 +32,7 @@ def benchmark_kl_div_loss_with_backward():
|
|||||||
student_logits = torch.randn(
|
student_logits = torch.randn(
|
||||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||||
)
|
)
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
target_token_ids = torch.randint(
|
target_token_ids = torch.randint(
|
||||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||||
)
|
)
|
||||||
@@ -56,6 +57,7 @@ def benchmark_kl_div_loss_with_backward():
|
|||||||
|
|
||||||
def run_triton():
|
def run_triton():
|
||||||
# Forward pass
|
# Forward pass
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
loss_triton = triton_loss(
|
loss_triton = triton_loss(
|
||||||
student_logits_triton,
|
student_logits_triton,
|
||||||
target_token_ids,
|
target_token_ids,
|
||||||
@@ -132,6 +134,7 @@ def benchmark_forward_backward_separately():
|
|||||||
student_logits = torch.randn(
|
student_logits = torch.randn(
|
||||||
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
batch_size, seq_len, vocab_size, device="cuda", requires_grad=True
|
||||||
)
|
)
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
target_token_ids = torch.randint(
|
target_token_ids = torch.randint(
|
||||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -422,6 +422,7 @@ class TopKKLDivergence(torch.autograd.Function):
|
|||||||
kd_loss = token_losses.sum()
|
kd_loss = token_losses.sum()
|
||||||
|
|
||||||
# Apply temperature scaling
|
# Apply temperature scaling
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
if kd_temperature != 1.0:
|
if kd_temperature != 1.0:
|
||||||
kd_loss = kd_loss * (kd_temperature**2)
|
kd_loss = kd_loss * (kd_temperature**2)
|
||||||
|
|
||||||
|
|||||||
@@ -90,6 +90,12 @@ class TestKnowledgeDistillation:
|
|||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||||
)
|
)
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
|
||||||
|
)
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"load_in_8bit",
|
"load_in_8bit",
|
||||||
@@ -121,3 +127,9 @@ class TestKnowledgeDistillation:
|
|||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
|
||||||
)
|
)
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/loss", 0.0, "Train Loss is too low", lt=False
|
||||||
|
)
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/grad_norm", 8.0, "Train grad norm is too high"
|
||||||
|
)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ def test_kl_loss_gradient():
|
|||||||
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
|
student_logits_triton = student_logits.detach().clone().requires_grad_(True)
|
||||||
|
|
||||||
# Generate random target token IDs, ensuring they're valid indices
|
# Generate random target token IDs, ensuring they're valid indices
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
target_token_ids = torch.randint(
|
target_token_ids = torch.randint(
|
||||||
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
0, vocab_size, (batch_size, seq_len, top_k), device="cuda"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -102,7 +102,11 @@ def is_hopper():
|
|||||||
|
|
||||||
|
|
||||||
def check_tensorboard(
|
def check_tensorboard(
|
||||||
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
|
temp_run_dir: str,
|
||||||
|
tag: str,
|
||||||
|
comparison_val: float,
|
||||||
|
assertion_err: str,
|
||||||
|
lt: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
helper function to parse and check tensorboard logs
|
helper function to parse and check tensorboard logs
|
||||||
@@ -112,10 +116,20 @@ def check_tensorboard(
|
|||||||
reader = SummaryReader(event_file)
|
reader = SummaryReader(event_file)
|
||||||
df = reader.scalars # pylint: disable=invalid-name
|
df = reader.scalars # pylint: disable=invalid-name
|
||||||
df = df[(df.tag == tag)] # pylint: disable=invalid-name
|
df = df[(df.tag == tag)] # pylint: disable=invalid-name
|
||||||
if "%s" in assertion_err:
|
if lt:
|
||||||
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
|
if "%s" in assertion_err:
|
||||||
|
assert df.value.values[-1] < comparison_val, (
|
||||||
|
assertion_err % df.value.values[-1]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert df.value.values[-1] < comparison_val, assertion_err
|
||||||
else:
|
else:
|
||||||
assert df.value.values[-1] < lt_val, assertion_err
|
if "%s" in assertion_err:
|
||||||
|
assert df.value.values[-1] > comparison_val, (
|
||||||
|
assertion_err % df.value.values[-1]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert df.value.values[-1] > comparison_val, assertion_err
|
||||||
|
|
||||||
|
|
||||||
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
|
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user