diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 50f1c4d8b..0abebc664 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -56,35 +56,24 @@ def _stack_weights( def _call_grouped_mm( As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype -) -> Tuple[Optional[List[torch.Tensor]], Optional[str]]: +) -> List[torch.Tensor]: if not As: - return [], None + return [] if dtype not in (torch.bfloat16, torch.float16): - msg = f"unsupported dtype {dtype}" - _LOGGER.debug("torch_grouped: %s", msg) - return None, msg + raise RuntimeError(f"unsupported dtype {dtype} for grouped_mm") - try: - As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As] - Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs] - device = As2[0].device - lengths = torch.tensor( - [a.shape[0] for a in As2], device=device, dtype=torch.int32 - ) - offsets = torch.cumsum(lengths, dim=0).to(torch.int32) - Y_cat = torch._grouped_mm( - torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets - ) - outs: List[torch.Tensor] = [] - start = 0 - for size in lengths.tolist(): - outs.append(Y_cat[start : start + size]) - start += size - return outs, None - except RuntimeError as err: - message = f"_grouped_mm failed ({err})" - _LOGGER.warning("torch_grouped: %s", message) - return None, message + As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As] + Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs] + device = As2[0].device + lengths = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32) + offsets = torch.cumsum(lengths, dim=0) + Y_cat = torch._grouped_mm(torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offsets) + outs: List[torch.Tensor] = [] + start = 0 + for size in lengths.tolist(): + outs.append(Y_cat[start : start + size]) + start += size + return outs def moe_ffn_forward_grouped( @@ -159,9 +148,7 @@ def moe_ffn_forward_grouped( if not as_list: return torch.zeros_like(x_flat).view(bsz, seqlen, hdim), router_logits - up_out, reason = _call_grouped_mm(as_list, bs_list, expert_dtype) - if up_out is None: - return None, None + up_out = _call_grouped_mm(as_list, bs_list, expert_dtype) down_inputs: List[torch.Tensor] = [] down_weights: List[torch.Tensor] = [] @@ -172,9 +159,7 @@ def moe_ffn_forward_grouped( down_inputs.append(hidden) down_weights.append(w2[i]) - down_out, reason = _call_grouped_mm(down_inputs, down_weights, expert_dtype) - if down_out is None: - return None, None + down_out = _call_grouped_mm(down_inputs, down_weights, expert_dtype) for (_i, sel), tensor in zip(slices, down_out, strict=False): buf[sel] = tensor