fix
This commit is contained in:
@@ -32,26 +32,43 @@ def _call_grouped_mm(
|
|||||||
As: List[torch.Tensor], Bs: List[torch.Tensor]
|
As: List[torch.Tensor], Bs: List[torch.Tensor]
|
||||||
) -> Optional[List[torch.Tensor]]:
|
) -> Optional[List[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Try calling the appropriate grouped mm op available in this torch build.
|
Call grouped mm using aten._grouped_mm with packed representation.
|
||||||
Returns list of outputs or None on failure.
|
|
||||||
|
- A_cat: concat As along rows -> [sum_i Mi, K]
|
||||||
|
- B_stk: stack Bs per group -> [G, K, N]
|
||||||
|
- offs: lengths per group Mi -> [G] int32
|
||||||
|
Returns list of per-group outputs split from concatenated result.
|
||||||
"""
|
"""
|
||||||
global LAST_ERROR
|
global LAST_ERROR
|
||||||
try:
|
try:
|
||||||
# Ensure 2D contiguous inputs
|
# Ensure 2D contiguous inputs
|
||||||
As = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
As2 = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
||||||
Bs = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
Bs2 = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
||||||
|
|
||||||
|
if not As2:
|
||||||
|
return []
|
||||||
|
device = As2[0].device
|
||||||
|
A_cat = torch.cat(As2, dim=0)
|
||||||
|
B_stk = torch.stack(Bs2, dim=0)
|
||||||
|
offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
|
||||||
|
|
||||||
if hasattr(torch.ops.aten, "_grouped_mm"):
|
if hasattr(torch.ops.aten, "_grouped_mm"):
|
||||||
try:
|
try:
|
||||||
# Some builds expect tuples rather than lists
|
Y_cat = torch.ops.aten._grouped_mm(A_cat, B_stk, offs) # type: ignore[attr-defined]
|
||||||
return torch.ops.aten._grouped_mm(tuple(As), tuple(Bs)) # type: ignore[attr-defined]
|
outs: List[torch.Tensor] = []
|
||||||
|
start = 0
|
||||||
|
for m in offs.tolist():
|
||||||
|
outs.append(Y_cat[start : start + m, :])
|
||||||
|
start += m
|
||||||
|
return outs
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LAST_ERROR = f"_grouped_mm failed: {e}"
|
LAST_ERROR = f"_grouped_mm failed: {e}"
|
||||||
# Avoid _scaled_grouped_mm for now; its signature requires packed inputs.
|
return None
|
||||||
|
LAST_ERROR = "aten._grouped_mm not present"
|
||||||
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LAST_ERROR = f"call error: {e}"
|
LAST_ERROR = f"call error: {e}"
|
||||||
return None
|
return None
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def moe_ffn_forward_grouped(
|
def moe_ffn_forward_grouped(
|
||||||
|
|||||||
Reference in New Issue
Block a user