chore: lint

This commit is contained in:
Wing Lian
2026-03-19 07:27:23 +00:00
parent 31d8d068bb
commit fec0c3a99e
8 changed files with 443 additions and 191 deletions

View File

@@ -19,8 +19,8 @@ import pytest
import torch
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
ops as base_ops,
lora_ops,
ops as base_ops,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
flatten_sort_count,
@@ -151,8 +151,14 @@ class TestScatter2ScatterLoRAForward:
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
kernel_out = lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=k,
lora_A=lA,
lora_B=lB,
scaling=SCALING,
)
ref_out = _reference_fwd(x, W, sei, ssi, eo, k, lA, lB, SCALING, E)
@@ -164,8 +170,14 @@ class TestScatter2ScatterLoRAForward:
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
out = lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=k,
lora_A=lA,
lora_B=lB,
scaling=SCALING,
)
assert out.shape == (T * k, N)
assert out.dtype == DTYPE
@@ -188,9 +200,16 @@ class TestScatter2ScatterLoRADX:
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
kernel_dx = lora_ops.scatter2scatter_lora_dX(
DY=dy, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=1, lora_A=lA, lora_B=lB, scaling=SCALING,
dy_grouped=True, dx_grouped=False,
DY=dy,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=1,
lora_A=lA,
lora_B=lB,
scaling=SCALING,
dy_grouped=True,
dx_grouped=False,
)
ref_dx = _reference_dX(dy, W, sei, ssi, eo, lA, lB, SCALING, E)
@@ -215,8 +234,13 @@ class TestGroupBwdLoRA:
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
kern_dA, kern_dB = lora_ops.group_bwd_lora(
DY=dy, X=gx, lora_A=lA, lora_B=lB,
expert_offsets=eo, E=E, scaling=SCALING,
DY=dy,
X=gx,
lora_A=lA,
lora_B=lB,
expert_offsets=eo,
E=E,
scaling=SCALING,
)
ref_dA, ref_dB = _reference_bwd_lora(dy, gx, lA, lB, eo, E, SCALING)
@@ -225,12 +249,10 @@ class TestGroupBwdLoRA:
# fp32 loop), so max absolute error can be large on individual elements
# while the overall tensor is correct.
dA_norm_err = (
(kern_dA.float() - ref_dA.float()).norm()
/ (ref_dA.float().norm() + 1e-6)
(kern_dA.float() - ref_dA.float()).norm() / (ref_dA.float().norm() + 1e-6)
).item()
dB_norm_err = (
(kern_dB.float() - ref_dB.float()).norm()
/ (ref_dB.float().norm() + 1e-6)
(kern_dB.float() - ref_dB.float()).norm() / (ref_dB.float().norm() + 1e-6)
).item()
assert dA_norm_err < 0.01, f"[{desc}] dA norm_rel_err={dA_norm_err}"
assert dB_norm_err < 0.01, f"[{desc}] dB norm_rel_err={dB_norm_err}"
@@ -249,14 +271,21 @@ class TestGroupBwdLoRA:
lB = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE)
dA, dB = lora_ops.group_bwd_lora(
DY=dy, X=gx, lora_A=lA, lora_B=lB,
expert_offsets=eo, E=E, scaling=2.0,
DY=dy,
X=gx,
lora_A=lA,
lora_B=lB,
expert_offsets=eo,
E=E,
scaling=2.0,
)
# Experts 1..7 should have zero gradients
for e in range(1, E):
assert dA[e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dA not zero"
assert dB[:, e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dB not zero"
assert dB[:, e * R : (e + 1) * R].abs().max() == 0, (
f"Expert {e} dB not zero"
)
# ─── Full autograd tests ────────────────────────────────────────────────────
@@ -278,9 +307,21 @@ class TestScatterMoELoRAAutograd:
lB = lB.requires_grad_(True)
out = ScatterMoELoRA.apply(
x, W, k, sei, ssi, eo,
lA, lB, SCALING,
None, None, False, False, True, False,
x,
W,
k,
sei,
ssi,
eo,
lA,
lB,
SCALING,
None,
None,
False,
False,
True,
False,
)
out.sum().backward()
@@ -293,7 +334,6 @@ class TestScatterMoELoRAAutograd:
assert x.grad.abs().sum() > 0, f"[{desc}] x.grad all zero"
assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero"
def test_split_matches_fused(self):
"""Split dispatch (for few large experts) matches fused kernel."""
# Use a shape where split would be dispatched (large K*N, few E)
@@ -304,15 +344,27 @@ class TestScatterMoELoRAAutograd:
orig = lora_ops._SPLIT_LORA_FWD_THRESHOLD
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 10**18
out_fused = lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=k,
lora_A=lA,
lora_B=lB,
scaling=SCALING,
)
# Force split path
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 0
out_split = lora_ops.scatter2scatter_lora(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=k,
lora_A=lA,
lora_B=lB,
scaling=SCALING,
)
lora_ops._SPLIT_LORA_FWD_THRESHOLD = orig
@@ -328,12 +380,28 @@ class TestScatterMoELoRAAutograd:
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
out_lora = ScatterMoELoRA.apply(
x, W, k, sei, ssi, eo,
lA, lB, 0.0,
None, None, False, False, True, False,
x,
W,
k,
sei,
ssi,
eo,
lA,
lB,
0.0,
None,
None,
False,
False,
True,
False,
)
out_base = base_ops.scatter2scatter(
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, k=k,
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=k,
)
err = (out_lora.float() - out_base.float()).abs().max().item()
assert err < 0.01, f"scaling=0 should match base: err={err}"