handle base+lora split kernel for older moe models

This commit is contained in:
Wing Lian
2026-03-19 07:11:30 +00:00
parent 66fea258c7
commit 31d8d068bb
3 changed files with 139 additions and 12 deletions

View File

@@ -134,7 +134,9 @@ def main():
_clean()
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(E, K, N, T, k, R)
# Forward with LoRA
# Forward with LoRA (auto-dispatched: fused or split)
dispatch = "split" if (E <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD) else "fused"
t_fwd = _bench(lambda: lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=2.0,
@@ -162,7 +164,8 @@ def main():
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
print(f" R={R:>2} {proj:<8} "
f"fwd={t_fwd:>6.2f}ms base={t_base:>6.2f}ms "
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
f"base={t_base:>6.2f}ms "
f"(+{overhead*100:.0f}%) "
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
f"total={total:>6.2f}ms")
@@ -184,7 +187,17 @@ def main():
lB_ag.grad = None
t_full = _bench(_run_autograd)
print(f" full_fwd_bwd={t_full:>6.2f}ms")
# Memory measurement
_clean()
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
_run_autograd()
torch.cuda.synchronize()
mem_peak = torch.cuda.max_memory_allocated() - mem_before
print(f" full_fwd_bwd={t_full:>6.2f}ms "
f"peak_delta={mem_peak/1e6:>6.1f}MB")
print()