increase hyperparams_count for gradients for added normalize_topk
This commit is contained in:
@@ -337,7 +337,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
|
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
|
||||||
|
|
||||||
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
|
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
|
||||||
ctx.hyperparams_count = 8 # Corresponds to number of hyperparams after main tensors in fwd signature
|
ctx.hyperparams_count = 9 # Corresponds to number of hyperparams after main tensors in fwd signature
|
||||||
ctx.bias_was_none = student_lm_head_bias is None
|
ctx.bias_was_none = student_lm_head_bias is None
|
||||||
ctx.orig_dims = (B, N, D, K)
|
ctx.orig_dims = (B, N, D, K)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user