fix kwarg
This commit is contained in:
@@ -235,7 +235,7 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
|
|||||||
|
|
||||||
kd_loss = kl_val.sum()
|
kd_loss = kl_val.sum()
|
||||||
# now compute dLoss/d stl
|
# now compute dLoss/d stl
|
||||||
grad_stl = torch.autograd.grad(kd_loss, stl, grad_output=grad_output)[0]
|
grad_stl = torch.autograd.grad(kd_loss, stl, grad_outputs=grad_output)[0]
|
||||||
|
|
||||||
return grad_stl, None, None
|
return grad_stl, None, None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user