chore: lint
This commit is contained in:
@@ -338,6 +338,7 @@ class TopKKLDivergence(torch.autograd.Function):
|
|||||||
)
|
)
|
||||||
kd_loss = token_losses.sum()
|
kd_loss = token_losses.sum()
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
# Apply temperature scaling
|
# Apply temperature scaling
|
||||||
if kd_temperature != 1.0:
|
if kd_temperature != 1.0:
|
||||||
kd_loss = kd_loss * (kd_temperature**2)
|
kd_loss = kd_loss * (kd_temperature**2)
|
||||||
@@ -426,7 +427,7 @@ class TopKKLDivergence(torch.autograd.Function):
|
|||||||
target_mask.stride(0),
|
target_mask.stride(0),
|
||||||
target_mask.stride(1),
|
target_mask.stride(1),
|
||||||
target_mask.stride(2),
|
target_mask.stride(2),
|
||||||
min(256, triton.next_power_of_2(top_k)),
|
min(1024, triton.next_power_of_2(top_k)),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Case 2: Softmax over full vocab
|
# Case 2: Softmax over full vocab
|
||||||
@@ -457,7 +458,7 @@ class TopKKLDivergence(torch.autograd.Function):
|
|||||||
target_mask.stride(0),
|
target_mask.stride(0),
|
||||||
target_mask.stride(1),
|
target_mask.stride(1),
|
||||||
target_mask.stride(2),
|
target_mask.stride(2),
|
||||||
min(256, triton.next_power_of_2(top_k)),
|
min(1024, triton.next_power_of_2(top_k)),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return gradients for student_logits and None for other inputs
|
# Return gradients for student_logits and None for other inputs
|
||||||
|
|||||||
Reference in New Issue
Block a user