From fd312f60585f41a372e598a4174c150ae3e3c7df Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 23 Sep 2025 12:20:39 -0400 Subject: [PATCH] dtype --- src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py b/src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py index 514211d50..2c1b62f54 100644 --- a/src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py +++ b/src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py @@ -135,11 +135,11 @@ def _kernel_cg_backward_dw( grad_output_ptr + offs_m[:, None] * N + offs_n[None, :] ) mask_go = mask_m[:, None] & mask_n[None, :] - go = tl.load(go_ptrs, mask=mask_go, other=0.0) + go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32) in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :] mask_in = mask_m[:, None] & mask_k[None, :] - inp = tl.load(in_ptrs, mask=mask_in, other=0.0) + inp = tl.load(in_ptrs, mask=mask_in, other=0.0).to(tl.float32) go_t = tl.trans(go) grad_weights += tl.dot(go_t, inp)