diff --git a/src/axolotl/integrations/kd/kernels/liger.py b/src/axolotl/integrations/kd/kernels/liger.py index 2aea80578..74c97897b 100644 --- a/src/axolotl/integrations/kd/kernels/liger.py +++ b/src/axolotl/integrations/kd/kernels/liger.py @@ -337,7 +337,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc) # For matching None returns in backward for non-tensor/non-grad_requiring inputs - ctx.hyperparams_count = 8 # Corresponds to number of hyperparams after main tensors in fwd signature + ctx.hyperparams_count = 9 # Corresponds to number of hyperparams after main tensors in fwd signature ctx.bias_was_none = student_lm_head_bias is None ctx.orig_dims = (B, N, D, K)