increase hyperparams_count for gradients for added normalize_topk

This commit is contained in:
Wing Lian
2025-05-31 08:42:06 -04:00
parent d55a51623f
commit a8e2bddd19

View File

@@ -337,7 +337,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
ctx.hyperparams_count = 8 # Corresponds to number of hyperparams after main tensors in fwd signature
ctx.hyperparams_count = 9 # Corresponds to number of hyperparams after main tensors in fwd signature
ctx.bias_was_none = student_lm_head_bias is None
ctx.orig_dims = (B, N, D, K)