chore: lint
This commit is contained in:
@@ -12,14 +12,14 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import statistics
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
|
||||
ops as base_ops,
|
||||
lora_ops,
|
||||
ops as base_ops,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
|
||||
flatten_sort_count,
|
||||
@@ -36,7 +36,7 @@ ITERS = 20
|
||||
# ─── Model configs ──────────────────────────────────────────────────────────
|
||||
|
||||
BUILTIN_CONFIGS = {
|
||||
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
|
||||
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
|
||||
"Qwen3-30B-A3B": (128, 2048, 768, 8),
|
||||
"OLMoE-1B-7B": (64, 2048, 1024, 8),
|
||||
"Mixtral-8x7B": (8, 4096, 14336, 2),
|
||||
@@ -50,26 +50,32 @@ def _resolve_config(spec):
|
||||
if key in name.lower() or name.lower() in key:
|
||||
return name, cfg
|
||||
|
||||
# Try HuggingFace AutoConfig
|
||||
from transformers import AutoConfig
|
||||
|
||||
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
|
||||
if callable(getattr(hf_cfg, "get_text_config", None)):
|
||||
tc = hf_cfg.get_text_config()
|
||||
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
|
||||
hf_cfg = tc
|
||||
H = hf_cfg.hidden_size
|
||||
I = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
|
||||
E = (getattr(hf_cfg, "num_experts", None)
|
||||
or getattr(hf_cfg, "num_local_experts", None)
|
||||
or getattr(hf_cfg, "n_routed_experts", None))
|
||||
k = (getattr(hf_cfg, "num_experts_per_tok", None)
|
||||
or getattr(hf_cfg, "num_experts_per_token", None) or 2)
|
||||
hidden = hf_cfg.hidden_size
|
||||
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
|
||||
experts = (
|
||||
getattr(hf_cfg, "num_experts", None)
|
||||
or getattr(hf_cfg, "num_local_experts", None)
|
||||
or getattr(hf_cfg, "n_routed_experts", None)
|
||||
)
|
||||
top_k = (
|
||||
getattr(hf_cfg, "num_experts_per_tok", None)
|
||||
or getattr(hf_cfg, "num_experts_per_token", None)
|
||||
or 2
|
||||
)
|
||||
name = spec.split("/")[-1]
|
||||
return name, (E, H, I, k)
|
||||
return name, (experts, hidden, inter, top_k)
|
||||
|
||||
|
||||
# ─── Benchmark helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _clean():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
@@ -87,29 +93,88 @@ def _bench(fn, warmup=WARMUP, iters=ITERS):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
times.append((time.perf_counter() - t0) * 1000)
|
||||
return statistics.median(times)
|
||||
times.sort()
|
||||
return times[len(times) // 2]
|
||||
|
||||
|
||||
def _setup(E, K, N, T, top_k, R):
|
||||
def _setup(num_experts, K, N, T, top_k, R):
|
||||
torch.manual_seed(42)
|
||||
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||
W = torch.randn(E, K, N, device=DEVICE, dtype=DTYPE) * 0.02
|
||||
lora_A = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
lora_B = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
logits = torch.randn(T, E, device=DEVICE)
|
||||
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
|
||||
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
logits = torch.randn(T, num_experts, device=DEVICE)
|
||||
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
|
||||
sei, ssi, eo = flatten_sort_count(top_idx, E)
|
||||
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
|
||||
gx = base_ops.group(x, ssi, fan_out=top_k)
|
||||
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
|
||||
|
||||
|
||||
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
|
||||
|
||||
|
||||
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
|
||||
return lora_ops.scatter2scatter_lora(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=top_k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
|
||||
def _call_base(x, W, sei, ssi, top_k):
|
||||
return base_ops.scatter2scatter(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=top_k,
|
||||
)
|
||||
|
||||
|
||||
def _call_dx(dy, W, sei, ssi, lA, lB):
|
||||
return lora_ops.scatter2scatter_lora_dX(
|
||||
DY=dy,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=1,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=2.0,
|
||||
dy_grouped=True,
|
||||
dx_grouped=False,
|
||||
)
|
||||
|
||||
|
||||
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
|
||||
return lora_ops.group_bwd_lora(
|
||||
DY=dy,
|
||||
X=gx,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
expert_offsets=eo,
|
||||
E=num_experts,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
|
||||
# ─── Main ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
|
||||
parser.add_argument("--models", "-m", nargs="+",
|
||||
help="Model names or HF IDs (default: all builtins)")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
"-m",
|
||||
nargs="+",
|
||||
help="Model names or HF IDs (default: all builtins)",
|
||||
)
|
||||
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
|
||||
parser.add_argument("--seq-len", "-T", type=int, default=2048)
|
||||
args = parser.parse_args()
|
||||
@@ -122,73 +187,84 @@ def main():
|
||||
configs = [_resolve_config(m) for m in args.models]
|
||||
else:
|
||||
configs = list(BUILTIN_CONFIGS.items())
|
||||
configs = [(n, c) for n, c in configs]
|
||||
|
||||
for model_name, (E, H, I, k) in configs:
|
||||
for model_name, (num_experts, hidden, inter, top_k) in configs:
|
||||
print(f"{'=' * 70}")
|
||||
print(f" {model_name}: E={E}, H={H}, I={I}, k={k}")
|
||||
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
|
||||
print(f"{'=' * 70}")
|
||||
|
||||
for R in args.ranks:
|
||||
for proj, K, N in [("gate_up", H, 2 * I), ("down", I, H)]:
|
||||
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
|
||||
_clean()
|
||||
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(E, K, N, T, k, R)
|
||||
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
|
||||
num_experts, K, N, T, top_k, R
|
||||
)
|
||||
|
||||
# Forward with LoRA (auto-dispatched: fused or split)
|
||||
dispatch = "split" if (E <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
|
||||
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD) else "fused"
|
||||
t_fwd = _bench(lambda: lora_ops.scatter2scatter_lora(
|
||||
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||
k=k, lora_A=lA, lora_B=lB, scaling=2.0,
|
||||
))
|
||||
|
||||
# Forward without LoRA (base)
|
||||
t_base = _bench(lambda: base_ops.scatter2scatter(
|
||||
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, k=k,
|
||||
))
|
||||
|
||||
# Backward dX
|
||||
t_dx = _bench(lambda: lora_ops.scatter2scatter_lora_dX(
|
||||
DY=dy, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||
k=1, lora_A=lA, lora_B=lB, scaling=2.0,
|
||||
dy_grouped=True, dx_grouped=False,
|
||||
))
|
||||
|
||||
# Backward dA/dB
|
||||
t_bwd = _bench(lambda: lora_ops.group_bwd_lora(
|
||||
DY=dy, X=gx, lora_A=lA, lora_B=lB,
|
||||
expert_offsets=eo, E=E, scaling=2.0,
|
||||
))
|
||||
dispatch = (
|
||||
"split"
|
||||
if (
|
||||
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
|
||||
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
||||
)
|
||||
else "fused"
|
||||
)
|
||||
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
|
||||
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
|
||||
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
|
||||
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
|
||||
|
||||
total = t_fwd + t_dx + t_bwd
|
||||
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
|
||||
|
||||
print(f" R={R:>2} {proj:<8} "
|
||||
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
|
||||
f"base={t_base:>6.2f}ms "
|
||||
f"(+{overhead*100:.0f}%) "
|
||||
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
|
||||
f"total={total:>6.2f}ms")
|
||||
print(
|
||||
f" R={R:>2} {proj:<8} "
|
||||
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
|
||||
f"base={t_base:>6.2f}ms "
|
||||
f"(+{overhead * 100:.0f}%) "
|
||||
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
|
||||
f"total={total:>6.2f}ms"
|
||||
)
|
||||
|
||||
# Full autograd fwd+bwd
|
||||
# Full autograd fwd+bwd with memory measurement
|
||||
x_ag = x.clone().requires_grad_(True)
|
||||
lA_ag = lA.clone().requires_grad_(True)
|
||||
lB_ag = lB.clone().requires_grad_(True)
|
||||
|
||||
def _run_autograd():
|
||||
def _run_autograd(
|
||||
_x=x_ag,
|
||||
_W=W,
|
||||
_k=top_k,
|
||||
_sei=sei,
|
||||
_ssi=ssi,
|
||||
_eo=eo,
|
||||
_lA=lA_ag,
|
||||
_lB=lB_ag,
|
||||
):
|
||||
out = ScatterMoELoRA.apply(
|
||||
x_ag, W, k, sei, ssi, eo,
|
||||
lA_ag, lB_ag, 2.0,
|
||||
None, None, False, False, True, False,
|
||||
_x,
|
||||
_W,
|
||||
_k,
|
||||
_sei,
|
||||
_ssi,
|
||||
_eo,
|
||||
_lA,
|
||||
_lB,
|
||||
2.0,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
out.sum().backward()
|
||||
x_ag.grad = None
|
||||
lA_ag.grad = None
|
||||
lB_ag.grad = None
|
||||
_x.grad = None
|
||||
_lA.grad = None
|
||||
_lB.grad = None
|
||||
|
||||
t_full = _bench(_run_autograd)
|
||||
|
||||
# Memory measurement
|
||||
_clean()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
mem_before = torch.cuda.memory_allocated()
|
||||
@@ -196,8 +272,10 @@ def main():
|
||||
torch.cuda.synchronize()
|
||||
mem_peak = torch.cuda.max_memory_allocated() - mem_before
|
||||
|
||||
print(f" full_fwd_bwd={t_full:>6.2f}ms "
|
||||
f"peak_delta={mem_peak/1e6:>6.1f}MB")
|
||||
print(
|
||||
f" full_fwd_bwd={t_full:>6.2f}ms "
|
||||
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
@@ -566,30 +566,41 @@ def _scatter2scatter_lora_split(
|
||||
|
||||
# 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes)
|
||||
output = scatter2scatter(
|
||||
X=X, W=W, b=b,
|
||||
X=X,
|
||||
W=W,
|
||||
b=b,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=k, x_grouped=x_grouped, y_grouped=y_grouped, out=out,
|
||||
k=k,
|
||||
x_grouped=x_grouped,
|
||||
y_grouped=y_grouped,
|
||||
out=out,
|
||||
)
|
||||
|
||||
# 2. XA = X @ A^T (tiny: output is [M*k, R])
|
||||
# Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)
|
||||
W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()
|
||||
XA = scatter2scatter(
|
||||
X=X, W=W_A,
|
||||
X=X,
|
||||
W=W_A,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=k, x_grouped=x_grouped, y_grouped=True,
|
||||
k=k,
|
||||
x_grouped=x_grouped,
|
||||
y_grouped=True,
|
||||
)
|
||||
|
||||
# 3. Y_lora = XA @ B^T (R is tiny, so this is very fast)
|
||||
# Reshape B: [N, R*E] → [E, R, N]
|
||||
W_B = lora_B.T.reshape(E, R, N).contiguous()
|
||||
Y_lora = scatter2scatter(
|
||||
X=XA, W=W_B,
|
||||
X=XA,
|
||||
W=W_B,
|
||||
sorted_expert_idxs=sorted_expert_idxs,
|
||||
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||
k=1, x_grouped=True, y_grouped=y_grouped,
|
||||
k=1,
|
||||
x_grouped=True,
|
||||
y_grouped=y_grouped,
|
||||
)
|
||||
|
||||
# 4. Y = Y_base + scaling * Y_lora
|
||||
@@ -650,13 +661,20 @@ def scatter2scatter_lora(
|
||||
N = W.size(2)
|
||||
|
||||
# Dispatch: split for few large experts, fused for many small experts
|
||||
if (
|
||||
E <= _SPLIT_LORA_FWD_MAX_EXPERTS
|
||||
and K * N >= _SPLIT_LORA_FWD_THRESHOLD
|
||||
):
|
||||
if E <= _SPLIT_LORA_FWD_MAX_EXPERTS and K * N >= _SPLIT_LORA_FWD_THRESHOLD:
|
||||
return _scatter2scatter_lora_split(
|
||||
X, W, sorted_expert_idxs, sorted_scattered_idxs, k,
|
||||
lora_A, lora_B, scaling, b, x_grouped, y_grouped, out,
|
||||
X,
|
||||
W,
|
||||
sorted_expert_idxs,
|
||||
sorted_scattered_idxs,
|
||||
k,
|
||||
lora_A,
|
||||
lora_B,
|
||||
scaling,
|
||||
b,
|
||||
x_grouped,
|
||||
y_grouped,
|
||||
out,
|
||||
)
|
||||
|
||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||
@@ -1443,7 +1461,6 @@ def _prune_split_configs(configs, named_args, **kwargs):
|
||||
"""Prune split kernel configs based on SMEM capacity."""
|
||||
smem_cap = _get_smem_capacity()
|
||||
block_r = named_args.get("BLOCK_R", 64)
|
||||
inner_dim = named_args.get("INNER_DIM", 2048)
|
||||
|
||||
# Fixed inner tile for reduction dimension
|
||||
BLOCK_INNER = 64
|
||||
@@ -1470,33 +1487,47 @@ def _prune_split_configs(configs, named_args, **kwargs):
|
||||
key=["M", "K", "N"],
|
||||
prune_configs_by={"early_config_prune": _prune_split_configs},
|
||||
)
|
||||
@triton.heuristics({
|
||||
"NO_DIM_MASK": lambda args: (
|
||||
(args["K"] % args["BLOCK_DIM"]) == 0
|
||||
if args["COMPUTE_DA"]
|
||||
else (args["N"] % args["BLOCK_DIM"]) == 0
|
||||
),
|
||||
})
|
||||
@triton.heuristics(
|
||||
{
|
||||
"NO_DIM_MASK": lambda args: (
|
||||
(args["K"] % args["BLOCK_DIM"]) == 0
|
||||
if args["COMPUTE_DA"]
|
||||
else (args["N"] % args["BLOCK_DIM"]) == 0
|
||||
),
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def _group_bwd_lora_split(
|
||||
# Data tensors (DY and X are always present)
|
||||
DY_ptr, stride_dym, stride_dyn,
|
||||
X_ptr, stride_xm, stride_xk,
|
||||
DY_ptr,
|
||||
stride_dym,
|
||||
stride_dyn,
|
||||
X_ptr,
|
||||
stride_xm,
|
||||
stride_xk,
|
||||
# LoRA weight for the inner reduction (B for dA, A for dB)
|
||||
LW_ptr, stride_lw0, stride_lw1,
|
||||
LW_ptr,
|
||||
stride_lw0,
|
||||
stride_lw1,
|
||||
# Output gradient tensor (dA or dB)
|
||||
OUT_ptr, stride_out0, stride_out1,
|
||||
OUT_ptr,
|
||||
stride_out0,
|
||||
stride_out1,
|
||||
# Expert offsets
|
||||
expert_offsets_ptr,
|
||||
# Dimensions
|
||||
M, K: tl.constexpr, N: tl.constexpr,
|
||||
ACTUAL_R: tl.constexpr, BLOCK_R: tl.constexpr,
|
||||
M,
|
||||
K: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
ACTUAL_R: tl.constexpr,
|
||||
BLOCK_R: tl.constexpr,
|
||||
INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB)
|
||||
scaling,
|
||||
# Mode flag
|
||||
COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB
|
||||
# Tile sizes
|
||||
BLOCK_M: tl.constexpr, BLOCK_DIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DIM: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr,
|
||||
allow_tf32: tl.constexpr,
|
||||
NO_DIM_MASK: tl.constexpr,
|
||||
@@ -1532,9 +1563,9 @@ def _group_bwd_lora_split(
|
||||
|
||||
# Output dimension tile (K for dA, N for dB)
|
||||
if COMPUTE_DA:
|
||||
OUT_DIM: tl.constexpr = K
|
||||
OUT_DIM: tl.constexpr = K # type: ignore[no-redef]
|
||||
else:
|
||||
OUT_DIM: tl.constexpr = N
|
||||
OUT_DIM: tl.constexpr = N # type: ignore[no-redef]
|
||||
dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
|
||||
dim_mask = dim_block < OUT_DIM
|
||||
R_block = tl.arange(0, BLOCK_R)
|
||||
@@ -1577,7 +1608,8 @@ def _group_bwd_lora_split(
|
||||
# Load X[M, K_block] (the "outer" tensor for dA)
|
||||
outer = tl.load(
|
||||
X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk,
|
||||
mask=M_mask[:, None] & dim_mask[None, :], other=0.0
|
||||
mask=M_mask[:, None] & dim_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
|
||||
# Reduce DY[M, :] @ B[e][:, R] over N → [M, R]
|
||||
@@ -1588,23 +1620,34 @@ def _group_bwd_lora_split(
|
||||
inn_mask = inn_off < N
|
||||
|
||||
dy_tile = tl.load(
|
||||
DY_ptr + M_idx[:, None] * stride_dym + inn_off[None, :] * stride_dyn,
|
||||
mask=M_mask[:, None] & inn_mask[None, :], other=0.0
|
||||
DY_ptr
|
||||
+ M_idx[:, None] * stride_dym
|
||||
+ inn_off[None, :] * stride_dyn,
|
||||
mask=M_mask[:, None] & inn_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
# B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride
|
||||
lw_tile = tl.load(
|
||||
LW_ptr + inn_off[:, None] * stride_lw0 + (lora_offset + R_block)[None, :] * stride_lw1,
|
||||
mask=inn_mask[:, None] & R_mask[None, :], other=0.0
|
||||
LW_ptr
|
||||
+ inn_off[:, None] * stride_lw0
|
||||
+ (lora_offset + R_block)[None, :] * stride_lw1,
|
||||
mask=inn_mask[:, None] & R_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32)
|
||||
|
||||
# dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block]
|
||||
acc += tl.dot(tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32)
|
||||
acc += tl.dot(
|
||||
tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32
|
||||
)
|
||||
else:
|
||||
# Load DY[M, N_block] (the "outer" tensor for dB)
|
||||
outer = tl.load(
|
||||
DY_ptr + M_idx[:, None] * stride_dym + dim_block[None, :] * stride_dyn,
|
||||
mask=M_mask[:, None] & dim_mask[None, :], other=0.0
|
||||
DY_ptr
|
||||
+ M_idx[:, None] * stride_dym
|
||||
+ dim_block[None, :] * stride_dyn,
|
||||
mask=M_mask[:, None] & dim_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
|
||||
# Reduce X[M, :] @ A[e][:, :].T over K → [M, R]
|
||||
@@ -1615,27 +1658,45 @@ def _group_bwd_lora_split(
|
||||
inn_mask = inn_off < K
|
||||
|
||||
x_tile = tl.load(
|
||||
X_ptr + M_idx[:, None] * stride_xm + inn_off[None, :] * stride_xk,
|
||||
mask=M_mask[:, None] & inn_mask[None, :], other=0.0
|
||||
X_ptr
|
||||
+ M_idx[:, None] * stride_xm
|
||||
+ inn_off[None, :] * stride_xk,
|
||||
mask=M_mask[:, None] & inn_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
# A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride
|
||||
# We want A[e]^T: [K, R], so load as [K_inner, R]
|
||||
lw_tile = tl.load(
|
||||
LW_ptr + (lora_offset + R_block)[None, :] * stride_lw0 + inn_off[:, None] * stride_lw1,
|
||||
mask=inn_mask[:, None] & R_mask[None, :], other=0.0
|
||||
LW_ptr
|
||||
+ (lora_offset + R_block)[None, :] * stride_lw0
|
||||
+ inn_off[:, None] * stride_lw1,
|
||||
mask=inn_mask[:, None] & R_mask[None, :],
|
||||
other=0.0,
|
||||
).to(INPUT_DTYPE)
|
||||
reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32)
|
||||
|
||||
# dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R]
|
||||
acc += tl.dot(tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32)
|
||||
acc += tl.dot(
|
||||
tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32
|
||||
)
|
||||
|
||||
tl.store(out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask)
|
||||
tl.store(
|
||||
out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask
|
||||
)
|
||||
else:
|
||||
# Zero out this expert's slice — needed because output uses empty_like
|
||||
if COMPUTE_DA:
|
||||
tl.store(out_blk_ptrs, tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty), mask=out_mask)
|
||||
tl.store(
|
||||
out_blk_ptrs,
|
||||
tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty),
|
||||
mask=out_mask,
|
||||
)
|
||||
else:
|
||||
tl.store(out_blk_ptrs, tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty), mask=out_mask)
|
||||
tl.store(
|
||||
out_blk_ptrs,
|
||||
tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty),
|
||||
mask=out_mask,
|
||||
)
|
||||
|
||||
|
||||
def group_bwd_lora(
|
||||
@@ -1683,34 +1744,58 @@ def group_bwd_lora(
|
||||
return (E, triton.cdiv(K, META["BLOCK_DIM"]))
|
||||
|
||||
_group_bwd_lora_split[grid_dA](
|
||||
DY, DY.stride(0), DY.stride(1),
|
||||
X, X.stride(0), X.stride(1),
|
||||
lora_B, lora_B.stride(0), lora_B.stride(1),
|
||||
dA, dA.stride(0), dA.stride(1),
|
||||
DY,
|
||||
DY.stride(0),
|
||||
DY.stride(1),
|
||||
X,
|
||||
X.stride(0),
|
||||
X.stride(1),
|
||||
lora_B,
|
||||
lora_B.stride(0),
|
||||
lora_B.stride(1),
|
||||
dA,
|
||||
dA.stride(0),
|
||||
dA.stride(1),
|
||||
expert_offsets,
|
||||
M=DY.size(0), K=K, N=N,
|
||||
ACTUAL_R=R, BLOCK_R=BLOCK_R,
|
||||
M=DY.size(0),
|
||||
K=K,
|
||||
N=N,
|
||||
ACTUAL_R=R,
|
||||
BLOCK_R=BLOCK_R,
|
||||
INNER_DIM=N,
|
||||
scaling=scaling,
|
||||
COMPUTE_DA=True,
|
||||
ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32,
|
||||
ACC_TYPE=tl.float32,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
)
|
||||
|
||||
def grid_dB(META):
|
||||
return (E, triton.cdiv(N, META["BLOCK_DIM"]))
|
||||
|
||||
_group_bwd_lora_split[grid_dB](
|
||||
DY, DY.stride(0), DY.stride(1),
|
||||
X, X.stride(0), X.stride(1),
|
||||
lora_A, lora_A.stride(0), lora_A.stride(1),
|
||||
dB, dB.stride(0), dB.stride(1),
|
||||
DY,
|
||||
DY.stride(0),
|
||||
DY.stride(1),
|
||||
X,
|
||||
X.stride(0),
|
||||
X.stride(1),
|
||||
lora_A,
|
||||
lora_A.stride(0),
|
||||
lora_A.stride(1),
|
||||
dB,
|
||||
dB.stride(0),
|
||||
dB.stride(1),
|
||||
expert_offsets,
|
||||
M=DY.size(0), K=K, N=N,
|
||||
ACTUAL_R=R, BLOCK_R=BLOCK_R,
|
||||
M=DY.size(0),
|
||||
K=K,
|
||||
N=N,
|
||||
ACTUAL_R=R,
|
||||
BLOCK_R=BLOCK_R,
|
||||
INNER_DIM=K,
|
||||
scaling=scaling,
|
||||
COMPUTE_DA=False,
|
||||
ACC_TYPE=tl.float32, allow_tf32=ALLOW_TF32,
|
||||
ACC_TYPE=tl.float32,
|
||||
allow_tf32=ALLOW_TF32,
|
||||
)
|
||||
|
||||
return dA, dB
|
||||
|
||||
@@ -511,20 +511,26 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
|
||||
active_experts = get_active_experts(sorted_expert_idxs, num_experts)
|
||||
remapped_expert_idxs, compact_offsets = remap_expert_indices(
|
||||
sorted_expert_idxs, expert_offsets, active_experts, num_experts,
|
||||
sorted_expert_idxs,
|
||||
expert_offsets,
|
||||
active_experts,
|
||||
num_experts,
|
||||
)
|
||||
num_active = len(active_experts)
|
||||
|
||||
# Dequantize only active experts' weights
|
||||
gate_up_W = selective_expert_weights(
|
||||
experts, "gate_up_proj", active_experts,
|
||||
experts,
|
||||
"gate_up_proj",
|
||||
active_experts,
|
||||
).transpose(2, 1) # [num_active, hidden, 2*inter]
|
||||
|
||||
# Remap LoRA weights to match compact expert indices
|
||||
if gup_lora is not None:
|
||||
gup_A, gup_B, gup_scaling = gup_lora
|
||||
gup_A, gup_B = selective_lora_weights(
|
||||
gup_A, gup_B, active_experts, num_experts,
|
||||
gup_A,
|
||||
gup_B,
|
||||
active_experts,
|
||||
num_experts,
|
||||
)
|
||||
gup_lora = (gup_A, gup_B, gup_scaling)
|
||||
|
||||
@@ -576,13 +582,18 @@ class HFScatterMoEGatedMLP(nn.Module):
|
||||
# ====================================================================
|
||||
if use_selective:
|
||||
down_W = selective_expert_weights(
|
||||
experts, "down_proj", active_experts,
|
||||
experts,
|
||||
"down_proj",
|
||||
active_experts,
|
||||
).transpose(2, 1) # [num_active, inter, hidden]
|
||||
|
||||
if down_lora is not None:
|
||||
down_A, down_B, down_scaling = down_lora
|
||||
down_A, down_B = selective_lora_weights(
|
||||
down_A, down_B, active_experts, num_experts,
|
||||
down_A,
|
||||
down_B,
|
||||
active_experts,
|
||||
num_experts,
|
||||
)
|
||||
down_lora = (down_A, down_B, down_scaling)
|
||||
|
||||
|
||||
@@ -21,8 +21,6 @@ from global (0..E-1) to compact (0..num_active-1) and pass the smaller
|
||||
weight tensor.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -79,7 +77,7 @@ def _selective_dequant_bnb4(
|
||||
raw_param: torch.Tensor,
|
||||
quant_state,
|
||||
active_experts: torch.Tensor,
|
||||
expert_shape: tuple[int, ...],
|
||||
expert_shape: tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""Dequantize only selected experts from BnB 4-bit packed data.
|
||||
|
||||
@@ -231,7 +229,9 @@ def selective_expert_weights(
|
||||
if E_total is None:
|
||||
E_total = int(active_experts.max().item()) + 1
|
||||
expert_numel = orig_shape[0] // E_total
|
||||
d2 = getattr(experts_module, "hidden_dim", None) or getattr(experts_module, "intermediate_dim", None)
|
||||
d2 = getattr(experts_module, "hidden_dim", None) or getattr(
|
||||
experts_module, "intermediate_dim", None
|
||||
)
|
||||
if d2 and expert_numel % d2 == 0:
|
||||
expert_shape = (expert_numel // d2, d2)
|
||||
else:
|
||||
@@ -241,9 +241,7 @@ def selective_expert_weights(
|
||||
full = getattr(experts_module, param_name)
|
||||
return full[active_experts]
|
||||
|
||||
return _selective_dequant_bnb4(
|
||||
raw_param, qs, active_experts, expert_shape
|
||||
)
|
||||
return _selective_dequant_bnb4(raw_param, qs, active_experts, expert_shape)
|
||||
|
||||
# Dense parameter (bf16/fp32) — direct indexing
|
||||
param = getattr(experts_module, param_name)
|
||||
|
||||
@@ -19,15 +19,25 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# NF4 codebook (16 values, precomputed by BnB)
|
||||
# These are the normalized float4 reconstruction values
|
||||
NF4_CODEBOOK = [
|
||||
-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453,
|
||||
-0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
|
||||
0.07958029955625534, 0.16093020141124725, 0.24611230194568634,
|
||||
0.33791524171829224, 0.44070982933044434, 0.5626170039176941,
|
||||
0.7229568362236023, 1.0,
|
||||
-1.0,
|
||||
-0.6961928009986877,
|
||||
-0.5250730514526367,
|
||||
-0.39491748809814453,
|
||||
-0.28444138169288635,
|
||||
-0.18477343022823334,
|
||||
-0.09105003625154495,
|
||||
0.0,
|
||||
0.07958029955625534,
|
||||
0.16093020141124725,
|
||||
0.24611230194568634,
|
||||
0.33791524171829224,
|
||||
0.44070982933044434,
|
||||
0.5626170039176941,
|
||||
0.7229568362236023,
|
||||
1.0,
|
||||
]
|
||||
|
||||
|
||||
@@ -46,8 +56,8 @@ def _selective_dequant_nf4_kernel(
|
||||
stride_out_e, # stride for expert dim in output
|
||||
# Dimensions
|
||||
num_active,
|
||||
packed_per_expert, # expert_numel // 2
|
||||
blocks_per_expert, # expert_numel // blocksize
|
||||
packed_per_expert, # expert_numel // 2
|
||||
blocks_per_expert, # expert_numel // blocksize
|
||||
blocksize: tl.constexpr,
|
||||
# Tile size
|
||||
BLOCK_SIZE: tl.constexpr, # elements per thread block (must be multiple of 2)
|
||||
@@ -79,7 +89,9 @@ def _selective_dequant_nf4_kernel(
|
||||
|
||||
# Read packed bytes from the global expert's region
|
||||
packed_global_offset = expert_global * packed_per_expert + byte_idx
|
||||
packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(tl.int32)
|
||||
packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(
|
||||
tl.int32
|
||||
)
|
||||
|
||||
# Extract 4-bit nibble
|
||||
# BnB packing: high nibble = even element, low nibble = odd element
|
||||
@@ -133,8 +145,9 @@ def selective_dequant_nf4_triton(
|
||||
|
||||
# Prepare codebook on device
|
||||
if codebook is None:
|
||||
codebook = torch.tensor(NF4_CODEBOOK, dtype=torch.float32,
|
||||
device=packed_data.device)
|
||||
codebook = torch.tensor(
|
||||
NF4_CODEBOOK, dtype=torch.float32, device=packed_data.device
|
||||
)
|
||||
else:
|
||||
codebook = codebook.to(device=packed_data.device, dtype=torch.float32)
|
||||
|
||||
@@ -143,8 +156,7 @@ def selective_dequant_nf4_triton(
|
||||
absmax_flat = absmax.reshape(-1).float() # absmax is usually fp32
|
||||
|
||||
# Output buffer
|
||||
out = torch.empty(num_active, expert_numel, dtype=dtype,
|
||||
device=packed_data.device)
|
||||
out = torch.empty(num_active, expert_numel, dtype=dtype, device=packed_data.device)
|
||||
|
||||
BLOCK_SIZE = 1024 # Process 1024 elements per thread block
|
||||
|
||||
|
||||
@@ -66,7 +66,10 @@ class KernelsPlugin(BasePlugin):
|
||||
# Prefer text backbone type for VLMs, but fall back to base type
|
||||
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
|
||||
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
|
||||
if moe_model_type not in SPARSE_MOE_BLOCK and cfg.model_config_type in SPARSE_MOE_BLOCK:
|
||||
if (
|
||||
moe_model_type not in SPARSE_MOE_BLOCK
|
||||
and cfg.model_config_type in SPARSE_MOE_BLOCK
|
||||
):
|
||||
moe_model_type = cfg.model_config_type
|
||||
|
||||
if cfg.use_scattermoe:
|
||||
|
||||
@@ -28,9 +28,7 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
if profiler_steps_start == 0:
|
||||
# start recording memory allocations before everything is allocated, because if we start
|
||||
# at the beginning of step 0, we won't have any memory allocations in the traces
|
||||
torch.cuda.memory._record_memory_history(
|
||||
enabled="all", stacks="all"
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
|
||||
profiler_steps_start = -1
|
||||
self.profiler_steps_start = profiler_steps_start
|
||||
self._profiler = None
|
||||
@@ -43,13 +41,11 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
**kwargs,
|
||||
):
|
||||
if state.global_step == self.profiler_steps_start:
|
||||
torch.cuda.memory._record_memory_history(
|
||||
enabled="all", stacks="all"
|
||||
)
|
||||
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
|
||||
|
||||
# Start torch.profiler on the first profiled step
|
||||
if state.global_step == max(self.profiler_steps_start, 0):
|
||||
self._profiler = torch.profiler.profile(
|
||||
profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
@@ -58,7 +54,8 @@ class PytorchProfilerCallback(TrainerCallback):
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
)
|
||||
self._profiler.__enter__()
|
||||
profiler.__enter__()
|
||||
self._profiler = profiler
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
|
||||
@@ -19,8 +19,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
|
||||
ops as base_ops,
|
||||
lora_ops,
|
||||
ops as base_ops,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
|
||||
flatten_sort_count,
|
||||
@@ -151,8 +151,14 @@ class TestScatter2ScatterLoRAForward:
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
|
||||
kernel_out = lora_ops.scatter2scatter_lora(
|
||||
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=SCALING,
|
||||
)
|
||||
ref_out = _reference_fwd(x, W, sei, ssi, eo, k, lA, lB, SCALING, E)
|
||||
|
||||
@@ -164,8 +170,14 @@ class TestScatter2ScatterLoRAForward:
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
|
||||
out = lora_ops.scatter2scatter_lora(
|
||||
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=SCALING,
|
||||
)
|
||||
assert out.shape == (T * k, N)
|
||||
assert out.dtype == DTYPE
|
||||
@@ -188,9 +200,16 @@ class TestScatter2ScatterLoRADX:
|
||||
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||
|
||||
kernel_dx = lora_ops.scatter2scatter_lora_dX(
|
||||
DY=dy, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||
k=1, lora_A=lA, lora_B=lB, scaling=SCALING,
|
||||
dy_grouped=True, dx_grouped=False,
|
||||
DY=dy,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=1,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=SCALING,
|
||||
dy_grouped=True,
|
||||
dx_grouped=False,
|
||||
)
|
||||
ref_dx = _reference_dX(dy, W, sei, ssi, eo, lA, lB, SCALING, E)
|
||||
|
||||
@@ -215,8 +234,13 @@ class TestGroupBwdLoRA:
|
||||
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||
|
||||
kern_dA, kern_dB = lora_ops.group_bwd_lora(
|
||||
DY=dy, X=gx, lora_A=lA, lora_B=lB,
|
||||
expert_offsets=eo, E=E, scaling=SCALING,
|
||||
DY=dy,
|
||||
X=gx,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
expert_offsets=eo,
|
||||
E=E,
|
||||
scaling=SCALING,
|
||||
)
|
||||
ref_dA, ref_dB = _reference_bwd_lora(dy, gx, lA, lB, eo, E, SCALING)
|
||||
|
||||
@@ -225,12 +249,10 @@ class TestGroupBwdLoRA:
|
||||
# fp32 loop), so max absolute error can be large on individual elements
|
||||
# while the overall tensor is correct.
|
||||
dA_norm_err = (
|
||||
(kern_dA.float() - ref_dA.float()).norm()
|
||||
/ (ref_dA.float().norm() + 1e-6)
|
||||
(kern_dA.float() - ref_dA.float()).norm() / (ref_dA.float().norm() + 1e-6)
|
||||
).item()
|
||||
dB_norm_err = (
|
||||
(kern_dB.float() - ref_dB.float()).norm()
|
||||
/ (ref_dB.float().norm() + 1e-6)
|
||||
(kern_dB.float() - ref_dB.float()).norm() / (ref_dB.float().norm() + 1e-6)
|
||||
).item()
|
||||
assert dA_norm_err < 0.01, f"[{desc}] dA norm_rel_err={dA_norm_err}"
|
||||
assert dB_norm_err < 0.01, f"[{desc}] dB norm_rel_err={dB_norm_err}"
|
||||
@@ -249,14 +271,21 @@ class TestGroupBwdLoRA:
|
||||
lB = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE)
|
||||
|
||||
dA, dB = lora_ops.group_bwd_lora(
|
||||
DY=dy, X=gx, lora_A=lA, lora_B=lB,
|
||||
expert_offsets=eo, E=E, scaling=2.0,
|
||||
DY=dy,
|
||||
X=gx,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
expert_offsets=eo,
|
||||
E=E,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
# Experts 1..7 should have zero gradients
|
||||
for e in range(1, E):
|
||||
assert dA[e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dA not zero"
|
||||
assert dB[:, e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dB not zero"
|
||||
assert dB[:, e * R : (e + 1) * R].abs().max() == 0, (
|
||||
f"Expert {e} dB not zero"
|
||||
)
|
||||
|
||||
|
||||
# ─── Full autograd tests ────────────────────────────────────────────────────
|
||||
@@ -278,9 +307,21 @@ class TestScatterMoELoRAAutograd:
|
||||
lB = lB.requires_grad_(True)
|
||||
|
||||
out = ScatterMoELoRA.apply(
|
||||
x, W, k, sei, ssi, eo,
|
||||
lA, lB, SCALING,
|
||||
None, None, False, False, True, False,
|
||||
x,
|
||||
W,
|
||||
k,
|
||||
sei,
|
||||
ssi,
|
||||
eo,
|
||||
lA,
|
||||
lB,
|
||||
SCALING,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
out.sum().backward()
|
||||
|
||||
@@ -293,7 +334,6 @@ class TestScatterMoELoRAAutograd:
|
||||
assert x.grad.abs().sum() > 0, f"[{desc}] x.grad all zero"
|
||||
assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero"
|
||||
|
||||
|
||||
def test_split_matches_fused(self):
|
||||
"""Split dispatch (for few large experts) matches fused kernel."""
|
||||
# Use a shape where split would be dispatched (large K*N, few E)
|
||||
@@ -304,15 +344,27 @@ class TestScatterMoELoRAAutograd:
|
||||
orig = lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
||||
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 10**18
|
||||
out_fused = lora_ops.scatter2scatter_lora(
|
||||
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=SCALING,
|
||||
)
|
||||
|
||||
# Force split path
|
||||
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 0
|
||||
out_split = lora_ops.scatter2scatter_lora(
|
||||
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi,
|
||||
k=k, lora_A=lA, lora_B=lB, scaling=SCALING,
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=SCALING,
|
||||
)
|
||||
lora_ops._SPLIT_LORA_FWD_THRESHOLD = orig
|
||||
|
||||
@@ -328,12 +380,28 @@ class TestScatterMoELoRAAutograd:
|
||||
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||
|
||||
out_lora = ScatterMoELoRA.apply(
|
||||
x, W, k, sei, ssi, eo,
|
||||
lA, lB, 0.0,
|
||||
None, None, False, False, True, False,
|
||||
x,
|
||||
W,
|
||||
k,
|
||||
sei,
|
||||
ssi,
|
||||
eo,
|
||||
lA,
|
||||
lB,
|
||||
0.0,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
out_base = base_ops.scatter2scatter(
|
||||
X=x, W=W, sorted_expert_idxs=sei, sorted_scattered_idxs=ssi, k=k,
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=k,
|
||||
)
|
||||
err = (out_lora.float() - out_base.float()).abs().max().item()
|
||||
assert err < 0.01, f"scaling=0 should match base: err={err}"
|
||||
|
||||
Reference in New Issue
Block a user