increase hyperparams_count for gradients for added normalize_topk
This commit is contained in:
@@ -337,7 +337,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
||||
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
|
||||
|
||||
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
|
||||
ctx.hyperparams_count = 8 # Corresponds to number of hyperparams after main tensors in fwd signature
|
||||
ctx.hyperparams_count = 9 # Corresponds to number of hyperparams after main tensors in fwd signature
|
||||
ctx.bias_was_none = student_lm_head_bias is None
|
||||
ctx.orig_dims = (B, N, D, K)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user