From 556d6448fe91ce7a9d2ffc22db5d5172ef5a84f5 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 15 Sep 2025 19:36:00 -0400 Subject: [PATCH] fix --- src/axolotl/kernels/moe/torch_grouped.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 1ed2ce20b..5445bd239 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -43,14 +43,11 @@ def _call_grouped_mm( if hasattr(torch.ops.aten, "_grouped_mm"): try: - return torch.ops.aten._grouped_mm(As, Bs) # type: ignore[attr-defined] + # Some builds expect tuples rather than lists + return torch.ops.aten._grouped_mm(tuple(As), tuple(Bs)) # type: ignore[attr-defined] except Exception as e: LAST_ERROR = f"_grouped_mm failed: {e}" - if hasattr(torch.ops.aten, "_scaled_grouped_mm"): - try: - return torch.ops.aten._scaled_grouped_mm(As, Bs, 1.0, 0.0) # type: ignore[attr-defined] - except Exception as e: - LAST_ERROR = f"_scaled_grouped_mm failed: {e}" + # Avoid _scaled_grouped_mm for now; its signature requires packed inputs. except Exception as e: LAST_ERROR = f"call error: {e}" return None