grouped_mm
This commit is contained in:
@@ -109,9 +109,6 @@ def main():
|
||||
)
|
||||
p.add_argument("--iters", type=int, default=50)
|
||||
p.add_argument("--warmup", type=int, default=10)
|
||||
p.add_argument(
|
||||
"--check", action="store_true", help="Check numerical equivalence (outputs)"
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -166,30 +163,56 @@ def main():
|
||||
t_hf, x, gate, experts, args.top_k, iters=args.iters, warmup=args.warmup
|
||||
)
|
||||
tflops = flops_total / ((t_ms / 1000.0) * 1e12)
|
||||
speedup = t_naive / t_ms
|
||||
print(
|
||||
f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s"
|
||||
f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
|
||||
)
|
||||
if args.check:
|
||||
with torch.no_grad():
|
||||
y_ref = forward_naive(x, gate, experts, args.top_k)
|
||||
y_fast = y
|
||||
# align dtypes for error metrics
|
||||
y_ref32 = y_ref.float()
|
||||
y_fast32 = y_fast.float()
|
||||
diff = (y_ref32 - y_fast32).abs()
|
||||
max_abs = diff.max().item()
|
||||
mean_abs = diff.mean().item()
|
||||
rel_l2 = (
|
||||
(diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
|
||||
)
|
||||
print(
|
||||
f"check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
|
||||
)
|
||||
with torch.no_grad():
|
||||
y_ref = forward_naive(x, gate, experts, args.top_k)
|
||||
y_fast = y
|
||||
y_ref32 = y_ref.float()
|
||||
y_fast32 = y_fast.float()
|
||||
diff = (y_ref32 - y_fast32).abs()
|
||||
max_abs = diff.max().item()
|
||||
mean_abs = diff.mean().item()
|
||||
rel_l2 = (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
|
||||
print(
|
||||
f"check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
|
||||
)
|
||||
else:
|
||||
print("hf_triton\tN/A (kernels hub not available)")
|
||||
|
||||
# torch_grouped placeholder — not yet implemented
|
||||
print("torch_grouped\tN/A (pending implementation)")
|
||||
# torch_grouped backend (PyTorch 2.8+)
|
||||
try:
|
||||
from axolotl.kernels.moe import torch_grouped as tg
|
||||
except Exception:
|
||||
tg = None
|
||||
if tg is not None and tg.available():
|
||||
|
||||
def forward_tg(a, g, ex, topk):
|
||||
y, _ = tg.moe_ffn_forward_grouped(a, g, ex, topk)
|
||||
return y
|
||||
|
||||
y_tg = forward_tg(x, gate, experts, args.top_k)
|
||||
if y_tg is not None:
|
||||
t_ms = bench(
|
||||
forward_tg,
|
||||
x,
|
||||
gate,
|
||||
experts,
|
||||
args.top_k,
|
||||
iters=args.iters,
|
||||
warmup=args.warmup,
|
||||
)
|
||||
tflops = flops_total / ((t_ms / 1000.0) * 1e12)
|
||||
speedup = t_naive / t_ms
|
||||
print(
|
||||
f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
|
||||
)
|
||||
else:
|
||||
print("torch_grouped\tN/A (op not callable)")
|
||||
else:
|
||||
print("torch_grouped\tN/A (unavailable)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
47
scripts/probe_torch_grouped_ops.py
Normal file
47
scripts/probe_torch_grouped_ops.py
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Probe PyTorch for grouped GEMM operator names and namespaces.
|
||||
Run: python scripts/probe_torch_grouped_ops.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
import torch
|
||||
except Exception as e:
|
||||
print("Failed to import torch:", e)
|
||||
sys.exit(1)
|
||||
|
||||
print("torch version:", torch.__version__)
|
||||
namespaces = [n for n in dir(torch.ops) if not n.startswith("_")]
|
||||
print("ops namespaces:", namespaces)
|
||||
|
||||
found_any = False
|
||||
for ns in namespaces:
|
||||
obj = getattr(torch.ops, ns, None)
|
||||
ops = []
|
||||
if obj is not None:
|
||||
try:
|
||||
ops = dir(obj)
|
||||
except Exception as e:
|
||||
print(f"warning: failed to list ops for namespace {ns}: {e}")
|
||||
cands = [
|
||||
o
|
||||
for o in ops
|
||||
if ("group" in o.lower())
|
||||
or ("mm_grouped" in o.lower())
|
||||
or ("matmul_grouped" in o.lower())
|
||||
or ("grouped" in o.lower())
|
||||
]
|
||||
if cands:
|
||||
found_any = True
|
||||
print(f"namespace {ns} candidates:", cands)
|
||||
|
||||
if not found_any:
|
||||
print("No grouped GEMM candidates found. PyTorch >= 2.8 is recommended.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,16 +1,141 @@
|
||||
"""
|
||||
Placeholder for PyTorch 2.8+ grouped GEMM MoE path.
|
||||
Currently probes availability; full integration to be implemented.
|
||||
PyTorch 2.8+ grouped GEMM MoE path (cuBLASLt-backed).
|
||||
This is a cautious first pass that probes available ops and only runs when supported.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def available() -> bool:
|
||||
try:
|
||||
import torch # noqa: F401
|
||||
|
||||
ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
|
||||
return ver >= (2, 8)
|
||||
if ver < (2, 8):
|
||||
return False
|
||||
# Check for aten grouped mm ops
|
||||
return hasattr(torch.ops, "aten") and (
|
||||
hasattr(torch.ops.aten, "_grouped_mm")
|
||||
or hasattr(torch.ops.aten, "_scaled_grouped_mm")
|
||||
)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _call_grouped_mm(
|
||||
As: List[torch.Tensor], Bs: List[torch.Tensor]
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
"""
|
||||
Try calling the appropriate grouped mm op available in this torch build.
|
||||
Returns list of outputs or None on failure.
|
||||
"""
|
||||
try:
|
||||
if hasattr(torch.ops.aten, "_grouped_mm"):
|
||||
return torch.ops.aten._grouped_mm(As, Bs) # type: ignore[attr-defined]
|
||||
if hasattr(torch.ops.aten, "_scaled_grouped_mm"):
|
||||
# signature likely (As, Bs, alpha, beta)
|
||||
return torch.ops.aten._scaled_grouped_mm(As, Bs, 1.0, 0.0) # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def moe_ffn_forward_grouped(
|
||||
hidden_states, gate_linear, experts_module, top_k: int
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Attempt a grouped GEMM fast path using PyTorch 2.8+.
|
||||
If unavailable or fails, returns (None, None) so caller can fallback.
|
||||
"""
|
||||
try:
|
||||
bsz, seqlen, hdim = hidden_states.shape
|
||||
x = hidden_states.view(-1, hdim)
|
||||
router_logits = gate_linear(x)
|
||||
|
||||
# topk routing in torch (keep simple to avoid dependency cycles)
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||
topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(x.dtype)
|
||||
|
||||
# Build per-expert input lists
|
||||
flat_idx = topk_idx.view(-1)
|
||||
x_rep = x.repeat_interleave(top_k, dim=0)
|
||||
|
||||
# Cache stacked weights on experts
|
||||
E = experts_module.num_experts
|
||||
dev, dt = x.device, x.dtype
|
||||
if (
|
||||
not hasattr(experts_module, "_stacked_w1")
|
||||
or experts_module._stacked_w1.device != dev
|
||||
or experts_module._stacked_w1.dtype != dt
|
||||
):
|
||||
w1 = [experts_module[i].w1.weight.t() for i in range(E)]
|
||||
w3 = [experts_module[i].w3.weight.t() for i in range(E)]
|
||||
w2 = [experts_module[i].w2.weight.t() for i in range(E)]
|
||||
experts_module._stacked_w1 = (
|
||||
torch.stack(w1, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w3 = (
|
||||
torch.stack(w3, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w2 = (
|
||||
torch.stack(w2, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w13 = torch.cat(
|
||||
[experts_module._stacked_w1, experts_module._stacked_w3], dim=-1
|
||||
).contiguous()
|
||||
W13 = experts_module._stacked_w13
|
||||
W2 = experts_module._stacked_w2
|
||||
|
||||
# Grouped GEMM for up+gate
|
||||
As: List[torch.Tensor] = []
|
||||
Bs: List[torch.Tensor] = []
|
||||
expert_slices = []
|
||||
for i in range(E):
|
||||
sel = flat_idx == i
|
||||
if sel.any():
|
||||
Xi = x_rep[sel]
|
||||
As.append(Xi)
|
||||
Bs.append(W13[i])
|
||||
expert_slices.append((i, sel))
|
||||
|
||||
if not As:
|
||||
# no tokens routed — edge case
|
||||
out = torch.zeros_like(x)
|
||||
return out.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
Y_list = _call_grouped_mm(As, Bs)
|
||||
if Y_list is None:
|
||||
return None, None
|
||||
|
||||
# SwiGLU on each expert block and prepare for down projection
|
||||
As2: List[torch.Tensor] = []
|
||||
Bs2: List[torch.Tensor] = []
|
||||
y_buf = torch.empty_like(x_rep)
|
||||
# split Y into (I, I)
|
||||
for (i, sel), Yi in zip(expert_slices, Y_list):
|
||||
I2 = Yi.shape[-1] // 2
|
||||
Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]
|
||||
As2.append(Yi_hidden)
|
||||
Bs2.append(W2[i])
|
||||
|
||||
Y2_list = _call_grouped_mm(As2, Bs2)
|
||||
if Y2_list is None:
|
||||
return None, None
|
||||
|
||||
# Write back, apply per-token weighting, and reduce over top_k
|
||||
for (i, sel), Out_i in zip(expert_slices, Y2_list):
|
||||
y_buf[sel] = Out_i
|
||||
y = (y_buf.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
return y.view(bsz, seqlen, hdim), router_logits
|
||||
except Exception:
|
||||
return None, None
|
||||
|
||||
Reference in New Issue
Block a user