fix
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
"""Mixture-of-Experts kernel implementations."""
|
"""Mixture-of-Experts kernel implementations."""
|
||||||
|
|
||||||
|
from .indices import generate_permute_indices
|
||||||
from .tt_cg_gemm import (
|
from .tt_cg_gemm import (
|
||||||
ContiguousGroupedGEMM,
|
ContiguousGroupedGEMM,
|
||||||
ContiguousGroupedGEMMForwardOnly,
|
ContiguousGroupedGEMMForwardOnly,
|
||||||
@@ -7,7 +8,6 @@ from .tt_cg_gemm import (
|
|||||||
cg_grouped_gemm_forward,
|
cg_grouped_gemm_forward,
|
||||||
cg_grouped_gemm_forward_dynamic,
|
cg_grouped_gemm_forward_dynamic,
|
||||||
)
|
)
|
||||||
from .indices import generate_permute_indices
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"cg_grouped_gemm",
|
"cg_grouped_gemm",
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -122,8 +121,7 @@ def _moe_triton_forward(
|
|||||||
|
|
||||||
counts_int = counts.to(torch.int32)
|
counts_int = counts.to(torch.int32)
|
||||||
aligned_counts = (
|
aligned_counts = (
|
||||||
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1)
|
(torch.clamp_min(counts_int, group_size_m) + group_size_m - 1) // group_size_m
|
||||||
// group_size_m
|
|
||||||
) * group_size_m
|
) * group_size_m
|
||||||
max_len = int(aligned_counts.sum().item())
|
max_len = int(aligned_counts.sum().item())
|
||||||
|
|
||||||
@@ -147,14 +145,17 @@ def _moe_triton_forward(
|
|||||||
valid_mask = permuted_indices_long >= 0
|
valid_mask = permuted_indices_long >= 0
|
||||||
valid_positions = torch.nonzero(valid_mask, as_tuple=False).squeeze(-1)
|
valid_positions = torch.nonzero(valid_mask, as_tuple=False).squeeze(-1)
|
||||||
source_indices = permuted_indices_long[valid_mask]
|
source_indices = permuted_indices_long[valid_mask]
|
||||||
|
padded_positions = torch.nonzero(~valid_mask, as_tuple=False).squeeze(-1)
|
||||||
|
|
||||||
grouped_hidden = hidden_states.new_zeros((max_len, hidden_dim))
|
grouped_hidden = hidden_states.new_empty((max_len, hidden_dim))
|
||||||
if valid_positions.numel() > 0:
|
if valid_positions.numel() > 0:
|
||||||
grouped_hidden.index_copy_(
|
grouped_hidden.index_copy_(
|
||||||
0,
|
0,
|
||||||
valid_positions,
|
valid_positions,
|
||||||
sorted_hidden.index_select(0, source_indices),
|
sorted_hidden.index_select(0, source_indices),
|
||||||
)
|
)
|
||||||
|
if valid_positions.numel() < max_len:
|
||||||
|
grouped_hidden.index_fill_(0, padded_positions, 0)
|
||||||
|
|
||||||
expert_index_tensor = torch.repeat_interleave(
|
expert_index_tensor = torch.repeat_interleave(
|
||||||
torch.arange(num_experts, device=device, dtype=torch.int32),
|
torch.arange(num_experts, device=device, dtype=torch.int32),
|
||||||
@@ -186,12 +187,16 @@ def _moe_triton_forward(
|
|||||||
up_valid = up_out.index_select(0, valid_positions).to(hidden_dtype)
|
up_valid = up_out.index_select(0, valid_positions).to(hidden_dtype)
|
||||||
hidden_concat = act_fn(gate_valid) * up_valid
|
hidden_concat = act_fn(gate_valid) * up_valid
|
||||||
else:
|
else:
|
||||||
hidden_concat = torch.empty((0, gate_out.shape[-1]), device=device, dtype=hidden_dtype)
|
hidden_concat = torch.empty(
|
||||||
|
(0, gate_out.shape[-1]), device=device, dtype=hidden_dtype
|
||||||
|
)
|
||||||
|
|
||||||
intermediate_dim = hidden_concat.shape[-1]
|
intermediate_dim = hidden_concat.shape[-1]
|
||||||
hidden_grouped = hidden_states.new_zeros((max_len, intermediate_dim))
|
hidden_grouped = hidden_states.new_empty((max_len, intermediate_dim))
|
||||||
if valid_positions.numel() > 0:
|
if valid_positions.numel() > 0:
|
||||||
hidden_grouped.index_copy_(0, valid_positions, hidden_concat)
|
hidden_grouped.index_copy_(0, valid_positions, hidden_concat)
|
||||||
|
if valid_positions.numel() < max_len:
|
||||||
|
hidden_grouped.index_fill_(0, padded_positions, 0)
|
||||||
|
|
||||||
down_out = ContiguousGroupedGEMM.apply(
|
down_out = ContiguousGroupedGEMM.apply(
|
||||||
hidden_grouped,
|
hidden_grouped,
|
||||||
@@ -203,13 +208,13 @@ def _moe_triton_forward(
|
|||||||
if valid_positions.numel() > 0:
|
if valid_positions.numel() > 0:
|
||||||
down_valid = down_out.index_select(0, valid_positions).to(hidden_dtype)
|
down_valid = down_out.index_select(0, valid_positions).to(hidden_dtype)
|
||||||
else:
|
else:
|
||||||
down_valid = torch.empty((0, down_out.shape[-1]), device=device, dtype=hidden_dtype)
|
down_valid = torch.empty(
|
||||||
|
(0, down_out.shape[-1]), device=device, dtype=hidden_dtype
|
||||||
|
)
|
||||||
|
|
||||||
sorted_outputs = hidden_states.new_empty((total_actual, hidden_dim))
|
sorted_outputs = hidden_states.new_zeros((total_actual, hidden_dim))
|
||||||
if down_valid.numel() > 0:
|
if down_valid.numel() > 0:
|
||||||
sorted_outputs.index_copy_(0, source_indices, down_valid)
|
sorted_outputs.index_copy_(0, source_indices, down_valid)
|
||||||
else:
|
|
||||||
sorted_outputs.zero_()
|
|
||||||
|
|
||||||
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
|
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
|
||||||
expanded_output.index_copy_(0, sort_perm, sorted_outputs)
|
expanded_output.index_copy_(0, sort_perm, sorted_outputs)
|
||||||
|
|||||||
Reference in New Issue
Block a user