diff --git a/scripts/benchmarks/deepseek_v3_moe.py b/scripts/benchmarks/deepseek_v3_moe.py new file mode 100644 index 000000000..3d309e723 --- /dev/null +++ b/scripts/benchmarks/deepseek_v3_moe.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python +"""Microbenchmark for DeepSeek V3 MoE block comparing baseline vs Triton CG kernels. + +Example usage (run from project root): + + PYTHONPATH=./src:../transformers/src \ + python scripts/benchmarks/deepseek_v3_moe.py --device cuda --iters 20 +""" + +from __future__ import annotations + +import argparse +import time +from types import MethodType + +import torch + +try: + from transformers.models.deepseek_v3.configuration_deepseek_v3 import ( + DeepseekV3Config, + ) + from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE +except ImportError as exc: # pragma: no cover - utility script + raise SystemExit( + "Transformers with DeepSeek-V3 support must be available in PYTHONPATH" + ) from exc + +from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe + +DTYPE_MAP = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, +} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--batch", type=int, default=2, help="batch size") + parser.add_argument("--seq-len", type=int, default=256, help="sequence length") + parser.add_argument("--hidden-size", type=int, default=4096, help="MoE hidden size") + parser.add_argument( + "--moe-intermediate-size", + type=int, + default=8192, + help="MoE intermediate projection size", + ) + parser.add_argument( + "--n-experts", + type=int, + default=64, + help="Number of routed experts", + ) + parser.add_argument( + "--top-k", + type=int, + default=4, + help="Number of experts per token", + ) + parser.add_argument( + "--groups", + type=int, + default=8, + help="Router groups (must divide n-experts)", + ) + parser.add_argument( + "--dtype", + choices=DTYPE_MAP.keys(), + default="bf16", + help="Computation dtype", + ) + parser.add_argument( + "--device", + default="auto", + choices=["auto", "cpu", "cuda"], + help="Execution device", + ) + parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations") + parser.add_argument("--iters", type=int, default=25, help="Benchmark iterations") + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--group-size", + type=int, + default=128, + help="GROUP_SIZE_M used by the Triton kernel", + ) + return parser.parse_args() + + +def resolve_device(requested: str) -> torch.device: + if requested == "auto": + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.device(requested) + + +def build_module(args: argparse.Namespace) -> DeepseekV3MoE: + config = DeepseekV3Config( + hidden_size=args.hidden_size, + intermediate_size=args.moe_intermediate_size, + moe_intermediate_size=args.moe_intermediate_size, + n_routed_experts=args.n_experts, + num_experts_per_tok=args.top_k, + n_group=args.groups, + topk_group=max(1, min(args.groups, args.top_k)), + n_shared_experts=1, + ) + module = DeepseekV3MoE(config) + module.eval() + return module + + +@torch.no_grad() +def benchmark( + module: DeepseekV3MoE, inputs: torch.Tensor, iters: int, warmup: int +) -> float: + for _ in range(warmup): + module(inputs) + if inputs.is_cuda: + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(iters): + module(inputs) + if inputs.is_cuda: + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + return (elapsed / iters) * 1000.0 + + +def main() -> None: # pragma: no cover - CLI entrypoint + args = parse_args() + torch.manual_seed(args.seed) + + device = resolve_device(args.device) + dtype = DTYPE_MAP[args.dtype] + + if args.n_experts % args.groups != 0: + raise SystemExit("n-experts must be divisible by groups") + if args.top_k > args.n_experts: + raise SystemExit("top-k cannot exceed number of experts") + + if device.type == "cuda" and not torch.cuda.is_available(): + raise SystemExit("CUDA requested but not available") + + baseline_module = build_module(args) + original_moe = DeepseekV3MoE.moe + baseline_module.moe = MethodType(original_moe, baseline_module) + state_dict = baseline_module.state_dict() + + patch_deepseek_v3_moe(group_size_m=args.group_size) + patched_module = build_module(args) + patched_module.load_state_dict(state_dict) + + baseline_module.to(device=device, dtype=dtype) + patched_module.to(device=device, dtype=dtype) + + inputs = torch.randn( + args.batch, + args.seq_len, + args.hidden_size, + device=device, + dtype=dtype, + ) + + with torch.no_grad(): + ref_output = baseline_module(inputs) + patched_output = patched_module(inputs) + max_diff = (ref_output - patched_output).abs().max().item() + + baseline_ms = benchmark(baseline_module, inputs, args.iters, args.warmup) + patched_ms = benchmark(patched_module, inputs, args.iters, args.warmup) + + speedup = baseline_ms / patched_ms if patched_ms > 0 else float("nan") + + print( + f"Device={device.type} dtype={dtype} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}" + ) + print( + f"Baseline: {baseline_ms:.3f} ms | Patched: {patched_ms:.3f} ms | x{speedup:.2f}" + ) + print(f"Max |Δ| between outputs: {max_diff:.2e}") + + +if __name__ == "__main__": + main() diff --git a/src/axolotl/kernels/moe/__init__.py b/src/axolotl/kernels/moe/__init__.py new file mode 100644 index 000000000..362b2cdb4 --- /dev/null +++ b/src/axolotl/kernels/moe/__init__.py @@ -0,0 +1,17 @@ +"""Mixture-of-Experts kernel implementations.""" + +from .tt_cg_gemm import ( + ContiguousGroupedGEMM, + ContiguousGroupedGEMMForwardOnly, + cg_grouped_gemm, + cg_grouped_gemm_forward, + cg_grouped_gemm_forward_dynamic, +) + +__all__ = [ + "cg_grouped_gemm", + "cg_grouped_gemm_forward", + "cg_grouped_gemm_forward_dynamic", + "ContiguousGroupedGEMM", + "ContiguousGroupedGEMMForwardOnly", +] diff --git a/src/axolotl/kernels/moe/tt_cg_gemm/__init__.py b/src/axolotl/kernels/moe/tt_cg_gemm/__init__.py new file mode 100644 index 000000000..64ea4b8ea --- /dev/null +++ b/src/axolotl/kernels/moe/tt_cg_gemm/__init__.py @@ -0,0 +1,17 @@ +"""Vendored Triton contiguous grouped GEMM kernels from TorchTitan.""" + +from .cg_backward import ContiguousGroupedGEMM +from .cg_forward import ( + ContiguousGroupedGEMM as ContiguousGroupedGEMMForwardOnly, + cg_grouped_gemm, + cg_grouped_gemm_forward, + cg_grouped_gemm_forward_dynamic, +) + +__all__ = [ + "cg_grouped_gemm", + "cg_grouped_gemm_forward", + "cg_grouped_gemm_forward_dynamic", + "ContiguousGroupedGEMM", + "ContiguousGroupedGEMMForwardOnly", +] diff --git a/src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py b/src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py new file mode 100644 index 000000000..eec527ef9 --- /dev/null +++ b/src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py @@ -0,0 +1,290 @@ +"""Vendored backward pass for Triton contiguous grouped GEMM.""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import triton +import triton.language as tl + +from .cg_forward import cg_grouped_gemm_forward +from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune + +GROUP_SIZE_M = 128 + + +@triton.autotune( + configs=STANDARD_CONFIGS, + key=["M_TOTAL", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_cg_backward_dx( + grad_output_ptr, + b_ptr, + grad_input_ptr, + indices_ptr, + M_TOTAL: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + NUM_EXPERTS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M, +): + """Compute gradients with respect to inputs.""" + + pid = tl.program_id(0) + + num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M) + num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + + tile_m = pid // num_k_tiles + tile_k = pid % num_k_tiles + + m_start = tile_m * BLOCK_SIZE_M + k_start = tile_k * BLOCK_SIZE_K + + if m_start < M_TOTAL: + offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start + offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start + + mask_m = offs_m < M_TOTAL + mask_k = offs_k < K + + group_idx = m_start // GROUP_SIZE_M + expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M) + + grad_input = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_K], dtype=tl.float32) + + for n in range(0, N, BLOCK_SIZE_N): + offs_n = tl.arange(0, BLOCK_SIZE_N) + n + mask_n = offs_n < N + + mask_go = mask_m[:, None] & mask_n[None, :] + mask_w = mask_n[:, None] & mask_k[None, :] + + go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :] + go = tl.load(go_ptrs, mask=mask_go, other=0.0) + + w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :] + w = tl.load(w_ptrs, mask=mask_w, other=0.0) + + grad_input += tl.dot(go, w) + + grad_input_ptrs = grad_input_ptr + offs_m[:, None] * K + offs_k[None, :] + mask_gi = mask_m[:, None] & mask_k[None, :] + tl.store(grad_input_ptrs, grad_input, mask=mask_gi) + + +@triton.jit +def _kernel_cg_backward_dw( + grad_output_ptr, + inputs_ptr, + grad_weights_ptr, + indices_ptr, + M_TOTAL: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + NUM_EXPERTS: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, +): + """Simplified kernel for expert weight gradients.""" + + pid = tl.program_id(0) + + expert_id = pid // ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K)) + position_id = pid % ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K)) + + if expert_id < NUM_EXPERTS: + n_tiles = K // BLOCK_SIZE_K + tile_n = position_id // n_tiles + tile_k = position_id % n_tiles + + n_start = tile_n * BLOCK_SIZE_N + k_start = tile_k * BLOCK_SIZE_K + + if n_start < N and k_start < K: + offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start + offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start + + mask_n = offs_n < N + mask_k = offs_k < K + + grad_weights = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_K], dtype=tl.float32) + + for group_idx in range(0, M_TOTAL // GROUP_SIZE_M): + group_start = group_idx * GROUP_SIZE_M + group_expert = tl.load(indices_ptr + group_start) + + if group_expert == expert_id: + for m_offset in range(0, GROUP_SIZE_M, BLOCK_SIZE_M): + m_start = group_start + m_offset + offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start + + mask_m = offs_m < min(group_start + GROUP_SIZE_M, M_TOTAL) + + go_ptrs = ( + grad_output_ptr + offs_m[:, None] * N + offs_n[None, :] + ) + mask_go = mask_m[:, None] & mask_n[None, :] + go = tl.load(go_ptrs, mask=mask_go, other=0.0) + + in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :] + mask_in = mask_m[:, None] & mask_k[None, :] + inp = tl.load(in_ptrs, mask=mask_in, other=0.0) + + go_t = tl.trans(go) + grad_weights += tl.dot(go_t, inp) + + grad_w_ptrs = ( + grad_weights_ptr + + expert_id * N * K + + offs_n[:, None] * K + + offs_k[None, :] + ) + mask_gw = mask_n[:, None] & mask_k[None, :] + tl.store(grad_w_ptrs, grad_weights, mask=mask_gw) + + +def cg_grouped_gemm_backward_weights( + grad_output: torch.Tensor, + inputs: torch.Tensor, + expert_indices: torch.Tensor, + num_experts: int, + group_size_m: int = GROUP_SIZE_M, +) -> torch.Tensor: + """Backward pass for expert weights.""" + + assert grad_output.is_contiguous(), "Grad output tensor must be contiguous" + assert inputs.is_contiguous(), "Inputs tensor must be contiguous" + assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous" + + M_total, N = grad_output.shape + _, K = inputs.shape + + if expert_indices.dtype != torch.int32: + expert_indices = expert_indices.to(torch.int32) + + grad_weights = torch.zeros( + (num_experts, N, K), device=grad_output.device, dtype=grad_output.dtype + ) + + block_size_n = min(128, N) + block_size_k = min(32, K) + block_size_m = min(32, group_size_m) + + n_tiles = triton.cdiv(N, block_size_n) + k_tiles = triton.cdiv(K, block_size_k) + grid = (num_experts * n_tiles * k_tiles,) + + _kernel_cg_backward_dw[grid]( + grad_output, + inputs, + grad_weights, + expert_indices, + M_TOTAL=M_total, + N=N, + K=K, + NUM_EXPERTS=num_experts, + GROUP_SIZE_M=group_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + BLOCK_SIZE_M=block_size_m, + ) + + return grad_weights + + +def cg_grouped_gemm_backward_inputs( + grad_output: torch.Tensor, + expert_weights: torch.Tensor, + expert_indices: torch.Tensor, + group_size_m: int = GROUP_SIZE_M, +) -> torch.Tensor: + """Backward pass for inputs.""" + + assert grad_output.is_contiguous(), "Grad output tensor must be contiguous" + assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous" + assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous" + + M_total, N = grad_output.shape + num_experts, _, K = expert_weights.shape + + assert M_total % group_size_m == 0, ( + f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})" + ) + + grad_inputs = torch.zeros( + (M_total, K), device=grad_output.device, dtype=grad_output.dtype + ) + + grid = lambda meta: ( + triton.cdiv(M_total, meta["BLOCK_SIZE_M"]) + * triton.cdiv(K, meta["BLOCK_SIZE_K"]), + ) + + _kernel_cg_backward_dx[grid]( + grad_output, + expert_weights, + grad_inputs, + expert_indices, + M_TOTAL=M_total, + N=N, + K=K, + NUM_EXPERTS=num_experts, + GROUP_SIZE_M=group_size_m, + ) + + return grad_inputs + + +class ContiguousGroupedGEMM(torch.autograd.Function): + """Autograd function with full backward support.""" + + @staticmethod + def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M): + ctx.save_for_backward(inputs, expert_weights, expert_indices) + ctx.group_size_m = group_size_m + + return cg_grouped_gemm_forward( + inputs=inputs, + expert_weights=expert_weights, + expert_indices=expert_indices, + group_size_m=group_size_m, + ) + + @staticmethod + def backward(ctx, grad_output): + inputs, expert_weights, expert_indices = ctx.saved_tensors + group_size_m = ctx.group_size_m + + grad_output = grad_output.contiguous() + num_experts = expert_weights.shape[0] + + grad_inputs = cg_grouped_gemm_backward_inputs( + grad_output=grad_output, + expert_weights=expert_weights, + expert_indices=expert_indices, + group_size_m=group_size_m, + ) + + grad_weights = cg_grouped_gemm_backward_weights( + grad_output=grad_output, + inputs=inputs, + expert_indices=expert_indices, + num_experts=num_experts, + group_size_m=group_size_m, + ) + + grad_indices = None + grad_group_size_m = None + + return grad_inputs, grad_weights, grad_indices, grad_group_size_m diff --git a/src/axolotl/kernels/moe/tt_cg_gemm/cg_forward.py b/src/axolotl/kernels/moe/tt_cg_gemm/cg_forward.py new file mode 100644 index 000000000..df0c00947 --- /dev/null +++ b/src/axolotl/kernels/moe/tt_cg_gemm/cg_forward.py @@ -0,0 +1,311 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Vendored forward Triton contiguous grouped GEMM kernels.""" + +import torch +import triton +import triton.language as tl + +from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune + +GROUP_SIZE_M = 128 + + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, super_group_m): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * super_group_m + group_size_m = min(num_pid_m - first_pid_m, super_group_m) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n + + +@triton.autotune( + configs=STANDARD_CONFIGS, + key=["M_TOTAL", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_cg_persistent_forward( + a_ptr, + b_ptr, + c_ptr, + indices_ptr, + M_TOTAL: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + NUM_EXPERTS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + NUM_SMS: tl.constexpr, + GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M, + SUPER_GROUP_M: tl.constexpr = 32, +): + """ + Contiguous Grouped GEMM kernel forward (persistent variant). + """ + + c_type = c_ptr.dtype.element_ty + + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M_TOTAL, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + tile_id_c = start_pid - NUM_SMS + num_pid_in_group = SUPER_GROUP_M * num_pid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS): + tile_m_idx, tile_n_idx = _compute_pid( + tile_id, num_pid_in_group, num_pid_m, SUPER_GROUP_M + ) + + m_start = tile_m_idx * BLOCK_SIZE_M + n_start = tile_n_idx * BLOCK_SIZE_N + + if m_start < M_TOTAL: + offs_m = m_start + tl.arange(0, BLOCK_SIZE_M) + offs_n = n_start + tl.arange(0, BLOCK_SIZE_N) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + + mask_m = offs_m < M_TOTAL + mask_n = offs_n < N + mask_k = offs_k < K + + mask_a = mask_m[:, None] & mask_k[None, :] + mask_b = mask_n[:, None] & mask_k[None, :] + + group_idx = m_start // GROUP_SIZE_M + expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M) + + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + a = tl.load(a_ptrs, mask=mask_a, other=0.0) + + b_ptrs = ( + b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :] + ) + b = tl.load(b_ptrs, mask=mask_b, other=0.0) + + accumulator += tl.dot(a, b.T) + + tile_id_c += NUM_SMS + tile_m_idx, tile_n_idx = _compute_pid( + tile_id_c, num_pid_in_group, num_pid_m, SUPER_GROUP_M + ) + + offs_m = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + mask_m = offs_m < M_TOTAL + mask_n = offs_n < N + mask_c = mask_m[:, None] & mask_n[None, :] + + c = accumulator.to(tl.float32) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + tl.store(c_ptrs, c.to(c_type), mask=mask_c) + + +@triton.autotune( + configs=STANDARD_CONFIGS, + key=["M_TOTAL", "N", "K"], + prune_configs_by={"early_config_prune": early_config_prune}, +) +@triton.jit +def _kernel_cg_forward_aligned( + a_ptr, + b_ptr, + c_ptr, + indices_ptr, + M_TOTAL: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + NUM_EXPERTS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M, +): + """ + Contiguous Grouped GEMM kernel forward for aligned inputs. + """ + + pid = tl.program_id(0) + + c_type = c_ptr.dtype.element_ty + + num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N) + + tile_m = pid // num_n_tiles + tile_n = pid % num_n_tiles + + m_start = tile_m * BLOCK_SIZE_M + n_start = tile_n * BLOCK_SIZE_N + + if m_start < M_TOTAL: + offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start + offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start + + mask_m = offs_m < M_TOTAL + mask_n = offs_n < N + + group_idx = m_start // GROUP_SIZE_M + expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M) + + acc = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=tl.float32) + + for k in range(0, K, BLOCK_SIZE_K): + offs_k = tl.arange(0, BLOCK_SIZE_K) + k + mask_k = offs_k < K + + mask_a = mask_m[:, None] & mask_k[None, :] + mask_b = mask_n[:, None] & mask_k[None, :] + + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + a = tl.load(a_ptrs, mask=mask_a, other=0.0) + + b_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :] + b = tl.load(b_ptrs, mask=mask_b, other=0.0) + + acc += tl.dot(a, b.T) + + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask_c = mask_m[:, None] & mask_n[None, :] + tl.store(c_ptrs, acc.to(c_type), mask=mask_c) + + +def cg_grouped_gemm_forward( + inputs: torch.Tensor, + expert_weights: torch.Tensor, + expert_indices: torch.Tensor, + group_size_m: int = GROUP_SIZE_M, +) -> torch.Tensor: + """Contiguous grouped GEMM forward pass for MoE.""" + + assert inputs.is_contiguous(), "Input tensor must be contiguous" + assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous" + assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous" + + M_total, K = inputs.shape + assert M_total % group_size_m == 0, ( + f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})" + ) + + if expert_indices.dtype != torch.int32: + expert_indices = expert_indices.to(torch.int32) + + num_experts, N, K_weights = expert_weights.shape + assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})" + assert expert_indices.shape[0] == M_total, ( + "Expert indices length must match M_total" + ) + + output = torch.empty((M_total, N), device=inputs.device, dtype=torch.bfloat16) + + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + grid = (NUM_SMS, 1, 1) + _kernel_cg_persistent_forward[grid]( + inputs, + expert_weights, + output, + expert_indices, + M_TOTAL=M_total, + N=N, + K=K, + NUM_EXPERTS=num_experts, + GROUP_SIZE_M=group_size_m, + NUM_SMS=NUM_SMS, + ) + + return output + + +def cg_grouped_gemm_forward_dynamic( + inputs: torch.Tensor, + expert_weights: torch.Tensor, + expert_indices: torch.Tensor, + group_size_m: int = GROUP_SIZE_M, +) -> torch.Tensor: + """Contiguous grouped GEMM forward pass for MoE with autotuned launch.""" + + assert inputs.is_contiguous(), "Input tensor must be contiguous" + assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous" + assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous" + + M_total, K = inputs.shape + assert M_total % group_size_m == 0, ( + f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})" + ) + + if expert_indices.dtype != torch.int32: + expert_indices = expert_indices.to(torch.int32) + + num_experts, N, K_weights = expert_weights.shape + assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})" + assert expert_indices.shape[0] == M_total, ( + "Expert indices length must match M_total" + ) + + output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype) + + grid = lambda meta: ( + triton.cdiv(M_total, meta["BLOCK_SIZE_M"]) + * triton.cdiv(N, meta["BLOCK_SIZE_N"]), + ) + + _kernel_cg_forward_aligned[grid]( + inputs, + expert_weights, + output, + expert_indices, + M_TOTAL=M_total, + N=N, + K=K, + NUM_EXPERTS=num_experts, + GROUP_SIZE_M=group_size_m, + ) + + return output + + +class ContiguousGroupedGEMM(torch.autograd.Function): + """Autograd function for contiguous grouped GEMM forward pass only.""" + + @staticmethod + def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M): + return cg_grouped_gemm_forward( + inputs=inputs, + expert_weights=expert_weights, + expert_indices=expert_indices, + group_size_m=group_size_m, + ) + + @staticmethod + def backward(ctx, grad_output): # pragma: no cover - not implemented + raise NotImplementedError("Backward pass not implemented") + + +def cg_grouped_gemm( + inputs: torch.Tensor, + expert_weights: torch.Tensor, + expert_indices: torch.Tensor, + group_size_m: int = GROUP_SIZE_M, +) -> torch.Tensor: + """Convenience wrapper for the forward-only autograd function.""" + + if expert_indices.dtype != torch.int32: + expert_indices = expert_indices.to(torch.int32) + + return ContiguousGroupedGEMM.apply( + inputs, expert_weights, expert_indices, group_size_m + ) diff --git a/src/axolotl/kernels/moe/tt_cg_gemm/cg_reference.py b/src/axolotl/kernels/moe/tt_cg_gemm/cg_reference.py new file mode 100644 index 000000000..2328940ea --- /dev/null +++ b/src/axolotl/kernels/moe/tt_cg_gemm/cg_reference.py @@ -0,0 +1,31 @@ +"""Reference implementation for contiguous grouped GEMM.""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def pytorch_reference( + inputs: torch.Tensor, + expert_weights: torch.Tensor, + expert_indices: torch.Tensor, + group_size_m: int = 128, +) -> torch.Tensor: + """Simple PyTorch implementation for verification.""" + + M_total, K = inputs.shape + num_experts, N, _ = expert_weights.shape + + output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype) + + for i in range(0, M_total, group_size_m): + end_idx = min(i + group_size_m, M_total) + expert_idx = expert_indices[i].item() + expert_weight = expert_weights[expert_idx] + output[i:end_idx] = torch.matmul(inputs[i:end_idx], expert_weight.T) + + return output diff --git a/src/axolotl/kernels/moe/tt_cg_gemm/tma_cuda_autotune.py b/src/axolotl/kernels/moe/tt_cg_gemm/tma_cuda_autotune.py new file mode 100644 index 000000000..168d5dd60 --- /dev/null +++ b/src/axolotl/kernels/moe/tt_cg_gemm/tma_cuda_autotune.py @@ -0,0 +1,209 @@ +"""Autotuning utilities for Triton contiguous grouped GEMM kernels.""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import torch +import triton +import triton.language as tl +from triton.runtime import driver + + +class CudaUtils: + """Helper utilities for CUDA specific Triton features.""" + + @staticmethod + def is_cuda() -> bool: + return driver.active.get_current_target().backend == "cuda" + + @staticmethod + def verify_tma() -> bool: + return ( + CudaUtils.is_cuda() + and torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 9 + ) + + @staticmethod + def get_num_sms() -> int: + if not CudaUtils.is_cuda(): + raise RuntimeError("Triton is not running on CUDA backend") + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + return torch.cuda.get_device_properties("cuda").multi_processor_count + + +class TmaDescriptorHelper: + """Helper class for managing TMA descriptors in Triton kernels.""" + + class KernelParamWrapper: + def __init__(self, desc: torch.Tensor): + self.desc = desc + + def tma_desc_cpu_ptr(self) -> int: + return self.desc.data_ptr() + + def __init__(self, tma_size: int = 128): + if not CudaUtils.verify_tma(): + raise RuntimeError( + "TMA not supported on this device (requires Hopper or newer)" + ) + if "nv_tma_desc_type" not in dir(tl): + raise RuntimeError( + "TMA grid constant descriptors not supported in your Triton version" + ) + + self.tma_size = tma_size + self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor + self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor + self.descriptors: Dict[str, torch.Tensor] = {} + + def init_tma_descriptor(self, name: str) -> None: + self.descriptors[name] = torch.empty( + self.tma_size, device="cpu", dtype=torch.int8 + ) + + def fill_1d_tma_descriptor( + self, name: str, ptr: int, dim: int, block_dim: int, element_size: int + ) -> None: + if name not in self.descriptors: + raise ValueError(f"TMA descriptor '{name}' not initialized") + + desc_x = self.descriptors[name] + if desc_x.data_ptr() % 64 != 0: + raise ValueError("TMA descriptor must be 64-byte aligned") + self.fill_1d_tma_descriptor_inner( + ptr, dim, block_dim, element_size, desc_x.data_ptr() + ) + + def fill_2d_tma_descriptor( + self, + name: str, + ptr: int, + dim1: int, + dim0: int, + block_dim1: int, + block_dim0: int, + element_size: int, + ) -> None: + if name not in self.descriptors: + raise ValueError(f"TMA descriptor '{name}' not initialized") + + desc_x = self.descriptors[name] + if desc_x.data_ptr() % 64 != 0: + raise ValueError("TMA descriptor must be 64-byte aligned") + self.fill_2d_tma_descriptor_inner( + ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr() + ) + + def get_tma_descriptor_kernel_param( + self, name: str + ) -> "TmaDescriptorHelper.KernelParamWrapper": + if name not in self.descriptors or self.descriptors[name] is None: + raise ValueError(f"TMA descriptor '{name}' not initialized") + return self.KernelParamWrapper(self.descriptors[name]) + + +HOPPER_CONFIGS = [ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=2, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=4, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, + num_stages=4, + num_warps=8, + ), +] + + +STANDARD_CONFIGS = [ + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=2, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=2, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, + num_stages=4, + num_warps=8, + ), +] + + +def early_config_prune(configs, args, **kwargs): + """Filter out configurations that would exceed shared memory capacity.""" + k = kwargs.get("K", 0) + valid_configs = [ + config for config in configs if config.kwargs.get("BLOCK_SIZE_K", 0) <= k + ] + if not valid_configs and configs: + return [ + min( + configs, + key=lambda c: c.kwargs.get("BLOCK_SIZE_K", float("inf")), + ) + ] + + return valid_configs diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 1e46f5c34..dafa8a28c 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -190,6 +190,11 @@ class PatchManager: apply_mistral_tokenizer_image_patch() + if self.cfg.model_config_type == "deepseek_v3": + from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe + + patch_deepseek_v3_moe() + def _apply_fp8_patches(self): """Apply patches for FP8 support.""" if self.cfg.fp8: diff --git a/src/axolotl/monkeypatch/deepseek_v3/__init__.py b/src/axolotl/monkeypatch/deepseek_v3/__init__.py new file mode 100644 index 000000000..d32793656 --- /dev/null +++ b/src/axolotl/monkeypatch/deepseek_v3/__init__.py @@ -0,0 +1,187 @@ +"""Monkeypatches for DeepSeek V3 MoE to use Triton contiguous grouped GEMM kernels.""" + +from __future__ import annotations + +import contextlib +import math +from typing import Callable + +import torch + +from axolotl.kernels.moe import ContiguousGroupedGEMM + +_GROUP_SIZE_M = 128 + + +def _align_to(value: int, alignment: int) -> int: + if value <= 0: + return 0 + return math.ceil(value / alignment) * alignment + + +def _is_triton_eligible(hidden_states: torch.Tensor) -> bool: + return hidden_states.is_cuda and hidden_states.shape[0] > 0 + + +def _collect_expert_weights(module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + gate_weights = [] + up_weights = [] + down_weights = [] + for expert in module.experts: + gate_weights.append(expert.gate_proj.weight) + up_weights.append(expert.up_proj.weight) + down_weights.append(expert.down_proj.weight) + gate = torch.stack(gate_weights, dim=0).contiguous() + up = torch.stack(up_weights, dim=0).contiguous() + down = torch.stack(down_weights, dim=0).contiguous() + return gate, up, down + + +def _moe_triton_forward( + module, + hidden_states: torch.Tensor, + topk_indices: torch.Tensor, + topk_weights: torch.Tensor, + group_size_m: int, + fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], +) -> torch.Tensor: + if not _is_triton_eligible(hidden_states): + return fallback(hidden_states, topk_indices, topk_weights) + + device = hidden_states.device + hidden_dtype = hidden_states.dtype + num_tokens, hidden_dim = hidden_states.shape + top_k = topk_indices.size(-1) + + expanded_hidden = hidden_states.repeat_interleave(top_k, dim=0) + expert_assignments = topk_indices.reshape(-1) + if expanded_hidden.numel() == 0: + return hidden_states.new_zeros_like(hidden_states) + + sort_perm = torch.argsort(expert_assignments) + sorted_hidden = expanded_hidden.index_select(0, sort_perm) + sorted_assignments = expert_assignments.index_select(0, sort_perm) + + num_experts = len(module.experts) + counts = torch.bincount(sorted_assignments, minlength=num_experts) + counts_cpu = counts.to(torch.int64).cpu().tolist() + padded_counts = [_align_to(c, group_size_m) for c in counts_cpu] + + total_actual = sum(counts_cpu) + total_padded = sum(padded_counts) + if total_actual == 0 or total_padded == 0: + return hidden_states.new_zeros_like(hidden_states) + + actual_offsets = [0] + padded_offsets = [0] + for count, padded in zip(counts_cpu, padded_counts, strict=False): + actual_offsets.append(actual_offsets[-1] + count) + padded_offsets.append(padded_offsets[-1] + padded) + + grouped_hidden = hidden_states.new_zeros((total_padded, hidden_dim)) + expert_index_tensor = torch.empty(total_padded, dtype=torch.int32, device=device) + + for idx, (count, padded) in enumerate(zip(counts_cpu, padded_counts, strict=False)): + dst_start = padded_offsets[idx] + dst_end = dst_start + padded + if padded == 0: + continue + expert_index_tensor[dst_start:dst_end] = idx + if count > 0: + src_start = actual_offsets[idx] + src_end = src_start + count + grouped_hidden[dst_start : dst_start + count].copy_( + sorted_hidden[src_start:src_end] + ) + + gate_weights, up_weights, down_weights = _collect_expert_weights(module) + + gate_out = ContiguousGroupedGEMM.apply( + grouped_hidden, + gate_weights, + expert_index_tensor, + group_size_m, + ) + up_out = ContiguousGroupedGEMM.apply( + grouped_hidden, + up_weights, + expert_index_tensor, + group_size_m, + ) + + act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn + + hidden_chunks = [] + for idx, count in enumerate(counts_cpu): + if count == 0: + continue + pad_start = padded_offsets[idx] + pad_end = pad_start + count + gate_slice = gate_out[pad_start:pad_end].to(hidden_dtype) + up_slice = up_out[pad_start:pad_end].to(hidden_dtype) + hidden_chunks.append(act_fn(gate_slice) * up_slice) + + hidden_concat = torch.cat(hidden_chunks, dim=0) + + intermediate_dim = hidden_concat.shape[-1] + hidden_grouped = hidden_states.new_zeros((total_padded, intermediate_dim)) + + for idx, count in enumerate(counts_cpu): + if count == 0: + continue + pad_start = padded_offsets[idx] + src_start = actual_offsets[idx] + src_end = src_start + count + hidden_grouped[pad_start : pad_start + count].copy_( + hidden_concat[src_start:src_end] + ) + + down_out = ContiguousGroupedGEMM.apply( + hidden_grouped, + down_weights, + expert_index_tensor, + group_size_m, + ) + + down_chunks = [] + for idx, count in enumerate(counts_cpu): + if count == 0: + continue + pad_start = padded_offsets[idx] + pad_end = pad_start + count + down_chunks.append(down_out[pad_start:pad_end].to(hidden_dtype)) + + down_concat = torch.cat(down_chunks, dim=0) + + expanded_output = expanded_hidden.new_empty(expanded_hidden.shape) + expanded_output.index_copy_(0, sort_perm, down_concat.to(hidden_dtype)) + expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim) + + weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype) + return weighted.sum(dim=1) + + +def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None: + """Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels.""" + + from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE + + if getattr(DeepseekV3MoE, "_axolotl_triton_patch", False): + return + + original_moe = DeepseekV3MoE.moe + + def patched_moe(self, hidden_states, topk_indices, topk_weights): + with contextlib.suppress(RuntimeError): + return _moe_triton_forward( + self, + hidden_states, + topk_indices, + topk_weights, + group_size_m, + original_moe, + ) + return original_moe(self, hidden_states, topk_indices, topk_weights) + + DeepseekV3MoE.moe = patched_moe + DeepseekV3MoE._axolotl_triton_patch = True