Scattermoe LoRA optimizations (#3513)
* optimize moe + lora * more scattermoe optims * selective dequant * add correctness unit tests and benchmarks for scattermoe + lora * handle base+lora split kernel for older moe models * chore: lint * fix casting for H200 and B200 * register pressure estimation and pruning for h200/b200 * use soft limit for pruning * qkv patch for qwen3.5moe * support text_model for qwen3.5 moe * nesting of qwen3 * use udpated cce with zero3 support * Fix decomposed backward for QKV and O projections eliminates B @ A materialization in LoRA attention backward, replacing full [out, in] matmuls with two small [T, R] matmuls.
This commit is contained in:
284
benchmarks/bench_scattermoe_lora.py
Normal file
284
benchmarks/bench_scattermoe_lora.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
"""Benchmark for ScatterMoE LoRA Triton kernels.
|
||||||
|
|
||||||
|
Measures forward, backward dX, and backward dA/dB kernels at common MoE
|
||||||
|
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
|
||||||
|
and full fwd+bwd autograd throughput.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import time
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
|
||||||
|
lora_ops,
|
||||||
|
ops as base_ops,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
|
||||||
|
flatten_sort_count,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
|
||||||
|
ScatterMoELoRA,
|
||||||
|
)
|
||||||
|
|
||||||
|
DEVICE = "cuda"
|
||||||
|
DTYPE = torch.bfloat16
|
||||||
|
WARMUP = 5
|
||||||
|
ITERS = 20
|
||||||
|
|
||||||
|
# ─── Model configs ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
BUILTIN_CONFIGS = {
|
||||||
|
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
|
||||||
|
"Qwen3-30B-A3B": (128, 2048, 768, 8),
|
||||||
|
"OLMoE-1B-7B": (64, 2048, 1024, 8),
|
||||||
|
"Mixtral-8x7B": (8, 4096, 14336, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_config(spec):
|
||||||
|
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
|
||||||
|
key = spec.lower().replace("/", "-")
|
||||||
|
for name, cfg in BUILTIN_CONFIGS.items():
|
||||||
|
if key in name.lower() or name.lower() in key:
|
||||||
|
return name, cfg
|
||||||
|
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
|
||||||
|
if callable(getattr(hf_cfg, "get_text_config", None)):
|
||||||
|
tc = hf_cfg.get_text_config()
|
||||||
|
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
|
||||||
|
hf_cfg = tc
|
||||||
|
hidden = hf_cfg.hidden_size
|
||||||
|
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
|
||||||
|
experts = (
|
||||||
|
getattr(hf_cfg, "num_experts", None)
|
||||||
|
or getattr(hf_cfg, "num_local_experts", None)
|
||||||
|
or getattr(hf_cfg, "n_routed_experts", None)
|
||||||
|
)
|
||||||
|
top_k = (
|
||||||
|
getattr(hf_cfg, "num_experts_per_tok", None)
|
||||||
|
or getattr(hf_cfg, "num_experts_per_token", None)
|
||||||
|
or 2
|
||||||
|
)
|
||||||
|
name = spec.split("/")[-1]
|
||||||
|
return name, (experts, hidden, inter, top_k)
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Benchmark helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _clean():
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
def _bench(fn, warmup=WARMUP, iters=ITERS):
|
||||||
|
for _ in range(warmup):
|
||||||
|
fn()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
times = []
|
||||||
|
for _ in range(iters):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
fn()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
times.append((time.perf_counter() - t0) * 1000)
|
||||||
|
times.sort()
|
||||||
|
return times[len(times) // 2]
|
||||||
|
|
||||||
|
|
||||||
|
def _setup(num_experts, K, N, T, top_k, R):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||||
|
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
|
||||||
|
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
|
||||||
|
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
|
||||||
|
logits = torch.randn(T, num_experts, device=DEVICE)
|
||||||
|
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
|
||||||
|
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
|
||||||
|
gx = base_ops.group(x, ssi, fan_out=top_k)
|
||||||
|
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||||
|
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
|
||||||
|
return lora_ops.scatter2scatter_lora(
|
||||||
|
X=x,
|
||||||
|
W=W,
|
||||||
|
sorted_expert_idxs=sei,
|
||||||
|
sorted_scattered_idxs=ssi,
|
||||||
|
k=top_k,
|
||||||
|
lora_A=lA,
|
||||||
|
lora_B=lB,
|
||||||
|
scaling=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _call_base(x, W, sei, ssi, top_k):
|
||||||
|
return base_ops.scatter2scatter(
|
||||||
|
X=x,
|
||||||
|
W=W,
|
||||||
|
sorted_expert_idxs=sei,
|
||||||
|
sorted_scattered_idxs=ssi,
|
||||||
|
k=top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _call_dx(dy, W, sei, ssi, lA, lB):
|
||||||
|
return lora_ops.scatter2scatter_lora_dX(
|
||||||
|
DY=dy,
|
||||||
|
W=W,
|
||||||
|
sorted_expert_idxs=sei,
|
||||||
|
sorted_scattered_idxs=ssi,
|
||||||
|
k=1,
|
||||||
|
lora_A=lA,
|
||||||
|
lora_B=lB,
|
||||||
|
scaling=2.0,
|
||||||
|
dy_grouped=True,
|
||||||
|
dx_grouped=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
|
||||||
|
return lora_ops.group_bwd_lora(
|
||||||
|
DY=dy,
|
||||||
|
X=gx,
|
||||||
|
lora_A=lA,
|
||||||
|
lora_B=lB,
|
||||||
|
expert_offsets=eo,
|
||||||
|
E=num_experts,
|
||||||
|
scaling=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Main ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
"-m",
|
||||||
|
nargs="+",
|
||||||
|
help="Model names or HF IDs (default: all builtins)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
|
||||||
|
parser.add_argument("--seq-len", "-T", type=int, default=2048)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
T = args.seq_len
|
||||||
|
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||||
|
print(f"T={T}, ranks={args.ranks}\n")
|
||||||
|
|
||||||
|
if args.models:
|
||||||
|
configs = [_resolve_config(m) for m in args.models]
|
||||||
|
else:
|
||||||
|
configs = list(BUILTIN_CONFIGS.items())
|
||||||
|
|
||||||
|
for model_name, (num_experts, hidden, inter, top_k) in configs:
|
||||||
|
print(f"{'=' * 70}")
|
||||||
|
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
|
||||||
|
print(f"{'=' * 70}")
|
||||||
|
|
||||||
|
for R in args.ranks:
|
||||||
|
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
|
||||||
|
_clean()
|
||||||
|
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
|
||||||
|
num_experts, K, N, T, top_k, R
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forward with LoRA (auto-dispatched: fused or split)
|
||||||
|
dispatch = (
|
||||||
|
"split"
|
||||||
|
if (
|
||||||
|
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
|
||||||
|
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
||||||
|
)
|
||||||
|
else "fused"
|
||||||
|
)
|
||||||
|
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
|
||||||
|
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
|
||||||
|
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
|
||||||
|
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
|
||||||
|
|
||||||
|
total = t_fwd + t_dx + t_bwd
|
||||||
|
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" R={R:>2} {proj:<8} "
|
||||||
|
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
|
||||||
|
f"base={t_base:>6.2f}ms "
|
||||||
|
f"(+{overhead * 100:.0f}%) "
|
||||||
|
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
|
||||||
|
f"total={total:>6.2f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Full autograd fwd+bwd with memory measurement
|
||||||
|
x_ag = x.clone().requires_grad_(True)
|
||||||
|
lA_ag = lA.clone().requires_grad_(True)
|
||||||
|
lB_ag = lB.clone().requires_grad_(True)
|
||||||
|
|
||||||
|
def _run_autograd(
|
||||||
|
_x=x_ag,
|
||||||
|
_W=W,
|
||||||
|
_k=top_k,
|
||||||
|
_sei=sei,
|
||||||
|
_ssi=ssi,
|
||||||
|
_eo=eo,
|
||||||
|
_lA=lA_ag,
|
||||||
|
_lB=lB_ag,
|
||||||
|
):
|
||||||
|
out = ScatterMoELoRA.apply(
|
||||||
|
_x,
|
||||||
|
_W,
|
||||||
|
_k,
|
||||||
|
_sei,
|
||||||
|
_ssi,
|
||||||
|
_eo,
|
||||||
|
_lA,
|
||||||
|
_lB,
|
||||||
|
2.0,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
out.sum().backward()
|
||||||
|
_x.grad = None
|
||||||
|
_lA.grad = None
|
||||||
|
_lB.grad = None
|
||||||
|
|
||||||
|
t_full = _bench(_run_autograd)
|
||||||
|
|
||||||
|
_clean()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
mem_before = torch.cuda.memory_allocated()
|
||||||
|
_run_autograd()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
mem_peak = torch.cuda.max_memory_allocated() - mem_before
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" full_fwd_bwd={t_full:>6.2f}ms "
|
||||||
|
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -40,7 +40,7 @@
|
|||||||
"%%capture\n",
|
"%%capture\n",
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe\""
|
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"'
|
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
|
|
||||||
- If you are installing from pip
|
- If you are installing from pip
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"
|
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
|
|||||||
|
|
||||||
_CCE_INSTALL_MESSAGE = (
|
_CCE_INSTALL_MESSAGE = (
|
||||||
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
|
||||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fa9a7fe"`'
|
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"`'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -195,6 +195,36 @@ def _estimate_smem_usage(
|
|||||||
_SMEM_SLACK = 10_000
|
_SMEM_SLACK = 10_000
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_register_pressure(
|
||||||
|
num_warps: int,
|
||||||
|
*tile_sizes: tuple[int, int],
|
||||||
|
) -> float:
|
||||||
|
"""Rough estimate of per-thread register footprint from live tile sizes.
|
||||||
|
|
||||||
|
This is a heuristic, NOT an accurate register count. Triton uses tensor
|
||||||
|
core MMA fragments that pack multiple elements per register, and can spill
|
||||||
|
to local memory when the hardware limit (255 regs/thread) is exceeded.
|
||||||
|
|
||||||
|
The estimate is used to prune only truly extreme configs that would cause
|
||||||
|
excessive spilling or compilation failures. The threshold is set high
|
||||||
|
(``_MAX_REGS_SOFT_LIMIT``) because the heuristic overestimates — it
|
||||||
|
doesn't account for MMA fragment packing. Configs like M=64,N=64,K=64
|
||||||
|
(est ~520) work fine in practice via spilling.
|
||||||
|
|
||||||
|
Returns estimated registers per thread.
|
||||||
|
"""
|
||||||
|
# Each thread in a warp holds ~1/32 of the tile elements
|
||||||
|
tile_regs = sum(r * c for r, c in tile_sizes) / 32
|
||||||
|
scalar_overhead = 40
|
||||||
|
return tile_regs + scalar_overhead
|
||||||
|
|
||||||
|
|
||||||
|
# Soft limit for register pressure pruning. Only prune configs with extreme
|
||||||
|
# tile products (e.g. M=128,K=256,N=256) that reliably crash on Blackwell.
|
||||||
|
# Moderate configs (M=64,N=64,K=64, est ~520) work via register spilling.
|
||||||
|
_MAX_REGS_SOFT_LIMIT = 1024
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Forward Kernel: scatter2scatter with fused LoRA
|
# Forward Kernel: scatter2scatter with fused LoRA
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -313,12 +343,11 @@ def _compute_expert_block_lora(
|
|||||||
B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0
|
B_blk_ptrs, mask=N_mask[:, None] & R_mask[None, :], other=0.0
|
||||||
) # [BLOCK_N, BLOCK_R]
|
) # [BLOCK_N, BLOCK_R]
|
||||||
|
|
||||||
# Cast xa_acc and b to same dtype for tl.dot (required when input is bf16/fp16)
|
# tl.dot requires non-float32 inputs (tensor cores); cast back to input dtype
|
||||||
# Both operands must match; cast to float32 (accumulator type) for precision.
|
b_inp = b.to(INPUT_DTYPE)
|
||||||
b_f32 = b.to(tl.float32)
|
|
||||||
|
|
||||||
# (X @ A^T) @ B^T: [M, R] @ [R, N] -> [M, N]
|
# (X @ A^T) @ B^T: [M, R] @ [R, N] -> [M, N]
|
||||||
lora_out = tl.dot(xa_acc, tl.trans(b_f32), allow_tf32=allow_tf32)
|
lora_out = tl.dot(xa_acc.to(INPUT_DTYPE), tl.trans(b_inp), allow_tf32=allow_tf32)
|
||||||
|
|
||||||
acc += scaling * lora_out
|
acc += scaling * lora_out
|
||||||
return acc
|
return acc
|
||||||
@@ -327,20 +356,21 @@ def _compute_expert_block_lora(
|
|||||||
def _scatter2scatter_lora_configs():
|
def _scatter2scatter_lora_configs():
|
||||||
"""Generate forward kernel autotune configs.
|
"""Generate forward kernel autotune configs.
|
||||||
|
|
||||||
Search space includes smaller tile sizes and fewer pipeline stages to
|
Search space includes BLOCK_M to allow trading token-tile size for
|
||||||
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
larger BLOCK_K/BLOCK_N tiles. On GPUs with ~99KB SMEM, BLOCK_M=128
|
||||||
|
forces BLOCK_K=32 and BLOCK_N=32; BLOCK_M=64 allows BLOCK_K=128
|
||||||
|
(4× fewer inner-loop iterations).
|
||||||
|
|
||||||
Search space:
|
Search space:
|
||||||
|
BLOCK_M: {32, 64, 128}
|
||||||
BLOCK_N: {32, 64, 128, 256}
|
BLOCK_N: {32, 64, 128, 256}
|
||||||
BLOCK_K: {32, 64, 128}
|
BLOCK_K: {32, 64, 128}
|
||||||
num_warps: {4, 8}
|
num_warps: {4, 8}
|
||||||
num_stages: {3, 4, 5}
|
num_stages: {3, 4, 5}
|
||||||
|
|
||||||
BLOCK_M is fixed at 128 (module-level constant, not autotuned in the
|
|
||||||
scatter2scatter pattern).
|
|
||||||
"""
|
"""
|
||||||
configs = []
|
configs = []
|
||||||
for block_n, block_k, warps, stages in product(
|
for block_m, block_n, block_k, warps, stages in product(
|
||||||
|
[32, 64, 128], # BLOCK_M
|
||||||
[32, 64, 128, 256], # BLOCK_N
|
[32, 64, 128, 256], # BLOCK_N
|
||||||
[32, 64, 128], # BLOCK_K
|
[32, 64, 128], # BLOCK_K
|
||||||
[4, 8], # num_warps
|
[4, 8], # num_warps
|
||||||
@@ -348,7 +378,7 @@ def _scatter2scatter_lora_configs():
|
|||||||
):
|
):
|
||||||
configs.append(
|
configs.append(
|
||||||
triton.Config(
|
triton.Config(
|
||||||
{"BLOCK_N": block_n, "BLOCK_K": block_k},
|
{"BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k},
|
||||||
num_stages=stages,
|
num_stages=stages,
|
||||||
num_warps=warps,
|
num_warps=warps,
|
||||||
)
|
)
|
||||||
@@ -357,7 +387,7 @@ def _scatter2scatter_lora_configs():
|
|||||||
|
|
||||||
|
|
||||||
def _prune_fwd_configs(configs, named_args, **kwargs):
|
def _prune_fwd_configs(configs, named_args, **kwargs):
|
||||||
"""Prune forward configs based on SMEM capacity.
|
"""Prune forward configs based on SMEM capacity and register pressure.
|
||||||
|
|
||||||
The forward kernel inner loop loads three tiles per pipeline stage:
|
The forward kernel inner loop loads three tiles per pipeline stage:
|
||||||
X[BLOCK_M, BLOCK_K], W[BLOCK_K, BLOCK_N], A[BLOCK_R, BLOCK_K].
|
X[BLOCK_M, BLOCK_K], W[BLOCK_K, BLOCK_N], A[BLOCK_R, BLOCK_K].
|
||||||
@@ -373,23 +403,49 @@ def _prune_fwd_configs(configs, named_args, **kwargs):
|
|||||||
|
|
||||||
scored = []
|
scored = []
|
||||||
for config in configs:
|
for config in configs:
|
||||||
|
block_m = config.kwargs["BLOCK_M"]
|
||||||
block_n = config.kwargs["BLOCK_N"]
|
block_n = config.kwargs["BLOCK_N"]
|
||||||
block_k = config.kwargs["BLOCK_K"]
|
block_k = config.kwargs["BLOCK_K"]
|
||||||
# Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N
|
# Base: stages * BLOCK_K * (BLOCK_M + BLOCK_N) + BLOCK_M * BLOCK_N
|
||||||
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_n, block_k)
|
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_n, block_k)
|
||||||
# A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop
|
# A tile [BLOCK_R, BLOCK_K] loaded per stage in the inner loop
|
||||||
smem_lora_loop = config.num_stages * block_r * block_k * 2
|
smem_lora_loop = config.num_stages * block_r * block_k * 2
|
||||||
# B tile [BLOCK_N, BLOCK_R] loaded once in epilogue
|
# B tile [BLOCK_N, BLOCK_R] loaded once in epilogue
|
||||||
smem_lora_epilogue = block_n * block_r * 2
|
smem_lora_epilogue = block_n * block_r * 2
|
||||||
smem = smem_base + smem_lora_loop + smem_lora_epilogue
|
smem = smem_base + smem_lora_loop + smem_lora_epilogue
|
||||||
|
|
||||||
|
# Register pressure: live tiles are acc[M,N], xa_acc[M,R],
|
||||||
|
# x[M,K], w[K,N], a[R,K], plus epilogue b[N,R]
|
||||||
|
est_regs = _estimate_register_pressure(
|
||||||
|
config.num_warps,
|
||||||
|
(block_m, block_n), # acc
|
||||||
|
(block_m, block_r), # xa_acc
|
||||||
|
(block_m, block_k), # x tile
|
||||||
|
(block_k, block_n), # w tile
|
||||||
|
(block_r, block_k), # a tile
|
||||||
|
(block_n, block_r), # b tile (epilogue)
|
||||||
|
)
|
||||||
|
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||||
|
continue
|
||||||
|
|
||||||
scored.append((smem, config))
|
scored.append((smem, config))
|
||||||
|
|
||||||
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
|
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
|
||||||
if pruned:
|
if pruned:
|
||||||
return pruned
|
return pruned
|
||||||
# All configs exceed SMEM — return the one with smallest estimated usage
|
if scored:
|
||||||
scored.sort(key=lambda x: x[0])
|
# All surviving configs exceed SMEM — return the one with smallest usage
|
||||||
return [scored[0][1]]
|
scored.sort(key=lambda x: x[0])
|
||||||
|
return [scored[0][1]]
|
||||||
|
# All configs pruned by register pressure — fall back to smallest tiles
|
||||||
|
return [
|
||||||
|
min(
|
||||||
|
configs,
|
||||||
|
key=lambda c: (
|
||||||
|
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_N"] * c.kwargs["BLOCK_K"]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
@@ -531,6 +587,89 @@ def _scatter2scatter_lora(
|
|||||||
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
|
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
|
||||||
|
|
||||||
|
|
||||||
|
def _scatter2scatter_lora_split(
|
||||||
|
X: torch.Tensor,
|
||||||
|
W: torch.Tensor,
|
||||||
|
sorted_expert_idxs: torch.Tensor,
|
||||||
|
sorted_scattered_idxs: torch.Tensor,
|
||||||
|
k: int,
|
||||||
|
lora_A: torch.Tensor,
|
||||||
|
lora_B: torch.Tensor,
|
||||||
|
scaling: float,
|
||||||
|
b: Optional[torch.Tensor] = None,
|
||||||
|
x_grouped: bool = False,
|
||||||
|
y_grouped: bool = False,
|
||||||
|
out: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Split base+LoRA forward: 3 scatter2scatter calls, no fused LoRA kernel.
|
||||||
|
|
||||||
|
Faster for models with few large experts (e.g. Mixtral E=8, I=14336)
|
||||||
|
because the base kernel runs at full speed without LoRA SMEM overhead,
|
||||||
|
and the LoRA matmuls (R=16) are tiny separate passes.
|
||||||
|
|
||||||
|
Y = scatter(X, W) + scaling * scatter(scatter(X, A^T), B^T)
|
||||||
|
"""
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels.ops import (
|
||||||
|
scatter2scatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
E = W.size(0)
|
||||||
|
R = lora_A.size(0) // E
|
||||||
|
K = W.size(1)
|
||||||
|
N = W.size(2)
|
||||||
|
|
||||||
|
# 1. Base: Y_base = X @ W (uses base kernel with optimal tile sizes)
|
||||||
|
output = scatter2scatter(
|
||||||
|
X=X,
|
||||||
|
W=W,
|
||||||
|
b=b,
|
||||||
|
sorted_expert_idxs=sorted_expert_idxs,
|
||||||
|
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||||
|
k=k,
|
||||||
|
x_grouped=x_grouped,
|
||||||
|
y_grouped=y_grouped,
|
||||||
|
out=out,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. XA = X @ A^T (tiny: output is [M*k, R])
|
||||||
|
# Reshape A: [R*E, K] → [E, K, R] (expert weights for scatter2scatter)
|
||||||
|
W_A = lora_A.reshape(E, R, K).permute(0, 2, 1).contiguous()
|
||||||
|
XA = scatter2scatter(
|
||||||
|
X=X,
|
||||||
|
W=W_A,
|
||||||
|
sorted_expert_idxs=sorted_expert_idxs,
|
||||||
|
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||||
|
k=k,
|
||||||
|
x_grouped=x_grouped,
|
||||||
|
y_grouped=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Y_lora = XA @ B^T (R is tiny, so this is very fast)
|
||||||
|
# Reshape B: [N, R*E] → [E, R, N]
|
||||||
|
W_B = lora_B.T.reshape(E, R, N).contiguous()
|
||||||
|
Y_lora = scatter2scatter(
|
||||||
|
X=XA,
|
||||||
|
W=W_B,
|
||||||
|
sorted_expert_idxs=sorted_expert_idxs,
|
||||||
|
sorted_scattered_idxs=sorted_scattered_idxs,
|
||||||
|
k=1,
|
||||||
|
x_grouped=True,
|
||||||
|
y_grouped=y_grouped,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Y = Y_base + scaling * Y_lora
|
||||||
|
output.add_(Y_lora, alpha=scaling)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
# Threshold for switching from fused to split LoRA forward.
|
||||||
|
# Split wins when per-expert matmul is large (bandwidth-bound LoRA tile
|
||||||
|
# loads dominate in the fused kernel's inner loop).
|
||||||
|
# Empirically: split wins for E<=32 with K*N > 20M (e.g. Mixtral, Phi-MoE).
|
||||||
|
_SPLIT_LORA_FWD_THRESHOLD = 20_000_000 # per-expert K*N
|
||||||
|
_SPLIT_LORA_FWD_MAX_EXPERTS = 32
|
||||||
|
|
||||||
|
|
||||||
def scatter2scatter_lora(
|
def scatter2scatter_lora(
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
W: torch.Tensor,
|
W: torch.Tensor,
|
||||||
@@ -546,7 +685,13 @@ def scatter2scatter_lora(
|
|||||||
out: Optional[torch.Tensor] = None,
|
out: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Fused scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
|
Scatter2scatter with LoRA: Y[i] = X[i] @ W[e] + scaling * (X[i] @ A[e]^T) @ B[e]^T + b[e]
|
||||||
|
|
||||||
|
Automatically selects between:
|
||||||
|
- Fused kernel: single Triton kernel with LoRA in the inner loop.
|
||||||
|
Best for many small experts (E>=64, small K*N).
|
||||||
|
- Split dispatch: 3 separate scatter2scatter calls (base + XA + lora).
|
||||||
|
Best for few large experts (E<=32, large K*N like Mixtral).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
X: Input [M, K] or [M*k, K] if x_grouped
|
X: Input [M, K] or [M*k, K] if x_grouped
|
||||||
@@ -565,12 +710,30 @@ def scatter2scatter_lora(
|
|||||||
Returns:
|
Returns:
|
||||||
Y: Output [M*k, N]
|
Y: Output [M*k, N]
|
||||||
"""
|
"""
|
||||||
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
|
||||||
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
|
||||||
|
|
||||||
E = W.size(0)
|
E = W.size(0)
|
||||||
K = W.size(1)
|
K = W.size(1)
|
||||||
N = W.size(2)
|
N = W.size(2)
|
||||||
|
|
||||||
|
# Dispatch: split for few large experts, fused for many small experts
|
||||||
|
if E <= _SPLIT_LORA_FWD_MAX_EXPERTS and K * N >= _SPLIT_LORA_FWD_THRESHOLD:
|
||||||
|
return _scatter2scatter_lora_split(
|
||||||
|
X,
|
||||||
|
W,
|
||||||
|
sorted_expert_idxs,
|
||||||
|
sorted_scattered_idxs,
|
||||||
|
k,
|
||||||
|
lora_A,
|
||||||
|
lora_B,
|
||||||
|
scaling,
|
||||||
|
b,
|
||||||
|
x_grouped,
|
||||||
|
y_grouped,
|
||||||
|
out,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
|
||||||
|
assert sorted_scattered_idxs.size(0) == X.size(0) * k
|
||||||
|
|
||||||
R = lora_A.size(0) // E
|
R = lora_A.size(0) // E
|
||||||
|
|
||||||
# Pad R to power of 2 for Triton tile size
|
# Pad R to power of 2 for Triton tile size
|
||||||
@@ -610,11 +773,9 @@ def scatter2scatter_lora(
|
|||||||
b_ptr,
|
b_ptr,
|
||||||
stride_be,
|
stride_be,
|
||||||
stride_bn,
|
stride_bn,
|
||||||
# A: [r*E, K] -> stride(0) is r*E dim stride, stride(1) is K dim stride
|
|
||||||
lora_A,
|
lora_A,
|
||||||
lora_A.stride(0),
|
lora_A.stride(0),
|
||||||
lora_A.stride(1),
|
lora_A.stride(1),
|
||||||
# B: [N, r*E] -> stride(0) is N dim stride, stride(1) is r*E dim stride
|
|
||||||
lora_B,
|
lora_B,
|
||||||
lora_B.stride(0),
|
lora_B.stride(0),
|
||||||
lora_B.stride(1),
|
lora_B.stride(1),
|
||||||
@@ -625,9 +786,8 @@ def scatter2scatter_lora(
|
|||||||
K=K,
|
K=K,
|
||||||
N=N,
|
N=N,
|
||||||
E=E,
|
E=E,
|
||||||
ACTUAL_R=R, # True LoRA rank for weight indexing
|
ACTUAL_R=R,
|
||||||
BLOCK_M=BLOCK_M,
|
BLOCK_R=BLOCK_R,
|
||||||
BLOCK_R=BLOCK_R, # Padded tile size >= max(R, 16)
|
|
||||||
ACC_TYPE=tl.float32,
|
ACC_TYPE=tl.float32,
|
||||||
scaling=scaling,
|
scaling=scaling,
|
||||||
allow_tf32=ALLOW_TF32,
|
allow_tf32=ALLOW_TF32,
|
||||||
@@ -761,13 +921,13 @@ def _compute_expert_block_lora_dX(
|
|||||||
+ (A_expert_offset + R_block)[:, None] * stride_ar
|
+ (A_expert_offset + R_block)[:, None] * stride_ar
|
||||||
+ K_block[None, :] * stride_ak
|
+ K_block[None, :] * stride_ak
|
||||||
)
|
)
|
||||||
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0)
|
a_e = tl.load(A_blk_ptrs, mask=R_mask[:, None] & K_mask[None, :], other=0.0).to(
|
||||||
|
INPUT_DTYPE
|
||||||
# Cast to float32 for precision
|
)
|
||||||
a_f32 = a_e.to(tl.float32)
|
|
||||||
|
|
||||||
# (DY @ B) @ A: [M, R] @ [R, K] -> [M, K]
|
# (DY @ B) @ A: [M, R] @ [R, K] -> [M, K]
|
||||||
lora_dx = tl.dot(dy_b_acc, a_f32, allow_tf32=allow_tf32)
|
# tl.dot requires non-float32 inputs (tensor cores); cast accumulator back to input dtype
|
||||||
|
lora_dx = tl.dot(dy_b_acc.to(INPUT_DTYPE), a_e, allow_tf32=allow_tf32)
|
||||||
|
|
||||||
acc += scaling * lora_dx
|
acc += scaling * lora_dx
|
||||||
return acc
|
return acc
|
||||||
@@ -779,17 +939,18 @@ def _scatter2scatter_lora_dX_configs():
|
|||||||
The inner loop is over N (not K as in forward). The output dimension is K.
|
The inner loop is over N (not K as in forward). The output dimension is K.
|
||||||
So BLOCK_K tiles the output and BLOCK_N tiles the reduction.
|
So BLOCK_K tiles the output and BLOCK_N tiles the reduction.
|
||||||
|
|
||||||
Search space includes smaller tile sizes and fewer pipeline stages to
|
BLOCK_M is now autotunable (was fixed at 128).
|
||||||
support GPUs with limited shared memory (e.g. ~99KB on some GPUs).
|
|
||||||
|
|
||||||
Search space:
|
Search space:
|
||||||
|
BLOCK_M: {32, 64, 128} (token tile)
|
||||||
BLOCK_K: {32, 64, 128, 256} (output tile)
|
BLOCK_K: {32, 64, 128, 256} (output tile)
|
||||||
BLOCK_N: {32, 64, 128, 256} (reduction tile)
|
BLOCK_N: {32, 64, 128, 256} (reduction tile)
|
||||||
num_warps: {4, 8}
|
num_warps: {4, 8}
|
||||||
num_stages: {3, 4, 5}
|
num_stages: {3, 4, 5}
|
||||||
"""
|
"""
|
||||||
configs = []
|
configs = []
|
||||||
for block_k, block_n, warps, stages in product(
|
for block_m, block_k, block_n, warps, stages in product(
|
||||||
|
[32, 64, 128], # BLOCK_M
|
||||||
[32, 64, 128, 256], # BLOCK_K (output dimension)
|
[32, 64, 128, 256], # BLOCK_K (output dimension)
|
||||||
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
|
[32, 64, 128, 256], # BLOCK_N (reduction dimension)
|
||||||
[4, 8], # num_warps
|
[4, 8], # num_warps
|
||||||
@@ -797,7 +958,7 @@ def _scatter2scatter_lora_dX_configs():
|
|||||||
):
|
):
|
||||||
configs.append(
|
configs.append(
|
||||||
triton.Config(
|
triton.Config(
|
||||||
{"BLOCK_K": block_k, "BLOCK_N": block_n},
|
{"BLOCK_M": block_m, "BLOCK_K": block_k, "BLOCK_N": block_n},
|
||||||
num_stages=stages,
|
num_stages=stages,
|
||||||
num_warps=warps,
|
num_warps=warps,
|
||||||
)
|
)
|
||||||
@@ -806,7 +967,7 @@ def _scatter2scatter_lora_dX_configs():
|
|||||||
|
|
||||||
|
|
||||||
def _prune_dX_configs(configs, named_args, **kwargs):
|
def _prune_dX_configs(configs, named_args, **kwargs):
|
||||||
"""Prune backward dX configs based on SMEM capacity.
|
"""Prune backward dX configs based on SMEM capacity and register pressure.
|
||||||
|
|
||||||
The dX kernel inner loop loads three tiles per pipeline stage:
|
The dX kernel inner loop loads three tiles per pipeline stage:
|
||||||
DY[BLOCK_M, BLOCK_N], W^T[BLOCK_N, BLOCK_K], B[BLOCK_N, BLOCK_R].
|
DY[BLOCK_M, BLOCK_N], W^T[BLOCK_N, BLOCK_K], B[BLOCK_N, BLOCK_R].
|
||||||
@@ -822,23 +983,49 @@ def _prune_dX_configs(configs, named_args, **kwargs):
|
|||||||
|
|
||||||
scored = []
|
scored = []
|
||||||
for config in configs:
|
for config in configs:
|
||||||
|
block_m = config.kwargs["BLOCK_M"]
|
||||||
block_k = config.kwargs["BLOCK_K"]
|
block_k = config.kwargs["BLOCK_K"]
|
||||||
block_n = config.kwargs["BLOCK_N"]
|
block_n = config.kwargs["BLOCK_N"]
|
||||||
# Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K
|
# Base: stages * BLOCK_N * (BLOCK_M + BLOCK_K) + BLOCK_M * BLOCK_K
|
||||||
smem_base = _estimate_smem_usage(config.num_stages, BLOCK_M, block_k, block_n)
|
smem_base = _estimate_smem_usage(config.num_stages, block_m, block_k, block_n)
|
||||||
# B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop
|
# B tile [BLOCK_N, BLOCK_R] loaded per stage in the inner loop
|
||||||
smem_lora_loop = config.num_stages * block_n * block_r * 2
|
smem_lora_loop = config.num_stages * block_n * block_r * 2
|
||||||
# A tile [BLOCK_R, BLOCK_K] loaded once in epilogue
|
# A tile [BLOCK_R, BLOCK_K] loaded once in epilogue
|
||||||
smem_lora_epilogue = block_r * block_k * 2
|
smem_lora_epilogue = block_r * block_k * 2
|
||||||
smem = smem_base + smem_lora_loop + smem_lora_epilogue
|
smem = smem_base + smem_lora_loop + smem_lora_epilogue
|
||||||
|
|
||||||
|
# Register pressure: live tiles are acc[M,K], dy_b_acc[M,R],
|
||||||
|
# dy[M,N], wt[N,K], b[N,R], plus epilogue a[R,K]
|
||||||
|
est_regs = _estimate_register_pressure(
|
||||||
|
config.num_warps,
|
||||||
|
(block_m, block_k), # acc
|
||||||
|
(block_m, block_r), # dy_b_acc
|
||||||
|
(block_m, block_n), # dy tile
|
||||||
|
(block_n, block_k), # wt tile
|
||||||
|
(block_n, block_r), # b tile
|
||||||
|
(block_r, block_k), # a tile (epilogue)
|
||||||
|
)
|
||||||
|
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||||
|
continue
|
||||||
|
|
||||||
scored.append((smem, config))
|
scored.append((smem, config))
|
||||||
|
|
||||||
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
|
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
|
||||||
if pruned:
|
if pruned:
|
||||||
return pruned
|
return pruned
|
||||||
# All configs exceed SMEM — return the one with smallest estimated usage
|
if scored:
|
||||||
scored.sort(key=lambda x: x[0])
|
# All surviving configs exceed SMEM — return the one with smallest usage
|
||||||
return [scored[0][1]]
|
scored.sort(key=lambda x: x[0])
|
||||||
|
return [scored[0][1]]
|
||||||
|
# All configs pruned by register pressure — fall back to smallest tiles
|
||||||
|
return [
|
||||||
|
min(
|
||||||
|
configs,
|
||||||
|
key=lambda c: (
|
||||||
|
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
@@ -1067,7 +1254,7 @@ def scatter2scatter_lora_dX(
|
|||||||
N=N,
|
N=N,
|
||||||
E=E,
|
E=E,
|
||||||
ACTUAL_R=R,
|
ACTUAL_R=R,
|
||||||
BLOCK_M=BLOCK_M,
|
# BLOCK_M is autotuned (injected by triton.autotune from Config kwargs)
|
||||||
BLOCK_R=BLOCK_R,
|
BLOCK_R=BLOCK_R,
|
||||||
ACC_TYPE=tl.float32,
|
ACC_TYPE=tl.float32,
|
||||||
scaling=scaling,
|
scaling=scaling,
|
||||||
@@ -1119,7 +1306,7 @@ def _group_bwd_lora_configs():
|
|||||||
|
|
||||||
|
|
||||||
def _prune_bwd_lora_configs(configs, named_args, **kwargs):
|
def _prune_bwd_lora_configs(configs, named_args, **kwargs):
|
||||||
"""Prune backward configs based on SMEM capacity.
|
"""Prune backward configs based on SMEM capacity and register pressure.
|
||||||
|
|
||||||
The backward kernel loads X[BLOCK_M, BLOCK_K] and DY[BLOCK_M, BLOCK_N]
|
The backward kernel loads X[BLOCK_M, BLOCK_K] and DY[BLOCK_M, BLOCK_N]
|
||||||
in the inner loop, plus holds A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R]
|
in the inner loop, plus holds A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R]
|
||||||
@@ -1138,14 +1325,40 @@ def _prune_bwd_lora_configs(configs, named_args, **kwargs):
|
|||||||
# A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] held for the full expert
|
# A[BLOCK_R, BLOCK_K] and B[BLOCK_N, BLOCK_R] held for the full expert
|
||||||
smem_lora = (block_r * block_k + block_n * block_r) * 2
|
smem_lora = (block_r * block_k + block_n * block_r) * 2
|
||||||
smem = smem_base + smem_lora
|
smem = smem_base + smem_lora
|
||||||
|
|
||||||
|
# Register pressure: dA_acc[R,K], dB_acc[N,R], x[M,K], dy[M,N],
|
||||||
|
# a[R,K], b[N,R], xa[M,R], dy_b[M,R]
|
||||||
|
est_regs = _estimate_register_pressure(
|
||||||
|
config.num_warps,
|
||||||
|
(block_r, block_k), # dA_acc
|
||||||
|
(block_n, block_r), # dB_acc
|
||||||
|
(block_m, block_k), # x tile
|
||||||
|
(block_m, block_n), # dy tile
|
||||||
|
(block_r, block_k), # a tile
|
||||||
|
(block_n, block_r), # b tile
|
||||||
|
(block_m, block_r), # xa intermediate
|
||||||
|
)
|
||||||
|
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||||
|
continue
|
||||||
|
|
||||||
scored.append((smem, config))
|
scored.append((smem, config))
|
||||||
|
|
||||||
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
|
pruned = [c for s, c in scored if s <= smem_cap - _SMEM_SLACK]
|
||||||
if pruned:
|
if pruned:
|
||||||
return pruned
|
return pruned
|
||||||
# All configs exceed SMEM — return the one with smallest estimated usage
|
if scored:
|
||||||
scored.sort(key=lambda x: x[0])
|
# All surviving configs exceed SMEM — return the one with smallest usage
|
||||||
return [scored[0][1]]
|
scored.sort(key=lambda x: x[0])
|
||||||
|
return [scored[0][1]]
|
||||||
|
# All configs pruned by register pressure — fall back to smallest tiles
|
||||||
|
return [
|
||||||
|
min(
|
||||||
|
configs,
|
||||||
|
key=lambda c: (
|
||||||
|
c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_K"] * c.kwargs["BLOCK_N"]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
@@ -1330,6 +1543,279 @@ def _group_bwd_lora(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _group_bwd_split_configs():
|
||||||
|
"""Autotune configs for split dA/dB kernels."""
|
||||||
|
configs = []
|
||||||
|
for block_m, block_dim, warps, stages in product(
|
||||||
|
[32, 64, 128], # BLOCK_M (token tile)
|
||||||
|
[32, 64, 128, 256], # BLOCK_DIM (K for dA, N for dB — output tile)
|
||||||
|
[4, 8], # num_warps
|
||||||
|
[3, 4, 5], # num_stages
|
||||||
|
):
|
||||||
|
configs.append(
|
||||||
|
triton.Config(
|
||||||
|
{"BLOCK_M": block_m, "BLOCK_DIM": block_dim},
|
||||||
|
num_stages=stages,
|
||||||
|
num_warps=warps,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def _prune_split_configs(configs, named_args, **kwargs):
|
||||||
|
"""Prune split kernel configs based on SMEM capacity and register pressure."""
|
||||||
|
smem_cap = _get_smem_capacity()
|
||||||
|
block_r = named_args.get("BLOCK_R", 64)
|
||||||
|
|
||||||
|
# Fixed inner tile for reduction dimension
|
||||||
|
BLOCK_INNER = 64
|
||||||
|
|
||||||
|
pruned = []
|
||||||
|
for config in configs:
|
||||||
|
block_m = config.kwargs["BLOCK_M"]
|
||||||
|
block_dim = config.kwargs["BLOCK_DIM"]
|
||||||
|
# Inner loop loads: input[M, INNER] and other[M, INNER_or_DIM]
|
||||||
|
smem = config.num_stages * BLOCK_INNER * (block_m + block_dim) * 2
|
||||||
|
# LoRA weights held in registers: [INNER, R] or [R, DIM]
|
||||||
|
smem += (block_r * max(block_dim, BLOCK_INNER)) * 2
|
||||||
|
|
||||||
|
# Register pressure check
|
||||||
|
est_regs = _estimate_register_pressure(
|
||||||
|
config.num_warps,
|
||||||
|
(block_r, block_dim), # acc
|
||||||
|
(block_m, BLOCK_INNER), # input tile
|
||||||
|
(block_m, block_dim), # other tile
|
||||||
|
(block_r, BLOCK_INNER), # lora weight
|
||||||
|
)
|
||||||
|
if est_regs > _MAX_REGS_SOFT_LIMIT:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if smem <= smem_cap - _SMEM_SLACK:
|
||||||
|
pruned.append(config)
|
||||||
|
|
||||||
|
if pruned:
|
||||||
|
return pruned
|
||||||
|
configs.sort(key=lambda c: c.kwargs["BLOCK_M"] * c.kwargs["BLOCK_DIM"])
|
||||||
|
return [configs[0]]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.autotune(
|
||||||
|
configs=_group_bwd_split_configs(),
|
||||||
|
key=["M", "K", "N"],
|
||||||
|
prune_configs_by={"early_config_prune": _prune_split_configs},
|
||||||
|
)
|
||||||
|
@triton.heuristics(
|
||||||
|
{
|
||||||
|
"NO_DIM_MASK": lambda args: (
|
||||||
|
(args["K"] % args["BLOCK_DIM"]) == 0
|
||||||
|
if args["COMPUTE_DA"]
|
||||||
|
else (args["N"] % args["BLOCK_DIM"]) == 0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def _group_bwd_lora_split(
|
||||||
|
# Data tensors (DY and X are always present)
|
||||||
|
DY_ptr,
|
||||||
|
stride_dym,
|
||||||
|
stride_dyn,
|
||||||
|
X_ptr,
|
||||||
|
stride_xm,
|
||||||
|
stride_xk,
|
||||||
|
# LoRA weight for the inner reduction (B for dA, A for dB)
|
||||||
|
LW_ptr,
|
||||||
|
stride_lw0,
|
||||||
|
stride_lw1,
|
||||||
|
# Output gradient tensor (dA or dB)
|
||||||
|
OUT_ptr,
|
||||||
|
stride_out0,
|
||||||
|
stride_out1,
|
||||||
|
# Expert offsets
|
||||||
|
expert_offsets_ptr,
|
||||||
|
# Dimensions
|
||||||
|
M,
|
||||||
|
K: tl.constexpr,
|
||||||
|
N: tl.constexpr,
|
||||||
|
ACTUAL_R: tl.constexpr,
|
||||||
|
BLOCK_R: tl.constexpr,
|
||||||
|
INNER_DIM: tl.constexpr, # reduction dimension (N for dA, K for dB)
|
||||||
|
scaling,
|
||||||
|
# Mode flag
|
||||||
|
COMPUTE_DA: tl.constexpr, # True = compute dA, False = compute dB
|
||||||
|
# Tile sizes
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DIM: tl.constexpr,
|
||||||
|
ACC_TYPE: tl.constexpr,
|
||||||
|
allow_tf32: tl.constexpr,
|
||||||
|
NO_DIM_MASK: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Unified split kernel for LoRA gradient computation.
|
||||||
|
|
||||||
|
When COMPUTE_DA=True:
|
||||||
|
dA[e] = scaling * (dY @ B[e])^T @ X → [R, K]
|
||||||
|
Grid: (E, cdiv(K, BLOCK_DIM))
|
||||||
|
- outer_ptr/stride = X (read [M, K_block])
|
||||||
|
- inner reduction over N using DY and B
|
||||||
|
- output shape [BLOCK_R, BLOCK_DIM]
|
||||||
|
|
||||||
|
When COMPUTE_DA=False:
|
||||||
|
dB[e] = scaling * dY^T @ (X @ A[e]^T) → [N, R]
|
||||||
|
Grid: (E, cdiv(N, BLOCK_DIM))
|
||||||
|
- outer_ptr/stride = DY (read [M, N_block])
|
||||||
|
- inner reduction over K using X and A
|
||||||
|
- output shape [BLOCK_DIM, BLOCK_R]
|
||||||
|
|
||||||
|
No atomic adds — each (E, dim_block) pair is written by exactly one block.
|
||||||
|
"""
|
||||||
|
E_idx = tl.program_id(0)
|
||||||
|
dim_block_id = tl.program_id(1)
|
||||||
|
|
||||||
|
if E_idx == 0:
|
||||||
|
start_idx = 0
|
||||||
|
else:
|
||||||
|
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
|
||||||
|
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
|
||||||
|
num_tokens = end_idx - start_idx
|
||||||
|
|
||||||
|
# Output dimension tile (K for dA, N for dB)
|
||||||
|
if COMPUTE_DA:
|
||||||
|
OUT_DIM: tl.constexpr = K # type: ignore[no-redef]
|
||||||
|
else:
|
||||||
|
OUT_DIM: tl.constexpr = N # type: ignore[no-redef]
|
||||||
|
dim_block = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
|
||||||
|
dim_mask = dim_block < OUT_DIM
|
||||||
|
R_block = tl.arange(0, BLOCK_R)
|
||||||
|
R_mask = R_block < ACTUAL_R
|
||||||
|
lora_offset = E_idx * ACTUAL_R
|
||||||
|
|
||||||
|
# Output pointers — layout differs: dA is [R, K], dB is [N, R]
|
||||||
|
if COMPUTE_DA:
|
||||||
|
out_blk_ptrs = (
|
||||||
|
OUT_ptr
|
||||||
|
+ (lora_offset + R_block)[:, None] * stride_out0
|
||||||
|
+ dim_block[None, :] * stride_out1
|
||||||
|
)
|
||||||
|
out_mask = R_mask[:, None] & dim_mask[None, :]
|
||||||
|
else:
|
||||||
|
out_blk_ptrs = (
|
||||||
|
OUT_ptr
|
||||||
|
+ dim_block[:, None] * stride_out0
|
||||||
|
+ (lora_offset + R_block)[None, :] * stride_out1
|
||||||
|
)
|
||||||
|
out_mask = dim_mask[:, None] & R_mask[None, :]
|
||||||
|
|
||||||
|
if num_tokens > 0:
|
||||||
|
M_block = tl.arange(0, BLOCK_M)
|
||||||
|
INPUT_DTYPE = X_ptr.dtype.element_ty
|
||||||
|
BLOCK_INNER: tl.constexpr = 64
|
||||||
|
inner_iters = tl.cdiv(INNER_DIM, BLOCK_INNER)
|
||||||
|
|
||||||
|
if COMPUTE_DA:
|
||||||
|
acc = tl.zeros((BLOCK_R, BLOCK_DIM), dtype=ACC_TYPE)
|
||||||
|
else:
|
||||||
|
acc = tl.zeros((BLOCK_DIM, BLOCK_R), dtype=ACC_TYPE)
|
||||||
|
|
||||||
|
M_iters = tl.cdiv(num_tokens, BLOCK_M)
|
||||||
|
for i in range(M_iters):
|
||||||
|
M_idx = start_idx + i * BLOCK_M + M_block
|
||||||
|
M_mask = M_idx < end_idx
|
||||||
|
|
||||||
|
if COMPUTE_DA:
|
||||||
|
# Load X[M, K_block] (the "outer" tensor for dA)
|
||||||
|
outer = tl.load(
|
||||||
|
X_ptr + M_idx[:, None] * stride_xm + dim_block[None, :] * stride_xk,
|
||||||
|
mask=M_mask[:, None] & dim_mask[None, :],
|
||||||
|
other=0.0,
|
||||||
|
).to(INPUT_DTYPE)
|
||||||
|
|
||||||
|
# Reduce DY[M, :] @ B[e][:, R] over N → [M, R]
|
||||||
|
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
|
||||||
|
inner_range = tl.arange(0, BLOCK_INNER)
|
||||||
|
for j in range(inner_iters):
|
||||||
|
inn_off = j * BLOCK_INNER + inner_range
|
||||||
|
inn_mask = inn_off < N
|
||||||
|
|
||||||
|
dy_tile = tl.load(
|
||||||
|
DY_ptr
|
||||||
|
+ M_idx[:, None] * stride_dym
|
||||||
|
+ inn_off[None, :] * stride_dyn,
|
||||||
|
mask=M_mask[:, None] & inn_mask[None, :],
|
||||||
|
other=0.0,
|
||||||
|
).to(INPUT_DTYPE)
|
||||||
|
# B layout: [N, r*E] → stride_lw0=N stride, stride_lw1=r*E stride
|
||||||
|
lw_tile = tl.load(
|
||||||
|
LW_ptr
|
||||||
|
+ inn_off[:, None] * stride_lw0
|
||||||
|
+ (lora_offset + R_block)[None, :] * stride_lw1,
|
||||||
|
mask=inn_mask[:, None] & R_mask[None, :],
|
||||||
|
other=0.0,
|
||||||
|
).to(INPUT_DTYPE)
|
||||||
|
reduced += tl.dot(dy_tile, lw_tile, allow_tf32=allow_tf32)
|
||||||
|
|
||||||
|
# dA += (DY@B)^T @ X: [R, M] @ [M, K_block] → [R, K_block]
|
||||||
|
acc += tl.dot(
|
||||||
|
tl.trans(reduced.to(INPUT_DTYPE)), outer, allow_tf32=allow_tf32
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Load DY[M, N_block] (the "outer" tensor for dB)
|
||||||
|
outer = tl.load(
|
||||||
|
DY_ptr
|
||||||
|
+ M_idx[:, None] * stride_dym
|
||||||
|
+ dim_block[None, :] * stride_dyn,
|
||||||
|
mask=M_mask[:, None] & dim_mask[None, :],
|
||||||
|
other=0.0,
|
||||||
|
).to(INPUT_DTYPE)
|
||||||
|
|
||||||
|
# Reduce X[M, :] @ A[e][:, :].T over K → [M, R]
|
||||||
|
reduced = tl.zeros((BLOCK_M, BLOCK_R), dtype=ACC_TYPE)
|
||||||
|
inner_range = tl.arange(0, BLOCK_INNER)
|
||||||
|
for j in range(inner_iters):
|
||||||
|
inn_off = j * BLOCK_INNER + inner_range
|
||||||
|
inn_mask = inn_off < K
|
||||||
|
|
||||||
|
x_tile = tl.load(
|
||||||
|
X_ptr
|
||||||
|
+ M_idx[:, None] * stride_xm
|
||||||
|
+ inn_off[None, :] * stride_xk,
|
||||||
|
mask=M_mask[:, None] & inn_mask[None, :],
|
||||||
|
other=0.0,
|
||||||
|
).to(INPUT_DTYPE)
|
||||||
|
# A layout: [r*E, K] → stride_lw0=r*E stride, stride_lw1=K stride
|
||||||
|
# We want A[e]^T: [K, R], so load as [K_inner, R]
|
||||||
|
lw_tile = tl.load(
|
||||||
|
LW_ptr
|
||||||
|
+ (lora_offset + R_block)[None, :] * stride_lw0
|
||||||
|
+ inn_off[:, None] * stride_lw1,
|
||||||
|
mask=inn_mask[:, None] & R_mask[None, :],
|
||||||
|
other=0.0,
|
||||||
|
).to(INPUT_DTYPE)
|
||||||
|
reduced += tl.dot(x_tile, lw_tile, allow_tf32=allow_tf32)
|
||||||
|
|
||||||
|
# dB += DY^T @ (X@A^T): [N_block, M] @ [M, R] → [N_block, R]
|
||||||
|
acc += tl.dot(
|
||||||
|
tl.trans(outer), reduced.to(INPUT_DTYPE), allow_tf32=allow_tf32
|
||||||
|
)
|
||||||
|
|
||||||
|
tl.store(
|
||||||
|
out_blk_ptrs, (acc * scaling).to(OUT_ptr.dtype.element_ty), mask=out_mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Zero out this expert's slice — needed because output uses empty_like
|
||||||
|
if COMPUTE_DA:
|
||||||
|
tl.store(
|
||||||
|
out_blk_ptrs,
|
||||||
|
tl.zeros((BLOCK_R, BLOCK_DIM), dtype=OUT_ptr.dtype.element_ty),
|
||||||
|
mask=out_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tl.store(
|
||||||
|
out_blk_ptrs,
|
||||||
|
tl.zeros((BLOCK_DIM, BLOCK_R), dtype=OUT_ptr.dtype.element_ty),
|
||||||
|
mask=out_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def group_bwd_lora(
|
def group_bwd_lora(
|
||||||
DY: torch.Tensor,
|
DY: torch.Tensor,
|
||||||
X: torch.Tensor,
|
X: torch.Tensor,
|
||||||
@@ -1344,6 +1830,9 @@ def group_bwd_lora(
|
|||||||
"""
|
"""
|
||||||
Compute LoRA gradients for A and B on expert-grouped data.
|
Compute LoRA gradients for A and B on expert-grouped data.
|
||||||
|
|
||||||
|
Uses split dA/dB kernels that eliminate atomic adds by giving each
|
||||||
|
(expert, output_block) pair its own thread block.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
DY: Gradient w.r.t. output [M_total, N] (grouped by expert)
|
DY: Gradient w.r.t. output [M_total, N] (grouped by expert)
|
||||||
X: Input [M_total, K] (grouped by expert)
|
X: Input [M_total, K] (grouped by expert)
|
||||||
@@ -1361,19 +1850,46 @@ def group_bwd_lora(
|
|||||||
K = X.size(1)
|
K = X.size(1)
|
||||||
N = DY.size(1)
|
N = DY.size(1)
|
||||||
|
|
||||||
# Zero-init for atomic accumulation
|
# No zero-init needed: the split kernels write zeros for experts with
|
||||||
dA = torch.zeros_like(lora_A)
|
# zero routed tokens directly in the kernel (else branch).
|
||||||
dB = torch.zeros_like(lora_B)
|
dA = torch.empty_like(lora_A)
|
||||||
|
dB = torch.empty_like(lora_B)
|
||||||
|
|
||||||
BLOCK_R = _block_r_for_rank(R)
|
BLOCK_R = _block_r_for_rank(R)
|
||||||
|
|
||||||
def grid(META):
|
def grid_dA(META):
|
||||||
return (
|
return (E, triton.cdiv(K, META["BLOCK_DIM"]))
|
||||||
E * triton.cdiv(K, META["BLOCK_K"]),
|
|
||||||
triton.cdiv(N, META["BLOCK_N"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
_group_bwd_lora[grid](
|
_group_bwd_lora_split[grid_dA](
|
||||||
|
DY,
|
||||||
|
DY.stride(0),
|
||||||
|
DY.stride(1),
|
||||||
|
X,
|
||||||
|
X.stride(0),
|
||||||
|
X.stride(1),
|
||||||
|
lora_B,
|
||||||
|
lora_B.stride(0),
|
||||||
|
lora_B.stride(1),
|
||||||
|
dA,
|
||||||
|
dA.stride(0),
|
||||||
|
dA.stride(1),
|
||||||
|
expert_offsets,
|
||||||
|
M=DY.size(0),
|
||||||
|
K=K,
|
||||||
|
N=N,
|
||||||
|
ACTUAL_R=R,
|
||||||
|
BLOCK_R=BLOCK_R,
|
||||||
|
INNER_DIM=N,
|
||||||
|
scaling=scaling,
|
||||||
|
COMPUTE_DA=True,
|
||||||
|
ACC_TYPE=tl.float32,
|
||||||
|
allow_tf32=ALLOW_TF32,
|
||||||
|
)
|
||||||
|
|
||||||
|
def grid_dB(META):
|
||||||
|
return (E, triton.cdiv(N, META["BLOCK_DIM"]))
|
||||||
|
|
||||||
|
_group_bwd_lora_split[grid_dB](
|
||||||
DY,
|
DY,
|
||||||
DY.stride(0),
|
DY.stride(0),
|
||||||
DY.stride(1),
|
DY.stride(1),
|
||||||
@@ -1383,12 +1899,6 @@ def group_bwd_lora(
|
|||||||
lora_A,
|
lora_A,
|
||||||
lora_A.stride(0),
|
lora_A.stride(0),
|
||||||
lora_A.stride(1),
|
lora_A.stride(1),
|
||||||
lora_B,
|
|
||||||
lora_B.stride(0),
|
|
||||||
lora_B.stride(1),
|
|
||||||
dA,
|
|
||||||
dA.stride(0),
|
|
||||||
dA.stride(1),
|
|
||||||
dB,
|
dB,
|
||||||
dB.stride(0),
|
dB.stride(0),
|
||||||
dB.stride(1),
|
dB.stride(1),
|
||||||
@@ -1396,9 +1906,11 @@ def group_bwd_lora(
|
|||||||
M=DY.size(0),
|
M=DY.size(0),
|
||||||
K=K,
|
K=K,
|
||||||
N=N,
|
N=N,
|
||||||
ACTUAL_R=R, # True LoRA rank
|
ACTUAL_R=R,
|
||||||
BLOCK_R=BLOCK_R, # Padded tile size
|
BLOCK_R=BLOCK_R,
|
||||||
|
INNER_DIM=K,
|
||||||
scaling=scaling,
|
scaling=scaling,
|
||||||
|
COMPUTE_DA=False,
|
||||||
ACC_TYPE=tl.float32,
|
ACC_TYPE=tl.float32,
|
||||||
allow_tf32=ALLOW_TF32,
|
allow_tf32=ALLOW_TF32,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -489,20 +489,71 @@ class HFScatterMoEGatedMLP(nn.Module):
|
|||||||
# ====================================================================
|
# ====================================================================
|
||||||
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
|
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
|
||||||
|
|
||||||
|
# ====================================================================
|
||||||
|
# Selective expert weight dequantization
|
||||||
|
# ====================================================================
|
||||||
|
# When experts are BnB-quantized (quantize_moe_experts), dequantize
|
||||||
|
# only the active experts instead of all E. This saves ~97% memory
|
||||||
|
# for the transient dequant buffer when few experts are active.
|
||||||
|
use_selective = (
|
||||||
|
getattr(self, "_use_selective_dequant", False)
|
||||||
|
and hasattr(experts, "parametrizations")
|
||||||
|
and "gate_up_proj" in experts.parametrizations
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_selective:
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant import (
|
||||||
|
get_active_experts,
|
||||||
|
remap_expert_indices,
|
||||||
|
selective_expert_weights,
|
||||||
|
selective_lora_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
active_experts = get_active_experts(sorted_expert_idxs, num_experts)
|
||||||
|
remapped_expert_idxs, compact_offsets = remap_expert_indices(
|
||||||
|
sorted_expert_idxs,
|
||||||
|
expert_offsets,
|
||||||
|
active_experts,
|
||||||
|
num_experts,
|
||||||
|
)
|
||||||
|
# Dequantize only active experts' weights
|
||||||
|
gate_up_W = selective_expert_weights(
|
||||||
|
experts,
|
||||||
|
"gate_up_proj",
|
||||||
|
active_experts,
|
||||||
|
).transpose(2, 1) # [num_active, hidden, 2*inter]
|
||||||
|
|
||||||
|
# Remap LoRA weights to match compact expert indices
|
||||||
|
if gup_lora is not None:
|
||||||
|
gup_A, gup_B, gup_scaling = gup_lora
|
||||||
|
gup_A, gup_B = selective_lora_weights(
|
||||||
|
gup_A,
|
||||||
|
gup_B,
|
||||||
|
active_experts,
|
||||||
|
num_experts,
|
||||||
|
)
|
||||||
|
gup_lora = (gup_A, gup_B, gup_scaling)
|
||||||
|
|
||||||
|
# Use remapped indices for ScatterMoE kernels
|
||||||
|
sei_gup = remapped_expert_idxs
|
||||||
|
eo_gup = compact_offsets
|
||||||
|
else:
|
||||||
|
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
|
||||||
|
sei_gup = sorted_expert_idxs
|
||||||
|
eo_gup = expert_offsets
|
||||||
|
|
||||||
# ====================================================================
|
# ====================================================================
|
||||||
# Gate + Up projection
|
# Gate + Up projection
|
||||||
# ====================================================================
|
# ====================================================================
|
||||||
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
|
|
||||||
|
|
||||||
if gup_lora is not None:
|
if gup_lora is not None:
|
||||||
gup_A, gup_B, gup_scaling = gup_lora
|
gup_A, gup_B, gup_scaling = gup_lora
|
||||||
gup = parallel_linear_lora(
|
gup = parallel_linear_lora(
|
||||||
hidden_states_flat,
|
hidden_states_flat,
|
||||||
gate_up_W,
|
gate_up_W,
|
||||||
top_k,
|
top_k,
|
||||||
sorted_expert_idxs,
|
sei_gup,
|
||||||
sorted_scattered_idxs,
|
sorted_scattered_idxs,
|
||||||
expert_offsets,
|
eo_gup,
|
||||||
lora_A=gup_A,
|
lora_A=gup_A,
|
||||||
lora_B=gup_B,
|
lora_B=gup_B,
|
||||||
scaling=gup_scaling,
|
scaling=gup_scaling,
|
||||||
@@ -516,9 +567,9 @@ class HFScatterMoEGatedMLP(nn.Module):
|
|||||||
hidden_states_flat,
|
hidden_states_flat,
|
||||||
gate_up_W,
|
gate_up_W,
|
||||||
top_k,
|
top_k,
|
||||||
sorted_expert_idxs,
|
sei_gup,
|
||||||
sorted_scattered_idxs,
|
sorted_scattered_idxs,
|
||||||
expert_offsets,
|
eo_gup,
|
||||||
grouped_in=False,
|
grouped_in=False,
|
||||||
grouped_out=True,
|
grouped_out=True,
|
||||||
)
|
)
|
||||||
@@ -529,7 +580,29 @@ class HFScatterMoEGatedMLP(nn.Module):
|
|||||||
# ====================================================================
|
# ====================================================================
|
||||||
# Down projection
|
# Down projection
|
||||||
# ====================================================================
|
# ====================================================================
|
||||||
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
|
if use_selective:
|
||||||
|
down_W = selective_expert_weights(
|
||||||
|
experts,
|
||||||
|
"down_proj",
|
||||||
|
active_experts,
|
||||||
|
).transpose(2, 1) # [num_active, inter, hidden]
|
||||||
|
|
||||||
|
if down_lora is not None:
|
||||||
|
down_A, down_B, down_scaling = down_lora
|
||||||
|
down_A, down_B = selective_lora_weights(
|
||||||
|
down_A,
|
||||||
|
down_B,
|
||||||
|
active_experts,
|
||||||
|
num_experts,
|
||||||
|
)
|
||||||
|
down_lora = (down_A, down_B, down_scaling)
|
||||||
|
|
||||||
|
sei_down = remapped_expert_idxs
|
||||||
|
eo_down = compact_offsets
|
||||||
|
else:
|
||||||
|
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
|
||||||
|
sei_down = sorted_expert_idxs
|
||||||
|
eo_down = expert_offsets
|
||||||
|
|
||||||
if down_lora is not None:
|
if down_lora is not None:
|
||||||
down_A, down_B, down_scaling = down_lora
|
down_A, down_B, down_scaling = down_lora
|
||||||
@@ -537,9 +610,9 @@ class HFScatterMoEGatedMLP(nn.Module):
|
|||||||
h,
|
h,
|
||||||
down_W,
|
down_W,
|
||||||
1,
|
1,
|
||||||
sorted_expert_idxs,
|
sei_down,
|
||||||
sorted_scattered_idxs,
|
sorted_scattered_idxs,
|
||||||
expert_offsets,
|
eo_down,
|
||||||
lora_A=down_A,
|
lora_A=down_A,
|
||||||
lora_B=down_B,
|
lora_B=down_B,
|
||||||
scaling=down_scaling,
|
scaling=down_scaling,
|
||||||
@@ -554,9 +627,9 @@ class HFScatterMoEGatedMLP(nn.Module):
|
|||||||
h,
|
h,
|
||||||
down_W,
|
down_W,
|
||||||
1,
|
1,
|
||||||
sorted_expert_idxs,
|
sei_down,
|
||||||
sorted_scattered_idxs,
|
sorted_scattered_idxs,
|
||||||
expert_offsets,
|
eo_down,
|
||||||
grouped_in=True,
|
grouped_in=True,
|
||||||
grouped_out=False,
|
grouped_out=False,
|
||||||
gates=routing_weights,
|
gates=routing_weights,
|
||||||
|
|||||||
@@ -0,0 +1,282 @@
|
|||||||
|
"""
|
||||||
|
Selective Expert Dequantization
|
||||||
|
===============================
|
||||||
|
|
||||||
|
Instead of dequantizing all E expert weight matrices at once (which creates
|
||||||
|
a ~1 GB transient buffer for 256 experts), only dequantize the experts that
|
||||||
|
are actually routed to by the current batch's top-k selection.
|
||||||
|
|
||||||
|
For Qwen3.5-35B-A3B (E=256, top_k=8, hidden=2048, intermediate=512):
|
||||||
|
- Full dequant: [256, 2048, 1024] = 1,074 MB per projection
|
||||||
|
- Selective (8 active): [8, 2048, 1024] = 33.5 MB per projection
|
||||||
|
- Savings: ~97% memory reduction per layer
|
||||||
|
|
||||||
|
This module provides format-agnostic selective weight extraction:
|
||||||
|
- BnB 4-bit (nf4/fp4): slice quantized data + absmax per expert
|
||||||
|
- bf16/fp32: direct indexing (no dequant needed)
|
||||||
|
- FP8: slice + cast
|
||||||
|
|
||||||
|
The ScatterMoE kernel itself doesn't change — we remap expert indices
|
||||||
|
from global (0..E-1) to compact (0..num_active-1) and pass the smaller
|
||||||
|
weight tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_experts(sorted_expert_idxs: torch.Tensor, E: int) -> torch.Tensor:
|
||||||
|
"""Get sorted unique expert indices from the routing output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sorted_expert_idxs: Expert assignments sorted by expert id [T*k]
|
||||||
|
E: Total number of experts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
active: Sorted unique expert indices [num_active]
|
||||||
|
"""
|
||||||
|
return torch.unique(sorted_expert_idxs)
|
||||||
|
|
||||||
|
|
||||||
|
def remap_expert_indices(
|
||||||
|
sorted_expert_idxs: torch.Tensor,
|
||||||
|
expert_offsets: torch.Tensor,
|
||||||
|
active_experts: torch.Tensor,
|
||||||
|
E: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Remap global expert indices to compact indices.
|
||||||
|
|
||||||
|
Maps expert ids from [0..E-1] to [0..num_active-1], preserving the
|
||||||
|
sort order. Also compacts expert_offsets to only active experts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sorted_expert_idxs: [T*k] expert ids in sorted order
|
||||||
|
expert_offsets: [E] cumulative token counts (original)
|
||||||
|
active_experts: [num_active] sorted unique expert ids
|
||||||
|
E: Total number of experts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
remapped_idxs: [T*k] expert ids in [0..num_active-1]
|
||||||
|
compact_offsets: [num_active] cumulative token counts
|
||||||
|
"""
|
||||||
|
# Build remap table: global_id -> compact_id
|
||||||
|
remap = torch.empty(E, dtype=torch.long, device=sorted_expert_idxs.device)
|
||||||
|
remap[active_experts] = torch.arange(
|
||||||
|
len(active_experts), device=sorted_expert_idxs.device
|
||||||
|
)
|
||||||
|
|
||||||
|
remapped_idxs = remap[sorted_expert_idxs]
|
||||||
|
|
||||||
|
# Compact the expert_offsets: only keep active experts' cumulative counts
|
||||||
|
compact_offsets = expert_offsets[active_experts]
|
||||||
|
|
||||||
|
return remapped_idxs, compact_offsets
|
||||||
|
|
||||||
|
|
||||||
|
def _selective_dequant_bnb4(
|
||||||
|
raw_param: torch.Tensor,
|
||||||
|
quant_state,
|
||||||
|
active_experts: torch.Tensor,
|
||||||
|
expert_shape: tuple[int, int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Dequantize only selected experts from BnB 4-bit packed data.
|
||||||
|
|
||||||
|
The raw parameter is a flattened 4-bit packed tensor. Each expert's
|
||||||
|
data is contiguous (stored in expert-major order), so we can gather
|
||||||
|
the packed data and absmax blocks for active experts, then dequantize
|
||||||
|
as one contiguous block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_param: Flattened uint8 tensor of packed 4-bit weights
|
||||||
|
quant_state: BnB QuantState with absmax, blocksize, code, etc.
|
||||||
|
active_experts: [num_active] expert indices to dequantize
|
||||||
|
expert_shape: (dim1, dim2) shape per expert (e.g. (1024, 2048))
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dequantized weights [num_active, dim1, dim2] in original dtype
|
||||||
|
"""
|
||||||
|
import bitsandbytes.functional as F # noqa: N812
|
||||||
|
from bitsandbytes.functional import QuantState
|
||||||
|
|
||||||
|
expert_numel = expert_shape[0] * expert_shape[1]
|
||||||
|
packed_per_expert = expert_numel // 2 # 4-bit = 2 values per byte
|
||||||
|
blocks_per_expert = expert_numel // quant_state.blocksize
|
||||||
|
num_active = len(active_experts)
|
||||||
|
|
||||||
|
if blocks_per_expert == 0:
|
||||||
|
# Expert is smaller than one quantization block — blocks span across
|
||||||
|
# expert boundaries, so per-expert slicing isn't possible.
|
||||||
|
# Fallback: full dequantize + index.
|
||||||
|
full = F.dequantize_4bit(raw_param, quant_state)
|
||||||
|
E_total = full.numel() // expert_numel
|
||||||
|
return full.reshape(E_total, *expert_shape)[active_experts]
|
||||||
|
|
||||||
|
# Use fused Triton kernel for NF4 (handles selective gather + dequant in one pass)
|
||||||
|
if quant_state.quant_type == "nf4" and raw_param.dtype == torch.uint8:
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.selective_dequant_kernel import (
|
||||||
|
selective_dequant_nf4_triton,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle nested (double) quantization: dequantize absmax first
|
||||||
|
# BnB uses dequantize_blockwise (not _4bit) for nested absmax + offset
|
||||||
|
if quant_state.nested:
|
||||||
|
absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
|
||||||
|
absmax += quant_state.offset
|
||||||
|
if absmax.dtype != torch.float32:
|
||||||
|
absmax = absmax.float()
|
||||||
|
else:
|
||||||
|
absmax = quant_state.absmax
|
||||||
|
|
||||||
|
return selective_dequant_nf4_triton(
|
||||||
|
packed_data=raw_param,
|
||||||
|
absmax=absmax,
|
||||||
|
active_experts=active_experts,
|
||||||
|
expert_shape=expert_shape,
|
||||||
|
blocksize=quant_state.blocksize,
|
||||||
|
dtype=quant_state.dtype,
|
||||||
|
codebook=quant_state.code,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback: gather + BnB dequant (for fp4 or non-uint8 packed formats)
|
||||||
|
raw_flat = raw_param.reshape(-1)
|
||||||
|
|
||||||
|
offsets_qt = (
|
||||||
|
active_experts.long()[:, None] * packed_per_expert
|
||||||
|
+ torch.arange(packed_per_expert, device=raw_param.device)[None, :]
|
||||||
|
).reshape(-1)
|
||||||
|
qt_gathered = raw_flat[offsets_qt]
|
||||||
|
|
||||||
|
offsets_abs = (
|
||||||
|
active_experts.long()[:, None] * blocks_per_expert
|
||||||
|
+ torch.arange(blocks_per_expert, device=raw_param.device)[None, :]
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
|
if quant_state.nested:
|
||||||
|
full_absmax = F.dequantize_blockwise(quant_state.absmax, quant_state.state2)
|
||||||
|
full_absmax += quant_state.offset
|
||||||
|
if full_absmax.dtype != torch.float32:
|
||||||
|
full_absmax = full_absmax.float()
|
||||||
|
absmax_gathered = full_absmax[offsets_abs]
|
||||||
|
else:
|
||||||
|
absmax_gathered = quant_state.absmax[offsets_abs]
|
||||||
|
|
||||||
|
qt_gathered = qt_gathered.unsqueeze(1) if qt_gathered.dim() == 1 else qt_gathered
|
||||||
|
|
||||||
|
gathered_qs = QuantState(
|
||||||
|
absmax=absmax_gathered,
|
||||||
|
shape=torch.Size([num_active * expert_numel]),
|
||||||
|
blocksize=quant_state.blocksize,
|
||||||
|
quant_type=quant_state.quant_type,
|
||||||
|
code=quant_state.code,
|
||||||
|
dtype=quant_state.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
deq = F.dequantize_4bit(qt_gathered, gathered_qs)
|
||||||
|
return deq.reshape(num_active, *expert_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _selective_index_dense(
|
||||||
|
param: torch.Tensor,
|
||||||
|
active_experts: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Select experts from a dense (bf16/fp32) weight tensor.
|
||||||
|
|
||||||
|
Simple indexing — no dequantization needed.
|
||||||
|
"""
|
||||||
|
return param[active_experts]
|
||||||
|
|
||||||
|
|
||||||
|
def selective_expert_weights(
|
||||||
|
experts_module: nn.Module,
|
||||||
|
param_name: str,
|
||||||
|
active_experts: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Extract and dequantize only the active experts' weights.
|
||||||
|
|
||||||
|
Format-agnostic: dispatches based on whether the parameter is
|
||||||
|
BnB 4-bit quantized (via parametrize), FP8, or dense bf16/fp32.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
experts_module: The base experts module (e.g. Qwen3_5MoeExperts)
|
||||||
|
param_name: "gate_up_proj" or "down_proj"
|
||||||
|
active_experts: [num_active] sorted unique expert indices
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compact weight tensor [num_active, dim1, dim2] ready for ScatterMoE
|
||||||
|
"""
|
||||||
|
# Check if the parameter is BnB-quantized via parametrize
|
||||||
|
if (
|
||||||
|
hasattr(experts_module, "parametrizations")
|
||||||
|
and param_name in experts_module.parametrizations
|
||||||
|
):
|
||||||
|
param_list = experts_module.parametrizations[param_name]
|
||||||
|
parametrization = param_list[0]
|
||||||
|
|
||||||
|
# BnB 4-bit parametrization
|
||||||
|
if hasattr(parametrization, "quant_state"):
|
||||||
|
# The raw quantized data is on the ParametrizationList, not the
|
||||||
|
# individual Bnb4bitParametrization module
|
||||||
|
raw_param = param_list.original
|
||||||
|
qs = parametrization.quant_state
|
||||||
|
# qs.shape is the original tensor shape before flattening.
|
||||||
|
# For MoE experts it's [E, d1, d2] (3D) or [total_elements] (1D).
|
||||||
|
orig_shape = qs.shape
|
||||||
|
if isinstance(orig_shape, torch.Size) and len(orig_shape) == 3:
|
||||||
|
expert_shape = (orig_shape[1], orig_shape[2])
|
||||||
|
elif isinstance(orig_shape, torch.Size) and len(orig_shape) == 1:
|
||||||
|
# Flattened — need to infer from module attributes
|
||||||
|
E_total = getattr(experts_module, "num_experts", None)
|
||||||
|
if E_total is None:
|
||||||
|
E_total = int(active_experts.max().item()) + 1
|
||||||
|
expert_numel = orig_shape[0] // E_total
|
||||||
|
d2 = getattr(experts_module, "hidden_dim", None) or getattr(
|
||||||
|
experts_module, "intermediate_dim", None
|
||||||
|
)
|
||||||
|
if d2 and expert_numel % d2 == 0:
|
||||||
|
expert_shape = (expert_numel // d2, d2)
|
||||||
|
else:
|
||||||
|
full = getattr(experts_module, param_name)
|
||||||
|
return full[active_experts]
|
||||||
|
else:
|
||||||
|
full = getattr(experts_module, param_name)
|
||||||
|
return full[active_experts]
|
||||||
|
|
||||||
|
return _selective_dequant_bnb4(raw_param, qs, active_experts, expert_shape)
|
||||||
|
|
||||||
|
# Dense parameter (bf16/fp32) — direct indexing
|
||||||
|
param = getattr(experts_module, param_name)
|
||||||
|
if param.dim() == 3:
|
||||||
|
return param[active_experts]
|
||||||
|
|
||||||
|
# Fallback: full access
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
|
def selective_lora_weights(
|
||||||
|
lora_A: torch.Tensor,
|
||||||
|
lora_B: torch.Tensor,
|
||||||
|
active_experts: torch.Tensor,
|
||||||
|
E: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Select LoRA A and B weights for only the active experts.
|
||||||
|
|
||||||
|
LoRA layout (scattermoe format):
|
||||||
|
A: [r*E, K] — expert e occupies rows [e*r : (e+1)*r]
|
||||||
|
B: [N, r*E] — expert e occupies cols [e*r : (e+1)*r]
|
||||||
|
|
||||||
|
Returns compact:
|
||||||
|
A: [r*num_active, K]
|
||||||
|
B: [N, r*num_active]
|
||||||
|
"""
|
||||||
|
R = lora_A.size(0) // E
|
||||||
|
|
||||||
|
# Vectorized gather: active_experts[:, None] * R + arange(R)[None, :]
|
||||||
|
row_idx = (
|
||||||
|
active_experts.long()[:, None] * R
|
||||||
|
+ torch.arange(R, device=lora_A.device)[None, :]
|
||||||
|
).reshape(-1)
|
||||||
|
|
||||||
|
compact_A = lora_A[row_idx] # [r*num_active, K]
|
||||||
|
compact_B = lora_B[:, row_idx] # [N, r*num_active]
|
||||||
|
|
||||||
|
return compact_A, compact_B
|
||||||
@@ -0,0 +1,179 @@
|
|||||||
|
"""
|
||||||
|
Triton kernel for fused selective expert gather + NF4 dequantization.
|
||||||
|
|
||||||
|
Instead of:
|
||||||
|
1. Gather packed uint8 data for active experts (memory copy)
|
||||||
|
2. Gather absmax for active experts (memory copy)
|
||||||
|
3. Call BnB dequantize_4bit CUDA kernel
|
||||||
|
|
||||||
|
This kernel does all three in one pass:
|
||||||
|
- Reads packed NF4 bytes from expert-strided positions
|
||||||
|
- Looks up the NF4 codebook
|
||||||
|
- Multiplies by the per-block absmax
|
||||||
|
- Writes bf16 output directly
|
||||||
|
|
||||||
|
This eliminates the intermediate gather buffer entirely.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
# NF4 codebook (16 values, precomputed by BnB)
|
||||||
|
# These are the normalized float4 reconstruction values
|
||||||
|
NF4_CODEBOOK = [
|
||||||
|
-1.0,
|
||||||
|
-0.6961928009986877,
|
||||||
|
-0.5250730514526367,
|
||||||
|
-0.39491748809814453,
|
||||||
|
-0.28444138169288635,
|
||||||
|
-0.18477343022823334,
|
||||||
|
-0.09105003625154495,
|
||||||
|
0.0,
|
||||||
|
0.07958029955625534,
|
||||||
|
0.16093020141124725,
|
||||||
|
0.24611230194568634,
|
||||||
|
0.33791524171829224,
|
||||||
|
0.44070982933044434,
|
||||||
|
0.5626170039176941,
|
||||||
|
0.7229568362236023,
|
||||||
|
1.0,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _selective_dequant_nf4_kernel(
|
||||||
|
# Input: packed NF4 data (flattened, expert-major order)
|
||||||
|
packed_ptr,
|
||||||
|
# Input: absmax values (flattened, expert-major order)
|
||||||
|
absmax_ptr,
|
||||||
|
# Input: active expert indices
|
||||||
|
active_experts_ptr,
|
||||||
|
# Input: NF4 codebook (16 float values)
|
||||||
|
codebook_ptr,
|
||||||
|
# Output: dequantized bf16 weights [num_active, expert_numel]
|
||||||
|
out_ptr,
|
||||||
|
stride_out_e, # stride for expert dim in output
|
||||||
|
# Dimensions
|
||||||
|
num_active,
|
||||||
|
packed_per_expert, # expert_numel // 2
|
||||||
|
blocks_per_expert, # expert_numel // blocksize
|
||||||
|
blocksize: tl.constexpr,
|
||||||
|
# Tile size
|
||||||
|
BLOCK_SIZE: tl.constexpr, # elements per thread block (must be multiple of 2)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Each program processes BLOCK_SIZE elements from one expert.
|
||||||
|
|
||||||
|
Grid: (num_active, cdiv(expert_numel, BLOCK_SIZE))
|
||||||
|
|
||||||
|
For each output element:
|
||||||
|
1. Compute which byte in packed data contains this element
|
||||||
|
2. Extract the 4-bit nibble (high or low)
|
||||||
|
3. Look up in NF4 codebook
|
||||||
|
4. Scale by absmax for this block
|
||||||
|
"""
|
||||||
|
expert_local_idx = tl.program_id(0) # which active expert (0..num_active-1)
|
||||||
|
block_id = tl.program_id(1) # which element block
|
||||||
|
|
||||||
|
# Load the global expert index
|
||||||
|
expert_global = tl.load(active_experts_ptr + expert_local_idx).to(tl.int64)
|
||||||
|
|
||||||
|
expert_numel = packed_per_expert * 2 # 2 elements per packed byte
|
||||||
|
elem_offset = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = elem_offset < expert_numel
|
||||||
|
|
||||||
|
# Each element is packed as: byte[i//2], low nibble for even i, high for odd i
|
||||||
|
byte_idx = elem_offset // 2
|
||||||
|
is_high = (elem_offset % 2) == 1
|
||||||
|
|
||||||
|
# Read packed bytes from the global expert's region
|
||||||
|
packed_global_offset = expert_global * packed_per_expert + byte_idx
|
||||||
|
packed_bytes = tl.load(packed_ptr + packed_global_offset, mask=mask, other=0).to(
|
||||||
|
tl.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract 4-bit nibble
|
||||||
|
# BnB packing: high nibble = even element, low nibble = odd element
|
||||||
|
nibble = tl.where(is_high, packed_bytes & 0xF, (packed_bytes >> 4) & 0xF)
|
||||||
|
|
||||||
|
# NF4 codebook lookup
|
||||||
|
# Load all 16 codebook values (small, fits in registers)
|
||||||
|
# Use gather from codebook pointer
|
||||||
|
code_val = tl.load(codebook_ptr + nibble, mask=mask, other=0.0)
|
||||||
|
|
||||||
|
# Load absmax for this element's quantization block
|
||||||
|
block_idx = elem_offset // blocksize
|
||||||
|
absmax_global_offset = expert_global * blocks_per_expert + block_idx
|
||||||
|
absmax_val = tl.load(absmax_ptr + absmax_global_offset, mask=mask, other=1.0)
|
||||||
|
|
||||||
|
# Dequantize: value = codebook[nibble] * absmax
|
||||||
|
result = code_val * absmax_val
|
||||||
|
|
||||||
|
# Store to output
|
||||||
|
out_offset = expert_local_idx * stride_out_e + elem_offset
|
||||||
|
tl.store(out_ptr + out_offset, result.to(out_ptr.dtype.element_ty), mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def selective_dequant_nf4_triton(
|
||||||
|
packed_data: torch.Tensor,
|
||||||
|
absmax: torch.Tensor,
|
||||||
|
active_experts: torch.Tensor,
|
||||||
|
expert_shape: tuple[int, int],
|
||||||
|
blocksize: int,
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
codebook: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Fused selective gather + NF4 dequantization via Triton kernel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
packed_data: Flattened packed NF4 data [total_packed] or [total_packed, 1]
|
||||||
|
absmax: Per-block scaling factors [total_blocks]
|
||||||
|
active_experts: Sorted indices of experts to dequantize [num_active]
|
||||||
|
expert_shape: (dim1, dim2) per expert
|
||||||
|
blocksize: Quantization block size
|
||||||
|
dtype: Output dtype (default bf16)
|
||||||
|
codebook: NF4 lookup table [16] (uses default NF4 codebook if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dequantized weights [num_active, dim1, dim2]
|
||||||
|
"""
|
||||||
|
num_active = active_experts.shape[0]
|
||||||
|
expert_numel = expert_shape[0] * expert_shape[1]
|
||||||
|
packed_per_expert = expert_numel // 2
|
||||||
|
blocks_per_expert = expert_numel // blocksize
|
||||||
|
|
||||||
|
# Prepare codebook on device
|
||||||
|
if codebook is None:
|
||||||
|
codebook = torch.tensor(
|
||||||
|
NF4_CODEBOOK, dtype=torch.float32, device=packed_data.device
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
codebook = codebook.to(device=packed_data.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Flatten inputs
|
||||||
|
packed_flat = packed_data.reshape(-1)
|
||||||
|
absmax_flat = absmax.reshape(-1).float() # absmax is usually fp32
|
||||||
|
|
||||||
|
# Output buffer
|
||||||
|
out = torch.empty(num_active, expert_numel, dtype=dtype, device=packed_data.device)
|
||||||
|
|
||||||
|
BLOCK_SIZE = 1024 # Process 1024 elements per thread block
|
||||||
|
|
||||||
|
grid = (num_active, triton.cdiv(expert_numel, BLOCK_SIZE))
|
||||||
|
|
||||||
|
_selective_dequant_nf4_kernel[grid](
|
||||||
|
packed_flat,
|
||||||
|
absmax_flat,
|
||||||
|
active_experts,
|
||||||
|
codebook,
|
||||||
|
out,
|
||||||
|
out.stride(0),
|
||||||
|
num_active=num_active,
|
||||||
|
packed_per_expert=packed_per_expert,
|
||||||
|
blocks_per_expert=blocks_per_expert,
|
||||||
|
blocksize=blocksize,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
return out.reshape(num_active, *expert_shape)
|
||||||
@@ -61,7 +61,16 @@ class KernelsPlugin(BasePlugin):
|
|||||||
return "axolotl.integrations.kernels.KernelsArgs"
|
return "axolotl.integrations.kernels.KernelsArgs"
|
||||||
|
|
||||||
def pre_model_load(self, cfg):
|
def pre_model_load(self, cfg):
|
||||||
|
from axolotl.integrations.kernels.constants import SPARSE_MOE_BLOCK
|
||||||
|
|
||||||
|
# Prefer text backbone type for VLMs, but fall back to base type
|
||||||
|
# when the text type isn't in the supported mapping (e.g. qwen3_5_moe_text)
|
||||||
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
|
moe_model_type = cfg.model_config_type_text or cfg.model_config_type
|
||||||
|
if (
|
||||||
|
moe_model_type not in SPARSE_MOE_BLOCK
|
||||||
|
and cfg.model_config_type in SPARSE_MOE_BLOCK
|
||||||
|
):
|
||||||
|
moe_model_type = cfg.model_config_type
|
||||||
|
|
||||||
if cfg.use_scattermoe:
|
if cfg.use_scattermoe:
|
||||||
self._register_kernels()
|
self._register_kernels()
|
||||||
|
|||||||
@@ -640,7 +640,9 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
del q_weight
|
del q_weight
|
||||||
del q_weight_t
|
del q_weight_t
|
||||||
if A_q is not None and B_q is not None:
|
if A_q is not None and B_q is not None:
|
||||||
grad_X.addmm_(q_grad, torch.mm(B_q_scaled, A_q_scaled))
|
# Stay decomposed: dQ @ B^T gives [T, R], then [T, R] @ (s*A) gives [T, in]
|
||||||
|
# This is 65x fewer FLOPs than materializing B@A into [out, in]
|
||||||
|
grad_X.addmm_(torch.mm(q_grad, B_q_scaled), A_q_scaled)
|
||||||
|
|
||||||
# K path
|
# K path
|
||||||
k_weight_t = dequantize(k_weight, k_quant)
|
k_weight_t = dequantize(k_weight, k_quant)
|
||||||
@@ -648,7 +650,7 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
del k_weight
|
del k_weight
|
||||||
del k_weight_t
|
del k_weight_t
|
||||||
if A_k is not None and B_k is not None:
|
if A_k is not None and B_k is not None:
|
||||||
grad_X.addmm_(k_grad, torch.mm(B_k_scaled, A_k_scaled))
|
grad_X.addmm_(torch.mm(k_grad, B_k_scaled), A_k_scaled)
|
||||||
|
|
||||||
# V path
|
# V path
|
||||||
v_weight_t = dequantize(v_weight, v_quant)
|
v_weight_t = dequantize(v_weight, v_quant)
|
||||||
@@ -656,7 +658,7 @@ class LoRA_QKV(torch.autograd.Function):
|
|||||||
del v_weight
|
del v_weight
|
||||||
del v_weight_t
|
del v_weight_t
|
||||||
if A_v is not None and B_v is not None:
|
if A_v is not None and B_v is not None:
|
||||||
grad_X.addmm_(v_grad, torch.mm(B_v_scaled, A_v_scaled))
|
grad_X.addmm_(torch.mm(v_grad, B_v_scaled), A_v_scaled)
|
||||||
|
|
||||||
# Transpose gradients if needed
|
# Transpose gradients if needed
|
||||||
if d_A_q is not None:
|
if d_A_q is not None:
|
||||||
@@ -819,7 +821,8 @@ class LoRA_O(torch.autograd.Function):
|
|||||||
del W
|
del W
|
||||||
|
|
||||||
A, B = A.to(dtype), B.to(dtype)
|
A, B = A.to(dtype), B.to(dtype)
|
||||||
dX += s * dY @ B @ A
|
# Stay decomposed: dY @ B gives [T, R], then [T, R] @ A gives [T, in]
|
||||||
|
dX.addmm_(torch.mm(dY, B), A, alpha=s)
|
||||||
|
|
||||||
# W, b, W_quant, A, B, s
|
# W, b, W_quant, A, B, s
|
||||||
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None
|
return dX.view(batch, seq_len, hd), None, None, None, d_A.t(), d_B.t(), None
|
||||||
|
|||||||
@@ -505,6 +505,20 @@ class ModelLoader:
|
|||||||
elif not is_ds_zero3:
|
elif not is_ds_zero3:
|
||||||
self.model_kwargs["device_map"] = device_map
|
self.model_kwargs["device_map"] = device_map
|
||||||
|
|
||||||
|
# quantize_moe_experts quantizes expert weights on-the-fly during loading,
|
||||||
|
# so the actual VRAM usage is much less than bf16 estimates.
|
||||||
|
# When device_map is "auto", accelerate's infer_auto_device_map computes
|
||||||
|
# the device map at bf16 size (before quantization), causing it to offload
|
||||||
|
# layers to CPU, which BnB then rejects. Force single-GPU placement to
|
||||||
|
# prevent this. Only applies to the non-FSDP, non-ZeRO3 path (DDP/single).
|
||||||
|
if getattr(self.cfg, "quantize_moe_experts", False) and device_map in (
|
||||||
|
"auto",
|
||||||
|
None,
|
||||||
|
):
|
||||||
|
self.model_kwargs["device_map"] = {
|
||||||
|
"": int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
}
|
||||||
|
|
||||||
cur_device = get_device_type()
|
cur_device = get_device_type()
|
||||||
if "mps" in str(cur_device):
|
if "mps" in str(cur_device):
|
||||||
self.model_kwargs["device_map"] = "mps:0"
|
self.model_kwargs["device_map"] = "mps:0"
|
||||||
|
|||||||
@@ -51,6 +51,29 @@ QKV_PATCHES = [
|
|||||||
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||||
""".lstrip("\n"),
|
""".lstrip("\n"),
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
"""
|
||||||
|
query_states, gate = torch.chunk(
|
||||||
|
self.q_proj(hidden_states).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
|
||||||
|
)
|
||||||
|
gate = gate.reshape(*input_shape, -1)
|
||||||
|
|
||||||
|
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
|
||||||
|
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
||||||
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||||
|
""".lstrip("\n"),
|
||||||
|
"""
|
||||||
|
query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
||||||
|
query_states, gate = torch.chunk(
|
||||||
|
query_states.view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
|
||||||
|
)
|
||||||
|
gate = gate.reshape(*input_shape, -1)
|
||||||
|
|
||||||
|
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
|
||||||
|
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
|
||||||
|
value_states = value_states.view(hidden_shape).transpose(1, 2)
|
||||||
|
""".lstrip("\n"),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
ORIGINAL_O_CODE = """
|
ORIGINAL_O_CODE = """
|
||||||
@@ -299,6 +322,8 @@ def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
|
|||||||
if hasattr(pretrained_model, "language_model"):
|
if hasattr(pretrained_model, "language_model"):
|
||||||
return pretrained_model.language_model.layers
|
return pretrained_model.language_model.layers
|
||||||
if hasattr(pretrained_model, "model"):
|
if hasattr(pretrained_model, "model"):
|
||||||
|
if hasattr(pretrained_model.model, "language_model"):
|
||||||
|
return pretrained_model.model.language_model.layers
|
||||||
return pretrained_model.model.layers
|
return pretrained_model.model.layers
|
||||||
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ from transformers import (
|
|||||||
class PytorchProfilerCallback(TrainerCallback):
|
class PytorchProfilerCallback(TrainerCallback):
|
||||||
"""
|
"""
|
||||||
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
|
PyTorch Profiler callback to create snapshots of GPU memory usage at specified steps.
|
||||||
|
|
||||||
|
Also runs torch.profiler to produce a Chrome trace for timing analysis.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
|
def __init__(self, steps_to_profile: int = 5, profiler_steps_start: int = 0):
|
||||||
@@ -26,9 +28,10 @@ class PytorchProfilerCallback(TrainerCallback):
|
|||||||
if profiler_steps_start == 0:
|
if profiler_steps_start == 0:
|
||||||
# start recording memory allocations before everything is allocated, because if we start
|
# start recording memory allocations before everything is allocated, because if we start
|
||||||
# at the beginning of step 0, we won't have any memory allocations in the traces
|
# at the beginning of step 0, we won't have any memory allocations in the traces
|
||||||
torch.cuda.memory._record_memory_history(enabled="all")
|
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
|
||||||
profiler_steps_start = -1
|
profiler_steps_start = -1
|
||||||
self.profiler_steps_start = profiler_steps_start
|
self.profiler_steps_start = profiler_steps_start
|
||||||
|
self._profiler = None
|
||||||
|
|
||||||
def on_step_begin(
|
def on_step_begin(
|
||||||
self,
|
self,
|
||||||
@@ -38,7 +41,21 @@ class PytorchProfilerCallback(TrainerCallback):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if state.global_step == self.profiler_steps_start:
|
if state.global_step == self.profiler_steps_start:
|
||||||
torch.cuda.memory._record_memory_history(enabled="all")
|
torch.cuda.memory._record_memory_history(enabled="all", stacks="all")
|
||||||
|
|
||||||
|
# Start torch.profiler on the first profiled step
|
||||||
|
if state.global_step == max(self.profiler_steps_start, 0):
|
||||||
|
profiler = torch.profiler.profile(
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
record_shapes=True,
|
||||||
|
profile_memory=True,
|
||||||
|
with_stack=True,
|
||||||
|
)
|
||||||
|
profiler.__enter__()
|
||||||
|
self._profiler = profiler
|
||||||
|
|
||||||
def on_step_end(
|
def on_step_end(
|
||||||
self,
|
self,
|
||||||
@@ -55,6 +72,13 @@ class PytorchProfilerCallback(TrainerCallback):
|
|||||||
# tell CUDA to stop recording memory allocations now
|
# tell CUDA to stop recording memory allocations now
|
||||||
torch.cuda.memory._record_memory_history(enabled=None)
|
torch.cuda.memory._record_memory_history(enabled=None)
|
||||||
|
|
||||||
|
# Stop and export torch.profiler trace
|
||||||
|
if self._profiler is not None:
|
||||||
|
self._profiler.__exit__(None, None, None)
|
||||||
|
trace_path = Path(args.output_dir) / "profiler_trace.json"
|
||||||
|
self._profiler.export_chrome_trace(str(trace_path))
|
||||||
|
self._profiler = None
|
||||||
|
|
||||||
def on_train_end(
|
def on_train_end(
|
||||||
self,
|
self,
|
||||||
args: TrainingArguments,
|
args: TrainingArguments,
|
||||||
@@ -73,3 +97,9 @@ class PytorchProfilerCallback(TrainerCallback):
|
|||||||
|
|
||||||
# tell CUDA to stop recording memory allocations now
|
# tell CUDA to stop recording memory allocations now
|
||||||
torch.cuda.memory._record_memory_history(enabled=None)
|
torch.cuda.memory._record_memory_history(enabled=None)
|
||||||
|
|
||||||
|
if self._profiler is not None:
|
||||||
|
self._profiler.__exit__(None, None, None)
|
||||||
|
trace_path = Path(args.output_dir) / "profiler_trace.json"
|
||||||
|
self._profiler.export_chrome_trace(str(trace_path))
|
||||||
|
self._profiler = None
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ E2E tests for lora llama
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
|
|
||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
from axolotl.common.datasets import load_datasets
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
@@ -68,51 +67,3 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(temp_dir, cfg)
|
check_model_output_exists(temp_dir, cfg)
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
|
|
||||||
@with_temp_dir
|
|
||||||
def test_lora_gptq_packed(self, temp_dir):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "lilmeaty/SmolLM2-135M-Instruct-GPTQ",
|
|
||||||
"model_type": "AutoModelForCausalLM",
|
|
||||||
"tokenizer_type": "AutoTokenizer",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"sample_packing": True,
|
|
||||||
"flash_attention": True,
|
|
||||||
"load_in_8bit": True,
|
|
||||||
"adapter": "lora",
|
|
||||||
"gptq": True,
|
|
||||||
"gptq_disable_exllama": True,
|
|
||||||
"lora_r": 32,
|
|
||||||
"lora_alpha": 64,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0.02,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 2,
|
|
||||||
"max_steps": 20,
|
|
||||||
"save_steps": 0.5,
|
|
||||||
"micro_batch_size": 8,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch_fused",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"save_first_step": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
dataset_meta = load_datasets(cfg=cfg)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
|
||||||
|
|||||||
407
tests/integrations/test_scattermoe_lora_kernels.py
Normal file
407
tests/integrations/test_scattermoe_lora_kernels.py
Normal file
@@ -0,0 +1,407 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) Axolotl AI
|
||||||
|
# Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
|
"""
|
||||||
|
Unit tests for ScatterMoE LoRA Triton kernels.
|
||||||
|
|
||||||
|
Tests correctness of:
|
||||||
|
- scatter2scatter_lora (forward)
|
||||||
|
- scatter2scatter_lora_dX (backward input gradient)
|
||||||
|
- group_bwd_lora (backward LoRA weight gradients via split dA/dB)
|
||||||
|
- ScatterMoELoRA autograd function (full forward + backward)
|
||||||
|
|
||||||
|
Each kernel is tested against a pure PyTorch per-expert-loop reference
|
||||||
|
implementation at multiple model shapes and LoRA ranks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
|
||||||
|
lora_ops,
|
||||||
|
ops as base_ops,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
|
||||||
|
flatten_sort_count,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
|
||||||
|
ScatterMoELoRA,
|
||||||
|
)
|
||||||
|
|
||||||
|
DEVICE = "cuda"
|
||||||
|
DTYPE = torch.bfloat16
|
||||||
|
|
||||||
|
|
||||||
|
def _requires_cuda():
|
||||||
|
return pytest.mark.skipif(
|
||||||
|
not torch.cuda.is_available(), reason="CUDA not available"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
pytestmark = _requires_cuda()
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _setup(E, K, N, T, top_k, R, seed=42):
|
||||||
|
"""Create synthetic expert weights, LoRA, routing, and grouped inputs."""
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||||
|
W = torch.randn(E, K, N, device=DEVICE, dtype=DTYPE) * 0.02
|
||||||
|
lora_A = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE) * 0.01
|
||||||
|
lora_B = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE) * 0.01
|
||||||
|
logits = torch.randn(T, E, device=DEVICE)
|
||||||
|
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
|
||||||
|
sei, ssi, eo = flatten_sort_count(top_idx, E)
|
||||||
|
return x, W, lora_A, lora_B, sei, ssi, eo
|
||||||
|
|
||||||
|
|
||||||
|
def _reference_fwd(x, W, sei, ssi, eo, k, lora_A, lora_B, scaling, E):
|
||||||
|
"""Per-expert loop reference: Y = X@W + scaling*(X@A^T)@B^T."""
|
||||||
|
grouped_x = base_ops.group(x, ssi, fan_out=k)
|
||||||
|
M, N = grouped_x.size(0), W.size(2)
|
||||||
|
R = lora_A.size(0) // E
|
||||||
|
out = torch.zeros(M, N, device=DEVICE, dtype=DTYPE)
|
||||||
|
for e in range(E):
|
||||||
|
s = eo[e - 1].item() if e > 0 else 0
|
||||||
|
end = eo[e].item()
|
||||||
|
if s == end:
|
||||||
|
continue
|
||||||
|
xe = grouped_x[s:end].float()
|
||||||
|
we = W[e].float()
|
||||||
|
ae = lora_A[e * R : (e + 1) * R].float()
|
||||||
|
be = lora_B[:, e * R : (e + 1) * R].float()
|
||||||
|
out[s:end] = (xe @ we + scaling * (xe @ ae.T) @ be.T).to(DTYPE)
|
||||||
|
result = torch.zeros(M, N, device=DEVICE, dtype=DTYPE)
|
||||||
|
result[ssi] = out
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _reference_dX(dy_grouped, W, sei, ssi, eo, lora_A, lora_B, scaling, E):
|
||||||
|
"""Per-expert loop reference: dX = dY@W^T + scaling*(dY@B)@A."""
|
||||||
|
M, K = dy_grouped.size(0), W.size(1)
|
||||||
|
R = lora_A.size(0) // E
|
||||||
|
out = torch.zeros(M, K, device=DEVICE, dtype=DTYPE)
|
||||||
|
for e in range(E):
|
||||||
|
s = eo[e - 1].item() if e > 0 else 0
|
||||||
|
end = eo[e].item()
|
||||||
|
if s == end:
|
||||||
|
continue
|
||||||
|
dye = dy_grouped[s:end].float()
|
||||||
|
we = W[e].float()
|
||||||
|
ae = lora_A[e * R : (e + 1) * R].float()
|
||||||
|
be = lora_B[:, e * R : (e + 1) * R].float()
|
||||||
|
out[s:end] = (dye @ we.T + scaling * (dye @ be) @ ae).to(DTYPE)
|
||||||
|
result = torch.zeros(M, K, device=DEVICE, dtype=DTYPE)
|
||||||
|
result[ssi] = out
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _reference_bwd_lora(dy, grouped_x, lora_A, lora_B, eo, E, scaling):
|
||||||
|
"""Per-expert loop reference: dA, dB for LoRA weight gradients."""
|
||||||
|
R = lora_A.size(0) // E
|
||||||
|
dA = torch.zeros_like(lora_A)
|
||||||
|
dB = torch.zeros_like(lora_B)
|
||||||
|
for e in range(E):
|
||||||
|
s = eo[e - 1].item() if e > 0 else 0
|
||||||
|
end = eo[e].item()
|
||||||
|
if s == end:
|
||||||
|
continue
|
||||||
|
xe = grouped_x[s:end].float()
|
||||||
|
dye = dy[s:end].float()
|
||||||
|
ae = lora_A[e * R : (e + 1) * R].float()
|
||||||
|
be = lora_B[:, e * R : (e + 1) * R].float()
|
||||||
|
dA[e * R : (e + 1) * R] = (scaling * (dye @ be).T @ xe).to(DTYPE)
|
||||||
|
dB[:, e * R : (e + 1) * R] = (scaling * dye.T @ (xe @ ae.T)).to(DTYPE)
|
||||||
|
return dA, dB
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Model shape configs ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# (E, K, N, T, top_k, R, description)
|
||||||
|
CONFIGS_SMALL = [
|
||||||
|
(32, 128, 64, 64, 2, 4, "tiny"),
|
||||||
|
(64, 256, 128, 128, 4, 8, "small"),
|
||||||
|
]
|
||||||
|
|
||||||
|
CONFIGS_REAL = [
|
||||||
|
(256, 2048, 1024, 2048, 8, 16, "qwen35_gate_up"),
|
||||||
|
(256, 512, 2048, 2048, 8, 16, "qwen35_down"),
|
||||||
|
(64, 2048, 2048, 2048, 8, 16, "olmoe_gate_up"),
|
||||||
|
(128, 2048, 1536, 2048, 8, 16, "qwen3_gate_up"),
|
||||||
|
]
|
||||||
|
|
||||||
|
SCALING = 2.0
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Forward tests ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestScatter2ScatterLoRAForward:
|
||||||
|
"""Test scatter2scatter_lora forward kernel vs reference."""
|
||||||
|
|
||||||
|
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
|
||||||
|
def config(self, request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
def test_matches_reference(self, config):
|
||||||
|
E, K, N, T, k, R, desc = config
|
||||||
|
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||||
|
|
||||||
|
kernel_out = lora_ops.scatter2scatter_lora(
|
||||||
|
X=x,
|
||||||
|
W=W,
|
||||||
|
sorted_expert_idxs=sei,
|
||||||
|
sorted_scattered_idxs=ssi,
|
||||||
|
k=k,
|
||||||
|
lora_A=lA,
|
||||||
|
lora_B=lB,
|
||||||
|
scaling=SCALING,
|
||||||
|
)
|
||||||
|
ref_out = _reference_fwd(x, W, sei, ssi, eo, k, lA, lB, SCALING, E)
|
||||||
|
|
||||||
|
err = (kernel_out.float() - ref_out.float()).abs().max().item()
|
||||||
|
assert err < 1.0, f"[{desc}] fwd max_err={err}"
|
||||||
|
|
||||||
|
def test_output_shape(self, config):
|
||||||
|
E, K, N, T, k, R, desc = config
|
||||||
|
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||||
|
|
||||||
|
out = lora_ops.scatter2scatter_lora(
|
||||||
|
X=x,
|
||||||
|
W=W,
|
||||||
|
sorted_expert_idxs=sei,
|
||||||
|
sorted_scattered_idxs=ssi,
|
||||||
|
k=k,
|
||||||
|
lora_A=lA,
|
||||||
|
lora_B=lB,
|
||||||
|
scaling=SCALING,
|
||||||
|
)
|
||||||
|
assert out.shape == (T * k, N)
|
||||||
|
assert out.dtype == DTYPE
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Backward dX tests ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestScatter2ScatterLoRADX:
|
||||||
|
"""Test scatter2scatter_lora_dX backward kernel vs reference."""
|
||||||
|
|
||||||
|
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
|
||||||
|
def config(self, request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
def test_matches_reference(self, config):
|
||||||
|
E, K, N, T, k, R, desc = config
|
||||||
|
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||||
|
gx = base_ops.group(x, ssi, fan_out=k)
|
||||||
|
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||||
|
|
||||||
|
kernel_dx = lora_ops.scatter2scatter_lora_dX(
|
||||||
|
DY=dy,
|
||||||
|
W=W,
|
||||||
|
sorted_expert_idxs=sei,
|
||||||
|
sorted_scattered_idxs=ssi,
|
||||||
|
k=1,
|
||||||
|
lora_A=lA,
|
||||||
|
lora_B=lB,
|
||||||
|
scaling=SCALING,
|
||||||
|
dy_grouped=True,
|
||||||
|
dx_grouped=False,
|
||||||
|
)
|
||||||
|
ref_dx = _reference_dX(dy, W, sei, ssi, eo, lA, lB, SCALING, E)
|
||||||
|
|
||||||
|
err = (kernel_dx.float() - ref_dx.float()).abs().max().item()
|
||||||
|
assert err < 1.0, f"[{desc}] dX max_err={err}"
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Backward LoRA gradient tests ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestGroupBwdLoRA:
|
||||||
|
"""Test group_bwd_lora (split dA/dB kernel) vs reference."""
|
||||||
|
|
||||||
|
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL)
|
||||||
|
def config(self, request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
def test_matches_reference(self, config):
|
||||||
|
E, K, N, T, k, R, desc = config
|
||||||
|
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||||
|
gx = base_ops.group(x, ssi, fan_out=k)
|
||||||
|
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||||
|
|
||||||
|
kern_dA, kern_dB = lora_ops.group_bwd_lora(
|
||||||
|
DY=dy,
|
||||||
|
X=gx,
|
||||||
|
lora_A=lA,
|
||||||
|
lora_B=lB,
|
||||||
|
expert_offsets=eo,
|
||||||
|
E=E,
|
||||||
|
scaling=SCALING,
|
||||||
|
)
|
||||||
|
ref_dA, ref_dB = _reference_bwd_lora(dy, gx, lA, lB, eo, E, SCALING)
|
||||||
|
|
||||||
|
# Use norm-relative error: bf16 accumulation order differs between
|
||||||
|
# kernel (tiled + different reduction order) and reference (per-expert
|
||||||
|
# fp32 loop), so max absolute error can be large on individual elements
|
||||||
|
# while the overall tensor is correct.
|
||||||
|
dA_norm_err = (
|
||||||
|
(kern_dA.float() - ref_dA.float()).norm() / (ref_dA.float().norm() + 1e-6)
|
||||||
|
).item()
|
||||||
|
dB_norm_err = (
|
||||||
|
(kern_dB.float() - ref_dB.float()).norm() / (ref_dB.float().norm() + 1e-6)
|
||||||
|
).item()
|
||||||
|
assert dA_norm_err < 0.01, f"[{desc}] dA norm_rel_err={dA_norm_err}"
|
||||||
|
assert dB_norm_err < 0.01, f"[{desc}] dB norm_rel_err={dB_norm_err}"
|
||||||
|
|
||||||
|
def test_zero_expert_tokens(self):
|
||||||
|
"""Experts with zero routed tokens produce zero gradients."""
|
||||||
|
E, K, N, R = 8, 64, 32, 4
|
||||||
|
torch.manual_seed(42)
|
||||||
|
# Route all tokens to expert 0 only
|
||||||
|
T, k = 16, 1
|
||||||
|
top_idx = torch.zeros(T, k, dtype=torch.long, device=DEVICE)
|
||||||
|
sei, ssi, eo = flatten_sort_count(top_idx, E)
|
||||||
|
gx = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||||
|
dy = torch.randn(T, N, device=DEVICE, dtype=DTYPE)
|
||||||
|
lA = torch.randn(R * E, K, device=DEVICE, dtype=DTYPE)
|
||||||
|
lB = torch.randn(N, R * E, device=DEVICE, dtype=DTYPE)
|
||||||
|
|
||||||
|
dA, dB = lora_ops.group_bwd_lora(
|
||||||
|
DY=dy,
|
||||||
|
X=gx,
|
||||||
|
lora_A=lA,
|
||||||
|
lora_B=lB,
|
||||||
|
expert_offsets=eo,
|
||||||
|
E=E,
|
||||||
|
scaling=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Experts 1..7 should have zero gradients
|
||||||
|
for e in range(1, E):
|
||||||
|
assert dA[e * R : (e + 1) * R].abs().max() == 0, f"Expert {e} dA not zero"
|
||||||
|
assert dB[:, e * R : (e + 1) * R].abs().max() == 0, (
|
||||||
|
f"Expert {e} dB not zero"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Full autograd tests ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestScatterMoELoRAAutograd:
|
||||||
|
"""Test full forward + backward through ScatterMoELoRA autograd function."""
|
||||||
|
|
||||||
|
@pytest.fixture(params=CONFIGS_SMALL + CONFIGS_REAL[:2])
|
||||||
|
def config(self, request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
def test_gradients_exist_and_finite(self, config):
|
||||||
|
E, K, N, T, k, R, desc = config
|
||||||
|
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||||
|
|
||||||
|
x = x.requires_grad_(True)
|
||||||
|
lA = lA.requires_grad_(True)
|
||||||
|
lB = lB.requires_grad_(True)
|
||||||
|
|
||||||
|
out = ScatterMoELoRA.apply(
|
||||||
|
x,
|
||||||
|
W,
|
||||||
|
k,
|
||||||
|
sei,
|
||||||
|
ssi,
|
||||||
|
eo,
|
||||||
|
lA,
|
||||||
|
lB,
|
||||||
|
SCALING,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
out.sum().backward()
|
||||||
|
|
||||||
|
assert x.grad is not None, f"[{desc}] x.grad is None"
|
||||||
|
assert lA.grad is not None, f"[{desc}] lA.grad is None"
|
||||||
|
assert lB.grad is not None, f"[{desc}] lB.grad is None"
|
||||||
|
assert torch.isfinite(x.grad).all(), f"[{desc}] x.grad has non-finite"
|
||||||
|
assert torch.isfinite(lA.grad).all(), f"[{desc}] lA.grad has non-finite"
|
||||||
|
assert torch.isfinite(lB.grad).all(), f"[{desc}] lB.grad has non-finite"
|
||||||
|
assert x.grad.abs().sum() > 0, f"[{desc}] x.grad all zero"
|
||||||
|
assert lA.grad.abs().sum() > 0, f"[{desc}] lA.grad all zero"
|
||||||
|
|
||||||
|
def test_split_matches_fused(self):
|
||||||
|
"""Split dispatch (for few large experts) matches fused kernel."""
|
||||||
|
# Use a shape where split would be dispatched (large K*N, few E)
|
||||||
|
E, K, N, T, k, R = 8, 512, 1024, 128, 2, 16
|
||||||
|
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||||
|
|
||||||
|
# Force fused path
|
||||||
|
orig = lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
||||||
|
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 10**18
|
||||||
|
out_fused = lora_ops.scatter2scatter_lora(
|
||||||
|
X=x,
|
||||||
|
W=W,
|
||||||
|
sorted_expert_idxs=sei,
|
||||||
|
sorted_scattered_idxs=ssi,
|
||||||
|
k=k,
|
||||||
|
lora_A=lA,
|
||||||
|
lora_B=lB,
|
||||||
|
scaling=SCALING,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force split path
|
||||||
|
lora_ops._SPLIT_LORA_FWD_THRESHOLD = 0
|
||||||
|
out_split = lora_ops.scatter2scatter_lora(
|
||||||
|
X=x,
|
||||||
|
W=W,
|
||||||
|
sorted_expert_idxs=sei,
|
||||||
|
sorted_scattered_idxs=ssi,
|
||||||
|
k=k,
|
||||||
|
lora_A=lA,
|
||||||
|
lora_B=lB,
|
||||||
|
scaling=SCALING,
|
||||||
|
)
|
||||||
|
lora_ops._SPLIT_LORA_FWD_THRESHOLD = orig
|
||||||
|
|
||||||
|
norm_err = (
|
||||||
|
(out_fused.float() - out_split.float()).norm()
|
||||||
|
/ (out_fused.float().norm() + 1e-6)
|
||||||
|
).item()
|
||||||
|
assert norm_err < 0.01, f"split vs fused norm_err={norm_err}"
|
||||||
|
|
||||||
|
def test_scaling_zero_gives_base_only(self):
|
||||||
|
"""With scaling=0.0, LoRA contribution vanishes. Output = X@W."""
|
||||||
|
E, K, N, T, k, R = 16, 64, 32, 32, 2, 4
|
||||||
|
x, W, lA, lB, sei, ssi, eo = _setup(E, K, N, T, k, R)
|
||||||
|
|
||||||
|
out_lora = ScatterMoELoRA.apply(
|
||||||
|
x,
|
||||||
|
W,
|
||||||
|
k,
|
||||||
|
sei,
|
||||||
|
ssi,
|
||||||
|
eo,
|
||||||
|
lA,
|
||||||
|
lB,
|
||||||
|
0.0,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
out_base = base_ops.scatter2scatter(
|
||||||
|
X=x,
|
||||||
|
W=W,
|
||||||
|
sorted_expert_idxs=sei,
|
||||||
|
sorted_scattered_idxs=ssi,
|
||||||
|
k=k,
|
||||||
|
)
|
||||||
|
err = (out_lora.float() - out_base.float()).abs().max().item()
|
||||||
|
assert err < 0.01, f"scaling=0 should match base: err={err}"
|
||||||
Reference in New Issue
Block a user