dtype issues

This commit is contained in:
Dan Saunders
2025-09-17 23:52:18 +00:00
parent c6878beb7d
commit bbf1f14ca4

View File

@@ -66,7 +66,7 @@ def _call_grouped_mm(
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
device = As2[0].device
lengths = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
offsets = torch.cumsum(lengths, dim=0)
offsets = torch.cumsum(lengths, dim=0).to(torch.int32)
Y_cat = torch._grouped_mm(torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets)
outs: List[torch.Tensor] = []
start = 0
@@ -90,8 +90,8 @@ def moe_ffn_forward_grouped(
device = hidden_states.device
routing_dtype = gate_linear.weight.dtype
expert_dtype = hidden_states.dtype
if expert_dtype not in (torch.bfloat16, torch.float16):
_LOGGER.debug(
"torch_grouped: unsupported expert dtype %s; falling back to naive",
@@ -165,6 +165,9 @@ def moe_ffn_forward_grouped(
buf[sel] = tensor
combined = (
buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1)
).sum(dim=1)
(buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1))
.sum(dim=1)
.to(torch.bfloat16)
)
return combined.view(bsz, seqlen, hdim), router_logits