This commit is contained in:
Dan Saunders
2025-09-18 11:44:21 -04:00
parent 2a176e4923
commit 19c91e3675

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
@@ -33,22 +34,190 @@ def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
return impls
def _stack_weights(
experts_module,
names: Tuple[str, ...],
*,
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
tensors: List[torch.Tensor] = []
for mod in _iter_expert_impls(experts_module):
parts = [getattr(mod, name).weight.t() for name in names]
tensors.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1))
@dataclass
class _GroupedWeightStorage:
pattern: str
gate: torch.Tensor
up: torch.Tensor
down: torch.Tensor
dtype: torch.dtype
device: torch.device
base_gate: Optional[torch.Tensor] = None
return (
torch.stack(tensors, dim=0)
.to(device=device, dtype=dtype, non_blocking=True)
.contiguous()
def _ensure_grouped_weights(
experts_module, expert_impls: List[torch.nn.Module], sample_mod: torch.nn.Module
) -> _GroupedWeightStorage:
storage: Optional[_GroupedWeightStorage] = getattr(
experts_module, "_ax_grouped_storage", None
)
def _store(new_storage: _GroupedWeightStorage) -> _GroupedWeightStorage:
experts_module._ax_grouped_storage = new_storage
return new_storage
# Identify expert parameter layout
if (
hasattr(sample_mod, "w1")
and hasattr(sample_mod, "w3")
and hasattr(sample_mod, "w2")
):
pattern = "swi_glu"
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == sample_mod.w1.weight.dtype
and storage.device == sample_mod.w1.weight.device
):
return storage
num_experts = len(expert_impls)
w1_shape = sample_mod.w1.weight.shape
w3_shape = sample_mod.w3.weight.shape
w2_shape = sample_mod.w2.weight.shape
gate = torch.empty(
(num_experts, *w1_shape),
device=sample_mod.w1.weight.device,
dtype=sample_mod.w1.weight.dtype,
)
up = torch.empty(
(num_experts, *w3_shape),
device=sample_mod.w3.weight.device,
dtype=sample_mod.w3.weight.dtype,
)
down = torch.empty(
(num_experts, *w2_shape),
device=sample_mod.w2.weight.device,
dtype=sample_mod.w2.weight.dtype,
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
gate[idx].copy_(mod.w1.weight.detach())
up[idx].copy_(mod.w3.weight.detach())
down[idx].copy_(mod.w2.weight.detach())
mod.w1.weight.detach_()
mod.w1.weight.set_(gate[idx])
mod.w3.weight.detach_()
mod.w3.weight.set_(up[idx])
mod.w2.weight.detach_()
mod.w2.weight.set_(down[idx])
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
dtype=gate.dtype,
device=gate.device,
)
)
if hasattr(sample_mod, "gate_up_proj") and hasattr(sample_mod, "down_proj"):
pattern = "fused_gate_up"
gate_weight = sample_mod.gate_up_proj.weight
down_weight = sample_mod.down_proj.weight
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == gate_weight.dtype
and storage.device == gate_weight.device
):
return storage
num_experts = len(expert_impls)
gate_full = torch.empty(
(num_experts, *gate_weight.shape),
device=gate_weight.device,
dtype=gate_weight.dtype,
)
down = torch.empty(
(num_experts, *down_weight.shape),
device=down_weight.device,
dtype=down_weight.dtype,
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
gate_full[idx].copy_(mod.gate_up_proj.weight.detach())
down[idx].copy_(mod.down_proj.weight.detach())
mod.gate_up_proj.weight.detach_()
mod.gate_up_proj.weight.set_(gate_full[idx])
mod.down_proj.weight.detach_()
mod.down_proj.weight.set_(down[idx])
inter = gate_weight.shape[0] // 2
gate = gate_full[:, :inter]
up = gate_full[:, inter:]
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
dtype=gate.dtype,
device=gate.device,
base_gate=gate_full,
)
)
if (
hasattr(sample_mod, "up_proj")
and hasattr(sample_mod, "gate_proj")
and hasattr(sample_mod, "down_proj")
):
pattern = "dual_proj"
if (
storage is not None
and storage.pattern == pattern
and storage.dtype == sample_mod.up_proj.weight.dtype
and storage.device == sample_mod.up_proj.weight.device
):
return storage
num_experts = len(expert_impls)
up_weight = sample_mod.up_proj.weight
gate_weight = sample_mod.gate_proj.weight
down_weight = sample_mod.down_proj.weight
up = torch.empty(
(num_experts, *up_weight.shape),
device=up_weight.device,
dtype=up_weight.dtype,
)
gate = torch.empty(
(num_experts, *gate_weight.shape),
device=gate_weight.device,
dtype=gate_weight.dtype,
)
down = torch.empty(
(num_experts, *down_weight.shape),
device=down_weight.device,
dtype=down_weight.dtype,
)
with torch.no_grad():
for idx, mod in enumerate(expert_impls):
up[idx].copy_(mod.up_proj.weight.detach())
gate[idx].copy_(mod.gate_proj.weight.detach())
down[idx].copy_(mod.down_proj.weight.detach())
mod.up_proj.weight.detach_()
mod.up_proj.weight.set_(up[idx])
mod.gate_proj.weight.detach_()
mod.gate_proj.weight.set_(gate[idx])
mod.down_proj.weight.detach_()
mod.down_proj.weight.set_(down[idx])
return _store(
_GroupedWeightStorage(
pattern=pattern,
gate=gate,
up=up,
down=down,
dtype=gate.dtype,
device=gate.device,
)
)
raise RuntimeError(
"torch_grouped: unsupported expert module layout for grouped weights"
)
@@ -75,31 +244,12 @@ def moe_ffn_forward_grouped(
)
return None, None
for suffix in ("w13", "w2"):
attr = f"_ax_grouped_{suffix}"
if hasattr(experts_module, attr):
delattr(experts_module, attr)
expert_impls = _iter_expert_impls(experts_module)
sample_mod = expert_impls[0]
if (
hasattr(sample_mod, "w1")
and hasattr(sample_mod, "w3")
and hasattr(sample_mod, "w2")
):
w13 = _stack_weights(
experts_module, ("w1", "w3"), dtype=expert_dtype, device=device
)
w2 = _stack_weights(experts_module, ("w2",), dtype=expert_dtype, device=device)
else:
if hasattr(sample_mod, "gate_up_proj"):
names13: Tuple[str, ...] = ("gate_up_proj",)
else:
names13 = ("up_proj", "gate_proj")
w13 = _stack_weights(experts_module, names13, dtype=expert_dtype, device=device)
w2 = _stack_weights(
experts_module, ("down_proj",), dtype=expert_dtype, device=device
)
storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod)
w_gate = storage.gate
w_up = storage.up
w2 = storage.down
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
router_logits = gate_linear(x_flat.to(routing_dtype))
@@ -139,10 +289,6 @@ def moe_ffn_forward_grouped(
zero = torch.zeros_like(x_flat)
return zero.view(bsz, seqlen, hdim), router_logits
mid = w13.shape[-1] // 2
w_gate = w13[..., :mid]
w_up = w13[..., mid:]
w_gate_t = w_gate[active_idx].transpose(-2, -1).contiguous()
w_up_t = w_up[active_idx].transpose(-2, -1).contiguous()
w2_t = w2[active_idx].transpose(-2, -1).contiguous()