add mg kernel backend

This commit is contained in:
Dan Saunders
2025-09-23 15:43:16 -04:00
parent 8a1f5ae940
commit d0da67eb17
9 changed files with 1753 additions and 77 deletions

View File

@@ -8,6 +8,7 @@ from .tt_cg_gemm import (
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
from .tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
__all__ = [
"cg_grouped_gemm",
@@ -16,4 +17,5 @@ __all__ = [
"ContiguousGroupedGEMM",
"ContiguousGroupedGEMMForwardOnly",
"generate_permute_indices",
"mg_grouped_gemm",
]

View File

@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .mg_grouped_gemm import grouped_gemm_forward
from .tma_autotuning import ALIGN_SIZE_M
__all__ = [
"grouped_gemm_forward",
"ALIGN_SIZE_M",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,237 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
# pyre-unsafe
import os
import sys
from typing import Dict
import torch
import triton
import triton.language as tl
from triton.runtime import driver # @manual
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# ===== Supporting utils, CUDA and TMA =====
class CudaUtils:
@staticmethod
def is_cuda() -> bool:
"""Check if Triton is running on CUDA backend."""
return driver.active.get_current_target().backend == "cuda"
@staticmethod
def verify_tma() -> bool:
"""Check if TMA is supported on the current device."""
return (
CudaUtils.is_cuda()
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 9
)
@staticmethod
def get_num_sms() -> int:
"""Get the number of streaming multiprocessors on the current device."""
if not CudaUtils.is_cuda():
raise RuntimeError("Triton is not running on CUDA backend")
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
return torch.cuda.get_device_properties("cuda").multi_processor_count
class TmaDescriptorHelper:
"""Helper class for managing TMA descriptors in Triton kernels.
Args:
tma_size: Size of the TMA descriptor in bytes
"""
class KernelParamWrapper:
"""Wrapper to implement the TmaDescKernelParam interface."""
def __init__(self, desc: torch.Tensor):
self.desc = desc
def tma_desc_cpu_ptr(self) -> int:
"""Return the CPU pointer to the TMA descriptor."""
return self.desc.data_ptr()
def __init__(self, tma_size: int = 128):
if not CudaUtils.verify_tma():
raise RuntimeError(
"TMA not supported on this device (requires Hopper or newer)"
)
if "nv_tma_desc_type" not in dir(tl):
raise RuntimeError(
"TMA grid constant descriptors not supported in your Triton version"
)
self.tma_size = tma_size
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
self.descriptors: Dict[str, torch.Tensor] = {}
def init_tma_descriptor(self, name: str) -> None:
"""Initialize a TMA descriptor with the given name.
Call this method outside of the lambda function for grid size.
"""
self.descriptors[name] = torch.empty(
self.tma_size, device="cpu", dtype=torch.int8
)
def fill_1d_tma_descriptor(
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
) -> None:
"""Fill a 1D TMA descriptor.
Call this method inside the lambda function for grid size.
"""
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
def fill_2d_tma_descriptor(
self,
name: str,
ptr: int,
dim1: int,
dim0: int,
block_dim1: int,
block_dim0: int,
element_size: int,
) -> None:
"""Fill a 2D TMA descriptor.
Call this method inside the lambda function for grid size.
"""
if name not in self.descriptors:
raise ValueError(f"TMA descriptor '{name}' not initialized")
desc_x = self.descriptors[name]
if desc_x.data_ptr() % 64 != 0:
raise ValueError("TMA descriptor must be 64-byte aligned")
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
"""Get the TMA descriptor kernel parameter for the given name."""
if name not in self.descriptors or self.descriptors[name] is None:
raise ValueError(f"TMA descriptor '{name}' not initialized")
return self.KernelParamWrapper(self.descriptors[name])
# ====== Autotuning utilities ======
ALIGN_SIZE_M = 128
_NV_CONFIGS = [
triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
},
num_stages=num_stages,
num_warps=num_warps,
num_ctas=num_ctas,
)
for block_size_m in [
ALIGN_SIZE_M,
]
for block_size_n in [64, 128, 256]
for block_size_k in [64, 128, 256]
for num_stages in [3, 4]
for num_warps in [4, 8]
for num_ctas in [1]
]
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
device = torch.cuda.current_device()
# Check for all possible pointer parameter names
if "grad_input_ptr" in named_args:
ptr_name = "grad_input_ptr"
elif "c_ptr" in named_args:
ptr_name = "c_ptr"
elif "grad_weight_ptr" in named_args:
ptr_name = "grad_weight_ptr"
else:
raise KeyError("No recognized pointer parameter found in kernel arguments")
if dtsize is None:
dtsize = named_args[ptr_name].element_size()
if dtype is None:
dtype = named_args[ptr_name].dtype
pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
kw["BLOCK_SIZE_M"],
kw["BLOCK_SIZE_N"],
kw["BLOCK_SIZE_K"],
config.num_stages,
)
G, M, N, K = (
named_args["G"],
named_args["M_BUCKET"],
named_args["N"],
named_args["K"],
)
# 1. make sure we have enough smem
max_shared_memory = driver.active.utils.get_device_properties(device)[
"max_shared_mem"
]
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory > max_shared_memory:
continue
M_PER_GROUP = M // G
MIN_M_TILES = 64
# 2. make sure we don't load M tiles that are too big
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
continue
# 3. make sure we don't load N tiles that are too small
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
continue
num_sm = driver.active.utils.get_device_properties(device)[
"multiprocessor_count"
]
N_TILES = N // BLOCK_N
MIN_N_TILES = 64
# 4. make sure we don't load N tiles that are too big
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
continue
# 5. make sure we don't load N tiles that are too small
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
continue
# 6. make sure K can be evenly divided
if K % BLOCK_K != 0:
continue
pruned_configs.append(config)
return pruned_configs
# ======== End Autotuning utilities ========

View File

@@ -193,7 +193,7 @@ class PatchManager:
if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3":
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
patch_deepseek_v3_moe()
patch_deepseek_v3_moe(backend=self.cfg.moe_kernel_backend)
elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels:
LOG.info(
"Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`"

View File

@@ -8,6 +8,7 @@ import torch
from axolotl.kernels.moe import ContiguousGroupedGEMM
from axolotl.kernels.moe.indices import generate_permute_indices
from axolotl.kernels.moe.tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
from axolotl.utils.logging import get_logger
_GROUP_SIZE_M = 128
@@ -30,28 +31,36 @@ def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
def _ensure_combined_expert_weights(
module, dtype: torch.dtype, device: torch.device
module, dtype: torch.dtype, device: torch.device, backend: str
) -> None:
if not hasattr(module, "_axolotl_original_specs"):
module._axolotl_original_specs = {}
if getattr(module, "_axolotl_combined_weights", False):
# Move cached combined weights to the working dtype/device if required.
for name in _COMBINED_SUBMODULES:
param_name = f"{name}_weight"
param = module.get_parameter(param_name)
if param.device != device or param.dtype != dtype:
module._parameters[param_name] = torch.nn.Parameter(
param.to(device=device, dtype=dtype).contiguous()
)
module._axolotl_combined_dtype = dtype
module._axolotl_combined_device = device
return
if not hasattr(module, "_axolotl_mg_shapes"):
module._axolotl_mg_shapes = {}
combined = {}
prev_backend = getattr(module, "_axolotl_combined_backend", None)
if getattr(module, "_axolotl_combined_weights", False):
if prev_backend != backend:
_restore_expert_weights(module)
else:
for name in _COMBINED_SUBMODULES:
param_name = f"{name}_weight"
param = module.get_parameter(param_name)
if param.device != device or param.dtype != dtype:
module._parameters[param_name] = torch.nn.Parameter(
param.to(device=device, dtype=dtype).contiguous()
)
module._axolotl_combined_dtype = dtype
module._axolotl_combined_device = device
module._axolotl_combined_backend = backend
return
module._axolotl_mg_shapes = {}
for name in _COMBINED_SUBMODULES:
weights = []
orig_device = None
orig_dtype = None
orig_shape = None
for expert in module.experts:
lin = expert.get_submodule(name)
weight_param = lin._parameters.get("weight")
@@ -60,19 +69,24 @@ def _ensure_combined_expert_weights(
if orig_device is None:
orig_device = weight_param.device
orig_dtype = weight_param.dtype
orig_shape = tuple(weight_param.shape)
weights.append(weight_param.detach().to(device=device, dtype=dtype))
if "weight" in lin._parameters:
del lin._parameters["weight"]
if "bias" in lin._parameters:
# DeepseekV3 MLP layers are bias-free, but keep this for safety.
del lin._parameters["bias"]
combined[name] = torch.stack(weights, dim=0).contiguous()
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined[name]))
module._axolotl_original_specs[name] = (orig_device, orig_dtype)
if backend == "cg":
combined_weight = torch.stack(weights, dim=0).contiguous()
else:
combined_weight = torch.cat(weights, dim=0).contiguous()
module._axolotl_mg_shapes[name] = orig_shape
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined_weight))
module._axolotl_original_specs[name] = (orig_device, orig_dtype, orig_shape)
module._axolotl_combined_weights = True
module._axolotl_combined_dtype = dtype
module._axolotl_combined_device = device
module._axolotl_combined_backend = backend
def _restore_expert_weights(module) -> None:
@@ -82,19 +96,111 @@ def _restore_expert_weights(module) -> None:
for name in _COMBINED_SUBMODULES:
param_name = f"{name}_weight"
combined = module._parameters.pop(param_name)
orig_device, orig_dtype = module._axolotl_original_specs.get(
name, (combined.device, combined.dtype)
orig_device, orig_dtype, orig_shape = module._axolotl_original_specs.get(
name, (combined.device, combined.dtype, None)
)
rows_per = orig_shape[0] if orig_shape else None
for idx, expert in enumerate(module.experts):
lin = expert.get_submodule(name)
if combined.dim() == 3:
slice_tensor = combined[idx]
elif rows_per is not None:
start = idx * rows_per
end = start + rows_per
slice_tensor = combined[start:end]
else:
raise RuntimeError(
"Unable to recover expert weight shape during restore"
)
lin._parameters["weight"] = torch.nn.Parameter(
combined[idx].detach().clone().to(orig_device, dtype=orig_dtype)
slice_tensor.detach().clone().to(orig_device, dtype=orig_dtype)
)
module._axolotl_combined_weights = False
module._axolotl_combined_dtype = None
module._axolotl_combined_device = None
module._axolotl_combined_backend = None
module._axolotl_original_specs = {}
module._axolotl_mg_shapes = {}
def _run_cg_grouped_gemm(
module,
grouped_hidden: torch.Tensor,
m_sizes: torch.Tensor,
num_experts: int,
group_size_m: int,
hidden_dtype: torch.dtype,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
_ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg")
expert_index_tensor = torch.repeat_interleave(
torch.arange(num_experts, device=device, dtype=torch.int32),
m_sizes.to(torch.int64),
)
gate_weights = module.get_parameter("gate_proj_weight")
if gate_weights.dim() == 2:
out_dim = gate_weights.shape[0] // num_experts
gate_weights = gate_weights.view(num_experts, out_dim, gate_weights.shape[1])
up_weights = module.get_parameter("up_proj_weight")
if up_weights.dim() == 2:
out_dim = up_weights.shape[0] // num_experts
up_weights = up_weights.view(num_experts, out_dim, up_weights.shape[1])
down_weights = module.get_parameter("down_proj_weight")
if down_weights.dim() == 2:
out_dim = down_weights.shape[0] // num_experts
down_weights = down_weights.view(num_experts, out_dim, down_weights.shape[1])
gate_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
gate_weights,
expert_index_tensor,
group_size_m,
)
up_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
up_weights,
expert_index_tensor,
group_size_m,
)
down_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
down_weights,
expert_index_tensor,
group_size_m,
)
return (
gate_out.to(hidden_dtype),
up_out.to(hidden_dtype),
down_out.to(hidden_dtype),
)
gate_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("gate_proj_weight"),
m_sizes_tensor,
)
up_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("up_proj_weight"),
m_sizes_tensor,
)
down_out = mg_grouped_gemm(
hidden_grouped,
module.get_parameter("down_proj_weight"),
m_sizes_tensor,
)
return (
gate_out.to(hidden_dtype),
up_out.to(hidden_dtype),
down_out.to(hidden_dtype),
)
def _moe_triton_forward(
@@ -103,6 +209,7 @@ def _moe_triton_forward(
topk_indices: torch.Tensor,
topk_weights: torch.Tensor,
group_size_m: int,
backend: str,
fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
) -> torch.Tensor:
if not _is_triton_eligible(hidden_states):
@@ -146,7 +253,7 @@ def _moe_triton_forward(
) * group_size_m
max_len = int(aligned_counts.sum().item())
permuted_indices, m_sizes, m_offsets = generate_permute_indices(
permuted_indices, m_sizes, _ = generate_permute_indices(
counts_int.to(device),
experts_per_rank=num_experts,
num_ranks=1,
@@ -155,12 +262,8 @@ def _moe_triton_forward(
use_cpu=not hidden_states.is_cuda,
)
if permuted_indices.device != device:
permuted_indices = permuted_indices.to(device)
if m_sizes.device != device:
m_sizes = m_sizes.to(device)
if m_offsets.device != device:
m_offsets = m_offsets.to(device)
permuted_indices = permuted_indices.to(device)
m_sizes = m_sizes.to(device)
permuted_indices_long = permuted_indices.to(torch.int64)
valid_mask = permuted_indices_long >= 0
@@ -178,34 +281,35 @@ def _moe_triton_forward(
if valid_positions.numel() < max_len:
grouped_hidden.index_fill_(0, padded_positions, 0)
expert_index_tensor = torch.repeat_interleave(
torch.arange(num_experts, device=device, dtype=torch.int32),
m_sizes.to(torch.int64),
)
m_sizes_tensor = m_sizes.to(device=device, dtype=torch.int32)
_ensure_combined_expert_weights(module, hidden_dtype, device)
gate_weights = module.get_parameter("gate_proj_weight")
up_weights = module.get_parameter("up_proj_weight")
down_weights = module.get_parameter("down_proj_weight")
gate_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
gate_weights,
expert_index_tensor,
group_size_m,
)
up_out = ContiguousGroupedGEMM.apply(
grouped_hidden,
up_weights,
expert_index_tensor,
group_size_m,
)
if backend == "mg":
_ensure_combined_expert_weights(module, hidden_dtype, device, backend)
gate_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("gate_proj_weight"),
m_sizes_tensor,
).to(hidden_dtype)
up_out = mg_grouped_gemm(
grouped_hidden,
module.get_parameter("up_proj_weight"),
m_sizes_tensor,
).to(hidden_dtype)
else:
gate_out, up_out, down_out_cg = _run_cg_grouped_gemm(
module,
grouped_hidden,
m_sizes,
num_experts,
group_size_m,
hidden_dtype,
device,
)
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
if valid_positions.numel() > 0:
gate_valid = gate_out.index_select(0, valid_positions).to(hidden_dtype)
up_valid = up_out.index_select(0, valid_positions).to(hidden_dtype)
gate_valid = gate_out.index_select(0, valid_positions)
up_valid = up_out.index_select(0, valid_positions)
hidden_concat = act_fn(gate_valid) * up_valid
else:
hidden_concat = torch.empty(
@@ -219,15 +323,17 @@ def _moe_triton_forward(
if valid_positions.numel() < max_len:
hidden_grouped.index_fill_(0, padded_positions, 0)
down_out = ContiguousGroupedGEMM.apply(
hidden_grouped,
down_weights,
expert_index_tensor,
group_size_m,
)
if backend == "mg":
down_out = mg_grouped_gemm(
hidden_grouped,
module.get_parameter("down_proj_weight"),
m_sizes_tensor,
).to(hidden_dtype)
else:
down_out = down_out_cg
if valid_positions.numel() > 0:
down_valid = down_out.index_select(0, valid_positions).to(hidden_dtype)
down_valid = down_out.index_select(0, valid_positions)
else:
down_valid = torch.empty(
(0, down_out.shape[-1]), device=device, dtype=hidden_dtype
@@ -245,11 +351,16 @@ def _moe_triton_forward(
return weighted.sum(dim=1)
def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
def patch_deepseek_v3_moe(
group_size_m: int = _GROUP_SIZE_M, backend: str = "mg"
) -> None:
"""Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels."""
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
if backend not in {"cg", "mg"}:
raise ValueError(f"Unsupported MoE kernel backend: {backend}")
# Record the unpatched implementation so callers can access a true baseline even
# after the Triton patch has been applied (e.g. repeated microbenchmarks).
if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"):
@@ -259,26 +370,26 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
return
original_moe = DeepseekV3MoE._axolotl_triton_original_moe
DeepseekV3MoE._axolotl_triton_backend = backend
DeepseekV3MoE._axolotl_group_size_m = group_size_m
def patched_moe(self, hidden_states, topk_indices, topk_weights):
backend_sel = getattr(self, "_axolotl_triton_backend", backend)
group_size_sel = getattr(self, "_axolotl_group_size_m", group_size_m)
try:
return _moe_triton_forward(
self,
hidden_states,
topk_indices,
topk_weights,
group_size_m,
group_size_sel,
backend_sel,
original_moe,
)
except Exception as err: # fall back if Triton compilation or runtime fails
if not getattr(self, "_axolotl_triton_warned", False):
LOG.warning(
"DeepseekV3MoE Triton path failed; falling back to baseline: %s",
err,
)
self._axolotl_triton_warned = True
except Exception as err: # surface Triton failures explicitly
_restore_expert_weights(self)
return original_moe(self, hidden_states, topk_indices, topk_weights)
LOG.error("DeepseekV3MoE Triton path failed: %s", err)
raise
DeepseekV3MoE.moe = patched_moe
DeepseekV3MoE._axolotl_triton_patch = True

View File

@@ -119,6 +119,12 @@ class AxolotlInputConfig(
"description": "Enable Axolotl's vendored MoE kernels when supported (e.g., DeepSeek V3)"
},
)
moe_kernel_backend: Literal["cg", "mg"] | None = Field(
default="mg",
json_schema_extra={
"description": "Grouped GEMM backend to use when `moe_kernels` is enabled. `mg` selects the Hopper TMA kernel; `cg` selects the contiguous kernel."
},
)
trainer_cls: str | None = Field(
default=None,