diff --git a/src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py b/src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py index 514211d50..2c1b62f54 100644 --- a/src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py +++ b/src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py @@ -135,11 +135,11 @@ def _kernel_cg_backward_dw( grad_output_ptr + offs_m[:, None] * N + offs_n[None, :] ) mask_go = mask_m[:, None] & mask_n[None, :] - go = tl.load(go_ptrs, mask=mask_go, other=0.0) + go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32) in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :] mask_in = mask_m[:, None] & mask_k[None, :] - inp = tl.load(in_ptrs, mask=mask_in, other=0.0) + inp = tl.load(in_ptrs, mask=mask_in, other=0.0).to(tl.float32) go_t = tl.trans(go) grad_weights += tl.dot(go_t, inp)