dtype issues

This commit is contained in:
Dan Saunders
2025-09-17 23:52:18 +00:00
parent c6878beb7d
commit bbf1f14ca4

View File

@@ -66,7 +66,7 @@ def _call_grouped_mm(
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs] Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
device = As2[0].device device = As2[0].device
lengths = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32) lengths = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
offsets = torch.cumsum(lengths, dim=0) offsets = torch.cumsum(lengths, dim=0).to(torch.int32)
Y_cat = torch._grouped_mm(torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets) Y_cat = torch._grouped_mm(torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets)
outs: List[torch.Tensor] = [] outs: List[torch.Tensor] = []
start = 0 start = 0
@@ -90,8 +90,8 @@ def moe_ffn_forward_grouped(
device = hidden_states.device device = hidden_states.device
routing_dtype = gate_linear.weight.dtype routing_dtype = gate_linear.weight.dtype
expert_dtype = hidden_states.dtype expert_dtype = hidden_states.dtype
if expert_dtype not in (torch.bfloat16, torch.float16): if expert_dtype not in (torch.bfloat16, torch.float16):
_LOGGER.debug( _LOGGER.debug(
"torch_grouped: unsupported expert dtype %s; falling back to naive", "torch_grouped: unsupported expert dtype %s; falling back to naive",
@@ -165,6 +165,9 @@ def moe_ffn_forward_grouped(
buf[sel] = tensor buf[sel] = tensor
combined = ( combined = (
buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1) (buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1))
).sum(dim=1) .sum(dim=1)
.to(torch.bfloat16)
)
return combined.view(bsz, seqlen, hdim), router_logits return combined.view(bsz, seqlen, hdim), router_logits