dtype issues
This commit is contained in:
@@ -66,7 +66,7 @@ def _call_grouped_mm(
|
||||
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
||||
device = As2[0].device
|
||||
lengths = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
|
||||
offsets = torch.cumsum(lengths, dim=0)
|
||||
offsets = torch.cumsum(lengths, dim=0).to(torch.int32)
|
||||
Y_cat = torch._grouped_mm(torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets)
|
||||
outs: List[torch.Tensor] = []
|
||||
start = 0
|
||||
@@ -90,8 +90,8 @@ def moe_ffn_forward_grouped(
|
||||
device = hidden_states.device
|
||||
|
||||
routing_dtype = gate_linear.weight.dtype
|
||||
|
||||
expert_dtype = hidden_states.dtype
|
||||
|
||||
if expert_dtype not in (torch.bfloat16, torch.float16):
|
||||
_LOGGER.debug(
|
||||
"torch_grouped: unsupported expert dtype %s; falling back to naive",
|
||||
@@ -165,6 +165,9 @@ def moe_ffn_forward_grouped(
|
||||
buf[sel] = tensor
|
||||
|
||||
combined = (
|
||||
buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1)
|
||||
).sum(dim=1)
|
||||
(buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1))
|
||||
.sum(dim=1)
|
||||
.to(torch.bfloat16)
|
||||
)
|
||||
|
||||
return combined.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
Reference in New Issue
Block a user