From bbf1f14ca44e6b7188ed6531c43c52b1354226e1 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 17 Sep 2025 23:52:18 +0000 Subject: [PATCH] dtype issues --- src/axolotl/kernels/moe/torch_grouped.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 0abebc664..fbb6f9b4d 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -66,7 +66,7 @@ def _call_grouped_mm( Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs] device = As2[0].device lengths = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32) - offsets = torch.cumsum(lengths, dim=0) + offsets = torch.cumsum(lengths, dim=0).to(torch.int32) Y_cat = torch._grouped_mm(torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets) outs: List[torch.Tensor] = [] start = 0 @@ -90,8 +90,8 @@ def moe_ffn_forward_grouped( device = hidden_states.device routing_dtype = gate_linear.weight.dtype - expert_dtype = hidden_states.dtype + if expert_dtype not in (torch.bfloat16, torch.float16): _LOGGER.debug( "torch_grouped: unsupported expert dtype %s; falling back to naive", @@ -165,6 +165,9 @@ def moe_ffn_forward_grouped( buf[sel] = tensor combined = ( - buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1) - ).sum(dim=1) + (buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1)) + .sum(dim=1) + .to(torch.bfloat16) + ) + return combined.view(bsz, seqlen, hdim), router_logits