Compare commits

...

68 Commits

Author SHA1 Message Date
Dan Saunders
8564961423 fix compile 2025-09-19 13:59:57 -04:00
Dan Saunders
ce21da9177 fix compile 2025-09-19 13:55:54 -04:00
Dan Saunders
b5dc58373f fix compile 2025-09-19 13:52:42 -04:00
Dan Saunders
7327144344 compile 2025-09-19 13:41:12 -04:00
Dan Saunders
fb11f696e9 bench sweep 2025-09-19 13:24:40 -04:00
Dan Saunders
105c817b0b default fix 2025-09-19 16:59:20 +00:00
Dan Saunders
64345e7707 recurse fix 2025-09-19 12:58:58 -04:00
Dan Saunders
0f8b921399 contig 2025-09-19 12:47:53 -04:00
Dan Saunders
336616d659 defaults 2025-09-19 16:45:39 +00:00
Dan Saunders
d2f1e23bcd fix 2025-09-19 12:45:18 -04:00
Dan Saunders
42aadc5069 bench fix 2025-09-19 12:34:08 -04:00
Dan Saunders
1e7302d30a bench fix 2025-09-19 12:20:35 -04:00
Dan Saunders
63544ce709 fix 2025-09-19 11:34:27 -04:00
Dan Saunders
3bfed0aac8 shared expert detection 2025-09-19 11:24:26 -04:00
Dan Saunders
bfc848f81d bits and pieces 2025-09-19 02:12:57 +00:00
Dan Saunders
abe1cad6bc another bench 2025-09-18 13:45:19 -04:00
Dan Saunders
354389caef torchtitan bench 2025-09-18 13:29:20 -04:00
Dan Saunders
efcd032fce yet another refactor 2025-09-18 13:03:28 -04:00
Dan Saunders
7500641601 yet another refactor 2025-09-18 12:47:15 -04:00
Dan Saunders
0295df5bca precompute fuse 2025-09-18 12:10:46 -04:00
Dan Saunders
b39ef54833 combine mult 2025-09-18 12:08:03 -04:00
Dan Saunders
ad4cd39bcd remove contig 2025-09-18 11:55:15 -04:00
Dan Saunders
5c197275ad inplace 2025-09-18 11:51:17 -04:00
Dan Saunders
19c91e3675 refactor 2025-09-18 11:44:21 -04:00
Dan Saunders
2a176e4923 fix 2025-09-18 11:29:33 -04:00
Dan Saunders
7d867de9b2 refactor 2025-09-18 11:23:15 -04:00
Dan Saunders
01b6792c2e refactor 2025-09-18 11:20:08 -04:00
Dan Saunders
bbf1f14ca4 dtype issues 2025-09-17 23:52:18 +00:00
Dan Saunders
c6878beb7d simplify 2025-09-17 19:15:34 -04:00
Dan Saunders
e62979d11d fix 2025-09-17 18:53:07 -04:00
Dan Saunders
d57b9c67c2 log 2025-09-17 18:52:27 -04:00
Dan Saunders
eaaf16aa00 cumulative offsets 2025-09-17 18:45:15 -04:00
Dan Saunders
f3b953e222 fix? 2025-09-17 18:42:10 -04:00
Dan Saunders
7935dc0911 dtype fix 2025-09-17 18:36:22 -04:00
Dan Saunders
d2b49b2670 error msg 2025-09-17 18:29:30 -04:00
Dan Saunders
b5cb345ca4 fix test 2025-09-17 18:24:00 -04:00
Dan Saunders
03d4c2683e fix perf degradation 2025-09-17 18:20:37 -04:00
Dan Saunders
fd87eed501 minify 2025-09-17 16:42:35 -04:00
Dan Saunders
129db67705 fix 2025-09-17 16:24:29 -04:00
Dan Saunders
38b890a36b fix 2025-09-17 16:16:41 -04:00
Dan Saunders
180920c7bf simplify 2025-09-17 19:49:18 +00:00
Dan Saunders
d024048d74 logs + fix 2025-09-17 14:50:49 -04:00
Dan Saunders
98dc945838 fix 2025-09-17 14:42:53 -04:00
Dan Saunders
108600cd69 update config 2025-09-17 14:36:24 -04:00
Dan Saunders
0e9387c395 fix 2025-09-17 14:35:36 -04:00
Dan Saunders
db61e0d4ff fix 2025-09-17 14:26:25 -04:00
Dan Saunders
51e565f60a logs 2025-09-17 14:15:51 -04:00
Dan Saunders
c774dd0409 refactor + fix 2025-09-17 14:01:39 -04:00
Dan Saunders
7289e0cb55 more logs 2025-09-17 13:44:26 -04:00
Dan Saunders
8d483c11f7 more logs 2025-09-17 13:44:26 -04:00
Dan Saunders
9c1829cf57 more logs 2025-09-17 13:44:26 -04:00
Dan Saunders
135b09d1de logs, qwen2 support 2025-09-17 13:44:26 -04:00
Dan Saunders
de4344a56e patch 2025-09-17 13:44:26 -04:00
Dan Saunders
7d572b58d1 just grouped_mm for now 2025-09-17 13:44:26 -04:00
Dan Saunders
773d7e4291 update 2025-09-17 13:44:26 -04:00
Dan Saunders
fef47a5b7c hardening 2025-09-17 13:44:26 -04:00
Dan Saunders
f6ed8ddc01 fix 2025-09-17 13:44:26 -04:00
Dan Saunders
556d6448fe fix 2025-09-17 13:44:26 -04:00
Dan Saunders
5c2229721d diag 2025-09-17 13:44:26 -04:00
Dan Saunders
d7de6b0e96 grouped_mm 2025-09-17 13:44:26 -04:00
Dan Saunders
3c6648678f numerics 2025-09-17 13:44:26 -04:00
Dan Saunders
5b19a1ea9c improve 2025-09-17 13:44:26 -04:00
Dan Saunders
cfefad1eea fix 2025-09-17 13:44:26 -04:00
Dan Saunders
125e7b5fe6 fast path 2025-09-17 13:44:26 -04:00
Dan Saunders
479b6144df tflops 2025-09-17 13:44:26 -04:00
Dan Saunders
68da65cba2 update 2025-09-17 13:44:26 -04:00
Dan Saunders
0d689bb421 cache, example 2025-09-17 13:44:26 -04:00
Dan Saunders
43ada1278a moe kernels init scaffold 2025-09-17 13:44:26 -04:00
19 changed files with 8637 additions and 8 deletions

View File

@@ -285,6 +285,7 @@ website:
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- docs/moe_backends.md
- docs/nd_parallelism.qmd
- section: "Troubleshooting"

18
docs/moe_backends.md Normal file
View File

@@ -0,0 +1,18 @@
MoE Backends in Axolotl
Axolotl supports selecting a Mixture-of-Experts (MoE) compute backend via the training config (YAML):
- Set `moe_backend: auto|torch_grouped|naive`
Behavior
- auto (default): prefers PyTorch 2.8+ grouped GEMM; otherwise naive.
- torch_grouped: targets PyTorch 2.8+ grouped GEMM (H100/SM90+ recommended).
- naive: keeps the reference per-expert loop.
Notes
- Current implementation wires the backend selector and routes Mixtral MoE through it. Torch grouped uses cuBLASLt grouped GEMM when available; otherwise, the code falls back to the naive per-expert loop.
- No changes to training scripts are required; selection happens inside the model forward.
Example
moe_backend: torch_grouped
accelerate launch -m axolotl.cli.train path/to/config.yaml

View File

@@ -0,0 +1,53 @@
base_model: Qwen/Qwen1.5-MoE-A2.7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
# Keep VRAM low
load_in_8bit: false
load_in_4bit: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/qwen2-moe-qlora-10gb
# Train small to fit 10GB
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 5
flash_attention: true
warmup_ratio: 0.03
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
model_config:
output_router_logits: true
special_tokens:

209
scripts/bench_moe.py Normal file
View File

@@ -0,0 +1,209 @@
#!/usr/bin/env python
"""Benchmark Hugging Face Qwen2 MoE block with and without grouped_mm."""
from __future__ import annotations
import argparse
import sys
import time
import weakref
from pathlib import Path
import torch
import torch._dynamo as dynamo
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception: # pragma: no cover
tg = None
def bench(run, *, iters: int, warmup: int, sync: bool = True) -> float:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for _ in range(warmup):
run()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times = []
for _ in range(iters):
if sync and device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
run()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def load_hf_block(
hidden: int,
inter: int,
experts: int,
top_k: int,
*,
device: torch.device,
dtype: torch.dtype,
):
project_root = Path(__file__).resolve().parents[2]
transformers_src = project_root / "transformers" / "src"
if transformers_src.exists() and str(transformers_src) not in sys.path:
sys.path.append(str(transformers_src))
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
cfg = Qwen2MoeConfig(
hidden_size=hidden,
moe_intermediate_size=inter,
shared_expert_intermediate_size=inter,
num_experts=experts,
num_experts_per_tok=top_k,
norm_topk_prob=True,
qkv_bias=True,
)
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped.load_state_dict(block.state_dict())
return block, block_grouped
def main() -> None:
p = argparse.ArgumentParser(description="Qwen2 MoE grouped_mm benchmark")
p.add_argument("--bsz", type=int, default=8)
p.add_argument("--seq", type=int, default=1024)
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--experts", type=int, default=32)
p.add_argument("--top_k", type=int, default=4)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10)
p.add_argument("--profile", action="store_true")
p.add_argument(
"--compile",
action="store_true",
help="Torch.compile both paths before benchmarking",
)
args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[args.dtype]
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
block_naive, block_grouped = load_hf_block(
args.hidden,
args.inter,
args.experts,
args.top_k,
device=device,
dtype=dtype,
)
tokens = args.bsz * args.seq
flops_total = estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} inter={args.inter} "
f"experts={args.experts} top_k={args.top_k}"
)
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
# Optional torch.compile
run_grouped_impl = None
if args.compile:
dynamo.config.capture_scalar_outputs = True
dynamo.config.allow_unspec_int_on_nn_module = True
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc: # pragma: no cover
print(f"torch.compile naive failed ({exc}); using eager")
else:
def grouped_forward(inp, *, block=block_grouped):
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
inp, block.gate, block.experts, block.top_k
)
return y
try:
run_grouped_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
except Exception as exc: # pragma: no cover
print(f"torch.compile grouped failed ({exc}); using eager")
run_grouped_impl = None
def run_naive(block=block_naive, data=x):
y, _ = block(data)
return y
def run_grouped(block=block_grouped, data=x, impl=run_grouped_impl):
if impl is not None:
return impl(data)
if tg is None or not tg.available():
return torch.empty(0)
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(data, block.gate, block.experts, block.top_k)
return y if y is not None else torch.empty(0)
t_naive = bench(run_naive, iters=args.iters, warmup=args.warmup)
tflops_naive = flops_total / ((t_naive / 1000.0) * 1e12)
print(
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t{tflops_naive:.2f} TFLOP/s"
)
with torch.no_grad():
y_ref = run_naive()
if tg is None or not tg.available():
print("torch_grouped\tN/A (unavailable)")
return
y_grouped = run_grouped()
if y_grouped.numel() == 0:
print("torch_grouped\tN/A (op not callable)")
return
t_grouped = bench(run_grouped, iters=args.iters, warmup=args.warmup)
tflops_grouped = flops_total / ((t_grouped / 1000.0) * 1e12)
speedup = t_naive / t_grouped
print(
f"torch_grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t"
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
)
diff = (y_ref.float() - y_grouped.float()).abs()
print(
"torch_grouped_check: "
f"max_abs={diff.max().item():.3e} mean_abs={diff.mean().item():.3e} "
f"rel_l2={(diff.pow(2).sum() / (y_ref.float().pow(2).sum() + 1e-12)).sqrt().item():.3e}"
)
if args.profile:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
) as prof:
run_naive()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True
) as prof:
run_grouped()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
if __name__ == "__main__":
main()

311
scripts/bench_moe_sweep.py Normal file
View File

@@ -0,0 +1,311 @@
#!/usr/bin/env python
"""Sweep grouped_mm vs naive performance for Qwen2 MoE block."""
from __future__ import annotations
import argparse
import csv
import sys
import time
import weakref
from dataclasses import dataclass
from pathlib import Path
from typing import List
import torch
import torch._dynamo as dynamo
try:
from axolotl.kernels.moe import torch_grouped as tg
except Exception: # pragma: no cover
tg = None
def _parse_list(arg: str) -> List[int]:
return [int(v) for v in arg.split(",") if v]
def _bench(run, *, iters: int, warmup: int, device: torch.device) -> float:
for _ in range(warmup):
run()
if device.type == "cuda":
torch.cuda.synchronize()
times: List[float] = []
for _ in range(iters):
if device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
run()
if device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def _load_block(
hidden: int,
inter: int,
experts: int,
top_k: int,
*,
device: torch.device,
dtype: torch.dtype,
):
project_root = Path(__file__).resolve().parents[2]
transformers_src = project_root / "transformers" / "src"
if transformers_src.exists() and str(transformers_src) not in sys.path:
sys.path.append(str(transformers_src))
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
cfg = Qwen2MoeConfig(
hidden_size=hidden,
moe_intermediate_size=inter,
shared_expert_intermediate_size=inter,
num_experts=experts,
num_experts_per_tok=top_k,
norm_topk_prob=True,
qkv_bias=True,
)
block = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped = Qwen2MoeSparseMoeBlock(cfg).to(device=device, dtype=dtype)
block_grouped.load_state_dict(block.state_dict())
return block, block_grouped
@dataclass
class Result:
bsz: int
seq: int
hidden: int
inter: int
experts: int
top_k: int
dtype: str
naive_ms: float
grouped_ms: float
speedup: float
naive_tflops: float
grouped_tflops: float
max_abs: float
mean_abs: float
rel_l2: float
def main() -> None:
p = argparse.ArgumentParser(description="Grouped MoE sweep")
p.add_argument("--batch-sizes", default="4,8,16")
p.add_argument("--seq-lens", default="512,1024,2048")
p.add_argument("--hidden", default="2048,4096")
p.add_argument("--inter", default="5632,8192,14336")
p.add_argument("--experts", default="8,16,32")
p.add_argument("--top-k", default="1,2,4")
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=25)
p.add_argument("--warmup", type=int, default=5)
p.add_argument("--csv", type=Path, default=None)
p.add_argument("--compile", action="store_true")
args = p.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[args.dtype]
if tg is None or not tg.available():
print("torch_grouped unavailable; sweep aborted")
return
bs_list = _parse_list(args.batch_sizes)
seq_list = _parse_list(args.seq_lens)
hidden_list = _parse_list(args.hidden)
inter_list = _parse_list(args.inter)
expert_list = _parse_list(args.experts)
topk_list = _parse_list(args.top_k)
results: List[Result] = []
print(
"bsz\tseq\thidden\tinter\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
)
for bsz in bs_list:
for seq in seq_list:
tokens = bsz * seq
for hidden in hidden_list:
for inter in inter_list:
for experts in expert_list:
for top_k in topk_list:
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
block_naive, block_grouped = _load_block(
hidden,
inter,
experts,
top_k,
device=device,
dtype=dtype,
)
x = torch.randn(
bsz, seq, hidden, device=device, dtype=dtype
)
compiled_impl = None
if args.compile:
dynamo.config.capture_scalar_outputs = True
dynamo.config.allow_unspec_int_on_nn_module = True
try:
block_naive = torch.compile(block_naive) # type: ignore[arg-type]
except Exception as exc:
print(
f"torch.compile naive failed ({exc}); using eager"
)
else:
def grouped_forward(inp, *, block=block_grouped):
block.experts._ax_parent_block_ref = (
weakref.ref(block)
) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
inp,
block.gate,
block.experts,
block.top_k,
)
return y
try:
compiled_impl = torch.compile(grouped_forward) # type: ignore[arg-type]
except Exception as exc:
print(
f"torch.compile grouped failed ({exc}); using eager"
)
compiled_impl = None
def run_naive(block=block_naive, data=x):
y, _ = block(data)
return y
def run_grouped(
block=block_grouped, data=x, impl=compiled_impl
):
if impl is not None:
return impl(data)
block.experts._ax_parent_block_ref = weakref.ref(block) # type: ignore[attr-defined]
y, _ = tg.moe_ffn_forward_grouped(
data,
block.gate,
block.experts,
block.top_k,
)
return y
naive_ms = _bench(
run_naive,
iters=args.iters,
warmup=args.warmup,
device=device,
)
y_naive = run_naive()
grouped_ms = _bench(
run_grouped,
iters=args.iters,
warmup=args.warmup,
device=device,
)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
res = Result(
bsz,
seq,
hidden,
inter,
experts,
top_k,
args.dtype,
naive_ms,
grouped_ms,
naive_ms / grouped_ms,
_estimate_flops(tokens, hidden, inter, top_k)
/ ((naive_ms / 1000.0) * 1e12),
_estimate_flops(tokens, hidden, inter, top_k)
/ ((grouped_ms / 1000.0) * 1e12),
diff.max().item(),
diff.mean().item(),
(
(
diff.pow(2).sum()
/ (y_naive.float().pow(2).sum() + 1e-12)
)
.sqrt()
.item()
),
)
results.append(res)
print(
f"{bsz}\t{seq}\t{hidden}\t{inter}\t{experts}\t{top_k}\t{res.naive_ms:.2f}\t"
f"{res.grouped_ms:.2f}\t{res.speedup:.2f}\t{res.naive_tflops:.2f}\t"
f"{res.grouped_tflops:.2f}\t{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
)
if args.csv:
fieldnames = [
"bsz",
"seq",
"hidden",
"inter",
"experts",
"top_k",
"dtype",
"naive_ms",
"grouped_ms",
"speedup",
"naive_tflops",
"grouped_tflops",
"max_abs",
"mean_abs",
"rel_l2",
]
with args.csv.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for r in results:
writer.writerow(
{
"bsz": r.bsz,
"seq": r.seq,
"hidden": r.hidden,
"inter": r.inter,
"experts": r.experts,
"top_k": r.top_k,
"dtype": r.dtype,
"naive_ms": f"{r.naive_ms:.4f}",
"grouped_ms": f"{r.grouped_ms:.4f}",
"speedup": f"{r.speedup:.4f}",
"naive_tflops": f"{r.naive_tflops:.4f}",
"grouped_tflops": f"{r.grouped_tflops:.4f}",
"max_abs": f"{r.max_abs:.6e}",
"mean_abs": f"{r.mean_abs:.6e}",
"rel_l2": f"{r.rel_l2:.6e}",
}
)
if __name__ == "__main__":
import weakref
main()

View File

@@ -0,0 +1,205 @@
#!/usr/bin/env python
"""Benchmark Torchtitan MoE grouped vs naive expert execution."""
from __future__ import annotations
import argparse
import sys
import time
from pathlib import Path
import torch
# Ensure torchtitan is importable when running from the axolotl tree
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_TITAN_PATH = _PROJECT_ROOT / "torchtitan"
if str(_TITAN_PATH) not in sys.path:
sys.path.append(str(_TITAN_PATH))
from torchtitan.models.moe import MoE, MoEArgs
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Torchtitan MoE microbenchmark")
p.add_argument("--bsz", type=int, default=8)
p.add_argument("--seq", type=int, default=1024)
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--experts", type=int, default=8)
p.add_argument("--top_k", type=int, default=2)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=50)
p.add_argument("--warmup", type=int, default=10)
p.add_argument("--init-std", type=float, default=0.02)
p.add_argument(
"--score-before",
action="store_true",
help="Apply routing scores before expert computation (default: after)",
)
p.add_argument(
"--score-func",
choices=["softmax", "sigmoid"],
default="softmax",
)
p.add_argument(
"--route-norm",
action="store_true",
help="Enable Torchtitan router normalization when using sigmoid scores.",
)
return p.parse_args()
def _map_dtype(arg: str) -> torch.dtype:
return {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[arg]
def _estimate_moe_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
# Two up projections + one down projection per expert/token combination.
return 6.0 * tokens * top_k * hidden * inter
def _prepare_module(
moe: MoE,
*,
device: torch.device,
dtype: torch.dtype,
) -> MoE:
moe = moe.to(device=device)
for param in moe.parameters():
param.data = param.data.to(dtype)
if param.grad is not None:
param.grad = None
buffers = dict(moe.named_buffers())
for name, buf in buffers.items():
if name == "tokens_per_expert":
moe._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
elif name == "expert_bias" and buf is not None:
moe._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
else:
moe._buffers[name] = buf.to(device=device, dtype=dtype)
moe.eval()
return moe
@torch.inference_mode()
def _forward_fn(module: MoE, x: torch.Tensor) -> torch.Tensor:
return module(x)
def _bench(fn, *, iters: int, warmup: int, sync: bool = True) -> float:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
for _ in range(warmup):
fn()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times = []
for _ in range(iters):
if sync and device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
fn()
if sync and device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000.0)
return sum(times) / len(times)
def main() -> None:
args = _parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dtype = _map_dtype(args.dtype)
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
moe_args_grouped = MoEArgs(
num_experts=args.experts,
num_shared_experts=0,
score_func=args.score_func,
route_norm=args.route_norm,
top_k=args.top_k,
use_grouped_mm=True,
score_before_experts=args.score_before,
load_balance_coeff=None,
)
moe_grouped = MoE(moe_args_grouped, dim=args.hidden, hidden_dim=args.inter)
moe_grouped.init_weights(args.init_std, buffer_device=device)
moe_args_naive = MoEArgs(
num_experts=args.experts,
num_shared_experts=0,
score_func=args.score_func,
route_norm=args.route_norm,
top_k=args.top_k,
use_grouped_mm=False,
score_before_experts=args.score_before,
load_balance_coeff=None,
)
moe_naive = MoE(moe_args_naive, dim=args.hidden, hidden_dim=args.inter)
moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True)
moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype)
moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype)
x = torch.randn(args.bsz, args.seq, args.hidden, device=device, dtype=dtype)
tokens = args.bsz * args.seq
print(
f"Device={device} dtype={dtype} tokens={tokens} hidden={args.hidden} "
f"inter={args.inter} experts={args.experts} top_k={args.top_k}"
)
def run_naive():
return _forward_fn(moe_naive, x)
def run_grouped():
return _forward_fn(moe_grouped, x)
if hasattr(moe_naive, "tokens_per_expert"):
moe_naive.tokens_per_expert.zero_()
if hasattr(moe_grouped, "tokens_per_expert"):
moe_grouped.tokens_per_expert.zero_()
t_naive = _bench(run_naive, iters=args.iters, warmup=args.warmup)
flops = _estimate_moe_flops(tokens, args.hidden, args.inter, args.top_k)
tflops_naive = flops / ((t_naive / 1000.0) * 1e12)
print(
f"naive\t{t_naive:.2f} ms\t{tokens / (t_naive / 1000.0):.1f} tok/s\t"
f"{tflops_naive:.2f} TFLOP/s"
)
y_naive = run_naive()
if hasattr(moe_grouped, "tokens_per_expert"):
moe_grouped.tokens_per_expert.zero_()
t_grouped = _bench(run_grouped, iters=args.iters, warmup=args.warmup)
tflops_grouped = flops / ((t_grouped / 1000.0) * 1e12)
speedup = t_naive / t_grouped if t_grouped > 0 else float("nan")
print(
f"grouped\t{t_grouped:.2f} ms\t{tokens / (t_grouped / 1000.0):.1f} tok/s\t"
f"{tflops_grouped:.2f} TFLOP/s\t{speedup:.2f}×"
)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
print(
f"grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,328 @@
#!/usr/bin/env python
"""Sweep Torchtitan MoE grouped vs naive configurations and report performance."""
from __future__ import annotations
import argparse
import csv
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List
import torch
_PROJECT_ROOT = Path(__file__).resolve().parents[2]
_TITAN_PATH = _PROJECT_ROOT / "torchtitan"
if str(_TITAN_PATH) not in sys.path:
sys.path.append(str(_TITAN_PATH))
from torchtitan.models.moe import MoE, MoEArgs
def _parse_int_list(value: str) -> List[int]:
return [int(v) for v in value.split(",") if v]
def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Torchtitan MoE grouped vs naive sweep")
p.add_argument(
"--batch-sizes", default="4,8,16", help="Comma separated batch sizes"
)
p.add_argument(
"--seq-lens", default="1024,2048", help="Comma separated sequence lengths"
)
p.add_argument(
"--experts", default="8,16,32,64", help="Comma separated expert counts"
)
p.add_argument("--top-ks", default="1,2,4", help="Comma separated top_k choices")
p.add_argument("--hidden", type=int, default=4096)
p.add_argument("--inter", type=int, default=14336)
p.add_argument("--dtype", choices=["bf16", "fp16", "fp32"], default="bf16")
p.add_argument("--iters", type=int, default=25)
p.add_argument("--warmup", type=int, default=5)
p.add_argument("--init-std", type=float, default=0.02)
p.add_argument("--score-before", action="store_true")
p.add_argument("--score-func", choices=["softmax", "sigmoid"], default="softmax")
p.add_argument("--route-norm", action="store_true")
p.add_argument("--csv", type=Path, default=None, help="Optional CSV output path")
return p.parse_args()
def _map_dtype(arg: str) -> torch.dtype:
return {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[arg]
def _estimate_flops(tokens: int, hidden: int, inter: int, top_k: int) -> float:
return 6.0 * tokens * top_k * hidden * inter
def _prepare_module(module: MoE, *, device: torch.device, dtype: torch.dtype) -> MoE:
module = module.to(device=device)
for param in module.parameters():
param.data = param.data.to(dtype)
if param.grad is not None:
param.grad = None
for name, buf in module.named_buffers():
if name == "tokens_per_expert":
module._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
elif name == "expert_bias" and buf is not None:
module._buffers[name] = torch.zeros_like(
buf, dtype=torch.float32, device=device
)
else:
module._buffers[name] = buf.to(device=device, dtype=dtype)
module.eval()
return module
@torch.inference_mode()
def _forward(module: MoE, x: torch.Tensor) -> torch.Tensor:
return module(x)
def _bench(callable_, *, iters: int, warmup: int, device: torch.device) -> float:
for _ in range(warmup):
callable_()
if device.type == "cuda":
torch.cuda.synchronize()
timings: List[float] = []
for _ in range(iters):
if device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
callable_()
if device.type == "cuda":
torch.cuda.synchronize()
timings.append((time.perf_counter() - start) * 1000.0)
return sum(timings) / len(timings)
@dataclass
class SweepResult:
bsz: int
seq: int
experts: int
top_k: int
dtype: str
naive_ms: float
grouped_ms: float
speedup: float
naive_tflops: float
grouped_tflops: float
max_abs: float
mean_abs: float
rel_l2: float
def _run_case(
*,
bsz: int,
seq: int,
experts: int,
top_k: int,
hidden: int,
inter: int,
dtype: torch.dtype,
device: torch.device,
iters: int,
warmup: int,
init_std: float,
score_before: bool,
score_func: str,
route_norm: bool,
) -> SweepResult:
torch.manual_seed(0)
if device.type == "cuda":
torch.cuda.manual_seed(0)
moe_args_grouped = MoEArgs(
num_experts=experts,
num_shared_experts=0,
score_func=score_func,
route_norm=route_norm,
top_k=top_k,
use_grouped_mm=True,
score_before_experts=score_before,
load_balance_coeff=None,
)
moe_grouped = MoE(moe_args_grouped, dim=hidden, hidden_dim=inter)
moe_grouped.init_weights(init_std, buffer_device=device)
moe_args_naive = MoEArgs(
num_experts=experts,
num_shared_experts=0,
score_func=score_func,
route_norm=route_norm,
top_k=top_k,
use_grouped_mm=False,
score_before_experts=score_before,
load_balance_coeff=None,
)
moe_naive = MoE(moe_args_naive, dim=hidden, hidden_dim=inter)
moe_naive.load_state_dict(moe_grouped.state_dict(), strict=True)
moe_grouped = _prepare_module(moe_grouped, device=device, dtype=dtype)
moe_naive = _prepare_module(moe_naive, device=device, dtype=dtype)
x = torch.randn(bsz, seq, hidden, device=device, dtype=dtype)
def run_naive():
if hasattr(moe_naive, "tokens_per_expert"):
moe_naive.tokens_per_expert.zero_()
return _forward(moe_naive, x)
def run_grouped():
if hasattr(moe_grouped, "tokens_per_expert"):
moe_grouped.tokens_per_expert.zero_()
return _forward(moe_grouped, x)
naive_ms = _bench(run_naive, iters=iters, warmup=warmup, device=device)
y_naive = run_naive()
grouped_ms = _bench(run_grouped, iters=iters, warmup=warmup, device=device)
y_grouped = run_grouped()
diff = (y_naive.float() - y_grouped.float()).abs()
max_abs = diff.max().item()
mean_abs = diff.mean().item()
rel_l2 = (diff.pow(2).sum() / (y_naive.float().pow(2).sum() + 1e-12)).sqrt().item()
tokens = bsz * seq
flops = _estimate_flops(tokens, hidden, inter, top_k)
naive_tflops = flops / ((naive_ms / 1000.0) * 1e12)
grouped_tflops = flops / ((grouped_ms / 1000.0) * 1e12)
speedup = naive_ms / grouped_ms if grouped_ms > 0 else float("nan")
return SweepResult(
bsz=bsz,
seq=seq,
experts=experts,
top_k=top_k,
dtype=str(dtype),
naive_ms=naive_ms,
grouped_ms=grouped_ms,
speedup=speedup,
naive_tflops=naive_tflops,
grouped_tflops=grouped_tflops,
max_abs=max_abs,
mean_abs=mean_abs,
rel_l2=rel_l2,
)
def _print_header(
hidden: int, inter: int, dtype: torch.dtype, device: torch.device
) -> None:
print(f"Device={device} dtype={dtype} hidden={hidden} inter={inter}")
print(
"bsz\tseq\texperts\ttop_k\tnaive(ms)\tgrouped(ms)\tspeedup\t"
"naive TF/s\tgrouped TF/s\tmax_abs\tmean_abs\trel_l2"
)
def _print_result(res: SweepResult) -> None:
print(
f"{res.bsz}\t{res.seq}\t{res.experts}\t{res.top_k}\t"
f"{res.naive_ms:.2f}\t{res.grouped_ms:.2f}\t{res.speedup:.2f}\t"
f"{res.naive_tflops:.2f}\t{res.grouped_tflops:.2f}\t"
f"{res.max_abs:.2e}\t{res.mean_abs:.2e}\t{res.rel_l2:.2e}"
)
def _write_csv(path: Path, results: Iterable[SweepResult]) -> None:
fieldnames = [
"batch_size",
"seq_len",
"experts",
"top_k",
"dtype",
"naive_ms",
"grouped_ms",
"speedup",
"naive_tflops",
"grouped_tflops",
"max_abs",
"mean_abs",
"rel_l2",
]
with path.open("w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for r in results:
writer.writerow(
{
"batch_size": r.bsz,
"seq_len": r.seq,
"experts": r.experts,
"top_k": r.top_k,
"dtype": r.dtype,
"naive_ms": f"{r.naive_ms:.4f}",
"grouped_ms": f"{r.grouped_ms:.4f}",
"speedup": f"{r.speedup:.4f}",
"naive_tflops": f"{r.naive_tflops:.4f}",
"grouped_tflops": f"{r.grouped_tflops:.4f}",
"max_abs": f"{r.max_abs:.6e}",
"mean_abs": f"{r.mean_abs:.6e}",
"rel_l2": f"{r.rel_l2:.6e}",
}
)
def main() -> None:
args = _parse_args()
dtype = _map_dtype(args.dtype)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch_sizes = _parse_int_list(args.batch_sizes)
seq_lens = _parse_int_list(args.seq_lens)
experts_list = _parse_int_list(args.experts)
top_ks = _parse_int_list(args.top_ks)
results: List[SweepResult] = []
_print_header(args.hidden, args.inter, dtype, device)
for bsz in batch_sizes:
for seq in seq_lens:
for experts in experts_list:
for top_k in top_ks:
try:
res = _run_case(
bsz=bsz,
seq=seq,
experts=experts,
top_k=top_k,
hidden=args.hidden,
inter=args.inter,
dtype=dtype,
device=device,
iters=args.iters,
warmup=args.warmup,
init_std=args.init_std,
score_before=args.score_before,
score_func=args.score_func,
route_norm=args.route_norm,
)
except RuntimeError as err:
print(
f"{bsz}\t{seq}\t{experts}\t{top_k}\tERROR: {err}",
file=sys.stderr,
)
continue
results.append(res)
_print_result(res)
if args.csv and results:
_write_csv(args.csv, results)
print(f"Wrote {len(results)} rows to {args.csv}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,53 @@
#!/usr/bin/env python
"""Inspect Qwen2 MoE expert implementations for grouped-mm debugging."""
from __future__ import annotations
import sys
from pathlib import Path
import torch
ROOT = Path(__file__).resolve().parents[2]
sys.path.extend(
[
str(ROOT / "transformers" / "src"),
str(ROOT / "src"),
]
)
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
from axolotl.kernels.moe.torch_grouped import _iter_expert_impls
def main() -> None:
cfg = Qwen2MoeConfig(
hidden_size=4096,
moe_intermediate_size=14336,
shared_expert_intermediate_size=14336,
num_experts=32,
num_experts_per_tok=4,
)
block = Qwen2MoeSparseMoeBlock(cfg).to("cuda", dtype=torch.bfloat16)
experts = block.experts
experts._ax_parent_block = block
impls = _iter_expert_impls(experts)
print(f"impl count: {len(impls)}")
for idx, impl in enumerate(impls[:8]):
has_gate = hasattr(impl, "gate_proj")
has_up = hasattr(impl, "up_proj")
print(
f"impl[{idx}] type={impl.__class__.__name__} has_gate={has_gate} has_up={has_up}"
)
if has_gate:
print(f" gate shape {tuple(impl.gate_proj.weight.shape)}")
print(f" up shape {tuple(impl.up_proj.weight.shape)}")
print(f" down shape {tuple(impl.down_proj.weight.shape)}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python
"""
Probe PyTorch for grouped GEMM operator names and namespaces.
Run: python scripts/probe_torch_grouped_ops.py
"""
import sys
def main():
try:
import torch
except Exception as e:
print("Failed to import torch:", e)
sys.exit(1)
print("torch version:", torch.__version__)
namespaces = [n for n in dir(torch.ops) if not n.startswith("_")]
print("ops namespaces:", namespaces)
found_any = False
for ns in namespaces:
obj = getattr(torch.ops, ns, None)
ops = []
if obj is not None:
try:
ops = dir(obj)
except Exception as e:
print(f"warning: failed to list ops for namespace {ns}: {e}")
cands = [
o
for o in ops
if ("group" in o.lower())
or ("mm_grouped" in o.lower())
or ("matmul_grouped" in o.lower())
or ("grouped" in o.lower())
]
if cands:
found_any = True
print(f"namespace {ns} candidates:", cands)
if not found_any:
print("No grouped GEMM candidates found. PyTorch >= 2.8 is recommended.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,3 @@
from .backends import MOEBackend, get_moe_backend_name
__all__ = ["get_moe_backend_name", "MOEBackend"]

View File

@@ -0,0 +1,47 @@
import warnings
from enum import Enum
class MOEBackend(str, Enum):
AUTO = "auto"
TORCH_GROUPED = "torch_grouped"
NAIVE = "naive"
def _probe_torch_grouped() -> bool:
try:
import torch # noqa: F401
# Prefer a simple version check; exact APIs may vary across 2.8+.
ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
return ver >= (2, 8)
except Exception:
return False
def get_moe_backend_name(preferred: str | None = None) -> MOEBackend:
"""
Resolve the desired MoE backend using, in order of precedence:
- explicit preferred argument (e.g., from config)
- auto detection
"""
choice = (preferred or "auto").lower()
try:
selected = MOEBackend(choice)
except ValueError:
warnings.warn(
f"Unknown moe backend '{choice}', falling back to auto", stacklevel=2
)
selected = MOEBackend.AUTO
if selected == MOEBackend.AUTO:
if _probe_torch_grouped():
return MOEBackend.TORCH_GROUPED
return MOEBackend.NAIVE
if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped():
warnings.warn(
"torch_grouped requested but torch>=2.8 not detected; falling back to naive",
stacklevel=2,
)
return MOEBackend.NAIVE
return selected

View File

@@ -0,0 +1,371 @@
"""Minimal grouped GEMM fast path for MoE experts using PyTorch _grouped_mm."""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
_LOGGER = logging.getLogger("axolotl.moe.grouped")
def available() -> bool:
try:
major, minor = map(int, torch.__version__.split("+")[0].split(".")[:2])
if (major, minor) < (2, 8):
return False
if not torch.cuda.is_available():
return False
sm, _ = torch.cuda.get_device_capability()
if sm < 9:
return False
return hasattr(torch.ops, "_grouped_mm")
except Exception:
return False
def _iter_expert_impls(
experts_module, visited: Optional[set[int]] = None
) -> List[torch.nn.Module]:
if visited is None:
visited = set()
module_id = id(experts_module)
if module_id in visited:
return []
visited.add(module_id)
impls: List[torch.nn.Module] = []
for exp in experts_module:
candidate = getattr(exp, "mlp", getattr(exp, "ffn", exp))
if hasattr(candidate, "gate_proj") and hasattr(candidate, "up_proj"):
impls.append(candidate)
continue
nested = getattr(candidate, "experts", None)
if nested is not None:
impls.extend(_iter_expert_impls(nested, visited))
continue
raise RuntimeError(
"torch_grouped: unable to resolve expert implementation for module"
)
return impls
@dataclass
class _GroupedWeightStorage:
pattern: str
gate: torch.Tensor
up: torch.Tensor
down: torch.Tensor
fused_gate_up: torch.Tensor
dtype: torch.dtype
device: torch.device
def _allocate_fused_gate_up(
num_experts: int,
gate_shape: torch.Size,
up_shape: torch.Size,
*,
device: torch.device,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if gate_shape[1] != up_shape[1]:
raise RuntimeError(
"torch_grouped: gate and up projections must share the hidden dimension"
)
fused = torch.empty(
(num_experts, gate_shape[0] + up_shape[0], gate_shape[1]),
device=device,
dtype=dtype,
)
gate_view = fused[:, : gate_shape[0]]
up_view = fused[:, gate_shape[0] : gate_shape[0] + up_shape[0]]
return fused, gate_view, up_view
def _ensure_grouped_weights(
experts_module, expert_impls: List[torch.nn.Module], sample_mod: torch.nn.Module
) -> _GroupedWeightStorage:
storage: Optional[_GroupedWeightStorage] = getattr(
experts_module, "_ax_grouped_storage", None
)
def _store(new_storage: _GroupedWeightStorage) -> _GroupedWeightStorage:
experts_module._ax_grouped_storage = new_storage
return new_storage
# Identify expert parameter layout
if (
hasattr(sample_mod, "w1")
and hasattr(sample_mod, "w3")
and hasattr(sample_mod, "w2")
):
pattern = "swi_glu"
num_experts = len(expert_impls)
w1_shape = sample_mod.w1.weight.shape
w3_shape = sample_mod.w3.weight.shape
w2_shape = sample_mod.w2.weight.shape
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == sample_mod.w1.weight.dtype
and storage.device == sample_mod.w1.weight.device
and storage.gate.shape[1:] == w1_shape
):
return storage
fused, gate, up = _allocate_fused_gate_up(
num_experts,
w1_shape,
w3_shape,
device=sample_mod.w1.weight.device,
dtype=sample_mod.w1.weight.dtype,
)
down = torch.empty(
(num_experts, *w2_shape),
device=sample_mod.w2.weight.device,
dtype=sample_mod.w2.weight.dtype,
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
gate[idx].copy_(mod.w1.weight.detach())
up[idx].copy_(mod.w3.weight.detach())
down[idx].copy_(mod.w2.weight.detach())
mod.w1.weight.detach_()
mod.w1.weight.set_(gate[idx])
mod.w3.weight.detach_()
mod.w3.weight.set_(up[idx])
mod.w2.weight.detach_()
mod.w2.weight.set_(down[idx])
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
fused_gate_up=fused,
dtype=gate.dtype,
device=gate.device,
)
)
if hasattr(sample_mod, "gate_up_proj") and hasattr(sample_mod, "down_proj"):
pattern = "fused_gate_up"
gate_weight = sample_mod.gate_up_proj.weight
down_weight = sample_mod.down_proj.weight
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == gate_weight.dtype
and storage.device == gate_weight.device
and storage.gate.shape[1:]
== (gate_weight.shape[0] // 2, gate_weight.shape[1])
):
return storage
num_experts = len(expert_impls)
gate_full = torch.empty(
(num_experts, *gate_weight.shape),
device=gate_weight.device,
dtype=gate_weight.dtype,
)
down = torch.empty(
(num_experts, *down_weight.shape),
device=down_weight.device,
dtype=down_weight.dtype,
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
gate_full[idx].copy_(mod.gate_up_proj.weight.detach())
down[idx].copy_(mod.down_proj.weight.detach())
mod.gate_up_proj.weight.detach_()
mod.gate_up_proj.weight.set_(gate_full[idx])
mod.down_proj.weight.detach_()
mod.down_proj.weight.set_(down[idx])
inter = gate_weight.shape[0] // 2
gate = gate_full[:, :inter]
up = gate_full[:, inter:]
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
fused_gate_up=gate_full,
dtype=gate.dtype,
device=gate.device,
)
)
if (
hasattr(sample_mod, "up_proj")
and hasattr(sample_mod, "gate_proj")
and hasattr(sample_mod, "down_proj")
):
pattern = "dual_proj"
up_weight = sample_mod.up_proj.weight
gate_weight = sample_mod.gate_proj.weight
down_weight = sample_mod.down_proj.weight
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == sample_mod.up_proj.weight.dtype
and storage.device == sample_mod.up_proj.weight.device
and storage.gate.shape[1:] == gate_weight.shape
):
return storage
num_experts = len(expert_impls)
fused, gate, up = _allocate_fused_gate_up(
num_experts,
gate_weight.shape,
up_weight.shape,
device=gate_weight.device,
dtype=gate_weight.dtype,
)
down = torch.empty(
(num_experts, *down_weight.shape),
device=down_weight.device,
dtype=down_weight.dtype,
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
gate[idx].copy_(mod.gate_proj.weight.detach())
up[idx].copy_(mod.up_proj.weight.detach())
down[idx].copy_(mod.down_proj.weight.detach())
mod.up_proj.weight.detach_()
mod.up_proj.weight.set_(up[idx])
mod.gate_proj.weight.detach_()
mod.gate_proj.weight.set_(gate[idx])
mod.down_proj.weight.detach_()
mod.down_proj.weight.set_(down[idx])
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
fused_gate_up=fused,
dtype=gate.dtype,
device=gate.device,
)
)
raise RuntimeError(
"torch_grouped: unsupported expert module layout for grouped weights"
)
def moe_ffn_forward_grouped(
hidden_states: torch.Tensor,
gate_linear: torch.nn.Linear,
experts_module,
top_k: int,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if not available():
return None, None
bsz, seqlen, hdim = hidden_states.shape
tokens = bsz * seqlen
device = hidden_states.device
routing_dtype = gate_linear.weight.dtype
expert_dtype = hidden_states.dtype
if expert_dtype not in (torch.bfloat16, torch.float16):
_LOGGER.debug(
"torch_grouped: unsupported expert dtype %s; falling back to naive",
expert_dtype,
)
return None, None
parent_block = None
parent_ref = getattr(experts_module, "_ax_parent_block_ref", None)
if parent_ref is not None:
try:
parent_block = parent_ref()
except TypeError:
parent_block = None
expert_container = getattr(experts_module, "experts", experts_module)
expert_impls = _iter_expert_impls(expert_container)
sample_mod = expert_impls[0]
storage = _ensure_grouped_weights(expert_container, expert_impls, sample_mod)
w_gate = storage.gate
w_up = storage.up
w2 = storage.down
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
router_logits = gate_linear(x_flat.to(routing_dtype))
shared_out_flat: Optional[torch.Tensor] = None
shared_owner = parent_block if parent_block is not None else experts_module
if hasattr(shared_owner, "shared_expert"):
shared_expert = shared_owner.shared_expert
shared_out_flat = shared_expert(x_flat)
shared_out_flat = shared_out_flat.to(expert_dtype)
shared_gate = getattr(shared_owner, "shared_expert_gate", None)
if shared_gate is not None:
gate_input = shared_gate(x_flat.to(shared_gate.weight.dtype))
gate_vals = torch.sigmoid(gate_input)
shared_out_flat.mul_(gate_vals.to(expert_dtype))
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
flat_idx = topk_idx.view(-1)
num_experts = len(expert_impls)
if flat_idx.numel() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
sorted_experts, perm = torch.sort(flat_idx)
assignments = torch.bincount(sorted_experts, minlength=num_experts)
if assignments.sum() == 0:
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
token_indices_sorted = torch.div(perm, top_k, rounding_mode="floor").contiguous()
scores_sorted = topk_weight.reshape(-1).index_select(0, perm)
gather_index = token_indices_sorted.unsqueeze(-1).expand(-1, hdim)
routed_input = torch.gather(x_flat, 0, gather_index)
counts_i32 = assignments.to(device=device, dtype=torch.int32)
offsets = torch.cumsum(counts_i32, dim=0).to(dtype=torch.int32)
mm_dtype = torch.bfloat16 if expert_dtype == torch.bfloat16 else expert_dtype
routed_in = routed_input.to(mm_dtype)
w_gate_t = w_gate.transpose(-2, -1).to(mm_dtype)
w_up_t = w_up.transpose(-2, -1).to(mm_dtype)
w2_t = w2.transpose(-2, -1).to(mm_dtype)
routed_in = routed_in.contiguous()
w_gate_t = w_gate_t.contiguous()
gate_out = torch._grouped_mm(routed_in, w_gate_t, offs=offsets)
torch.ops.aten.silu_(gate_out)
w_up_t = w_up_t.contiguous()
up_out = torch._grouped_mm(routed_in, w_up_t, offs=offsets)
gate_out.mul_(up_out)
gate_out = gate_out.contiguous()
w2_t = w2_t.contiguous()
down_out = torch._grouped_mm(gate_out, w2_t, offs=offsets).to(expert_dtype)
weights = scores_sorted.unsqueeze(-1).to(expert_dtype)
down_out.mul_(weights)
combined = torch.zeros_like(x_flat)
combined.scatter_add_(0, gather_index, down_out)
output = combined.view(bsz, seqlen, hdim)
if shared_out_flat is not None:
output = output + shared_out_flat.view(bsz, seqlen, hdim)
return output, router_logits

View File

@@ -12,6 +12,7 @@ import transformers
from transformers import PretrainedConfig, PreTrainedModel
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.moe_grouped import apply_grouped_to_moe_blocks
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
@@ -57,6 +58,8 @@ class PatchManager:
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
# Apply MoE grouped GEMM patches (cfg.moe_backend)
apply_grouped_to_moe_blocks(self.cfg)
self._apply_fp8_patches()
self._apply_flash_attention_peft_patches()
self._apply_gradient_checkpointing_patches()
@@ -269,6 +272,7 @@ class PatchManager:
self.cfg.model_config_type,
model_name=self.cfg.base_model,
has_remote_code=has_remote_code,
cfg=self.cfg,
)
if self.cfg.sample_packing:

View File

@@ -5,9 +5,14 @@ Patches to support multipack for mixtral
import torch
def patch_mixtral_moe_forward_zero3() -> None:
def patch_mixtral_moe_forward_zero3(cfg=None) -> None:
import warnings
import torch.nn.functional as F
from axolotl.kernels.moe import backends as _moe_backends
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
def mlp_forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(
hidden_states
@@ -21,21 +26,32 @@ def patch_mixtral_moe_forward_zero3() -> None:
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if (
backend == MOEBackend.TORCH_GROUPED
and not _moe_backends._probe_torch_grouped()
):
warnings.warn(
"torch_grouped selected but not available; falling back to naive",
stacklevel=2,
)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
topk_weight, topk_idx = torch.topk(
routing_weights, self.top_k, dim=-1, sorted=False
)
topk_weight /= topk_weight.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
topk_weight = topk_weight.to(hidden_states.dtype)
hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states)
hidden_states_rep = hidden_states.repeat_interleave(self.top_k, dim=0)
y = torch.empty_like(hidden_states_rep)
flat_topk_idx = topk_idx.view(-1)
for i in range(self.num_experts):
expert = self.experts[i]
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
sel = flat_topk_idx == i
if sel.any():
y[sel] = expert(hidden_states_rep[sel])
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
@@ -46,4 +62,23 @@ def patch_mixtral_moe_forward_zero3() -> None:
)
MixtralBlockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward
# Wrap forward to support optional torch_grouped backend via config
from axolotl.kernels.moe import torch_grouped as _tg
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if backend == MOEBackend.TORCH_GROUPED and _tg.available():
def moe_forward_grouped(self, hidden_states: torch.Tensor) -> torch.Tensor:
bsz, seqlen, hdim = hidden_states.shape
y, router_logits = _tg.moe_ffn_forward_grouped(
hidden_states, self.gate, self.experts, self.top_k
)
if y is None:
return moe_forward(self, hidden_states)
return y, router_logits
MixtralSparseMoeBlock.forward = moe_forward_grouped
else:
MixtralSparseMoeBlock.forward = moe_forward

View File

@@ -0,0 +1,133 @@
import logging
import weakref
from functools import wraps
import torch
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
_LOG = logging.getLogger("axolotl.moe.patch")
def _patch_block_forward(block_cls, grouped_fn):
"""Replace block_cls.forward with grouped_fn preserving signature."""
block_cls.forward = grouped_fn
def apply_grouped_to_moe_blocks(cfg=None) -> None:
"""
Attempt to patch all known MoE block classes to use the torch_grouped backend
when cfg.moe_backend resolves to 'torch_grouped' and the op is available.
Falls back to original forwards otherwise.
"""
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
backend = get_moe_backend_name(preferred)
if backend != MOEBackend.TORCH_GROUPED:
_LOG.info(
f"moe_backend is '{backend}', not 'torch_grouped'; skipping grouped patches"
)
return
try:
from axolotl.kernels.moe import torch_grouped as _tg
except Exception:
_LOG.warning("torch_grouped backend import failed; skipping grouped patches")
return
if not _tg.available():
_LOG.warning(
"torch_grouped requested but unavailable (op smoke test failed); skipping grouped patches"
)
return
# Map of architecture key to (modeling module path, class name or list of class names)
model_mods = {
"mixtral": (
"transformers.models.mixtral.modeling_mixtral",
MOE_ARCH_BLOCK.get("mixtral"),
),
"qwen2_moe": (
"transformers.models.qwen2_moe.modeling_qwen2_moe",
MOE_ARCH_BLOCK.get("qwen2_moe"),
),
"qwen3_moe": (
"transformers.models.qwen3_moe.modeling_qwen3_moe",
MOE_ARCH_BLOCK.get("qwen3_moe"),
),
"jamba": (
"transformers.models.jamba.modeling_jamba",
MOE_ARCH_BLOCK.get("jamba"),
),
"deepseek_v2": (
"transformers.models.deepseek_v2.modeling_deepseek_v2",
MOE_ARCH_BLOCK.get("deepseek_v2"),
),
# Others may not follow standard paths; best-effort import
"dbrx": ("transformers.models.dbrx.modeling_dbrx", MOE_ARCH_BLOCK.get("dbrx")),
"jetmoe": (
"transformers.models.jetmoe.modeling_jetmoe",
MOE_ARCH_BLOCK.get("jetmoe"),
),
"gpt_oss": (
"transformers.models.gpt_oss.modeling_gpt_oss",
MOE_ARCH_BLOCK.get("gpt_oss"),
),
}
def make_grouped_forward(orig_forward):
@wraps(orig_forward)
def _grouped_forward(self, hidden_states: torch.Tensor, *args, **kwargs):
bsz, seqlen, hdim = hidden_states.shape
# expose parent block so grouped backend can access shared expert context
try:
self.experts._ax_parent_block_ref = weakref.ref(self)
except Exception:
pass
y, router_logits = _tg.moe_ffn_forward_grouped(
hidden_states, self.gate, self.experts, self.top_k
)
# One-time log per block instance indicating whether grouped engaged or fallback occurred
if not getattr(self, "_ax_grouped_wrapper_logged", False):
if y is None:
_LOG.warning(
"Grouped wrapper active but fell back to naive for %s",
self.__class__.__name__,
)
else:
_LOG.info(
f"Grouped wrapper engaged for {self.__class__.__name__} (top_k={self.top_k})"
)
self._ax_grouped_wrapper_logged = True
if y is None:
return orig_forward(self, hidden_states, *args, **kwargs)
return y, router_logits
return _grouped_forward
patched = 0
for key, (mod_path, cls_names) in model_mods.items():
if not cls_names:
continue
try:
import importlib
modeling = importlib.import_module(mod_path)
names = cls_names if isinstance(cls_names, list) else [cls_names]
for name in names:
if not hasattr(modeling, name):
continue
block_cls = getattr(modeling, name)
orig_forward = getattr(block_cls, "forward", None)
if orig_forward is None:
continue
_patch_block_forward(block_cls, make_grouped_forward(orig_forward))
patched += 1
_LOG.info(f"Patched MoE block for grouped GEMM: {mod_path}.{name}")
except Exception as e:
# Best effort; log and skip this entry
_LOG.warning(f"Skipping MoE patch for arch '{key}' ({mod_path}): {e}")
if patched == 0:
_LOG.warning(
"No MoE blocks patched for grouped GEMM; model may not use known MoE classes"
)
else:
_LOG.info(f"Grouped GEMM patches applied to {patched} MoE block class(es)")

View File

@@ -46,7 +46,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
]
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
def patch_for_multipack(model_type, model_name=None, has_remote_code=False, cfg=None):
if has_remote_code:
patch_remote(model_name)
elif hasattr(transformers, "modeling_flash_attention_utils"):
@@ -57,7 +57,7 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
patch_mixtral_moe_forward_zero3(cfg)
def patch_remote(model_name):

View File

@@ -132,6 +132,14 @@ class AxolotlInputConfig(
vllm: VllmConfig | None = Field(
default_factory=lambda: VllmConfig(),
)
moe_backend: Literal["auto", "torch_grouped", "naive"] | None = Field(
default=None,
json_schema_extra={
"description": "Mixture-of-Experts backend to use: 'auto', 'torch_grouped', or 'naive'. If not set, defaults to 'auto'.",
},
)
# Value is constrained by the Literal type; no normalization needed.
qat: QATConfig | None = None
quantization: PTQConfig | None = None
reward_model: bool | None = Field(

View File

@@ -0,0 +1,258 @@
import sys
import types
import torch
import torch.nn as nn
from axolotl.kernels.moe import (
backends as moe_backends,
torch_grouped as torch_grouped_module,
)
from axolotl.monkeypatch import moe_grouped
class DummyExperts(nn.Module):
def __init__(self, layers):
super().__init__()
self.layers = nn.ModuleList(layers)
self.num_experts = len(layers)
def __getitem__(self, idx):
return self.layers[idx]
class DummyQwenMLP(nn.Module):
def __init__(self, idx: int, hidden: int, intermediate: int):
super().__init__()
self.gate_up_proj = nn.Linear(hidden, 2 * intermediate, bias=False)
self.down_proj = nn.Linear(intermediate, hidden, bias=False)
nn.init.constant_(self.gate_up_proj.weight, float(idx + 1))
nn.init.constant_(self.down_proj.weight, float((idx + 1) * 10))
class DummyQwenExpert(nn.Module):
def __init__(self, idx: int, hidden: int, intermediate: int):
super().__init__()
self.mlp = DummyQwenMLP(idx, hidden, intermediate)
def _make_transformers_stub(monkeypatch, block_cls):
# ensure we start from the original forward for each test
if block_cls is DummyMixtralBlock:
DummyMixtralBlock.forward = _DUMMY_MIXTRAL_ORIG_FORWARD
transformers_mod = types.ModuleType("transformers")
models_mod = types.ModuleType("transformers.models")
mixtral_mod = types.ModuleType("transformers.models.mixtral")
modeling_mixtral = types.ModuleType("transformers.models.mixtral.modeling_mixtral")
modeling_mixtral.MixtralSparseMoeBlock = block_cls
transformers_mod.models = models_mod
models_mod.mixtral = mixtral_mod
mixtral_mod.modeling_mixtral = modeling_mixtral
monkeypatch.setitem(sys.modules, "transformers", transformers_mod)
monkeypatch.setitem(sys.modules, "transformers.models", models_mod)
monkeypatch.setitem(sys.modules, "transformers.models.mixtral", mixtral_mod)
monkeypatch.setitem(
sys.modules,
"transformers.models.mixtral.modeling_mixtral",
modeling_mixtral,
)
def test_grouped_uses_per_expert_nested_modules(monkeypatch):
hidden = 4
intermediate = 2
num_experts = 2
experts = DummyExperts(
[DummyQwenExpert(i, hidden, intermediate) for i in range(num_experts)]
)
gate = nn.Linear(hidden, num_experts, bias=False)
nn.init.zeros_(gate.weight)
captured = []
def fake_grouped_mm(As, Bs, dtype):
captured.append([b.detach().clone() for b in Bs])
return [
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
for a, b in zip(As, Bs, strict=False)
]
monkeypatch.setattr(torch_grouped_module, "_call_grouped_mm", fake_grouped_mm)
hidden_states = torch.randn(1, 2, hidden)
y, router_logits = torch_grouped_module.moe_ffn_forward_grouped(
hidden_states, gate, experts, top_k=2
)
assert y is not None
assert router_logits is not None
assert captured, "Grouped GEMM path should have been invoked"
first_call = captured[0]
expected0 = experts[0].mlp.gate_up_proj.weight.t()
expected1 = experts[1].mlp.gate_up_proj.weight.t()
assert torch.equal(first_call[0], expected0)
assert torch.equal(first_call[1], expected1)
assert not torch.equal(first_call[0], first_call[1])
def test_grouped_accepts_module_list_experts(monkeypatch):
hidden = 4
intermediate = 2
experts = nn.ModuleList(
[DummyQwenExpert(i, hidden, intermediate) for i in range(2)]
)
gate = nn.Linear(hidden, len(experts), bias=False)
nn.init.zeros_(gate.weight)
calls = {"count": 0}
def fake_grouped_mm(As, Bs, dtype):
calls["count"] += 1
return [
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
for a, b in zip(As, Bs, strict=False)
]
monkeypatch.setattr(torch_grouped_module, "_call_grouped_mm", fake_grouped_mm)
hidden_states = torch.randn(1, 2, hidden)
y, router_logits = torch_grouped_module.moe_ffn_forward_grouped(
hidden_states, gate, experts, top_k=2
)
assert y is not None
assert router_logits is not None
assert calls["count"] > 0
class _DummyCfg:
moe_backend = "torch_grouped"
class DummyMixtralBlock(nn.Module):
def __init__(self):
super().__init__()
self.top_k = 1
self.gate = lambda x: x
self.experts = object()
self._calls = []
def forward(self, hidden_states: torch.Tensor, attention_mask=None):
self._calls.append((hidden_states, attention_mask))
tokens = hidden_states.shape[0] * hidden_states.shape[1]
router = torch.ones(
tokens, 2, device=hidden_states.device, dtype=hidden_states.dtype
)
return hidden_states + 5, router
_DUMMY_MIXTRAL_ORIG_FORWARD = DummyMixtralBlock.forward
def test_apply_grouped_forward_handles_args(monkeypatch):
_make_transformers_stub(monkeypatch, DummyMixtralBlock)
import axolotl.common.architectures as arch
original_map = arch.MOE_ARCH_BLOCK.copy()
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, "mixtral", "MixtralSparseMoeBlock")
for key in list(original_map.keys()):
if key != "mixtral":
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, key, None)
monkeypatch.setattr(
moe_grouped,
"get_moe_backend_name",
lambda preferred=None: moe_backends.MOEBackend.TORCH_GROUPED,
)
results = {}
def fake_grouped_forward(hidden_states, gate, experts, top_k):
results["called"] = True
router = torch.zeros(
hidden_states.shape[0] * hidden_states.shape[1],
2,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
return hidden_states + 1, router
monkeypatch.setattr(torch_grouped_module, "available", lambda: True)
monkeypatch.setattr(
torch_grouped_module,
"moe_ffn_forward_grouped",
fake_grouped_forward,
)
cfg = _DummyCfg()
moe_grouped.apply_grouped_to_moe_blocks(cfg)
block = DummyMixtralBlock()
hidden_states = torch.ones(1, 2, 3)
mask = torch.zeros(1, 2)
out, router = block.forward(hidden_states, attention_mask=mask)
assert results.get("called") is True
assert torch.equal(out, hidden_states + 1)
assert router.shape[0] == hidden_states.shape[0] * hidden_states.shape[1]
def test_apply_grouped_forward_fallback(monkeypatch):
_make_transformers_stub(monkeypatch, DummyMixtralBlock)
import axolotl.common.architectures as arch
original_map = arch.MOE_ARCH_BLOCK.copy()
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, "mixtral", "MixtralSparseMoeBlock")
for key in list(original_map.keys()):
if key != "mixtral":
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, key, None)
monkeypatch.setattr(
moe_grouped,
"get_moe_backend_name",
lambda preferred=None: moe_backends.MOEBackend.TORCH_GROUPED,
)
monkeypatch.setattr(torch_grouped_module, "available", lambda: True)
monkeypatch.setattr(
torch_grouped_module,
"moe_ffn_forward_grouped",
lambda *args, **kwargs: (None, None),
)
cfg = _DummyCfg()
moe_grouped.apply_grouped_to_moe_blocks(cfg)
block = DummyMixtralBlock()
hidden_states = torch.ones(1, 2, 3)
mask = torch.zeros(1, 2)
out, router = block.forward(hidden_states, attention_mask=mask)
assert torch.equal(out, hidden_states + 5)
assert router.shape[0] == hidden_states.shape[0] * hidden_states.shape[1]
assert block._calls, "Original forward should have been invoked"
call_hidden, call_mask = block._calls[-1]
assert torch.equal(call_hidden, hidden_states)
assert torch.equal(call_mask, mask)
def test_get_moe_backend_name_prefers_probe(monkeypatch):
monkeypatch.setattr(moe_backends, "_probe_torch_grouped", lambda: True)
assert moe_backends.get_moe_backend_name() == moe_backends.MOEBackend.TORCH_GROUPED
def test_get_moe_backend_name_falls_back(monkeypatch):
warnings_captured = []
def fake_warn(msg, *, stacklevel=None): # noqa: ARG001
warnings_captured.append(msg)
monkeypatch.setattr(moe_backends, "_probe_torch_grouped", lambda: False)
monkeypatch.setattr(moe_backends.warnings, "warn", fake_warn)
backend = moe_backends.get_moe_backend_name("torch_grouped")
assert backend == moe_backends.MOEBackend.NAIVE
assert warnings_captured, "Expected warning when torch_grouped unavailable"

6545
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff