diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py index 12bdda272..b1ec058d4 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl_triton.py @@ -338,6 +338,7 @@ class TopKKLDivergence(torch.autograd.Function): ) kd_loss = token_losses.sum() + # pylint: disable=duplicate-code # Apply temperature scaling if kd_temperature != 1.0: kd_loss = kd_loss * (kd_temperature**2) @@ -426,7 +427,7 @@ class TopKKLDivergence(torch.autograd.Function): target_mask.stride(0), target_mask.stride(1), target_mask.stride(2), - min(256, triton.next_power_of_2(top_k)), + min(1024, triton.next_power_of_2(top_k)), ) else: # Case 2: Softmax over full vocab @@ -457,7 +458,7 @@ class TopKKLDivergence(torch.autograd.Function): target_mask.stride(0), target_mask.stride(1), target_mask.stride(2), - min(256, triton.next_power_of_2(top_k)), + min(1024, triton.next_power_of_2(top_k)), ) # Return gradients for student_logits and None for other inputs