fix kwarg
This commit is contained in:
@@ -235,7 +235,7 @@ class _KLDivergenceTritonFn(torch.autograd.Function):
|
||||
|
||||
kd_loss = kl_val.sum()
|
||||
# now compute dLoss/d stl
|
||||
grad_stl = torch.autograd.grad(kd_loss, stl, grad_output=grad_output)[0]
|
||||
grad_stl = torch.autograd.grad(kd_loss, stl, grad_outputs=grad_output)[0]
|
||||
|
||||
return grad_stl, None, None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user