dtype issues
This commit is contained in:
@@ -66,7 +66,7 @@ def _call_grouped_mm(
|
|||||||
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
||||||
device = As2[0].device
|
device = As2[0].device
|
||||||
lengths = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
|
lengths = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
|
||||||
offsets = torch.cumsum(lengths, dim=0)
|
offsets = torch.cumsum(lengths, dim=0).to(torch.int32)
|
||||||
Y_cat = torch._grouped_mm(torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets)
|
Y_cat = torch._grouped_mm(torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets)
|
||||||
outs: List[torch.Tensor] = []
|
outs: List[torch.Tensor] = []
|
||||||
start = 0
|
start = 0
|
||||||
@@ -90,8 +90,8 @@ def moe_ffn_forward_grouped(
|
|||||||
device = hidden_states.device
|
device = hidden_states.device
|
||||||
|
|
||||||
routing_dtype = gate_linear.weight.dtype
|
routing_dtype = gate_linear.weight.dtype
|
||||||
|
|
||||||
expert_dtype = hidden_states.dtype
|
expert_dtype = hidden_states.dtype
|
||||||
|
|
||||||
if expert_dtype not in (torch.bfloat16, torch.float16):
|
if expert_dtype not in (torch.bfloat16, torch.float16):
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"torch_grouped: unsupported expert dtype %s; falling back to naive",
|
"torch_grouped: unsupported expert dtype %s; falling back to naive",
|
||||||
@@ -165,6 +165,9 @@ def moe_ffn_forward_grouped(
|
|||||||
buf[sel] = tensor
|
buf[sel] = tensor
|
||||||
|
|
||||||
combined = (
|
combined = (
|
||||||
buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1)
|
(buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1))
|
||||||
).sum(dim=1)
|
.sum(dim=1)
|
||||||
|
.to(torch.bfloat16)
|
||||||
|
)
|
||||||
|
|
||||||
return combined.view(bsz, seqlen, hdim), router_logits
|
return combined.view(bsz, seqlen, hdim), router_logits
|
||||||
|
|||||||
Reference in New Issue
Block a user