add correctness unit tests and benchmarks for scattermoe + lora
This commit is contained in:
193
benchmarks/bench_scattermoe_lora.py
Normal file
193
benchmarks/bench_scattermoe_lora.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""Benchmark for ScatterMoE LoRA Triton kernels.
|
||||||
|
|
||||||
|
Measures forward, backward dX, and backward dA/dB kernels at common MoE
|
||||||
|
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
|
||||||
|
and full fwd+bwd autograd throughput.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import statistics
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
|
||||||
|
ops as base_ops,
|
||||||
|
lora_ops,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
|
||||||
|
flatten_sort_count,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
|
||||||
|
ScatterMoELoRA,
|
||||||
|
)
|
||||||
|
|
||||||
|
DEVICE = "cuda"
|
||||||
|
DTYPE = torch.bfloat16
|
||||||
|
WARMUP = 5
|
||||||
|
ITERS = 20
|
||||||
|
|
||||||
|
# ─── Model configs ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
BUILTIN_CONFIGS = {
|
||||||
|
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
|
||||||
|
"Qwen3-30B-A3B": (128, 2048, 768, 8),
|
||||||
|
"OLMoE-1B-7B": (64, 2048, 1024, 8),
|
||||||
|
"Mixtral-8x7B": (8, 4096, 14336, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_config(spec):
|
||||||
|
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
|
||||||
|
key = spec.lower().replace("/", "-")
|
||||||
|
for name, cfg in BUILTIN_CONFIGS.items():
|
||||||
|
if key in name.lower() or name.lower() in key:
|
||||||
|
return name, cfg
|
||||||
|
|
||||||
|
# Try HuggingFace AutoConfig
|
||||||
|
from transformers import AutoConfig
|
||||||
|
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
|
||||||
|
if callable(getattr(hf_cfg, "get_text_config", None)):
|
||||||
|
tc = hf_cfg.get_text_config()
|
||||||
|
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
|
||||||
|
hf_cfg = tc
|
||||||
|
H = hf_cfg.hidden_size
|
||||||
|
I = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
|
||||||
|
E = (getattr(hf_cfg, "num_experts", None)
|
||||||
|
or getattr(hf_cfg, "num_local_experts", None)
|
||||||
|
or getattr(hf_cfg, "n_routed_experts", None))
|
||||||
|
k = (getattr(hf_cfg, "num_experts_per_tok", None)
|
||||||
|
or getattr(hf_cfg, "num_experts_per_token", None) or 2)
|
||||||
|
name = spec.split("/")[-1]
|
||||||
|
return name, (E, H, I, k)
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Benchmark helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _clean():
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
def _bench(fn, warmup=WARMUP, iters=ITERS):
|
||||||
|
for _ in range(warmup):
|
||||||
|
fn()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
times = []
|
||||||
|
for _ in range(iters):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
fn()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
times.append((time.perf_counter() - t0) * 1000)
|
||||||
|
return statistics.median(times)
|
||||||
|
|
||||||
|
|
||||||
|
def _setup(E, K, N, T, top_k, R):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||||
|
W = torch.randn(E, K, N, device=DEVICE, dtype=DTYPE) * 0.02
|
||||||
|
lora_A = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) * 0.01
|
||||||
|
lora_B = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) * 0.01
|
||||||
|
logits = torch.randn(T, E, device=DEVICE)
|
||||||
|
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
|
||||||
|
sei, ssi, eo = flatten_sort_count(top_idx, E)
|
||||||
|
gx = base_ops.group(x, ssi, fan_out=top_k)
|
||||||
|
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||||
|
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Main ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
|
||||||
|
parser.add_argument("--models", "-m", nargs="+",
|
||||||
|
help="Model names or HF IDs (default: all builtins)")
|
||||||
|
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
|
||||||
|
parser.add_argument("--seq-len", "-T", type=int, default=2048)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
T = args.seq_len
|
||||||
|
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||||
|
print(f"T={T}, ranks={args.ranks}\n")
|
||||||
|
|
||||||
|
if args.models:
|
||||||
|
configs = [_resolve_config(m) for m in args.models]
|
||||||
|
else:
|
||||||
|
configs = list(BUILTIN_CONFIGS.items())
|
||||||
|
configs = [(n, c) for n, c in configs]
|
||||||
|
|
||||||
|
for model_name, (E, H, I, k) in configs:
|
||||||
|
print(f"{'=' * 70}")
|
||||||
|
print(f" {model_name}: E={E}, H={H}, I={I}, k={k}")
|
||||||
|
print(f"{'=' * 70}")
|
||||||
|
|
||||||
|
for R in args.ranks:
|
||||||
|
for proj, K, N in [("gate_up", H, 2 * I), ("down", I, H)]:
|
||||||
|
_clean()
|
||||||
|
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(E, K, N, T, k, R)
|
||||||
|
|
||||||
|
# Forward with LoRA
|
||||||
|
t_fwd = _bench(lambda: lora_ops.scatter2scatter_lora(
|
||||||
|
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||||
|
k=k, lora_A=lA, lora_B=lB, scaling=2.0,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Forward without LoRA (base)
|
||||||
|
t_base = _bench(lambda: base_ops.scatter2scatter(
|
||||||
|
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, k=k,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Backward dX
|
||||||
|
t_dx = _bench(lambda: lora_ops.scatter2scatter_lora_dX(
|
||||||
|
DY=dy, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||||
|
k=1, lora_A=lA, lora_B=lB, scaling=2.0,
|
||||||
|
dy_grouped=True, dx_grouped=False,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Backward dA/dB
|
||||||
|
t_bwd = _bench(lambda: lora_ops.group_bwd_lora(
|
||||||
|
DY=dy, X=gx, lora_A=lA, lora_B=lB,
|
||||||
|
expert_offsets=eo, E=E, scaling=2.0,
|
||||||
|
))
|
||||||
|
|
||||||
|
total = t_fwd + t_dx + t_bwd
|
||||||
|
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
|
||||||
|
|
||||||
|
print(f" R={R:>2} {proj:<8} "
|
||||||
|
f"fwd={t_fwd:>6.2f}ms base={t_base:>6.2f}ms "
|
||||||
|
f"(+{overhead*100:.0f}%) "
|
||||||
|
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
|
||||||
|
f"total={total:>6.2f}ms")
|
||||||
|
|
||||||
|
# Full autograd fwd+bwd
|
||||||
|
x_ag = x.clone().requires_grad_(True)
|
||||||
|
lA_ag = lA.clone().requires_grad_(True)
|
||||||
|
lB_ag = lB.clone().requires_grad_(True)
|
||||||
|
|
||||||
|
def _run_autograd():
|
||||||
|
out = ScatterMoELoRA.apply(
|
||||||
|
x_ag, W, k, sei, ssi, eo,
|
||||||
|
lA_ag, lB_ag, 2.0,
|
||||||
|
None, None, False, False, True, False,
|
||||||
|
)
|
||||||
|
out.sum().backward()
|
||||||
|
x_ag.grad = None
|
||||||
|
lA_ag.grad = None
|
||||||
|
lB_ag.grad = None
|
||||||
|
|
||||||
|
t_full = _bench(_run_autograd)
|
||||||
|
print(f" full_fwd_bwd={t_full:>6.2f}ms")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
311
tests/integrations/test_scattermoe_lora_kernels.py
Normal file
311
tests/integrations/test_scattermoe_lora_kernels.py
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) Axolotl AI
|
||||||
|
# Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
|
"""
|
||||||
|
Unit tests for ScatterMoE LoRA Triton kernels.
|
||||||
|
|
||||||
|
Tests correctness of:
|
||||||
|
- scatter2scatter_lora (forward)
|
||||||
|
- scatter2scatter_lora_dX (backward input gradient)
|
||||||
|
- group_bwd_lora (backward LoRA weight gradients via split dA/dB)
|
||||||
|
- ScatterMoELoRA autograd function (full forward + backward)
|
||||||
|
|
||||||
|
Each kernel is tested against a pure PyTorch per-expert-loop reference
|
||||||
|
implementation at multiple model shapes and LoRA ranks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
|
||||||
|
ops as base_ops,
|
||||||
|
lora_ops,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
|
||||||
|
flatten_sort_count,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
|
||||||
|
ScatterMoELoRA,
|
||||||
|
)
|
||||||
|
|
||||||
|
DEVICE = "cuda"
|
||||||
|
DTYPE = torch.bfloat16
|
||||||
|
|
||||||
|
|
||||||
|
def _requires_cuda():
|
||||||
|
return pytest.mark.skipif(
|
||||||
|
not torch.cuda.is_available(), reason="CUDA not available"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
pytestmark = _requires_cuda()
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _setup(E, K, N, T, top_k, R, seed=42):
|
||||||
|
"""Create synthetic expert weights, LoRA, routing, and grouped inputs."""
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||||
|
W = torch.randn(E, K, N, device=DEVICE, dtype=DTYPE) * 0.02
|
||||||
|
lora_A = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) * 0.01
|
||||||
|
lora_B = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) * 0.01
|
||||||
|
logits = torch.randn(T, E, device=DEVICE)
|
||||||
|
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
|
||||||
|
sei, ssi, eo = flatten_sort_count(top_idx, E)
|
||||||
|
return x, W, lora_A, lora_B, sei, ssi, eo
|
||||||
|
|
||||||
|
|
||||||
|
def _reference_fwd(x, W, sei, ssi, eo, k, lora_A, lora_B, scaling, E):
|
||||||
|
"""Per-expert loop reference: Y = X@W + scaling*(X@A^T)@B^T."""
|
||||||
|
grouped_x = base_ops.group(x, ssi, fan_out=k)
|
||||||
|
M, N = grouped_x.size(0), W.size(2)
|
||||||
|
R = lora_A.size(0) // E
|
||||||
|
out = torch.zeros(M, N, device=DEVICE, dtype=DTYPE)
|
||||||
|
for e in range(E):
|
||||||
|
s = eo[e - 1].item() if e > 0 else 0
|
||||||
|
end = eo[e].item()
|
||||||
|
if s == end:
|
||||||
|
continue
|
||||||
|
xe = grouped_x[s:end].float()
|
||||||
|
we = W[e].float()
|
||||||
|
ae = lora_A[e * R : (e + 1) * R].float()
|
||||||
|
be = lora_B[:, e * R : (e + 1) * R].float()
|
||||||
|
out[s:end] = (xe @ we + scaling * (xe @ ae.T) @ be.T).to(DTYPE)
|
||||||
|
result = torch.zeros(M, N, device=DEVICE, dtype=DTYPE)
|
||||||
|
result[ssi] = out
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _reference_dX(dy_grouped, W, sei, ssi, eo, lora_A, lora_B, scaling, E):
|
||||||
|
"""Per-expert loop reference: dX = dY@W^T + scaling*(dY@B)@A."""
|
||||||
|
M, K = dy_grouped.size(0), W.size(1)
|
||||||
|
R = lora_A.size(0) // E
|
||||||
|
out = torch.zeros(M, K, device=DEVICE, dtype=DTYPE)
|
||||||
|
for e in range(E):
|
||||||
|
s = eo[e - 1].item() if e > 0 else 0
|
||||||
|
end = eo[e].item()
|
||||||
|
if s == end:
|
||||||
|
continue
|
||||||
|
dye = dy_grouped[s:end].float()
|
||||||
|
we = W[e].float()
|
||||||
|
ae = lora_A[e * R : (e + 1) * R].float()
|
||||||
|
be = lora_B[:, e * R : (e + 1) * R].float()
|
||||||
|
out[s:end] = (dye @ we.T + scaling * (dye @ be) @ ae).to(DTYPE)
|
||||||
|
result = torch.zeros(M, K, device=DEVICE, dtype=DTYPE)
|
||||||
|
result[ssi] = out
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _reference_bwd_lora(dy, grouped_x, lora_A, lora_B, eo, E, scaling):
|
||||||
|
"""Per-expert loop reference: dA, dB for LoRA weight gradients."""
|
||||||
|
R = lora_A.size(0) // E
|
||||||
|
dA = torch.zeros_like(lora_A)
|
||||||
|
dB = torch.zeros_like(lora_B)
|
||||||
|
for e in range(E):
|
||||||
|
s = eo[e - 1].item() if e > 0 else 0
|
||||||
|
end = eo[e].item()
|
||||||
|
if s == end:
|
||||||
|
continue
|
||||||
|
xe = grouped_x[s:end].float()
|
||||||
|
dye = dy[s:end].float()
|
||||||
|
ae = lora_A[e * R : (e + 1) * R].float()
|
||||||
|
be = lora_B[:, e * R : (e + 1) * R].float()
|
||||||
|
dA[e * R : (e + 1) * R] = (scaling * (dye @ be).T @ xe).to(DTYPE)
|
||||||
|
dB[:, e * R : (e + 1) * R] = (scaling * dye.T @ (xe @ ae.T)).to(DTYPE)
|
||||||
|
return dA, dB
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Model shape configs ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# (E, K, N, T, top_k, R, description)
|
||||||
|
CONFIGS_SMALL = [
|
||||||
|
(32, 128, 64, 64, 2, 4, "tiny"),
|
||||||
|
(64, 256, 128, 128, 4, 8, "small"),
|
||||||
|
]
|
||||||
|
|
||||||
|
CONFIGS_REAL = [
|
||||||
|
(256, 2048, 1024, 2048, 8, 16, "qwen35_gate_up"),
|
||||||
|
(256, 512, 2048, 2048, 8, 16, "qwen35_down"),
|
||||||
|
(64, 2048, 2048, 2048, 8, 16, "olmoe_gate_up"),
|
||||||
|
(128, 2048, 1536, 2048, 8, 16, "qwen3_gate_up"),
|
||||||
|
]
|
||||||
|
|
||||||
|
SCALING = 2.0
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Forward tests ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestScatter2ScatterLoRAForward:
|
||||||
|
"""Test scatter2scatter_lora forward kernel vs reference."""
|
||||||
|
|
||||||
|
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
|
||||||
|
def config(self, request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
def test_matches_reference(self, config):
|
||||||
|
E, K, N, T, k, R, desc = config
|
||||||
|
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||||
|
|
||||||
|
kernel_out = lora_ops.scatter2scatter_lora(
|
||||||
|
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||||
|
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
|
||||||
|
)
|
||||||
|
ref_out = _reference_fwd(x, W, sei, ssi, eo, k, lA, lB, SCALING, E)
|
||||||
|
|
||||||
|
err = (kernel_out.float() - ref_out.float()).abs().max().item()
|
||||||
|
assert err < 1.0, f"[{desc}] fwd max_err={err}"
|
||||||
|
|
||||||
|
def test_output_shape(self, config):
|
||||||
|
E, K, N, T, k, R, desc = config
|
||||||
|
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||||
|
|
||||||
|
out = lora_ops.scatter2scatter_lora(
|
||||||
|
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||||
|
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
|
||||||
|
)
|
||||||
|
assert out.shape == (T * k, N)
|
||||||
|
assert out.dtype == DTYPE
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Backward dX tests ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestScatter2ScatterLoRADX:
|
||||||
|
"""Test scatter2scatter_lora_dX backward kernel vs reference."""
|
||||||
|
|
||||||
|
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
|
||||||
|
def config(self, request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
def test_matches_reference(self, config):
|
||||||
|
E, K, N, T, k, R, desc = config
|
||||||
|
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||||
|
gx = base_ops.group(x, ssi, fan_out=k)
|
||||||
|
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||||
|
|
||||||
|
kernel_dx = lora_ops.scatter2scatter_lora_dX(
|
||||||
|
DY=dy, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||||
|
k=1, lora_A=lA, lora_B=lB, scaling=SCALING,
|
||||||
|
dy_grouped=True, dx_grouped=False,
|
||||||
|
)
|
||||||
|
ref_dx = _reference_dX(dy, W, sei, ssi, eo, lA, lB, SCALING, E)
|
||||||
|
|
||||||
|
err = (kernel_dx.float() - ref_dx.float()).abs().max().item()
|
||||||
|
assert err < 1.0, f"[{desc}] dX max_err={err}"
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Backward LoRA gradient tests ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestGroupBwdLoRA:
|
||||||
|
"""Test group_bwd_lora (split dA/dB kernel) vs reference."""
|
||||||
|
|
||||||
|
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
|
||||||
|
def config(self, request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
def test_matches_reference(self, config):
|
||||||
|
E, K, N, T, k, R, desc = config
|
||||||
|
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||||
|
gx = base_ops.group(x, ssi, fan_out=k)
|
||||||
|
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||||
|
|
||||||
|
kern_dA, kern_dB = lora_ops.group_bwd_lora(
|
||||||
|
DY=dy, X=gx, lora_A=lA, lora_B=lB,
|
||||||
|
expert_offsets=eo, E=E, scaling=SCALING,
|
||||||
|
)
|
||||||
|
ref_dA, ref_dB = _reference_bwd_lora(dy, gx, lA, lB, eo, E, SCALING)
|
||||||
|
|
||||||
|
# Use norm-relative error: bf16 accumulation order differs between
|
||||||
|
# kernel (tiled + different reduction order) and reference (per-expert
|
||||||
|
# fp32 loop), so max absolute error can be large on individual elements
|
||||||
|
# while the overall tensor is correct.
|
||||||
|
dA_norm_err = (
|
||||||
|
(kern_dA.float() - ref_dA.float()).norm()
|
||||||
|
/ (ref_dA.float().norm() + 1e-6)
|
||||||
|
).item()
|
||||||
|
dB_norm_err = (
|
||||||
|
(kern_dB.float() - ref_dB.float()).norm()
|
||||||
|
/ (ref_dB.float().norm() + 1e-6)
|
||||||
|
).item()
|
||||||
|
assert dA_norm_err < 0.01, f"[{desc}] dA norm_rel_err={dA_norm_err}"
|
||||||
|
assert dB_norm_err < 0.01, f"[{desc}] dB norm_rel_err={dB_norm_err}"
|
||||||
|
|
||||||
|
def test_zero_expert_tokens(self):
|
||||||
|
"""Experts with zero routed tokens produce zero gradients."""
|
||||||
|
E, K, N, R = 8, 64, 32, 4
|
||||||
|
torch.manual_seed(42)
|
||||||
|
# Route all tokens to expert 0 only
|
||||||
|
T, k = 16, 1
|
||||||
|
top_idx = torch.zeros(T, k, dtype=torch.long, device=DEVICE)
|
||||||
|
sei, ssi, eo = flatten_sort_count(top_idx, E)
|
||||||
|
gx = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||||
|
dy = torch.randn(T, N, device=DEVICE, dtype=DTYPE)
|
||||||
|
lA = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE)
|
||||||
|
lB = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE)
|
||||||
|
|
||||||
|
dA, dB = lora_ops.group_bwd_lora(
|
||||||
|
DY=dy, X=gx, lora_A=lA, lora_B=lB,
|
||||||
|
expert_offsets=eo, E=E, scaling=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Experts 1..7 should have zero gradients
|
||||||
|
for e in range(1, E):
|
||||||
|
assert dA[e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dA not zero"
|
||||||
|
assert dB[:, e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dB not zero"
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Full autograd tests ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestScatterMoELoRAAutograd:
|
||||||
|
"""Test full forward + backward through ScatterMoELoRA autograd function."""
|
||||||
|
|
||||||
|
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL[:2])
|
||||||
|
def config(self, request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
def test_gradients_exist_and_finite(self, config):
|
||||||
|
E, K, N, T, k, R, desc = config
|
||||||
|
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||||
|
|
||||||
|
x = x.requires_grad_(True)
|
||||||
|
lA = lA.requires_grad_(True)
|
||||||
|
lB = lB.requires_grad_(True)
|
||||||
|
|
||||||
|
out = ScatterMoELoRA.apply(
|
||||||
|
x, W, k, sei, ssi, eo,
|
||||||
|
lA, lB, SCALING,
|
||||||
|
None, None, False, False, True, False,
|
||||||
|
)
|
||||||
|
out.sum().backward()
|
||||||
|
|
||||||
|
assert x.grad is not None, f"[{desc}] x.grad is None"
|
||||||
|
assert lA.grad is not None, f"[{desc}] lA.grad is None"
|
||||||
|
assert lB.grad is not None, f"[{desc}] lB.grad is None"
|
||||||
|
assert torch.isfinite(x.grad).all(), f"[{desc}] x.grad has non-finite"
|
||||||
|
assert torch.isfinite(lA.grad).all(), f"[{desc}] lA.grad has non-finite"
|
||||||
|
assert torch.isfinite(lB.grad).all(), f"[{desc}] lB.grad has non-finite"
|
||||||
|
assert x.grad.abs().sum() > 0, f"[{desc}] x.grad all zero"
|
||||||
|
assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero"
|
||||||
|
|
||||||
|
|
||||||
|
def test_scaling_zero_gives_base_only(self):
|
||||||
|
"""With scaling=0.0, LoRA contribution vanishes. Output = X@W."""
|
||||||
|
E, K, N, T, k, R = 16, 64, 32, 32, 2, 4
|
||||||
|
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||||
|
|
||||||
|
out_lora = ScatterMoELoRA.apply(
|
||||||
|
x, W, k, sei, ssi, eo,
|
||||||
|
lA, lB, 0.0,
|
||||||
|
None, None, False, False, True, False,
|
||||||
|
)
|
||||||
|
out_base = base_ops.scatter2scatter(
|
||||||
|
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, k=k,
|
||||||
|
)
|
||||||
|
err = (out_lora.float() - out_base.float()).abs().max().item()
|
||||||
|
assert err < 0.01, f"scaling=0 should match base: err={err}"
|
||||||
Reference in New Issue
Block a user