From 23d7ae6caa97193324774ab94e76c5dee5ee52b7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 21 Dec 2024 14:32:11 -0500 Subject: [PATCH] fix kwarg --- src/axolotl/integrations/kd/kernels/kd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/integrations/kd/kernels/kd.py b/src/axolotl/integrations/kd/kernels/kd.py index f45717a01..26cf76aac 100644 --- a/src/axolotl/integrations/kd/kernels/kd.py +++ b/src/axolotl/integrations/kd/kernels/kd.py @@ -235,7 +235,7 @@ class _KLDivergenceTritonFn(torch.autograd.Function): kd_loss = kl_val.sum() # now compute dLoss/d stl - grad_stl = torch.autograd.grad(kd_loss, stl, grad_output=grad_output)[0] + grad_stl = torch.autograd.grad(kd_loss, stl, grad_outputs=grad_output)[0] return grad_stl, None, None