vendor torchtitan moe kernels

This commit is contained in:
Dan Saunders
2025-09-21 12:52:25 -04:00
parent f9748c4dc5
commit 95e607574a
9 changed files with 1251 additions and 0 deletions

View File

@@ -0,0 +1,184 @@
#!/usr/bin/env python
"""Microbenchmark for DeepSeek V3 MoE block comparing baseline vs Triton CG kernels.
Example usage (run from project root):
PYTHONPATH=./src:../transformers/src \
python scripts/benchmarks/deepseek_v3_moe.py --device cuda --iters 20
"""
from __future__ import annotations
import argparse
import time
from types import MethodType
import torch
try:
from transformers.models.deepseek_v3.configuration_deepseek_v3 import (
DeepseekV3Config,
)
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
except ImportError as exc: # pragma: no cover - utility script
raise SystemExit(
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
) from exc
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
DTYPE_MAP = {
"bf16": torch.bfloat16,
"fp16": torch.float16,
"fp32": torch.float32,
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--batch", type=int, default=2, help="batch size")
parser.add_argument("--seq-len", type=int, default=256, help="sequence length")
parser.add_argument("--hidden-size", type=int, default=4096, help="MoE hidden size")
parser.add_argument(
"--moe-intermediate-size",
type=int,
default=8192,
help="MoE intermediate projection size",
)
parser.add_argument(
"--n-experts",
type=int,
default=64,
help="Number of routed experts",
)
parser.add_argument(
"--top-k",
type=int,
default=4,
help="Number of experts per token",
)
parser.add_argument(
"--groups",
type=int,
default=8,
help="Router groups (must divide n-experts)",
)
parser.add_argument(
"--dtype",
choices=DTYPE_MAP.keys(),
default="bf16",
help="Computation dtype",
)
parser.add_argument(
"--device",
default="auto",
choices=["auto", "cpu", "cuda"],
help="Execution device",
)
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=25, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--group-size",
type=int,
default=128,
help="GROUP_SIZE_M used by the Triton kernel",
)
return parser.parse_args()
def resolve_device(requested: str) -> torch.device:
if requested == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(requested)
def build_module(args: argparse.Namespace) -> DeepseekV3MoE:
config = DeepseekV3Config(
hidden_size=args.hidden_size,
intermediate_size=args.moe_intermediate_size,
moe_intermediate_size=args.moe_intermediate_size,
n_routed_experts=args.n_experts,
num_experts_per_tok=args.top_k,
n_group=args.groups,
topk_group=max(1, min(args.groups, args.top_k)),
n_shared_experts=1,
)
module = DeepseekV3MoE(config)
module.eval()
return module
@torch.no_grad()
def benchmark(
module: DeepseekV3MoE, inputs: torch.Tensor, iters: int, warmup: int
) -> float:
for _ in range(warmup):
module(inputs)
if inputs.is_cuda:
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iters):
module(inputs)
if inputs.is_cuda:
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
return (elapsed / iters) * 1000.0
def main() -> None: # pragma: no cover - CLI entrypoint
args = parse_args()
torch.manual_seed(args.seed)
device = resolve_device(args.device)
dtype = DTYPE_MAP[args.dtype]
if args.n_experts % args.groups != 0:
raise SystemExit("n-experts must be divisible by groups")
if args.top_k > args.n_experts:
raise SystemExit("top-k cannot exceed number of experts")
if device.type == "cuda" and not torch.cuda.is_available():
raise SystemExit("CUDA requested but not available")
baseline_module = build_module(args)
original_moe = DeepseekV3MoE.moe
baseline_module.moe = MethodType(original_moe, baseline_module)
state_dict = baseline_module.state_dict()
patch_deepseek_v3_moe(group_size_m=args.group_size)
patched_module = build_module(args)
patched_module.load_state_dict(state_dict)
baseline_module.to(device=device, dtype=dtype)
patched_module.to(device=device, dtype=dtype)
inputs = torch.randn(
args.batch,
args.seq_len,
args.hidden_size,
device=device,
dtype=dtype,
)
with torch.no_grad():
ref_output = baseline_module(inputs)
patched_output = patched_module(inputs)
max_diff = (ref_output - patched_output).abs().max().item()
baseline_ms = benchmark(baseline_module, inputs, args.iters, args.warmup)
patched_ms = benchmark(patched_module, inputs, args.iters, args.warmup)
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
print(
f"Device={device.type} dtype={dtype} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
)
print(
f"Baseline: {baseline_ms:.3f} ms | Patched: {patched_ms:.3f} ms | x{speedup:.2f}"
)
print(f"Max |Δ| between outputs: {max_diff:.2e}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,17 @@
"""Mixture-of-Experts kernel implementations."""
from .tt_cg_gemm import (
ContiguousGroupedGEMM,
ContiguousGroupedGEMMForwardOnly,
cg_grouped_gemm,
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
__all__ = [
"cg_grouped_gemm",
"cg_grouped_gemm_forward",
"cg_grouped_gemm_forward_dynamic",
"ContiguousGroupedGEMM",
"ContiguousGroupedGEMMForwardOnly",
]

View File

@@ -0,0 +1,17 @@
"""Vendored Triton contiguous grouped GEMM kernels from TorchTitan."""
from .cg_backward import ContiguousGroupedGEMM
from .cg_forward import (
ContiguousGroupedGEMM as ContiguousGroupedGEMMForwardOnly,
cg_grouped_gemm,
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
__all__ = [
"cg_grouped_gemm",
"cg_grouped_gemm_forward",
"cg_grouped_gemm_forward_dynamic",
"ContiguousGroupedGEMM",
"ContiguousGroupedGEMMForwardOnly",
]

View File

@@ -0,0 +1,290 @@
"""Vendored backward pass for Triton contiguous grouped GEMM."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import triton
import triton.language as tl
from .cg_forward import cg_grouped_gemm_forward
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
GROUP_SIZE_M = 128
@triton.autotune(
configs=STANDARD_CONFIGS,
key=["M_TOTAL", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_cg_backward_dx(
grad_output_ptr,
b_ptr,
grad_input_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
):
"""Compute gradients with respect to inputs."""
pid = tl.program_id(0)
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
tile_m = pid // num_k_tiles
tile_k = pid % num_k_tiles
m_start = tile_m * BLOCK_SIZE_M
k_start = tile_k * BLOCK_SIZE_K
if m_start < M_TOTAL:
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
mask_m = offs_m < M_TOTAL
mask_k = offs_k < K
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
grad_input = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_K], dtype=tl.float32)
for n in range(0, N, BLOCK_SIZE_N):
offs_n = tl.arange(0, BLOCK_SIZE_N) + n
mask_n = offs_n < N
mask_go = mask_m[:, None] & mask_n[None, :]
mask_w = mask_n[:, None] & mask_k[None, :]
go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
go = tl.load(go_ptrs, mask=mask_go, other=0.0)
w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
w = tl.load(w_ptrs, mask=mask_w, other=0.0)
grad_input += tl.dot(go, w)
grad_input_ptrs = grad_input_ptr + offs_m[:, None] * K + offs_k[None, :]
mask_gi = mask_m[:, None] & mask_k[None, :]
tl.store(grad_input_ptrs, grad_input, mask=mask_gi)
@triton.jit
def _kernel_cg_backward_dw(
grad_output_ptr,
inputs_ptr,
grad_weights_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
):
"""Simplified kernel for expert weight gradients."""
pid = tl.program_id(0)
expert_id = pid // ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
position_id = pid % ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
if expert_id < NUM_EXPERTS:
n_tiles = K // BLOCK_SIZE_K
tile_n = position_id // n_tiles
tile_k = position_id % n_tiles
n_start = tile_n * BLOCK_SIZE_N
k_start = tile_k * BLOCK_SIZE_K
if n_start < N and k_start < K:
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
mask_n = offs_n < N
mask_k = offs_k < K
grad_weights = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_K], dtype=tl.float32)
for group_idx in range(0, M_TOTAL // GROUP_SIZE_M):
group_start = group_idx * GROUP_SIZE_M
group_expert = tl.load(indices_ptr + group_start)
if group_expert == expert_id:
for m_offset in range(0, GROUP_SIZE_M, BLOCK_SIZE_M):
m_start = group_start + m_offset
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
mask_m = offs_m < min(group_start + GROUP_SIZE_M, M_TOTAL)
go_ptrs = (
grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
)
mask_go = mask_m[:, None] & mask_n[None, :]
go = tl.load(go_ptrs, mask=mask_go, other=0.0)
in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :]
mask_in = mask_m[:, None] & mask_k[None, :]
inp = tl.load(in_ptrs, mask=mask_in, other=0.0)
go_t = tl.trans(go)
grad_weights += tl.dot(go_t, inp)
grad_w_ptrs = (
grad_weights_ptr
+ expert_id * N * K
+ offs_n[:, None] * K
+ offs_k[None, :]
)
mask_gw = mask_n[:, None] & mask_k[None, :]
tl.store(grad_w_ptrs, grad_weights, mask=mask_gw)
def cg_grouped_gemm_backward_weights(
grad_output: torch.Tensor,
inputs: torch.Tensor,
expert_indices: torch.Tensor,
num_experts: int,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Backward pass for expert weights."""
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
assert inputs.is_contiguous(), "Inputs tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, N = grad_output.shape
_, K = inputs.shape
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
grad_weights = torch.zeros(
(num_experts, N, K), device=grad_output.device, dtype=grad_output.dtype
)
block_size_n = min(128, N)
block_size_k = min(32, K)
block_size_m = min(32, group_size_m)
n_tiles = triton.cdiv(N, block_size_n)
k_tiles = triton.cdiv(K, block_size_k)
grid = (num_experts * n_tiles * k_tiles,)
_kernel_cg_backward_dw[grid](
grad_output,
inputs,
grad_weights,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
BLOCK_SIZE_M=block_size_m,
)
return grad_weights
def cg_grouped_gemm_backward_inputs(
grad_output: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Backward pass for inputs."""
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, N = grad_output.shape
num_experts, _, K = expert_weights.shape
assert M_total % group_size_m == 0, (
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
)
grad_inputs = torch.zeros(
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
)
grid = lambda meta: (
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
* triton.cdiv(K, meta["BLOCK_SIZE_K"]),
)
_kernel_cg_backward_dx[grid](
grad_output,
expert_weights,
grad_inputs,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
)
return grad_inputs
class ContiguousGroupedGEMM(torch.autograd.Function):
"""Autograd function with full backward support."""
@staticmethod
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
ctx.save_for_backward(inputs, expert_weights, expert_indices)
ctx.group_size_m = group_size_m
return cg_grouped_gemm_forward(
inputs=inputs,
expert_weights=expert_weights,
expert_indices=expert_indices,
group_size_m=group_size_m,
)
@staticmethod
def backward(ctx, grad_output):
inputs, expert_weights, expert_indices = ctx.saved_tensors
group_size_m = ctx.group_size_m
grad_output = grad_output.contiguous()
num_experts = expert_weights.shape[0]
grad_inputs = cg_grouped_gemm_backward_inputs(
grad_output=grad_output,
expert_weights=expert_weights,
expert_indices=expert_indices,
group_size_m=group_size_m,
)
grad_weights = cg_grouped_gemm_backward_weights(
grad_output=grad_output,
inputs=inputs,
expert_indices=expert_indices,
num_experts=num_experts,
group_size_m=group_size_m,
)
grad_indices = None
grad_group_size_m = None
return grad_inputs, grad_weights, grad_indices, grad_group_size_m

View File

@@ -0,0 +1,311 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Vendored forward Triton contiguous grouped GEMM kernels."""
import torch
import triton
import triton.language as tl
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
GROUP_SIZE_M = 128
@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, super_group_m):
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * super_group_m
group_size_m = min(num_pid_m - first_pid_m, super_group_m)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
return pid_m, pid_n
@triton.autotune(
configs=STANDARD_CONFIGS,
key=["M_TOTAL", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_cg_persistent_forward(
a_ptr,
b_ptr,
c_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
NUM_SMS: tl.constexpr,
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
SUPER_GROUP_M: tl.constexpr = 32,
):
"""
Contiguous Grouped GEMM kernel forward (persistent variant).
"""
c_type = c_ptr.dtype.element_ty
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n
tile_id_c = start_pid - NUM_SMS
num_pid_in_group = SUPER_GROUP_M * num_pid_n
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS):
tile_m_idx, tile_n_idx = _compute_pid(
tile_id, num_pid_in_group, num_pid_m, SUPER_GROUP_M
)
m_start = tile_m_idx * BLOCK_SIZE_M
n_start = tile_n_idx * BLOCK_SIZE_N
if m_start < M_TOTAL:
offs_m = m_start + tl.arange(0, BLOCK_SIZE_M)
offs_n = n_start + tl.arange(0, BLOCK_SIZE_N)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for ki in range(k_tiles):
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
mask_k = offs_k < K
mask_a = mask_m[:, None] & mask_k[None, :]
mask_b = mask_n[:, None] & mask_k[None, :]
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
b_ptrs = (
b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
)
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
accumulator += tl.dot(a, b.T)
tile_id_c += NUM_SMS
tile_m_idx, tile_n_idx = _compute_pid(
tile_id_c, num_pid_in_group, num_pid_m, SUPER_GROUP_M
)
offs_m = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
mask_c = mask_m[:, None] & mask_n[None, :]
c = accumulator.to(tl.float32)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
tl.store(c_ptrs, c.to(c_type), mask=mask_c)
@triton.autotune(
configs=STANDARD_CONFIGS,
key=["M_TOTAL", "N", "K"],
prune_configs_by={"early_config_prune": early_config_prune},
)
@triton.jit
def _kernel_cg_forward_aligned(
a_ptr,
b_ptr,
c_ptr,
indices_ptr,
M_TOTAL: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
):
"""
Contiguous Grouped GEMM kernel forward for aligned inputs.
"""
pid = tl.program_id(0)
c_type = c_ptr.dtype.element_ty
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
tile_m = pid // num_n_tiles
tile_n = pid % num_n_tiles
m_start = tile_m * BLOCK_SIZE_M
n_start = tile_n * BLOCK_SIZE_N
if m_start < M_TOTAL:
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
mask_m = offs_m < M_TOTAL
mask_n = offs_n < N
group_idx = m_start // GROUP_SIZE_M
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
acc = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
offs_k = tl.arange(0, BLOCK_SIZE_K) + k
mask_k = offs_k < K
mask_a = mask_m[:, None] & mask_k[None, :]
mask_b = mask_n[:, None] & mask_k[None, :]
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
b_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
acc += tl.dot(a, b.T)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask_c = mask_m[:, None] & mask_n[None, :]
tl.store(c_ptrs, acc.to(c_type), mask=mask_c)
def cg_grouped_gemm_forward(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Contiguous grouped GEMM forward pass for MoE."""
assert inputs.is_contiguous(), "Input tensor must be contiguous"
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, K = inputs.shape
assert M_total % group_size_m == 0, (
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
)
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
num_experts, N, K_weights = expert_weights.shape
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
assert expert_indices.shape[0] == M_total, (
"Expert indices length must match M_total"
)
output = torch.empty((M_total, N), device=inputs.device, dtype=torch.bfloat16)
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
grid = (NUM_SMS, 1, 1)
_kernel_cg_persistent_forward[grid](
inputs,
expert_weights,
output,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
NUM_SMS=NUM_SMS,
)
return output
def cg_grouped_gemm_forward_dynamic(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Contiguous grouped GEMM forward pass for MoE with autotuned launch."""
assert inputs.is_contiguous(), "Input tensor must be contiguous"
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
M_total, K = inputs.shape
assert M_total % group_size_m == 0, (
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
)
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
num_experts, N, K_weights = expert_weights.shape
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
assert expert_indices.shape[0] == M_total, (
"Expert indices length must match M_total"
)
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
grid = lambda meta: (
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
* triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
_kernel_cg_forward_aligned[grid](
inputs,
expert_weights,
output,
expert_indices,
M_TOTAL=M_total,
N=N,
K=K,
NUM_EXPERTS=num_experts,
GROUP_SIZE_M=group_size_m,
)
return output
class ContiguousGroupedGEMM(torch.autograd.Function):
"""Autograd function for contiguous grouped GEMM forward pass only."""
@staticmethod
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
return cg_grouped_gemm_forward(
inputs=inputs,
expert_weights=expert_weights,
expert_indices=expert_indices,
group_size_m=group_size_m,
)
@staticmethod
def backward(ctx, grad_output): # pragma: no cover - not implemented
raise NotImplementedError("Backward pass not implemented")
def cg_grouped_gemm(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = GROUP_SIZE_M,
) -> torch.Tensor:
"""Convenience wrapper for the forward-only autograd function."""
if expert_indices.dtype != torch.int32:
expert_indices = expert_indices.to(torch.int32)
return ContiguousGroupedGEMM.apply(
inputs, expert_weights, expert_indices, group_size_m
)

View File

@@ -0,0 +1,31 @@
"""Reference implementation for contiguous grouped GEMM."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
def pytorch_reference(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
expert_indices: torch.Tensor,
group_size_m: int = 128,
) -> torch.Tensor:
"""Simple PyTorch implementation for verification."""
M_total, K = inputs.shape
num_experts, N, _ = expert_weights.shape
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
for i in range(0, M_total, group_size_m):
end_idx = min(i + group_size_m, M_total)
expert_idx = expert_indices[i].item()
expert_weight = expert_weights[expert_idx]
output[i:end_idx] = torch.matmul(inputs[i:end_idx], expert_weight.T)
return output

View File

@@ -0,0 +1,209 @@
"""Autotuning utilities for Triton contiguous grouped GEMM kernels."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict
import torch
import triton
import triton.language as tl
from triton.runtime import driver
class CudaUtils:
"""Helper utilities for CUDA specific Triton features."""
@staticmethod
def is_cuda() -> bool:
return driver.active.get_current_target().backend == "cuda"
@staticmethod
def verify_tma() -> bool:
return (
CudaUtils.is_cuda()
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
)
@staticmethod
def get_num_sms() -> int:
if not CudaUtils.is_cuda():
raise RuntimeError("Triton is not running on CUDA backend")
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
return torch.cuda.get_device_properties("cuda").multi_processor_count
class TmaDescriptorHelper:
"""Helper class for managing TMA descriptors in Triton kernels."""
class KernelParamWrapper:
def __init__(self, desc: torch.Tensor):
self.desc = desc
def tma_desc_cpu_ptr(self) -> int:
return self.desc.data_ptr()
def __init__(self, tma_size: int = 128):
if not CudaUtils.verify_tma():
raise RuntimeError(
"TMA not supported on this device (requires Hopper or newer)"
)
if "nv_tma_desc_type" not in dir(tl):
raise RuntimeError(
"TMA grid constant descriptors not supported in your Triton version"
)
self.tma_size = tma_size
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
self.descriptors: Dict[str, torch.Tensor] = {}
def init_tma_descriptor(self, name: str) -> None:
self.descriptors[name] = torch.empty(
self.tma_size, device="cpu", dtype=torch.int8
)
def fill_1d_tma_descriptor(
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
) -> None:
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
def fill_2d_tma_descriptor(
self,
name: str,
ptr: int,
dim1: int,
dim0: int,
block_dim1: int,
block_dim0: int,
element_size: int,
) -> None:
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
def get_tma_descriptor_kernel_param(
self, name: str
) -> "TmaDescriptorHelper.KernelParamWrapper":
if name not in self.descriptors or self.descriptors[name] is None:
raise ValueError(f"TMA descriptor '{name}' not initialized")
return self.KernelParamWrapper(self.descriptors[name])
HOPPER_CONFIGS = [
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=8,
),
]
STANDARD_CONFIGS = [
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=2,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=8,
),
]
def early_config_prune(configs, args, **kwargs):
"""Filter out configurations that would exceed shared memory capacity."""
k = kwargs.get("K", 0)
valid_configs = [
config for config in configs if config.kwargs.get("BLOCK_SIZE_K", 0) <= k
]
if not valid_configs and configs:
return [
min(
configs,
key=lambda c: c.kwargs.get("BLOCK_SIZE_K", float("inf")),
)
]
return valid_configs

View File

@@ -190,6 +190,11 @@ class PatchManager:
apply_mistral_tokenizer_image_patch()
if self.cfg.model_config_type == "deepseek_v3":
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
patch_deepseek_v3_moe()
def _apply_fp8_patches(self):
"""Apply patches for FP8 support."""
if self.cfg.fp8:

View File

@@ -0,0 +1,187 @@
"""Monkeypatches for DeepSeek V3 MoE to use Triton contiguous grouped GEMM kernels."""
from __future__ import annotations
import contextlib
import math
from typing import Callable
import torch
from axolotl.kernels.moe import ContiguousGroupedGEMM
_GROUP_SIZE_M = 128
def _align_to(value: int, alignment: int) -> int:
if value <= 0:
return 0
return math.ceil(value / alignment) * alignment
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
return hidden_states.is_cuda and hidden_states.shape[0] > 0
def _collect_expert_weights(module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
gate_weights = []
up_weights = []
down_weights = []
for expert in module.experts:
gate_weights.append(expert.gate_proj.weight)
up_weights.append(expert.up_proj.weight)
down_weights.append(expert.down_proj.weight)
gate = torch.stack(gate_weights, dim=0).contiguous()
up = torch.stack(up_weights, dim=0).contiguous()
down = torch.stack(down_weights, dim=0).contiguous()
return gate, up, down
def _moe_triton_forward(
module,
hidden_states: torch.Tensor,
topk_indices: torch.Tensor,
topk_weights: torch.Tensor,
group_size_m: int,
fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
) -> torch.Tensor:
if not _is_triton_eligible(hidden_states):
return fallback(hidden_states, topk_indices, topk_weights)
device = hidden_states.device
hidden_dtype = hidden_states.dtype
num_tokens, hidden_dim = hidden_states.shape
top_k = topk_indices.size(-1)
expanded_hidden = hidden_states.repeat_interleave(top_k, dim=0)
expert_assignments = topk_indices.reshape(-1)
if expanded_hidden.numel() == 0:
return hidden_states.new_zeros_like(hidden_states)
sort_perm = torch.argsort(expert_assignments)
sorted_hidden = expanded_hidden.index_select(0, sort_perm)
sorted_assignments = expert_assignments.index_select(0, sort_perm)
num_experts = len(module.experts)
counts = torch.bincount(sorted_assignments, minlength=num_experts)
counts_cpu = counts.to(torch.int64).cpu().tolist()
padded_counts = [_align_to(c, group_size_m) for c in counts_cpu]
total_actual = sum(counts_cpu)
total_padded = sum(padded_counts)
if total_actual == 0 or total_padded == 0:
return hidden_states.new_zeros_like(hidden_states)
actual_offsets = [0]
padded_offsets = [0]
for count, padded in zip(counts_cpu, padded_counts, strict=False):
actual_offsets.append(actual_offsets[-1] + count)
padded_offsets.append(padded_offsets[-1] + padded)
grouped_hidden = hidden_states.new_zeros((total_padded, hidden_dim))
expert_index_tensor = torch.empty(total_padded, dtype=torch.int32, device=device)
for idx, (count, padded) in enumerate(zip(counts_cpu, padded_counts, strict=False)):
dst_start = padded_offsets[idx]
dst_end = dst_start + padded
if padded == 0:
continue
expert_index_tensor[dst_start:dst_end] = idx
if count > 0:
src_start = actual_offsets[idx]
src_end = src_start + count
grouped_hidden[dst_start : dst_start + count].copy_(
sorted_hidden[src_start:src_end]
)
gate_weights, up_weights, down_weights = _collect_expert_weights(module)
gate_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
gate_weights,
expert_index_tensor,
group_size_m,
)
up_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
up_weights,
expert_index_tensor,
group_size_m,
)
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
hidden_chunks = []
for idx, count in enumerate(counts_cpu):
if count == 0:
continue
pad_start = padded_offsets[idx]
pad_end = pad_start + count
gate_slice = gate_out[pad_start:pad_end].to(hidden_dtype)
up_slice = up_out[pad_start:pad_end].to(hidden_dtype)
hidden_chunks.append(act_fn(gate_slice) * up_slice)
hidden_concat = torch.cat(hidden_chunks, dim=0)
intermediate_dim = hidden_concat.shape[-1]
hidden_grouped = hidden_states.new_zeros((total_padded, intermediate_dim))
for idx, count in enumerate(counts_cpu):
if count == 0:
continue
pad_start = padded_offsets[idx]
src_start = actual_offsets[idx]
src_end = src_start + count
hidden_grouped[pad_start : pad_start + count].copy_(
hidden_concat[src_start:src_end]
)
down_out = ContiguousGroupedGEMM.apply(
hidden_grouped,
down_weights,
expert_index_tensor,
group_size_m,
)
down_chunks = []
for idx, count in enumerate(counts_cpu):
if count == 0:
continue
pad_start = padded_offsets[idx]
pad_end = pad_start + count
down_chunks.append(down_out[pad_start:pad_end].to(hidden_dtype))
down_concat = torch.cat(down_chunks, dim=0)
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
expanded_output.index_copy_(0, sort_perm, down_concat.to(hidden_dtype))
expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim)
weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype)
return weighted.sum(dim=1)
def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
"""Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels."""
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
if getattr(DeepseekV3MoE, "_axolotl_triton_patch", False):
return
original_moe = DeepseekV3MoE.moe
def patched_moe(self, hidden_states, topk_indices, topk_weights):
with contextlib.suppress(RuntimeError):
return _moe_triton_forward(
self,
hidden_states,
topk_indices,
topk_weights,
group_size_m,
original_moe,
)
return original_moe(self, hidden_states, topk_indices, topk_weights)
DeepseekV3MoE.moe = patched_moe
DeepseekV3MoE._axolotl_triton_patch = True