diff --git a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py index d605b652d..2a221c13c 100644 --- a/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py +++ b/src/axolotl/integrations/kernels/libs/scattermoe_lora/kernels/lora_ops.py @@ -313,12 +313,11 @@ def _compute_expert_block_lora( B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0 ) # [BLOCK_N, BLOCK_R] - # Cast xa_acc and b to same dtype for tl.dot (required when input is bf16/fp16) - # Both operands must match; cast to float32 (accumulator type) for precision. - b_f32 = b.to(tl.float32) + # tl.dot requires non-float32 inputs (tensor cores); cast back to input dtype + b_inp = b.to(INPUT_DTYPE) # (X @ A^T) @ B^T: [M, R] @ [R, N] -> [M, N] - lora_out = tl.dot(xa_acc, tl.trans(b_f32), allow_tf32=allow_tf32) + lora_out = tl.dot(xa_acc.to(INPUT_DTYPE), tl.trans(b_inp), allow_tf32=allow_tf32) acc += scaling * lora_out return acc @@ -867,13 +866,13 @@ def _compute_expert_block_lora_dX( + (A_expert_offset + R_block)[:, None] * stride_ar + K_block[None, :] * stride_ak ) - a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0) - - # Cast to float32 for precision - a_f32 = a_e.to(tl.float32) + a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to( + INPUT_DTYPE + ) # (DY @ B) @ A: [M, R] @ [R, K] -> [M, K] - lora_dx = tl.dot(dy_b_acc, a_f32, allow_tf32=allow_tf32) + # tl.dot requires non-float32 inputs (tensor cores); cast accumulator back to input dtype + lora_dx = tl.dot(dy_b_acc.to(INPUT_DTYPE), a_e, allow_tf32=allow_tf32) acc += scaling * lora_dx return acc