diff --git a/scripts/benchmarks/deepseek_v3_moe.py b/scripts/benchmarks/deepseek_v3_moe.py index 3d309e723..abe5974e6 100644 --- a/scripts/benchmarks/deepseek_v3_moe.py +++ b/scripts/benchmarks/deepseek_v3_moe.py @@ -36,8 +36,8 @@ DTYPE_MAP = { def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--batch", type=int, default=2, help="batch size") - parser.add_argument("--seq-len", type=int, default=256, help="sequence length") + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--seq-len", type=int, default=2048, help="sequence length") parser.add_argument("--hidden-size", type=int, default=4096, help="MoE hidden size") parser.add_argument( "--moe-intermediate-size", @@ -48,13 +48,13 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--n-experts", type=int, - default=64, + default=256, help="Number of routed experts", ) parser.add_argument( "--top-k", type=int, - default=4, + default=8, help="Number of experts per token", ) parser.add_argument( @@ -153,6 +153,10 @@ def main() -> None: # pragma: no cover - CLI entrypoint baseline_module.to(device=device, dtype=dtype) patched_module.to(device=device, dtype=dtype) + tokens = args.batch * args.seq_len + routed_tokens = tokens * args.top_k + avg_tokens_per_expert = routed_tokens / args.n_experts + inputs = torch.randn( args.batch, args.seq_len, @@ -174,6 +178,9 @@ def main() -> None: # pragma: no cover - CLI entrypoint print( f"Device={device.type} dtype={dtype} batch={args.batch} seq={args.seq_len} hidden={args.hidden_size}" ) + print( + f"routed tokens={routed_tokens} avg tokens/expert={avg_tokens_per_expert:.1f} group_size={args.group_size}" + ) print( f"Baseline: {baseline_ms:.3f} ms | Patched: {patched_ms:.3f} ms | x{speedup:.2f}" ) diff --git a/src/axolotl/kernels/moe/__init__.py b/src/axolotl/kernels/moe/__init__.py index 362b2cdb4..1cb51abe9 100644 --- a/src/axolotl/kernels/moe/__init__.py +++ b/src/axolotl/kernels/moe/__init__.py @@ -7,6 +7,7 @@ from .tt_cg_gemm import ( cg_grouped_gemm_forward, cg_grouped_gemm_forward_dynamic, ) +from .indices import generate_permute_indices __all__ = [ "cg_grouped_gemm", @@ -14,4 +15,5 @@ __all__ = [ "cg_grouped_gemm_forward_dynamic", "ContiguousGroupedGEMM", "ContiguousGroupedGEMMForwardOnly", + "generate_permute_indices", ] diff --git a/src/axolotl/kernels/moe/indices/__init__.py b/src/axolotl/kernels/moe/indices/__init__.py new file mode 100644 index 000000000..32933e391 --- /dev/null +++ b/src/axolotl/kernels/moe/indices/__init__.py @@ -0,0 +1,5 @@ +"""Token permutation utilities for grouped MoE kernels.""" + +from .indices import generate_permute_indices + +__all__ = ["generate_permute_indices"] diff --git a/src/axolotl/kernels/moe/indices/indices.py b/src/axolotl/kernels/moe/indices/indices.py new file mode 100644 index 000000000..2f873908b --- /dev/null +++ b/src/axolotl/kernels/moe/indices/indices.py @@ -0,0 +1,144 @@ +"""Vendored token permutation kernels from TorchTitan.""" + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import triton +import triton.language as tl + +__all__ = ["generate_permute_indices"] + + +@triton.jit +def _fill_indices_kernel( + tokens_per_expert_group_ptr, + start_index_values_ptr, + write_offsets_ptr, + output_ptr, + experts_per_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + + for expert_id in range(pid, experts_per_rank, num_programs): + write_offset = tl.load(write_offsets_ptr + expert_id) + + for r in range(num_ranks): + idx = r * experts_per_rank + expert_id + + start_index = tl.load(start_index_values_ptr + idx) + length = tl.load(tokens_per_expert_group_ptr + idx) + + offsets = tl.arange(0, BLOCK_SIZE) + + for chunk_start in range(0, length, BLOCK_SIZE): + chunk_offsets = chunk_start + offsets + mask = chunk_offsets < length + values = start_index + chunk_offsets + dest_indices = write_offset + chunk_offsets + tl.store(output_ptr + dest_indices, values, mask=mask) + + write_offset += length + + +def fill_indices_wrapper( + tokens_per_expert_group: torch.Tensor, + start_index_values: torch.Tensor, + write_offsets: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, + block_size: int = 128, + max_blocks: int = 1024, +): + permuted_indices = torch.full( + (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device + ) + num_blocks = min(experts_per_rank, max_blocks) + grid = (num_blocks,) + _fill_indices_kernel[grid]( + tokens_per_expert_group, + start_index_values, + write_offsets, + permuted_indices, + experts_per_rank, + num_ranks, + BLOCK_SIZE=block_size, + ) + return permuted_indices + + +def fill_indices_cpu( + tokens_per_expert_group: torch.Tensor, + start_index_values: torch.Tensor, + write_offsets: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, +): + permuted_indices = torch.full((max_len,), -1, dtype=torch.int32) + for expert_id in range(experts_per_rank): + write_start = write_offsets[expert_id].item() + for r in range(num_ranks): + idx = r * experts_per_rank + expert_id + start_index = start_index_values[idx].item() + length = tokens_per_expert_group[idx].item() + if length > 0: + end_idx = min(write_start + length, max_len) + permuted_indices[write_start:end_idx] = torch.arange( + start_index, + start_index + (end_idx - write_start), + dtype=torch.int32, + ) + write_start += length + return permuted_indices + + +def generate_permute_indices( + tokens_per_expert_group: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, + alignment: int, + use_cpu: bool = False, +): + start_index_values = ( + torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group + ) + + total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0) + total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment) + + m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to( + torch.int32 + ) + + m_offsets = torch.cumsum(m_sizes, 0) + write_offsets = m_offsets - m_sizes + + if use_cpu: + permuted_indices = fill_indices_cpu( + tokens_per_expert_group, + start_index_values, + write_offsets, + experts_per_rank, + num_ranks, + max_len, + ) + else: + permuted_indices = fill_indices_wrapper( + tokens_per_expert_group, + start_index_values, + write_offsets, + experts_per_rank, + num_ranks, + max_len, + ) + + return permuted_indices, m_sizes, m_offsets.to(torch.int32) diff --git a/src/axolotl/monkeypatch/deepseek_v3/__init__.py b/src/axolotl/monkeypatch/deepseek_v3/__init__.py index 2e9d53a32..b4cda94d2 100644 --- a/src/axolotl/monkeypatch/deepseek_v3/__init__.py +++ b/src/axolotl/monkeypatch/deepseek_v3/__init__.py @@ -8,10 +8,14 @@ from typing import Callable import torch from axolotl.kernels.moe import ContiguousGroupedGEMM +from axolotl.kernels.moe.indices import generate_permute_indices +from axolotl.utils.logging import get_logger _GROUP_SIZE_M = 128 _COMBINED_SUBMODULES = ("gate_proj", "up_proj", "down_proj") +LOG = get_logger(__name__) + def _is_triton_eligible(hidden_states: torch.Tensor) -> bool: return hidden_states.is_cuda and hidden_states.shape[0] > 0 @@ -55,9 +59,7 @@ def _ensure_combined_expert_weights( # DeepseekV3 MLP layers are bias-free, but keep this for safety. del lin._parameters["bias"] combined[name] = torch.stack(weights, dim=0).contiguous() - module.register_parameter( - f"{name}_weight", torch.nn.Parameter(combined[name]) - ) + module.register_parameter(f"{name}_weight", torch.nn.Parameter(combined[name])) module._axolotl_original_specs[name] = (orig_device, orig_dtype) module._axolotl_combined_weights = True @@ -72,7 +74,9 @@ def _restore_expert_weights(module) -> None: for name in _COMBINED_SUBMODULES: param_name = f"{name}_weight" combined = module._parameters.pop(param_name) - orig_device, orig_dtype = module._axolotl_original_specs.get(name, (combined.device, combined.dtype)) + orig_device, orig_dtype = module._axolotl_original_specs.get( + name, (combined.device, combined.dtype) + ) for idx, expert in enumerate(module.experts): lin = expert.get_submodule(name) lin._parameters["weight"] = torch.nn.Parameter( @@ -82,6 +86,7 @@ def _restore_expert_weights(module) -> None: module._axolotl_combined_weights = False module._axolotl_combined_dtype = None module._axolotl_combined_device = None + module._axolotl_original_specs = {} def _moe_triton_forward( @@ -115,36 +120,45 @@ def _moe_triton_forward( if total_actual == 0: return hidden_states.new_zeros_like(hidden_states) - padded_counts = ( - ( - torch.where( - counts > 0, - counts, - torch.full_like(counts, group_size_m), - ) - + group_size_m - - 1 - ) + counts_int = counts.to(torch.int32) + aligned_counts = ( + (torch.clamp_min(counts_int, group_size_m) + group_size_m - 1) // group_size_m ) * group_size_m + max_len = int(aligned_counts.sum().item()) - total_padded = int(padded_counts.sum().item()) - grouped_hidden = hidden_states.new_zeros((total_padded, hidden_dim)) + permuted_indices, m_sizes, m_offsets = generate_permute_indices( + counts_int.to(device), + experts_per_rank=num_experts, + num_ranks=1, + max_len=max_len, + alignment=group_size_m, + use_cpu=not hidden_states.is_cuda, + ) - write_offsets = torch.cumsum(padded_counts, dim=0) - padded_counts - actual_offsets = torch.cumsum(counts, dim=0) - counts + if permuted_indices.device != device: + permuted_indices = permuted_indices.to(device) + if m_sizes.device != device: + m_sizes = m_sizes.to(device) + if m_offsets.device != device: + m_offsets = m_offsets.to(device) - repeated_offsets = torch.repeat_interleave(actual_offsets, counts) - token_index = torch.arange(total_actual, device=device) - repeated_offsets - dest_indices = write_offsets[sorted_assignments] + token_index + permuted_indices_long = permuted_indices.to(torch.int64) + valid_mask = permuted_indices_long >= 0 + valid_positions = torch.nonzero(valid_mask, as_tuple=False).squeeze(-1) + source_indices = permuted_indices_long[valid_mask] - grouped_hidden.index_copy_(0, dest_indices, sorted_hidden) - padded_counts_idx = padded_counts.to(torch.int64) - expert_index_tensor = ( - torch.arange(num_experts, device=device, dtype=torch.int64) - .repeat_interleave(padded_counts_idx) - .to(torch.int32) - .contiguous() + grouped_hidden = hidden_states.new_zeros((max_len, hidden_dim)) + if valid_positions.numel() > 0: + grouped_hidden.index_copy_( + 0, + valid_positions, + sorted_hidden.index_select(0, source_indices), + ) + + expert_index_tensor = torch.repeat_interleave( + torch.arange(num_experts, device=device, dtype=torch.int32), + m_sizes.to(torch.int64), ) _ensure_combined_expert_weights(module, hidden_dtype, device) @@ -167,13 +181,17 @@ def _moe_triton_forward( ) act_fn: Callable[[torch.Tensor], torch.Tensor] = module.experts[0].act_fn - valid_gate = gate_out.index_select(0, dest_indices).to(hidden_dtype) - valid_up = up_out.index_select(0, dest_indices).to(hidden_dtype) - hidden_concat = act_fn(valid_gate) * valid_up + if valid_positions.numel() > 0: + gate_valid = gate_out.index_select(0, valid_positions).to(hidden_dtype) + up_valid = up_out.index_select(0, valid_positions).to(hidden_dtype) + hidden_concat = act_fn(gate_valid) * up_valid + else: + hidden_concat = torch.empty((0, gate_out.shape[-1]), device=device, dtype=hidden_dtype) intermediate_dim = hidden_concat.shape[-1] - hidden_grouped = hidden_states.new_zeros((total_padded, intermediate_dim)) - hidden_grouped.index_copy_(0, dest_indices, hidden_concat) + hidden_grouped = hidden_states.new_zeros((max_len, intermediate_dim)) + if valid_positions.numel() > 0: + hidden_grouped.index_copy_(0, valid_positions, hidden_concat) down_out = ContiguousGroupedGEMM.apply( hidden_grouped, @@ -182,10 +200,19 @@ def _moe_triton_forward( group_size_m, ) - down_valid = down_out.index_select(0, dest_indices).to(hidden_dtype) + if valid_positions.numel() > 0: + down_valid = down_out.index_select(0, valid_positions).to(hidden_dtype) + else: + down_valid = torch.empty((0, down_out.shape[-1]), device=device, dtype=hidden_dtype) + + sorted_outputs = hidden_states.new_empty((total_actual, hidden_dim)) + if down_valid.numel() > 0: + sorted_outputs.index_copy_(0, source_indices, down_valid) + else: + sorted_outputs.zero_() expanded_output = expanded_hidden.new_empty(expanded_hidden.shape) - expanded_output.index_copy_(0, sort_perm, down_valid) + expanded_output.index_copy_(0, sort_perm, sorted_outputs) expert_outputs = expanded_output.view(num_tokens, top_k, hidden_dim) weighted = expert_outputs * topk_weights.unsqueeze(-1).to(hidden_dtype) @@ -212,7 +239,13 @@ def patch_deepseek_v3_moe(group_size_m: int = _GROUP_SIZE_M) -> None: group_size_m, original_moe, ) - except RuntimeError: + except RuntimeError as err: + if not getattr(self, "_axolotl_triton_warned", False): + LOG.warning( + "DeepseekV3MoE Triton path failed; falling back to baseline: %s", + err, + ) + self._axolotl_triton_warned = True _restore_expert_weights(self) return original_moe(self, hidden_states, topk_indices, topk_weights)