dtype
This commit is contained in:
@@ -135,11 +135,11 @@ def _kernel_cg_backward_dw(
|
|||||||
grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||||
)
|
)
|
||||||
mask_go = mask_m[:, None] & mask_n[None, :]
|
mask_go = mask_m[:, None] & mask_n[None, :]
|
||||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0)
|
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
|
||||||
|
|
||||||
in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :]
|
in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||||
mask_in = mask_m[:, None] & mask_k[None, :]
|
mask_in = mask_m[:, None] & mask_k[None, :]
|
||||||
inp = tl.load(in_ptrs, mask=mask_in, other=0.0)
|
inp = tl.load(in_ptrs, mask=mask_in, other=0.0).to(tl.float32)
|
||||||
|
|
||||||
go_t = tl.trans(go)
|
go_t = tl.trans(go)
|
||||||
grad_weights += tl.dot(go_t, inp)
|
grad_weights += tl.dot(go_t, inp)
|
||||||
|
|||||||
Reference in New Issue
Block a user