This commit is contained in:
Dan Saunders
2025-09-15 20:03:12 -04:00
parent fef47a5b7c
commit 773d7e4291

View File

@@ -155,7 +155,11 @@ def main():
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s" f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
) )
# HF Triton (stub compute for now) # Prepare reference output once for checks
with torch.no_grad():
y_ref = forward_naive(x, gate, experts, args.top_k)
# HF Triton
t_hf = forward_hf_triton t_hf = forward_hf_triton
y = t_hf(x, gate, experts, args.top_k) y = t_hf(x, gate, experts, args.top_k)
if y is not None: if y is not None:
@@ -167,8 +171,8 @@ def main():
print( print(
f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×" f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
) )
# parity for hf_triton vs naive
with torch.no_grad(): with torch.no_grad():
y_ref = forward_naive(x, gate, experts, args.top_k)
y_fast = y y_fast = y
y_ref32 = y_ref.float() y_ref32 = y_ref.float()
y_fast32 = y_fast.float() y_fast32 = y_fast.float()
@@ -177,7 +181,7 @@ def main():
mean_abs = diff.mean().item() mean_abs = diff.mean().item()
rel_l2 = (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item() rel_l2 = (diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
print( print(
f"check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}" f"hf_triton_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
) )
else: else:
print("hf_triton\tN/A (kernels hub not available)") print("hf_triton\tN/A (kernels hub not available)")
@@ -209,6 +213,19 @@ def main():
print( print(
f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×" f"torch_grouped\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s\t{speedup:.2f}×"
) )
with torch.no_grad():
y_fast = y_tg
y_ref32 = y_ref.float()
y_fast32 = y_fast.float()
diff = (y_ref32 - y_fast32).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (
(diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
)
print(
f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
)
else: else:
try: try:
from axolotl.kernels.moe.torch_grouped import LAST_ERROR as _TG_ERR from axolotl.kernels.moe.torch_grouped import LAST_ERROR as _TG_ERR