diff --git a/src/axolotl/kernels/moe/__init__.py b/src/axolotl/kernels/moe/__init__.py index 1cb51abe9..eb1e5c3cf 100644 --- a/src/axolotl/kernels/moe/__init__.py +++ b/src/axolotl/kernels/moe/__init__.py @@ -1,5 +1,6 @@ """Mixture-of-Experts kernel implementations.""" +from .indices import generate_permute_indices from .tt_cg_gemm import ( ContiguousGroupedGEMM, ContiguousGroupedGEMMForwardOnly, @@ -7,7 +8,6 @@ from .tt_cg_gemm import ( cg_grouped_gemm_forward, cg_grouped_gemm_forward_dynamic, ) -from .indices import generate_permute_indices __all__ = [ "cg_grouped_gemm", diff --git a/src/axolotl/monkeypatch/deepseek_v3/__init__.py b/src/axolotl/monkeypatch/deepseek_v3/__init__.py index b4cda94d2..c46ec34e8 100644 --- a/src/axolotl/monkeypatch/deepseek_v3/__init__.py +++ b/src/axolotl/monkeypatch/deepseek_v3/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations -import contextlib from typing import Callable import torch @@ -122,8 +121,7 @@ def _moe_triton_forward( counts_int = counts.to(torch.int32) aligned_counts = ( - (torch.clamp_min(counts_int, group_size_m) + group_size_m - 1) - // group_size_m + (torch.clamp_min(counts_int, group_size_m) + group_size_m - 1) // group_size_m ) * group_size_m max_len = int(aligned_counts.sum().item()) @@ -147,14 +145,17 @@ def _moe_triton_forward( valid_mask = permuted_indices_long >= 0 valid_positions = torch.nonzero(valid_mask, as_tuple=False).squeeze(-1) source_indices = permuted_indices_long[valid_mask] + padded_positions = torch.nonzero(~valid_mask, as_tuple=False).squeeze(-1) - grouped_hidden = hidden_states.new_zeros((max_len, hidden_dim)) + grouped_hidden = hidden_states.new_empty((max_len, hidden_dim)) if valid_positions.numel() > 0: grouped_hidden.index_copy_( 0, valid_positions, sorted_hidden.index_select(0, source_indices), ) + if valid_positions.numel() < max_len: + grouped_hidden.index_fill_(0, padded_positions, 0) expert_index_tensor = torch.repeat_interleave( torch.arange(num_experts, device=device, dtype=torch.int32), @@ -186,12 +187,16 @@ def _moe_triton_forward( up_valid = up_out.index_select(0, valid_positions).to(hidden_dtype) hidden_concat = act_fn(gate_valid) * up_valid else: - hidden_concat = torch.empty((0, gate_out.shape[-1]), device=device, dtype=hidden_dtype) + hidden_concat = torch.empty( + (0, gate_out.shape[-1]), device=device, dtype=hidden_dtype + ) intermediate_dim = hidden_concat.shape[-1] - hidden_grouped = hidden_states.new_zeros((max_len, intermediate_dim)) + hidden_grouped = hidden_states.new_empty((max_len, intermediate_dim)) if valid_positions.numel() > 0: hidden_grouped.index_copy_(0, valid_positions, hidden_concat) + if valid_positions.numel() < max_len: + hidden_grouped.index_fill_(0, padded_positions, 0) down_out = ContiguousGroupedGEMM.apply( hidden_grouped, @@ -203,13 +208,13 @@ def _moe_triton_forward( if valid_positions.numel() > 0: down_valid = down_out.index_select(0, valid_positions).to(hidden_dtype) else: - down_valid = torch.empty((0, down_out.shape[-1]), device=device, dtype=hidden_dtype) + down_valid = torch.empty( + (0, down_out.shape[-1]), device=device, dtype=hidden_dtype + ) - sorted_outputs = hidden_states.new_empty((total_actual, hidden_dim)) + sorted_outputs = hidden_states.new_zeros((total_actual, hidden_dim)) if down_valid.numel() > 0: sorted_outputs.index_copy_(0, source_indices, down_valid) - else: - sorted_outputs.zero_() expanded_output = expanded_hidden.new_empty(expanded_hidden.shape) expanded_output.index_copy_(0, sort_perm, sorted_outputs)