This commit is contained in:
Dan Saunders
2025-09-22 15:54:44 -04:00
parent 5c74edeefe
commit db782430f8
2 changed files with 16 additions and 11 deletions

View File

@@ -1,5 +1,6 @@
"""Mixture-of-Experts kernel implementations."""
from .indices import generate_permute_indices
from .tt_cg_gemm import (
ContiguousGroupedGEMM,
ContiguousGroupedGEMMForwardOnly,
@@ -7,7 +8,6 @@ from .tt_cg_gemm import (
cg_grouped_gemm_forward,
cg_grouped_gemm_forward_dynamic,
)
from .indices import generate_permute_indices
__all__ = [
"cg_grouped_gemm",

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
import contextlib
from typing import Callable
import torch
@@ -122,8 +121,7 @@ def _moe_triton_forward(
counts_int = counts.to(torch.int32)
aligned_counts = (
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1)
// group_size_m
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1) // group_size_m
) * group_size_m
max_len = int(aligned_counts.sum().item())
@@ -147,14 +145,17 @@ def _moe_triton_forward(
valid_mask = permuted_indices_long >= 0
valid_positions = torch.nonzero(valid_mask, as_tuple=False).squeeze(-1)
source_indices = permuted_indices_long[valid_mask]
padded_positions = torch.nonzero(~valid_mask, as_tuple=False).squeeze(-1)
grouped_hidden = hidden_states.new_zeros((max_len, hidden_dim))
grouped_hidden = hidden_states.new_empty((max_len, hidden_dim))
if valid_positions.numel() > 0:
grouped_hidden.index_copy_(
0,
valid_positions,
sorted_hidden.index_select(0, source_indices),
)
if valid_positions.numel() < max_len:
grouped_hidden.index_fill_(0, padded_positions, 0)
expert_index_tensor = torch.repeat_interleave(
torch.arange(num_experts, device=device, dtype=torch.int32),
@@ -186,12 +187,16 @@ def _moe_triton_forward(
up_valid = up_out.index_select(0, valid_positions).to(hidden_dtype)
hidden_concat = act_fn(gate_valid) * up_valid
else:
hidden_concat = torch.empty((0, gate_out.shape[-1]), device=device, dtype=hidden_dtype)
hidden_concat = torch.empty(
(0, gate_out.shape[-1]), device=device, dtype=hidden_dtype
)
intermediate_dim = hidden_concat.shape[-1]
hidden_grouped = hidden_states.new_zeros((max_len, intermediate_dim))
hidden_grouped = hidden_states.new_empty((max_len, intermediate_dim))
if valid_positions.numel() > 0:
hidden_grouped.index_copy_(0, valid_positions, hidden_concat)
if valid_positions.numel() < max_len:
hidden_grouped.index_fill_(0, padded_positions, 0)
down_out = ContiguousGroupedGEMM.apply(
hidden_grouped,
@@ -203,13 +208,13 @@ def _moe_triton_forward(
if valid_positions.numel() > 0:
down_valid = down_out.index_select(0, valid_positions).to(hidden_dtype)
else:
down_valid = torch.empty((0, down_out.shape[-1]), device=device, dtype=hidden_dtype)
down_valid = torch.empty(
(0, down_out.shape[-1]), device=device, dtype=hidden_dtype
)
sorted_outputs = hidden_states.new_empty((total_actual, hidden_dim))
sorted_outputs = hidden_states.new_zeros((total_actual, hidden_dim))
if down_valid.numel() > 0:
sorted_outputs.index_copy_(0, source_indices, down_valid)
else:
sorted_outputs.zero_()
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
expanded_output.index_copy_(0, sort_perm, sorted_outputs)