dtype
This commit is contained in:
@@ -135,11 +135,11 @@ def _kernel_cg_backward_dw(
|
||||
grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
)
|
||||
mask_go = mask_m[:, None] & mask_n[None, :]
|
||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0)
|
||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
|
||||
|
||||
in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
mask_in = mask_m[:, None] & mask_k[None, :]
|
||||
inp = tl.load(in_ptrs, mask=mask_in, other=0.0)
|
||||
inp = tl.load(in_ptrs, mask=mask_in, other=0.0).to(tl.float32)
|
||||
|
||||
go_t = tl.trans(go)
|
||||
grad_weights += tl.dot(go_t, inp)
|
||||
|
||||
Reference in New Issue
Block a user