token shuffle kernel
This commit is contained in:
@@ -36,8 +36,8 @@ DTYPE_MAP = {
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--batch", type=int, default=2, help="batch size")
|
||||
parser.add_argument("--seq-len", type=int, default=256, help="sequence length")
|
||||
parser.add_argument("--batch", type=int, default=8, help="batch size")
|
||||
parser.add_argument("--seq-len", type=int, default=2048, help="sequence length")
|
||||
parser.add_argument("--hidden-size", type=int, default=4096, help="MoE hidden size")
|
||||
parser.add_argument(
|
||||
"--moe-intermediate-size",
|
||||
@@ -48,13 +48,13 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument(
|
||||
"--n-experts",
|
||||
type=int,
|
||||
default=64,
|
||||
default=256,
|
||||
help="Number of routed experts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=4,
|
||||
default=8,
|
||||
help="Number of experts per token",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -153,6 +153,10 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
baseline_module.to(device=device, dtype=dtype)
|
||||
patched_module.to(device=device, dtype=dtype)
|
||||
|
||||
tokens = args.batch * args.seq_len
|
||||
routed_tokens = tokens * args.top_k
|
||||
avg_tokens_per_expert = routed_tokens / args.n_experts
|
||||
|
||||
inputs = torch.randn(
|
||||
args.batch,
|
||||
args.seq_len,
|
||||
@@ -174,6 +178,9 @@ def main() -> None: # pragma: no cover - CLI entrypoint
|
||||
print(
|
||||
f"Device={device.type} dtype={dtype} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}"
|
||||
)
|
||||
print(
|
||||
f"routed tokens={routed_tokens} avg tokens/expert={avg_tokens_per_expert:.1f} group_size={args.group_size}"
|
||||
)
|
||||
print(
|
||||
f"Baseline: {baseline_ms:.3f} ms | Patched: {patched_ms:.3f} ms | x{speedup:.2f}"
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from .tt_cg_gemm import (
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
from .indices import generate_permute_indices
|
||||
|
||||
__all__ = [
|
||||
"cg_grouped_gemm",
|
||||
@@ -14,4 +15,5 @@ __all__ = [
|
||||
"cg_grouped_gemm_forward_dynamic",
|
||||
"ContiguousGroupedGEMM",
|
||||
"ContiguousGroupedGEMMForwardOnly",
|
||||
"generate_permute_indices",
|
||||
]
|
||||
|
||||
5
src/axolotl/kernels/moe/indices/__init__.py
Normal file
5
src/axolotl/kernels/moe/indices/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Token permutation utilities for grouped MoE kernels."""
|
||||
|
||||
from .indices import generate_permute_indices
|
||||
|
||||
__all__ = ["generate_permute_indices"]
|
||||
144
src/axolotl/kernels/moe/indices/indices.py
Normal file
144
src/axolotl/kernels/moe/indices/indices.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Vendored token permutation kernels from TorchTitan."""
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
__all__ = ["generate_permute_indices"]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fill_indices_kernel(
|
||||
tokens_per_expert_group_ptr,
|
||||
start_index_values_ptr,
|
||||
write_offsets_ptr,
|
||||
output_ptr,
|
||||
experts_per_rank: tl.constexpr,
|
||||
num_ranks: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
num_programs = tl.num_programs(axis=0)
|
||||
|
||||
for expert_id in range(pid, experts_per_rank, num_programs):
|
||||
write_offset = tl.load(write_offsets_ptr + expert_id)
|
||||
|
||||
for r in range(num_ranks):
|
||||
idx = r * experts_per_rank + expert_id
|
||||
|
||||
start_index = tl.load(start_index_values_ptr + idx)
|
||||
length = tl.load(tokens_per_expert_group_ptr + idx)
|
||||
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
for chunk_start in range(0, length, BLOCK_SIZE):
|
||||
chunk_offsets = chunk_start + offsets
|
||||
mask = chunk_offsets < length
|
||||
values = start_index + chunk_offsets
|
||||
dest_indices = write_offset + chunk_offsets
|
||||
tl.store(output_ptr + dest_indices, values, mask=mask)
|
||||
|
||||
write_offset += length
|
||||
|
||||
|
||||
def fill_indices_wrapper(
|
||||
tokens_per_expert_group: torch.Tensor,
|
||||
start_index_values: torch.Tensor,
|
||||
write_offsets: torch.Tensor,
|
||||
experts_per_rank: int,
|
||||
num_ranks: int,
|
||||
max_len: int,
|
||||
block_size: int = 128,
|
||||
max_blocks: int = 1024,
|
||||
):
|
||||
permuted_indices = torch.full(
|
||||
(max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device
|
||||
)
|
||||
num_blocks = min(experts_per_rank, max_blocks)
|
||||
grid = (num_blocks,)
|
||||
_fill_indices_kernel[grid](
|
||||
tokens_per_expert_group,
|
||||
start_index_values,
|
||||
write_offsets,
|
||||
permuted_indices,
|
||||
experts_per_rank,
|
||||
num_ranks,
|
||||
BLOCK_SIZE=block_size,
|
||||
)
|
||||
return permuted_indices
|
||||
|
||||
|
||||
def fill_indices_cpu(
|
||||
tokens_per_expert_group: torch.Tensor,
|
||||
start_index_values: torch.Tensor,
|
||||
write_offsets: torch.Tensor,
|
||||
experts_per_rank: int,
|
||||
num_ranks: int,
|
||||
max_len: int,
|
||||
):
|
||||
permuted_indices = torch.full((max_len,), -1, dtype=torch.int32)
|
||||
for expert_id in range(experts_per_rank):
|
||||
write_start = write_offsets[expert_id].item()
|
||||
for r in range(num_ranks):
|
||||
idx = r * experts_per_rank + expert_id
|
||||
start_index = start_index_values[idx].item()
|
||||
length = tokens_per_expert_group[idx].item()
|
||||
if length > 0:
|
||||
end_idx = min(write_start + length, max_len)
|
||||
permuted_indices[write_start:end_idx] = torch.arange(
|
||||
start_index,
|
||||
start_index + (end_idx - write_start),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
write_start += length
|
||||
return permuted_indices
|
||||
|
||||
|
||||
def generate_permute_indices(
|
||||
tokens_per_expert_group: torch.Tensor,
|
||||
experts_per_rank: int,
|
||||
num_ranks: int,
|
||||
max_len: int,
|
||||
alignment: int,
|
||||
use_cpu: bool = False,
|
||||
):
|
||||
start_index_values = (
|
||||
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
|
||||
)
|
||||
|
||||
total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0)
|
||||
total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment)
|
||||
|
||||
m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to(
|
||||
torch.int32
|
||||
)
|
||||
|
||||
m_offsets = torch.cumsum(m_sizes, 0)
|
||||
write_offsets = m_offsets - m_sizes
|
||||
|
||||
if use_cpu:
|
||||
permuted_indices = fill_indices_cpu(
|
||||
tokens_per_expert_group,
|
||||
start_index_values,
|
||||
write_offsets,
|
||||
experts_per_rank,
|
||||
num_ranks,
|
||||
max_len,
|
||||
)
|
||||
else:
|
||||
permuted_indices = fill_indices_wrapper(
|
||||
tokens_per_expert_group,
|
||||
start_index_values,
|
||||
write_offsets,
|
||||
experts_per_rank,
|
||||
num_ranks,
|
||||
max_len,
|
||||
)
|
||||
|
||||
return permuted_indices, m_sizes, m_offsets.to(torch.int32)
|
||||
@@ -8,10 +8,14 @@ from typing import Callable
|
||||
import torch
|
||||
|
||||
from axolotl.kernels.moe import ContiguousGroupedGEMM
|
||||
from axolotl.kernels.moe.indices import generate_permute_indices
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
_GROUP_SIZE_M = 128
|
||||
_COMBINED_SUBMODULES = ("gate_proj", "up_proj", "down_proj")
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
|
||||
return hidden_states.is_cuda and hidden_states.shape[0] > 0
|
||||
@@ -55,9 +59,7 @@ def _ensure_combined_expert_weights(
|
||||
# DeepseekV3 MLP layers are bias-free, but keep this for safety.
|
||||
del lin._parameters["bias"]
|
||||
combined[name] = torch.stack(weights, dim=0).contiguous()
|
||||
module.register_parameter(
|
||||
f"{name}_weight", torch.nn.Parameter(combined[name])
|
||||
)
|
||||
module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined[name]))
|
||||
module._axolotl_original_specs[name] = (orig_device, orig_dtype)
|
||||
|
||||
module._axolotl_combined_weights = True
|
||||
@@ -72,7 +74,9 @@ def _restore_expert_weights(module) -> None:
|
||||
for name in _COMBINED_SUBMODULES:
|
||||
param_name = f"{name}_weight"
|
||||
combined = module._parameters.pop(param_name)
|
||||
orig_device, orig_dtype = module._axolotl_original_specs.get(name, (combined.device, combined.dtype))
|
||||
orig_device, orig_dtype = module._axolotl_original_specs.get(
|
||||
name, (combined.device, combined.dtype)
|
||||
)
|
||||
for idx, expert in enumerate(module.experts):
|
||||
lin = expert.get_submodule(name)
|
||||
lin._parameters["weight"] = torch.nn.Parameter(
|
||||
@@ -82,6 +86,7 @@ def _restore_expert_weights(module) -> None:
|
||||
module._axolotl_combined_weights = False
|
||||
module._axolotl_combined_dtype = None
|
||||
module._axolotl_combined_device = None
|
||||
module._axolotl_original_specs = {}
|
||||
|
||||
|
||||
def _moe_triton_forward(
|
||||
@@ -115,36 +120,45 @@ def _moe_triton_forward(
|
||||
if total_actual == 0:
|
||||
return hidden_states.new_zeros_like(hidden_states)
|
||||
|
||||
padded_counts = (
|
||||
(
|
||||
torch.where(
|
||||
counts > 0,
|
||||
counts,
|
||||
torch.full_like(counts, group_size_m),
|
||||
)
|
||||
+ group_size_m
|
||||
- 1
|
||||
)
|
||||
counts_int = counts.to(torch.int32)
|
||||
aligned_counts = (
|
||||
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1)
|
||||
// group_size_m
|
||||
) * group_size_m
|
||||
max_len = int(aligned_counts.sum().item())
|
||||
|
||||
total_padded = int(padded_counts.sum().item())
|
||||
grouped_hidden = hidden_states.new_zeros((total_padded, hidden_dim))
|
||||
permuted_indices, m_sizes, m_offsets = generate_permute_indices(
|
||||
counts_int.to(device),
|
||||
experts_per_rank=num_experts,
|
||||
num_ranks=1,
|
||||
max_len=max_len,
|
||||
alignment=group_size_m,
|
||||
use_cpu=not hidden_states.is_cuda,
|
||||
)
|
||||
|
||||
write_offsets = torch.cumsum(padded_counts, dim=0) - padded_counts
|
||||
actual_offsets = torch.cumsum(counts, dim=0) - counts
|
||||
if permuted_indices.device != device:
|
||||
permuted_indices = permuted_indices.to(device)
|
||||
if m_sizes.device != device:
|
||||
m_sizes = m_sizes.to(device)
|
||||
if m_offsets.device != device:
|
||||
m_offsets = m_offsets.to(device)
|
||||
|
||||
repeated_offsets = torch.repeat_interleave(actual_offsets, counts)
|
||||
token_index = torch.arange(total_actual, device=device) - repeated_offsets
|
||||
dest_indices = write_offsets[sorted_assignments] + token_index
|
||||
permuted_indices_long = permuted_indices.to(torch.int64)
|
||||
valid_mask = permuted_indices_long >= 0
|
||||
valid_positions = torch.nonzero(valid_mask, as_tuple=False).squeeze(-1)
|
||||
source_indices = permuted_indices_long[valid_mask]
|
||||
|
||||
grouped_hidden.index_copy_(0, dest_indices, sorted_hidden)
|
||||
padded_counts_idx = padded_counts.to(torch.int64)
|
||||
expert_index_tensor = (
|
||||
torch.arange(num_experts, device=device, dtype=torch.int64)
|
||||
.repeat_interleave(padded_counts_idx)
|
||||
.to(torch.int32)
|
||||
.contiguous()
|
||||
grouped_hidden = hidden_states.new_zeros((max_len, hidden_dim))
|
||||
if valid_positions.numel() > 0:
|
||||
grouped_hidden.index_copy_(
|
||||
0,
|
||||
valid_positions,
|
||||
sorted_hidden.index_select(0, source_indices),
|
||||
)
|
||||
|
||||
expert_index_tensor = torch.repeat_interleave(
|
||||
torch.arange(num_experts, device=device, dtype=torch.int32),
|
||||
m_sizes.to(torch.int64),
|
||||
)
|
||||
|
||||
_ensure_combined_expert_weights(module, hidden_dtype, device)
|
||||
@@ -167,13 +181,17 @@ def _moe_triton_forward(
|
||||
)
|
||||
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
|
||||
valid_gate = gate_out.index_select(0, dest_indices).to(hidden_dtype)
|
||||
valid_up = up_out.index_select(0, dest_indices).to(hidden_dtype)
|
||||
hidden_concat = act_fn(valid_gate) * valid_up
|
||||
if valid_positions.numel() > 0:
|
||||
gate_valid = gate_out.index_select(0, valid_positions).to(hidden_dtype)
|
||||
up_valid = up_out.index_select(0, valid_positions).to(hidden_dtype)
|
||||
hidden_concat = act_fn(gate_valid) * up_valid
|
||||
else:
|
||||
hidden_concat = torch.empty((0, gate_out.shape[-1]), device=device, dtype=hidden_dtype)
|
||||
|
||||
intermediate_dim = hidden_concat.shape[-1]
|
||||
hidden_grouped = hidden_states.new_zeros((total_padded, intermediate_dim))
|
||||
hidden_grouped.index_copy_(0, dest_indices, hidden_concat)
|
||||
hidden_grouped = hidden_states.new_zeros((max_len, intermediate_dim))
|
||||
if valid_positions.numel() > 0:
|
||||
hidden_grouped.index_copy_(0, valid_positions, hidden_concat)
|
||||
|
||||
down_out = ContiguousGroupedGEMM.apply(
|
||||
hidden_grouped,
|
||||
@@ -182,10 +200,19 @@ def _moe_triton_forward(
|
||||
group_size_m,
|
||||
)
|
||||
|
||||
down_valid = down_out.index_select(0, dest_indices).to(hidden_dtype)
|
||||
if valid_positions.numel() > 0:
|
||||
down_valid = down_out.index_select(0, valid_positions).to(hidden_dtype)
|
||||
else:
|
||||
down_valid = torch.empty((0, down_out.shape[-1]), device=device, dtype=hidden_dtype)
|
||||
|
||||
sorted_outputs = hidden_states.new_empty((total_actual, hidden_dim))
|
||||
if down_valid.numel() > 0:
|
||||
sorted_outputs.index_copy_(0, source_indices, down_valid)
|
||||
else:
|
||||
sorted_outputs.zero_()
|
||||
|
||||
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
|
||||
expanded_output.index_copy_(0, sort_perm, down_valid)
|
||||
expanded_output.index_copy_(0, sort_perm, sorted_outputs)
|
||||
expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim)
|
||||
|
||||
weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype)
|
||||
@@ -212,7 +239,13 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None:
|
||||
group_size_m,
|
||||
original_moe,
|
||||
)
|
||||
except RuntimeError:
|
||||
except RuntimeError as err:
|
||||
if not getattr(self, "_axolotl_triton_warned", False):
|
||||
LOG.warning(
|
||||
"DeepseekV3MoE Triton path failed; falling back to baseline: %s",
|
||||
err,
|
||||
)
|
||||
self._axolotl_triton_warned = True
|
||||
_restore_expert_weights(self)
|
||||
return original_moe(self, hidden_states, topk_indices, topk_weights)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user