chore: lint

This commit is contained in:
Wing Lian
2025-02-25 07:29:46 -05:00
parent e82268e580
commit a2e52a29e9

View File

@@ -338,6 +338,7 @@ class TopKKLDivergence(torch.autograd.Function):
)
kd_loss = token_losses.sum()
# pylint: disable=duplicate-code
# Apply temperature scaling
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
@@ -426,7 +427,7 @@ class TopKKLDivergence(torch.autograd.Function):
target_mask.stride(0),
target_mask.stride(1),
target_mask.stride(2),
min(256, triton.next_power_of_2(top_k)),
min(1024, triton.next_power_of_2(top_k)),
)
else:
# Case 2: Softmax over full vocab
@@ -457,7 +458,7 @@ class TopKKLDivergence(torch.autograd.Function):
target_mask.stride(0),
target_mask.stride(1),
target_mask.stride(2),
min(256, triton.next_power_of_2(top_k)),
min(1024, triton.next_power_of_2(top_k)),
)
# Return gradients for student_logits and None for other inputs