diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 5445bd239..987bfeb6e 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -32,26 +32,43 @@ def _call_grouped_mm( As: List[torch.Tensor], Bs: List[torch.Tensor] ) -> Optional[List[torch.Tensor]]: """ - Try calling the appropriate grouped mm op available in this torch build. - Returns list of outputs or None on failure. + Call grouped mm using aten._grouped_mm with packed representation. + + - A_cat: concat As along rows -> [sum_i Mi, K] + - B_stk: stack Bs per group -> [G, K, N] + - offs: lengths per group Mi -> [G] int32 + Returns list of per-group outputs split from concatenated result. """ global LAST_ERROR try: # Ensure 2D contiguous inputs - As = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As] - Bs = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs] + As2 = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As] + Bs2 = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs] + + if not As2: + return [] + device = As2[0].device + A_cat = torch.cat(As2, dim=0) + B_stk = torch.stack(Bs2, dim=0) + offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32) if hasattr(torch.ops.aten, "_grouped_mm"): try: - # Some builds expect tuples rather than lists - return torch.ops.aten._grouped_mm(tuple(As), tuple(Bs)) # type: ignore[attr-defined] + Y_cat = torch.ops.aten._grouped_mm(A_cat, B_stk, offs) # type: ignore[attr-defined] + outs: List[torch.Tensor] = [] + start = 0 + for m in offs.tolist(): + outs.append(Y_cat[start : start + m, :]) + start += m + return outs except Exception as e: LAST_ERROR = f"_grouped_mm failed: {e}" - # Avoid _scaled_grouped_mm for now; its signature requires packed inputs. + return None + LAST_ERROR = "aten._grouped_mm not present" + return None except Exception as e: LAST_ERROR = f"call error: {e}" return None - return None def moe_ffn_forward_grouped(