This commit is contained in:
Dan Saunders
2025-09-17 19:15:34 -04:00
parent e62979d11d
commit c6878beb7d

View File

@@ -56,35 +56,24 @@ def _stack_weights(
def _call_grouped_mm( def _call_grouped_mm(
As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype
) -> Tuple[Optional[List[torch.Tensor]], Optional[str]]: ) -> List[torch.Tensor]:
if not As: if not As:
return [], None return []
if dtype not in (torch.bfloat16, torch.float16): if dtype not in (torch.bfloat16, torch.float16):
msg = f"unsupported dtype {dtype}" raise RuntimeError(f"unsupported dtype {dtype} for grouped_mm")
_LOGGER.debug("torch_grouped: %s", msg)
return None, msg
try: As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As] Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs] device = As2[0].device
device = As2[0].device lengths = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
lengths = torch.tensor( offsets = torch.cumsum(lengths, dim=0)
[a.shape[0] for a in As2], device=device, dtype=torch.int32 Y_cat = torch._grouped_mm(torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets)
) outs: List[torch.Tensor] = []
offsets = torch.cumsum(lengths, dim=0).to(torch.int32) start = 0
Y_cat = torch._grouped_mm( for size in lengths.tolist():
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets outs.append(Y_cat[start : start + size])
) start += size
outs: List[torch.Tensor] = [] return outs
start = 0
for size in lengths.tolist():
outs.append(Y_cat[start : start + size])
start += size
return outs, None
except RuntimeError as err:
message = f"_grouped_mm failed ({err})"
_LOGGER.warning("torch_grouped: %s", message)
return None, message
def moe_ffn_forward_grouped( def moe_ffn_forward_grouped(
@@ -159,9 +148,7 @@ def moe_ffn_forward_grouped(
if not as_list: if not as_list:
return torch.zeros_like(x_flat).view(bsz, seqlen, hdim), router_logits return torch.zeros_like(x_flat).view(bsz, seqlen, hdim), router_logits
up_out, reason = _call_grouped_mm(as_list, bs_list, expert_dtype) up_out = _call_grouped_mm(as_list, bs_list, expert_dtype)
if up_out is None:
return None, None
down_inputs: List[torch.Tensor] = [] down_inputs: List[torch.Tensor] = []
down_weights: List[torch.Tensor] = [] down_weights: List[torch.Tensor] = []
@@ -172,9 +159,7 @@ def moe_ffn_forward_grouped(
down_inputs.append(hidden) down_inputs.append(hidden)
down_weights.append(w2[i]) down_weights.append(w2[i])
down_out, reason = _call_grouped_mm(down_inputs, down_weights, expert_dtype) down_out = _call_grouped_mm(down_inputs, down_weights, expert_dtype)
if down_out is None:
return None, None
for (_i, sel), tensor in zip(slices, down_out, strict=False): for (_i, sel), tensor in zip(slices, down_out, strict=False):
buf[sel] = tensor buf[sel] = tensor