add mg kernel backend
This commit is contained in:
@@ -32,7 +32,7 @@ for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
|||||||
else: # pragma: no cover - defensive guard
|
else: # pragma: no cover - defensive guard
|
||||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||||
|
|
||||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe # noqa: E402
|
||||||
|
|
||||||
ACCURACY_TOLERANCE = 5e-3
|
ACCURACY_TOLERANCE = 5e-3
|
||||||
|
|
||||||
@@ -98,6 +98,12 @@ def parse_args() -> argparse.Namespace:
|
|||||||
default=128,
|
default=128,
|
||||||
help="GROUP_SIZE_M used by the Triton kernel",
|
help="GROUP_SIZE_M used by the Triton kernel",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
choices=["cg", "mg"],
|
||||||
|
default="mg",
|
||||||
|
help="MoE kernel backend to benchmark",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -163,7 +169,7 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
|
|||||||
baseline_module.moe = MethodType(original_moe, baseline_module)
|
baseline_module.moe = MethodType(original_moe, baseline_module)
|
||||||
state_dict = baseline_module.state_dict()
|
state_dict = baseline_module.state_dict()
|
||||||
|
|
||||||
patch_deepseek_v3_moe(group_size_m=args.group_size)
|
patch_deepseek_v3_moe(group_size_m=args.group_size, backend=args.backend)
|
||||||
patched_module = build_module(args)
|
patched_module = build_module(args)
|
||||||
patched_module.load_state_dict(state_dict)
|
patched_module.load_state_dict(state_dict)
|
||||||
|
|
||||||
@@ -250,6 +256,7 @@ def benchmark_deepseek_v3(args: argparse.Namespace) -> dict:
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"device": device,
|
"device": device,
|
||||||
|
"backend": args.backend,
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
"baseline_ms": baseline_ms,
|
"baseline_ms": baseline_ms,
|
||||||
"patched_ms": patched_ms,
|
"patched_ms": patched_ms,
|
||||||
@@ -270,7 +277,7 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
|||||||
result = benchmark_deepseek_v3(args)
|
result = benchmark_deepseek_v3(args)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Device={result['device'].type} dtype={result['dtype']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
|
f"Device={result['device'].type} dtype={result['dtype']} backend={result['backend']} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
|
f"routed tokens={result['routed_tokens']} avg tokens/expert={result['avg_tokens']:.1f} group_size={args.group_size}"
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ for candidate in [CURRENT_DIR, *CURRENT_DIR.parents]:
|
|||||||
else: # pragma: no cover - defensive guard
|
else: # pragma: no cover - defensive guard
|
||||||
raise SystemExit("Unable to locate axolotl repository root for imports")
|
raise SystemExit("Unable to locate axolotl repository root for imports")
|
||||||
|
|
||||||
from scripts.benchmarks.deepseek_v3_moe import (
|
from scripts.benchmarks.deepseek_v3_moe import ( # noqa: E402
|
||||||
ACCURACY_TOLERANCE,
|
ACCURACY_TOLERANCE,
|
||||||
DTYPE_MAP,
|
DTYPE_MAP,
|
||||||
benchmark_deepseek_v3,
|
benchmark_deepseek_v3,
|
||||||
@@ -42,6 +42,12 @@ def parse_args() -> argparse.Namespace:
|
|||||||
choices=["auto", "cpu", "cuda"],
|
choices=["auto", "cpu", "cuda"],
|
||||||
help="Execution device",
|
help="Execution device",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
choices=["cg", "mg"],
|
||||||
|
default="mg",
|
||||||
|
help="MoE kernel backend to benchmark",
|
||||||
|
)
|
||||||
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
|
parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations")
|
||||||
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
|
parser.add_argument("--iters", type=int, default=15, help="Benchmark iterations")
|
||||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||||
@@ -105,6 +111,7 @@ def make_namespace(base: dict, args: argparse.Namespace) -> SimpleNamespace:
|
|||||||
{
|
{
|
||||||
"dtype": args.dtype,
|
"dtype": args.dtype,
|
||||||
"device": args.device,
|
"device": args.device,
|
||||||
|
"backend": args.backend,
|
||||||
"warmup": args.warmup,
|
"warmup": args.warmup,
|
||||||
"iters": args.iters,
|
"iters": args.iters,
|
||||||
"seed": args.seed,
|
"seed": args.seed,
|
||||||
@@ -164,6 +171,7 @@ def main() -> None: # pragma: no cover - utility script
|
|||||||
"n_experts",
|
"n_experts",
|
||||||
"top_k",
|
"top_k",
|
||||||
"groups",
|
"groups",
|
||||||
|
"backend",
|
||||||
"baseline_ms",
|
"baseline_ms",
|
||||||
"patched_ms",
|
"patched_ms",
|
||||||
"speedup",
|
"speedup",
|
||||||
@@ -177,10 +185,10 @@ def main() -> None: # pragma: no cover - utility script
|
|||||||
rows = []
|
rows = []
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Running sweep on device={args.device} dtype={args.dtype} uniform_routing={args.uniform_routing}"
|
f"Running sweep on device={args.device} dtype={args.dtype} backend={args.backend} uniform_routing={args.uniform_routing}"
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6}"
|
f"{'batch':>5} {'seq':>5} {'hidden':>7} {'experts':>7} {'topk':>4} {'groups':>6} {'backend':>8}"
|
||||||
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}"
|
f" {'baseline':>12} {'patched':>12} {'speedup':>8} {'b_vram':>8} {'p_vram':>8} {'acc':>5}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -206,6 +214,7 @@ def main() -> None: # pragma: no cover - utility script
|
|||||||
cfg["n_experts"],
|
cfg["n_experts"],
|
||||||
cfg["top_k"],
|
cfg["top_k"],
|
||||||
cfg["groups"],
|
cfg["groups"],
|
||||||
|
args.backend,
|
||||||
result["baseline_ms"],
|
result["baseline_ms"],
|
||||||
result["patched_ms"],
|
result["patched_ms"],
|
||||||
result["speedup"],
|
result["speedup"],
|
||||||
@@ -219,7 +228,7 @@ def main() -> None: # pragma: no cover - utility script
|
|||||||
)
|
)
|
||||||
status = "OK" if result["accuracy_ok"] else "FAIL"
|
status = "OK" if result["accuracy_ok"] else "FAIL"
|
||||||
print(
|
print(
|
||||||
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6}"
|
f"{cfg['batch']:>5} {cfg['seq_len']:>5} {cfg['hidden_size']:>7} {cfg['n_experts']:>7} {cfg['top_k']:>4} {cfg['groups']:>6} {args.backend:>8}"
|
||||||
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
f" {result['baseline_ms']:>11.3f} ms {result['patched_ms']:>11.3f} ms {result['speedup']:>7.2f}x"
|
||||||
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}"
|
f" {baseline_vram_mib:>8.1f} {patched_vram_mib:>8.1f} {status:>5}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from .tt_cg_gemm import (
|
|||||||
cg_grouped_gemm_forward,
|
cg_grouped_gemm_forward,
|
||||||
cg_grouped_gemm_forward_dynamic,
|
cg_grouped_gemm_forward_dynamic,
|
||||||
)
|
)
|
||||||
|
from .tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"cg_grouped_gemm",
|
"cg_grouped_gemm",
|
||||||
@@ -16,4 +17,5 @@ __all__ = [
|
|||||||
"ContiguousGroupedGEMM",
|
"ContiguousGroupedGEMM",
|
||||||
"ContiguousGroupedGEMMForwardOnly",
|
"ContiguousGroupedGEMMForwardOnly",
|
||||||
"generate_permute_indices",
|
"generate_permute_indices",
|
||||||
|
"mg_grouped_gemm",
|
||||||
]
|
]
|
||||||
|
|||||||
13
src/axolotl/kernels/moe/tt_mg_gemm/__init__.py
Normal file
13
src/axolotl/kernels/moe/tt_mg_gemm/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .mg_grouped_gemm import grouped_gemm_forward
|
||||||
|
from .tma_autotuning import ALIGN_SIZE_M
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"grouped_gemm_forward",
|
||||||
|
"ALIGN_SIZE_M",
|
||||||
|
]
|
||||||
1291
src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py
Normal file
1291
src/axolotl/kernels/moe/tt_mg_gemm/mg_grouped_gemm.py
Normal file
File diff suppressed because it is too large
Load Diff
237
src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py
Normal file
237
src/axolotl/kernels/moe/tt_mg_gemm/tma_autotuning.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# credit - TMAHelper class, AutoTuning are derived from FBGemm:
|
||||||
|
# https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
|
||||||
|
|
||||||
|
# pyre-unsafe
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from triton.runtime import driver # @manual
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Supporting utils, CUDA and TMA =====
|
||||||
|
|
||||||
|
|
||||||
|
class CudaUtils:
|
||||||
|
@staticmethod
|
||||||
|
def is_cuda() -> bool:
|
||||||
|
"""Check if Triton is running on CUDA backend."""
|
||||||
|
return driver.active.get_current_target().backend == "cuda"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_tma() -> bool:
|
||||||
|
"""Check if TMA is supported on the current device."""
|
||||||
|
return (
|
||||||
|
CudaUtils.is_cuda()
|
||||||
|
and torch.cuda.is_available()
|
||||||
|
and torch.cuda.get_device_capability()[0] >= 9
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_num_sms() -> int:
|
||||||
|
"""Get the number of streaming multiprocessors on the current device."""
|
||||||
|
if not CudaUtils.is_cuda():
|
||||||
|
raise RuntimeError("Triton is not running on CUDA backend")
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise RuntimeError("CUDA is not available")
|
||||||
|
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||||
|
|
||||||
|
|
||||||
|
class TmaDescriptorHelper:
|
||||||
|
"""Helper class for managing TMA descriptors in Triton kernels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tma_size: Size of the TMA descriptor in bytes
|
||||||
|
"""
|
||||||
|
|
||||||
|
class KernelParamWrapper:
|
||||||
|
"""Wrapper to implement the TmaDescKernelParam interface."""
|
||||||
|
|
||||||
|
def __init__(self, desc: torch.Tensor):
|
||||||
|
self.desc = desc
|
||||||
|
|
||||||
|
def tma_desc_cpu_ptr(self) -> int:
|
||||||
|
"""Return the CPU pointer to the TMA descriptor."""
|
||||||
|
return self.desc.data_ptr()
|
||||||
|
|
||||||
|
def __init__(self, tma_size: int = 128):
|
||||||
|
if not CudaUtils.verify_tma():
|
||||||
|
raise RuntimeError(
|
||||||
|
"TMA not supported on this device (requires Hopper or newer)"
|
||||||
|
)
|
||||||
|
if "nv_tma_desc_type" not in dir(tl):
|
||||||
|
raise RuntimeError(
|
||||||
|
"TMA grid constant descriptors not supported in your Triton version"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tma_size = tma_size
|
||||||
|
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
|
||||||
|
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
|
||||||
|
self.descriptors: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
def init_tma_descriptor(self, name: str) -> None:
|
||||||
|
"""Initialize a TMA descriptor with the given name.
|
||||||
|
|
||||||
|
Call this method outside of the lambda function for grid size.
|
||||||
|
"""
|
||||||
|
self.descriptors[name] = torch.empty(
|
||||||
|
self.tma_size, device="cpu", dtype=torch.int8
|
||||||
|
)
|
||||||
|
|
||||||
|
def fill_1d_tma_descriptor(
|
||||||
|
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
|
||||||
|
) -> None:
|
||||||
|
"""Fill a 1D TMA descriptor.
|
||||||
|
|
||||||
|
Call this method inside the lambda function for grid size.
|
||||||
|
"""
|
||||||
|
if name not in self.descriptors:
|
||||||
|
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||||
|
|
||||||
|
desc_x = self.descriptors[name]
|
||||||
|
if desc_x.data_ptr() % 64 != 0:
|
||||||
|
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||||
|
self.fill_1d_tma_descriptor_inner(
|
||||||
|
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
||||||
|
)
|
||||||
|
|
||||||
|
def fill_2d_tma_descriptor(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
ptr: int,
|
||||||
|
dim1: int,
|
||||||
|
dim0: int,
|
||||||
|
block_dim1: int,
|
||||||
|
block_dim0: int,
|
||||||
|
element_size: int,
|
||||||
|
) -> None:
|
||||||
|
"""Fill a 2D TMA descriptor.
|
||||||
|
|
||||||
|
Call this method inside the lambda function for grid size.
|
||||||
|
"""
|
||||||
|
if name not in self.descriptors:
|
||||||
|
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||||
|
|
||||||
|
desc_x = self.descriptors[name]
|
||||||
|
if desc_x.data_ptr() % 64 != 0:
|
||||||
|
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||||
|
self.fill_2d_tma_descriptor_inner(
|
||||||
|
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_tma_descriptor_kernel_param(self, name: str) -> KernelParamWrapper:
|
||||||
|
"""Get the TMA descriptor kernel parameter for the given name."""
|
||||||
|
if name not in self.descriptors or self.descriptors[name] is None:
|
||||||
|
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||||
|
return self.KernelParamWrapper(self.descriptors[name])
|
||||||
|
|
||||||
|
|
||||||
|
# ====== Autotuning utilities ======
|
||||||
|
ALIGN_SIZE_M = 128
|
||||||
|
|
||||||
|
_NV_CONFIGS = [
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": block_size_m,
|
||||||
|
"BLOCK_SIZE_N": block_size_n,
|
||||||
|
"BLOCK_SIZE_K": block_size_k,
|
||||||
|
},
|
||||||
|
num_stages=num_stages,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_ctas=num_ctas,
|
||||||
|
)
|
||||||
|
for block_size_m in [
|
||||||
|
ALIGN_SIZE_M,
|
||||||
|
]
|
||||||
|
for block_size_n in [64, 128, 256]
|
||||||
|
for block_size_k in [64, 128, 256]
|
||||||
|
for num_stages in [3, 4]
|
||||||
|
for num_warps in [4, 8]
|
||||||
|
for num_ctas in [1]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
# Check for all possible pointer parameter names
|
||||||
|
if "grad_input_ptr" in named_args:
|
||||||
|
ptr_name = "grad_input_ptr"
|
||||||
|
elif "c_ptr" in named_args:
|
||||||
|
ptr_name = "c_ptr"
|
||||||
|
elif "grad_weight_ptr" in named_args:
|
||||||
|
ptr_name = "grad_weight_ptr"
|
||||||
|
else:
|
||||||
|
raise KeyError("No recognized pointer parameter found in kernel arguments")
|
||||||
|
|
||||||
|
if dtsize is None:
|
||||||
|
dtsize = named_args[ptr_name].element_size()
|
||||||
|
if dtype is None:
|
||||||
|
dtype = named_args[ptr_name].dtype
|
||||||
|
|
||||||
|
pruned_configs = []
|
||||||
|
for config in configs:
|
||||||
|
kw = config.kwargs
|
||||||
|
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = (
|
||||||
|
kw["BLOCK_SIZE_M"],
|
||||||
|
kw["BLOCK_SIZE_N"],
|
||||||
|
kw["BLOCK_SIZE_K"],
|
||||||
|
config.num_stages,
|
||||||
|
)
|
||||||
|
G, M, N, K = (
|
||||||
|
named_args["G"],
|
||||||
|
named_args["M_BUCKET"],
|
||||||
|
named_args["N"],
|
||||||
|
named_args["K"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. make sure we have enough smem
|
||||||
|
max_shared_memory = driver.active.utils.get_device_properties(device)[
|
||||||
|
"max_shared_mem"
|
||||||
|
]
|
||||||
|
|
||||||
|
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
|
||||||
|
if required_shared_memory > max_shared_memory:
|
||||||
|
continue
|
||||||
|
|
||||||
|
M_PER_GROUP = M // G
|
||||||
|
MIN_M_TILES = 64
|
||||||
|
# 2. make sure we don't load M tiles that are too big
|
||||||
|
if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2):
|
||||||
|
continue
|
||||||
|
# 3. make sure we don't load N tiles that are too small
|
||||||
|
if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2):
|
||||||
|
continue
|
||||||
|
|
||||||
|
num_sm = driver.active.utils.get_device_properties(device)[
|
||||||
|
"multiprocessor_count"
|
||||||
|
]
|
||||||
|
N_TILES = N // BLOCK_N
|
||||||
|
MIN_N_TILES = 64
|
||||||
|
# 4. make sure we don't load N tiles that are too big
|
||||||
|
if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm:
|
||||||
|
continue
|
||||||
|
# 5. make sure we don't load N tiles that are too small
|
||||||
|
if BLOCK_N < 128 and M * N_TILES > 2 * num_sm:
|
||||||
|
continue
|
||||||
|
# 6. make sure K can be evenly divided
|
||||||
|
if K % BLOCK_K != 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pruned_configs.append(config)
|
||||||
|
|
||||||
|
return pruned_configs
|
||||||
|
|
||||||
|
|
||||||
|
# ======== End Autotuning utilities ========
|
||||||
@@ -193,7 +193,7 @@ class PatchManager:
|
|||||||
if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3":
|
if self.cfg.moe_kernels and self.cfg.model_config_type == "deepseek_v3":
|
||||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
||||||
|
|
||||||
patch_deepseek_v3_moe()
|
patch_deepseek_v3_moe(backend=self.cfg.moe_kernel_backend)
|
||||||
elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels:
|
elif self.cfg.model_config_type == "deepseek_v3" and not self.cfg.moe_kernels:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`"
|
"Skipping DeepSeek V3 Triton MoE kernels; enable with `moe_kernels: true`"
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import torch
|
|||||||
|
|
||||||
from axolotl.kernels.moe import ContiguousGroupedGEMM
|
from axolotl.kernels.moe import ContiguousGroupedGEMM
|
||||||
from axolotl.kernels.moe.indices import generate_permute_indices
|
from axolotl.kernels.moe.indices import generate_permute_indices
|
||||||
|
from axolotl.kernels.moe.tt_mg_gemm import grouped_gemm_forward as mg_grouped_gemm
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
_GROUP_SIZE_M = 128
|
_GROUP_SIZE_M = 128
|
||||||
@@ -30,28 +31,36 @@ def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _ensure_combined_expert_weights(
|
def _ensure_combined_expert_weights(
|
||||||
module, dtype: torch.dtype, device: torch.device
|
module, dtype: torch.dtype, device: torch.device, backend: str
|
||||||
) -> None:
|
) -> None:
|
||||||
if not hasattr(module, "_axolotl_original_specs"):
|
if not hasattr(module, "_axolotl_original_specs"):
|
||||||
module._axolotl_original_specs = {}
|
module._axolotl_original_specs = {}
|
||||||
if getattr(module, "_axolotl_combined_weights", False):
|
if not hasattr(module, "_axolotl_mg_shapes"):
|
||||||
# Move cached combined weights to the working dtype/device if required.
|
module._axolotl_mg_shapes = {}
|
||||||
for name in _COMBINED_SUBMODULES:
|
|
||||||
param_name = f"{name}_weight"
|
|
||||||
param = module.get_parameter(param_name)
|
|
||||||
if param.device != device or param.dtype != dtype:
|
|
||||||
module._parameters[param_name] = torch.nn.Parameter(
|
|
||||||
param.to(device=device, dtype=dtype).contiguous()
|
|
||||||
)
|
|
||||||
module._axolotl_combined_dtype = dtype
|
|
||||||
module._axolotl_combined_device = device
|
|
||||||
return
|
|
||||||
|
|
||||||
combined = {}
|
prev_backend = getattr(module, "_axolotl_combined_backend", None)
|
||||||
|
if getattr(module, "_axolotl_combined_weights", False):
|
||||||
|
if prev_backend != backend:
|
||||||
|
_restore_expert_weights(module)
|
||||||
|
else:
|
||||||
|
for name in _COMBINED_SUBMODULES:
|
||||||
|
param_name = f"{name}_weight"
|
||||||
|
param = module.get_parameter(param_name)
|
||||||
|
if param.device != device or param.dtype != dtype:
|
||||||
|
module._parameters[param_name] = torch.nn.Parameter(
|
||||||
|
param.to(device=device, dtype=dtype).contiguous()
|
||||||
|
)
|
||||||
|
module._axolotl_combined_dtype = dtype
|
||||||
|
module._axolotl_combined_device = device
|
||||||
|
module._axolotl_combined_backend = backend
|
||||||
|
return
|
||||||
|
|
||||||
|
module._axolotl_mg_shapes = {}
|
||||||
for name in _COMBINED_SUBMODULES:
|
for name in _COMBINED_SUBMODULES:
|
||||||
weights = []
|
weights = []
|
||||||
orig_device = None
|
orig_device = None
|
||||||
orig_dtype = None
|
orig_dtype = None
|
||||||
|
orig_shape = None
|
||||||
for expert in module.experts:
|
for expert in module.experts:
|
||||||
lin = expert.get_submodule(name)
|
lin = expert.get_submodule(name)
|
||||||
weight_param = lin._parameters.get("weight")
|
weight_param = lin._parameters.get("weight")
|
||||||
@@ -60,19 +69,24 @@ def _ensure_combined_expert_weights(
|
|||||||
if orig_device is None:
|
if orig_device is None:
|
||||||
orig_device = weight_param.device
|
orig_device = weight_param.device
|
||||||
orig_dtype = weight_param.dtype
|
orig_dtype = weight_param.dtype
|
||||||
|
orig_shape = tuple(weight_param.shape)
|
||||||
weights.append(weight_param.detach().to(device=device, dtype=dtype))
|
weights.append(weight_param.detach().to(device=device, dtype=dtype))
|
||||||
if "weight" in lin._parameters:
|
if "weight" in lin._parameters:
|
||||||
del lin._parameters["weight"]
|
del lin._parameters["weight"]
|
||||||
if "bias" in lin._parameters:
|
if "bias" in lin._parameters:
|
||||||
# DeepseekV3 MLP layers are bias-free, but keep this for safety.
|
|
||||||
del lin._parameters["bias"]
|
del lin._parameters["bias"]
|
||||||
combined[name] = torch.stack(weights, dim=0).contiguous()
|
if backend == "cg":
|
||||||
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined[name]))
|
combined_weight = torch.stack(weights, dim=0).contiguous()
|
||||||
module._axolotl_original_specs[name] = (orig_device, orig_dtype)
|
else:
|
||||||
|
combined_weight = torch.cat(weights, dim=0).contiguous()
|
||||||
|
module._axolotl_mg_shapes[name] = orig_shape
|
||||||
|
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined_weight))
|
||||||
|
module._axolotl_original_specs[name] = (orig_device, orig_dtype, orig_shape)
|
||||||
|
|
||||||
module._axolotl_combined_weights = True
|
module._axolotl_combined_weights = True
|
||||||
module._axolotl_combined_dtype = dtype
|
module._axolotl_combined_dtype = dtype
|
||||||
module._axolotl_combined_device = device
|
module._axolotl_combined_device = device
|
||||||
|
module._axolotl_combined_backend = backend
|
||||||
|
|
||||||
|
|
||||||
def _restore_expert_weights(module) -> None:
|
def _restore_expert_weights(module) -> None:
|
||||||
@@ -82,19 +96,111 @@ def _restore_expert_weights(module) -> None:
|
|||||||
for name in _COMBINED_SUBMODULES:
|
for name in _COMBINED_SUBMODULES:
|
||||||
param_name = f"{name}_weight"
|
param_name = f"{name}_weight"
|
||||||
combined = module._parameters.pop(param_name)
|
combined = module._parameters.pop(param_name)
|
||||||
orig_device, orig_dtype = module._axolotl_original_specs.get(
|
orig_device, orig_dtype, orig_shape = module._axolotl_original_specs.get(
|
||||||
name, (combined.device, combined.dtype)
|
name, (combined.device, combined.dtype, None)
|
||||||
)
|
)
|
||||||
|
rows_per = orig_shape[0] if orig_shape else None
|
||||||
for idx, expert in enumerate(module.experts):
|
for idx, expert in enumerate(module.experts):
|
||||||
lin = expert.get_submodule(name)
|
lin = expert.get_submodule(name)
|
||||||
|
if combined.dim() == 3:
|
||||||
|
slice_tensor = combined[idx]
|
||||||
|
elif rows_per is not None:
|
||||||
|
start = idx * rows_per
|
||||||
|
end = start + rows_per
|
||||||
|
slice_tensor = combined[start:end]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Unable to recover expert weight shape during restore"
|
||||||
|
)
|
||||||
lin._parameters["weight"] = torch.nn.Parameter(
|
lin._parameters["weight"] = torch.nn.Parameter(
|
||||||
combined[idx].detach().clone().to(orig_device, dtype=orig_dtype)
|
slice_tensor.detach().clone().to(orig_device, dtype=orig_dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
module._axolotl_combined_weights = False
|
module._axolotl_combined_weights = False
|
||||||
module._axolotl_combined_dtype = None
|
module._axolotl_combined_dtype = None
|
||||||
module._axolotl_combined_device = None
|
module._axolotl_combined_device = None
|
||||||
|
module._axolotl_combined_backend = None
|
||||||
module._axolotl_original_specs = {}
|
module._axolotl_original_specs = {}
|
||||||
|
module._axolotl_mg_shapes = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _run_cg_grouped_gemm(
|
||||||
|
module,
|
||||||
|
grouped_hidden: torch.Tensor,
|
||||||
|
m_sizes: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
group_size_m: int,
|
||||||
|
hidden_dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
_ensure_combined_expert_weights(module, hidden_dtype, device, backend="cg")
|
||||||
|
|
||||||
|
expert_index_tensor = torch.repeat_interleave(
|
||||||
|
torch.arange(num_experts, device=device, dtype=torch.int32),
|
||||||
|
m_sizes.to(torch.int64),
|
||||||
|
)
|
||||||
|
|
||||||
|
gate_weights = module.get_parameter("gate_proj_weight")
|
||||||
|
if gate_weights.dim() == 2:
|
||||||
|
out_dim = gate_weights.shape[0] // num_experts
|
||||||
|
gate_weights = gate_weights.view(num_experts, out_dim, gate_weights.shape[1])
|
||||||
|
|
||||||
|
up_weights = module.get_parameter("up_proj_weight")
|
||||||
|
if up_weights.dim() == 2:
|
||||||
|
out_dim = up_weights.shape[0] // num_experts
|
||||||
|
up_weights = up_weights.view(num_experts, out_dim, up_weights.shape[1])
|
||||||
|
|
||||||
|
down_weights = module.get_parameter("down_proj_weight")
|
||||||
|
if down_weights.dim() == 2:
|
||||||
|
out_dim = down_weights.shape[0] // num_experts
|
||||||
|
down_weights = down_weights.view(num_experts, out_dim, down_weights.shape[1])
|
||||||
|
|
||||||
|
gate_out = ContiguousGroupedGEMM.apply(
|
||||||
|
grouped_hidden,
|
||||||
|
gate_weights,
|
||||||
|
expert_index_tensor,
|
||||||
|
group_size_m,
|
||||||
|
)
|
||||||
|
up_out = ContiguousGroupedGEMM.apply(
|
||||||
|
grouped_hidden,
|
||||||
|
up_weights,
|
||||||
|
expert_index_tensor,
|
||||||
|
group_size_m,
|
||||||
|
)
|
||||||
|
down_out = ContiguousGroupedGEMM.apply(
|
||||||
|
grouped_hidden,
|
||||||
|
down_weights,
|
||||||
|
expert_index_tensor,
|
||||||
|
group_size_m,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
gate_out.to(hidden_dtype),
|
||||||
|
up_out.to(hidden_dtype),
|
||||||
|
down_out.to(hidden_dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
gate_out = mg_grouped_gemm(
|
||||||
|
grouped_hidden,
|
||||||
|
module.get_parameter("gate_proj_weight"),
|
||||||
|
m_sizes_tensor,
|
||||||
|
)
|
||||||
|
up_out = mg_grouped_gemm(
|
||||||
|
grouped_hidden,
|
||||||
|
module.get_parameter("up_proj_weight"),
|
||||||
|
m_sizes_tensor,
|
||||||
|
)
|
||||||
|
down_out = mg_grouped_gemm(
|
||||||
|
hidden_grouped,
|
||||||
|
module.get_parameter("down_proj_weight"),
|
||||||
|
m_sizes_tensor,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
gate_out.to(hidden_dtype),
|
||||||
|
up_out.to(hidden_dtype),
|
||||||
|
down_out.to(hidden_dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _moe_triton_forward(
|
def _moe_triton_forward(
|
||||||
@@ -103,6 +209,7 @@ def _moe_triton_forward(
|
|||||||
topk_indices: torch.Tensor,
|
topk_indices: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
group_size_m: int,
|
group_size_m: int,
|
||||||
|
backend: str,
|
||||||
fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if not _is_triton_eligible(hidden_states):
|
if not _is_triton_eligible(hidden_states):
|
||||||
@@ -146,7 +253,7 @@ def _moe_triton_forward(
|
|||||||
) * group_size_m
|
) * group_size_m
|
||||||
max_len = int(aligned_counts.sum().item())
|
max_len = int(aligned_counts.sum().item())
|
||||||
|
|
||||||
permuted_indices, m_sizes, m_offsets = generate_permute_indices(
|
permuted_indices, m_sizes, _ = generate_permute_indices(
|
||||||
counts_int.to(device),
|
counts_int.to(device),
|
||||||
experts_per_rank=num_experts,
|
experts_per_rank=num_experts,
|
||||||
num_ranks=1,
|
num_ranks=1,
|
||||||
@@ -155,12 +262,8 @@ def _moe_triton_forward(
|
|||||||
use_cpu=not hidden_states.is_cuda,
|
use_cpu=not hidden_states.is_cuda,
|
||||||
)
|
)
|
||||||
|
|
||||||
if permuted_indices.device != device:
|
permuted_indices = permuted_indices.to(device)
|
||||||
permuted_indices = permuted_indices.to(device)
|
m_sizes = m_sizes.to(device)
|
||||||
if m_sizes.device != device:
|
|
||||||
m_sizes = m_sizes.to(device)
|
|
||||||
if m_offsets.device != device:
|
|
||||||
m_offsets = m_offsets.to(device)
|
|
||||||
|
|
||||||
permuted_indices_long = permuted_indices.to(torch.int64)
|
permuted_indices_long = permuted_indices.to(torch.int64)
|
||||||
valid_mask = permuted_indices_long >= 0
|
valid_mask = permuted_indices_long >= 0
|
||||||
@@ -178,34 +281,35 @@ def _moe_triton_forward(
|
|||||||
if valid_positions.numel() < max_len:
|
if valid_positions.numel() < max_len:
|
||||||
grouped_hidden.index_fill_(0, padded_positions, 0)
|
grouped_hidden.index_fill_(0, padded_positions, 0)
|
||||||
|
|
||||||
expert_index_tensor = torch.repeat_interleave(
|
m_sizes_tensor = m_sizes.to(device=device, dtype=torch.int32)
|
||||||
torch.arange(num_experts, device=device, dtype=torch.int32),
|
|
||||||
m_sizes.to(torch.int64),
|
|
||||||
)
|
|
||||||
|
|
||||||
_ensure_combined_expert_weights(module, hidden_dtype, device)
|
if backend == "mg":
|
||||||
|
_ensure_combined_expert_weights(module, hidden_dtype, device, backend)
|
||||||
gate_weights = module.get_parameter("gate_proj_weight")
|
gate_out = mg_grouped_gemm(
|
||||||
up_weights = module.get_parameter("up_proj_weight")
|
grouped_hidden,
|
||||||
down_weights = module.get_parameter("down_proj_weight")
|
module.get_parameter("gate_proj_weight"),
|
||||||
|
m_sizes_tensor,
|
||||||
gate_out = ContiguousGroupedGEMM.apply(
|
).to(hidden_dtype)
|
||||||
grouped_hidden,
|
up_out = mg_grouped_gemm(
|
||||||
gate_weights,
|
grouped_hidden,
|
||||||
expert_index_tensor,
|
module.get_parameter("up_proj_weight"),
|
||||||
group_size_m,
|
m_sizes_tensor,
|
||||||
)
|
).to(hidden_dtype)
|
||||||
up_out = ContiguousGroupedGEMM.apply(
|
else:
|
||||||
grouped_hidden,
|
gate_out, up_out, down_out_cg = _run_cg_grouped_gemm(
|
||||||
up_weights,
|
module,
|
||||||
expert_index_tensor,
|
grouped_hidden,
|
||||||
group_size_m,
|
m_sizes,
|
||||||
)
|
num_experts,
|
||||||
|
group_size_m,
|
||||||
|
hidden_dtype,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
|
||||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
|
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
|
||||||
if valid_positions.numel() > 0:
|
if valid_positions.numel() > 0:
|
||||||
gate_valid = gate_out.index_select(0, valid_positions).to(hidden_dtype)
|
gate_valid = gate_out.index_select(0, valid_positions)
|
||||||
up_valid = up_out.index_select(0, valid_positions).to(hidden_dtype)
|
up_valid = up_out.index_select(0, valid_positions)
|
||||||
hidden_concat = act_fn(gate_valid) * up_valid
|
hidden_concat = act_fn(gate_valid) * up_valid
|
||||||
else:
|
else:
|
||||||
hidden_concat = torch.empty(
|
hidden_concat = torch.empty(
|
||||||
@@ -219,15 +323,17 @@ def _moe_triton_forward(
|
|||||||
if valid_positions.numel() < max_len:
|
if valid_positions.numel() < max_len:
|
||||||
hidden_grouped.index_fill_(0, padded_positions, 0)
|
hidden_grouped.index_fill_(0, padded_positions, 0)
|
||||||
|
|
||||||
down_out = ContiguousGroupedGEMM.apply(
|
if backend == "mg":
|
||||||
hidden_grouped,
|
down_out = mg_grouped_gemm(
|
||||||
down_weights,
|
hidden_grouped,
|
||||||
expert_index_tensor,
|
module.get_parameter("down_proj_weight"),
|
||||||
group_size_m,
|
m_sizes_tensor,
|
||||||
)
|
).to(hidden_dtype)
|
||||||
|
else:
|
||||||
|
down_out = down_out_cg
|
||||||
|
|
||||||
if valid_positions.numel() > 0:
|
if valid_positions.numel() > 0:
|
||||||
down_valid = down_out.index_select(0, valid_positions).to(hidden_dtype)
|
down_valid = down_out.index_select(0, valid_positions)
|
||||||
else:
|
else:
|
||||||
down_valid = torch.empty(
|
down_valid = torch.empty(
|
||||||
(0, down_out.shape[-1]), device=device, dtype=hidden_dtype
|
(0, down_out.shape[-1]), device=device, dtype=hidden_dtype
|
||||||
@@ -245,11 +351,16 @@ def _moe_triton_forward(
|
|||||||
return weighted.sum(dim=1)
|
return weighted.sum(dim=1)
|
||||||
|
|
||||||
|
|
||||||
def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
|
def patch_deepseek_v3_moe(
|
||||||
|
group_size_m: int = _GROUP_SIZE_M, backend: str = "mg"
|
||||||
|
) -> None:
|
||||||
"""Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels."""
|
"""Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels."""
|
||||||
|
|
||||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||||
|
|
||||||
|
if backend not in {"cg", "mg"}:
|
||||||
|
raise ValueError(f"Unsupported MoE kernel backend: {backend}")
|
||||||
|
|
||||||
# Record the unpatched implementation so callers can access a true baseline even
|
# Record the unpatched implementation so callers can access a true baseline even
|
||||||
# after the Triton patch has been applied (e.g. repeated microbenchmarks).
|
# after the Triton patch has been applied (e.g. repeated microbenchmarks).
|
||||||
if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"):
|
if not hasattr(DeepseekV3MoE, "_axolotl_triton_original_moe"):
|
||||||
@@ -259,26 +370,26 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
original_moe = DeepseekV3MoE._axolotl_triton_original_moe
|
original_moe = DeepseekV3MoE._axolotl_triton_original_moe
|
||||||
|
DeepseekV3MoE._axolotl_triton_backend = backend
|
||||||
|
DeepseekV3MoE._axolotl_group_size_m = group_size_m
|
||||||
|
|
||||||
def patched_moe(self, hidden_states, topk_indices, topk_weights):
|
def patched_moe(self, hidden_states, topk_indices, topk_weights):
|
||||||
|
backend_sel = getattr(self, "_axolotl_triton_backend", backend)
|
||||||
|
group_size_sel = getattr(self, "_axolotl_group_size_m", group_size_m)
|
||||||
try:
|
try:
|
||||||
return _moe_triton_forward(
|
return _moe_triton_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
topk_indices,
|
topk_indices,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
group_size_m,
|
group_size_sel,
|
||||||
|
backend_sel,
|
||||||
original_moe,
|
original_moe,
|
||||||
)
|
)
|
||||||
except Exception as err: # fall back if Triton compilation or runtime fails
|
except Exception as err: # surface Triton failures explicitly
|
||||||
if not getattr(self, "_axolotl_triton_warned", False):
|
|
||||||
LOG.warning(
|
|
||||||
"DeepseekV3MoE Triton path failed; falling back to baseline: %s",
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
self._axolotl_triton_warned = True
|
|
||||||
_restore_expert_weights(self)
|
_restore_expert_weights(self)
|
||||||
return original_moe(self, hidden_states, topk_indices, topk_weights)
|
LOG.error("DeepseekV3MoE Triton path failed: %s", err)
|
||||||
|
raise
|
||||||
|
|
||||||
DeepseekV3MoE.moe = patched_moe
|
DeepseekV3MoE.moe = patched_moe
|
||||||
DeepseekV3MoE._axolotl_triton_patch = True
|
DeepseekV3MoE._axolotl_triton_patch = True
|
||||||
|
|||||||
@@ -119,6 +119,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Enable Axolotl's vendored MoE kernels when supported (e.g., DeepSeek V3)"
|
"description": "Enable Axolotl's vendored MoE kernels when supported (e.g., DeepSeek V3)"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
moe_kernel_backend: Literal["cg", "mg"] | None = Field(
|
||||||
|
default="mg",
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Grouped GEMM backend to use when `moe_kernels` is enabled. `mg` selects the Hopper TMA kernel; `cg` selects the contiguous kernel."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
trainer_cls: str | None = Field(
|
trainer_cls: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
Reference in New Issue
Block a user