vendor torchtitan moe kernels
This commit is contained in:
184
scripts/benchmarks/deepseek_v3_moe.py
Normal file
184
scripts/benchmarks/deepseek_v3_moe.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#!/usr/bin/env python
|
||||
"""Microbenchmark for DeepSeek V3 MoE block comparing baseline vs Triton CG kernels.
|
||||
|
||||
Example usage (run from project root):
|
||||
|
||||
PYTHONPATH=./src:../transformers/src \
|
||||
python scripts/benchmarks/deepseek_v3_moe.py --device cuda --iters 20
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from types import MethodType
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from transformers.models.deepseek_v3.configuration_deepseek_v3 import (
|
||||
DeepseekV3Config,
|
||||
)
|
||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
except ImportError as exc: # pragma: no cover - utility script
|
||||
raise SystemExit(
|
||||
"Transformers with DeepSeek-V3 support must be available in PYTHONPATH"
|
||||
) from exc
|
||||
|
||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
||||
|
||||
DTYPE_MAP = {
|
||||
"bf16": torch.bfloat16,
|
||||
"fp16": torch.float16,
|
||||
"fp32": torch.float32,
|
||||
}
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--batch", type=int, default=2, help="batch size")
|
||||
parser.add_argument("--seq-len", type=int, default=256, help="sequence length")
|
||||
parser.add_argument("--hidden-size", type=int, default=4096, help="MoE hidden size")
|
||||
parser.add_argument(
|
||||
"--moe-intermediate-size",
|
||||
type=int,
|
||||
default=8192,
|
||||
help="MoE intermediate projection size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-experts",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of routed experts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of experts per token",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--groups",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Router groups (must divide n-experts)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
choices=DTYPE_MAP.keys(),
|
||||
default="bf16",
|
||||
help="Computation dtype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="auto",
|
||||
choices=["auto", "cpu", "cuda"],
|
||||
help="Execution device",
|
||||
)
|
||||
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations")
|
||||
parser.add_argument("--iters", type=int, default=25, help="Benchmark iterations")
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--group-size",
|
||||
type=int,
|
||||
default=128,
|
||||
help="GROUP_SIZE_M used by the Triton kernel",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def resolve_device(requested: str) -> torch.device:
|
||||
if requested == "auto":
|
||||
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
return torch.device(requested)
|
||||
|
||||
|
||||
def build_module(args: argparse.Namespace) -> DeepseekV3MoE:
|
||||
config = DeepseekV3Config(
|
||||
hidden_size=args.hidden_size,
|
||||
intermediate_size=args.moe_intermediate_size,
|
||||
moe_intermediate_size=args.moe_intermediate_size,
|
||||
n_routed_experts=args.n_experts,
|
||||
num_experts_per_tok=args.top_k,
|
||||
n_group=args.groups,
|
||||
topk_group=max(1, min(args.groups, args.top_k)),
|
||||
n_shared_experts=1,
|
||||
)
|
||||
module = DeepseekV3MoE(config)
|
||||
module.eval()
|
||||
return module
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark(
|
||||
module: DeepseekV3MoE, inputs: torch.Tensor, iters: int, warmup: int
|
||||
) -> float:
|
||||
for _ in range(warmup):
|
||||
module(inputs)
|
||||
if inputs.is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
start = time.perf_counter()
|
||||
for _ in range(iters):
|
||||
module(inputs)
|
||||
if inputs.is_cuda:
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start
|
||||
return (elapsed / iters) * 1000.0
|
||||
|
||||
|
||||
def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
args = parse_args()
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
device = resolve_device(args.device)
|
||||
dtype = DTYPE_MAP[args.dtype]
|
||||
|
||||
if args.n_experts % args.groups != 0:
|
||||
raise SystemExit("n-experts must be divisible by groups")
|
||||
if args.top_k > args.n_experts:
|
||||
raise SystemExit("top-k cannot exceed number of experts")
|
||||
|
||||
if device.type == "cuda" and not torch.cuda.is_available():
|
||||
raise SystemExit("CUDA requested but not available")
|
||||
|
||||
baseline_module = build_module(args)
|
||||
original_moe = DeepseekV3MoE.moe
|
||||
baseline_module.moe = MethodType(original_moe, baseline_module)
|
||||
state_dict = baseline_module.state_dict()
|
||||
|
||||
patch_deepseek_v3_moe(group_size_m=args.group_size)
|
||||
patched_module = build_module(args)
|
||||
patched_module.load_state_dict(state_dict)
|
||||
|
||||
baseline_module.to(device=device, dtype=dtype)
|
||||
patched_module.to(device=device, dtype=dtype)
|
||||
|
||||
inputs = torch.randn(
|
||||
args.batch,
|
||||
args.seq_len,
|
||||
args.hidden_size,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_output = baseline_module(inputs)
|
||||
patched_output = patched_module(inputs)
|
||||
max_diff = (ref_output - patched_output).abs().max().item()
|
||||
|
||||
baseline_ms = benchmark(baseline_module, inputs, args.iters, args.warmup)
|
||||
patched_ms = benchmark(patched_module, inputs, args.iters, args.warmup)
|
||||
|
||||
speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan")
|
||||
|
||||
print(
|
||||
f"Device={device.type} dtype={dtype} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
|
||||
)
|
||||
print(
|
||||
f"Baseline: {baseline_ms:.3f} ms | Patched: {patched_ms:.3f} ms | x{speedup:.2f}"
|
||||
)
|
||||
print(f"Max |Δ| between outputs: {max_diff:.2e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
17
src/axolotl/kernels/moe/__init__.py
Normal file
17
src/axolotl/kernels/moe/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Mixture-of-Experts kernel implementations."""
|
||||
|
||||
from .tt_cg_gemm import (
|
||||
ContiguousGroupedGEMM,
|
||||
ContiguousGroupedGEMMForwardOnly,
|
||||
cg_grouped_gemm,
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"cg_grouped_gemm",
|
||||
"cg_grouped_gemm_forward",
|
||||
"cg_grouped_gemm_forward_dynamic",
|
||||
"ContiguousGroupedGEMM",
|
||||
"ContiguousGroupedGEMMForwardOnly",
|
||||
]
|
||||
17
src/axolotl/kernels/moe/tt_cg_gemm/__init__.py
Normal file
17
src/axolotl/kernels/moe/tt_cg_gemm/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Vendored Triton contiguous grouped GEMM kernels from TorchTitan."""
|
||||
|
||||
from .cg_backward import ContiguousGroupedGEMM
|
||||
from .cg_forward import (
|
||||
ContiguousGroupedGEMM as ContiguousGroupedGEMMForwardOnly,
|
||||
cg_grouped_gemm,
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"cg_grouped_gemm",
|
||||
"cg_grouped_gemm_forward",
|
||||
"cg_grouped_gemm_forward_dynamic",
|
||||
"ContiguousGroupedGEMM",
|
||||
"ContiguousGroupedGEMMForwardOnly",
|
||||
]
|
||||
290
src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py
Normal file
290
src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Vendored backward pass for Triton contiguous grouped GEMM."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .cg_forward import cg_grouped_gemm_forward
|
||||
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
|
||||
|
||||
GROUP_SIZE_M = 128
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=STANDARD_CONFIGS,
|
||||
key=["M_TOTAL", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_cg_backward_dx(
|
||||
grad_output_ptr,
|
||||
b_ptr,
|
||||
grad_input_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
|
||||
):
|
||||
"""Compute gradients with respect to inputs."""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
|
||||
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
|
||||
tile_m = pid // num_k_tiles
|
||||
tile_k = pid % num_k_tiles
|
||||
|
||||
m_start = tile_m * BLOCK_SIZE_M
|
||||
k_start = tile_k * BLOCK_SIZE_K
|
||||
|
||||
if m_start < M_TOTAL:
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_k = offs_k < K
|
||||
|
||||
group_idx = m_start // GROUP_SIZE_M
|
||||
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
|
||||
|
||||
grad_input = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_K], dtype=tl.float32)
|
||||
|
||||
for n in range(0, N, BLOCK_SIZE_N):
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N) + n
|
||||
mask_n = offs_n < N
|
||||
|
||||
mask_go = mask_m[:, None] & mask_n[None, :]
|
||||
mask_w = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0)
|
||||
|
||||
w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
w = tl.load(w_ptrs, mask=mask_w, other=0.0)
|
||||
|
||||
grad_input += tl.dot(go, w)
|
||||
|
||||
grad_input_ptrs = grad_input_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
mask_gi = mask_m[:, None] & mask_k[None, :]
|
||||
tl.store(grad_input_ptrs, grad_input, mask=mask_gi)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _kernel_cg_backward_dw(
|
||||
grad_output_ptr,
|
||||
inputs_ptr,
|
||||
grad_weights_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""Simplified kernel for expert weight gradients."""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
expert_id = pid // ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
|
||||
position_id = pid % ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
|
||||
|
||||
if expert_id < NUM_EXPERTS:
|
||||
n_tiles = K // BLOCK_SIZE_K
|
||||
tile_n = position_id // n_tiles
|
||||
tile_k = position_id % n_tiles
|
||||
|
||||
n_start = tile_n * BLOCK_SIZE_N
|
||||
k_start = tile_k * BLOCK_SIZE_K
|
||||
|
||||
if n_start < N and k_start < K:
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
|
||||
|
||||
mask_n = offs_n < N
|
||||
mask_k = offs_k < K
|
||||
|
||||
grad_weights = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_K], dtype=tl.float32)
|
||||
|
||||
for group_idx in range(0, M_TOTAL // GROUP_SIZE_M):
|
||||
group_start = group_idx * GROUP_SIZE_M
|
||||
group_expert = tl.load(indices_ptr + group_start)
|
||||
|
||||
if group_expert == expert_id:
|
||||
for m_offset in range(0, GROUP_SIZE_M, BLOCK_SIZE_M):
|
||||
m_start = group_start + m_offset
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
|
||||
|
||||
mask_m = offs_m < min(group_start + GROUP_SIZE_M, M_TOTAL)
|
||||
|
||||
go_ptrs = (
|
||||
grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
)
|
||||
mask_go = mask_m[:, None] & mask_n[None, :]
|
||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0)
|
||||
|
||||
in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
mask_in = mask_m[:, None] & mask_k[None, :]
|
||||
inp = tl.load(in_ptrs, mask=mask_in, other=0.0)
|
||||
|
||||
go_t = tl.trans(go)
|
||||
grad_weights += tl.dot(go_t, inp)
|
||||
|
||||
grad_w_ptrs = (
|
||||
grad_weights_ptr
|
||||
+ expert_id * N * K
|
||||
+ offs_n[:, None] * K
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
mask_gw = mask_n[:, None] & mask_k[None, :]
|
||||
tl.store(grad_w_ptrs, grad_weights, mask=mask_gw)
|
||||
|
||||
|
||||
def cg_grouped_gemm_backward_weights(
|
||||
grad_output: torch.Tensor,
|
||||
inputs: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
num_experts: int,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Backward pass for expert weights."""
|
||||
|
||||
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
|
||||
assert inputs.is_contiguous(), "Inputs tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, N = grad_output.shape
|
||||
_, K = inputs.shape
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
grad_weights = torch.zeros(
|
||||
(num_experts, N, K), device=grad_output.device, dtype=grad_output.dtype
|
||||
)
|
||||
|
||||
block_size_n = min(128, N)
|
||||
block_size_k = min(32, K)
|
||||
block_size_m = min(32, group_size_m)
|
||||
|
||||
n_tiles = triton.cdiv(N, block_size_n)
|
||||
k_tiles = triton.cdiv(K, block_size_k)
|
||||
grid = (num_experts * n_tiles * k_tiles,)
|
||||
|
||||
_kernel_cg_backward_dw[grid](
|
||||
grad_output,
|
||||
inputs,
|
||||
grad_weights,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
BLOCK_SIZE_N=block_size_n,
|
||||
BLOCK_SIZE_K=block_size_k,
|
||||
BLOCK_SIZE_M=block_size_m,
|
||||
)
|
||||
|
||||
return grad_weights
|
||||
|
||||
|
||||
def cg_grouped_gemm_backward_inputs(
|
||||
grad_output: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Backward pass for inputs."""
|
||||
|
||||
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
|
||||
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, N = grad_output.shape
|
||||
num_experts, _, K = expert_weights.shape
|
||||
|
||||
assert M_total % group_size_m == 0, (
|
||||
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
|
||||
)
|
||||
|
||||
grad_inputs = torch.zeros(
|
||||
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
|
||||
)
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(K, meta["BLOCK_SIZE_K"]),
|
||||
)
|
||||
|
||||
_kernel_cg_backward_dx[grid](
|
||||
grad_output,
|
||||
expert_weights,
|
||||
grad_inputs,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
)
|
||||
|
||||
return grad_inputs
|
||||
|
||||
|
||||
class ContiguousGroupedGEMM(torch.autograd.Function):
|
||||
"""Autograd function with full backward support."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
|
||||
ctx.save_for_backward(inputs, expert_weights, expert_indices)
|
||||
ctx.group_size_m = group_size_m
|
||||
|
||||
return cg_grouped_gemm_forward(
|
||||
inputs=inputs,
|
||||
expert_weights=expert_weights,
|
||||
expert_indices=expert_indices,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
inputs, expert_weights, expert_indices = ctx.saved_tensors
|
||||
group_size_m = ctx.group_size_m
|
||||
|
||||
grad_output = grad_output.contiguous()
|
||||
num_experts = expert_weights.shape[0]
|
||||
|
||||
grad_inputs = cg_grouped_gemm_backward_inputs(
|
||||
grad_output=grad_output,
|
||||
expert_weights=expert_weights,
|
||||
expert_indices=expert_indices,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
grad_weights = cg_grouped_gemm_backward_weights(
|
||||
grad_output=grad_output,
|
||||
inputs=inputs,
|
||||
expert_indices=expert_indices,
|
||||
num_experts=num_experts,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
grad_indices = None
|
||||
grad_group_size_m = None
|
||||
|
||||
return grad_inputs, grad_weights, grad_indices, grad_group_size_m
|
||||
311
src/axolotl/kernels/moe/tt_cg_gemm/cg_forward.py
Normal file
311
src/axolotl/kernels/moe/tt_cg_gemm/cg_forward.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Vendored forward Triton contiguous grouped GEMM kernels."""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
|
||||
|
||||
GROUP_SIZE_M = 128
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, super_group_m):
|
||||
group_id = tile_id // num_pid_in_group
|
||||
first_pid_m = group_id * super_group_m
|
||||
group_size_m = min(num_pid_m - first_pid_m, super_group_m)
|
||||
pid_m = first_pid_m + (tile_id % group_size_m)
|
||||
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
||||
return pid_m, pid_n
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=STANDARD_CONFIGS,
|
||||
key=["M_TOTAL", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_cg_persistent_forward(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
NUM_SMS: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
|
||||
SUPER_GROUP_M: tl.constexpr = 32,
|
||||
):
|
||||
"""
|
||||
Contiguous Grouped GEMM kernel forward (persistent variant).
|
||||
"""
|
||||
|
||||
c_type = c_ptr.dtype.element_ty
|
||||
|
||||
start_pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_tiles = num_pid_m * num_pid_n
|
||||
tile_id_c = start_pid - NUM_SMS
|
||||
num_pid_in_group = SUPER_GROUP_M * num_pid_n
|
||||
|
||||
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS):
|
||||
tile_m_idx, tile_n_idx = _compute_pid(
|
||||
tile_id, num_pid_in_group, num_pid_m, SUPER_GROUP_M
|
||||
)
|
||||
|
||||
m_start = tile_m_idx * BLOCK_SIZE_M
|
||||
n_start = tile_n_idx * BLOCK_SIZE_N
|
||||
|
||||
if m_start < M_TOTAL:
|
||||
offs_m = m_start + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = n_start + tl.arange(0, BLOCK_SIZE_N)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for ki in range(k_tiles):
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_n = offs_n < N
|
||||
mask_k = offs_k < K
|
||||
|
||||
mask_a = mask_m[:, None] & mask_k[None, :]
|
||||
mask_b = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
group_idx = m_start // GROUP_SIZE_M
|
||||
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
|
||||
|
||||
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
|
||||
|
||||
b_ptrs = (
|
||||
b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
|
||||
|
||||
accumulator += tl.dot(a, b.T)
|
||||
|
||||
tile_id_c += NUM_SMS
|
||||
tile_m_idx, tile_n_idx = _compute_pid(
|
||||
tile_id_c, num_pid_in_group, num_pid_m, SUPER_GROUP_M
|
||||
)
|
||||
|
||||
offs_m = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_n = offs_n < N
|
||||
mask_c = mask_m[:, None] & mask_n[None, :]
|
||||
|
||||
c = accumulator.to(tl.float32)
|
||||
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
tl.store(c_ptrs, c.to(c_type), mask=mask_c)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=STANDARD_CONFIGS,
|
||||
key=["M_TOTAL", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_cg_forward_aligned(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
|
||||
):
|
||||
"""
|
||||
Contiguous Grouped GEMM kernel forward for aligned inputs.
|
||||
"""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
c_type = c_ptr.dtype.element_ty
|
||||
|
||||
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
|
||||
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
tile_m = pid // num_n_tiles
|
||||
tile_n = pid % num_n_tiles
|
||||
|
||||
m_start = tile_m * BLOCK_SIZE_M
|
||||
n_start = tile_n * BLOCK_SIZE_N
|
||||
|
||||
if m_start < M_TOTAL:
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_n = offs_n < N
|
||||
|
||||
group_idx = m_start // GROUP_SIZE_M
|
||||
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
|
||||
|
||||
acc = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=tl.float32)
|
||||
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + k
|
||||
mask_k = offs_k < K
|
||||
|
||||
mask_a = mask_m[:, None] & mask_k[None, :]
|
||||
mask_b = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
|
||||
|
||||
b_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
|
||||
|
||||
acc += tl.dot(a, b.T)
|
||||
|
||||
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
mask_c = mask_m[:, None] & mask_n[None, :]
|
||||
tl.store(c_ptrs, acc.to(c_type), mask=mask_c)
|
||||
|
||||
|
||||
def cg_grouped_gemm_forward(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Contiguous grouped GEMM forward pass for MoE."""
|
||||
|
||||
assert inputs.is_contiguous(), "Input tensor must be contiguous"
|
||||
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, K = inputs.shape
|
||||
assert M_total % group_size_m == 0, (
|
||||
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
|
||||
)
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
num_experts, N, K_weights = expert_weights.shape
|
||||
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
|
||||
assert expert_indices.shape[0] == M_total, (
|
||||
"Expert indices length must match M_total"
|
||||
)
|
||||
|
||||
output = torch.empty((M_total, N), device=inputs.device, dtype=torch.bfloat16)
|
||||
|
||||
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
|
||||
grid = (NUM_SMS, 1, 1)
|
||||
_kernel_cg_persistent_forward[grid](
|
||||
inputs,
|
||||
expert_weights,
|
||||
output,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
NUM_SMS=NUM_SMS,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def cg_grouped_gemm_forward_dynamic(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Contiguous grouped GEMM forward pass for MoE with autotuned launch."""
|
||||
|
||||
assert inputs.is_contiguous(), "Input tensor must be contiguous"
|
||||
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, K = inputs.shape
|
||||
assert M_total % group_size_m == 0, (
|
||||
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
|
||||
)
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
num_experts, N, K_weights = expert_weights.shape
|
||||
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
|
||||
assert expert_indices.shape[0] == M_total, (
|
||||
"Expert indices length must match M_total"
|
||||
)
|
||||
|
||||
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
_kernel_cg_forward_aligned[grid](
|
||||
inputs,
|
||||
expert_weights,
|
||||
output,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ContiguousGroupedGEMM(torch.autograd.Function):
|
||||
"""Autograd function for contiguous grouped GEMM forward pass only."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
|
||||
return cg_grouped_gemm_forward(
|
||||
inputs=inputs,
|
||||
expert_weights=expert_weights,
|
||||
expert_indices=expert_indices,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output): # pragma: no cover - not implemented
|
||||
raise NotImplementedError("Backward pass not implemented")
|
||||
|
||||
|
||||
def cg_grouped_gemm(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Convenience wrapper for the forward-only autograd function."""
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
return ContiguousGroupedGEMM.apply(
|
||||
inputs, expert_weights, expert_indices, group_size_m
|
||||
)
|
||||
31
src/axolotl/kernels/moe/tt_cg_gemm/cg_reference.py
Normal file
31
src/axolotl/kernels/moe/tt_cg_gemm/cg_reference.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Reference implementation for contiguous grouped GEMM."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def pytorch_reference(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Simple PyTorch implementation for verification."""
|
||||
|
||||
M_total, K = inputs.shape
|
||||
num_experts, N, _ = expert_weights.shape
|
||||
|
||||
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
|
||||
|
||||
for i in range(0, M_total, group_size_m):
|
||||
end_idx = min(i + group_size_m, M_total)
|
||||
expert_idx = expert_indices[i].item()
|
||||
expert_weight = expert_weights[expert_idx]
|
||||
output[i:end_idx] = torch.matmul(inputs[i:end_idx], expert_weight.T)
|
||||
|
||||
return output
|
||||
209
src/axolotl/kernels/moe/tt_cg_gemm/tma_cuda_autotune.py
Normal file
209
src/axolotl/kernels/moe/tt_cg_gemm/tma_cuda_autotune.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Autotuning utilities for Triton contiguous grouped GEMM kernels."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.runtime import driver
|
||||
|
||||
|
||||
class CudaUtils:
|
||||
"""Helper utilities for CUDA specific Triton features."""
|
||||
|
||||
@staticmethod
|
||||
def is_cuda() -> bool:
|
||||
return driver.active.get_current_target().backend == "cuda"
|
||||
|
||||
@staticmethod
|
||||
def verify_tma() -> bool:
|
||||
return (
|
||||
CudaUtils.is_cuda()
|
||||
and torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_num_sms() -> int:
|
||||
if not CudaUtils.is_cuda():
|
||||
raise RuntimeError("Triton is not running on CUDA backend")
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is not available")
|
||||
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
|
||||
|
||||
class TmaDescriptorHelper:
|
||||
"""Helper class for managing TMA descriptors in Triton kernels."""
|
||||
|
||||
class KernelParamWrapper:
|
||||
def __init__(self, desc: torch.Tensor):
|
||||
self.desc = desc
|
||||
|
||||
def tma_desc_cpu_ptr(self) -> int:
|
||||
return self.desc.data_ptr()
|
||||
|
||||
def __init__(self, tma_size: int = 128):
|
||||
if not CudaUtils.verify_tma():
|
||||
raise RuntimeError(
|
||||
"TMA not supported on this device (requires Hopper or newer)"
|
||||
)
|
||||
if "nv_tma_desc_type" not in dir(tl):
|
||||
raise RuntimeError(
|
||||
"TMA grid constant descriptors not supported in your Triton version"
|
||||
)
|
||||
|
||||
self.tma_size = tma_size
|
||||
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
|
||||
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
|
||||
self.descriptors: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def init_tma_descriptor(self, name: str) -> None:
|
||||
self.descriptors[name] = torch.empty(
|
||||
self.tma_size, device="cpu", dtype=torch.int8
|
||||
)
|
||||
|
||||
def fill_1d_tma_descriptor(
|
||||
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
|
||||
) -> None:
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_1d_tma_descriptor_inner(
|
||||
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def fill_2d_tma_descriptor(
|
||||
self,
|
||||
name: str,
|
||||
ptr: int,
|
||||
dim1: int,
|
||||
dim0: int,
|
||||
block_dim1: int,
|
||||
block_dim0: int,
|
||||
element_size: int,
|
||||
) -> None:
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_2d_tma_descriptor_inner(
|
||||
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def get_tma_descriptor_kernel_param(
|
||||
self, name: str
|
||||
) -> "TmaDescriptorHelper.KernelParamWrapper":
|
||||
if name not in self.descriptors or self.descriptors[name] is None:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
return self.KernelParamWrapper(self.descriptors[name])
|
||||
|
||||
|
||||
HOPPER_CONFIGS = [
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=8,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
STANDARD_CONFIGS = [
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=8,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def early_config_prune(configs, args, **kwargs):
|
||||
"""Filter out configurations that would exceed shared memory capacity."""
|
||||
k = kwargs.get("K", 0)
|
||||
valid_configs = [
|
||||
config for config in configs if config.kwargs.get("BLOCK_SIZE_K", 0) <= k
|
||||
]
|
||||
if not valid_configs and configs:
|
||||
return [
|
||||
min(
|
||||
configs,
|
||||
key=lambda c: c.kwargs.get("BLOCK_SIZE_K", float("inf")),
|
||||
)
|
||||
]
|
||||
|
||||
return valid_configs
|
||||
@@ -190,6 +190,11 @@ class PatchManager:
|
||||
|
||||
apply_mistral_tokenizer_image_patch()
|
||||
|
||||
if self.cfg.model_config_type == "deepseek_v3":
|
||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
||||
|
||||
patch_deepseek_v3_moe()
|
||||
|
||||
def _apply_fp8_patches(self):
|
||||
"""Apply patches for FP8 support."""
|
||||
if self.cfg.fp8:
|
||||
|
||||
187
src/axolotl/monkeypatch/deepseek_v3/__init__.py
Normal file
187
src/axolotl/monkeypatch/deepseek_v3/__init__.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Monkeypatches for DeepSeek V3 MoE to use Triton contiguous grouped GEMM kernels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.kernels.moe import ContiguousGroupedGEMM
|
||||
|
||||
_GROUP_SIZE_M = 128
|
||||
|
||||
|
||||
def _align_to(value: int, alignment: int) -> int:
|
||||
if value <= 0:
|
||||
return 0
|
||||
return math.ceil(value / alignment) * alignment
|
||||
|
||||
|
||||
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
|
||||
return hidden_states.is_cuda and hidden_states.shape[0] > 0
|
||||
|
||||
|
||||
def _collect_expert_weights(module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
gate_weights = []
|
||||
up_weights = []
|
||||
down_weights = []
|
||||
for expert in module.experts:
|
||||
gate_weights.append(expert.gate_proj.weight)
|
||||
up_weights.append(expert.up_proj.weight)
|
||||
down_weights.append(expert.down_proj.weight)
|
||||
gate = torch.stack(gate_weights, dim=0).contiguous()
|
||||
up = torch.stack(up_weights, dim=0).contiguous()
|
||||
down = torch.stack(down_weights, dim=0).contiguous()
|
||||
return gate, up, down
|
||||
|
||||
|
||||
def _moe_triton_forward(
|
||||
module,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
group_size_m: int,
|
||||
fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if not _is_triton_eligible(hidden_states):
|
||||
return fallback(hidden_states, topk_indices, topk_weights)
|
||||
|
||||
device = hidden_states.device
|
||||
hidden_dtype = hidden_states.dtype
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
top_k = topk_indices.size(-1)
|
||||
|
||||
expanded_hidden = hidden_states.repeat_interleave(top_k, dim=0)
|
||||
expert_assignments = topk_indices.reshape(-1)
|
||||
if expanded_hidden.numel() == 0:
|
||||
return hidden_states.new_zeros_like(hidden_states)
|
||||
|
||||
sort_perm = torch.argsort(expert_assignments)
|
||||
sorted_hidden = expanded_hidden.index_select(0, sort_perm)
|
||||
sorted_assignments = expert_assignments.index_select(0, sort_perm)
|
||||
|
||||
num_experts = len(module.experts)
|
||||
counts = torch.bincount(sorted_assignments, minlength=num_experts)
|
||||
counts_cpu = counts.to(torch.int64).cpu().tolist()
|
||||
padded_counts = [_align_to(c, group_size_m) for c in counts_cpu]
|
||||
|
||||
total_actual = sum(counts_cpu)
|
||||
total_padded = sum(padded_counts)
|
||||
if total_actual == 0 or total_padded == 0:
|
||||
return hidden_states.new_zeros_like(hidden_states)
|
||||
|
||||
actual_offsets = [0]
|
||||
padded_offsets = [0]
|
||||
for count, padded in zip(counts_cpu, padded_counts, strict=False):
|
||||
actual_offsets.append(actual_offsets[-1] + count)
|
||||
padded_offsets.append(padded_offsets[-1] + padded)
|
||||
|
||||
grouped_hidden = hidden_states.new_zeros((total_padded, hidden_dim))
|
||||
expert_index_tensor = torch.empty(total_padded, dtype=torch.int32, device=device)
|
||||
|
||||
for idx, (count, padded) in enumerate(zip(counts_cpu, padded_counts, strict=False)):
|
||||
dst_start = padded_offsets[idx]
|
||||
dst_end = dst_start + padded
|
||||
if padded == 0:
|
||||
continue
|
||||
expert_index_tensor[dst_start:dst_end] = idx
|
||||
if count > 0:
|
||||
src_start = actual_offsets[idx]
|
||||
src_end = src_start + count
|
||||
grouped_hidden[dst_start : dst_start + count].copy_(
|
||||
sorted_hidden[src_start:src_end]
|
||||
)
|
||||
|
||||
gate_weights, up_weights, down_weights = _collect_expert_weights(module)
|
||||
|
||||
gate_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
gate_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
up_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
up_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
|
||||
|
||||
hidden_chunks = []
|
||||
for idx, count in enumerate(counts_cpu):
|
||||
if count == 0:
|
||||
continue
|
||||
pad_start = padded_offsets[idx]
|
||||
pad_end = pad_start + count
|
||||
gate_slice = gate_out[pad_start:pad_end].to(hidden_dtype)
|
||||
up_slice = up_out[pad_start:pad_end].to(hidden_dtype)
|
||||
hidden_chunks.append(act_fn(gate_slice) * up_slice)
|
||||
|
||||
hidden_concat = torch.cat(hidden_chunks, dim=0)
|
||||
|
||||
intermediate_dim = hidden_concat.shape[-1]
|
||||
hidden_grouped = hidden_states.new_zeros((total_padded, intermediate_dim))
|
||||
|
||||
for idx, count in enumerate(counts_cpu):
|
||||
if count == 0:
|
||||
continue
|
||||
pad_start = padded_offsets[idx]
|
||||
src_start = actual_offsets[idx]
|
||||
src_end = src_start + count
|
||||
hidden_grouped[pad_start : pad_start + count].copy_(
|
||||
hidden_concat[src_start:src_end]
|
||||
)
|
||||
|
||||
down_out = ContiguousGroupedGEMM.apply(
|
||||
hidden_grouped,
|
||||
down_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
|
||||
down_chunks = []
|
||||
for idx, count in enumerate(counts_cpu):
|
||||
if count == 0:
|
||||
continue
|
||||
pad_start = padded_offsets[idx]
|
||||
pad_end = pad_start + count
|
||||
down_chunks.append(down_out[pad_start:pad_end].to(hidden_dtype))
|
||||
|
||||
down_concat = torch.cat(down_chunks, dim=0)
|
||||
|
||||
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
|
||||
expanded_output.index_copy_(0, sort_perm, down_concat.to(hidden_dtype))
|
||||
expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim)
|
||||
|
||||
weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype)
|
||||
return weighted.sum(dim=1)
|
||||
|
||||
|
||||
def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
|
||||
"""Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels."""
|
||||
|
||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
|
||||
if getattr(DeepseekV3MoE, "_axolotl_triton_patch", False):
|
||||
return
|
||||
|
||||
original_moe = DeepseekV3MoE.moe
|
||||
|
||||
def patched_moe(self, hidden_states, topk_indices, topk_weights):
|
||||
with contextlib.suppress(RuntimeError):
|
||||
return _moe_triton_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
topk_indices,
|
||||
topk_weights,
|
||||
group_size_m,
|
||||
original_moe,
|
||||
)
|
||||
return original_moe(self, hidden_states, topk_indices, topk_weights)
|
||||
|
||||
DeepseekV3MoE.moe = patched_moe
|
||||
DeepseekV3MoE._axolotl_triton_patch = True
|
||||
Reference in New Issue
Block a user