handle base+lora split kernel for older moe models
This commit is contained in:
@@ -134,7 +134,9 @@ def main():
|
||||
_clean()
|
||||
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(E, K, N, T, k, R)
|
||||
|
||||
# Forward with LoRA
|
||||
# Forward with LoRA (auto-dispatched: fused or split)
|
||||
dispatch = "split" if (E <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
|
||||
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD) else "fused"
|
||||
t_fwd = _bench(lambda: lora_ops.scatter2scatter_lora(
|
||||
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||
k=k, lora_A=lA, lora_B=lB, scaling=2.0,
|
||||
@@ -162,7 +164,8 @@ def main():
|
||||
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
|
||||
|
||||
print(f" R={R:>2} {proj:<8} "
|
||||
f"fwd={t_fwd:>6.2f}ms base={t_base:>6.2f}ms "
|
||||
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
|
||||
f"base={t_base:>6.2f}ms "
|
||||
f"(+{overhead*100:.0f}%) "
|
||||
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
|
||||
f"total={total:>6.2f}ms")
|
||||
@@ -184,7 +187,17 @@ def main():
|
||||
lB_ag.grad = None
|
||||
|
||||
t_full = _bench(_run_autograd)
|
||||
print(f" full_fwd_bwd={t_full:>6.2f}ms")
|
||||
|
||||
# Memory measurement
|
||||
_clean()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
mem_before = torch.cuda.memory_allocated()
|
||||
_run_autograd()
|
||||
torch.cuda.synchronize()
|
||||
mem_peak = torch.cuda.max_memory_allocated() - mem_before
|
||||
|
||||
print(f" full_fwd_bwd={t_full:>6.2f}ms "
|
||||
f"peak_delta={mem_peak/1e6:>6.1f}MB")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user