fix
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Mixture-of-Experts kernel implementations."""
|
||||
|
||||
from .indices import generate_permute_indices
|
||||
from .tt_cg_gemm import (
|
||||
ContiguousGroupedGEMM,
|
||||
ContiguousGroupedGEMMForwardOnly,
|
||||
@@ -7,7 +8,6 @@ from .tt_cg_gemm import (
|
||||
cg_grouped_gemm_forward,
|
||||
cg_grouped_gemm_forward_dynamic,
|
||||
)
|
||||
from .indices import generate_permute_indices
|
||||
|
||||
__all__ = [
|
||||
"cg_grouped_gemm",
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
@@ -122,8 +121,7 @@ def _moe_triton_forward(
|
||||
|
||||
counts_int = counts.to(torch.int32)
|
||||
aligned_counts = (
|
||||
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1)
|
||||
// group_size_m
|
||||
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1) // group_size_m
|
||||
) * group_size_m
|
||||
max_len = int(aligned_counts.sum().item())
|
||||
|
||||
@@ -147,14 +145,17 @@ def _moe_triton_forward(
|
||||
valid_mask = permuted_indices_long >= 0
|
||||
valid_positions = torch.nonzero(valid_mask, as_tuple=False).squeeze(-1)
|
||||
source_indices = permuted_indices_long[valid_mask]
|
||||
padded_positions = torch.nonzero(~valid_mask, as_tuple=False).squeeze(-1)
|
||||
|
||||
grouped_hidden = hidden_states.new_zeros((max_len, hidden_dim))
|
||||
grouped_hidden = hidden_states.new_empty((max_len, hidden_dim))
|
||||
if valid_positions.numel() > 0:
|
||||
grouped_hidden.index_copy_(
|
||||
0,
|
||||
valid_positions,
|
||||
sorted_hidden.index_select(0, source_indices),
|
||||
)
|
||||
if valid_positions.numel() < max_len:
|
||||
grouped_hidden.index_fill_(0, padded_positions, 0)
|
||||
|
||||
expert_index_tensor = torch.repeat_interleave(
|
||||
torch.arange(num_experts, device=device, dtype=torch.int32),
|
||||
@@ -186,12 +187,16 @@ def _moe_triton_forward(
|
||||
up_valid = up_out.index_select(0, valid_positions).to(hidden_dtype)
|
||||
hidden_concat = act_fn(gate_valid) * up_valid
|
||||
else:
|
||||
hidden_concat = torch.empty((0, gate_out.shape[-1]), device=device, dtype=hidden_dtype)
|
||||
hidden_concat = torch.empty(
|
||||
(0, gate_out.shape[-1]), device=device, dtype=hidden_dtype
|
||||
)
|
||||
|
||||
intermediate_dim = hidden_concat.shape[-1]
|
||||
hidden_grouped = hidden_states.new_zeros((max_len, intermediate_dim))
|
||||
hidden_grouped = hidden_states.new_empty((max_len, intermediate_dim))
|
||||
if valid_positions.numel() > 0:
|
||||
hidden_grouped.index_copy_(0, valid_positions, hidden_concat)
|
||||
if valid_positions.numel() < max_len:
|
||||
hidden_grouped.index_fill_(0, padded_positions, 0)
|
||||
|
||||
down_out = ContiguousGroupedGEMM.apply(
|
||||
hidden_grouped,
|
||||
@@ -203,13 +208,13 @@ def _moe_triton_forward(
|
||||
if valid_positions.numel() > 0:
|
||||
down_valid = down_out.index_select(0, valid_positions).to(hidden_dtype)
|
||||
else:
|
||||
down_valid = torch.empty((0, down_out.shape[-1]), device=device, dtype=hidden_dtype)
|
||||
down_valid = torch.empty(
|
||||
(0, down_out.shape[-1]), device=device, dtype=hidden_dtype
|
||||
)
|
||||
|
||||
sorted_outputs = hidden_states.new_empty((total_actual, hidden_dim))
|
||||
sorted_outputs = hidden_states.new_zeros((total_actual, hidden_dim))
|
||||
if down_valid.numel() > 0:
|
||||
sorted_outputs.index_copy_(0, source_indices, down_valid)
|
||||
else:
|
||||
sorted_outputs.zero_()
|
||||
|
||||
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
|
||||
expanded_output.index_copy_(0, sort_perm, sorted_outputs)
|
||||
|
||||
Reference in New Issue
Block a user