glue
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
@@ -13,27 +12,17 @@ from axolotl.kernels.moe import ContiguousGroupedGEMM
|
||||
_GROUP_SIZE_M = 128
|
||||
|
||||
|
||||
def _align_to(value: int, alignment: int) -> int:
|
||||
if value <= 0:
|
||||
return 0
|
||||
return math.ceil(value / alignment) * alignment
|
||||
|
||||
|
||||
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
|
||||
return hidden_states.is_cuda and hidden_states.shape[0] > 0
|
||||
|
||||
|
||||
def _collect_expert_weights(module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
gate_weights = []
|
||||
up_weights = []
|
||||
down_weights = []
|
||||
for expert in module.experts:
|
||||
gate_weights.append(expert.gate_proj.weight)
|
||||
up_weights.append(expert.up_proj.weight)
|
||||
down_weights.append(expert.down_proj.weight)
|
||||
gate = torch.stack(gate_weights, dim=0).contiguous()
|
||||
up = torch.stack(up_weights, dim=0).contiguous()
|
||||
down = torch.stack(down_weights, dim=0).contiguous()
|
||||
gate_weights = [expert.gate_proj.weight for expert in module.experts]
|
||||
up_weights = [expert.up_proj.weight for expert in module.experts]
|
||||
down_weights = [expert.down_proj.weight for expert in module.experts]
|
||||
gate = torch.stack(gate_weights, dim=0)
|
||||
up = torch.stack(up_weights, dim=0)
|
||||
down = torch.stack(down_weights, dim=0)
|
||||
return gate, up, down
|
||||
|
||||
|
||||
@@ -64,35 +53,41 @@ def _moe_triton_forward(
|
||||
|
||||
num_experts = len(module.experts)
|
||||
counts = torch.bincount(sorted_assignments, minlength=num_experts)
|
||||
counts_cpu = counts.to(torch.int64).cpu().tolist()
|
||||
padded_counts = [_align_to(c, group_size_m) for c in counts_cpu]
|
||||
|
||||
total_actual = sum(counts_cpu)
|
||||
total_padded = sum(padded_counts)
|
||||
if total_actual == 0 or total_padded == 0:
|
||||
total_actual = int(counts.sum().item())
|
||||
if total_actual == 0:
|
||||
return hidden_states.new_zeros_like(hidden_states)
|
||||
|
||||
actual_offsets = [0]
|
||||
padded_offsets = [0]
|
||||
for count, padded in zip(counts_cpu, padded_counts, strict=False):
|
||||
actual_offsets.append(actual_offsets[-1] + count)
|
||||
padded_offsets.append(padded_offsets[-1] + padded)
|
||||
|
||||
grouped_hidden = hidden_states.new_zeros((total_padded, hidden_dim))
|
||||
expert_index_tensor = torch.empty(total_padded, dtype=torch.int32, device=device)
|
||||
|
||||
for idx, (count, padded) in enumerate(zip(counts_cpu, padded_counts, strict=False)):
|
||||
dst_start = padded_offsets[idx]
|
||||
dst_end = dst_start + padded
|
||||
if padded == 0:
|
||||
continue
|
||||
expert_index_tensor[dst_start:dst_end] = idx
|
||||
if count > 0:
|
||||
src_start = actual_offsets[idx]
|
||||
src_end = src_start + count
|
||||
grouped_hidden[dst_start : dst_start + count].copy_(
|
||||
sorted_hidden[src_start:src_end]
|
||||
padded_counts = (
|
||||
(
|
||||
torch.where(
|
||||
counts > 0,
|
||||
counts,
|
||||
torch.full_like(counts, group_size_m),
|
||||
)
|
||||
+ group_size_m
|
||||
- 1
|
||||
)
|
||||
// group_size_m
|
||||
) * group_size_m
|
||||
|
||||
total_padded = int(padded_counts.sum().item())
|
||||
grouped_hidden = hidden_states.new_zeros((total_padded, hidden_dim))
|
||||
|
||||
write_offsets = torch.cumsum(padded_counts, dim=0) - padded_counts
|
||||
actual_offsets = torch.cumsum(counts, dim=0) - counts
|
||||
|
||||
repeated_offsets = torch.repeat_interleave(actual_offsets, counts)
|
||||
token_index = torch.arange(total_actual, device=device) - repeated_offsets
|
||||
dest_indices = write_offsets[sorted_assignments] + token_index
|
||||
|
||||
grouped_hidden.index_copy_(0, dest_indices, sorted_hidden)
|
||||
padded_counts_idx = padded_counts.to(torch.int64)
|
||||
expert_index_tensor = (
|
||||
torch.arange(num_experts, device=device, dtype=torch.int64)
|
||||
.repeat_interleave(padded_counts_idx)
|
||||
.to(torch.int32)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
gate_weights, up_weights, down_weights = _collect_expert_weights(module)
|
||||
|
||||
@@ -110,31 +105,13 @@ def _moe_triton_forward(
|
||||
)
|
||||
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
|
||||
|
||||
hidden_chunks = []
|
||||
for idx, count in enumerate(counts_cpu):
|
||||
if count == 0:
|
||||
continue
|
||||
pad_start = padded_offsets[idx]
|
||||
pad_end = pad_start + count
|
||||
gate_slice = gate_out[pad_start:pad_end].to(hidden_dtype)
|
||||
up_slice = up_out[pad_start:pad_end].to(hidden_dtype)
|
||||
hidden_chunks.append(act_fn(gate_slice) * up_slice)
|
||||
|
||||
hidden_concat = torch.cat(hidden_chunks, dim=0)
|
||||
valid_gate = gate_out.index_select(0, dest_indices).to(hidden_dtype)
|
||||
valid_up = up_out.index_select(0, dest_indices).to(hidden_dtype)
|
||||
hidden_concat = act_fn(valid_gate) * valid_up
|
||||
|
||||
intermediate_dim = hidden_concat.shape[-1]
|
||||
hidden_grouped = hidden_states.new_zeros((total_padded, intermediate_dim))
|
||||
|
||||
for idx, count in enumerate(counts_cpu):
|
||||
if count == 0:
|
||||
continue
|
||||
pad_start = padded_offsets[idx]
|
||||
src_start = actual_offsets[idx]
|
||||
src_end = src_start + count
|
||||
hidden_grouped[pad_start : pad_start + count].copy_(
|
||||
hidden_concat[src_start:src_end]
|
||||
)
|
||||
hidden_grouped.index_copy_(0, dest_indices, hidden_concat)
|
||||
|
||||
down_out = ContiguousGroupedGEMM.apply(
|
||||
hidden_grouped,
|
||||
@@ -143,18 +120,10 @@ def _moe_triton_forward(
|
||||
group_size_m,
|
||||
)
|
||||
|
||||
down_chunks = []
|
||||
for idx, count in enumerate(counts_cpu):
|
||||
if count == 0:
|
||||
continue
|
||||
pad_start = padded_offsets[idx]
|
||||
pad_end = pad_start + count
|
||||
down_chunks.append(down_out[pad_start:pad_end].to(hidden_dtype))
|
||||
|
||||
down_concat = torch.cat(down_chunks, dim=0)
|
||||
down_valid = down_out.index_select(0, dest_indices).to(hidden_dtype)
|
||||
|
||||
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
|
||||
expanded_output.index_copy_(0, sort_perm, down_concat.to(hidden_dtype))
|
||||
expanded_output.index_copy_(0, sort_perm, down_valid)
|
||||
expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim)
|
||||
|
||||
weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype)
|
||||
|
||||
Reference in New Issue
Block a user