token shuffle kernel

This commit is contained in:
Dan Saunders
2025-09-21 16:46:46 -04:00
parent 18269ee6a9
commit 5c74edeefe
5 changed files with 231 additions and 40 deletions

View File

@@ -36,8 +36,8 @@ DTYPE_MAP = {
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--batch", type=int, default=2, help="batch size")
parser.add_argument("--seq-len", type=int, default=256, help="sequence length")
parser.add_argument("--batch", type=int, default=8, help="batch size")
parser.add_argument("--seq-len", type=int, default=2048, help="sequence length")
parser.add_argument("--hidden-size", type=int, default=4096, help="MoE hidden size")
parser.add_argument(
"--moe-intermediate-size",
@@ -48,13 +48,13 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--n-experts",
type=int,
default=64,
default=256,
help="Number of routed experts",
)
parser.add_argument(
"--top-k",
type=int,
default=4,
default=8,
help="Number of experts per token",
)
parser.add_argument(
@@ -153,6 +153,10 @@ def main() -> None: # pragma: no cover - CLI entrypoint
baseline_module.to(device=device, dtype=dtype)
patched_module.to(device=device, dtype=dtype)
tokens = args.batch * args.seq_len
routed_tokens = tokens * args.top_k
avg_tokens_per_expert = routed_tokens / args.n_experts
inputs = torch.randn(
args.batch,
args.seq_len,
@@ -174,6 +178,9 @@ def main() -> None: # pragma: no cover - CLI entrypoint
print(
f"Device={device.type} dtype={dtype} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
)
print(
f"routed tokens={routed_tokens} avg tokens/expert={avg_tokens_per_expert:.1f} group_size={args.group_size}"
)
print(
f"Baseline: {baseline_ms:.3f} ms | Patched: {patched_ms:.3f} ms | x{speedup:.2f}"
)

View File

@@ -7,6 +7,7 @@ from .tt_cg_gemm import (
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
from .indices import generate_permute_indices
__all__ = [
"cg_grouped_gemm",
@@ -14,4 +15,5 @@ __all__ = [
"cg_grouped_gemm_forward_dynamic",
"ContiguousGroupedGEMM",
"ContiguousGroupedGEMMForwardOnly",
"generate_permute_indices",
]

View File

@@ -0,0 +1,5 @@
"""Token permutation utilities for grouped MoE kernels."""
from .indices import generate_permute_indices
__all__ = ["generate_permute_indices"]

View File

@@ -0,0 +1,144 @@
"""Vendored token permutation kernels from TorchTitan."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import triton
import triton.language as tl
__all__ = ["generate_permute_indices"]
@triton.jit
def _fill_indices_kernel(
tokens_per_expert_group_ptr,
start_index_values_ptr,
write_offsets_ptr,
output_ptr,
experts_per_rank: tl.constexpr,
num_ranks: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_programs = tl.num_programs(axis=0)
for expert_id in range(pid, experts_per_rank, num_programs):
write_offset = tl.load(write_offsets_ptr + expert_id)
for r in range(num_ranks):
idx = r * experts_per_rank + expert_id
start_index = tl.load(start_index_values_ptr + idx)
length = tl.load(tokens_per_expert_group_ptr + idx)
offsets = tl.arange(0, BLOCK_SIZE)
for chunk_start in range(0, length, BLOCK_SIZE):
chunk_offsets = chunk_start + offsets
mask = chunk_offsets < length
values = start_index + chunk_offsets
dest_indices = write_offset + chunk_offsets
tl.store(output_ptr + dest_indices, values, mask=mask)
write_offset += length
def fill_indices_wrapper(
tokens_per_expert_group: torch.Tensor,
start_index_values: torch.Tensor,
write_offsets: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
block_size: int = 128,
max_blocks: int = 1024,
):
permuted_indices = torch.full(
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
)
num_blocks = min(experts_per_rank, max_blocks)
grid = (num_blocks,)
_fill_indices_kernel[grid](
tokens_per_expert_group,
start_index_values,
write_offsets,
permuted_indices,
experts_per_rank,
num_ranks,
BLOCK_SIZE=block_size,
)
return permuted_indices
def fill_indices_cpu(
tokens_per_expert_group: torch.Tensor,
start_index_values: torch.Tensor,
write_offsets: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
):
permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
for expert_id in range(experts_per_rank):
write_start = write_offsets[expert_id].item()
for r in range(num_ranks):
idx = r * experts_per_rank + expert_id
start_index = start_index_values[idx].item()
length = tokens_per_expert_group[idx].item()
if length > 0:
end_idx = min(write_start + length, max_len)
permuted_indices[write_start:end_idx] = torch.arange(
start_index,
start_index + (end_idx - write_start),
dtype=torch.int32,
)
write_start += length
return permuted_indices
def generate_permute_indices(
tokens_per_expert_group: torch.Tensor,
experts_per_rank: int,
num_ranks: int,
max_len: int,
alignment: int,
use_cpu: bool = False,
):
start_index_values = (
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
)
total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment)
m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to(
torch.int32
)
m_offsets = torch.cumsum(m_sizes, 0)
write_offsets = m_offsets - m_sizes
if use_cpu:
permuted_indices = fill_indices_cpu(
tokens_per_expert_group,
start_index_values,
write_offsets,
experts_per_rank,
num_ranks,
max_len,
)
else:
permuted_indices = fill_indices_wrapper(
tokens_per_expert_group,
start_index_values,
write_offsets,
experts_per_rank,
num_ranks,
max_len,
)
return permuted_indices, m_sizes, m_offsets.to(torch.int32)

View File

@@ -8,10 +8,14 @@ from typing import Callable
import torch
from axolotl.kernels.moe import ContiguousGroupedGEMM
from axolotl.kernels.moe.indices import generate_permute_indices
from axolotl.utils.logging import get_logger
_GROUP_SIZE_M = 128
_COMBINED_SUBMODULES = ("gate_proj", "up_proj", "down_proj")
LOG = get_logger(__name__)
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
return hidden_states.is_cuda and hidden_states.shape[0] > 0
@@ -55,9 +59,7 @@ def _ensure_combined_expert_weights(
# DeepseekV3 MLP layers are bias-free, but keep this for safety.
del lin._parameters["bias"]
combined[name] = torch.stack(weights, dim=0).contiguous()
module.register_parameter(
f"{name}_weight", torch.nn.Parameter(combined[name])
)
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined[name]))
module._axolotl_original_specs[name] = (orig_device, orig_dtype)
module._axolotl_combined_weights = True
@@ -72,7 +74,9 @@ def _restore_expert_weights(module) -> None:
for name in _COMBINED_SUBMODULES:
param_name = f"{name}_weight"
combined = module._parameters.pop(param_name)
orig_device, orig_dtype = module._axolotl_original_specs.get(name, (combined.device, combined.dtype))
orig_device, orig_dtype = module._axolotl_original_specs.get(
name, (combined.device, combined.dtype)
)
for idx, expert in enumerate(module.experts):
lin = expert.get_submodule(name)
lin._parameters["weight"] = torch.nn.Parameter(
@@ -82,6 +86,7 @@ def _restore_expert_weights(module) -> None:
module._axolotl_combined_weights = False
module._axolotl_combined_dtype = None
module._axolotl_combined_device = None
module._axolotl_original_specs = {}
def _moe_triton_forward(
@@ -115,36 +120,45 @@ def _moe_triton_forward(
if total_actual == 0:
return hidden_states.new_zeros_like(hidden_states)
padded_counts = (
(
torch.where(
counts > 0,
counts,
torch.full_like(counts, group_size_m),
)
+ group_size_m
- 1
)
counts_int = counts.to(torch.int32)
aligned_counts = (
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1)
// group_size_m
) * group_size_m
max_len = int(aligned_counts.sum().item())
total_padded = int(padded_counts.sum().item())
grouped_hidden = hidden_states.new_zeros((total_padded, hidden_dim))
permuted_indices, m_sizes, m_offsets = generate_permute_indices(
counts_int.to(device),
experts_per_rank=num_experts,
num_ranks=1,
max_len=max_len,
alignment=group_size_m,
use_cpu=not hidden_states.is_cuda,
)
write_offsets = torch.cumsum(padded_counts, dim=0) - padded_counts
actual_offsets = torch.cumsum(counts, dim=0) - counts
if permuted_indices.device != device:
permuted_indices = permuted_indices.to(device)
if m_sizes.device != device:
m_sizes = m_sizes.to(device)
if m_offsets.device != device:
m_offsets = m_offsets.to(device)
repeated_offsets = torch.repeat_interleave(actual_offsets, counts)
token_index = torch.arange(total_actual, device=device) - repeated_offsets
dest_indices = write_offsets[sorted_assignments] + token_index
permuted_indices_long = permuted_indices.to(torch.int64)
valid_mask = permuted_indices_long >= 0
valid_positions = torch.nonzero(valid_mask, as_tuple=False).squeeze(-1)
source_indices = permuted_indices_long[valid_mask]
grouped_hidden.index_copy_(0, dest_indices, sorted_hidden)
padded_counts_idx = padded_counts.to(torch.int64)
expert_index_tensor = (
torch.arange(num_experts, device=device, dtype=torch.int64)
.repeat_interleave(padded_counts_idx)
.to(torch.int32)
.contiguous()
grouped_hidden = hidden_states.new_zeros((max_len, hidden_dim))
if valid_positions.numel() > 0:
grouped_hidden.index_copy_(
0,
valid_positions,
sorted_hidden.index_select(0, source_indices),
)
expert_index_tensor = torch.repeat_interleave(
torch.arange(num_experts, device=device, dtype=torch.int32),
m_sizes.to(torch.int64),
)
_ensure_combined_expert_weights(module, hidden_dtype, device)
@@ -167,13 +181,17 @@ def _moe_triton_forward(
)
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
valid_gate = gate_out.index_select(0, dest_indices).to(hidden_dtype)
valid_up = up_out.index_select(0, dest_indices).to(hidden_dtype)
hidden_concat = act_fn(valid_gate) * valid_up
if valid_positions.numel() > 0:
gate_valid = gate_out.index_select(0, valid_positions).to(hidden_dtype)
up_valid = up_out.index_select(0, valid_positions).to(hidden_dtype)
hidden_concat = act_fn(gate_valid) * up_valid
else:
hidden_concat = torch.empty((0, gate_out.shape[-1]), device=device, dtype=hidden_dtype)
intermediate_dim = hidden_concat.shape[-1]
hidden_grouped = hidden_states.new_zeros((total_padded, intermediate_dim))
hidden_grouped.index_copy_(0, dest_indices, hidden_concat)
hidden_grouped = hidden_states.new_zeros((max_len, intermediate_dim))
if valid_positions.numel() > 0:
hidden_grouped.index_copy_(0, valid_positions, hidden_concat)
down_out = ContiguousGroupedGEMM.apply(
hidden_grouped,
@@ -182,10 +200,19 @@ def _moe_triton_forward(
group_size_m,
)
down_valid = down_out.index_select(0, dest_indices).to(hidden_dtype)
if valid_positions.numel() > 0:
down_valid = down_out.index_select(0, valid_positions).to(hidden_dtype)
else:
down_valid = torch.empty((0, down_out.shape[-1]), device=device, dtype=hidden_dtype)
sorted_outputs = hidden_states.new_empty((total_actual, hidden_dim))
if down_valid.numel() > 0:
sorted_outputs.index_copy_(0, source_indices, down_valid)
else:
sorted_outputs.zero_()
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
expanded_output.index_copy_(0, sort_perm, down_valid)
expanded_output.index_copy_(0, sort_perm, sorted_outputs)
expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim)
weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype)
@@ -212,7 +239,13 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
group_size_m,
original_moe,
)
except RuntimeError:
except RuntimeError as err:
if not getattr(self, "_axolotl_triton_warned", False):
LOG.warning(
"DeepseekV3MoE Triton path failed; falling back to baseline: %s",
err,
)
self._axolotl_triton_warned = True
_restore_expert_weights(self)
return original_moe(self, hidden_states, topk_indices, topk_weights)