add mg kernel backend
This commit is contained in:
@@ -8,6 +8,7 @@ from .tt_cg_gemm import (
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
from .tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
|
||||
|
||||
__all__ = [
|
||||
"cg_grouped_gemm",
|
||||
@@ -16,4 +17,5 @@ __all__ = [
|
||||
"ContiguousGroupedGEMM",
|
||||
"ContiguousGroupedGEMMForwardOnly",
|
||||
"generate_permute_indices",
|
||||
"mg_grouped_gemm",
|
||||
]
|
||||
|
||||
13
src/axolotl/kernels/moe/tt_mg_gemm/__init__.py
Normal file
13
src/axolotl/kernels/moe/tt_mg_gemm/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .mg_grouped_gemm import grouped_gemm_forward
|
||||
from .tma_autotuning import ALIGN_SIZE_M
|
||||
|
||||
__all__ = [
|
||||
"grouped_gemm_forward",
|
||||
"ALIGN_SIZE_M",
|
||||
]
|
||||
1291
src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py
Normal file
1291
src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py
Normal file
File diff suppressed because it is too large
Load Diff
237
src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py
Normal file
237
src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
|
||||
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.runtime import driver # @manual
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
# ===== Supporting utils, CUDA and TMA =====
|
||||
|
||||
|
||||
class CudaUtils:
|
||||
@staticmethod
|
||||
def is_cuda() -> bool:
|
||||
"""Check if Triton is running on CUDA backend."""
|
||||
return driver.active.get_current_target().backend == "cuda"
|
||||
|
||||
@staticmethod
|
||||
def verify_tma() -> bool:
|
||||
"""Check if TMA is supported on the current device."""
|
||||
return (
|
||||
CudaUtils.is_cuda()
|
||||
and torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_num_sms() -> int:
|
||||
"""Get the number of streaming multiprocessors on the current device."""
|
||||
if not CudaUtils.is_cuda():
|
||||
raise RuntimeError("Triton is not running on CUDA backend")
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is not available")
|
||||
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
|
||||
|
||||
class TmaDescriptorHelper:
|
||||
"""Helper class for managing TMA descriptors in Triton kernels.
|
||||
|
||||
Args:
|
||||
tma_size: Size of the TMA descriptor in bytes
|
||||
"""
|
||||
|
||||
class KernelParamWrapper:
|
||||
"""Wrapper to implement the TmaDescKernelParam interface."""
|
||||
|
||||
def __init__(self, desc: torch.Tensor):
|
||||
self.desc = desc
|
||||
|
||||
def tma_desc_cpu_ptr(self) -> int:
|
||||
"""Return the CPU pointer to the TMA descriptor."""
|
||||
return self.desc.data_ptr()
|
||||
|
||||
def __init__(self, tma_size: int = 128):
|
||||
if not CudaUtils.verify_tma():
|
||||
raise RuntimeError(
|
||||
"TMA not supported on this device (requires Hopper or newer)"
|
||||
)
|
||||
if "nv_tma_desc_type" not in dir(tl):
|
||||
raise RuntimeError(
|
||||
"TMA grid constant descriptors not supported in your Triton version"
|
||||
)
|
||||
|
||||
self.tma_size = tma_size
|
||||
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
|
||||
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
|
||||
self.descriptors: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def init_tma_descriptor(self, name: str) -> None:
|
||||
"""Initialize a TMA descriptor with the given name.
|
||||
|
||||
Call this method outside of the lambda function for grid size.
|
||||
"""
|
||||
self.descriptors[name] = torch.empty(
|
||||
self.tma_size, device="cpu", dtype=torch.int8
|
||||
)
|
||||
|
||||
def fill_1d_tma_descriptor(
|
||||
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
|
||||
) -> None:
|
||||
"""Fill a 1D TMA descriptor.
|
||||
|
||||
Call this method inside the lambda function for grid size.
|
||||
"""
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_1d_tma_descriptor_inner(
|
||||
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def fill_2d_tma_descriptor(
|
||||
self,
|
||||
name: str,
|
||||
ptr: int,
|
||||
dim1: int,
|
||||
dim0: int,
|
||||
block_dim1: int,
|
||||
block_dim0: int,
|
||||
element_size: int,
|
||||
) -> None:
|
||||
"""Fill a 2D TMA descriptor.
|
||||
|
||||
Call this method inside the lambda function for grid size.
|
||||
"""
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_2d_tma_descriptor_inner(
|
||||
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
|
||||
"""Get the TMA descriptor kernel parameter for the given name."""
|
||||
if name not in self.descriptors or self.descriptors[name] is None:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
return self.KernelParamWrapper(self.descriptors[name])
|
||||
|
||||
|
||||
# ====== Autotuning utilities ======
|
||||
ALIGN_SIZE_M = 128
|
||||
|
||||
_NV_CONFIGS = [
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_size_m,
|
||||
"BLOCK_SIZE_N": block_size_n,
|
||||
"BLOCK_SIZE_K": block_size_k,
|
||||
},
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
num_ctas=num_ctas,
|
||||
)
|
||||
for block_size_m in [
|
||||
ALIGN_SIZE_M,
|
||||
]
|
||||
for block_size_n in [64, 128, 256]
|
||||
for block_size_k in [64, 128, 256]
|
||||
for num_stages in [3, 4]
|
||||
for num_warps in [4, 8]
|
||||
for num_ctas in [1]
|
||||
]
|
||||
|
||||
|
||||
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
|
||||
device = torch.cuda.current_device()
|
||||
# Check for all possible pointer parameter names
|
||||
if "grad_input_ptr" in named_args:
|
||||
ptr_name = "grad_input_ptr"
|
||||
elif "c_ptr" in named_args:
|
||||
ptr_name = "c_ptr"
|
||||
elif "grad_weight_ptr" in named_args:
|
||||
ptr_name = "grad_weight_ptr"
|
||||
else:
|
||||
raise KeyError("No recognized pointer parameter found in kernel arguments")
|
||||
|
||||
if dtsize is None:
|
||||
dtsize = named_args[ptr_name].element_size()
|
||||
if dtype is None:
|
||||
dtype = named_args[ptr_name].dtype
|
||||
|
||||
pruned_configs = []
|
||||
for config in configs:
|
||||
kw = config.kwargs
|
||||
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
|
||||
kw["BLOCK_SIZE_M"],
|
||||
kw["BLOCK_SIZE_N"],
|
||||
kw["BLOCK_SIZE_K"],
|
||||
config.num_stages,
|
||||
)
|
||||
G, M, N, K = (
|
||||
named_args["G"],
|
||||
named_args["M_BUCKET"],
|
||||
named_args["N"],
|
||||
named_args["K"],
|
||||
)
|
||||
|
||||
# 1. make sure we have enough smem
|
||||
max_shared_memory = driver.active.utils.get_device_properties(device)[
|
||||
"max_shared_mem"
|
||||
]
|
||||
|
||||
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
||||
if required_shared_memory > max_shared_memory:
|
||||
continue
|
||||
|
||||
M_PER_GROUP = M // G
|
||||
MIN_M_TILES = 64
|
||||
# 2. make sure we don't load M tiles that are too big
|
||||
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
|
||||
continue
|
||||
# 3. make sure we don't load N tiles that are too small
|
||||
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
||||
continue
|
||||
|
||||
num_sm = driver.active.utils.get_device_properties(device)[
|
||||
"multiprocessor_count"
|
||||
]
|
||||
N_TILES = N // BLOCK_N
|
||||
MIN_N_TILES = 64
|
||||
# 4. make sure we don't load N tiles that are too big
|
||||
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
|
||||
continue
|
||||
# 5. make sure we don't load N tiles that are too small
|
||||
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
||||
continue
|
||||
# 6. make sure K can be evenly divided
|
||||
if K % BLOCK_K != 0:
|
||||
continue
|
||||
|
||||
pruned_configs.append(config)
|
||||
|
||||
return pruned_configs
|
||||
|
||||
|
||||
# ======== End Autotuning utilities ========
|
||||
@@ -193,7 +193,7 @@ class PatchManager:
|
||||
if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3":
|
||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
||||
|
||||
patch_deepseek_v3_moe()
|
||||
patch_deepseek_v3_moe(backend=self.cfg.moe_kernel_backend)
|
||||
elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels:
|
||||
LOG.info(
|
||||
"Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`"
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
|
||||
from axolotl.kernels.moe import ContiguousGroupedGEMM
|
||||
from axolotl.kernels.moe.indices import generate_permute_indices
|
||||
from axolotl.kernels.moe.tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
_GROUP_SIZE_M = 128
|
||||
@@ -30,28 +31,36 @@ def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
|
||||
|
||||
|
||||
def _ensure_combined_expert_weights(
|
||||
module, dtype: torch.dtype, device: torch.device
|
||||
module, dtype: torch.dtype, device: torch.device, backend: str
|
||||
) -> None:
|
||||
if not hasattr(module, "_axolotl_original_specs"):
|
||||
module._axolotl_original_specs = {}
|
||||
if getattr(module, "_axolotl_combined_weights", False):
|
||||
# Move cached combined weights to the working dtype/device if required.
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
param_name = f"{name}_weight"
|
||||
param = module.get_parameter(param_name)
|
||||
if param.device != device or param.dtype != dtype:
|
||||
module._parameters[param_name] = torch.nn.Parameter(
|
||||
param.to(device=device, dtype=dtype).contiguous()
|
||||
)
|
||||
module._axolotl_combined_dtype = dtype
|
||||
module._axolotl_combined_device = device
|
||||
return
|
||||
if not hasattr(module, "_axolotl_mg_shapes"):
|
||||
module._axolotl_mg_shapes = {}
|
||||
|
||||
combined = {}
|
||||
prev_backend = getattr(module, "_axolotl_combined_backend", None)
|
||||
if getattr(module, "_axolotl_combined_weights", False):
|
||||
if prev_backend != backend:
|
||||
_restore_expert_weights(module)
|
||||
else:
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
param_name = f"{name}_weight"
|
||||
param = module.get_parameter(param_name)
|
||||
if param.device != device or param.dtype != dtype:
|
||||
module._parameters[param_name] = torch.nn.Parameter(
|
||||
param.to(device=device, dtype=dtype).contiguous()
|
||||
)
|
||||
module._axolotl_combined_dtype = dtype
|
||||
module._axolotl_combined_device = device
|
||||
module._axolotl_combined_backend = backend
|
||||
return
|
||||
|
||||
module._axolotl_mg_shapes = {}
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
weights = []
|
||||
orig_device = None
|
||||
orig_dtype = None
|
||||
orig_shape = None
|
||||
for expert in module.experts:
|
||||
lin = expert.get_submodule(name)
|
||||
weight_param = lin._parameters.get("weight")
|
||||
@@ -60,19 +69,24 @@ def _ensure_combined_expert_weights(
|
||||
if orig_device is None:
|
||||
orig_device = weight_param.device
|
||||
orig_dtype = weight_param.dtype
|
||||
orig_shape = tuple(weight_param.shape)
|
||||
weights.append(weight_param.detach().to(device=device, dtype=dtype))
|
||||
if "weight" in lin._parameters:
|
||||
del lin._parameters["weight"]
|
||||
if "bias" in lin._parameters:
|
||||
# DeepseekV3 MLP layers are bias-free, but keep this for safety.
|
||||
del lin._parameters["bias"]
|
||||
combined[name] = torch.stack(weights, dim=0).contiguous()
|
||||
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined[name]))
|
||||
module._axolotl_original_specs[name] = (orig_device, orig_dtype)
|
||||
if backend == "cg":
|
||||
combined_weight = torch.stack(weights, dim=0).contiguous()
|
||||
else:
|
||||
combined_weight = torch.cat(weights, dim=0).contiguous()
|
||||
module._axolotl_mg_shapes[name] = orig_shape
|
||||
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined_weight))
|
||||
module._axolotl_original_specs[name] = (orig_device, orig_dtype, orig_shape)
|
||||
|
||||
module._axolotl_combined_weights = True
|
||||
module._axolotl_combined_dtype = dtype
|
||||
module._axolotl_combined_device = device
|
||||
module._axolotl_combined_backend = backend
|
||||
|
||||
|
||||
def _restore_expert_weights(module) -> None:
|
||||
@@ -82,19 +96,111 @@ def _restore_expert_weights(module) -> None:
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
param_name = f"{name}_weight"
|
||||
combined = module._parameters.pop(param_name)
|
||||
orig_device, orig_dtype = module._axolotl_original_specs.get(
|
||||
name, (combined.device, combined.dtype)
|
||||
orig_device, orig_dtype, orig_shape = module._axolotl_original_specs.get(
|
||||
name, (combined.device, combined.dtype, None)
|
||||
)
|
||||
rows_per = orig_shape[0] if orig_shape else None
|
||||
for idx, expert in enumerate(module.experts):
|
||||
lin = expert.get_submodule(name)
|
||||
if combined.dim() == 3:
|
||||
slice_tensor = combined[idx]
|
||||
elif rows_per is not None:
|
||||
start = idx * rows_per
|
||||
end = start + rows_per
|
||||
slice_tensor = combined[start:end]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unable to recover expert weight shape during restore"
|
||||
)
|
||||
lin._parameters["weight"] = torch.nn.Parameter(
|
||||
combined[idx].detach().clone().to(orig_device, dtype=orig_dtype)
|
||||
slice_tensor.detach().clone().to(orig_device, dtype=orig_dtype)
|
||||
)
|
||||
|
||||
module._axolotl_combined_weights = False
|
||||
module._axolotl_combined_dtype = None
|
||||
module._axolotl_combined_device = None
|
||||
module._axolotl_combined_backend = None
|
||||
module._axolotl_original_specs = {}
|
||||
module._axolotl_mg_shapes = {}
|
||||
|
||||
|
||||
def _run_cg_grouped_gemm(
|
||||
module,
|
||||
grouped_hidden: torch.Tensor,
|
||||
m_sizes: torch.Tensor,
|
||||
num_experts: int,
|
||||
group_size_m: int,
|
||||
hidden_dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
_ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg")
|
||||
|
||||
expert_index_tensor = torch.repeat_interleave(
|
||||
torch.arange(num_experts, device=device, dtype=torch.int32),
|
||||
m_sizes.to(torch.int64),
|
||||
)
|
||||
|
||||
gate_weights = module.get_parameter("gate_proj_weight")
|
||||
if gate_weights.dim() == 2:
|
||||
out_dim = gate_weights.shape[0] // num_experts
|
||||
gate_weights = gate_weights.view(num_experts, out_dim, gate_weights.shape[1])
|
||||
|
||||
up_weights = module.get_parameter("up_proj_weight")
|
||||
if up_weights.dim() == 2:
|
||||
out_dim = up_weights.shape[0] // num_experts
|
||||
up_weights = up_weights.view(num_experts, out_dim, up_weights.shape[1])
|
||||
|
||||
down_weights = module.get_parameter("down_proj_weight")
|
||||
if down_weights.dim() == 2:
|
||||
out_dim = down_weights.shape[0] // num_experts
|
||||
down_weights = down_weights.view(num_experts, out_dim, down_weights.shape[1])
|
||||
|
||||
gate_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
gate_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
up_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
up_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
down_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
down_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
|
||||
return (
|
||||
gate_out.to(hidden_dtype),
|
||||
up_out.to(hidden_dtype),
|
||||
down_out.to(hidden_dtype),
|
||||
)
|
||||
|
||||
gate_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("gate_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
)
|
||||
up_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("up_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
)
|
||||
down_out = mg_grouped_gemm(
|
||||
hidden_grouped,
|
||||
module.get_parameter("down_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
)
|
||||
|
||||
return (
|
||||
gate_out.to(hidden_dtype),
|
||||
up_out.to(hidden_dtype),
|
||||
down_out.to(hidden_dtype),
|
||||
)
|
||||
|
||||
|
||||
def _moe_triton_forward(
|
||||
@@ -103,6 +209,7 @@ def _moe_triton_forward(
|
||||
topk_indices: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
group_size_m: int,
|
||||
backend: str,
|
||||
fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if not _is_triton_eligible(hidden_states):
|
||||
@@ -146,7 +253,7 @@ def _moe_triton_forward(
|
||||
) * group_size_m
|
||||
max_len = int(aligned_counts.sum().item())
|
||||
|
||||
permuted_indices, m_sizes, m_offsets = generate_permute_indices(
|
||||
permuted_indices, m_sizes, _ = generate_permute_indices(
|
||||
counts_int.to(device),
|
||||
experts_per_rank=num_experts,
|
||||
num_ranks=1,
|
||||
@@ -155,12 +262,8 @@ def _moe_triton_forward(
|
||||
use_cpu=not hidden_states.is_cuda,
|
||||
)
|
||||
|
||||
if permuted_indices.device != device:
|
||||
permuted_indices = permuted_indices.to(device)
|
||||
if m_sizes.device != device:
|
||||
m_sizes = m_sizes.to(device)
|
||||
if m_offsets.device != device:
|
||||
m_offsets = m_offsets.to(device)
|
||||
permuted_indices = permuted_indices.to(device)
|
||||
m_sizes = m_sizes.to(device)
|
||||
|
||||
permuted_indices_long = permuted_indices.to(torch.int64)
|
||||
valid_mask = permuted_indices_long >= 0
|
||||
@@ -178,34 +281,35 @@ def _moe_triton_forward(
|
||||
if valid_positions.numel() < max_len:
|
||||
grouped_hidden.index_fill_(0, padded_positions, 0)
|
||||
|
||||
expert_index_tensor = torch.repeat_interleave(
|
||||
torch.arange(num_experts, device=device, dtype=torch.int32),
|
||||
m_sizes.to(torch.int64),
|
||||
)
|
||||
m_sizes_tensor = m_sizes.to(device=device, dtype=torch.int32)
|
||||
|
||||
_ensure_combined_expert_weights(module, hidden_dtype, device)
|
||||
|
||||
gate_weights = module.get_parameter("gate_proj_weight")
|
||||
up_weights = module.get_parameter("up_proj_weight")
|
||||
down_weights = module.get_parameter("down_proj_weight")
|
||||
|
||||
gate_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
gate_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
up_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
up_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
if backend == "mg":
|
||||
_ensure_combined_expert_weights(module, hidden_dtype, device, backend)
|
||||
gate_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("gate_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
).to(hidden_dtype)
|
||||
up_out = mg_grouped_gemm(
|
||||
grouped_hidden,
|
||||
module.get_parameter("up_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
).to(hidden_dtype)
|
||||
else:
|
||||
gate_out, up_out, down_out_cg = _run_cg_grouped_gemm(
|
||||
module,
|
||||
grouped_hidden,
|
||||
m_sizes,
|
||||
num_experts,
|
||||
group_size_m,
|
||||
hidden_dtype,
|
||||
device,
|
||||
)
|
||||
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
|
||||
if valid_positions.numel() > 0:
|
||||
gate_valid = gate_out.index_select(0, valid_positions).to(hidden_dtype)
|
||||
up_valid = up_out.index_select(0, valid_positions).to(hidden_dtype)
|
||||
gate_valid = gate_out.index_select(0, valid_positions)
|
||||
up_valid = up_out.index_select(0, valid_positions)
|
||||
hidden_concat = act_fn(gate_valid) * up_valid
|
||||
else:
|
||||
hidden_concat = torch.empty(
|
||||
@@ -219,15 +323,17 @@ def _moe_triton_forward(
|
||||
if valid_positions.numel() < max_len:
|
||||
hidden_grouped.index_fill_(0, padded_positions, 0)
|
||||
|
||||
down_out = ContiguousGroupedGEMM.apply(
|
||||
hidden_grouped,
|
||||
down_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
if backend == "mg":
|
||||
down_out = mg_grouped_gemm(
|
||||
hidden_grouped,
|
||||
module.get_parameter("down_proj_weight"),
|
||||
m_sizes_tensor,
|
||||
).to(hidden_dtype)
|
||||
else:
|
||||
down_out = down_out_cg
|
||||
|
||||
if valid_positions.numel() > 0:
|
||||
down_valid = down_out.index_select(0, valid_positions).to(hidden_dtype)
|
||||
down_valid = down_out.index_select(0, valid_positions)
|
||||
else:
|
||||
down_valid = torch.empty(
|
||||
(0, down_out.shape[-1]), device=device, dtype=hidden_dtype
|
||||
@@ -245,11 +351,16 @@ def _moe_triton_forward(
|
||||
return weighted.sum(dim=1)
|
||||
|
||||
|
||||
def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
|
||||
def patch_deepseek_v3_moe(
|
||||
group_size_m: int = _GROUP_SIZE_M, backend: str = "mg"
|
||||
) -> None:
|
||||
"""Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels."""
|
||||
|
||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
|
||||
if backend not in {"cg", "mg"}:
|
||||
raise ValueError(f"Unsupported MoE kernel backend: {backend}")
|
||||
|
||||
# Record the unpatched implementation so callers can access a true baseline even
|
||||
# after the Triton patch has been applied (e.g. repeated microbenchmarks).
|
||||
if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"):
|
||||
@@ -259,26 +370,26 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
|
||||
return
|
||||
|
||||
original_moe = DeepseekV3MoE._axolotl_triton_original_moe
|
||||
DeepseekV3MoE._axolotl_triton_backend = backend
|
||||
DeepseekV3MoE._axolotl_group_size_m = group_size_m
|
||||
|
||||
def patched_moe(self, hidden_states, topk_indices, topk_weights):
|
||||
backend_sel = getattr(self, "_axolotl_triton_backend", backend)
|
||||
group_size_sel = getattr(self, "_axolotl_group_size_m", group_size_m)
|
||||
try:
|
||||
return _moe_triton_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
topk_indices,
|
||||
topk_weights,
|
||||
group_size_m,
|
||||
group_size_sel,
|
||||
backend_sel,
|
||||
original_moe,
|
||||
)
|
||||
except Exception as err: # fall back if Triton compilation or runtime fails
|
||||
if not getattr(self, "_axolotl_triton_warned", False):
|
||||
LOG.warning(
|
||||
"DeepseekV3MoE Triton path failed; falling back to baseline: %s",
|
||||
err,
|
||||
)
|
||||
self._axolotl_triton_warned = True
|
||||
except Exception as err: # surface Triton failures explicitly
|
||||
_restore_expert_weights(self)
|
||||
return original_moe(self, hidden_states, topk_indices, topk_weights)
|
||||
LOG.error("DeepseekV3MoE Triton path failed: %s", err)
|
||||
raise
|
||||
|
||||
DeepseekV3MoE.moe = patched_moe
|
||||
DeepseekV3MoE._axolotl_triton_patch = True
|
||||
|
||||
@@ -119,6 +119,12 @@ class AxolotlInputConfig(
|
||||
"description": "Enable Axolotl's vendored MoE kernels when supported (e.g., DeepSeek V3)"
|
||||
},
|
||||
)
|
||||
moe_kernel_backend: Literal["cg", "mg"] | None = Field(
|
||||
default="mg",
|
||||
json_schema_extra={
|
||||
"description": "Grouped GEMM backend to use when `moe_kernels` is enabled. `mg` selects the Hopper TMA kernel; `cg` selects the contiguous kernel."
|
||||
},
|
||||
)
|
||||
|
||||
trainer_cls: str | None = Field(
|
||||
default=None,
|
||||
|
||||
Reference in New Issue
Block a user