Compare commits

...

37 Commits

Author SHA1 Message Date
Dan Saunders
dd85358543 default mg 2025-09-25 16:30:23 -04:00
Dan Saunders
55d98db0d0 fix 2025-09-25 16:08:35 -04:00
Dan Saunders
d90ade3b1b fix 2025-09-25 15:55:08 -04:00
Dan Saunders
824a641cee uniform routing default 2025-09-25 15:47:23 -04:00
Dan Saunders
e003a05177 narrow sweep; compare both backends 2025-09-25 14:54:03 -04:00
Dan Saunders
91393c4dc8 allocator 2025-09-25 14:27:34 -04:00
Dan Saunders
d578c53603 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
4db7a21ff7 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
3b2e05c563 update to new api 2025-09-25 14:27:34 -04:00
Dan Saunders
1037ca3a97 update to new api 2025-09-25 14:27:34 -04:00
Dan Saunders
6369dcd7b8 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
a81612305c fix? 2025-09-25 14:27:34 -04:00
Dan Saunders
d0da67eb17 add mg kernel backend 2025-09-25 14:27:34 -04:00
Dan Saunders
8a1f5ae940 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
146ca48cba vram 2025-09-25 14:27:34 -04:00
Dan Saunders
fd312f6058 dtype 2025-09-25 14:27:34 -04:00
Dan Saunders
ab8fa56b16 dtype 2025-09-25 14:27:34 -04:00
Dan Saunders
1640cd4006 delete config 2025-09-25 14:27:34 -04:00
Dan Saunders
3277d44d71 cfg value 2025-09-25 14:27:34 -04:00
Dan Saunders
d3e1b0ef1a small deepseek script 2025-09-25 14:27:34 -04:00
Dan Saunders
5b97633faa Fix 2025-09-25 14:27:34 -04:00
Dan Saunders
94cbc6d42d log device, dtype 2025-09-25 14:27:34 -04:00
Dan Saunders
493616fc3d reprod tt table 2025-09-25 14:27:34 -04:00
Dan Saunders
d2b25c7327 grid sweep 2025-09-25 14:27:34 -04:00
Dan Saunders
b670c45276 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
61faf4cbe4 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
8d8fa834a2 sweep 2025-09-25 14:27:34 -04:00
Dan Saunders
9d69c6fb3e Fix 2025-09-25 14:27:34 -04:00
Dan Saunders
92f2f6e73c dtype fix 2025-09-25 14:27:34 -04:00
Dan Saunders
e5d2aebe16 uniform routing: 2025-09-25 14:27:34 -04:00
Dan Saunders
4ab9e3f58b add logs 2025-09-25 14:27:34 -04:00
Dan Saunders
5788832812 simplify 2025-09-25 14:27:34 -04:00
Dan Saunders
db782430f8 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
5c74edeefe token shuffle kernel 2025-09-25 14:27:34 -04:00
Dan Saunders
18269ee6a9 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
6a45d804f9 glue 2025-09-25 14:27:34 -04:00
Dan Saunders
95e607574a vendor torchtitan moe kernels 2025-09-25 14:27:34 -04:00
20 changed files with 3329 additions and 0 deletions

1
scripts/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Utility scripts package."""

View File

@@ -0,0 +1,5 @@
"""Benchmark helpers."""
from .deepseek_v3_moe import ACCURACY_TOLERANCE, DTYPE_MAP, benchmark_deepseek_v3
__all__ = ["benchmark_deepseek_v3", "DTYPE_MAP", "ACCURACY_TOLERANCE"]

View File

@@ -0,0 +1,100 @@
#!/usr/bin/env python3
"""Instantiate a ~8.3B DeepSeek-V3 MoE model with random weights.
Run this on a GPU-equipped machine (e.g. 1× NVL H100) so the dense
initialization completes quickly:
python scripts/benchmarks/build_deepseek_v3_8b.py --output deepseek-v3-8b-moe
"""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from transformers import DeepseekV3Config, DeepseekV3ForCausalLM
DTYPE_MAP = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}
def build_config() -> DeepseekV3Config:
"""Return a DeepSeek V3 configuration totaling ~8.3B parameters."""
return DeepseekV3Config(
vocab_size=32_000,
hidden_size=3_072,
intermediate_size=8_192,
moe_intermediate_size=2_560,
num_hidden_layers=20,
num_attention_heads=24,
num_key_value_heads=24,
n_routed_experts=18,
num_experts_per_tok=4,
n_group=6,
topk_group=4,
kv_lora_rank=192,
q_lora_rank=384,
max_position_embeddings=2_048,
rope_theta=10_000.0,
rope_interleave=True,
hidden_act="silu",
initializer_range=0.02,
attention_dropout=0.0,
attention_bias=False,
n_shared_experts=1,
routed_scaling_factor=2.5,
norm_topk_prob=True,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--output",
type=Path,
required=True,
help="Directory to save the generated model",
)
parser.add_argument(
"--dtype",
default="bfloat16",
choices=DTYPE_MAP.keys(),
help="Storage dtype for the checkpoint",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Torch RNG seed for reproducibility",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
torch.manual_seed(args.seed)
output_dir = args.output
output_dir.mkdir(parents=True, exist_ok=True)
config = build_config()
model = DeepseekV3ForCausalLM(config)
dtype = DTYPE_MAP[args.dtype]
model.to(dtype=dtype)
param_count = sum(p.numel() for p in model.parameters())
print(f"Initialized DeepSeek-V3 MoE with {param_count / 1e9:.3f}B parameters")
model.save_pretrained(output_dir, safe_serialization=True)
config.save_pretrained(output_dir)
print(f"Saved model and config to {output_dir.resolve()}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,190 @@
#!/usr/bin/env python
"""Reproduce TorchTitan CG GEMM timings for selected problem sizes."""
from __future__ import annotations
import argparse
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
import torch
CURRENT_DIR = Path(__file__).resolve().parent
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else:
raise SystemExit("Unable to locate axolotl repository root for imports")
from axolotl.kernels.moe import (
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
@dataclass
class Scenario:
num_groups: int
m: int
n: int
k: int
SCENARIOS: tuple[Scenario, ...] = (
Scenario(num_groups=4, m=8192, n=4096, k=7168),
Scenario(num_groups=4, m=8192, n=7168, k=2048),
Scenario(num_groups=8, m=4096, n=4096, k=7168),
Scenario(num_groups=8, m=4096, n=7168, k=2048),
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--device", default="cuda", choices=["cuda"], help="Execution device"
)
parser.add_argument(
"--dtype",
default="bf16",
choices=["bf16", "fp16", "fp32"],
help="Computation dtype",
)
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=20, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--group-size",
type=int,
default=128,
help="GROUP_SIZE_M expected by the kernel",
)
return parser.parse_args()
def pick_dtype(name: str) -> torch.dtype:
return {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[name]
def make_indices(
num_groups: int, group_size: int, device: torch.device
) -> torch.Tensor:
indices = torch.arange(num_groups, device=device, dtype=torch.int32)
return indices.repeat_interleave(group_size)
def timed_call(fn, *args, warmup: int, iters: int) -> float:
for _ in range(warmup):
fn(*args)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iters):
fn(*args)
torch.cuda.synchronize()
return (time.perf_counter() - start) * 1000.0 / iters
def run_scenario(
scenario: Scenario,
*,
dtype: torch.dtype,
device: torch.device,
warmup: int,
iters: int,
group_size_m: int,
) -> dict:
if scenario.m % scenario.num_groups != 0:
raise ValueError(
f"M ({scenario.m}) not divisible by groups ({scenario.num_groups})"
)
group_size = scenario.m // scenario.num_groups
if group_size % group_size_m != 0:
raise ValueError(
f"Group size {group_size} must be a multiple of GROUP_SIZE_M ({group_size_m}) for the Triton kernel"
)
inputs = torch.randn(scenario.m, scenario.k, device=device, dtype=dtype)
weights = torch.randn(
scenario.num_groups, scenario.n, scenario.k, device=device, dtype=dtype
)
indices = make_indices(scenario.num_groups, group_size, device)
def persistent():
return cg_grouped_gemm_forward(inputs, weights, indices, group_size_m)
def baseline():
return cg_grouped_gemm_forward_dynamic(inputs, weights, indices, group_size_m)
persistent_ms = timed_call(persistent, warmup=warmup, iters=iters)
baseline_ms = timed_call(baseline, warmup=warmup, iters=iters)
return {
"scenario": scenario,
"persistent_ms": persistent_ms,
"baseline_ms": baseline_ms,
"speedup": baseline_ms / persistent_ms if persistent_ms > 0 else float("nan"),
}
def main() -> None: # pragma: no cover - utility script
args = parse_args()
torch.manual_seed(args.seed)
if args.device != "cuda" or not torch.cuda.is_available():
raise SystemExit("CUDA device required for this benchmark")
dtype = pick_dtype(args.dtype)
device = torch.device(args.device)
print(
f"device={device} dtype={dtype} warmup={args.warmup} iters={args.iters} group_size={args.group_size}"
)
print(
f"{'groups':>7} {'m':>7} {'n':>7} {'k':>7} {'persistent':>12} {'baseline':>12} {'speedup':>8}"
)
for result in run_all(
SCENARIOS,
dtype=dtype,
device=device,
warmup=args.warmup,
iters=args.iters,
group_size_m=args.group_size,
):
scen = result["scenario"]
print(
f"{scen.num_groups:>7} {scen.m:>7} {scen.n:>7} {scen.k:>7}"
f" {result['persistent_ms']:>11.3f} ms {result['baseline_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
)
def run_all(
scenarios: Iterable[Scenario],
*,
dtype: torch.dtype,
device: torch.device,
warmup: int,
iters: int,
group_size_m: int,
) -> Iterable[dict]:
for scenario in scenarios:
yield run_scenario(
scenario,
dtype=dtype,
device=device,
warmup=warmup,
iters=iters,
group_size_m=group_size_m,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,301 @@
#!/usr/bin/env python
# mypy: ignore-errors
"""Microbenchmark for DeepSeek V3 MoE block comparing baseline vs Triton CG kernels."""
from __future__ import annotations
import argparse
import sys
import time
from pathlib import Path
from types import MethodType
import torch
try:
from transformers.models.deepseek_v3.configuration_deepseek_v3 import (
DeepseekV3Config,
)
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
except ImportError as exc: # pragma: no cover - utility script
raise SystemExit(
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
) from exc
CURRENT_DIR = Path(__file__).resolve().parent
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports")
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe # noqa: E402
ACCURACY_TOLERANCE = 5e-3
DTYPE_MAP = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--seq-len", type=int, default=2048, help="sequence length")
parser.add_argument("--hidden-size", type=int, default=4096, help="MoE hidden size")
parser.add_argument(
"--moe-intermediate-size",
type=int,
default=8192,
help="MoE intermediate projection size",
)
parser.add_argument(
"--n-experts",
type=int,
default=256,
help="Number of routed experts",
)
parser.add_argument(
"--top-k",
type=int,
default=8,
help="Number of experts per token",
)
parser.add_argument(
"--groups",
type=int,
default=8,
help="Router groups (must divide n-experts)",
)
parser.add_argument(
"--dtype",
choices=DTYPE_MAP.keys(),
default="bf16",
help="Computation dtype",
)
parser.add_argument(
"--device",
default="auto",
choices=["auto", "cpu", "cuda"],
help="Execution device",
)
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=25, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--uniform-routing",
action="store_true",
help="Override router to distribute tokens evenly across experts",
)
parser.add_argument(
"--group-size",
type=int,
default=128,
help="GROUP_SIZE_M used by the Triton kernel",
)
parser.add_argument(
"--backend",
choices=["cg", "mg"],
default="mg",
help="MoE kernel backend to benchmark",
)
return parser.parse_args()
def resolve_device(requested: str) -> torch.device:
if requested == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(requested)
def build_module(args: argparse.Namespace) -> DeepseekV3MoE:
config = DeepseekV3Config(
hidden_size=args.hidden_size,
intermediate_size=args.moe_intermediate_size,
moe_intermediate_size=args.moe_intermediate_size,
n_routed_experts=args.n_experts,
num_experts_per_tok=args.top_k,
n_group=args.groups,
topk_group=max(1, min(args.groups, args.top_k)),
n_shared_experts=1,
)
module = DeepseekV3MoE(config)
module.eval()
return module
@torch.no_grad()
def timed_forward(
module: DeepseekV3MoE, inputs: torch.Tensor, iters: int, warmup: int
) -> float:
for _ in range(warmup):
module(inputs)
if inputs.is_cuda:
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iters):
module(inputs)
if inputs.is_cuda:
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
return (elapsed / iters) * 1000.0
def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
torch.manual_seed(args.seed)
device = resolve_device(args.device)
dtype = DTYPE_MAP[args.dtype]
if args.n_experts % args.groups != 0:
raise SystemExit("n-experts must be divisible by groups")
if args.top_k > args.n_experts:
raise SystemExit("top-k cannot exceed number of experts")
if device.type == "cuda" and not torch.cuda.is_available():
raise SystemExit("CUDA requested but not available")
baseline_module = build_module(args)
original_moe = getattr(
DeepseekV3MoE,
"_axolotl_triton_original_moe",
DeepseekV3MoE.moe,
)
baseline_module.moe = MethodType(original_moe, baseline_module)
state_dict = baseline_module.state_dict()
patch_deepseek_v3_moe(group_size_m=args.group_size, backend=args.backend)
patched_module = build_module(args)
patched_module.load_state_dict(state_dict)
baseline_module.to(device=device, dtype=dtype)
patched_module.to(device=device, dtype=dtype)
tokens = args.batch * args.seq_len
routed_tokens = tokens * args.top_k
avg_tokens_per_expert = routed_tokens / args.n_experts
inputs = torch.randn(
args.batch,
args.seq_len,
args.hidden_size,
device=device,
dtype=dtype,
)
with torch.no_grad():
flat_inputs = inputs.view(-1, args.hidden_size)
if args.uniform_routing:
total_assignments = flat_inputs.size(0) * args.top_k
base = total_assignments // args.n_experts
remainder = total_assignments % args.n_experts
counts = torch.full(
(args.n_experts,),
base,
dtype=torch.int64,
device=device,
)
if remainder:
counts[:remainder] += 1
assignments = torch.repeat_interleave(
torch.arange(args.n_experts, device=device), counts
)
assignments = assignments[torch.randperm(assignments.size(0))]
topk_idx = assignments.view(flat_inputs.size(0), args.top_k)
else:
topk_idx, _ = patched_module.gate(flat_inputs)
tokens_per_expert = torch.bincount(
topk_idx.reshape(-1), minlength=args.n_experts
)
min_tokens = int(tokens_per_expert.min().item())
max_tokens = int(tokens_per_expert.max().item())
if args.uniform_routing:
weights = torch.full(
topk_idx.shape,
1.0 / args.top_k,
device=device,
dtype=torch.float32,
)
def _uniform_gate(self, hidden_states):
flat = hidden_states.view(-1, hidden_states.shape[-1])
token_count = flat.shape[0]
return topk_idx[:token_count], weights[:token_count]
patched_module.gate.forward = _uniform_gate.__get__(
patched_module.gate, patched_module.gate.__class__
)
baseline_module.gate.forward = _uniform_gate.__get__(
baseline_module.gate, baseline_module.gate.__class__
)
with torch.no_grad():
ref_output = baseline_module(inputs)
patched_output = patched_module(inputs)
max_diff = (ref_output - patched_output).abs().max().item()
baseline_vram = patched_vram = None
if device.type == "cuda":
torch.cuda.reset_peak_memory_stats(device)
baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup)
if device.type == "cuda":
baseline_vram = torch.cuda.max_memory_allocated(device)
torch.cuda.reset_peak_memory_stats(device)
patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup)
if device.type == "cuda":
patched_vram = torch.cuda.max_memory_allocated(device)
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
return {
"device": device,
"backend": args.backend,
"dtype": dtype,
"baseline_ms": baseline_ms,
"patched_ms": patched_ms,
"speedup": speedup,
"max_diff": max_diff,
"routed_tokens": routed_tokens,
"avg_tokens": avg_tokens_per_expert,
"min_tokens": min_tokens,
"max_tokens": max_tokens,
"baseline_vram": baseline_vram,
"patched_vram": patched_vram,
"accuracy_ok": max_diff <= ACCURACY_TOLERANCE,
}
def main() -> None: # pragma: no cover - CLI entrypoint
args = parse_args()
result = benchmark_deepseek_v3(args)
print(
f"Device={result['device'].type} dtype={result['dtype']} backend={result['backend']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
)
print(
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
)
print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}")
if result["baseline_vram"] is not None:
print(
f"VRAM baseline={result['baseline_vram'] / (1024**2):.1f} MiB | patched={result['patched_vram'] / (1024**2):.1f} MiB"
)
print(
f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}"
)
print(f"Max |Δ| between outputs: {result['max_diff']:.2e}")
if not result["accuracy_ok"]:
raise RuntimeError(
f"Accuracy check failed: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,275 @@
#!/usr/bin/env python
# mypy: ignore-errors
"""Sweep a set of DeepSeek V3 MoE benchmark configurations."""
from __future__ import annotations
import argparse
import csv
import logging
import sys
from pathlib import Path
from types import SimpleNamespace
CURRENT_DIR = Path(__file__).resolve().parent
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports")
from scripts.benchmarks.deepseek_v3_moe import ( # noqa: E402
ACCURACY_TOLERANCE,
DTYPE_MAP,
benchmark_deepseek_v3,
)
LOG = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--dtype",
choices=DTYPE_MAP.keys(),
default="bf16",
help="Computation dtype for all benchmarks",
)
parser.add_argument(
"--device",
default="auto",
choices=["auto", "cpu", "cuda"],
help="Execution device",
)
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--group-size",
type=int,
help="Override GROUP_SIZE_M for every configuration",
)
parser.add_argument(
"--backends",
default="mg",
help="Comma separated list of backends to benchmark (subset of cg,mg)",
)
parser.add_argument(
"--no-uniform-routing",
action="store_true",
help="Disable uniform routing for every configuration",
)
parser.add_argument(
"--include-mixtral-long",
action="store_true",
help="Add an 8×8192 Mixtral-style run to the sweep",
)
parser.add_argument(
"--output",
type=Path,
help="Optional CSV file to store results",
)
return parser.parse_args()
def make_namespace(
base: dict, args: argparse.Namespace, backend: str
) -> SimpleNamespace:
combined = dict(base)
combined.update(
{
"dtype": args.dtype,
"device": args.device,
"backend": backend,
"warmup": args.warmup,
"iters": args.iters,
"seed": args.seed,
"uniform_routing": not args.no_uniform_routing,
}
)
if args.group_size is not None:
combined["group_size"] = args.group_size
return SimpleNamespace(**combined)
ARCHETYPES = (
(
"mixtral",
{
"hidden_size": 4096,
"moe_intermediate_size": 14336,
"n_experts": 8,
"top_k": 2,
"groups": 1,
"group_size": 128,
},
[(4, 2048), (8, 4096)],
),
(
"qwen",
{
"hidden_size": 6144,
"moe_intermediate_size": 24576,
"n_experts": 16,
"top_k": 4,
"groups": 8,
"group_size": 128,
},
[(4, 4096), (8, 8192)],
),
(
"deepseek_v3",
{
"hidden_size": 12288,
"moe_intermediate_size": 49152,
"n_experts": 128,
"top_k": 8,
"groups": 16,
"group_size": 128,
},
[(4, 4096), (8, 8192)],
),
)
MIXTRAL_LONG_SHAPES = [(8, 8192)]
def main() -> None: # pragma: no cover - utility script
args = parse_args()
grid = []
for label, base_cfg, shapes in ARCHETYPES:
for batch, seq_len in shapes:
cfg = {
"label": label,
"batch": batch,
"seq_len": seq_len,
**base_cfg,
}
if cfg["n_experts"] % cfg["groups"] != 0 or cfg["top_k"] > cfg["n_experts"]:
continue
grid.append(cfg)
if args.include_mixtral_long:
base_cfg = ARCHETYPES[0][1]
for batch, seq_len in MIXTRAL_LONG_SHAPES:
grid.append(
{
"label": "mixtral_long",
"batch": batch,
"seq_len": seq_len,
**base_cfg,
}
)
if not grid:
raise SystemExit("No valid parameter combinations produced")
header = (
"model",
"batch",
"seq_len",
"hidden_size",
"moe_intermediate",
"n_experts",
"top_k",
"groups",
"backend",
"baseline_ms",
"patched_ms",
"speedup",
"baseline_vram_mib",
"patched_vram_mib",
"min_tokens",
"max_tokens",
"max_diff",
"accuracy_ok",
)
rows = []
raw_backends = [
token.strip() for token in args.backends.split(",") if token.strip()
]
if not raw_backends:
raw_backends = ["mg"]
valid_backends = []
for backend in raw_backends:
if backend not in {"cg", "mg"}:
raise SystemExit(f"Unsupported backend '{backend}' requested")
if backend not in valid_backends:
valid_backends.append(backend)
uniform_flag = not args.no_uniform_routing
print(
f"Running sweep on device={args.device} dtype={args.dtype} backends={tuple(valid_backends)} uniform_routing={uniform_flag}"
)
print(
f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'diff':>10} {'acc':>5}"
)
for cfg in grid:
for backend in valid_backends:
ns = make_namespace(cfg, args, backend)
result = benchmark_deepseek_v3(ns)
baseline_vram_mib = (
result["baseline_vram"] / (1024**2)
if result["baseline_vram"] is not None
else float("nan")
)
patched_vram_mib = (
result["patched_vram"] / (1024**2)
if result["patched_vram"] is not None
else float("nan")
)
rows.append(
(
cfg["label"],
cfg["batch"],
cfg["seq_len"],
cfg["hidden_size"],
cfg["moe_intermediate_size"],
cfg["n_experts"],
cfg["top_k"],
cfg["groups"],
backend,
result["baseline_ms"],
result["patched_ms"],
result["speedup"],
baseline_vram_mib,
patched_vram_mib,
result["min_tokens"],
result["max_tokens"],
result["max_diff"],
result["accuracy_ok"],
)
)
status = "OK" if result["accuracy_ok"] else "FAIL"
print(
f"{cfg['label']:>10} {cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {backend:>8}"
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {result['max_diff']:>10.3e} {status:>5}"
)
if not result["accuracy_ok"]:
LOG.warning(
"Accuracy tolerance exceeded for %s backend=%s: diff=%.3e (> %.1e)",
cfg["label"],
backend,
result["max_diff"],
ACCURACY_TOLERANCE,
)
if args.output:
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", newline="") as fp:
writer = csv.writer(fp)
writer.writerow(header)
writer.writerows(rows)
print(f"Results written to {args.output}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,21 @@
"""Mixture-of-Experts kernel implementations."""
from .indices import generate_permute_indices
from .tt_cg_gemm import (
ContiguousGroupedGEMM,
ContiguousGroupedGEMMForwardOnly,
cg_grouped_gemm,
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
from .tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
__all__ = [
"cg_grouped_gemm",
"cg_grouped_gemm_forward",
"cg_grouped_gemm_forward_dynamic",
"ContiguousGroupedGEMM",
"ContiguousGroupedGEMMForwardOnly",
"generate_permute_indices",
"mg_grouped_gemm",
]

View File

@@ -0,0 +1,5 @@
"""Token permutation utilities for grouped MoE kernels."""
from .indices import generate_permute_indices
__all__ = ["generate_permute_indices"]

View File

@@ -0,0 +1,144 @@
"""Vendored token permutation kernels from TorchTitan."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import triton
import triton.language as tl
__all__ = ["generate_permute_indices"]
@triton.jit
def _fill_indices_kernel(
tokens_per_expert_group_ptr,
start_index_values_ptr,
write_offsets_ptr,
output_ptr,
experts_per_rank: tl.constexpr,
num_ranks: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_programs = tl.num_programs(axis=0)
for expert_id in range(pid, experts_per_rank, num_programs):
write_offset = tl.load(write_offsets_ptr + expert_id)
for r in range(num_ranks):
idx = r * experts_per_rank + expert_id
start_index = tl.load(start_index_values_ptr + idx)
length = tl.load(tokens_per_expert_group_ptr + idx)
offsets = tl.arange(0, BLOCK_SIZE)
for chunk_start in range(0, length, BLOCK_SIZE):
chunk_offsets = chunk_start + offsets
mask = chunk_offsets < length
values = start_index + chunk_offsets
dest_indices = write_offset + chunk_offsets
tl.store(output_ptr + dest_indices, values, mask=mask)
write_offset += length
def fill_indices_wrapper(
tokens_per_expert_group: torch.Tensor,
start_index_values: torch.Tensor,
write_offsets: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
block_size: int = 128,
max_blocks: int = 1024,
):
permuted_indices = torch.full(
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
)
num_blocks = min(experts_per_rank, max_blocks)
grid = (num_blocks,)
_fill_indices_kernel[grid](
tokens_per_expert_group,
start_index_values,
write_offsets,
permuted_indices,
experts_per_rank,
num_ranks,
BLOCK_SIZE=block_size,
)
return permuted_indices
def fill_indices_cpu(
tokens_per_expert_group: torch.Tensor,
start_index_values: torch.Tensor,
write_offsets: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
):
permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
for expert_id in range(experts_per_rank):
write_start = write_offsets[expert_id].item()
for r in range(num_ranks):
idx = r * experts_per_rank + expert_id
start_index = start_index_values[idx].item()
length = tokens_per_expert_group[idx].item()
if length > 0:
end_idx = min(write_start + length, max_len)
permuted_indices[write_start:end_idx] = torch.arange(
start_index,
start_index + (end_idx - write_start),
dtype=torch.int32,
)
write_start += length
return permuted_indices
def generate_permute_indices(
tokens_per_expert_group: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
alignment: int,
use_cpu: bool = False,
):
start_index_values = (
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
)
total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment)
m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to(
torch.int32
)
m_offsets = torch.cumsum(m_sizes, 0)
write_offsets = m_offsets - m_sizes
if use_cpu:
permuted_indices = fill_indices_cpu(
tokens_per_expert_group,
start_index_values,
write_offsets,
experts_per_rank,
num_ranks,
max_len,
)
else:
permuted_indices = fill_indices_wrapper(
tokens_per_expert_group,
start_index_values,
write_offsets,
experts_per_rank,
num_ranks,
max_len,
)
return permuted_indices, m_sizes, m_offsets.to(torch.int32)

View File

@@ -0,0 +1,17 @@
"""Vendored Triton contiguous grouped GEMM kernels from TorchTitan."""
from .cg_backward import ContiguousGroupedGEMM
from .cg_forward import (
ContiguousGroupedGEMM as ContiguousGroupedGEMMForwardOnly,
cg_grouped_gemm,
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
__all__ = [
"cg_grouped_gemm",
"cg_grouped_gemm_forward",
"cg_grouped_gemm_forward_dynamic",
"ContiguousGroupedGEMM",
"ContiguousGroupedGEMMForwardOnly",
]

View File

@@ -0,0 +1,290 @@
"""Vendored backward pass for Triton contiguous grouped GEMM."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import triton
import triton.language as tl
from .cg_forward import cg_grouped_gemm_forward
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
GROUP_SIZE_M = 128
@triton.autotune(
configs=STANDARD_CONFIGS,
key=["M_TOTAL", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_cg_backward_dx(
grad_output_ptr,
b_ptr,
grad_input_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
):
"""Compute gradients with respect to inputs."""
pid = tl.program_id(0)
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
tile_m = pid // num_k_tiles
tile_k = pid % num_k_tiles
m_start = tile_m * BLOCK_SIZE_M
k_start = tile_k * BLOCK_SIZE_K
if m_start < M_TOTAL:
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
mask_m = offs_m < M_TOTAL
mask_k = offs_k < K
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
grad_input = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_K], dtype=tl.float32)
for n in range(0, N, BLOCK_SIZE_N):
offs_n = tl.arange(0, BLOCK_SIZE_N) + n
mask_n = offs_n < N
mask_go = mask_m[:, None] & mask_n[None, :]
mask_w = mask_n[:, None] & mask_k[None, :]
go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
w = tl.load(w_ptrs, mask=mask_w, other=0.0).to(tl.float32)
grad_input += tl.dot(go, w)
grad_input_ptrs = grad_input_ptr + offs_m[:, None] * K + offs_k[None, :]
mask_gi = mask_m[:, None] & mask_k[None, :]
tl.store(grad_input_ptrs, grad_input, mask=mask_gi)
@triton.jit
def _kernel_cg_backward_dw(
grad_output_ptr,
inputs_ptr,
grad_weights_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
):
"""Simplified kernel for expert weight gradients."""
pid = tl.program_id(0)
expert_id = pid // ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
position_id = pid % ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
if expert_id < NUM_EXPERTS:
n_tiles = K // BLOCK_SIZE_K
tile_n = position_id // n_tiles
tile_k = position_id % n_tiles
n_start = tile_n * BLOCK_SIZE_N
k_start = tile_k * BLOCK_SIZE_K
if n_start < N and k_start < K:
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
mask_n = offs_n < N
mask_k = offs_k < K
grad_weights = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_K], dtype=tl.float32)
for group_idx in range(0, M_TOTAL // GROUP_SIZE_M):
group_start = group_idx * GROUP_SIZE_M
group_expert = tl.load(indices_ptr + group_start)
if group_expert == expert_id:
for m_offset in range(0, GROUP_SIZE_M, BLOCK_SIZE_M):
m_start = group_start + m_offset
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
mask_m = offs_m < min(group_start + GROUP_SIZE_M, M_TOTAL)
go_ptrs = (
grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
)
mask_go = mask_m[:, None] & mask_n[None, :]
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :]
mask_in = mask_m[:, None] & mask_k[None, :]
inp = tl.load(in_ptrs, mask=mask_in, other=0.0).to(tl.float32)
go_t = tl.trans(go)
grad_weights += tl.dot(go_t, inp)
grad_w_ptrs = (
grad_weights_ptr
+ expert_id * N * K
+ offs_n[:, None] * K
+ offs_k[None, :]
)
mask_gw = mask_n[:, None] & mask_k[None, :]
tl.store(grad_w_ptrs, grad_weights, mask=mask_gw)
def cg_grouped_gemm_backward_weights(
grad_output: torch.Tensor,
inputs: torch.Tensor,
expert_indices: torch.Tensor,
num_experts: int,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Backward pass for expert weights."""
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
assert inputs.is_contiguous(), "Inputs tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, N = grad_output.shape
_, K = inputs.shape
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
grad_weights = torch.zeros(
(num_experts, N, K), device=grad_output.device, dtype=grad_output.dtype
)
block_size_n = min(128, N)
block_size_k = min(32, K)
block_size_m = min(32, group_size_m)
n_tiles = triton.cdiv(N, block_size_n)
k_tiles = triton.cdiv(K, block_size_k)
grid = (num_experts * n_tiles * k_tiles,)
_kernel_cg_backward_dw[grid](
grad_output,
inputs,
grad_weights,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
BLOCK_SIZE_M=block_size_m,
)
return grad_weights
def cg_grouped_gemm_backward_inputs(
grad_output: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Backward pass for inputs."""
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, N = grad_output.shape
num_experts, _, K = expert_weights.shape
assert M_total % group_size_m == 0, (
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
)
grad_inputs = torch.zeros(
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
)
grid = lambda meta: (
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
* triton.cdiv(K, meta["BLOCK_SIZE_K"]),
)
_kernel_cg_backward_dx[grid](
grad_output,
expert_weights,
grad_inputs,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
)
return grad_inputs
class ContiguousGroupedGEMM(torch.autograd.Function):
"""Autograd function with full backward support."""
@staticmethod
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
ctx.save_for_backward(inputs, expert_weights, expert_indices)
ctx.group_size_m = group_size_m
return cg_grouped_gemm_forward(
inputs=inputs,
expert_weights=expert_weights,
expert_indices=expert_indices,
group_size_m=group_size_m,
)
@staticmethod
def backward(ctx, grad_output):
inputs, expert_weights, expert_indices = ctx.saved_tensors
group_size_m = ctx.group_size_m
grad_output = grad_output.contiguous()
num_experts = expert_weights.shape[0]
grad_inputs = cg_grouped_gemm_backward_inputs(
grad_output=grad_output,
expert_weights=expert_weights,
expert_indices=expert_indices,
group_size_m=group_size_m,
)
grad_weights = cg_grouped_gemm_backward_weights(
grad_output=grad_output,
inputs=inputs,
expert_indices=expert_indices,
num_experts=num_experts,
group_size_m=group_size_m,
)
grad_indices = None
grad_group_size_m = None
return grad_inputs, grad_weights, grad_indices, grad_group_size_m

View File

@@ -0,0 +1,311 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Vendored forward Triton contiguous grouped GEMM kernels."""
import torch
import triton
import triton.language as tl
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
GROUP_SIZE_M = 128
@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, super_group_m):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * super_group_m
group_size_m = min(num_pid_m - first_pid_m, super_group_m)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
return pid_m, pid_n
@triton.autotune(
configs=STANDARD_CONFIGS,
key=["M_TOTAL", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_cg_persistent_forward(
a_ptr,
b_ptr,
c_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
NUM_SMS: tl.constexpr,
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
SUPER_GROUP_M: tl.constexpr = 32,
):
"""
Contiguous Grouped GEMM kernel forward (persistent variant).
"""
c_type = c_ptr.dtype.element_ty
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n
tile_id_c = start_pid - NUM_SMS
num_pid_in_group = SUPER_GROUP_M * num_pid_n
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS):
tile_m_idx, tile_n_idx = _compute_pid(
tile_id, num_pid_in_group, num_pid_m, SUPER_GROUP_M
)
m_start = tile_m_idx * BLOCK_SIZE_M
n_start = tile_n_idx * BLOCK_SIZE_N
if m_start < M_TOTAL:
offs_m = m_start + tl.arange(0, BLOCK_SIZE_M)
offs_n = n_start + tl.arange(0, BLOCK_SIZE_N)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for ki in range(k_tiles):
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
mask_k = offs_k < K
mask_a = mask_m[:, None] & mask_k[None, :]
mask_b = mask_n[:, None] & mask_k[None, :]
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
b_ptrs = (
b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
)
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
accumulator += tl.dot(a, b.T)
tile_id_c += NUM_SMS
tile_m_idx, tile_n_idx = _compute_pid(
tile_id_c, num_pid_in_group, num_pid_m, SUPER_GROUP_M
)
offs_m = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
mask_c = mask_m[:, None] & mask_n[None, :]
c = accumulator.to(tl.float32)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
tl.store(c_ptrs, c.to(c_type), mask=mask_c)
@triton.autotune(
configs=STANDARD_CONFIGS,
key=["M_TOTAL", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_cg_forward_aligned(
a_ptr,
b_ptr,
c_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
):
"""
Contiguous Grouped GEMM kernel forward for aligned inputs.
"""
pid = tl.program_id(0)
c_type = c_ptr.dtype.element_ty
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
tile_m = pid // num_n_tiles
tile_n = pid % num_n_tiles
m_start = tile_m * BLOCK_SIZE_M
n_start = tile_n * BLOCK_SIZE_N
if m_start < M_TOTAL:
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
acc = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
offs_k = tl.arange(0, BLOCK_SIZE_K) + k
mask_k = offs_k < K
mask_a = mask_m[:, None] & mask_k[None, :]
mask_b = mask_n[:, None] & mask_k[None, :]
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
b_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
acc += tl.dot(a, b.T)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask_c = mask_m[:, None] & mask_n[None, :]
tl.store(c_ptrs, acc.to(c_type), mask=mask_c)
def cg_grouped_gemm_forward(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Contiguous grouped GEMM forward pass for MoE."""
assert inputs.is_contiguous(), "Input tensor must be contiguous"
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, K = inputs.shape
assert M_total % group_size_m == 0, (
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
)
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
num_experts, N, K_weights = expert_weights.shape
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
assert expert_indices.shape[0] == M_total, (
"Expert indices length must match M_total"
)
output = torch.empty((M_total, N), device=inputs.device, dtype=torch.bfloat16)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
grid = (NUM_SMS, 1, 1)
_kernel_cg_persistent_forward[grid](
inputs,
expert_weights,
output,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
NUM_SMS=NUM_SMS,
)
return output
def cg_grouped_gemm_forward_dynamic(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Contiguous grouped GEMM forward pass for MoE with autotuned launch."""
assert inputs.is_contiguous(), "Input tensor must be contiguous"
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, K = inputs.shape
assert M_total % group_size_m == 0, (
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
)
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
num_experts, N, K_weights = expert_weights.shape
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
assert expert_indices.shape[0] == M_total, (
"Expert indices length must match M_total"
)
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
grid = lambda meta: (
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
* triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
_kernel_cg_forward_aligned[grid](
inputs,
expert_weights,
output,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
)
return output
class ContiguousGroupedGEMM(torch.autograd.Function):
"""Autograd function for contiguous grouped GEMM forward pass only."""
@staticmethod
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
return cg_grouped_gemm_forward(
inputs=inputs,
expert_weights=expert_weights,
expert_indices=expert_indices,
group_size_m=group_size_m,
)
@staticmethod
def backward(ctx, grad_output): # pragma: no cover - not implemented
raise NotImplementedError("Backward pass not implemented")
def cg_grouped_gemm(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Convenience wrapper for the forward-only autograd function."""
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
return ContiguousGroupedGEMM.apply(
inputs, expert_weights, expert_indices, group_size_m
)

View File

@@ -0,0 +1,31 @@
"""Reference implementation for contiguous grouped GEMM."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
def pytorch_reference(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = 128,
) -> torch.Tensor:
"""Simple PyTorch implementation for verification."""
M_total, K = inputs.shape
num_experts, N, _ = expert_weights.shape
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
for i in range(0, M_total, group_size_m):
end_idx = min(i + group_size_m, M_total)
expert_idx = expert_indices[i].item()
expert_weight = expert_weights[expert_idx]
output[i:end_idx] = torch.matmul(inputs[i:end_idx], expert_weight.T)
return output

View File

@@ -0,0 +1,209 @@
"""Autotuning utilities for Triton contiguous grouped GEMM kernels."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict
import torch
import triton
import triton.language as tl
from triton.runtime import driver
class CudaUtils:
"""Helper utilities for CUDA specific Triton features."""
@staticmethod
def is_cuda() -> bool:
return driver.active.get_current_target().backend == "cuda"
@staticmethod
def verify_tma() -> bool:
return (
CudaUtils.is_cuda()
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
)
@staticmethod
def get_num_sms() -> int:
if not CudaUtils.is_cuda():
raise RuntimeError("Triton is not running on CUDA backend")
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
return torch.cuda.get_device_properties("cuda").multi_processor_count
class TmaDescriptorHelper:
"""Helper class for managing TMA descriptors in Triton kernels."""
class KernelParamWrapper:
def __init__(self, desc: torch.Tensor):
self.desc = desc
def tma_desc_cpu_ptr(self) -> int:
return self.desc.data_ptr()
def __init__(self, tma_size: int = 128):
if not CudaUtils.verify_tma():
raise RuntimeError(
"TMA not supported on this device (requires Hopper or newer)"
)
if "nv_tma_desc_type" not in dir(tl):
raise RuntimeError(
"TMA grid constant descriptors not supported in your Triton version"
)
self.tma_size = tma_size
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
self.descriptors: Dict[str, torch.Tensor] = {}
def init_tma_descriptor(self, name: str) -> None:
self.descriptors[name] = torch.empty(
self.tma_size, device="cpu", dtype=torch.int8
)
def fill_1d_tma_descriptor(
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
) -> None:
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
def fill_2d_tma_descriptor(
self,
name: str,
ptr: int,
dim1: int,
dim0: int,
block_dim1: int,
block_dim0: int,
element_size: int,
) -> None:
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
def get_tma_descriptor_kernel_param(
self, name: str
) -> "TmaDescriptorHelper.KernelParamWrapper":
if name not in self.descriptors or self.descriptors[name] is None:
raise ValueError(f"TMA descriptor '{name}' not initialized")
return self.KernelParamWrapper(self.descriptors[name])
HOPPER_CONFIGS = [
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=8,
),
]
STANDARD_CONFIGS = [
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=8,
),
]
def early_config_prune(configs, args, **kwargs):
"""Filter out configurations that would exceed shared memory capacity."""
k = kwargs.get("K", 0)
valid_configs = [
config for config in configs if config.kwargs.get("BLOCK_SIZE_K", 0) <= k
]
if not valid_configs and configs:
return [
min(
configs,
key=lambda c: c.kwargs.get("BLOCK_SIZE_K", float("inf")),
)
]
return valid_configs

View File

@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .mg_grouped_gemm import grouped_gemm_forward
from .tma_autotuning import ALIGN_SIZE_M
__all__ = [
"grouped_gemm_forward",
"ALIGN_SIZE_M",
]

View File

@@ -0,0 +1,761 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# credit - flat index forward kernel is derived from FBGemm:
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
# pyre-unsafe
import logging
from typing import Tuple
import torch
import triton
import triton.language as tl
from .tma_autotuning import (
_NV_CONFIGS,
CudaUtils,
early_config_prune,
)
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
_allocator_registered = False
def _torch_allocator(size: int, alignment: int, stream) -> torch.Tensor:
return torch.empty(size, device="cuda", dtype=torch.int8)
def _ensure_triton_allocator() -> None:
global _allocator_registered
if not _allocator_registered:
triton.set_allocator(_torch_allocator)
_allocator_registered = True
# ============== Start Triton Kernels ===============
@triton.autotune(
configs=_NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_forward_hopper(
a_ptr,
b_ptr,
c_ptr,
m_sizes,
M_TOTAL,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
# config
NUM_SMS: tl.constexpr,
USE_EPILOGUE_SUBTILING: tl.constexpr,
# tiles
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
) -> None:
"""Flat index style forward kernel for Hopper using tensor descriptors."""
tbidx = tl.program_id(0)
c_dtype = c_ptr.dtype.element_ty
n_size = N // G
a_desc = tl.make_tensor_descriptor(
a_ptr,
shape=[M_TOTAL, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
b_desc = tl.make_tensor_descriptor(
b_ptr,
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
M_end = tl.full([], 0, dtype=tl.int32)
processed_tiles = 0
for g in range(G):
M_start = M_end
m_size = tl.load(m_sizes + g)
M_end = M_start + m_size
if m_size > 0:
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
group_num_tiles = num_m_tiles * num_n_tiles
while (
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
):
group_index = tbidx - processed_tiles
tile_m_index = group_index % num_m_tiles
tile_n_index = group_index // num_m_tiles
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N
col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
global_n_offset = (g * n_size + n_offset).to(tl.int32)
for k_offset in range(0, K, BLOCK_SIZE_K):
k_remaining = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
a = a_desc.load([m_offset, k_offset])
a_mask = row_mask[:, None] & k_mask[None, :]
a = tl.where(a_mask, a, tl.zeros_like(a))
b = b_desc.load([global_n_offset, k_offset])
b_mask = col_mask[:, None] & k_mask[None, :]
b = tl.where(b_mask, b, tl.zeros_like(b))
accumulator += tl.dot(a, b.T)
local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M)
row_store_mask = local_row_offsets < m_size
global_row = (M_start + local_row_offsets).to(tl.int32)
local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(
0, BLOCK_SIZE_N
)
col_store_mask = local_col_offsets < n_size
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
if USE_EPILOGUE_SUBTILING:
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
acc = tl.permute(acc, (0, 2, 1))
acc0, acc1 = tl.split(acc)
col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2]
col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2]
ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :]
tl.store(
ptr0,
acc0.to(c_dtype),
mask=row_store_mask[:, None] & col_mask0[None, :],
)
col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :]
col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :]
ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :]
tl.store(
ptr1,
acc1.to(c_dtype),
mask=row_store_mask[:, None] & col_mask1[None, :],
)
else:
ptr = (
c_ptr
+ global_row[:, None] * n_size
+ local_col_offsets[None, :]
)
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
tbidx += NUM_SMS
processed_tiles += group_num_tiles
"""
Backward pass for grouped GEMM with Triton, where grouping is M*G
We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
"""
# ---- dx flat linear indexed ----
@triton.autotune(
configs=_NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_dx_tma(
grad_output_ptr,
w_ptr,
grad_input_ptr,
m_sizes,
M_TOTAL,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
# config
NUM_SMS: tl.constexpr,
# tiles
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
) -> None:
"""Compute grad_input = grad_output @ w using tensor descriptors."""
tbidx = tl.program_id(0)
c_dtype = grad_input_ptr.dtype.element_ty
grad_output_desc = tl.make_tensor_descriptor(
grad_output_ptr,
shape=[M_TOTAL, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
w_desc = tl.make_tensor_descriptor(
w_ptr,
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
M_end = tl.full([], 0, dtype=tl.int32)
processed_tiles = 0
for g in range(G):
M_start = M_end
m_size = tl.load(m_sizes + g)
M_end = M_start + m_size
if m_size > 0:
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
group_num_tiles = num_m_tiles * num_k_tiles
while (
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
):
group_index = tbidx - processed_tiles
tile_m_index = group_index % num_m_tiles
tile_k_index = group_index // num_m_tiles
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
k_offset = tile_k_index * BLOCK_SIZE_K
k_remaining_total = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
for n_offset in range(0, N, BLOCK_SIZE_N):
n_remaining = N - n_offset
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
grad_y = grad_output_desc.load([m_offset, n_offset])
grad_y_mask = row_mask[:, None] & n_mask[None, :]
grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y))
w_tile = w_desc.load([n_offset, k_offset])
w_mask = n_mask[:, None] & k_mask[None, :]
w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile))
accumulator += tl.dot(grad_y, w_tile)
local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(
0, BLOCK_SIZE_M
)
row_store_mask = local_row_offsets < m_size
global_row = (M_start + local_row_offsets).to(tl.int32)
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
col_store_mask = col_offsets < K
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :]
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
tbidx += NUM_SMS
processed_tiles += group_num_tiles
@triton.autotune(
configs=_NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_dw_tma(
x_ptr,
grad_output_ptr,
grad_weight_ptr,
m_sizes,
M_TOTAL,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
# config
NUM_SMS: tl.constexpr,
# tiles
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
) -> None:
"""Compute grad_weight = grad_output.T @ x using tensor descriptors."""
tbidx = tl.program_id(0)
c_dtype = grad_weight_ptr.dtype.element_ty
x_desc = tl.make_tensor_descriptor(
x_ptr,
shape=[M_TOTAL, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
grad_output_desc = tl.make_tensor_descriptor(
grad_output_ptr,
shape=[M_TOTAL, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
total_tiles = num_n_tiles * num_k_tiles
for tile_idx in range(tbidx, total_tiles, NUM_SMS):
tile_n_idx = tile_idx % num_n_tiles
tile_k_idx = tile_idx // num_n_tiles
n_offset = tile_n_idx * BLOCK_SIZE_N
n_remaining = N - n_offset
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
k_offset = tile_k_idx * BLOCK_SIZE_K
k_remaining = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
M_end = tl.full([], 0, dtype=tl.int32)
for g in range(G):
M_start = M_end
m_size = tl.load(m_sizes + g)
M_end = M_start + m_size
if m_size > 0:
for m_offset_local in range(0, m_size, BLOCK_SIZE_M):
rows_remaining = m_size - m_offset_local
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
m_offset = (M_start + m_offset_local).to(tl.int32)
x_block = x_desc.load([m_offset, k_offset])
x_mask = row_mask[:, None] & k_mask[None, :]
x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block))
grad_block = grad_output_desc.load([m_offset, n_offset])
grad_mask = row_mask[:, None] & n_mask[None, :]
grad_block = tl.where(
grad_mask, grad_block, tl.zeros_like(grad_block)
)
contribution = tl.dot(
grad_block.to(tl.float32).T,
x_block.to(tl.float32),
)
accumulator += contribution
row_offsets = n_offset + tl.arange(0, BLOCK_SIZE_N)
row_store_mask = row_offsets < N
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
col_store_mask = col_offsets < K
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
ptr = grad_weight_ptr + row_offsets[:, None] * K + col_offsets[None, :]
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
# ======== End Triton kernels ========
# ======== End Triton kernels ========
# ======== Triton wrapper functions ========
# ----- main forward pass wrapper -----
def grouped_gemm_forward(
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
tma_size: int = 128,
using_fp8: bool = False,
) -> torch.Tensor:
"""Grouped GEMM forward using Hopper TMA kernels."""
_ensure_triton_allocator()
if not CudaUtils.verify_tma():
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
if using_fp8:
raise NotImplementedError(
"FP8 path not implemented with the new Triton API yet"
)
G = m_sizes.shape[0]
assert x.is_contiguous()
assert w.is_contiguous()
assert m_sizes.is_contiguous()
M_total, K = x.shape
N = w.shape[0]
assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
if M_total == 0:
return y
NUM_SMS = CudaUtils.get_num_sms()
USE_EPILOGUE_SUBTILING = False
def grid(_meta):
return (NUM_SMS,)
M_BUCKET = triton.next_power_of_2(M_total)
_kernel_mg_forward_hopper[grid](
x,
w,
y,
m_sizes,
M_total,
G,
M_BUCKET,
N,
K,
NUM_SMS,
USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
)
return y
# ======== Improved Backward =============
def grouped_gemm_backward(
grad_output: torch.Tensor,
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
use_tma: bool = True,
tma_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Unified backward pass for grouped GeMM with M*G grouping.
Uses optimized TMA-based implementations for both dx and dw when available.
Args:
grad_output: Gradient of output, shape [M_total, N]
x: Input tensor from forward pass, shape [M_total, K]
w: Weight tensor from forward pass, shape [N, K]
m_sizes: Group sizes tensor, shape [G]
use_tma: Whether to try using TMA acceleration (if available)
tma_size: Size of TMA descriptor in bytes
Returns:
Tuple of gradients with respect to x and w: (grad_x, grad_w)
"""
logging.info("Starting unified grouped_gemm_backward")
# do this once, seems expensive
NUM_SMS = CudaUtils.get_num_sms()
# Basic validation
M_total, K_x = x.shape
M_grad, N = grad_output.shape
N_w, K_w = w.shape
# Check dimensions
if K_x != K_w:
raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
if M_total != M_grad:
raise ValueError(
f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
)
# Check total M matches sum of group sizes
sum_m_sizes = m_sizes.sum().item()
if M_total != sum_m_sizes:
raise ValueError(
f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
)
# Make sure inputs are contiguous
grad_output = grad_output.contiguous()
x = x.contiguous()
w = w.contiguous()
m_sizes = m_sizes.contiguous()
# Check TMA support
if use_tma and not CudaUtils.verify_tma():
logging.info("TMA requested but not supported on this device")
use_tma = False
# Compute grad_x using flat linear implementation
try:
logging.info("Computing grad_x with flat linear kernel")
# Use TMA-optimized implementation
grad_x = grouped_gemm_dx_tma(
grad_output=grad_output,
w=w,
m_sizes=m_sizes,
num_sms=NUM_SMS,
tma_size=tma_size,
)
except Exception as e:
logging.error(f"Error in grad_x computation: {e}")
raise
# Compute grad_w using flat linear style implementation
try:
logging.info("Computing grad_w with flat linear kernel")
grad_w = grouped_gemm_dw_tma(
x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
)
except Exception as e:
logging.error(f"Error in grad_w computation: {e}")
raise
return grad_x, grad_w
# ----- dx backward pass wrapper -----
def grouped_gemm_dx_tma(
grad_output: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
num_sms: int = 132,
tma_size: int = 128,
) -> torch.Tensor:
"""Compute grad_x using the Hopper grouped GEMM kernel."""
_ensure_triton_allocator()
if not CudaUtils.verify_tma():
raise NotImplementedError("Optimized dx computation requires TMA support")
grad_output = grad_output.contiguous()
w = w.contiguous()
m_sizes = m_sizes.contiguous()
M_total, N = grad_output.shape
N_w, K = w.shape
if N != N_w:
raise ValueError(f"Grad_output N ({N}) must match weight N ({N_w})")
if m_sizes.sum().item() != M_total:
raise ValueError("Sum of m_sizes must equal the number of rows in grad_output")
grad_x = torch.empty(
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
)
NUM_SMS = num_sms
def grid(_meta):
return (NUM_SMS,)
M_BUCKET = triton.next_power_of_2(M_total)
_kernel_mg_dx_tma[grid](
grad_output,
w,
grad_x,
m_sizes,
M_total,
m_sizes.shape[0],
M_BUCKET,
N,
K,
NUM_SMS,
)
return grad_x
# ======== dw wrapper function ==========
def grouped_gemm_dw_tma(
x: torch.Tensor,
grad_output: torch.Tensor,
m_sizes: torch.Tensor,
num_sms: int = 132,
tma_size: int = 128,
) -> torch.Tensor:
"""Compute grad_w using the Hopper grouped GEMM kernel."""
_ensure_triton_allocator()
if not CudaUtils.verify_tma():
raise RuntimeError("TMA grouped GEMM requested on a device without TMA support")
x = x.contiguous()
grad_output = grad_output.contiguous()
m_sizes = m_sizes.contiguous()
M_total, K = x.shape
M_grad, N = grad_output.shape
if M_total != M_grad:
raise ValueError("x and grad_output must have matching batch dimension")
if m_sizes.sum().item() != M_total:
raise ValueError("Sum of m_sizes must equal the number of rows in the inputs")
grad_w = torch.zeros((N, K), device=x.device, dtype=x.dtype)
NUM_SMS = num_sms
def grid(_meta):
return (NUM_SMS,)
M_BUCKET = triton.next_power_of_2(M_total)
_kernel_mg_dw_tma[grid](
x,
grad_output,
grad_w,
m_sizes,
M_total,
m_sizes.shape[0],
M_BUCKET,
N,
K,
NUM_SMS,
)
return grad_w
# ======== End Backwards Wrapper Functions =============
# ======== PyTorch wrapper functions ========
class GroupedGemmMg(torch.autograd.Function):
"""
Autograd function for GroupedGEMM with M*G grouping.
Supports both standard and FP8 quantized operations.
"""
@staticmethod
def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128, using_fp8=False):
"""
Forward pass of GroupedGEMM.
Args:
x: Input tensor, shape [M_total, K]
w: Weight tensor, shape [N, K]
m_sizes: Tensor of shape [G] containing the size of each group
use_tma: Whether to try using TMA acceleration (if available)
tma_size: Size of TMA descriptor in bytes
using_fp8: Whether to use FP8 quantization
Returns:
Output tensor, shape [M_total, N]
"""
# Use regular forward without quantization
output = grouped_gemm_forward(
x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
)
# Save inputs and parameters for backward pass
ctx.save_for_backward(x, w, m_sizes)
ctx.use_tma = use_tma
ctx.tma_size = tma_size
ctx.save_for_backward(x, w, m_sizes)
return output
@staticmethod
def backward(ctx, grad_output):
"""
Backward pass of M*G GroupedGEMM.
Args:
grad_output: Gradient of output, shape [M_total, N]
Returns:
Tuple of gradients:
- grad_x: Gradient with respect to x, shape [M_total, K]
- grad_w: Gradient with respect to w, shape [N, K]
- None: Gradient with respect to m_sizes (not differentiable)
- None: Gradient with respect to use_tma (not differentiable)
- None: Gradient with respect to tma_size (not differentiable)
"""
# Retrieve saved tensors and parameters
x, w, m_sizes = ctx.saved_tensors
use_tma = ctx.use_tma
tma_size = ctx.tma_size
# Compute gradients using the unified implementation
grad_x, grad_w = grouped_gemm_backward(
grad_output=grad_output,
x=x,
w=w,
m_sizes=m_sizes,
use_tma=use_tma,
tma_size=tma_size,
)
# Return gradients for all inputs (None for non-differentiable parameters)
return grad_x, grad_w, None, None, None, None
def mg_grouped_gemm(
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
use_tma: bool = True,
tma_size: int = 128,
using_fp8: bool = False,
) -> torch.Tensor:
"""
Unified differentiable grouped GEMM operation for M*G grouped GEMM.
Supports both standard precision and FP8 quantized operations.
Args:
x: Input tensor, shape [M_total, K]
w: Weight tensor, shape [N, K]
m_sizes: Tensor of shape [G] containing the size of each group
use_tma: Whether to try using TMA acceleration (if available)
tma_size: Size of TMA descriptor in bytes
using_fp8: Whether to use FP8 quantization
Returns:
Output tensor, shape [M_total, N]
"""
return GroupedGemmMg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)

View File

@@ -0,0 +1,232 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
# pyre-unsafe
import os
import sys
from typing import Dict
import torch
import triton
from triton.runtime import driver # @manual
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# ===== Supporting utils, CUDA and TMA =====
class CudaUtils:
@staticmethod
def is_cuda() -> bool:
"""Check if Triton is running on CUDA backend."""
return driver.active.get_current_target().backend == "cuda"
@staticmethod
def verify_tma() -> bool:
"""Check if TMA is supported on the current device."""
return (
CudaUtils.is_cuda()
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
)
@staticmethod
def get_num_sms() -> int:
"""Get the number of streaming multiprocessors on the current device."""
if not CudaUtils.is_cuda():
raise RuntimeError("Triton is not running on CUDA backend")
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
return torch.cuda.get_device_properties("cuda").multi_processor_count
class TmaDescriptorHelper:
"""Helper class for managing TMA descriptors in Triton kernels.
Args:
tma_size: Size of the TMA descriptor in bytes
"""
class KernelParamWrapper:
"""Wrapper to implement the TmaDescKernelParam interface."""
def __init__(self, desc: torch.Tensor):
self.desc = desc
def tma_desc_cpu_ptr(self) -> int:
"""Return the CPU pointer to the TMA descriptor."""
return self.desc.data_ptr()
def __init__(self, tma_size: int = 128):
if not CudaUtils.verify_tma():
raise RuntimeError(
"TMA not supported on this device (requires Hopper or newer)"
)
self.tma_size = tma_size
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
self.descriptors: Dict[str, torch.Tensor] = {}
def init_tma_descriptor(self, name: str) -> None:
"""Initialize a TMA descriptor with the given name.
Call this method outside of the lambda function for grid size.
"""
self.descriptors[name] = torch.empty(
self.tma_size, device="cpu", dtype=torch.int8
)
def fill_1d_tma_descriptor(
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
) -> None:
"""Fill a 1D TMA descriptor.
Call this method inside the lambda function for grid size.
"""
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
def fill_2d_tma_descriptor(
self,
name: str,
ptr: int,
dim1: int,
dim0: int,
block_dim1: int,
block_dim0: int,
element_size: int,
) -> None:
"""Fill a 2D TMA descriptor.
Call this method inside the lambda function for grid size.
"""
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
"""Get the TMA descriptor kernel parameter for the given name."""
if name not in self.descriptors or self.descriptors[name] is None:
raise ValueError(f"TMA descriptor '{name}' not initialized")
return self.KernelParamWrapper(self.descriptors[name])
# ====== Autotuning utilities ======
ALIGN_SIZE_M = 128
_NV_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
},
num_stages=num_stages,
num_warps=num_warps,
num_ctas=num_ctas,
)
for block_size_m in [
ALIGN_SIZE_M,
]
for block_size_n in [64, 128, 256]
for block_size_k in [64, 128, 256]
for num_stages in [3, 4]
for num_warps in [4, 8]
for num_ctas in [1]
]
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
device = torch.cuda.current_device()
# Check for all possible pointer parameter names
if "grad_input_ptr" in named_args:
ptr_name = "grad_input_ptr"
elif "c_ptr" in named_args:
ptr_name = "c_ptr"
elif "grad_weight_ptr" in named_args:
ptr_name = "grad_weight_ptr"
else:
raise KeyError("No recognized pointer parameter found in kernel arguments")
if dtsize is None:
dtsize = named_args[ptr_name].element_size()
if dtype is None:
dtype = named_args[ptr_name].dtype
pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
kw["BLOCK_SIZE_M"],
kw["BLOCK_SIZE_N"],
kw["BLOCK_SIZE_K"],
config.num_stages,
)
G, M, N, K = (
named_args["G"],
named_args["M_BUCKET"],
named_args["N"],
named_args["K"],
)
# 1. make sure we have enough smem
max_shared_memory = driver.active.utils.get_device_properties(device)[
"max_shared_mem"
]
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory > max_shared_memory:
continue
M_PER_GROUP = M // G
MIN_M_TILES = 64
# 2. make sure we don't load M tiles that are too big
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
continue
# 3. make sure we don't load N tiles that are too small
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
continue
num_sm = driver.active.utils.get_device_properties(device)[
"multiprocessor_count"
]
N_TILES = N // BLOCK_N
MIN_N_TILES = 64
# 4. make sure we don't load N tiles that are too big
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
continue
# 5. make sure we don't load N tiles that are too small
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
continue
# 6. make sure K can be evenly divided
if K % BLOCK_K != 0:
continue
pruned_configs.append(config)
return pruned_configs
# ======== End Autotuning utilities ========

View File

@@ -190,6 +190,15 @@ class PatchManager:
apply_mistral_tokenizer_image_patch()
if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3":
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
patch_deepseek_v3_moe(backend=self.cfg.moe_kernel_backend)
elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels:
LOG.info(
"Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`"
)
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:

View File

@@ -0,0 +1,401 @@
"""Monkeypatches for DeepSeek V3 MoE to use Triton contiguous grouped GEMM kernels."""
from __future__ import annotations
from typing import Callable
import torch
from axolotl.kernels.moe import ContiguousGroupedGEMM
from axolotl.kernels.moe.indices import generate_permute_indices
from axolotl.kernels.moe.tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
from axolotl.utils.logging import get_logger
_GROUP_SIZE_M = 128
_COMBINED_SUBMODULES = ("gate_proj", "up_proj", "down_proj")
LOG = get_logger(__name__)
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
if not hidden_states.is_cuda or hidden_states.shape[0] == 0:
return False
major, _ = torch.cuda.get_device_capability(hidden_states.device)
if major < 9:
LOG.debug(
"Skipping Triton MoE kernels: requires compute capability >= 90, found %s",
major,
)
return False
return True
def _ensure_combined_expert_weights(
module, dtype: torch.dtype, device: torch.device, backend: str
) -> None:
if not hasattr(module, "_axolotl_original_specs"):
module._axolotl_original_specs = {}
if not hasattr(module, "_axolotl_mg_shapes"):
module._axolotl_mg_shapes = {}
prev_backend = getattr(module, "_axolotl_combined_backend", None)
if getattr(module, "_axolotl_combined_weights", False):
if prev_backend != backend:
_restore_expert_weights(module)
else:
for name in _COMBINED_SUBMODULES:
param_name = f"{name}_weight"
param = module.get_parameter(param_name)
if param.device != device or param.dtype != dtype:
module._parameters[param_name] = torch.nn.Parameter(
param.to(device=device, dtype=dtype).contiguous()
)
module._axolotl_combined_dtype = dtype
module._axolotl_combined_device = device
module._axolotl_combined_backend = backend
return
module._axolotl_mg_shapes = {}
for name in _COMBINED_SUBMODULES:
weights = []
orig_device = None
orig_dtype = None
orig_shape = None
for expert in module.experts:
lin = expert.get_submodule(name)
weight_param = lin._parameters.get("weight")
if weight_param is None:
raise RuntimeError("Expected expert linear layers to have weights")
if orig_device is None:
orig_device = weight_param.device
orig_dtype = weight_param.dtype
orig_shape = tuple(weight_param.shape)
weights.append(weight_param.detach().to(device=device, dtype=dtype))
if "weight" in lin._parameters:
del lin._parameters["weight"]
if "bias" in lin._parameters:
del lin._parameters["bias"]
if backend == "cg":
combined_weight = torch.stack(weights, dim=0).contiguous()
else:
combined_weight = torch.cat(weights, dim=0).contiguous()
module._axolotl_mg_shapes[name] = orig_shape
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined_weight))
module._axolotl_original_specs[name] = (orig_device, orig_dtype, orig_shape)
module._axolotl_combined_weights = True
module._axolotl_combined_dtype = dtype
module._axolotl_combined_device = device
module._axolotl_combined_backend = backend
def _restore_expert_weights(module) -> None:
if not getattr(module, "_axolotl_combined_weights", False):
return
for name in _COMBINED_SUBMODULES:
param_name = f"{name}_weight"
combined = module._parameters.pop(param_name)
orig_device, orig_dtype, orig_shape = module._axolotl_original_specs.get(
name, (combined.device, combined.dtype, None)
)
rows_per = orig_shape[0] if orig_shape else None
for idx, expert in enumerate(module.experts):
lin = expert.get_submodule(name)
if combined.dim() == 3:
slice_tensor = combined[idx]
elif rows_per is not None:
start = idx * rows_per
end = start + rows_per
slice_tensor = combined[start:end]
else:
raise RuntimeError(
"Unable to recover expert weight shape during restore"
)
lin._parameters["weight"] = torch.nn.Parameter(
slice_tensor.detach().clone().to(orig_device, dtype=orig_dtype)
)
module._axolotl_combined_weights = False
module._axolotl_combined_dtype = None
module._axolotl_combined_device = None
module._axolotl_combined_backend = None
module._axolotl_original_specs = {}
module._axolotl_mg_shapes = {}
def _run_cg_grouped_gemm(
module,
grouped_hidden: torch.Tensor,
m_sizes: torch.Tensor,
num_experts: int,
group_size_m: int,
hidden_dtype: torch.dtype,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
_ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg")
expert_index_tensor = torch.repeat_interleave(
torch.arange(num_experts, device=device, dtype=torch.int32),
m_sizes.to(torch.int64),
)
gate_weights = module.get_parameter("gate_proj_weight")
if gate_weights.dim() == 2:
out_dim = gate_weights.shape[0] // num_experts
gate_weights = gate_weights.view(num_experts, out_dim, gate_weights.shape[1])
up_weights = module.get_parameter("up_proj_weight")
if up_weights.dim() == 2:
out_dim = up_weights.shape[0] // num_experts
up_weights = up_weights.view(num_experts, out_dim, up_weights.shape[1])
down_weights = module.get_parameter("down_proj_weight")
if down_weights.dim() == 2:
out_dim = down_weights.shape[0] // num_experts
down_weights = down_weights.view(num_experts, out_dim, down_weights.shape[1])
gate_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
gate_weights,
expert_index_tensor,
group_size_m,
)
up_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
up_weights,
expert_index_tensor,
group_size_m,
)
return (
gate_out.to(hidden_dtype),
up_out.to(hidden_dtype),
down_weights,
expert_index_tensor,
)
gate_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("gate_proj_weight"),
m_sizes_tensor,
)
up_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("up_proj_weight"),
m_sizes_tensor,
)
down_out = mg_grouped_gemm(
hidden_grouped,
module.get_parameter("down_proj_weight"),
m_sizes_tensor,
)
return (
gate_out.to(hidden_dtype),
up_out.to(hidden_dtype),
down_out.to(hidden_dtype),
)
def _moe_triton_forward(
module,
hidden_states: torch.Tensor,
topk_indices: torch.Tensor,
topk_weights: torch.Tensor,
group_size_m: int,
backend: str,
fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
) -> torch.Tensor:
if not _is_triton_eligible(hidden_states):
return fallback(hidden_states, topk_indices, topk_weights)
device = hidden_states.device
hidden_dtype = hidden_states.dtype
num_tokens, hidden_dim = hidden_states.shape
top_k = topk_indices.size(-1)
expanded_hidden = hidden_states.repeat_interleave(top_k, dim=0)
expert_assignments = topk_indices.reshape(-1)
if expanded_hidden.numel() == 0:
return hidden_states.new_zeros_like(hidden_states)
sort_perm = torch.argsort(expert_assignments)
sorted_hidden = expanded_hidden.index_select(0, sort_perm)
sorted_assignments = expert_assignments.index_select(0, sort_perm)
num_experts = len(module.experts)
counts = torch.bincount(sorted_assignments, minlength=num_experts)
total_actual = int(counts.sum().item())
if total_actual == 0:
return hidden_states.new_zeros_like(hidden_states)
if not getattr(module, "_axolotl_triton_logged", False):
min_tokens = int(counts.min().item())
max_tokens = int(counts.max().item())
LOG.info(
"DeepseekV3MoE Triton: tokens per expert (min=%s, max=%s, avg=%.1f) with group_size=%s",
min_tokens,
max_tokens,
total_actual / max(1, num_experts),
group_size_m,
)
module._axolotl_triton_logged = True
counts_int = counts.to(torch.int32)
aligned_counts = (
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1) // group_size_m
) * group_size_m
max_len = int(aligned_counts.sum().item())
permuted_indices, m_sizes, _ = generate_permute_indices(
counts_int.to(device),
experts_per_rank=num_experts,
num_ranks=1,
max_len=max_len,
alignment=group_size_m,
use_cpu=not hidden_states.is_cuda,
)
permuted_indices = permuted_indices.to(device)
m_sizes = m_sizes.to(device)
permuted_indices_long = permuted_indices.to(torch.int64)
valid_mask = permuted_indices_long >= 0
valid_positions = torch.nonzero(valid_mask, as_tuple=False).squeeze(-1)
source_indices = permuted_indices_long[valid_mask]
padded_positions = torch.nonzero(~valid_mask, as_tuple=False).squeeze(-1)
grouped_hidden = hidden_states.new_empty((max_len, hidden_dim))
if valid_positions.numel() > 0:
grouped_hidden.index_copy_(
0,
valid_positions,
sorted_hidden.index_select(0, source_indices),
)
if valid_positions.numel() < max_len:
grouped_hidden.index_fill_(0, padded_positions, 0)
m_sizes_tensor = m_sizes.to(device=device, dtype=torch.int32)
if backend == "mg":
_ensure_combined_expert_weights(module, hidden_dtype, device, backend)
gate_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("gate_proj_weight"),
m_sizes_tensor,
).to(hidden_dtype)
up_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("up_proj_weight"),
m_sizes_tensor,
).to(hidden_dtype)
else:
gate_out, up_out, down_weights, expert_index_tensor = _run_cg_grouped_gemm(
module,
grouped_hidden,
m_sizes,
num_experts,
group_size_m,
hidden_dtype,
device,
)
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
if valid_positions.numel() > 0:
gate_valid = gate_out.index_select(0, valid_positions)
up_valid = up_out.index_select(0, valid_positions)
hidden_concat = act_fn(gate_valid) * up_valid
else:
hidden_concat = torch.empty(
(0, gate_out.shape[-1]), device=device, dtype=hidden_dtype
)
intermediate_dim = hidden_concat.shape[-1]
hidden_grouped = hidden_states.new_empty((max_len, intermediate_dim))
if valid_positions.numel() > 0:
hidden_grouped.index_copy_(0, valid_positions, hidden_concat)
if valid_positions.numel() < max_len:
hidden_grouped.index_fill_(0, padded_positions, 0)
if backend == "mg":
down_out = mg_grouped_gemm(
hidden_grouped,
module.get_parameter("down_proj_weight"),
m_sizes_tensor,
).to(hidden_dtype)
else:
down_out = ContiguousGroupedGEMM.apply(
hidden_grouped,
down_weights,
expert_index_tensor,
group_size_m,
).to(hidden_dtype)
if valid_positions.numel() > 0:
down_valid = down_out.index_select(0, valid_positions)
else:
down_valid = torch.empty(
(0, down_out.shape[-1]), device=device, dtype=hidden_dtype
)
sorted_outputs = hidden_states.new_zeros((total_actual, hidden_dim))
if down_valid.numel() > 0:
sorted_outputs.index_copy_(0, source_indices, down_valid)
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
expanded_output.index_copy_(0, sort_perm, sorted_outputs)
expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim)
weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype)
return weighted.sum(dim=1)
def patch_deepseek_v3_moe(
group_size_m: int = _GROUP_SIZE_M, backend: str = "mg"
) -> None:
"""Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels."""
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
if backend not in {"cg", "mg"}:
raise ValueError(f"Unsupported MoE kernel backend: {backend}")
# Record the unpatched implementation so callers can access a true baseline even
# after the Triton patch has been applied (e.g. repeated microbenchmarks).
if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"):
DeepseekV3MoE._axolotl_triton_original_moe = DeepseekV3MoE.moe
if getattr(DeepseekV3MoE, "_axolotl_triton_patch", False):
return
original_moe = DeepseekV3MoE._axolotl_triton_original_moe
DeepseekV3MoE._axolotl_triton_backend = backend
DeepseekV3MoE._axolotl_group_size_m = group_size_m
def patched_moe(self, hidden_states, topk_indices, topk_weights):
backend_sel = getattr(self, "_axolotl_triton_backend", backend)
group_size_sel = getattr(self, "_axolotl_group_size_m", group_size_m)
if backend_sel == "cg" and group_size_sel != _GROUP_SIZE_M:
LOG.debug(
"Adjusting group_size_m=%s to %s for CG backend",
group_size_sel,
_GROUP_SIZE_M,
)
group_size_sel = _GROUP_SIZE_M
try:
return _moe_triton_forward(
self,
hidden_states,
topk_indices,
topk_weights,
group_size_sel,
backend_sel,
original_moe,
)
except Exception as err: # surface Triton failures explicitly
_restore_expert_weights(self)
LOG.error("DeepseekV3MoE Triton path failed: %s", err)
raise
DeepseekV3MoE.moe = patched_moe
DeepseekV3MoE._axolotl_triton_patch = True

View File

@@ -113,6 +113,19 @@ class AxolotlInputConfig(
},
)
moe_kernels: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable Axolotl's vendored MoE kernels when supported (e.g., DeepSeek V3)"
},
)
moe_kernel_backend: Literal["cg", "mg"] | None = Field(
default="mg",
json_schema_extra={
"description": "Grouped GEMM backend to use when `moe_kernels` is enabled. `mg` selects the Hopper TMA kernel; `cg` selects the contiguous kernel."
},
)
trainer_cls: str | None = Field(
default=None,
json_schema_extra={