add mg kernel backend

This commit is contained in:
Dan Saunders
2025-09-23 15:43:16 -04:00
parent 8a1f5ae940
commit d0da67eb17
9 changed files with 1753 additions and 77 deletions

View File

@@ -32,7 +32,7 @@ for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
else: # pragma: no cover - defensive guard else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports") raise SystemExit("Unable to locate axolotl repository root for imports")
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe # noqa: E402
ACCURACY_TOLERANCE = 5e-3 ACCURACY_TOLERANCE = 5e-3
@@ -98,6 +98,12 @@ def parse_args() -> argparse.Namespace:
default=128, default=128,
help="GROUP_SIZE_M used by the Triton kernel", help="GROUP_SIZE_M used by the Triton kernel",
) )
parser.add_argument(
"--backend",
choices=["cg", "mg"],
default="mg",
help="MoE kernel backend to benchmark",
)
return parser.parse_args() return parser.parse_args()
@@ -163,7 +169,7 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
baseline_module.moe = MethodType(original_moe, baseline_module) baseline_module.moe = MethodType(original_moe, baseline_module)
state_dict = baseline_module.state_dict() state_dict = baseline_module.state_dict()
patch_deepseek_v3_moe(group_size_m=args.group_size) patch_deepseek_v3_moe(group_size_m=args.group_size, backend=args.backend)
patched_module = build_module(args) patched_module = build_module(args)
patched_module.load_state_dict(state_dict) patched_module.load_state_dict(state_dict)
@@ -250,6 +256,7 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
return { return {
"device": device, "device": device,
"backend": args.backend,
"dtype": dtype, "dtype": dtype,
"baseline_ms": baseline_ms, "baseline_ms": baseline_ms,
"patched_ms": patched_ms, "patched_ms": patched_ms,
@@ -270,7 +277,7 @@ def main() -> None: # pragma: no cover - CLI entrypoint
result = benchmark_deepseek_v3(args) result = benchmark_deepseek_v3(args)
print( print(
f"Device={result['device'].type} dtype={result['dtype']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}" f"Device={result['device'].type} dtype={result['dtype']} backend={result['backend']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
) )
print( print(
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}" f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"

View File

@@ -21,7 +21,7 @@ for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
else: # pragma: no cover - defensive guard else: # pragma: no cover - defensive guard
raise SystemExit("Unable to locate axolotl repository root for imports") raise SystemExit("Unable to locate axolotl repository root for imports")
from scripts.benchmarks.deepseek_v3_moe import ( from scripts.benchmarks.deepseek_v3_moe import ( # noqa: E402
ACCURACY_TOLERANCE, ACCURACY_TOLERANCE,
DTYPE_MAP, DTYPE_MAP,
benchmark_deepseek_v3, benchmark_deepseek_v3,
@@ -42,6 +42,12 @@ def parse_args() -> argparse.Namespace:
choices=["auto", "cpu", "cuda"], choices=["auto", "cpu", "cuda"],
help="Execution device", help="Execution device",
) )
parser.add_argument(
"--backend",
choices=["cg", "mg"],
default="mg",
help="MoE kernel backend to benchmark",
)
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations") parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations") parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
parser.add_argument("--seed", type=int, default=0, help="Random seed") parser.add_argument("--seed", type=int, default=0, help="Random seed")
@@ -105,6 +111,7 @@ def make_namespace(base: dict, args: argparse.Namespace) -> SimpleNamespace:
{ {
"dtype": args.dtype, "dtype": args.dtype,
"device": args.device, "device": args.device,
"backend": args.backend,
"warmup": args.warmup, "warmup": args.warmup,
"iters": args.iters, "iters": args.iters,
"seed": args.seed, "seed": args.seed,
@@ -164,6 +171,7 @@ def main() -> None: # pragma: no cover - utility script
"n_experts", "n_experts",
"top_k", "top_k",
"groups", "groups",
"backend",
"baseline_ms", "baseline_ms",
"patched_ms", "patched_ms",
"speedup", "speedup",
@@ -177,10 +185,10 @@ def main() -> None: # pragma: no cover - utility script
rows = [] rows = []
print( print(
f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}" f"Running sweep on device={args.device} dtype={args.dtype} backend={args.backend} uniform_routing={args.uniform_routing}"
) )
print( print(
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6}" f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}" f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}"
) )
@@ -206,6 +214,7 @@ def main() -> None: # pragma: no cover - utility script
cfg["n_experts"], cfg["n_experts"],
cfg["top_k"], cfg["top_k"],
cfg["groups"], cfg["groups"],
args.backend,
result["baseline_ms"], result["baseline_ms"],
result["patched_ms"], result["patched_ms"],
result["speedup"], result["speedup"],
@@ -219,7 +228,7 @@ def main() -> None: # pragma: no cover - utility script
) )
status = "OK" if result["accuracy_ok"] else "FAIL" status = "OK" if result["accuracy_ok"] else "FAIL"
print( print(
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6}" f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {args.backend:>8}"
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x" f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}" f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}"
) )

View File

@@ -8,6 +8,7 @@ from .tt_cg_gemm import (
cg_grouped_gemm_forward, cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic, cg_grouped_gemm_forward_dynamic,
) )
from .tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
__all__ = [ __all__ = [
"cg_grouped_gemm", "cg_grouped_gemm",
@@ -16,4 +17,5 @@ __all__ = [
"ContiguousGroupedGEMM", "ContiguousGroupedGEMM",
"ContiguousGroupedGEMMForwardOnly", "ContiguousGroupedGEMMForwardOnly",
"generate_permute_indices", "generate_permute_indices",
"mg_grouped_gemm",
] ]

View File

@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .mg_grouped_gemm import grouped_gemm_forward
from .tma_autotuning import ALIGN_SIZE_M
__all__ = [
"grouped_gemm_forward",
"ALIGN_SIZE_M",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,237 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
# pyre-unsafe
import os
import sys
from typing import Dict
import torch
import triton
import triton.language as tl
from triton.runtime import driver # @manual
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# ===== Supporting utils, CUDA and TMA =====
class CudaUtils:
@staticmethod
def is_cuda() -> bool:
"""Check if Triton is running on CUDA backend."""
return driver.active.get_current_target().backend == "cuda"
@staticmethod
def verify_tma() -> bool:
"""Check if TMA is supported on the current device."""
return (
CudaUtils.is_cuda()
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
)
@staticmethod
def get_num_sms() -> int:
"""Get the number of streaming multiprocessors on the current device."""
if not CudaUtils.is_cuda():
raise RuntimeError("Triton is not running on CUDA backend")
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
return torch.cuda.get_device_properties("cuda").multi_processor_count
class TmaDescriptorHelper:
"""Helper class for managing TMA descriptors in Triton kernels.
Args:
tma_size: Size of the TMA descriptor in bytes
"""
class KernelParamWrapper:
"""Wrapper to implement the TmaDescKernelParam interface."""
def __init__(self, desc: torch.Tensor):
self.desc = desc
def tma_desc_cpu_ptr(self) -> int:
"""Return the CPU pointer to the TMA descriptor."""
return self.desc.data_ptr()
def __init__(self, tma_size: int = 128):
if not CudaUtils.verify_tma():
raise RuntimeError(
"TMA not supported on this device (requires Hopper or newer)"
)
if "nv_tma_desc_type" not in dir(tl):
raise RuntimeError(
"TMA grid constant descriptors not supported in your Triton version"
)
self.tma_size = tma_size
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
self.descriptors: Dict[str, torch.Tensor] = {}
def init_tma_descriptor(self, name: str) -> None:
"""Initialize a TMA descriptor with the given name.
Call this method outside of the lambda function for grid size.
"""
self.descriptors[name] = torch.empty(
self.tma_size, device="cpu", dtype=torch.int8
)
def fill_1d_tma_descriptor(
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
) -> None:
"""Fill a 1D TMA descriptor.
Call this method inside the lambda function for grid size.
"""
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
def fill_2d_tma_descriptor(
self,
name: str,
ptr: int,
dim1: int,
dim0: int,
block_dim1: int,
block_dim0: int,
element_size: int,
) -> None:
"""Fill a 2D TMA descriptor.
Call this method inside the lambda function for grid size.
"""
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
"""Get the TMA descriptor kernel parameter for the given name."""
if name not in self.descriptors or self.descriptors[name] is None:
raise ValueError(f"TMA descriptor '{name}' not initialized")
return self.KernelParamWrapper(self.descriptors[name])
# ====== Autotuning utilities ======
ALIGN_SIZE_M = 128
_NV_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
},
num_stages=num_stages,
num_warps=num_warps,
num_ctas=num_ctas,
)
for block_size_m in [
ALIGN_SIZE_M,
]
for block_size_n in [64, 128, 256]
for block_size_k in [64, 128, 256]
for num_stages in [3, 4]
for num_warps in [4, 8]
for num_ctas in [1]
]
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
device = torch.cuda.current_device()
# Check for all possible pointer parameter names
if "grad_input_ptr" in named_args:
ptr_name = "grad_input_ptr"
elif "c_ptr" in named_args:
ptr_name = "c_ptr"
elif "grad_weight_ptr" in named_args:
ptr_name = "grad_weight_ptr"
else:
raise KeyError("No recognized pointer parameter found in kernel arguments")
if dtsize is None:
dtsize = named_args[ptr_name].element_size()
if dtype is None:
dtype = named_args[ptr_name].dtype
pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
kw["BLOCK_SIZE_M"],
kw["BLOCK_SIZE_N"],
kw["BLOCK_SIZE_K"],
config.num_stages,
)
G, M, N, K = (
named_args["G"],
named_args["M_BUCKET"],
named_args["N"],
named_args["K"],
)
# 1. make sure we have enough smem
max_shared_memory = driver.active.utils.get_device_properties(device)[
"max_shared_mem"
]
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory > max_shared_memory:
continue
M_PER_GROUP = M // G
MIN_M_TILES = 64
# 2. make sure we don't load M tiles that are too big
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
continue
# 3. make sure we don't load N tiles that are too small
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
continue
num_sm = driver.active.utils.get_device_properties(device)[
"multiprocessor_count"
]
N_TILES = N // BLOCK_N
MIN_N_TILES = 64
# 4. make sure we don't load N tiles that are too big
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
continue
# 5. make sure we don't load N tiles that are too small
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
continue
# 6. make sure K can be evenly divided
if K % BLOCK_K != 0:
continue
pruned_configs.append(config)
return pruned_configs
# ======== End Autotuning utilities ========

View File

@@ -193,7 +193,7 @@ class PatchManager:
if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3": if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3":
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
patch_deepseek_v3_moe() patch_deepseek_v3_moe(backend=self.cfg.moe_kernel_backend)
elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels: elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels:
LOG.info( LOG.info(
"Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`" "Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`"

View File

@@ -8,6 +8,7 @@ import torch
from axolotl.kernels.moe import ContiguousGroupedGEMM from axolotl.kernels.moe import ContiguousGroupedGEMM
from axolotl.kernels.moe.indices import generate_permute_indices from axolotl.kernels.moe.indices import generate_permute_indices
from axolotl.kernels.moe.tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
_GROUP_SIZE_M = 128 _GROUP_SIZE_M = 128
@@ -30,28 +31,36 @@ def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
def _ensure_combined_expert_weights( def _ensure_combined_expert_weights(
module, dtype: torch.dtype, device: torch.device module, dtype: torch.dtype, device: torch.device, backend: str
) -> None: ) -> None:
if not hasattr(module, "_axolotl_original_specs"): if not hasattr(module, "_axolotl_original_specs"):
module._axolotl_original_specs = {} module._axolotl_original_specs = {}
if getattr(module, "_axolotl_combined_weights", False): if not hasattr(module, "_axolotl_mg_shapes"):
# Move cached combined weights to the working dtype/device if required. module._axolotl_mg_shapes = {}
for name in _COMBINED_SUBMODULES:
param_name = f"{name}_weight"
param = module.get_parameter(param_name)
if param.device != device or param.dtype != dtype:
module._parameters[param_name] = torch.nn.Parameter(
param.to(device=device, dtype=dtype).contiguous()
)
module._axolotl_combined_dtype = dtype
module._axolotl_combined_device = device
return
combined = {} prev_backend = getattr(module, "_axolotl_combined_backend", None)
if getattr(module, "_axolotl_combined_weights", False):
if prev_backend != backend:
_restore_expert_weights(module)
else:
for name in _COMBINED_SUBMODULES:
param_name = f"{name}_weight"
param = module.get_parameter(param_name)
if param.device != device or param.dtype != dtype:
module._parameters[param_name] = torch.nn.Parameter(
param.to(device=device, dtype=dtype).contiguous()
)
module._axolotl_combined_dtype = dtype
module._axolotl_combined_device = device
module._axolotl_combined_backend = backend
return
module._axolotl_mg_shapes = {}
for name in _COMBINED_SUBMODULES: for name in _COMBINED_SUBMODULES:
weights = [] weights = []
orig_device = None orig_device = None
orig_dtype = None orig_dtype = None
orig_shape = None
for expert in module.experts: for expert in module.experts:
lin = expert.get_submodule(name) lin = expert.get_submodule(name)
weight_param = lin._parameters.get("weight") weight_param = lin._parameters.get("weight")
@@ -60,19 +69,24 @@ def _ensure_combined_expert_weights(
if orig_device is None: if orig_device is None:
orig_device = weight_param.device orig_device = weight_param.device
orig_dtype = weight_param.dtype orig_dtype = weight_param.dtype
orig_shape = tuple(weight_param.shape)
weights.append(weight_param.detach().to(device=device, dtype=dtype)) weights.append(weight_param.detach().to(device=device, dtype=dtype))
if "weight" in lin._parameters: if "weight" in lin._parameters:
del lin._parameters["weight"] del lin._parameters["weight"]
if "bias" in lin._parameters: if "bias" in lin._parameters:
# DeepseekV3 MLP layers are bias-free, but keep this for safety.
del lin._parameters["bias"] del lin._parameters["bias"]
combined[name] = torch.stack(weights, dim=0).contiguous() if backend == "cg":
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined[name])) combined_weight = torch.stack(weights, dim=0).contiguous()
module._axolotl_original_specs[name] = (orig_device, orig_dtype) else:
combined_weight = torch.cat(weights, dim=0).contiguous()
module._axolotl_mg_shapes[name] = orig_shape
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined_weight))
module._axolotl_original_specs[name] = (orig_device, orig_dtype, orig_shape)
module._axolotl_combined_weights = True module._axolotl_combined_weights = True
module._axolotl_combined_dtype = dtype module._axolotl_combined_dtype = dtype
module._axolotl_combined_device = device module._axolotl_combined_device = device
module._axolotl_combined_backend = backend
def _restore_expert_weights(module) -> None: def _restore_expert_weights(module) -> None:
@@ -82,19 +96,111 @@ def _restore_expert_weights(module) -> None:
for name in _COMBINED_SUBMODULES: for name in _COMBINED_SUBMODULES:
param_name = f"{name}_weight" param_name = f"{name}_weight"
combined = module._parameters.pop(param_name) combined = module._parameters.pop(param_name)
orig_device, orig_dtype = module._axolotl_original_specs.get( orig_device, orig_dtype, orig_shape = module._axolotl_original_specs.get(
name, (combined.device, combined.dtype) name, (combined.device, combined.dtype, None)
) )
rows_per = orig_shape[0] if orig_shape else None
for idx, expert in enumerate(module.experts): for idx, expert in enumerate(module.experts):
lin = expert.get_submodule(name) lin = expert.get_submodule(name)
if combined.dim() == 3:
slice_tensor = combined[idx]
elif rows_per is not None:
start = idx * rows_per
end = start + rows_per
slice_tensor = combined[start:end]
else:
raise RuntimeError(
"Unable to recover expert weight shape during restore"
)
lin._parameters["weight"] = torch.nn.Parameter( lin._parameters["weight"] = torch.nn.Parameter(
combined[idx].detach().clone().to(orig_device, dtype=orig_dtype) slice_tensor.detach().clone().to(orig_device, dtype=orig_dtype)
) )
module._axolotl_combined_weights = False module._axolotl_combined_weights = False
module._axolotl_combined_dtype = None module._axolotl_combined_dtype = None
module._axolotl_combined_device = None module._axolotl_combined_device = None
module._axolotl_combined_backend = None
module._axolotl_original_specs = {} module._axolotl_original_specs = {}
module._axolotl_mg_shapes = {}
def _run_cg_grouped_gemm(
module,
grouped_hidden: torch.Tensor,
m_sizes: torch.Tensor,
num_experts: int,
group_size_m: int,
hidden_dtype: torch.dtype,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
_ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg")
expert_index_tensor = torch.repeat_interleave(
torch.arange(num_experts, device=device, dtype=torch.int32),
m_sizes.to(torch.int64),
)
gate_weights = module.get_parameter("gate_proj_weight")
if gate_weights.dim() == 2:
out_dim = gate_weights.shape[0] // num_experts
gate_weights = gate_weights.view(num_experts, out_dim, gate_weights.shape[1])
up_weights = module.get_parameter("up_proj_weight")
if up_weights.dim() == 2:
out_dim = up_weights.shape[0] // num_experts
up_weights = up_weights.view(num_experts, out_dim, up_weights.shape[1])
down_weights = module.get_parameter("down_proj_weight")
if down_weights.dim() == 2:
out_dim = down_weights.shape[0] // num_experts
down_weights = down_weights.view(num_experts, out_dim, down_weights.shape[1])
gate_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
gate_weights,
expert_index_tensor,
group_size_m,
)
up_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
up_weights,
expert_index_tensor,
group_size_m,
)
down_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
down_weights,
expert_index_tensor,
group_size_m,
)
return (
gate_out.to(hidden_dtype),
up_out.to(hidden_dtype),
down_out.to(hidden_dtype),
)
gate_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("gate_proj_weight"),
m_sizes_tensor,
)
up_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("up_proj_weight"),
m_sizes_tensor,
)
down_out = mg_grouped_gemm(
hidden_grouped,
module.get_parameter("down_proj_weight"),
m_sizes_tensor,
)
return (
gate_out.to(hidden_dtype),
up_out.to(hidden_dtype),
down_out.to(hidden_dtype),
)
def _moe_triton_forward( def _moe_triton_forward(
@@ -103,6 +209,7 @@ def _moe_triton_forward(
topk_indices: torch.Tensor, topk_indices: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
group_size_m: int, group_size_m: int,
backend: str,
fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
if not _is_triton_eligible(hidden_states): if not _is_triton_eligible(hidden_states):
@@ -146,7 +253,7 @@ def _moe_triton_forward(
) * group_size_m ) * group_size_m
max_len = int(aligned_counts.sum().item()) max_len = int(aligned_counts.sum().item())
permuted_indices, m_sizes, m_offsets = generate_permute_indices( permuted_indices, m_sizes, _ = generate_permute_indices(
counts_int.to(device), counts_int.to(device),
experts_per_rank=num_experts, experts_per_rank=num_experts,
num_ranks=1, num_ranks=1,
@@ -155,12 +262,8 @@ def _moe_triton_forward(
use_cpu=not hidden_states.is_cuda, use_cpu=not hidden_states.is_cuda,
) )
if permuted_indices.device != device: permuted_indices = permuted_indices.to(device)
permuted_indices = permuted_indices.to(device) m_sizes = m_sizes.to(device)
if m_sizes.device != device:
m_sizes = m_sizes.to(device)
if m_offsets.device != device:
m_offsets = m_offsets.to(device)
permuted_indices_long = permuted_indices.to(torch.int64) permuted_indices_long = permuted_indices.to(torch.int64)
valid_mask = permuted_indices_long >= 0 valid_mask = permuted_indices_long >= 0
@@ -178,34 +281,35 @@ def _moe_triton_forward(
if valid_positions.numel() < max_len: if valid_positions.numel() < max_len:
grouped_hidden.index_fill_(0, padded_positions, 0) grouped_hidden.index_fill_(0, padded_positions, 0)
expert_index_tensor = torch.repeat_interleave( m_sizes_tensor = m_sizes.to(device=device, dtype=torch.int32)
torch.arange(num_experts, device=device, dtype=torch.int32),
m_sizes.to(torch.int64),
)
_ensure_combined_expert_weights(module, hidden_dtype, device) if backend == "mg":
_ensure_combined_expert_weights(module, hidden_dtype, device, backend)
gate_weights = module.get_parameter("gate_proj_weight") gate_out = mg_grouped_gemm(
up_weights = module.get_parameter("up_proj_weight") grouped_hidden,
down_weights = module.get_parameter("down_proj_weight") module.get_parameter("gate_proj_weight"),
m_sizes_tensor,
gate_out = ContiguousGroupedGEMM.apply( ).to(hidden_dtype)
grouped_hidden, up_out = mg_grouped_gemm(
gate_weights, grouped_hidden,
expert_index_tensor, module.get_parameter("up_proj_weight"),
group_size_m, m_sizes_tensor,
) ).to(hidden_dtype)
up_out = ContiguousGroupedGEMM.apply( else:
grouped_hidden, gate_out, up_out, down_out_cg = _run_cg_grouped_gemm(
up_weights, module,
expert_index_tensor, grouped_hidden,
group_size_m, m_sizes,
) num_experts,
group_size_m,
hidden_dtype,
device,
)
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
if valid_positions.numel() > 0: if valid_positions.numel() > 0:
gate_valid = gate_out.index_select(0, valid_positions).to(hidden_dtype) gate_valid = gate_out.index_select(0, valid_positions)
up_valid = up_out.index_select(0, valid_positions).to(hidden_dtype) up_valid = up_out.index_select(0, valid_positions)
hidden_concat = act_fn(gate_valid) * up_valid hidden_concat = act_fn(gate_valid) * up_valid
else: else:
hidden_concat = torch.empty( hidden_concat = torch.empty(
@@ -219,15 +323,17 @@ def _moe_triton_forward(
if valid_positions.numel() < max_len: if valid_positions.numel() < max_len:
hidden_grouped.index_fill_(0, padded_positions, 0) hidden_grouped.index_fill_(0, padded_positions, 0)
down_out = ContiguousGroupedGEMM.apply( if backend == "mg":
hidden_grouped, down_out = mg_grouped_gemm(
down_weights, hidden_grouped,
expert_index_tensor, module.get_parameter("down_proj_weight"),
group_size_m, m_sizes_tensor,
) ).to(hidden_dtype)
else:
down_out = down_out_cg
if valid_positions.numel() > 0: if valid_positions.numel() > 0:
down_valid = down_out.index_select(0, valid_positions).to(hidden_dtype) down_valid = down_out.index_select(0, valid_positions)
else: else:
down_valid = torch.empty( down_valid = torch.empty(
(0, down_out.shape[-1]), device=device, dtype=hidden_dtype (0, down_out.shape[-1]), device=device, dtype=hidden_dtype
@@ -245,11 +351,16 @@ def _moe_triton_forward(
return weighted.sum(dim=1) return weighted.sum(dim=1)
def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None: def patch_deepseek_v3_moe(
group_size_m: int = _GROUP_SIZE_M, backend: str = "mg"
) -> None:
"""Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels.""" """Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels."""
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
if backend not in {"cg", "mg"}:
raise ValueError(f"Unsupported MoE kernel backend: {backend}")
# Record the unpatched implementation so callers can access a true baseline even # Record the unpatched implementation so callers can access a true baseline even
# after the Triton patch has been applied (e.g. repeated microbenchmarks). # after the Triton patch has been applied (e.g. repeated microbenchmarks).
if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"): if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"):
@@ -259,26 +370,26 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
return return
original_moe = DeepseekV3MoE._axolotl_triton_original_moe original_moe = DeepseekV3MoE._axolotl_triton_original_moe
DeepseekV3MoE._axolotl_triton_backend = backend
DeepseekV3MoE._axolotl_group_size_m = group_size_m
def patched_moe(self, hidden_states, topk_indices, topk_weights): def patched_moe(self, hidden_states, topk_indices, topk_weights):
backend_sel = getattr(self, "_axolotl_triton_backend", backend)
group_size_sel = getattr(self, "_axolotl_group_size_m", group_size_m)
try: try:
return _moe_triton_forward( return _moe_triton_forward(
self, self,
hidden_states, hidden_states,
topk_indices, topk_indices,
topk_weights, topk_weights,
group_size_m, group_size_sel,
backend_sel,
original_moe, original_moe,
) )
except Exception as err: # fall back if Triton compilation or runtime fails except Exception as err: # surface Triton failures explicitly
if not getattr(self, "_axolotl_triton_warned", False):
LOG.warning(
"DeepseekV3MoE Triton path failed; falling back to baseline: %s",
err,
)
self._axolotl_triton_warned = True
_restore_expert_weights(self) _restore_expert_weights(self)
return original_moe(self, hidden_states, topk_indices, topk_weights) LOG.error("DeepseekV3MoE Triton path failed: %s", err)
raise
DeepseekV3MoE.moe = patched_moe DeepseekV3MoE.moe = patched_moe
DeepseekV3MoE._axolotl_triton_patch = True DeepseekV3MoE._axolotl_triton_patch = True

View File

@@ -119,6 +119,12 @@ class AxolotlInputConfig(
"description": "Enable Axolotl's vendored MoE kernels when supported (e.g., DeepSeek V3)" "description": "Enable Axolotl's vendored MoE kernels when supported (e.g., DeepSeek V3)"
}, },
) )
moe_kernel_backend: Literal["cg", "mg"] | None = Field(
default="mg",
json_schema_extra={
"description": "Grouped GEMM backend to use when `moe_kernels` is enabled. `mg` selects the Hopper TMA kernel; `cg` selects the contiguous kernel."
},
)
trainer_cls: str | None = Field( trainer_cls: str | None = Field(
default=None, default=None,