From 19c91e36752ec5a288b396b996f9b395a1a2ab40 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Thu, 18 Sep 2025 11:44:21 -0400 Subject: [PATCH] refactor --- src/axolotl/kernels/moe/torch_grouped.py | 230 ++++++++++++++++++----- 1 file changed, 188 insertions(+), 42 deletions(-) diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index d5d311b6b..c6d671241 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from dataclasses import dataclass from typing import List, Optional, Tuple import torch @@ -33,22 +34,190 @@ def _iter_expert_impls(experts_module) -> List[torch.nn.Module]: return impls -def _stack_weights( - experts_module, - names: Tuple[str, ...], - *, - dtype: torch.dtype, - device: torch.device, -) -> torch.Tensor: - tensors: List[torch.Tensor] = [] - for mod in _iter_expert_impls(experts_module): - parts = [getattr(mod, name).weight.t() for name in names] - tensors.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1)) +@dataclass +class _GroupedWeightStorage: + pattern: str + gate: torch.Tensor + up: torch.Tensor + down: torch.Tensor + dtype: torch.dtype + device: torch.device + base_gate: Optional[torch.Tensor] = None - return ( - torch.stack(tensors, dim=0) - .to(device=device, dtype=dtype, non_blocking=True) - .contiguous() + +def _ensure_grouped_weights( + experts_module, expert_impls: List[torch.nn.Module], sample_mod: torch.nn.Module +) -> _GroupedWeightStorage: + storage: Optional[_GroupedWeightStorage] = getattr( + experts_module, "_ax_grouped_storage", None + ) + + def _store(new_storage: _GroupedWeightStorage) -> _GroupedWeightStorage: + experts_module._ax_grouped_storage = new_storage + return new_storage + + # Identify expert parameter layout + if ( + hasattr(sample_mod, "w1") + and hasattr(sample_mod, "w3") + and hasattr(sample_mod, "w2") + ): + pattern = "swi_glu" + if ( + storage is not None + and storage.pattern == pattern + and storage.dtype == sample_mod.w1.weight.dtype + and storage.device == sample_mod.w1.weight.device + ): + return storage + + num_experts = len(expert_impls) + w1_shape = sample_mod.w1.weight.shape + w3_shape = sample_mod.w3.weight.shape + w2_shape = sample_mod.w2.weight.shape + gate = torch.empty( + (num_experts, *w1_shape), + device=sample_mod.w1.weight.device, + dtype=sample_mod.w1.weight.dtype, + ) + up = torch.empty( + (num_experts, *w3_shape), + device=sample_mod.w3.weight.device, + dtype=sample_mod.w3.weight.dtype, + ) + down = torch.empty( + (num_experts, *w2_shape), + device=sample_mod.w2.weight.device, + dtype=sample_mod.w2.weight.dtype, + ) + with torch.no_grad(): + for idx, mod in enumerate(expert_impls): + gate[idx].copy_(mod.w1.weight.detach()) + up[idx].copy_(mod.w3.weight.detach()) + down[idx].copy_(mod.w2.weight.detach()) + mod.w1.weight.detach_() + mod.w1.weight.set_(gate[idx]) + mod.w3.weight.detach_() + mod.w3.weight.set_(up[idx]) + mod.w2.weight.detach_() + mod.w2.weight.set_(down[idx]) + + return _store( + _GroupedWeightStorage( + pattern=pattern, + gate=gate, + up=up, + down=down, + dtype=gate.dtype, + device=gate.device, + ) + ) + + if hasattr(sample_mod, "gate_up_proj") and hasattr(sample_mod, "down_proj"): + pattern = "fused_gate_up" + gate_weight = sample_mod.gate_up_proj.weight + down_weight = sample_mod.down_proj.weight + if ( + storage is not None + and storage.pattern == pattern + and storage.dtype == gate_weight.dtype + and storage.device == gate_weight.device + ): + return storage + + num_experts = len(expert_impls) + gate_full = torch.empty( + (num_experts, *gate_weight.shape), + device=gate_weight.device, + dtype=gate_weight.dtype, + ) + down = torch.empty( + (num_experts, *down_weight.shape), + device=down_weight.device, + dtype=down_weight.dtype, + ) + with torch.no_grad(): + for idx, mod in enumerate(expert_impls): + gate_full[idx].copy_(mod.gate_up_proj.weight.detach()) + down[idx].copy_(mod.down_proj.weight.detach()) + mod.gate_up_proj.weight.detach_() + mod.gate_up_proj.weight.set_(gate_full[idx]) + mod.down_proj.weight.detach_() + mod.down_proj.weight.set_(down[idx]) + + inter = gate_weight.shape[0] // 2 + gate = gate_full[:, :inter] + up = gate_full[:, inter:] + return _store( + _GroupedWeightStorage( + pattern=pattern, + gate=gate, + up=up, + down=down, + dtype=gate.dtype, + device=gate.device, + base_gate=gate_full, + ) + ) + + if ( + hasattr(sample_mod, "up_proj") + and hasattr(sample_mod, "gate_proj") + and hasattr(sample_mod, "down_proj") + ): + pattern = "dual_proj" + if ( + storage is not None + and storage.pattern == pattern + and storage.dtype == sample_mod.up_proj.weight.dtype + and storage.device == sample_mod.up_proj.weight.device + ): + return storage + + num_experts = len(expert_impls) + up_weight = sample_mod.up_proj.weight + gate_weight = sample_mod.gate_proj.weight + down_weight = sample_mod.down_proj.weight + up = torch.empty( + (num_experts, *up_weight.shape), + device=up_weight.device, + dtype=up_weight.dtype, + ) + gate = torch.empty( + (num_experts, *gate_weight.shape), + device=gate_weight.device, + dtype=gate_weight.dtype, + ) + down = torch.empty( + (num_experts, *down_weight.shape), + device=down_weight.device, + dtype=down_weight.dtype, + ) + with torch.no_grad(): + for idx, mod in enumerate(expert_impls): + up[idx].copy_(mod.up_proj.weight.detach()) + gate[idx].copy_(mod.gate_proj.weight.detach()) + down[idx].copy_(mod.down_proj.weight.detach()) + mod.up_proj.weight.detach_() + mod.up_proj.weight.set_(up[idx]) + mod.gate_proj.weight.detach_() + mod.gate_proj.weight.set_(gate[idx]) + mod.down_proj.weight.detach_() + mod.down_proj.weight.set_(down[idx]) + + return _store( + _GroupedWeightStorage( + pattern=pattern, + gate=gate, + up=up, + down=down, + dtype=gate.dtype, + device=gate.device, + ) + ) + + raise RuntimeError( + "torch_grouped: unsupported expert module layout for grouped weights" ) @@ -75,31 +244,12 @@ def moe_ffn_forward_grouped( ) return None, None - for suffix in ("w13", "w2"): - attr = f"_ax_grouped_{suffix}" - if hasattr(experts_module, attr): - delattr(experts_module, attr) - expert_impls = _iter_expert_impls(experts_module) sample_mod = expert_impls[0] - if ( - hasattr(sample_mod, "w1") - and hasattr(sample_mod, "w3") - and hasattr(sample_mod, "w2") - ): - w13 = _stack_weights( - experts_module, ("w1", "w3"), dtype=expert_dtype, device=device - ) - w2 = _stack_weights(experts_module, ("w2",), dtype=expert_dtype, device=device) - else: - if hasattr(sample_mod, "gate_up_proj"): - names13: Tuple[str, ...] = ("gate_up_proj",) - else: - names13 = ("up_proj", "gate_proj") - w13 = _stack_weights(experts_module, names13, dtype=expert_dtype, device=device) - w2 = _stack_weights( - experts_module, ("down_proj",), dtype=expert_dtype, device=device - ) + storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod) + w_gate = storage.gate + w_up = storage.up + w2 = storage.down x_flat = hidden_states.view(tokens, hdim).to(expert_dtype) router_logits = gate_linear(x_flat.to(routing_dtype)) @@ -139,10 +289,6 @@ def moe_ffn_forward_grouped( zero = torch.zeros_like(x_flat) return zero.view(bsz, seqlen, hdim), router_logits - mid = w13.shape[-1] // 2 - w_gate = w13[..., :mid] - w_up = w13[..., mid:] - w_gate_t = w_gate[active_idx].transpose(-2, -1).contiguous() w_up_t = w_up[active_idx].transpose(-2, -1).contiguous() w2_t = w2[active_idx].transpose(-2, -1).contiguous()