From 6a45d804f9904307aa8bb9227de49e6cd06adc13 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sun, 21 Sep 2025 16:23:23 -0400 Subject: [PATCH] glue --- .../monkeypatch/deepseek_v3/__init__.py | 119 +++++++----------- 1 file changed, 44 insertions(+), 75 deletions(-) diff --git a/src/axolotl/monkeypatch/deepseek_v3/__init__.py b/src/axolotl/monkeypatch/deepseek_v3/__init__.py index d32793656..92d57b64c 100644 --- a/src/axolotl/monkeypatch/deepseek_v3/__init__.py +++ b/src/axolotl/monkeypatch/deepseek_v3/__init__.py @@ -3,7 +3,6 @@ from __future__ import annotations import contextlib -import math from typing import Callable import torch @@ -13,27 +12,17 @@ from axolotl.kernels.moe import ContiguousGroupedGEMM _GROUP_SIZE_M = 128 -def _align_to(value: int, alignment: int) -> int: - if value <= 0: - return 0 - return math.ceil(value / alignment) * alignment - - def _is_triton_eligible(hidden_states: torch.Tensor) -> bool: return hidden_states.is_cuda and hidden_states.shape[0] > 0 def _collect_expert_weights(module) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - gate_weights = [] - up_weights = [] - down_weights = [] - for expert in module.experts: - gate_weights.append(expert.gate_proj.weight) - up_weights.append(expert.up_proj.weight) - down_weights.append(expert.down_proj.weight) - gate = torch.stack(gate_weights, dim=0).contiguous() - up = torch.stack(up_weights, dim=0).contiguous() - down = torch.stack(down_weights, dim=0).contiguous() + gate_weights = [expert.gate_proj.weight for expert in module.experts] + up_weights = [expert.up_proj.weight for expert in module.experts] + down_weights = [expert.down_proj.weight for expert in module.experts] + gate = torch.stack(gate_weights, dim=0) + up = torch.stack(up_weights, dim=0) + down = torch.stack(down_weights, dim=0) return gate, up, down @@ -64,35 +53,41 @@ def _moe_triton_forward( num_experts = len(module.experts) counts = torch.bincount(sorted_assignments, minlength=num_experts) - counts_cpu = counts.to(torch.int64).cpu().tolist() - padded_counts = [_align_to(c, group_size_m) for c in counts_cpu] - - total_actual = sum(counts_cpu) - total_padded = sum(padded_counts) - if total_actual == 0 or total_padded == 0: + total_actual = int(counts.sum().item()) + if total_actual == 0: return hidden_states.new_zeros_like(hidden_states) - actual_offsets = [0] - padded_offsets = [0] - for count, padded in zip(counts_cpu, padded_counts, strict=False): - actual_offsets.append(actual_offsets[-1] + count) - padded_offsets.append(padded_offsets[-1] + padded) - - grouped_hidden = hidden_states.new_zeros((total_padded, hidden_dim)) - expert_index_tensor = torch.empty(total_padded, dtype=torch.int32, device=device) - - for idx, (count, padded) in enumerate(zip(counts_cpu, padded_counts, strict=False)): - dst_start = padded_offsets[idx] - dst_end = dst_start + padded - if padded == 0: - continue - expert_index_tensor[dst_start:dst_end] = idx - if count > 0: - src_start = actual_offsets[idx] - src_end = src_start + count - grouped_hidden[dst_start : dst_start + count].copy_( - sorted_hidden[src_start:src_end] + padded_counts = ( + ( + torch.where( + counts > 0, + counts, + torch.full_like(counts, group_size_m), ) + + group_size_m + - 1 + ) + // group_size_m + ) * group_size_m + + total_padded = int(padded_counts.sum().item()) + grouped_hidden = hidden_states.new_zeros((total_padded, hidden_dim)) + + write_offsets = torch.cumsum(padded_counts, dim=0) - padded_counts + actual_offsets = torch.cumsum(counts, dim=0) - counts + + repeated_offsets = torch.repeat_interleave(actual_offsets, counts) + token_index = torch.arange(total_actual, device=device) - repeated_offsets + dest_indices = write_offsets[sorted_assignments] + token_index + + grouped_hidden.index_copy_(0, dest_indices, sorted_hidden) + padded_counts_idx = padded_counts.to(torch.int64) + expert_index_tensor = ( + torch.arange(num_experts, device=device, dtype=torch.int64) + .repeat_interleave(padded_counts_idx) + .to(torch.int32) + .contiguous() + ) gate_weights, up_weights, down_weights = _collect_expert_weights(module) @@ -110,31 +105,13 @@ def _moe_triton_forward( ) act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn - - hidden_chunks = [] - for idx, count in enumerate(counts_cpu): - if count == 0: - continue - pad_start = padded_offsets[idx] - pad_end = pad_start + count - gate_slice = gate_out[pad_start:pad_end].to(hidden_dtype) - up_slice = up_out[pad_start:pad_end].to(hidden_dtype) - hidden_chunks.append(act_fn(gate_slice) * up_slice) - - hidden_concat = torch.cat(hidden_chunks, dim=0) + valid_gate = gate_out.index_select(0, dest_indices).to(hidden_dtype) + valid_up = up_out.index_select(0, dest_indices).to(hidden_dtype) + hidden_concat = act_fn(valid_gate) * valid_up intermediate_dim = hidden_concat.shape[-1] hidden_grouped = hidden_states.new_zeros((total_padded, intermediate_dim)) - - for idx, count in enumerate(counts_cpu): - if count == 0: - continue - pad_start = padded_offsets[idx] - src_start = actual_offsets[idx] - src_end = src_start + count - hidden_grouped[pad_start : pad_start + count].copy_( - hidden_concat[src_start:src_end] - ) + hidden_grouped.index_copy_(0, dest_indices, hidden_concat) down_out = ContiguousGroupedGEMM.apply( hidden_grouped, @@ -143,18 +120,10 @@ def _moe_triton_forward( group_size_m, ) - down_chunks = [] - for idx, count in enumerate(counts_cpu): - if count == 0: - continue - pad_start = padded_offsets[idx] - pad_end = pad_start + count - down_chunks.append(down_out[pad_start:pad_end].to(hidden_dtype)) - - down_concat = torch.cat(down_chunks, dim=0) + down_valid = down_out.index_select(0, dest_indices).to(hidden_dtype) expanded_output = expanded_hidden.new_empty(expanded_hidden.shape) - expanded_output.index_copy_(0, sort_perm, down_concat.to(hidden_dtype)) + expanded_output.index_copy_(0, sort_perm, down_valid) expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim) weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype)