vendor torchtitan moe kernels
This commit is contained in:
17
src/axolotl/kernels/moe/__init__.py
Normal file
17
src/axolotl/kernels/moe/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Mixture-of-Experts kernel implementations."""
|
||||
|
||||
from .tt_cg_gemm import (
|
||||
ContiguousGroupedGEMM,
|
||||
ContiguousGroupedGEMMForwardOnly,
|
||||
cg_grouped_gemm,
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"cg_grouped_gemm",
|
||||
"cg_grouped_gemm_forward",
|
||||
"cg_grouped_gemm_forward_dynamic",
|
||||
"ContiguousGroupedGEMM",
|
||||
"ContiguousGroupedGEMMForwardOnly",
|
||||
]
|
||||
17
src/axolotl/kernels/moe/tt_cg_gemm/__init__.py
Normal file
17
src/axolotl/kernels/moe/tt_cg_gemm/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Vendored Triton contiguous grouped GEMM kernels from TorchTitan."""
|
||||
|
||||
from .cg_backward import ContiguousGroupedGEMM
|
||||
from .cg_forward import (
|
||||
ContiguousGroupedGEMM as ContiguousGroupedGEMMForwardOnly,
|
||||
cg_grouped_gemm,
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"cg_grouped_gemm",
|
||||
"cg_grouped_gemm_forward",
|
||||
"cg_grouped_gemm_forward_dynamic",
|
||||
"ContiguousGroupedGEMM",
|
||||
"ContiguousGroupedGEMMForwardOnly",
|
||||
]
|
||||
290
src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py
Normal file
290
src/axolotl/kernels/moe/tt_cg_gemm/cg_backward.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Vendored backward pass for Triton contiguous grouped GEMM."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .cg_forward import cg_grouped_gemm_forward
|
||||
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
|
||||
|
||||
GROUP_SIZE_M = 128
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=STANDARD_CONFIGS,
|
||||
key=["M_TOTAL", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_cg_backward_dx(
|
||||
grad_output_ptr,
|
||||
b_ptr,
|
||||
grad_input_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
|
||||
):
|
||||
"""Compute gradients with respect to inputs."""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
|
||||
num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
|
||||
tile_m = pid // num_k_tiles
|
||||
tile_k = pid % num_k_tiles
|
||||
|
||||
m_start = tile_m * BLOCK_SIZE_M
|
||||
k_start = tile_k * BLOCK_SIZE_K
|
||||
|
||||
if m_start < M_TOTAL:
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_k = offs_k < K
|
||||
|
||||
group_idx = m_start // GROUP_SIZE_M
|
||||
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
|
||||
|
||||
grad_input = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_K], dtype=tl.float32)
|
||||
|
||||
for n in range(0, N, BLOCK_SIZE_N):
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N) + n
|
||||
mask_n = offs_n < N
|
||||
|
||||
mask_go = mask_m[:, None] & mask_n[None, :]
|
||||
mask_w = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
go_ptrs = grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0)
|
||||
|
||||
w_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
w = tl.load(w_ptrs, mask=mask_w, other=0.0)
|
||||
|
||||
grad_input += tl.dot(go, w)
|
||||
|
||||
grad_input_ptrs = grad_input_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
mask_gi = mask_m[:, None] & mask_k[None, :]
|
||||
tl.store(grad_input_ptrs, grad_input, mask=mask_gi)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _kernel_cg_backward_dw(
|
||||
grad_output_ptr,
|
||||
inputs_ptr,
|
||||
grad_weights_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""Simplified kernel for expert weight gradients."""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
expert_id = pid // ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
|
||||
position_id = pid % ((N * K) // (BLOCK_SIZE_N * BLOCK_SIZE_K))
|
||||
|
||||
if expert_id < NUM_EXPERTS:
|
||||
n_tiles = K // BLOCK_SIZE_K
|
||||
tile_n = position_id // n_tiles
|
||||
tile_k = position_id % n_tiles
|
||||
|
||||
n_start = tile_n * BLOCK_SIZE_N
|
||||
k_start = tile_k * BLOCK_SIZE_K
|
||||
|
||||
if n_start < N and k_start < K:
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + k_start
|
||||
|
||||
mask_n = offs_n < N
|
||||
mask_k = offs_k < K
|
||||
|
||||
grad_weights = tl.zeros([BLOCK_SIZE_N, BLOCK_SIZE_K], dtype=tl.float32)
|
||||
|
||||
for group_idx in range(0, M_TOTAL // GROUP_SIZE_M):
|
||||
group_start = group_idx * GROUP_SIZE_M
|
||||
group_expert = tl.load(indices_ptr + group_start)
|
||||
|
||||
if group_expert == expert_id:
|
||||
for m_offset in range(0, GROUP_SIZE_M, BLOCK_SIZE_M):
|
||||
m_start = group_start + m_offset
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
|
||||
|
||||
mask_m = offs_m < min(group_start + GROUP_SIZE_M, M_TOTAL)
|
||||
|
||||
go_ptrs = (
|
||||
grad_output_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
)
|
||||
mask_go = mask_m[:, None] & mask_n[None, :]
|
||||
go = tl.load(go_ptrs, mask=mask_go, other=0.0)
|
||||
|
||||
in_ptrs = inputs_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
mask_in = mask_m[:, None] & mask_k[None, :]
|
||||
inp = tl.load(in_ptrs, mask=mask_in, other=0.0)
|
||||
|
||||
go_t = tl.trans(go)
|
||||
grad_weights += tl.dot(go_t, inp)
|
||||
|
||||
grad_w_ptrs = (
|
||||
grad_weights_ptr
|
||||
+ expert_id * N * K
|
||||
+ offs_n[:, None] * K
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
mask_gw = mask_n[:, None] & mask_k[None, :]
|
||||
tl.store(grad_w_ptrs, grad_weights, mask=mask_gw)
|
||||
|
||||
|
||||
def cg_grouped_gemm_backward_weights(
|
||||
grad_output: torch.Tensor,
|
||||
inputs: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
num_experts: int,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Backward pass for expert weights."""
|
||||
|
||||
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
|
||||
assert inputs.is_contiguous(), "Inputs tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, N = grad_output.shape
|
||||
_, K = inputs.shape
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
grad_weights = torch.zeros(
|
||||
(num_experts, N, K), device=grad_output.device, dtype=grad_output.dtype
|
||||
)
|
||||
|
||||
block_size_n = min(128, N)
|
||||
block_size_k = min(32, K)
|
||||
block_size_m = min(32, group_size_m)
|
||||
|
||||
n_tiles = triton.cdiv(N, block_size_n)
|
||||
k_tiles = triton.cdiv(K, block_size_k)
|
||||
grid = (num_experts * n_tiles * k_tiles,)
|
||||
|
||||
_kernel_cg_backward_dw[grid](
|
||||
grad_output,
|
||||
inputs,
|
||||
grad_weights,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
BLOCK_SIZE_N=block_size_n,
|
||||
BLOCK_SIZE_K=block_size_k,
|
||||
BLOCK_SIZE_M=block_size_m,
|
||||
)
|
||||
|
||||
return grad_weights
|
||||
|
||||
|
||||
def cg_grouped_gemm_backward_inputs(
|
||||
grad_output: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Backward pass for inputs."""
|
||||
|
||||
assert grad_output.is_contiguous(), "Grad output tensor must be contiguous"
|
||||
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, N = grad_output.shape
|
||||
num_experts, _, K = expert_weights.shape
|
||||
|
||||
assert M_total % group_size_m == 0, (
|
||||
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
|
||||
)
|
||||
|
||||
grad_inputs = torch.zeros(
|
||||
(M_total, K), device=grad_output.device, dtype=grad_output.dtype
|
||||
)
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(K, meta["BLOCK_SIZE_K"]),
|
||||
)
|
||||
|
||||
_kernel_cg_backward_dx[grid](
|
||||
grad_output,
|
||||
expert_weights,
|
||||
grad_inputs,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
)
|
||||
|
||||
return grad_inputs
|
||||
|
||||
|
||||
class ContiguousGroupedGEMM(torch.autograd.Function):
|
||||
"""Autograd function with full backward support."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
|
||||
ctx.save_for_backward(inputs, expert_weights, expert_indices)
|
||||
ctx.group_size_m = group_size_m
|
||||
|
||||
return cg_grouped_gemm_forward(
|
||||
inputs=inputs,
|
||||
expert_weights=expert_weights,
|
||||
expert_indices=expert_indices,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
inputs, expert_weights, expert_indices = ctx.saved_tensors
|
||||
group_size_m = ctx.group_size_m
|
||||
|
||||
grad_output = grad_output.contiguous()
|
||||
num_experts = expert_weights.shape[0]
|
||||
|
||||
grad_inputs = cg_grouped_gemm_backward_inputs(
|
||||
grad_output=grad_output,
|
||||
expert_weights=expert_weights,
|
||||
expert_indices=expert_indices,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
grad_weights = cg_grouped_gemm_backward_weights(
|
||||
grad_output=grad_output,
|
||||
inputs=inputs,
|
||||
expert_indices=expert_indices,
|
||||
num_experts=num_experts,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
grad_indices = None
|
||||
grad_group_size_m = None
|
||||
|
||||
return grad_inputs, grad_weights, grad_indices, grad_group_size_m
|
||||
311
src/axolotl/kernels/moe/tt_cg_gemm/cg_forward.py
Normal file
311
src/axolotl/kernels/moe/tt_cg_gemm/cg_forward.py
Normal file
@@ -0,0 +1,311 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Vendored forward Triton contiguous grouped GEMM kernels."""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .tma_cuda_autotune import STANDARD_CONFIGS, early_config_prune
|
||||
|
||||
GROUP_SIZE_M = 128
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, super_group_m):
|
||||
group_id = tile_id // num_pid_in_group
|
||||
first_pid_m = group_id * super_group_m
|
||||
group_size_m = min(num_pid_m - first_pid_m, super_group_m)
|
||||
pid_m = first_pid_m + (tile_id % group_size_m)
|
||||
pid_n = (tile_id % num_pid_in_group) // group_size_m
|
||||
return pid_m, pid_n
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=STANDARD_CONFIGS,
|
||||
key=["M_TOTAL", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_cg_persistent_forward(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
NUM_SMS: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
|
||||
SUPER_GROUP_M: tl.constexpr = 32,
|
||||
):
|
||||
"""
|
||||
Contiguous Grouped GEMM kernel forward (persistent variant).
|
||||
"""
|
||||
|
||||
c_type = c_ptr.dtype.element_ty
|
||||
|
||||
start_pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_tiles = num_pid_m * num_pid_n
|
||||
tile_id_c = start_pid - NUM_SMS
|
||||
num_pid_in_group = SUPER_GROUP_M * num_pid_n
|
||||
|
||||
for tile_id in tl.range(start_pid, num_tiles, NUM_SMS):
|
||||
tile_m_idx, tile_n_idx = _compute_pid(
|
||||
tile_id, num_pid_in_group, num_pid_m, SUPER_GROUP_M
|
||||
)
|
||||
|
||||
m_start = tile_m_idx * BLOCK_SIZE_M
|
||||
n_start = tile_n_idx * BLOCK_SIZE_N
|
||||
|
||||
if m_start < M_TOTAL:
|
||||
offs_m = m_start + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = n_start + tl.arange(0, BLOCK_SIZE_N)
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for ki in range(k_tiles):
|
||||
offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_n = offs_n < N
|
||||
mask_k = offs_k < K
|
||||
|
||||
mask_a = mask_m[:, None] & mask_k[None, :]
|
||||
mask_b = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
group_idx = m_start // GROUP_SIZE_M
|
||||
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
|
||||
|
||||
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
|
||||
|
||||
b_ptrs = (
|
||||
b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
)
|
||||
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
|
||||
|
||||
accumulator += tl.dot(a, b.T)
|
||||
|
||||
tile_id_c += NUM_SMS
|
||||
tile_m_idx, tile_n_idx = _compute_pid(
|
||||
tile_id_c, num_pid_in_group, num_pid_m, SUPER_GROUP_M
|
||||
)
|
||||
|
||||
offs_m = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_n = offs_n < N
|
||||
mask_c = mask_m[:, None] & mask_n[None, :]
|
||||
|
||||
c = accumulator.to(tl.float32)
|
||||
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
tl.store(c_ptrs, c.to(c_type), mask=mask_c)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=STANDARD_CONFIGS,
|
||||
key=["M_TOTAL", "N", "K"],
|
||||
prune_configs_by={"early_config_prune": early_config_prune},
|
||||
)
|
||||
@triton.jit
|
||||
def _kernel_cg_forward_aligned(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
indices_ptr,
|
||||
M_TOTAL: tl.constexpr,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
NUM_EXPERTS: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr = GROUP_SIZE_M,
|
||||
):
|
||||
"""
|
||||
Contiguous Grouped GEMM kernel forward for aligned inputs.
|
||||
"""
|
||||
|
||||
pid = tl.program_id(0)
|
||||
|
||||
c_type = c_ptr.dtype.element_ty
|
||||
|
||||
num_m_tiles = tl.cdiv(M_TOTAL, BLOCK_SIZE_M)
|
||||
num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
tile_m = pid // num_n_tiles
|
||||
tile_n = pid % num_n_tiles
|
||||
|
||||
m_start = tile_m * BLOCK_SIZE_M
|
||||
n_start = tile_n * BLOCK_SIZE_N
|
||||
|
||||
if m_start < M_TOTAL:
|
||||
offs_m = tl.arange(0, BLOCK_SIZE_M) + m_start
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_N) + n_start
|
||||
|
||||
mask_m = offs_m < M_TOTAL
|
||||
mask_n = offs_n < N
|
||||
|
||||
group_idx = m_start // GROUP_SIZE_M
|
||||
expert_idx = tl.load(indices_ptr + group_idx * GROUP_SIZE_M)
|
||||
|
||||
acc = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], dtype=tl.float32)
|
||||
|
||||
for k in range(0, K, BLOCK_SIZE_K):
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + k
|
||||
mask_k = offs_k < K
|
||||
|
||||
mask_a = mask_m[:, None] & mask_k[None, :]
|
||||
mask_b = mask_n[:, None] & mask_k[None, :]
|
||||
|
||||
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=mask_a, other=0.0)
|
||||
|
||||
b_ptrs = b_ptr + expert_idx * N * K + offs_n[:, None] * K + offs_k[None, :]
|
||||
b = tl.load(b_ptrs, mask=mask_b, other=0.0)
|
||||
|
||||
acc += tl.dot(a, b.T)
|
||||
|
||||
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
|
||||
mask_c = mask_m[:, None] & mask_n[None, :]
|
||||
tl.store(c_ptrs, acc.to(c_type), mask=mask_c)
|
||||
|
||||
|
||||
def cg_grouped_gemm_forward(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Contiguous grouped GEMM forward pass for MoE."""
|
||||
|
||||
assert inputs.is_contiguous(), "Input tensor must be contiguous"
|
||||
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, K = inputs.shape
|
||||
assert M_total % group_size_m == 0, (
|
||||
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
|
||||
)
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
num_experts, N, K_weights = expert_weights.shape
|
||||
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
|
||||
assert expert_indices.shape[0] == M_total, (
|
||||
"Expert indices length must match M_total"
|
||||
)
|
||||
|
||||
output = torch.empty((M_total, N), device=inputs.device, dtype=torch.bfloat16)
|
||||
|
||||
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
|
||||
grid = (NUM_SMS, 1, 1)
|
||||
_kernel_cg_persistent_forward[grid](
|
||||
inputs,
|
||||
expert_weights,
|
||||
output,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
NUM_SMS=NUM_SMS,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def cg_grouped_gemm_forward_dynamic(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Contiguous grouped GEMM forward pass for MoE with autotuned launch."""
|
||||
|
||||
assert inputs.is_contiguous(), "Input tensor must be contiguous"
|
||||
assert expert_weights.is_contiguous(), "Expert weights tensor must be contiguous"
|
||||
assert expert_indices.is_contiguous(), "Expert indices tensor must be contiguous"
|
||||
|
||||
M_total, K = inputs.shape
|
||||
assert M_total % group_size_m == 0, (
|
||||
f"M_total ({M_total}) must be a multiple of group_size_m ({group_size_m})"
|
||||
)
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
num_experts, N, K_weights = expert_weights.shape
|
||||
assert K == K_weights, f"Input K ({K}) must match weight K ({K_weights})"
|
||||
assert expert_indices.shape[0] == M_total, (
|
||||
"Expert indices length must match M_total"
|
||||
)
|
||||
|
||||
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
|
||||
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(M_total, meta["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(N, meta["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
_kernel_cg_forward_aligned[grid](
|
||||
inputs,
|
||||
expert_weights,
|
||||
output,
|
||||
expert_indices,
|
||||
M_TOTAL=M_total,
|
||||
N=N,
|
||||
K=K,
|
||||
NUM_EXPERTS=num_experts,
|
||||
GROUP_SIZE_M=group_size_m,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ContiguousGroupedGEMM(torch.autograd.Function):
|
||||
"""Autograd function for contiguous grouped GEMM forward pass only."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, expert_weights, expert_indices, group_size_m=GROUP_SIZE_M):
|
||||
return cg_grouped_gemm_forward(
|
||||
inputs=inputs,
|
||||
expert_weights=expert_weights,
|
||||
expert_indices=expert_indices,
|
||||
group_size_m=group_size_m,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output): # pragma: no cover - not implemented
|
||||
raise NotImplementedError("Backward pass not implemented")
|
||||
|
||||
|
||||
def cg_grouped_gemm(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = GROUP_SIZE_M,
|
||||
) -> torch.Tensor:
|
||||
"""Convenience wrapper for the forward-only autograd function."""
|
||||
|
||||
if expert_indices.dtype != torch.int32:
|
||||
expert_indices = expert_indices.to(torch.int32)
|
||||
|
||||
return ContiguousGroupedGEMM.apply(
|
||||
inputs, expert_weights, expert_indices, group_size_m
|
||||
)
|
||||
31
src/axolotl/kernels/moe/tt_cg_gemm/cg_reference.py
Normal file
31
src/axolotl/kernels/moe/tt_cg_gemm/cg_reference.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Reference implementation for contiguous grouped GEMM."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def pytorch_reference(
|
||||
inputs: torch.Tensor,
|
||||
expert_weights: torch.Tensor,
|
||||
expert_indices: torch.Tensor,
|
||||
group_size_m: int = 128,
|
||||
) -> torch.Tensor:
|
||||
"""Simple PyTorch implementation for verification."""
|
||||
|
||||
M_total, K = inputs.shape
|
||||
num_experts, N, _ = expert_weights.shape
|
||||
|
||||
output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype)
|
||||
|
||||
for i in range(0, M_total, group_size_m):
|
||||
end_idx = min(i + group_size_m, M_total)
|
||||
expert_idx = expert_indices[i].item()
|
||||
expert_weight = expert_weights[expert_idx]
|
||||
output[i:end_idx] = torch.matmul(inputs[i:end_idx], expert_weight.T)
|
||||
|
||||
return output
|
||||
209
src/axolotl/kernels/moe/tt_cg_gemm/tma_cuda_autotune.py
Normal file
209
src/axolotl/kernels/moe/tt_cg_gemm/tma_cuda_autotune.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Autotuning utilities for Triton contiguous grouped GEMM kernels."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.runtime import driver
|
||||
|
||||
|
||||
class CudaUtils:
|
||||
"""Helper utilities for CUDA specific Triton features."""
|
||||
|
||||
@staticmethod
|
||||
def is_cuda() -> bool:
|
||||
return driver.active.get_current_target().backend == "cuda"
|
||||
|
||||
@staticmethod
|
||||
def verify_tma() -> bool:
|
||||
return (
|
||||
CudaUtils.is_cuda()
|
||||
and torch.cuda.is_available()
|
||||
and torch.cuda.get_device_capability()[0] >= 9
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_num_sms() -> int:
|
||||
if not CudaUtils.is_cuda():
|
||||
raise RuntimeError("Triton is not running on CUDA backend")
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA is not available")
|
||||
return torch.cuda.get_device_properties("cuda").multi_processor_count
|
||||
|
||||
|
||||
class TmaDescriptorHelper:
|
||||
"""Helper class for managing TMA descriptors in Triton kernels."""
|
||||
|
||||
class KernelParamWrapper:
|
||||
def __init__(self, desc: torch.Tensor):
|
||||
self.desc = desc
|
||||
|
||||
def tma_desc_cpu_ptr(self) -> int:
|
||||
return self.desc.data_ptr()
|
||||
|
||||
def __init__(self, tma_size: int = 128):
|
||||
if not CudaUtils.verify_tma():
|
||||
raise RuntimeError(
|
||||
"TMA not supported on this device (requires Hopper or newer)"
|
||||
)
|
||||
if "nv_tma_desc_type" not in dir(tl):
|
||||
raise RuntimeError(
|
||||
"TMA grid constant descriptors not supported in your Triton version"
|
||||
)
|
||||
|
||||
self.tma_size = tma_size
|
||||
self.fill_1d_tma_descriptor_inner = driver.active.utils.fill_1d_tma_descriptor
|
||||
self.fill_2d_tma_descriptor_inner = driver.active.utils.fill_2d_tma_descriptor
|
||||
self.descriptors: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def init_tma_descriptor(self, name: str) -> None:
|
||||
self.descriptors[name] = torch.empty(
|
||||
self.tma_size, device="cpu", dtype=torch.int8
|
||||
)
|
||||
|
||||
def fill_1d_tma_descriptor(
|
||||
self, name: str, ptr: int, dim: int, block_dim: int, element_size: int
|
||||
) -> None:
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_1d_tma_descriptor_inner(
|
||||
ptr, dim, block_dim, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def fill_2d_tma_descriptor(
|
||||
self,
|
||||
name: str,
|
||||
ptr: int,
|
||||
dim1: int,
|
||||
dim0: int,
|
||||
block_dim1: int,
|
||||
block_dim0: int,
|
||||
element_size: int,
|
||||
) -> None:
|
||||
if name not in self.descriptors:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
|
||||
desc_x = self.descriptors[name]
|
||||
if desc_x.data_ptr() % 64 != 0:
|
||||
raise ValueError("TMA descriptor must be 64-byte aligned")
|
||||
self.fill_2d_tma_descriptor_inner(
|
||||
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
|
||||
)
|
||||
|
||||
def get_tma_descriptor_kernel_param(
|
||||
self, name: str
|
||||
) -> "TmaDescriptorHelper.KernelParamWrapper":
|
||||
if name not in self.descriptors or self.descriptors[name] is None:
|
||||
raise ValueError(f"TMA descriptor '{name}' not initialized")
|
||||
return self.KernelParamWrapper(self.descriptors[name])
|
||||
|
||||
|
||||
HOPPER_CONFIGS = [
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=8,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
STANDARD_CONFIGS = [
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=8,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def early_config_prune(configs, args, **kwargs):
|
||||
"""Filter out configurations that would exceed shared memory capacity."""
|
||||
k = kwargs.get("K", 0)
|
||||
valid_configs = [
|
||||
config for config in configs if config.kwargs.get("BLOCK_SIZE_K", 0) <= k
|
||||
]
|
||||
if not valid_configs and configs:
|
||||
return [
|
||||
min(
|
||||
configs,
|
||||
key=lambda c: c.kwargs.get("BLOCK_SIZE_K", float("inf")),
|
||||
)
|
||||
]
|
||||
|
||||
return valid_configs
|
||||
@@ -190,6 +190,11 @@ class PatchManager:
|
||||
|
||||
apply_mistral_tokenizer_image_patch()
|
||||
|
||||
if self.cfg.model_config_type == "deepseek_v3":
|
||||
from axolotl.monkeypatch.deepseek_v3 import patch_deepseek_v3_moe
|
||||
|
||||
patch_deepseek_v3_moe()
|
||||
|
||||
def _apply_fp8_patches(self):
|
||||
"""Apply patches for FP8 support."""
|
||||
if self.cfg.fp8:
|
||||
|
||||
187
src/axolotl/monkeypatch/deepseek_v3/__init__.py
Normal file
187
src/axolotl/monkeypatch/deepseek_v3/__init__.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""Monkeypatches for DeepSeek V3 MoE to use Triton contiguous grouped GEMM kernels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.kernels.moe import ContiguousGroupedGEMM
|
||||
|
||||
_GROUP_SIZE_M = 128
|
||||
|
||||
|
||||
def _align_to(value: int, alignment: int) -> int:
|
||||
if value <= 0:
|
||||
return 0
|
||||
return math.ceil(value / alignment) * alignment
|
||||
|
||||
|
||||
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
|
||||
return hidden_states.is_cuda and hidden_states.shape[0] > 0
|
||||
|
||||
|
||||
def _collect_expert_weights(module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
gate_weights = []
|
||||
up_weights = []
|
||||
down_weights = []
|
||||
for expert in module.experts:
|
||||
gate_weights.append(expert.gate_proj.weight)
|
||||
up_weights.append(expert.up_proj.weight)
|
||||
down_weights.append(expert.down_proj.weight)
|
||||
gate = torch.stack(gate_weights, dim=0).contiguous()
|
||||
up = torch.stack(up_weights, dim=0).contiguous()
|
||||
down = torch.stack(down_weights, dim=0).contiguous()
|
||||
return gate, up, down
|
||||
|
||||
|
||||
def _moe_triton_forward(
|
||||
module,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
group_size_m: int,
|
||||
fallback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if not _is_triton_eligible(hidden_states):
|
||||
return fallback(hidden_states, topk_indices, topk_weights)
|
||||
|
||||
device = hidden_states.device
|
||||
hidden_dtype = hidden_states.dtype
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
top_k = topk_indices.size(-1)
|
||||
|
||||
expanded_hidden = hidden_states.repeat_interleave(top_k, dim=0)
|
||||
expert_assignments = topk_indices.reshape(-1)
|
||||
if expanded_hidden.numel() == 0:
|
||||
return hidden_states.new_zeros_like(hidden_states)
|
||||
|
||||
sort_perm = torch.argsort(expert_assignments)
|
||||
sorted_hidden = expanded_hidden.index_select(0, sort_perm)
|
||||
sorted_assignments = expert_assignments.index_select(0, sort_perm)
|
||||
|
||||
num_experts = len(module.experts)
|
||||
counts = torch.bincount(sorted_assignments, minlength=num_experts)
|
||||
counts_cpu = counts.to(torch.int64).cpu().tolist()
|
||||
padded_counts = [_align_to(c, group_size_m) for c in counts_cpu]
|
||||
|
||||
total_actual = sum(counts_cpu)
|
||||
total_padded = sum(padded_counts)
|
||||
if total_actual == 0 or total_padded == 0:
|
||||
return hidden_states.new_zeros_like(hidden_states)
|
||||
|
||||
actual_offsets = [0]
|
||||
padded_offsets = [0]
|
||||
for count, padded in zip(counts_cpu, padded_counts, strict=False):
|
||||
actual_offsets.append(actual_offsets[-1] + count)
|
||||
padded_offsets.append(padded_offsets[-1] + padded)
|
||||
|
||||
grouped_hidden = hidden_states.new_zeros((total_padded, hidden_dim))
|
||||
expert_index_tensor = torch.empty(total_padded, dtype=torch.int32, device=device)
|
||||
|
||||
for idx, (count, padded) in enumerate(zip(counts_cpu, padded_counts, strict=False)):
|
||||
dst_start = padded_offsets[idx]
|
||||
dst_end = dst_start + padded
|
||||
if padded == 0:
|
||||
continue
|
||||
expert_index_tensor[dst_start:dst_end] = idx
|
||||
if count > 0:
|
||||
src_start = actual_offsets[idx]
|
||||
src_end = src_start + count
|
||||
grouped_hidden[dst_start : dst_start + count].copy_(
|
||||
sorted_hidden[src_start:src_end]
|
||||
)
|
||||
|
||||
gate_weights, up_weights, down_weights = _collect_expert_weights(module)
|
||||
|
||||
gate_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
gate_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
up_out = ContiguousGroupedGEMM.apply(
|
||||
grouped_hidden,
|
||||
up_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
|
||||
|
||||
hidden_chunks = []
|
||||
for idx, count in enumerate(counts_cpu):
|
||||
if count == 0:
|
||||
continue
|
||||
pad_start = padded_offsets[idx]
|
||||
pad_end = pad_start + count
|
||||
gate_slice = gate_out[pad_start:pad_end].to(hidden_dtype)
|
||||
up_slice = up_out[pad_start:pad_end].to(hidden_dtype)
|
||||
hidden_chunks.append(act_fn(gate_slice) * up_slice)
|
||||
|
||||
hidden_concat = torch.cat(hidden_chunks, dim=0)
|
||||
|
||||
intermediate_dim = hidden_concat.shape[-1]
|
||||
hidden_grouped = hidden_states.new_zeros((total_padded, intermediate_dim))
|
||||
|
||||
for idx, count in enumerate(counts_cpu):
|
||||
if count == 0:
|
||||
continue
|
||||
pad_start = padded_offsets[idx]
|
||||
src_start = actual_offsets[idx]
|
||||
src_end = src_start + count
|
||||
hidden_grouped[pad_start : pad_start + count].copy_(
|
||||
hidden_concat[src_start:src_end]
|
||||
)
|
||||
|
||||
down_out = ContiguousGroupedGEMM.apply(
|
||||
hidden_grouped,
|
||||
down_weights,
|
||||
expert_index_tensor,
|
||||
group_size_m,
|
||||
)
|
||||
|
||||
down_chunks = []
|
||||
for idx, count in enumerate(counts_cpu):
|
||||
if count == 0:
|
||||
continue
|
||||
pad_start = padded_offsets[idx]
|
||||
pad_end = pad_start + count
|
||||
down_chunks.append(down_out[pad_start:pad_end].to(hidden_dtype))
|
||||
|
||||
down_concat = torch.cat(down_chunks, dim=0)
|
||||
|
||||
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
|
||||
expanded_output.index_copy_(0, sort_perm, down_concat.to(hidden_dtype))
|
||||
expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim)
|
||||
|
||||
weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype)
|
||||
return weighted.sum(dim=1)
|
||||
|
||||
|
||||
def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
|
||||
"""Patch HuggingFace DeepseekV3MoE to use Triton contiguous group GEMM kernels."""
|
||||
|
||||
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
|
||||
|
||||
if getattr(DeepseekV3MoE, "_axolotl_triton_patch", False):
|
||||
return
|
||||
|
||||
original_moe = DeepseekV3MoE.moe
|
||||
|
||||
def patched_moe(self, hidden_states, topk_indices, topk_weights):
|
||||
with contextlib.suppress(RuntimeError):
|
||||
return _moe_triton_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
topk_indices,
|
||||
topk_weights,
|
||||
group_size_m,
|
||||
original_moe,
|
||||
)
|
||||
return original_moe(self, hidden_states, topk_indices, topk_weights)
|
||||
|
||||
DeepseekV3MoE.moe = patched_moe
|
||||
DeepseekV3MoE._axolotl_triton_patch = True
|
||||
Reference in New Issue
Block a user