fix kwarg

This commit is contained in:
Wing Lian
2024-12-21 14:32:11 -05:00
parent e565694914
commit c0757e8a20

View File

@@ -235,7 +235,7 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
kd_loss = kl_val.sum()
# now compute dLoss/d stl
grad_stl = torch.autograd.grad(kd_loss, stl, grad_output=grad_output)[0]
grad_stl = torch.autograd.grad(kd_loss, stl, grad_outputs=grad_output)[0]
return grad_stl, None, None