Compare commits

..

37 Commits

Author SHA1 Message Date
Dan Saunders
dd85358543 default mg 2025-09-25 16:30:23 -04:00
Dan Saunders
55d98db0d0 fix 2025-09-25 16:08:35 -04:00
Dan Saunders
d90ade3b1b fix 2025-09-25 15:55:08 -04:00
Dan Saunders
824a641cee uniform routing default 2025-09-25 15:47:23 -04:00
Dan Saunders
e003a05177 narrow sweep; compare both backends 2025-09-25 14:54:03 -04:00
Dan Saunders
91393c4dc8 allocator 2025-09-25 14:27:34 -04:00
Dan Saunders
d578c53603 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
4db7a21ff7 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
3b2e05c563 update to new api 2025-09-25 14:27:34 -04:00
Dan Saunders
1037ca3a97 update to new api 2025-09-25 14:27:34 -04:00
Dan Saunders
6369dcd7b8 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
a81612305c fix? 2025-09-25 14:27:34 -04:00
Dan Saunders
d0da67eb17 add mg kernel backend 2025-09-25 14:27:34 -04:00
Dan Saunders
8a1f5ae940 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
146ca48cba vram 2025-09-25 14:27:34 -04:00
Dan Saunders
fd312f6058 dtype 2025-09-25 14:27:34 -04:00
Dan Saunders
ab8fa56b16 dtype 2025-09-25 14:27:34 -04:00
Dan Saunders
1640cd4006 delete config 2025-09-25 14:27:34 -04:00
Dan Saunders
3277d44d71 cfg value 2025-09-25 14:27:34 -04:00
Dan Saunders
d3e1b0ef1a small deepseek script 2025-09-25 14:27:34 -04:00
Dan Saunders
5b97633faa Fix 2025-09-25 14:27:34 -04:00
Dan Saunders
94cbc6d42d log device, dtype 2025-09-25 14:27:34 -04:00
Dan Saunders
493616fc3d reprod tt table 2025-09-25 14:27:34 -04:00
Dan Saunders
d2b25c7327 grid sweep 2025-09-25 14:27:34 -04:00
Dan Saunders
b670c45276 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
61faf4cbe4 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
8d8fa834a2 sweep 2025-09-25 14:27:34 -04:00
Dan Saunders
9d69c6fb3e Fix 2025-09-25 14:27:34 -04:00
Dan Saunders
92f2f6e73c dtype fix 2025-09-25 14:27:34 -04:00
Dan Saunders
e5d2aebe16 uniform routing: 2025-09-25 14:27:34 -04:00
Dan Saunders
4ab9e3f58b add logs 2025-09-25 14:27:34 -04:00
Dan Saunders
5788832812 simplify 2025-09-25 14:27:34 -04:00
Dan Saunders
db782430f8 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
5c74edeefe token shuffle kernel 2025-09-25 14:27:34 -04:00
Dan Saunders
18269ee6a9 fix 2025-09-25 14:27:34 -04:00
Dan Saunders
6a45d804f9 glue 2025-09-25 14:27:34 -04:00
Dan Saunders
95e607574a vendor torchtitan moe kernels 2025-09-25 14:27:34 -04:00
23 changed files with 3379 additions and 58 deletions

View File

@@ -5,11 +5,10 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
(including DDP, DeepSpeed, and FSDP2) training. These include (1) SwiGLU and GEGLU
activation function Triton kernels, and (2) LoRA MLP and attention custom autograd
functions. Our goal was to leverage operator fusion and tensor re-use in order to
improve speed and reduce memory usage during the forward and backward passes of these
calculations.
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
to leverage operator fusion and tensor re-use in order to improve speed and reduce
memory usage during the forward and backward passes of these calculations.
We currently support several common model architectures, including (but not limited to):
@@ -93,12 +92,13 @@ Currently, LoRA kernels are not supported for RLHF training, only SFT.
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
- Note: Set `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` to enable [memory-efficient attention on AMD GPUs](https://github.com/ROCm/aotriton/issues/16#issuecomment-2346675491)
- Targeted LoRA adapters must disable dropout (`lora_dropout: 0`)
- Targeted LoRA adapters cannot use Dropout
- This may limit model expressivity / cause overfitting
- Targeted LoRA adapters cannot have bias terms
- This may limit model expressivity
- Adapters that already include bias terms are supported.
Models with pre-existing LoRA adapters that use Dropout may need to be re-finetuned
without it in order to be as performant.
Models with pre-existing LoRA adapters that use Dropout or have bias terms may need to
be re-finetuned without these features in order to be useful.
## Implementation details
@@ -131,5 +131,6 @@ computation path.
## Future Work
- Support for additional model architectures
- Support for dropout
- Support for the FSDP setting
- Support for dropout and bias
- Additional operator fusions

1
scripts/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Utility scripts package."""

View File

@@ -0,0 +1,5 @@
"""Benchmark helpers."""
from .deepseek_v3_moe import ACCURACY_TOLERANCE, DTYPE_MAP, benchmark_deepseek_v3
__all__ = ["benchmark_deepseek_v3", "DTYPE_MAP", "ACCURACY_TOLERANCE"]

View File

@@ -0,0 +1,100 @@
#!/usr/bin/env python3
"""Instantiate a ~8.3B DeepSeek-V3 MoE model with random weights.
Run this on a GPU-equipped machine (e.g. 1× NVL H100) so the dense
initialization completes quickly:
python scripts/benchmarks/build_deepseek_v3_8b.py --output deepseek-v3-8b-moe
"""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from transformers import DeepseekV3Config, DeepseekV3ForCausalLM
DTYPE_MAP = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
"float16": torch.float16,
}
def build_config() -> DeepseekV3Config:
"""Return a DeepSeek V3 configuration totaling ~8.3B parameters."""
return DeepseekV3Config(
vocab_size=32_000,
hidden_size=3_072,
intermediate_size=8_192,
moe_intermediate_size=2_560,
num_hidden_layers=20,
num_attention_heads=24,
num_key_value_heads=24,
n_routed_experts=18,
num_experts_per_tok=4,
n_group=6,
topk_group=4,
kv_lora_rank=192,
q_lora_rank=384,
max_position_embeddings=2_048,
rope_theta=10_000.0,
rope_interleave=True,
hidden_act="silu",
initializer_range=0.02,
attention_dropout=0.0,
attention_bias=False,
n_shared_experts=1,
routed_scaling_factor=2.5,
norm_topk_prob=True,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--output",
type=Path,
required=True,
help="Directory to save the generated model",
)
parser.add_argument(
"--dtype",
default="bfloat16",
choices=DTYPE_MAP.keys(),
help="Storage dtype for the checkpoint",
)
parser.add_argument(
"--seed",
type=int,
default=0,
help="Torch RNG seed for reproducibility",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
torch.manual_seed(args.seed)
output_dir = args.output
output_dir.mkdir(parents=True, exist_ok=True)
config = build_config()
model = DeepseekV3ForCausalLM(config)
dtype = DTYPE_MAP[args.dtype]
model.to(dtype=dtype)
param_count = sum(p.numel() for p in model.parameters())
print(f"Initialized DeepSeek-V3 MoE with {param_count / 1e9:.3f}B parameters")
model.save_pretrained(output_dir, safe_serialization=True)
config.save_pretrained(output_dir)
print(f"Saved model and config to {output_dir.resolve()}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,190 @@
#!/usr/bin/env python
"""Reproduce TorchTitan CG GEMM timings for selected problem sizes."""
from __future__ import annotations
import argparse
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
import torch
CURRENT_DIR = Path(__file__).resolve().parent
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else:
raise SystemExit("Unable to locate axolotl repository root for imports")
from axolotl.kernels.moe import (
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
@dataclass
class Scenario:
num_groups: int
m: int
n: int
k: int
SCENARIOS: tuple[Scenario, ...] = (
Scenario(num_groups=4, m=8192, n=4096, k=7168),
Scenario(num_groups=4, m=8192, n=7168, k=2048),
Scenario(num_groups=8, m=4096, n=4096, k=7168),
Scenario(num_groups=8, m=4096, n=7168, k=2048),
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--device", default="cuda", choices=["cuda"], help="Execution device"
)
parser.add_argument(
"--dtype",
default="bf16",
choices=["bf16", "fp16", "fp32"],
help="Computation dtype",
)
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=20, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--group-size",
type=int,
default=128,
help="GROUP_SIZE_M expected by the kernel",
)
return parser.parse_args()
def pick_dtype(name: str) -> torch.dtype:
return {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}[name]
def make_indices(
num_groups: int, group_size: int, device: torch.device
) -> torch.Tensor:
indices = torch.arange(num_groups, device=device, dtype=torch.int32)
return indices.repeat_interleave(group_size)
def timed_call(fn, *args, warmup: int, iters: int) -> float:
for _ in range(warmup):
fn(*args)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iters):
fn(*args)
torch.cuda.synchronize()
return (time.perf_counter() - start) * 1000.0 / iters
def run_scenario(
scenario: Scenario,
*,
dtype: torch.dtype,
device: torch.device,
warmup: int,
iters: int,
group_size_m: int,
) -> dict:
if scenario.m % scenario.num_groups != 0:
raise ValueError(
f"M ({scenario.m}) not divisible by groups ({scenario.num_groups})"
)
group_size = scenario.m // scenario.num_groups
if group_size % group_size_m != 0:
raise ValueError(
f"Group size {group_size} must be a multiple of GROUP_SIZE_M ({group_size_m}) for the Triton kernel"
)
inputs = torch.randn(scenario.m, scenario.k, device=device, dtype=dtype)
weights = torch.randn(
scenario.num_groups, scenario.n, scenario.k, device=device, dtype=dtype
)
indices = make_indices(scenario.num_groups, group_size, device)
def persistent():
return cg_grouped_gemm_forward(inputs, weights, indices, group_size_m)
def baseline():
return cg_grouped_gemm_forward_dynamic(inputs, weights, indices, group_size_m)
persistent_ms = timed_call(persistent, warmup=warmup, iters=iters)
baseline_ms = timed_call(baseline, warmup=warmup, iters=iters)
return {
"scenario": scenario,
"persistent_ms": persistent_ms,
"baseline_ms": baseline_ms,
"speedup": baseline_ms / persistent_ms if persistent_ms > 0 else float("nan"),
}
def main() -> None: # pragma: no cover - utility script
args = parse_args()
torch.manual_seed(args.seed)
if args.device != "cuda" or not torch.cuda.is_available():
raise SystemExit("CUDA device required for this benchmark")
dtype = pick_dtype(args.dtype)
device = torch.device(args.device)
print(
f"device={device} dtype={dtype} warmup={args.warmup} iters={args.iters} group_size={args.group_size}"
)
print(
f"{'groups':>7} {'m':>7} {'n':>7} {'k':>7} {'persistent':>12} {'baseline':>12} {'speedup':>8}"
)
for result in run_all(
SCENARIOS,
dtype=dtype,
device=device,
warmup=args.warmup,
iters=args.iters,
group_size_m=args.group_size,
):
scen = result["scenario"]
print(
f"{scen.num_groups:>7} {scen.m:>7} {scen.n:>7} {scen.k:>7}"
f" {result['persistent_ms']:>11.3f} ms {result['baseline_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
)
def run_all(
scenarios: Iterable[Scenario],
*,
dtype: torch.dtype,
device: torch.device,
warmup: int,
iters: int,
group_size_m: int,
) -> Iterable[dict]:
for scenario in scenarios:
yield run_scenario(
scenario,
dtype=dtype,
device=device,
warmup=warmup,
iters=iters,
group_size_m=group_size_m,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,301 @@
#!/usr/bin/env python
# mypy: ignore-errors
"""Microbenchmark for DeepSeek V3 MoE block comparing baseline vs Triton CG kernels."""
from __future__ import annotations
import argparse
import sys
import time
from pathlib import Path
from types import MethodType
import torch
try:
from transformers.models.deepseek_v3.configuration_deepseek_v3 import (
DeepseekV3Config,
)
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
except ImportError as exc: # pragma: no cover - utility script
raise SystemExit(
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
) from exc
CURRENT_DIR = Path(__file__).resolve().parent
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports")
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe # noqa: E402
ACCURACY_TOLERANCE = 5e-3
DTYPE_MAP = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--seq-len", type=int, default=2048, help="sequence length")
parser.add_argument("--hidden-size", type=int, default=4096, help="MoE hidden size")
parser.add_argument(
"--moe-intermediate-size",
type=int,
default=8192,
help="MoE intermediate projection size",
)
parser.add_argument(
"--n-experts",
type=int,
default=256,
help="Number of routed experts",
)
parser.add_argument(
"--top-k",
type=int,
default=8,
help="Number of experts per token",
)
parser.add_argument(
"--groups",
type=int,
default=8,
help="Router groups (must divide n-experts)",
)
parser.add_argument(
"--dtype",
choices=DTYPE_MAP.keys(),
default="bf16",
help="Computation dtype",
)
parser.add_argument(
"--device",
default="auto",
choices=["auto", "cpu", "cuda"],
help="Execution device",
)
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=25, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--uniform-routing",
action="store_true",
help="Override router to distribute tokens evenly across experts",
)
parser.add_argument(
"--group-size",
type=int,
default=128,
help="GROUP_SIZE_M used by the Triton kernel",
)
parser.add_argument(
"--backend",
choices=["cg", "mg"],
default="mg",
help="MoE kernel backend to benchmark",
)
return parser.parse_args()
def resolve_device(requested: str) -> torch.device:
if requested == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(requested)
def build_module(args: argparse.Namespace) -> DeepseekV3MoE:
config = DeepseekV3Config(
hidden_size=args.hidden_size,
intermediate_size=args.moe_intermediate_size,
moe_intermediate_size=args.moe_intermediate_size,
n_routed_experts=args.n_experts,
num_experts_per_tok=args.top_k,
n_group=args.groups,
topk_group=max(1, min(args.groups, args.top_k)),
n_shared_experts=1,
)
module = DeepseekV3MoE(config)
module.eval()
return module
@torch.no_grad()
def timed_forward(
module: DeepseekV3MoE, inputs: torch.Tensor, iters: int, warmup: int
) -> float:
for _ in range(warmup):
module(inputs)
if inputs.is_cuda:
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iters):
module(inputs)
if inputs.is_cuda:
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
return (elapsed / iters) * 1000.0
def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
torch.manual_seed(args.seed)
device = resolve_device(args.device)
dtype = DTYPE_MAP[args.dtype]
if args.n_experts % args.groups != 0:
raise SystemExit("n-experts must be divisible by groups")
if args.top_k > args.n_experts:
raise SystemExit("top-k cannot exceed number of experts")
if device.type == "cuda" and not torch.cuda.is_available():
raise SystemExit("CUDA requested but not available")
baseline_module = build_module(args)
original_moe = getattr(
DeepseekV3MoE,
"_axolotl_triton_original_moe",
DeepseekV3MoE.moe,
)
baseline_module.moe = MethodType(original_moe, baseline_module)
state_dict = baseline_module.state_dict()
patch_deepseek_v3_moe(group_size_m=args.group_size, backend=args.backend)
patched_module = build_module(args)
patched_module.load_state_dict(state_dict)
baseline_module.to(device=device, dtype=dtype)
patched_module.to(device=device, dtype=dtype)
tokens = args.batch * args.seq_len
routed_tokens = tokens * args.top_k
avg_tokens_per_expert = routed_tokens / args.n_experts
inputs = torch.randn(
args.batch,
args.seq_len,
args.hidden_size,
device=device,
dtype=dtype,
)
with torch.no_grad():
flat_inputs = inputs.view(-1, args.hidden_size)
if args.uniform_routing:
total_assignments = flat_inputs.size(0) * args.top_k
base = total_assignments // args.n_experts
remainder = total_assignments % args.n_experts
counts = torch.full(
(args.n_experts,),
base,
dtype=torch.int64,
device=device,
)
if remainder:
counts[:remainder] += 1
assignments = torch.repeat_interleave(
torch.arange(args.n_experts, device=device), counts
)
assignments = assignments[torch.randperm(assignments.size(0))]
topk_idx = assignments.view(flat_inputs.size(0), args.top_k)
else:
topk_idx, _ = patched_module.gate(flat_inputs)
tokens_per_expert = torch.bincount(
topk_idx.reshape(-1), minlength=args.n_experts
)
min_tokens = int(tokens_per_expert.min().item())
max_tokens = int(tokens_per_expert.max().item())
if args.uniform_routing:
weights = torch.full(
topk_idx.shape,
1.0 / args.top_k,
device=device,
dtype=torch.float32,
)
def _uniform_gate(self, hidden_states):
flat = hidden_states.view(-1, hidden_states.shape[-1])
token_count = flat.shape[0]
return topk_idx[:token_count], weights[:token_count]
patched_module.gate.forward = _uniform_gate.__get__(
patched_module.gate, patched_module.gate.__class__
)
baseline_module.gate.forward = _uniform_gate.__get__(
baseline_module.gate, baseline_module.gate.__class__
)
with torch.no_grad():
ref_output = baseline_module(inputs)
patched_output = patched_module(inputs)
max_diff = (ref_output - patched_output).abs().max().item()
baseline_vram = patched_vram = None
if device.type == "cuda":
torch.cuda.reset_peak_memory_stats(device)
baseline_ms = timed_forward(baseline_module, inputs, args.iters, args.warmup)
if device.type == "cuda":
baseline_vram = torch.cuda.max_memory_allocated(device)
torch.cuda.reset_peak_memory_stats(device)
patched_ms = timed_forward(patched_module, inputs, args.iters, args.warmup)
if device.type == "cuda":
patched_vram = torch.cuda.max_memory_allocated(device)
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
return {
"device": device,
"backend": args.backend,
"dtype": dtype,
"baseline_ms": baseline_ms,
"patched_ms": patched_ms,
"speedup": speedup,
"max_diff": max_diff,
"routed_tokens": routed_tokens,
"avg_tokens": avg_tokens_per_expert,
"min_tokens": min_tokens,
"max_tokens": max_tokens,
"baseline_vram": baseline_vram,
"patched_vram": patched_vram,
"accuracy_ok": max_diff <= ACCURACY_TOLERANCE,
}
def main() -> None: # pragma: no cover - CLI entrypoint
args = parse_args()
result = benchmark_deepseek_v3(args)
print(
f"Device={result['device'].type} dtype={result['dtype']} backend={result['backend']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
)
print(
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
)
print(f"min/max tokens per expert: {result['min_tokens']}/{result['max_tokens']}")
if result["baseline_vram"] is not None:
print(
f"VRAM baseline={result['baseline_vram'] / (1024**2):.1f} MiB | patched={result['patched_vram'] / (1024**2):.1f} MiB"
)
print(
f"Baseline: {result['baseline_ms']:.3f} ms | Patched: {result['patched_ms']:.3f} ms | x{result['speedup']:.2f}"
)
print(f"Max |Δ| between outputs: {result['max_diff']:.2e}")
if not result["accuracy_ok"]:
raise RuntimeError(
f"Accuracy check failed: max diff {result['max_diff']:.3e} exceeds tolerance {ACCURACY_TOLERANCE:.1e}"
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,275 @@
#!/usr/bin/env python
# mypy: ignore-errors
"""Sweep a set of DeepSeek V3 MoE benchmark configurations."""
from __future__ import annotations
import argparse
import csv
import logging
import sys
from pathlib import Path
from types import SimpleNamespace
CURRENT_DIR = Path(__file__).resolve().parent
for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
repo_root = candidate / "axolotl"
if repo_root.exists():
if str(repo_root) not in sys.path:
sys.path.insert(0, str(repo_root))
break
else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports")
from scripts.benchmarks.deepseek_v3_moe import ( # noqa: E402
ACCURACY_TOLERANCE,
DTYPE_MAP,
benchmark_deepseek_v3,
)
LOG = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--dtype",
choices=DTYPE_MAP.keys(),
default="bf16",
help="Computation dtype for all benchmarks",
)
parser.add_argument(
"--device",
default="auto",
choices=["auto", "cpu", "cuda"],
help="Execution device",
)
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--group-size",
type=int,
help="Override GROUP_SIZE_M for every configuration",
)
parser.add_argument(
"--backends",
default="mg",
help="Comma separated list of backends to benchmark (subset of cg,mg)",
)
parser.add_argument(
"--no-uniform-routing",
action="store_true",
help="Disable uniform routing for every configuration",
)
parser.add_argument(
"--include-mixtral-long",
action="store_true",
help="Add an 8×8192 Mixtral-style run to the sweep",
)
parser.add_argument(
"--output",
type=Path,
help="Optional CSV file to store results",
)
return parser.parse_args()
def make_namespace(
base: dict, args: argparse.Namespace, backend: str
) -> SimpleNamespace:
combined = dict(base)
combined.update(
{
"dtype": args.dtype,
"device": args.device,
"backend": backend,
"warmup": args.warmup,
"iters": args.iters,
"seed": args.seed,
"uniform_routing": not args.no_uniform_routing,
}
)
if args.group_size is not None:
combined["group_size"] = args.group_size
return SimpleNamespace(**combined)
ARCHETYPES = (
(
"mixtral",
{
"hidden_size": 4096,
"moe_intermediate_size": 14336,
"n_experts": 8,
"top_k": 2,
"groups": 1,
"group_size": 128,
},
[(4, 2048), (8, 4096)],
),
(
"qwen",
{
"hidden_size": 6144,
"moe_intermediate_size": 24576,
"n_experts": 16,
"top_k": 4,
"groups": 8,
"group_size": 128,
},
[(4, 4096), (8, 8192)],
),
(
"deepseek_v3",
{
"hidden_size": 12288,
"moe_intermediate_size": 49152,
"n_experts": 128,
"top_k": 8,
"groups": 16,
"group_size": 128,
},
[(4, 4096), (8, 8192)],
),
)
MIXTRAL_LONG_SHAPES = [(8, 8192)]
def main() -> None: # pragma: no cover - utility script
args = parse_args()
grid = []
for label, base_cfg, shapes in ARCHETYPES:
for batch, seq_len in shapes:
cfg = {
"label": label,
"batch": batch,
"seq_len": seq_len,
**base_cfg,
}
if cfg["n_experts"] % cfg["groups"] != 0 or cfg["top_k"] > cfg["n_experts"]:
continue
grid.append(cfg)
if args.include_mixtral_long:
base_cfg = ARCHETYPES[0][1]
for batch, seq_len in MIXTRAL_LONG_SHAPES:
grid.append(
{
"label": "mixtral_long",
"batch": batch,
"seq_len": seq_len,
**base_cfg,
}
)
if not grid:
raise SystemExit("No valid parameter combinations produced")
header = (
"model",
"batch",
"seq_len",
"hidden_size",
"moe_intermediate",
"n_experts",
"top_k",
"groups",
"backend",
"baseline_ms",
"patched_ms",
"speedup",
"baseline_vram_mib",
"patched_vram_mib",
"min_tokens",
"max_tokens",
"max_diff",
"accuracy_ok",
)
rows = []
raw_backends = [
token.strip() for token in args.backends.split(",") if token.strip()
]
if not raw_backends:
raw_backends = ["mg"]
valid_backends = []
for backend in raw_backends:
if backend not in {"cg", "mg"}:
raise SystemExit(f"Unsupported backend '{backend}' requested")
if backend not in valid_backends:
valid_backends.append(backend)
uniform_flag = not args.no_uniform_routing
print(
f"Running sweep on device={args.device} dtype={args.dtype} backends={tuple(valid_backends)} uniform_routing={uniform_flag}"
)
print(
f"{'model':>10} {'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'diff':>10} {'acc':>5}"
)
for cfg in grid:
for backend in valid_backends:
ns = make_namespace(cfg, args, backend)
result = benchmark_deepseek_v3(ns)
baseline_vram_mib = (
result["baseline_vram"] / (1024**2)
if result["baseline_vram"] is not None
else float("nan")
)
patched_vram_mib = (
result["patched_vram"] / (1024**2)
if result["patched_vram"] is not None
else float("nan")
)
rows.append(
(
cfg["label"],
cfg["batch"],
cfg["seq_len"],
cfg["hidden_size"],
cfg["moe_intermediate_size"],
cfg["n_experts"],
cfg["top_k"],
cfg["groups"],
backend,
result["baseline_ms"],
result["patched_ms"],
result["speedup"],
baseline_vram_mib,
patched_vram_mib,
result["min_tokens"],
result["max_tokens"],
result["max_diff"],
result["accuracy_ok"],
)
)
status = "OK" if result["accuracy_ok"] else "FAIL"
print(
f"{cfg['label']:>10} {cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {backend:>8}"
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {result['max_diff']:>10.3e} {status:>5}"
)
if not result["accuracy_ok"]:
LOG.warning(
"Accuracy tolerance exceeded for %s backend=%s: diff=%.3e (> %.1e)",
cfg["label"],
backend,
result["max_diff"],
ACCURACY_TOLERANCE,
)
if args.output:
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", newline="") as fp:
writer = csv.writer(fp)
writer.writerow(header)
writer.writerows(rows)
print(f"Results written to {args.output}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,21 @@
"""Mixture-of-Experts kernel implementations."""
from .indices import generate_permute_indices
from .tt_cg_gemm import (
ContiguousGroupedGEMM,
ContiguousGroupedGEMMForwardOnly,
cg_grouped_gemm,
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
from .tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
__all__ = [
"cg_grouped_gemm",
"cg_grouped_gemm_forward",
"cg_grouped_gemm_forward_dynamic",
"ContiguousGroupedGEMM",
"ContiguousGroupedGEMMForwardOnly",
"generate_permute_indices",
"mg_grouped_gemm",
]

View File

@@ -0,0 +1,5 @@
"""Token permutation utilities for grouped MoE kernels."""
from .indices import generate_permute_indices
__all__ = ["generate_permute_indices"]

View File

@@ -0,0 +1,144 @@
"""Vendored token permutation kernels from TorchTitan."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import triton
import triton.language as tl
__all__ = ["generate_permute_indices"]
@triton.jit
def _fill_indices_kernel(
tokens_per_expert_group_ptr,
start_index_values_ptr,
write_offsets_ptr,
output_ptr,
experts_per_rank: tl.constexpr,
num_ranks: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_programs = tl.num_programs(axis=0)
for expert_id in range(pid, experts_per_rank, num_programs):
write_offset = tl.load(write_offsets_ptr + expert_id)
for r in range(num_ranks):
idx = r * experts_per_rank + expert_id
start_index = tl.load(start_index_values_ptr + idx)
length = tl.load(tokens_per_expert_group_ptr + idx)
offsets = tl.arange(0, BLOCK_SIZE)
for chunk_start in range(0, length, BLOCK_SIZE):
chunk_offsets = chunk_start + offsets
mask = chunk_offsets < length
values = start_index + chunk_offsets
dest_indices = write_offset + chunk_offsets
tl.store(output_ptr + dest_indices, values, mask=mask)
write_offset += length
def fill_indices_wrapper(
tokens_per_expert_group: torch.Tensor,
start_index_values: torch.Tensor,
write_offsets: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
block_size: int = 128,
max_blocks: int = 1024,
):
permuted_indices = torch.full(
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
)
num_blocks = min(experts_per_rank, max_blocks)
grid = (num_blocks,)
_fill_indices_kernel[grid](
tokens_per_expert_group,
start_index_values,
write_offsets,
permuted_indices,
experts_per_rank,
num_ranks,
BLOCK_SIZE=block_size,
)
return permuted_indices
def fill_indices_cpu(
tokens_per_expert_group: torch.Tensor,
start_index_values: torch.Tensor,
write_offsets: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
):
permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
for expert_id in range(experts_per_rank):
write_start = write_offsets[expert_id].item()
for r in range(num_ranks):
idx = r * experts_per_rank + expert_id
start_index = start_index_values[idx].item()
length = tokens_per_expert_group[idx].item()
if length > 0:
end_idx = min(write_start + length, max_len)
permuted_indices[write_start:end_idx] = torch.arange(
start_index,
start_index + (end_idx - write_start),
dtype=torch.int32,
)
write_start += length
return permuted_indices
def generate_permute_indices(
tokens_per_expert_group: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
alignment: int,
use_cpu: bool = False,
):
start_index_values = (
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
)
total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment)
m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to(
torch.int32
)
m_offsets = torch.cumsum(m_sizes, 0)
write_offsets = m_offsets - m_sizes
if use_cpu:
permuted_indices = fill_indices_cpu(
tokens_per_expert_group,
start_index_values,
write_offsets,
experts_per_rank,
num_ranks,
max_len,
)
else:
permuted_indices = fill_indices_wrapper(
tokens_per_expert_group,
start_index_values,
write_offsets,
experts_per_rank,
num_ranks,
max_len,
)
return permuted_indices, m_sizes, m_offsets.to(torch.int32)

View File

@@ -0,0 +1,17 @@
"""Vendored Triton contiguous grouped GEMM kernels from TorchTitan."""
from .cg_backward import ContiguousGroupedGEMM
from .cg_forward import (
ContiguousGroupedGEMM as ContiguousGroupedGEMMForwardOnly,
cg_grouped_gemm,
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
__all__ = [
"cg_grouped_gemm",
"cg_grouped_gemm_forward",
"cg_grouped_gemm_forward_dynamic",
"ContiguousGroupedGEMM",
"ContiguousGroupedGEMMForwardOnly",
]

View File

@@ -0,0 +1,290 @@
"""Vendored backward pass for Triton contiguous grouped GEMM."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import triton
import triton.language as tl
from .cg_forward import cg_grouped_gemm_forward
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
GROUP_SIZE_M = 128
@triton.autotune(
configs=STANDARD_CONFIGS,
key=["M_TOTAL", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_cg_backward_dx(
grad_output_ptr,
b_ptr,
grad_input_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
):
"""Compute gradients with respect to inputs."""
pid = tl.program_id(0)
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
tile_m = pid // num_k_tiles
tile_k = pid % num_k_tiles
m_start = tile_m * BLOCK_SIZE_M
k_start = tile_k * BLOCK_SIZE_K
if m_start < M_TOTAL:
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
mask_m = offs_m < M_TOTAL
mask_k = offs_k < K
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
grad_input = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_K], dtype=tl.float32)
for n in range(0, N, BLOCK_SIZE_N):
offs_n = tl.arange(0, BLOCK_SIZE_N) + n
mask_n = offs_n < N
mask_go = mask_m[:, None] & mask_n[None, :]
mask_w = mask_n[:, None] & mask_k[None, :]
go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
w = tl.load(w_ptrs, mask=mask_w, other=0.0).to(tl.float32)
grad_input += tl.dot(go, w)
grad_input_ptrs = grad_input_ptr + offs_m[:, None] * K + offs_k[None, :]
mask_gi = mask_m[:, None] & mask_k[None, :]
tl.store(grad_input_ptrs, grad_input, mask=mask_gi)
@triton.jit
def _kernel_cg_backward_dw(
grad_output_ptr,
inputs_ptr,
grad_weights_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
):
"""Simplified kernel for expert weight gradients."""
pid = tl.program_id(0)
expert_id = pid // ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
position_id = pid % ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
if expert_id < NUM_EXPERTS:
n_tiles = K // BLOCK_SIZE_K
tile_n = position_id // n_tiles
tile_k = position_id % n_tiles
n_start = tile_n * BLOCK_SIZE_N
k_start = tile_k * BLOCK_SIZE_K
if n_start < N and k_start < K:
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
mask_n = offs_n < N
mask_k = offs_k < K
grad_weights = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_K], dtype=tl.float32)
for group_idx in range(0, M_TOTAL // GROUP_SIZE_M):
group_start = group_idx * GROUP_SIZE_M
group_expert = tl.load(indices_ptr + group_start)
if group_expert == expert_id:
for m_offset in range(0, GROUP_SIZE_M, BLOCK_SIZE_M):
m_start = group_start + m_offset
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
mask_m = offs_m < min(group_start + GROUP_SIZE_M, M_TOTAL)
go_ptrs = (
grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
)
mask_go = mask_m[:, None] & mask_n[None, :]
go = tl.load(go_ptrs, mask=mask_go, other=0.0).to(tl.float32)
in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :]
mask_in = mask_m[:, None] & mask_k[None, :]
inp = tl.load(in_ptrs, mask=mask_in, other=0.0).to(tl.float32)
go_t = tl.trans(go)
grad_weights += tl.dot(go_t, inp)
grad_w_ptrs = (
grad_weights_ptr
+ expert_id * N * K
+ offs_n[:, None] * K
+ offs_k[None, :]
)
mask_gw = mask_n[:, None] & mask_k[None, :]
tl.store(grad_w_ptrs, grad_weights, mask=mask_gw)
def cg_grouped_gemm_backward_weights(
grad_output: torch.Tensor,
inputs: torch.Tensor,
expert_indices: torch.Tensor,
num_experts: int,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Backward pass for expert weights."""
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
assert inputs.is_contiguous(), "Inputs tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, N = grad_output.shape
_, K = inputs.shape
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
grad_weights = torch.zeros(
(num_experts, N, K), device=grad_output.device, dtype=grad_output.dtype
)
block_size_n = min(128, N)
block_size_k = min(32, K)
block_size_m = min(32, group_size_m)
n_tiles = triton.cdiv(N, block_size_n)
k_tiles = triton.cdiv(K, block_size_k)
grid = (num_experts * n_tiles * k_tiles,)
_kernel_cg_backward_dw[grid](
grad_output,
inputs,
grad_weights,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
BLOCK_SIZE_M=block_size_m,
)
return grad_weights
def cg_grouped_gemm_backward_inputs(
grad_output: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Backward pass for inputs."""
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, N = grad_output.shape
num_experts, _, K = expert_weights.shape
assert M_total % group_size_m == 0, (
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
)
grad_inputs = torch.zeros(
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
)
grid = lambda meta: (
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
* triton.cdiv(K, meta["BLOCK_SIZE_K"]),
)
_kernel_cg_backward_dx[grid](
grad_output,
expert_weights,
grad_inputs,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
)
return grad_inputs
class ContiguousGroupedGEMM(torch.autograd.Function):
"""Autograd function with full backward support."""
@staticmethod
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
ctx.save_for_backward(inputs, expert_weights, expert_indices)
ctx.group_size_m = group_size_m
return cg_grouped_gemm_forward(
inputs=inputs,
expert_weights=expert_weights,
expert_indices=expert_indices,
group_size_m=group_size_m,
)
@staticmethod
def backward(ctx, grad_output):
inputs, expert_weights, expert_indices = ctx.saved_tensors
group_size_m = ctx.group_size_m
grad_output = grad_output.contiguous()
num_experts = expert_weights.shape[0]
grad_inputs = cg_grouped_gemm_backward_inputs(
grad_output=grad_output,
expert_weights=expert_weights,
expert_indices=expert_indices,
group_size_m=group_size_m,
)
grad_weights = cg_grouped_gemm_backward_weights(
grad_output=grad_output,
inputs=inputs,
expert_indices=expert_indices,
num_experts=num_experts,
group_size_m=group_size_m,
)
grad_indices = None
grad_group_size_m = None
return grad_inputs, grad_weights, grad_indices, grad_group_size_m

View File

@@ -0,0 +1,311 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Vendored forward Triton contiguous grouped GEMM kernels."""
import torch
import triton
import triton.language as tl
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
GROUP_SIZE_M = 128
@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, super_group_m):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * super_group_m
group_size_m = min(num_pid_m - first_pid_m, super_group_m)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
return pid_m, pid_n
@triton.autotune(
configs=STANDARD_CONFIGS,
key=["M_TOTAL", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_cg_persistent_forward(
a_ptr,
b_ptr,
c_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
NUM_SMS: tl.constexpr,
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
SUPER_GROUP_M: tl.constexpr = 32,
):
"""
Contiguous Grouped GEMM kernel forward (persistent variant).
"""
c_type = c_ptr.dtype.element_ty
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n
tile_id_c = start_pid - NUM_SMS
num_pid_in_group = SUPER_GROUP_M * num_pid_n
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS):
tile_m_idx, tile_n_idx = _compute_pid(
tile_id, num_pid_in_group, num_pid_m, SUPER_GROUP_M
)
m_start = tile_m_idx * BLOCK_SIZE_M
n_start = tile_n_idx * BLOCK_SIZE_N
if m_start < M_TOTAL:
offs_m = m_start + tl.arange(0, BLOCK_SIZE_M)
offs_n = n_start + tl.arange(0, BLOCK_SIZE_N)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for ki in range(k_tiles):
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
mask_k = offs_k < K
mask_a = mask_m[:, None] & mask_k[None, :]
mask_b = mask_n[:, None] & mask_k[None, :]
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
b_ptrs = (
b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
)
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
accumulator += tl.dot(a, b.T)
tile_id_c += NUM_SMS
tile_m_idx, tile_n_idx = _compute_pid(
tile_id_c, num_pid_in_group, num_pid_m, SUPER_GROUP_M
)
offs_m = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
mask_c = mask_m[:, None] & mask_n[None, :]
c = accumulator.to(tl.float32)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
tl.store(c_ptrs, c.to(c_type), mask=mask_c)
@triton.autotune(
configs=STANDARD_CONFIGS,
key=["M_TOTAL", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_cg_forward_aligned(
a_ptr,
b_ptr,
c_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
):
"""
Contiguous Grouped GEMM kernel forward for aligned inputs.
"""
pid = tl.program_id(0)
c_type = c_ptr.dtype.element_ty
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
tile_m = pid // num_n_tiles
tile_n = pid % num_n_tiles
m_start = tile_m * BLOCK_SIZE_M
n_start = tile_n * BLOCK_SIZE_N
if m_start < M_TOTAL:
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
acc = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
offs_k = tl.arange(0, BLOCK_SIZE_K) + k
mask_k = offs_k < K
mask_a = mask_m[:, None] & mask_k[None, :]
mask_b = mask_n[:, None] & mask_k[None, :]
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
b_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
acc += tl.dot(a, b.T)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask_c = mask_m[:, None] & mask_n[None, :]
tl.store(c_ptrs, acc.to(c_type), mask=mask_c)
def cg_grouped_gemm_forward(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Contiguous grouped GEMM forward pass for MoE."""
assert inputs.is_contiguous(), "Input tensor must be contiguous"
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, K = inputs.shape
assert M_total % group_size_m == 0, (
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
)
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
num_experts, N, K_weights = expert_weights.shape
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
assert expert_indices.shape[0] == M_total, (
"Expert indices length must match M_total"
)
output = torch.empty((M_total, N), device=inputs.device, dtype=torch.bfloat16)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
grid = (NUM_SMS, 1, 1)
_kernel_cg_persistent_forward[grid](
inputs,
expert_weights,
output,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
NUM_SMS=NUM_SMS,
)
return output
def cg_grouped_gemm_forward_dynamic(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Contiguous grouped GEMM forward pass for MoE with autotuned launch."""
assert inputs.is_contiguous(), "Input tensor must be contiguous"
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, K = inputs.shape
assert M_total % group_size_m == 0, (
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
)
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
num_experts, N, K_weights = expert_weights.shape
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
assert expert_indices.shape[0] == M_total, (
"Expert indices length must match M_total"
)
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
grid = lambda meta: (
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
* triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
_kernel_cg_forward_aligned[grid](
inputs,
expert_weights,
output,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
)
return output
class ContiguousGroupedGEMM(torch.autograd.Function):
"""Autograd function for contiguous grouped GEMM forward pass only."""
@staticmethod
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
return cg_grouped_gemm_forward(
inputs=inputs,
expert_weights=expert_weights,
expert_indices=expert_indices,
group_size_m=group_size_m,
)
@staticmethod
def backward(ctx, grad_output): # pragma: no cover - not implemented
raise NotImplementedError("Backward pass not implemented")
def cg_grouped_gemm(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Convenience wrapper for the forward-only autograd function."""
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
return ContiguousGroupedGEMM.apply(
inputs, expert_weights, expert_indices, group_size_m
)

View File

@@ -0,0 +1,31 @@
"""Reference implementation for contiguous grouped GEMM."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
def pytorch_reference(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = 128,
) -> torch.Tensor:
"""Simple PyTorch implementation for verification."""
M_total, K = inputs.shape
num_experts, N, _ = expert_weights.shape
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
for i in range(0, M_total, group_size_m):
end_idx = min(i + group_size_m, M_total)
expert_idx = expert_indices[i].item()
expert_weight = expert_weights[expert_idx]
output[i:end_idx] = torch.matmul(inputs[i:end_idx], expert_weight.T)
return output

View File

@@ -0,0 +1,209 @@
"""Autotuning utilities for Triton contiguous grouped GEMM kernels."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict
import torch
import triton
import triton.language as tl
from triton.runtime import driver
class CudaUtils:
"""Helper utilities for CUDA specific Triton features."""
@staticmethod
def is_cuda() -> bool:
return driver.active.get_current_target().backend == "cuda"
@staticmethod
def verify_tma() -> bool:
return (
CudaUtils.is_cuda()
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
)
@staticmethod
def get_num_sms() -> int:
if not CudaUtils.is_cuda():
raise RuntimeError("Triton is not running on CUDA backend")
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
return torch.cuda.get_device_properties("cuda").multi_processor_count
class TmaDescriptorHelper:
"""Helper class for managing TMA descriptors in Triton kernels."""
class KernelParamWrapper:
def __init__(self, desc: torch.Tensor):
self.desc = desc
def tma_desc_cpu_ptr(self) -> int:
return self.desc.data_ptr()
def __init__(self, tma_size: int = 128):
if not CudaUtils.verify_tma():
raise RuntimeError(
"TMA not supported on this device (requires Hopper or newer)"
)
if "nv_tma_desc_type" not in dir(tl):
raise RuntimeError(
"TMA grid constant descriptors not supported in your Triton version"
)
self.tma_size = tma_size
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
self.descriptors: Dict[str, torch.Tensor] = {}
def init_tma_descriptor(self, name: str) -> None:
self.descriptors[name] = torch.empty(
self.tma_size, device="cpu", dtype=torch.int8
)
def fill_1d_tma_descriptor(
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
) -> None:
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
def fill_2d_tma_descriptor(
self,
name: str,
ptr: int,
dim1: int,
dim0: int,
block_dim1: int,
block_dim0: int,
element_size: int,
) -> None:
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
def get_tma_descriptor_kernel_param(
self, name: str
) -> "TmaDescriptorHelper.KernelParamWrapper":
if name not in self.descriptors or self.descriptors[name] is None:
raise ValueError(f"TMA descriptor '{name}' not initialized")
return self.KernelParamWrapper(self.descriptors[name])
HOPPER_CONFIGS = [
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=8,
),
]
STANDARD_CONFIGS = [
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=8,
),
]
def early_config_prune(configs, args, **kwargs):
"""Filter out configurations that would exceed shared memory capacity."""
k = kwargs.get("K", 0)
valid_configs = [
config for config in configs if config.kwargs.get("BLOCK_SIZE_K", 0) <= k
]
if not valid_configs and configs:
return [
min(
configs,
key=lambda c: c.kwargs.get("BLOCK_SIZE_K", float("inf")),
)
]
return valid_configs

View File

@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .mg_grouped_gemm import grouped_gemm_forward
from .tma_autotuning import ALIGN_SIZE_M
__all__ = [
"grouped_gemm_forward",
"ALIGN_SIZE_M",
]

View File

@@ -0,0 +1,761 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# credit - flat index forward kernel is derived from FBGemm:
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
# pyre-unsafe
import logging
from typing import Tuple
import torch
import triton
import triton.language as tl
from .tma_autotuning import (
_NV_CONFIGS,
CudaUtils,
early_config_prune,
)
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
_allocator_registered = False
def _torch_allocator(size: int, alignment: int, stream) -> torch.Tensor:
return torch.empty(size, device="cuda", dtype=torch.int8)
def _ensure_triton_allocator() -> None:
global _allocator_registered
if not _allocator_registered:
triton.set_allocator(_torch_allocator)
_allocator_registered = True
# ============== Start Triton Kernels ===============
@triton.autotune(
configs=_NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_forward_hopper(
a_ptr,
b_ptr,
c_ptr,
m_sizes,
M_TOTAL,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
# config
NUM_SMS: tl.constexpr,
USE_EPILOGUE_SUBTILING: tl.constexpr,
# tiles
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
) -> None:
"""Flat index style forward kernel for Hopper using tensor descriptors."""
tbidx = tl.program_id(0)
c_dtype = c_ptr.dtype.element_ty
n_size = N // G
a_desc = tl.make_tensor_descriptor(
a_ptr,
shape=[M_TOTAL, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
b_desc = tl.make_tensor_descriptor(
b_ptr,
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
M_end = tl.full([], 0, dtype=tl.int32)
processed_tiles = 0
for g in range(G):
M_start = M_end
m_size = tl.load(m_sizes + g)
M_end = M_start + m_size
if m_size > 0:
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
group_num_tiles = num_m_tiles * num_n_tiles
while (
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
):
group_index = tbidx - processed_tiles
tile_m_index = group_index % num_m_tiles
tile_n_index = group_index // num_m_tiles
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
cols_remaining = n_size - tile_n_index * BLOCK_SIZE_N
col_mask = tl.arange(0, BLOCK_SIZE_N) < cols_remaining
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
global_n_offset = (g * n_size + n_offset).to(tl.int32)
for k_offset in range(0, K, BLOCK_SIZE_K):
k_remaining = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
a = a_desc.load([m_offset, k_offset])
a_mask = row_mask[:, None] & k_mask[None, :]
a = tl.where(a_mask, a, tl.zeros_like(a))
b = b_desc.load([global_n_offset, k_offset])
b_mask = col_mask[:, None] & k_mask[None, :]
b = tl.where(b_mask, b, tl.zeros_like(b))
accumulator += tl.dot(a, b.T)
local_m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
local_row_offsets = local_m_offset + tl.arange(0, BLOCK_SIZE_M)
row_store_mask = local_row_offsets < m_size
global_row = (M_start + local_row_offsets).to(tl.int32)
local_col_offsets = tile_n_index * BLOCK_SIZE_N + tl.arange(
0, BLOCK_SIZE_N
)
col_store_mask = local_col_offsets < n_size
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
if USE_EPILOGUE_SUBTILING:
acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
acc = tl.permute(acc, (0, 2, 1))
acc0, acc1 = tl.split(acc)
col_offsets0 = local_col_offsets[: BLOCK_SIZE_N // 2]
col_mask0 = col_store_mask[: BLOCK_SIZE_N // 2]
ptr0 = c_ptr + global_row[:, None] * n_size + col_offsets0[None, :]
tl.store(
ptr0,
acc0.to(c_dtype),
mask=row_store_mask[:, None] & col_mask0[None, :],
)
col_offsets1 = local_col_offsets[BLOCK_SIZE_N // 2 :]
col_mask1 = col_store_mask[BLOCK_SIZE_N // 2 :]
ptr1 = c_ptr + global_row[:, None] * n_size + col_offsets1[None, :]
tl.store(
ptr1,
acc1.to(c_dtype),
mask=row_store_mask[:, None] & col_mask1[None, :],
)
else:
ptr = (
c_ptr
+ global_row[:, None] * n_size
+ local_col_offsets[None, :]
)
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
tbidx += NUM_SMS
processed_tiles += group_num_tiles
"""
Backward pass for grouped GEMM with Triton, where grouping is M*G
We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
"""
# ---- dx flat linear indexed ----
@triton.autotune(
configs=_NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_dx_tma(
grad_output_ptr,
w_ptr,
grad_input_ptr,
m_sizes,
M_TOTAL,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
# config
NUM_SMS: tl.constexpr,
# tiles
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
) -> None:
"""Compute grad_input = grad_output @ w using tensor descriptors."""
tbidx = tl.program_id(0)
c_dtype = grad_input_ptr.dtype.element_ty
grad_output_desc = tl.make_tensor_descriptor(
grad_output_ptr,
shape=[M_TOTAL, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
w_desc = tl.make_tensor_descriptor(
w_ptr,
shape=[N, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K],
)
M_end = tl.full([], 0, dtype=tl.int32)
processed_tiles = 0
for g in range(G):
M_start = M_end
m_size = tl.load(m_sizes + g)
M_end = M_start + m_size
if m_size > 0:
num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
group_num_tiles = num_m_tiles * num_k_tiles
while (
tbidx >= processed_tiles and tbidx < processed_tiles + group_num_tiles
):
group_index = tbidx - processed_tiles
tile_m_index = group_index % num_m_tiles
tile_k_index = group_index // num_m_tiles
rows_remaining = m_size - tile_m_index * BLOCK_SIZE_M
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
k_offset = tile_k_index * BLOCK_SIZE_K
k_remaining_total = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining_total
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
m_offset = (M_start + tile_m_index * BLOCK_SIZE_M).to(tl.int32)
for n_offset in range(0, N, BLOCK_SIZE_N):
n_remaining = N - n_offset
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
grad_y = grad_output_desc.load([m_offset, n_offset])
grad_y_mask = row_mask[:, None] & n_mask[None, :]
grad_y = tl.where(grad_y_mask, grad_y, tl.zeros_like(grad_y))
w_tile = w_desc.load([n_offset, k_offset])
w_mask = n_mask[:, None] & k_mask[None, :]
w_tile = tl.where(w_mask, w_tile, tl.zeros_like(w_tile))
accumulator += tl.dot(grad_y, w_tile)
local_row_offsets = tile_m_index * BLOCK_SIZE_M + tl.arange(
0, BLOCK_SIZE_M
)
row_store_mask = local_row_offsets < m_size
global_row = (M_start + local_row_offsets).to(tl.int32)
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
col_store_mask = col_offsets < K
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
ptr = grad_input_ptr + global_row[:, None] * K + col_offsets[None, :]
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
tbidx += NUM_SMS
processed_tiles += group_num_tiles
@triton.autotune(
configs=_NV_CONFIGS,
key=["G", "M_BUCKET", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_mg_dw_tma(
x_ptr,
grad_output_ptr,
grad_weight_ptr,
m_sizes,
M_TOTAL,
# problem sizes
G: tl.constexpr,
M_BUCKET: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
# config
NUM_SMS: tl.constexpr,
# tiles
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
) -> None:
"""Compute grad_weight = grad_output.T @ x using tensor descriptors."""
tbidx = tl.program_id(0)
c_dtype = grad_weight_ptr.dtype.element_ty
x_desc = tl.make_tensor_descriptor(
x_ptr,
shape=[M_TOTAL, K],
strides=[K, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
)
grad_output_desc = tl.make_tensor_descriptor(
grad_output_ptr,
shape=[M_TOTAL, N],
strides=[N, 1],
block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
)
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
total_tiles = num_n_tiles * num_k_tiles
for tile_idx in range(tbidx, total_tiles, NUM_SMS):
tile_n_idx = tile_idx % num_n_tiles
tile_k_idx = tile_idx // num_n_tiles
n_offset = tile_n_idx * BLOCK_SIZE_N
n_remaining = N - n_offset
n_mask = tl.arange(0, BLOCK_SIZE_N) < n_remaining
k_offset = tile_k_idx * BLOCK_SIZE_K
k_remaining = K - k_offset
k_mask = tl.arange(0, BLOCK_SIZE_K) < k_remaining
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
M_end = tl.full([], 0, dtype=tl.int32)
for g in range(G):
M_start = M_end
m_size = tl.load(m_sizes + g)
M_end = M_start + m_size
if m_size > 0:
for m_offset_local in range(0, m_size, BLOCK_SIZE_M):
rows_remaining = m_size - m_offset_local
rows_remaining = tl.maximum(rows_remaining, 0)
row_mask = tl.arange(0, BLOCK_SIZE_M) < rows_remaining
m_offset = (M_start + m_offset_local).to(tl.int32)
x_block = x_desc.load([m_offset, k_offset])
x_mask = row_mask[:, None] & k_mask[None, :]
x_block = tl.where(x_mask, x_block, tl.zeros_like(x_block))
grad_block = grad_output_desc.load([m_offset, n_offset])
grad_mask = row_mask[:, None] & n_mask[None, :]
grad_block = tl.where(
grad_mask, grad_block, tl.zeros_like(grad_block)
)
contribution = tl.dot(
grad_block.to(tl.float32).T,
x_block.to(tl.float32),
)
accumulator += contribution
row_offsets = n_offset + tl.arange(0, BLOCK_SIZE_N)
row_store_mask = row_offsets < N
col_offsets = k_offset + tl.arange(0, BLOCK_SIZE_K)
col_store_mask = col_offsets < K
store_mask = row_store_mask[:, None] & col_store_mask[None, :]
ptr = grad_weight_ptr + row_offsets[:, None] * K + col_offsets[None, :]
tl.store(ptr, accumulator.to(c_dtype), mask=store_mask)
# ======== End Triton kernels ========
# ======== End Triton kernels ========
# ======== Triton wrapper functions ========
# ----- main forward pass wrapper -----
def grouped_gemm_forward(
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
tma_size: int = 128,
using_fp8: bool = False,
) -> torch.Tensor:
"""Grouped GEMM forward using Hopper TMA kernels."""
_ensure_triton_allocator()
if not CudaUtils.verify_tma():
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
if using_fp8:
raise NotImplementedError(
"FP8 path not implemented with the new Triton API yet"
)
G = m_sizes.shape[0]
assert x.is_contiguous()
assert w.is_contiguous()
assert m_sizes.is_contiguous()
M_total, K = x.shape
N = w.shape[0]
assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
if M_total == 0:
return y
NUM_SMS = CudaUtils.get_num_sms()
USE_EPILOGUE_SUBTILING = False
def grid(_meta):
return (NUM_SMS,)
M_BUCKET = triton.next_power_of_2(M_total)
_kernel_mg_forward_hopper[grid](
x,
w,
y,
m_sizes,
M_total,
G,
M_BUCKET,
N,
K,
NUM_SMS,
USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
)
return y
# ======== Improved Backward =============
def grouped_gemm_backward(
grad_output: torch.Tensor,
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
use_tma: bool = True,
tma_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Unified backward pass for grouped GeMM with M*G grouping.
Uses optimized TMA-based implementations for both dx and dw when available.
Args:
grad_output: Gradient of output, shape [M_total, N]
x: Input tensor from forward pass, shape [M_total, K]
w: Weight tensor from forward pass, shape [N, K]
m_sizes: Group sizes tensor, shape [G]
use_tma: Whether to try using TMA acceleration (if available)
tma_size: Size of TMA descriptor in bytes
Returns:
Tuple of gradients with respect to x and w: (grad_x, grad_w)
"""
logging.info("Starting unified grouped_gemm_backward")
# do this once, seems expensive
NUM_SMS = CudaUtils.get_num_sms()
# Basic validation
M_total, K_x = x.shape
M_grad, N = grad_output.shape
N_w, K_w = w.shape
# Check dimensions
if K_x != K_w:
raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
if M_total != M_grad:
raise ValueError(
f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
)
# Check total M matches sum of group sizes
sum_m_sizes = m_sizes.sum().item()
if M_total != sum_m_sizes:
raise ValueError(
f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
)
# Make sure inputs are contiguous
grad_output = grad_output.contiguous()
x = x.contiguous()
w = w.contiguous()
m_sizes = m_sizes.contiguous()
# Check TMA support
if use_tma and not CudaUtils.verify_tma():
logging.info("TMA requested but not supported on this device")
use_tma = False
# Compute grad_x using flat linear implementation
try:
logging.info("Computing grad_x with flat linear kernel")
# Use TMA-optimized implementation
grad_x = grouped_gemm_dx_tma(
grad_output=grad_output,
w=w,
m_sizes=m_sizes,
num_sms=NUM_SMS,
tma_size=tma_size,
)
except Exception as e:
logging.error(f"Error in grad_x computation: {e}")
raise
# Compute grad_w using flat linear style implementation
try:
logging.info("Computing grad_w with flat linear kernel")
grad_w = grouped_gemm_dw_tma(
x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
)
except Exception as e:
logging.error(f"Error in grad_w computation: {e}")
raise
return grad_x, grad_w
# ----- dx backward pass wrapper -----
def grouped_gemm_dx_tma(
grad_output: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
num_sms: int = 132,
tma_size: int = 128,
) -> torch.Tensor:
"""Compute grad_x using the Hopper grouped GEMM kernel."""
_ensure_triton_allocator()
if not CudaUtils.verify_tma():
raise NotImplementedError("Optimized dx computation requires TMA support")
grad_output = grad_output.contiguous()
w = w.contiguous()
m_sizes = m_sizes.contiguous()
M_total, N = grad_output.shape
N_w, K = w.shape
if N != N_w:
raise ValueError(f"Grad_output N ({N}) must match weight N ({N_w})")
if m_sizes.sum().item() != M_total:
raise ValueError("Sum of m_sizes must equal the number of rows in grad_output")
grad_x = torch.empty(
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
)
NUM_SMS = num_sms
def grid(_meta):
return (NUM_SMS,)
M_BUCKET = triton.next_power_of_2(M_total)
_kernel_mg_dx_tma[grid](
grad_output,
w,
grad_x,
m_sizes,
M_total,
m_sizes.shape[0],
M_BUCKET,
N,
K,
NUM_SMS,
)
return grad_x
# ======== dw wrapper function ==========
def grouped_gemm_dw_tma(
x: torch.Tensor,
grad_output: torch.Tensor,
m_sizes: torch.Tensor,
num_sms: int = 132,
tma_size: int = 128,
) -> torch.Tensor:
"""Compute grad_w using the Hopper grouped GEMM kernel."""
_ensure_triton_allocator()
if not CudaUtils.verify_tma():
raise RuntimeError("TMA grouped GEMM requested on a device without TMA support")
x = x.contiguous()
grad_output = grad_output.contiguous()
m_sizes = m_sizes.contiguous()
M_total, K = x.shape
M_grad, N = grad_output.shape
if M_total != M_grad:
raise ValueError("x and grad_output must have matching batch dimension")
if m_sizes.sum().item() != M_total:
raise ValueError("Sum of m_sizes must equal the number of rows in the inputs")
grad_w = torch.zeros((N, K), device=x.device, dtype=x.dtype)
NUM_SMS = num_sms
def grid(_meta):
return (NUM_SMS,)
M_BUCKET = triton.next_power_of_2(M_total)
_kernel_mg_dw_tma[grid](
x,
grad_output,
grad_w,
m_sizes,
M_total,
m_sizes.shape[0],
M_BUCKET,
N,
K,
NUM_SMS,
)
return grad_w
# ======== End Backwards Wrapper Functions =============
# ======== PyTorch wrapper functions ========
class GroupedGemmMg(torch.autograd.Function):
"""
Autograd function for GroupedGEMM with M*G grouping.
Supports both standard and FP8 quantized operations.
"""
@staticmethod
def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128, using_fp8=False):
"""
Forward pass of GroupedGEMM.
Args:
x: Input tensor, shape [M_total, K]
w: Weight tensor, shape [N, K]
m_sizes: Tensor of shape [G] containing the size of each group
use_tma: Whether to try using TMA acceleration (if available)
tma_size: Size of TMA descriptor in bytes
using_fp8: Whether to use FP8 quantization
Returns:
Output tensor, shape [M_total, N]
"""
# Use regular forward without quantization
output = grouped_gemm_forward(
x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
)
# Save inputs and parameters for backward pass
ctx.save_for_backward(x, w, m_sizes)
ctx.use_tma = use_tma
ctx.tma_size = tma_size
ctx.save_for_backward(x, w, m_sizes)
return output
@staticmethod
def backward(ctx, grad_output):
"""
Backward pass of M*G GroupedGEMM.
Args:
grad_output: Gradient of output, shape [M_total, N]
Returns:
Tuple of gradients:
- grad_x: Gradient with respect to x, shape [M_total, K]
- grad_w: Gradient with respect to w, shape [N, K]
- None: Gradient with respect to m_sizes (not differentiable)
- None: Gradient with respect to use_tma (not differentiable)
- None: Gradient with respect to tma_size (not differentiable)
"""
# Retrieve saved tensors and parameters
x, w, m_sizes = ctx.saved_tensors
use_tma = ctx.use_tma
tma_size = ctx.tma_size
# Compute gradients using the unified implementation
grad_x, grad_w = grouped_gemm_backward(
grad_output=grad_output,
x=x,
w=w,
m_sizes=m_sizes,
use_tma=use_tma,
tma_size=tma_size,
)
# Return gradients for all inputs (None for non-differentiable parameters)
return grad_x, grad_w, None, None, None, None
def mg_grouped_gemm(
x: torch.Tensor,
w: torch.Tensor,
m_sizes: torch.Tensor,
use_tma: bool = True,
tma_size: int = 128,
using_fp8: bool = False,
) -> torch.Tensor:
"""
Unified differentiable grouped GEMM operation for M*G grouped GEMM.
Supports both standard precision and FP8 quantized operations.
Args:
x: Input tensor, shape [M_total, K]
w: Weight tensor, shape [N, K]
m_sizes: Tensor of shape [G] containing the size of each group
use_tma: Whether to try using TMA acceleration (if available)
tma_size: Size of TMA descriptor in bytes
using_fp8: Whether to use FP8 quantization
Returns:
Output tensor, shape [M_total, N]
"""
return GroupedGemmMg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)

View File

@@ -0,0 +1,232 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
# pyre-unsafe
import os
import sys
from typing import Dict
import torch
import triton
from triton.runtime import driver # @manual
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# ===== Supporting utils, CUDA and TMA =====
class CudaUtils:
@staticmethod
def is_cuda() -> bool:
"""Check if Triton is running on CUDA backend."""
return driver.active.get_current_target().backend == "cuda"
@staticmethod
def verify_tma() -> bool:
"""Check if TMA is supported on the current device."""
return (
CudaUtils.is_cuda()
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
)
@staticmethod
def get_num_sms() -> int:
"""Get the number of streaming multiprocessors on the current device."""
if not CudaUtils.is_cuda():
raise RuntimeError("Triton is not running on CUDA backend")
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
return torch.cuda.get_device_properties("cuda").multi_processor_count
class TmaDescriptorHelper:
"""Helper class for managing TMA descriptors in Triton kernels.
Args:
tma_size: Size of the TMA descriptor in bytes
"""
class KernelParamWrapper:
"""Wrapper to implement the TmaDescKernelParam interface."""
def __init__(self, desc: torch.Tensor):
self.desc = desc
def tma_desc_cpu_ptr(self) -> int:
"""Return the CPU pointer to the TMA descriptor."""
return self.desc.data_ptr()
def __init__(self, tma_size: int = 128):
if not CudaUtils.verify_tma():
raise RuntimeError(
"TMA not supported on this device (requires Hopper or newer)"
)
self.tma_size = tma_size
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_tma_descriptor
self.descriptors: Dict[str, torch.Tensor] = {}
def init_tma_descriptor(self, name: str) -> None:
"""Initialize a TMA descriptor with the given name.
Call this method outside of the lambda function for grid size.
"""
self.descriptors[name] = torch.empty(
self.tma_size, device="cpu", dtype=torch.int8
)
def fill_1d_tma_descriptor(
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
) -> None:
"""Fill a 1D TMA descriptor.
Call this method inside the lambda function for grid size.
"""
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
def fill_2d_tma_descriptor(
self,
name: str,
ptr: int,
dim1: int,
dim0: int,
block_dim1: int,
block_dim0: int,
element_size: int,
) -> None:
"""Fill a 2D TMA descriptor.
Call this method inside the lambda function for grid size.
"""
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
"""Get the TMA descriptor kernel parameter for the given name."""
if name not in self.descriptors or self.descriptors[name] is None:
raise ValueError(f"TMA descriptor '{name}' not initialized")
return self.KernelParamWrapper(self.descriptors[name])
# ====== Autotuning utilities ======
ALIGN_SIZE_M = 128
_NV_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
},
num_stages=num_stages,
num_warps=num_warps,
num_ctas=num_ctas,
)
for block_size_m in [
ALIGN_SIZE_M,
]
for block_size_n in [64, 128, 256]
for block_size_k in [64, 128, 256]
for num_stages in [3, 4]
for num_warps in [4, 8]
for num_ctas in [1]
]
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
device = torch.cuda.current_device()
# Check for all possible pointer parameter names
if "grad_input_ptr" in named_args:
ptr_name = "grad_input_ptr"
elif "c_ptr" in named_args:
ptr_name = "c_ptr"
elif "grad_weight_ptr" in named_args:
ptr_name = "grad_weight_ptr"
else:
raise KeyError("No recognized pointer parameter found in kernel arguments")
if dtsize is None:
dtsize = named_args[ptr_name].element_size()
if dtype is None:
dtype = named_args[ptr_name].dtype
pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
kw["BLOCK_SIZE_M"],
kw["BLOCK_SIZE_N"],
kw["BLOCK_SIZE_K"],
config.num_stages,
)
G, M, N, K = (
named_args["G"],
named_args["M_BUCKET"],
named_args["N"],
named_args["K"],
)
# 1. make sure we have enough smem
max_shared_memory = driver.active.utils.get_device_properties(device)[
"max_shared_mem"
]
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory > max_shared_memory:
continue
M_PER_GROUP = M // G
MIN_M_TILES = 64
# 2. make sure we don't load M tiles that are too big
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
continue
# 3. make sure we don't load N tiles that are too small
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
continue
num_sm = driver.active.utils.get_device_properties(device)[
"multiprocessor_count"
]
N_TILES = N // BLOCK_N
MIN_N_TILES = 64
# 4. make sure we don't load N tiles that are too big
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
continue
# 5. make sure we don't load N tiles that are too small
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
continue
# 6. make sure K can be evenly divided
if K % BLOCK_K != 0:
continue
pruned_configs.append(config)
return pruned_configs
# ======== End Autotuning utilities ========

View File

@@ -190,6 +190,15 @@ class PatchManager:
apply_mistral_tokenizer_image_patch()
if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3":
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
patch_deepseek_v3_moe(backend=self.cfg.moe_kernel_backend)
elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels:
LOG.info(
"Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`"
)
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:

View File

@@ -0,0 +1,401 @@
"""Monkeypatches for DeepSeek V3 MoE to use Triton contiguous grouped GEMM kernels."""
from __future__ import annotations
from typing import Callable
import torch
from axolotl.kernels.moe import ContiguousGroupedGEMM
from axolotl.kernels.moe.indices import generate_permute_indices
from axolotl.kernels.moe.tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
from axolotl.utils.logging import get_logger
_GROUP_SIZE_M = 128
_COMBINED_SUBMODULES = ("gate_proj", "up_proj", "down_proj")
LOG = get_logger(__name__)
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
if not hidden_states.is_cuda or hidden_states.shape[0] == 0:
return False
major, _ = torch.cuda.get_device_capability(hidden_states.device)
if major < 9:
LOG.debug(
"Skipping Triton MoE kernels: requires compute capability >= 90, found %s",
major,
)
return False
return True
def _ensure_combined_expert_weights(
module, dtype: torch.dtype, device: torch.device, backend: str
) -> None:
if not hasattr(module, "_axolotl_original_specs"):
module._axolotl_original_specs = {}
if not hasattr(module, "_axolotl_mg_shapes"):
module._axolotl_mg_shapes = {}
prev_backend = getattr(module, "_axolotl_combined_backend", None)
if getattr(module, "_axolotl_combined_weights", False):
if prev_backend != backend:
_restore_expert_weights(module)
else:
for name in _COMBINED_SUBMODULES:
param_name = f"{name}_weight"
param = module.get_parameter(param_name)
if param.device != device or param.dtype != dtype:
module._parameters[param_name] = torch.nn.Parameter(
param.to(device=device, dtype=dtype).contiguous()
)
module._axolotl_combined_dtype = dtype
module._axolotl_combined_device = device
module._axolotl_combined_backend = backend
return
module._axolotl_mg_shapes = {}
for name in _COMBINED_SUBMODULES:
weights = []
orig_device = None
orig_dtype = None
orig_shape = None
for expert in module.experts:
lin = expert.get_submodule(name)
weight_param = lin._parameters.get("weight")
if weight_param is None:
raise RuntimeError("Expected expert linear layers to have weights")
if orig_device is None:
orig_device = weight_param.device
orig_dtype = weight_param.dtype
orig_shape = tuple(weight_param.shape)
weights.append(weight_param.detach().to(device=device, dtype=dtype))
if "weight" in lin._parameters:
del lin._parameters["weight"]
if "bias" in lin._parameters:
del lin._parameters["bias"]
if backend == "cg":
combined_weight = torch.stack(weights, dim=0).contiguous()
else:
combined_weight = torch.cat(weights, dim=0).contiguous()
module._axolotl_mg_shapes[name] = orig_shape
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined_weight))
module._axolotl_original_specs[name] = (orig_device, orig_dtype, orig_shape)
module._axolotl_combined_weights = True
module._axolotl_combined_dtype = dtype
module._axolotl_combined_device = device
module._axolotl_combined_backend = backend
def _restore_expert_weights(module) -> None:
if not getattr(module, "_axolotl_combined_weights", False):
return
for name in _COMBINED_SUBMODULES:
param_name = f"{name}_weight"
combined = module._parameters.pop(param_name)
orig_device, orig_dtype, orig_shape = module._axolotl_original_specs.get(
name, (combined.device, combined.dtype, None)
)
rows_per = orig_shape[0] if orig_shape else None
for idx, expert in enumerate(module.experts):
lin = expert.get_submodule(name)
if combined.dim() == 3:
slice_tensor = combined[idx]
elif rows_per is not None:
start = idx * rows_per
end = start + rows_per
slice_tensor = combined[start:end]
else:
raise RuntimeError(
"Unable to recover expert weight shape during restore"
)
lin._parameters["weight"] = torch.nn.Parameter(
slice_tensor.detach().clone().to(orig_device, dtype=orig_dtype)
)
module._axolotl_combined_weights = False
module._axolotl_combined_dtype = None
module._axolotl_combined_device = None
module._axolotl_combined_backend = None
module._axolotl_original_specs = {}
module._axolotl_mg_shapes = {}
def _run_cg_grouped_gemm(
module,
grouped_hidden: torch.Tensor,
m_sizes: torch.Tensor,
num_experts: int,
group_size_m: int,
hidden_dtype: torch.dtype,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
_ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg")
expert_index_tensor = torch.repeat_interleave(
torch.arange(num_experts, device=device, dtype=torch.int32),
m_sizes.to(torch.int64),
)
gate_weights = module.get_parameter("gate_proj_weight")
if gate_weights.dim() == 2:
out_dim = gate_weights.shape[0] // num_experts
gate_weights = gate_weights.view(num_experts, out_dim, gate_weights.shape[1])
up_weights = module.get_parameter("up_proj_weight")
if up_weights.dim() == 2:
out_dim = up_weights.shape[0] // num_experts
up_weights = up_weights.view(num_experts, out_dim, up_weights.shape[1])
down_weights = module.get_parameter("down_proj_weight")
if down_weights.dim() == 2:
out_dim = down_weights.shape[0] // num_experts
down_weights = down_weights.view(num_experts, out_dim, down_weights.shape[1])
gate_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
gate_weights,
expert_index_tensor,
group_size_m,
)
up_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
up_weights,
expert_index_tensor,
group_size_m,
)
return (
gate_out.to(hidden_dtype),
up_out.to(hidden_dtype),
down_weights,
expert_index_tensor,
)
gate_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("gate_proj_weight"),
m_sizes_tensor,
)
up_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("up_proj_weight"),
m_sizes_tensor,
)
down_out = mg_grouped_gemm(
hidden_grouped,
module.get_parameter("down_proj_weight"),
m_sizes_tensor,
)
return (
gate_out.to(hidden_dtype),
up_out.to(hidden_dtype),
down_out.to(hidden_dtype),
)
def _moe_triton_forward(
module,
hidden_states: torch.Tensor,
topk_indices: torch.Tensor,
topk_weights: torch.Tensor,
group_size_m: int,
backend: str,
fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
) -> torch.Tensor:
if not _is_triton_eligible(hidden_states):
return fallback(hidden_states, topk_indices, topk_weights)
device = hidden_states.device
hidden_dtype = hidden_states.dtype
num_tokens, hidden_dim = hidden_states.shape
top_k = topk_indices.size(-1)
expanded_hidden = hidden_states.repeat_interleave(top_k, dim=0)
expert_assignments = topk_indices.reshape(-1)
if expanded_hidden.numel() == 0:
return hidden_states.new_zeros_like(hidden_states)
sort_perm = torch.argsort(expert_assignments)
sorted_hidden = expanded_hidden.index_select(0, sort_perm)
sorted_assignments = expert_assignments.index_select(0, sort_perm)
num_experts = len(module.experts)
counts = torch.bincount(sorted_assignments, minlength=num_experts)
total_actual = int(counts.sum().item())
if total_actual == 0:
return hidden_states.new_zeros_like(hidden_states)
if not getattr(module, "_axolotl_triton_logged", False):
min_tokens = int(counts.min().item())
max_tokens = int(counts.max().item())
LOG.info(
"DeepseekV3MoE Triton: tokens per expert (min=%s, max=%s, avg=%.1f) with group_size=%s",
min_tokens,
max_tokens,
total_actual / max(1, num_experts),
group_size_m,
)
module._axolotl_triton_logged = True
counts_int = counts.to(torch.int32)
aligned_counts = (
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1) // group_size_m
) * group_size_m
max_len = int(aligned_counts.sum().item())
permuted_indices, m_sizes, _ = generate_permute_indices(
counts_int.to(device),
experts_per_rank=num_experts,
num_ranks=1,
max_len=max_len,
alignment=group_size_m,
use_cpu=not hidden_states.is_cuda,
)
permuted_indices = permuted_indices.to(device)
m_sizes = m_sizes.to(device)
permuted_indices_long = permuted_indices.to(torch.int64)
valid_mask = permuted_indices_long >= 0
valid_positions = torch.nonzero(valid_mask, as_tuple=False).squeeze(-1)
source_indices = permuted_indices_long[valid_mask]
padded_positions = torch.nonzero(~valid_mask, as_tuple=False).squeeze(-1)
grouped_hidden = hidden_states.new_empty((max_len, hidden_dim))
if valid_positions.numel() > 0:
grouped_hidden.index_copy_(
0,
valid_positions,
sorted_hidden.index_select(0, source_indices),
)
if valid_positions.numel() < max_len:
grouped_hidden.index_fill_(0, padded_positions, 0)
m_sizes_tensor = m_sizes.to(device=device, dtype=torch.int32)
if backend == "mg":
_ensure_combined_expert_weights(module, hidden_dtype, device, backend)
gate_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("gate_proj_weight"),
m_sizes_tensor,
).to(hidden_dtype)
up_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("up_proj_weight"),
m_sizes_tensor,
).to(hidden_dtype)
else:
gate_out, up_out, down_weights, expert_index_tensor = _run_cg_grouped_gemm(
module,
grouped_hidden,
m_sizes,
num_experts,
group_size_m,
hidden_dtype,
device,
)
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
if valid_positions.numel() > 0:
gate_valid = gate_out.index_select(0, valid_positions)
up_valid = up_out.index_select(0, valid_positions)
hidden_concat = act_fn(gate_valid) * up_valid
else:
hidden_concat = torch.empty(
(0, gate_out.shape[-1]), device=device, dtype=hidden_dtype
)
intermediate_dim = hidden_concat.shape[-1]
hidden_grouped = hidden_states.new_empty((max_len, intermediate_dim))
if valid_positions.numel() > 0:
hidden_grouped.index_copy_(0, valid_positions, hidden_concat)
if valid_positions.numel() < max_len:
hidden_grouped.index_fill_(0, padded_positions, 0)
if backend == "mg":
down_out = mg_grouped_gemm(
hidden_grouped,
module.get_parameter("down_proj_weight"),
m_sizes_tensor,
).to(hidden_dtype)
else:
down_out = ContiguousGroupedGEMM.apply(
hidden_grouped,
down_weights,
expert_index_tensor,
group_size_m,
).to(hidden_dtype)
if valid_positions.numel() > 0:
down_valid = down_out.index_select(0, valid_positions)
else:
down_valid = torch.empty(
(0, down_out.shape[-1]), device=device, dtype=hidden_dtype
)
sorted_outputs = hidden_states.new_zeros((total_actual, hidden_dim))
if down_valid.numel() > 0:
sorted_outputs.index_copy_(0, source_indices, down_valid)
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
expanded_output.index_copy_(0, sort_perm, sorted_outputs)
expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim)
weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype)
return weighted.sum(dim=1)
def patch_deepseek_v3_moe(
group_size_m: int = _GROUP_SIZE_M, backend: str = "mg"
) -> None:
"""Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels."""
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
if backend not in {"cg", "mg"}:
raise ValueError(f"Unsupported MoE kernel backend: {backend}")
# Record the unpatched implementation so callers can access a true baseline even
# after the Triton patch has been applied (e.g. repeated microbenchmarks).
if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"):
DeepseekV3MoE._axolotl_triton_original_moe = DeepseekV3MoE.moe
if getattr(DeepseekV3MoE, "_axolotl_triton_patch", False):
return
original_moe = DeepseekV3MoE._axolotl_triton_original_moe
DeepseekV3MoE._axolotl_triton_backend = backend
DeepseekV3MoE._axolotl_group_size_m = group_size_m
def patched_moe(self, hidden_states, topk_indices, topk_weights):
backend_sel = getattr(self, "_axolotl_triton_backend", backend)
group_size_sel = getattr(self, "_axolotl_group_size_m", group_size_m)
if backend_sel == "cg" and group_size_sel != _GROUP_SIZE_M:
LOG.debug(
"Adjusting group_size_m=%s to %s for CG backend",
group_size_sel,
_GROUP_SIZE_M,
)
group_size_sel = _GROUP_SIZE_M
try:
return _moe_triton_forward(
self,
hidden_states,
topk_indices,
topk_weights,
group_size_sel,
backend_sel,
original_moe,
)
except Exception as err: # surface Triton failures explicitly
_restore_expert_weights(self)
LOG.error("DeepseekV3MoE Triton path failed: %s", err)
raise
DeepseekV3MoE.moe = patched_moe
DeepseekV3MoE._axolotl_triton_patch = True

View File

@@ -323,8 +323,8 @@ def apply_lora_kernel_patches(
AssertionError: If multiple adapters are active (currently unsupported).
Note:
The optimizations require LoRA adapters with no dropout. The function will skip
patching if that condition isn't met.
The optimizations require LoRA adapters with no dropout and no bias terms. The
function will skip patching if these conditions aren't met.
"""
if not isinstance(model, PeftModelForCausalLM):
raise TypeError("Model must be a PeftModelForCausalLM")
@@ -340,10 +340,10 @@ def apply_lora_kernel_patches(
lora_config = model.model.peft_config[active_adapter]
# Only patch if conditions are met
can_patch = lora_config.lora_dropout == 0
can_patch = lora_config.lora_dropout == 0 and lora_config.bias == "none"
if not can_patch:
LOG.warning("Cannot patch layers - requires `lora_dropout: 0`")
LOG.warning("Cannot patch layers - requires no dropout and no bias")
LOG.warning("Please specify `lora_dropout: 0` in your axolotl config file")
return model

View File

@@ -113,6 +113,19 @@ class AxolotlInputConfig(
},
)
moe_kernels: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable Axolotl's vendored MoE kernels when supported (e.g., DeepSeek V3)"
},
)
moe_kernel_backend: Literal["cg", "mg"] | None = Field(
default="mg",
json_schema_extra={
"description": "Grouped GEMM backend to use when `moe_kernels` is enabled. `mg` selects the Hopper TMA kernel; `cg` selects the contiguous kernel."
},
)
trainer_cls: str | None = Field(
default=None,
json_schema_extra={

View File

@@ -221,53 +221,44 @@ def test_model_specific_activation(model_name, expected_activation):
assert layer.mlp.forward.__func__ is expected_activation
def test_kernel_patch_requires_zero_dropout():
"""Kernel patching should be skipped when dropout is enabled."""
config = {
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0.1,
"bias": "none",
}
def test_kernel_patch_conditions():
"""Test various conditions that should prevent kernel patching."""
test_configs = [
# Dropout prevents patching
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0.1,
"bias": "none",
},
# Bias prevents patching
{
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "lora_only",
},
]
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
for config in test_configs:
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
# Should not patch
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
# Verify no patches applied when dropout is non-zero
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
assert layer.forward.__func__ is not apply_lora_mlp_geglu
def test_kernel_patch_with_bias_enabled():
"""Kernel patching should succeed when LoRA bias is enabled."""
config = {
"peft_type": "LORA",
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 16,
"target_modules": ["gate_proj", "up_proj", "down_proj"],
"lora_dropout": 0,
"bias": "lora_only",
}
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M")
peft_config = get_peft_config(config)
model = PeftModelForCausalLM(model, peft_config)
cfg = DictDefault({"lora_mlp_kernel": True})
patched_model = apply_lora_kernel_patches(model, cfg)
layer = patched_model.model.model.layers[0].mlp
# Verify patches applied when bias support is enabled
assert layer.forward.__func__ is apply_lora_mlp_swiglu
# Verify no patches applied
assert layer.forward.__func__ is not apply_lora_mlp_swiglu
assert layer.forward.__func__ is not apply_lora_mlp_geglu
def test_kernel_config_options():