numerics
This commit is contained in:
@@ -109,6 +109,9 @@ def main():
|
|||||||
)
|
)
|
||||||
p.add_argument("--iters", type=int, default=50)
|
p.add_argument("--iters", type=int, default=50)
|
||||||
p.add_argument("--warmup", type=int, default=10)
|
p.add_argument("--warmup", type=int, default=10)
|
||||||
|
p.add_argument(
|
||||||
|
"--check", action="store_true", help="Check numerical equivalence (outputs)"
|
||||||
|
)
|
||||||
args = p.parse_args()
|
args = p.parse_args()
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
@@ -166,6 +169,22 @@ def main():
|
|||||||
print(
|
print(
|
||||||
f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s"
|
f"hf_triton\t{t_ms:.2f} ms\t{tokens / (t_ms / 1000):.1f} tok/s\t{tflops:.2f} TFLOP/s"
|
||||||
)
|
)
|
||||||
|
if args.check:
|
||||||
|
with torch.no_grad():
|
||||||
|
y_ref = forward_naive(x, gate, experts, args.top_k)
|
||||||
|
y_fast = y
|
||||||
|
# align dtypes for error metrics
|
||||||
|
y_ref32 = y_ref.float()
|
||||||
|
y_fast32 = y_fast.float()
|
||||||
|
diff = (y_ref32 - y_fast32).abs()
|
||||||
|
max_abs = diff.max().item()
|
||||||
|
mean_abs = diff.mean().item()
|
||||||
|
rel_l2 = (
|
||||||
|
(diff.pow(2).sum() / (y_ref32.pow(2).sum() + 1e-12)).sqrt().item()
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print("hf_triton\tN/A (kernels hub not available)")
|
print("hf_triton\tN/A (kernels hub not available)")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user