This commit is contained in:
Dan Saunders
2025-09-21 16:23:23 -04:00
parent 95e607574a
commit 6a45d804f9

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import contextlib
import math
from typing import Callable
import torch
@@ -13,27 +12,17 @@ from axolotl.kernels.moe import ContiguousGroupedGEMM
_GROUP_SIZE_M = 128
def _align_to(value: int, alignment: int) -> int:
if value <= 0:
return 0
return math.ceil(value / alignment) * alignment
def _is_triton_eligible(hidden_states: torch.Tensor) -> bool:
return hidden_states.is_cuda and hidden_states.shape[0] > 0
def _collect_expert_weights(module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
gate_weights = []
up_weights = []
down_weights = []
for expert in module.experts:
gate_weights.append(expert.gate_proj.weight)
up_weights.append(expert.up_proj.weight)
down_weights.append(expert.down_proj.weight)
gate = torch.stack(gate_weights, dim=0).contiguous()
up = torch.stack(up_weights, dim=0).contiguous()
down = torch.stack(down_weights, dim=0).contiguous()
gate_weights = [expert.gate_proj.weight for expert in module.experts]
up_weights = [expert.up_proj.weight for expert in module.experts]
down_weights = [expert.down_proj.weight for expert in module.experts]
gate = torch.stack(gate_weights, dim=0)
up = torch.stack(up_weights, dim=0)
down = torch.stack(down_weights, dim=0)
return gate, up, down
@@ -64,35 +53,41 @@ def _moe_triton_forward(
num_experts = len(module.experts)
counts = torch.bincount(sorted_assignments, minlength=num_experts)
counts_cpu = counts.to(torch.int64).cpu().tolist()
padded_counts = [_align_to(c, group_size_m) for c in counts_cpu]
total_actual = sum(counts_cpu)
total_padded = sum(padded_counts)
if total_actual == 0 or total_padded == 0:
total_actual = int(counts.sum().item())
if total_actual == 0:
return hidden_states.new_zeros_like(hidden_states)
actual_offsets = [0]
padded_offsets = [0]
for count, padded in zip(counts_cpu, padded_counts, strict=False):
actual_offsets.append(actual_offsets[-1] + count)
padded_offsets.append(padded_offsets[-1] + padded)
grouped_hidden = hidden_states.new_zeros((total_padded, hidden_dim))
expert_index_tensor = torch.empty(total_padded, dtype=torch.int32, device=device)
for idx, (count, padded) in enumerate(zip(counts_cpu, padded_counts, strict=False)):
dst_start = padded_offsets[idx]
dst_end = dst_start + padded
if padded == 0:
continue
expert_index_tensor[dst_start:dst_end] = idx
if count > 0:
src_start = actual_offsets[idx]
src_end = src_start + count
grouped_hidden[dst_start : dst_start + count].copy_(
sorted_hidden[src_start:src_end]
padded_counts = (
(
torch.where(
counts > 0,
counts,
torch.full_like(counts, group_size_m),
)
+ group_size_m
- 1
)
// group_size_m
) * group_size_m
total_padded = int(padded_counts.sum().item())
grouped_hidden = hidden_states.new_zeros((total_padded, hidden_dim))
write_offsets = torch.cumsum(padded_counts, dim=0) - padded_counts
actual_offsets = torch.cumsum(counts, dim=0) - counts
repeated_offsets = torch.repeat_interleave(actual_offsets, counts)
token_index = torch.arange(total_actual, device=device) - repeated_offsets
dest_indices = write_offsets[sorted_assignments] + token_index
grouped_hidden.index_copy_(0, dest_indices, sorted_hidden)
padded_counts_idx = padded_counts.to(torch.int64)
expert_index_tensor = (
torch.arange(num_experts, device=device, dtype=torch.int64)
.repeat_interleave(padded_counts_idx)
.to(torch.int32)
.contiguous()
)
gate_weights, up_weights, down_weights = _collect_expert_weights(module)
@@ -110,31 +105,13 @@ def _moe_triton_forward(
)
act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn
hidden_chunks = []
for idx, count in enumerate(counts_cpu):
if count == 0:
continue
pad_start = padded_offsets[idx]
pad_end = pad_start + count
gate_slice = gate_out[pad_start:pad_end].to(hidden_dtype)
up_slice = up_out[pad_start:pad_end].to(hidden_dtype)
hidden_chunks.append(act_fn(gate_slice) * up_slice)
hidden_concat = torch.cat(hidden_chunks, dim=0)
valid_gate = gate_out.index_select(0, dest_indices).to(hidden_dtype)
valid_up = up_out.index_select(0, dest_indices).to(hidden_dtype)
hidden_concat = act_fn(valid_gate) * valid_up
intermediate_dim = hidden_concat.shape[-1]
hidden_grouped = hidden_states.new_zeros((total_padded, intermediate_dim))
for idx, count in enumerate(counts_cpu):
if count == 0:
continue
pad_start = padded_offsets[idx]
src_start = actual_offsets[idx]
src_end = src_start + count
hidden_grouped[pad_start : pad_start + count].copy_(
hidden_concat[src_start:src_end]
)
hidden_grouped.index_copy_(0, dest_indices, hidden_concat)
down_out = ContiguousGroupedGEMM.apply(
hidden_grouped,
@@ -143,18 +120,10 @@ def _moe_triton_forward(
group_size_m,
)
down_chunks = []
for idx, count in enumerate(counts_cpu):
if count == 0:
continue
pad_start = padded_offsets[idx]
pad_end = pad_start + count
down_chunks.append(down_out[pad_start:pad_end].to(hidden_dtype))
down_concat = torch.cat(down_chunks, dim=0)
down_valid = down_out.index_select(0, dest_indices).to(hidden_dtype)
expanded_output = expanded_hidden.new_empty(expanded_hidden.shape)
expanded_output.index_copy_(0, sort_perm, down_concat.to(hidden_dtype))
expanded_output.index_copy_(0, sort_perm, down_valid)
expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim)
weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype)