handle base+lora split kernel for older moe models
This commit is contained in:
@@ -294,6 +294,34 @@ class TestScatterMoELoRAAutograd:
|
||||
assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero"
|
||||
|
||||
|
||||
def test_split_matches_fused(self):
|
||||
"""Split dispatch (for few large experts) matches fused kernel."""
|
||||
# Use a shape where split would be dispatched (large K*N, few E)
|
||||
E, K, N, T, k, R = 8, 512, 1024, 128, 2, 16
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
|
||||
# Force fused path
|
||||
orig = lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
||||
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 10**18
|
||||
out_fused = lora_ops.scatter2scatter_lora(
|
||||
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
|
||||
)
|
||||
|
||||
# Force split path
|
||||
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 0
|
||||
out_split = lora_ops.scatter2scatter_lora(
|
||||
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
|
||||
)
|
||||
lora_ops._SPLIT_LORA_FWD_THRESHOLD = orig
|
||||
|
||||
norm_err = (
|
||||
(out_fused.float() - out_split.float()).norm()
|
||||
/ (out_fused.float().norm() + 1e-6)
|
||||
).item()
|
||||
assert norm_err < 0.01, f"split vs fused norm_err={norm_err}"
|
||||
|
||||
def test_scaling_zero_gives_base_only(self):
|
||||
"""With scaling=0.0, LoRA contribution vanishes. Output = X@W."""
|
||||
E, K, N, T, k, R = 16, 64, 32, 32, 2, 4
|
||||
|
||||
Reference in New Issue
Block a user