just grouped_mm for now

This commit is contained in:
Dan Saunders
2025-09-15 23:03:18 -04:00
parent 773d7e4291
commit 7d572b58d1
6 changed files with 32 additions and 264 deletions

View File

@@ -1,6 +1,5 @@
#!/usr/bin/env python
import argparse
import os
import time
import torch
@@ -52,20 +51,6 @@ def forward_naive(
return y.view(bsz, seqlen, hdim)
def forward_hf_triton(
hidden_states: torch.Tensor, gate: nn.Linear, experts: Experts, top_k: int
):
try:
from axolotl.kernels.moe import hf_triton as _hf
except Exception:
return None
try:
y, _ = _hf.moe_ffn_forward_stub(hidden_states, gate, experts, top_k)
return y
except Exception:
return None
def bench(fn, *args, iters=50, warmup=10, sync=True):
# warmup
for _ in range(warmup):
@@ -159,33 +144,6 @@ def main():
with torch.no_grad():
y_ref = forward_naive(x, gate, experts, args.top_k)
# HF Triton
t_hf = forward_hf_triton
y = t_hf(x, gate, experts, args.top_k)
if y is not None:
t_ms = bench(
t_hf, x, gate, experts, args.top_k, iters=args.iters, warmup=args.warmup
)
tflops = flops_total / ((t_ms / 1000.0) * 1e12)
speedup = t_naive / t_ms
print(
f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
)
# parity for hf_triton vs naive
with torch.no_grad():
y_fast = y
y_ref32 = y_ref.float()
y_fast32 = y_fast.float()
diff = (y_ref32 - y_fast32).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
print(
f"hf_triton_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
)
else:
print("hf_triton\tN/A (kernels hub not available)")
# torch_grouped backend (PyTorch 2.8+)
try:
from axolotl.kernels.moe import torch_grouped as tg