chore: lint
This commit is contained in:
@@ -338,6 +338,7 @@ class TopKKLDivergence(torch.autograd.Function):
|
||||
)
|
||||
kd_loss = token_losses.sum()
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
# Apply temperature scaling
|
||||
if kd_temperature != 1.0:
|
||||
kd_loss = kd_loss * (kd_temperature**2)
|
||||
@@ -426,7 +427,7 @@ class TopKKLDivergence(torch.autograd.Function):
|
||||
target_mask.stride(0),
|
||||
target_mask.stride(1),
|
||||
target_mask.stride(2),
|
||||
min(256, triton.next_power_of_2(top_k)),
|
||||
min(1024, triton.next_power_of_2(top_k)),
|
||||
)
|
||||
else:
|
||||
# Case 2: Softmax over full vocab
|
||||
@@ -457,7 +458,7 @@ class TopKKLDivergence(torch.autograd.Function):
|
||||
target_mask.stride(0),
|
||||
target_mask.stride(1),
|
||||
target_mask.stride(2),
|
||||
min(256, triton.next_power_of_2(top_k)),
|
||||
min(1024, triton.next_power_of_2(top_k)),
|
||||
)
|
||||
|
||||
# Return gradients for student_logits and None for other inputs
|
||||
|
||||
Reference in New Issue
Block a user