refactor
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -33,22 +34,190 @@ def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
|
||||
return impls
|
||||
|
||||
|
||||
def _stack_weights(
|
||||
experts_module,
|
||||
names: Tuple[str, ...],
|
||||
*,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
tensors: List[torch.Tensor] = []
|
||||
for mod in _iter_expert_impls(experts_module):
|
||||
parts = [getattr(mod, name).weight.t() for name in names]
|
||||
tensors.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1))
|
||||
@dataclass
|
||||
class _GroupedWeightStorage:
|
||||
pattern: str
|
||||
gate: torch.Tensor
|
||||
up: torch.Tensor
|
||||
down: torch.Tensor
|
||||
dtype: torch.dtype
|
||||
device: torch.device
|
||||
base_gate: Optional[torch.Tensor] = None
|
||||
|
||||
return (
|
||||
torch.stack(tensors, dim=0)
|
||||
.to(device=device, dtype=dtype, non_blocking=True)
|
||||
.contiguous()
|
||||
|
||||
def _ensure_grouped_weights(
|
||||
experts_module, expert_impls: List[torch.nn.Module], sample_mod: torch.nn.Module
|
||||
) -> _GroupedWeightStorage:
|
||||
storage: Optional[_GroupedWeightStorage] = getattr(
|
||||
experts_module, "_ax_grouped_storage", None
|
||||
)
|
||||
|
||||
def _store(new_storage: _GroupedWeightStorage) -> _GroupedWeightStorage:
|
||||
experts_module._ax_grouped_storage = new_storage
|
||||
return new_storage
|
||||
|
||||
# Identify expert parameter layout
|
||||
if (
|
||||
hasattr(sample_mod, "w1")
|
||||
and hasattr(sample_mod, "w3")
|
||||
and hasattr(sample_mod, "w2")
|
||||
):
|
||||
pattern = "swi_glu"
|
||||
if (
|
||||
storage is not None
|
||||
and storage.pattern == pattern
|
||||
and storage.dtype == sample_mod.w1.weight.dtype
|
||||
and storage.device == sample_mod.w1.weight.device
|
||||
):
|
||||
return storage
|
||||
|
||||
num_experts = len(expert_impls)
|
||||
w1_shape = sample_mod.w1.weight.shape
|
||||
w3_shape = sample_mod.w3.weight.shape
|
||||
w2_shape = sample_mod.w2.weight.shape
|
||||
gate = torch.empty(
|
||||
(num_experts, *w1_shape),
|
||||
device=sample_mod.w1.weight.device,
|
||||
dtype=sample_mod.w1.weight.dtype,
|
||||
)
|
||||
up = torch.empty(
|
||||
(num_experts, *w3_shape),
|
||||
device=sample_mod.w3.weight.device,
|
||||
dtype=sample_mod.w3.weight.dtype,
|
||||
)
|
||||
down = torch.empty(
|
||||
(num_experts, *w2_shape),
|
||||
device=sample_mod.w2.weight.device,
|
||||
dtype=sample_mod.w2.weight.dtype,
|
||||
)
|
||||
with torch.no_grad():
|
||||
for idx, mod in enumerate(expert_impls):
|
||||
gate[idx].copy_(mod.w1.weight.detach())
|
||||
up[idx].copy_(mod.w3.weight.detach())
|
||||
down[idx].copy_(mod.w2.weight.detach())
|
||||
mod.w1.weight.detach_()
|
||||
mod.w1.weight.set_(gate[idx])
|
||||
mod.w3.weight.detach_()
|
||||
mod.w3.weight.set_(up[idx])
|
||||
mod.w2.weight.detach_()
|
||||
mod.w2.weight.set_(down[idx])
|
||||
|
||||
return _store(
|
||||
_GroupedWeightStorage(
|
||||
pattern=pattern,
|
||||
gate=gate,
|
||||
up=up,
|
||||
down=down,
|
||||
dtype=gate.dtype,
|
||||
device=gate.device,
|
||||
)
|
||||
)
|
||||
|
||||
if hasattr(sample_mod, "gate_up_proj") and hasattr(sample_mod, "down_proj"):
|
||||
pattern = "fused_gate_up"
|
||||
gate_weight = sample_mod.gate_up_proj.weight
|
||||
down_weight = sample_mod.down_proj.weight
|
||||
if (
|
||||
storage is not None
|
||||
and storage.pattern == pattern
|
||||
and storage.dtype == gate_weight.dtype
|
||||
and storage.device == gate_weight.device
|
||||
):
|
||||
return storage
|
||||
|
||||
num_experts = len(expert_impls)
|
||||
gate_full = torch.empty(
|
||||
(num_experts, *gate_weight.shape),
|
||||
device=gate_weight.device,
|
||||
dtype=gate_weight.dtype,
|
||||
)
|
||||
down = torch.empty(
|
||||
(num_experts, *down_weight.shape),
|
||||
device=down_weight.device,
|
||||
dtype=down_weight.dtype,
|
||||
)
|
||||
with torch.no_grad():
|
||||
for idx, mod in enumerate(expert_impls):
|
||||
gate_full[idx].copy_(mod.gate_up_proj.weight.detach())
|
||||
down[idx].copy_(mod.down_proj.weight.detach())
|
||||
mod.gate_up_proj.weight.detach_()
|
||||
mod.gate_up_proj.weight.set_(gate_full[idx])
|
||||
mod.down_proj.weight.detach_()
|
||||
mod.down_proj.weight.set_(down[idx])
|
||||
|
||||
inter = gate_weight.shape[0] // 2
|
||||
gate = gate_full[:, :inter]
|
||||
up = gate_full[:, inter:]
|
||||
return _store(
|
||||
_GroupedWeightStorage(
|
||||
pattern=pattern,
|
||||
gate=gate,
|
||||
up=up,
|
||||
down=down,
|
||||
dtype=gate.dtype,
|
||||
device=gate.device,
|
||||
base_gate=gate_full,
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(sample_mod, "up_proj")
|
||||
and hasattr(sample_mod, "gate_proj")
|
||||
and hasattr(sample_mod, "down_proj")
|
||||
):
|
||||
pattern = "dual_proj"
|
||||
if (
|
||||
storage is not None
|
||||
and storage.pattern == pattern
|
||||
and storage.dtype == sample_mod.up_proj.weight.dtype
|
||||
and storage.device == sample_mod.up_proj.weight.device
|
||||
):
|
||||
return storage
|
||||
|
||||
num_experts = len(expert_impls)
|
||||
up_weight = sample_mod.up_proj.weight
|
||||
gate_weight = sample_mod.gate_proj.weight
|
||||
down_weight = sample_mod.down_proj.weight
|
||||
up = torch.empty(
|
||||
(num_experts, *up_weight.shape),
|
||||
device=up_weight.device,
|
||||
dtype=up_weight.dtype,
|
||||
)
|
||||
gate = torch.empty(
|
||||
(num_experts, *gate_weight.shape),
|
||||
device=gate_weight.device,
|
||||
dtype=gate_weight.dtype,
|
||||
)
|
||||
down = torch.empty(
|
||||
(num_experts, *down_weight.shape),
|
||||
device=down_weight.device,
|
||||
dtype=down_weight.dtype,
|
||||
)
|
||||
with torch.no_grad():
|
||||
for idx, mod in enumerate(expert_impls):
|
||||
up[idx].copy_(mod.up_proj.weight.detach())
|
||||
gate[idx].copy_(mod.gate_proj.weight.detach())
|
||||
down[idx].copy_(mod.down_proj.weight.detach())
|
||||
mod.up_proj.weight.detach_()
|
||||
mod.up_proj.weight.set_(up[idx])
|
||||
mod.gate_proj.weight.detach_()
|
||||
mod.gate_proj.weight.set_(gate[idx])
|
||||
mod.down_proj.weight.detach_()
|
||||
mod.down_proj.weight.set_(down[idx])
|
||||
|
||||
return _store(
|
||||
_GroupedWeightStorage(
|
||||
pattern=pattern,
|
||||
gate=gate,
|
||||
up=up,
|
||||
down=down,
|
||||
dtype=gate.dtype,
|
||||
device=gate.device,
|
||||
)
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
"torch_grouped: unsupported expert module layout for grouped weights"
|
||||
)
|
||||
|
||||
|
||||
@@ -75,31 +244,12 @@ def moe_ffn_forward_grouped(
|
||||
)
|
||||
return None, None
|
||||
|
||||
for suffix in ("w13", "w2"):
|
||||
attr = f"_ax_grouped_{suffix}"
|
||||
if hasattr(experts_module, attr):
|
||||
delattr(experts_module, attr)
|
||||
|
||||
expert_impls = _iter_expert_impls(experts_module)
|
||||
sample_mod = expert_impls[0]
|
||||
if (
|
||||
hasattr(sample_mod, "w1")
|
||||
and hasattr(sample_mod, "w3")
|
||||
and hasattr(sample_mod, "w2")
|
||||
):
|
||||
w13 = _stack_weights(
|
||||
experts_module, ("w1", "w3"), dtype=expert_dtype, device=device
|
||||
)
|
||||
w2 = _stack_weights(experts_module, ("w2",), dtype=expert_dtype, device=device)
|
||||
else:
|
||||
if hasattr(sample_mod, "gate_up_proj"):
|
||||
names13: Tuple[str, ...] = ("gate_up_proj",)
|
||||
else:
|
||||
names13 = ("up_proj", "gate_proj")
|
||||
w13 = _stack_weights(experts_module, names13, dtype=expert_dtype, device=device)
|
||||
w2 = _stack_weights(
|
||||
experts_module, ("down_proj",), dtype=expert_dtype, device=device
|
||||
)
|
||||
storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod)
|
||||
w_gate = storage.gate
|
||||
w_up = storage.up
|
||||
w2 = storage.down
|
||||
|
||||
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
|
||||
router_logits = gate_linear(x_flat.to(routing_dtype))
|
||||
@@ -139,10 +289,6 @@ def moe_ffn_forward_grouped(
|
||||
zero = torch.zeros_like(x_flat)
|
||||
return zero.view(bsz, seqlen, hdim), router_logits
|
||||
|
||||
mid = w13.shape[-1] // 2
|
||||
w_gate = w13[..., :mid]
|
||||
w_up = w13[..., mid:]
|
||||
|
||||
w_gate_t = w_gate[active_idx].transpose(-2, -1).contiguous()
|
||||
w_up_t = w_up[active_idx].transpose(-2, -1).contiguous()
|
||||
w2_t = w2[active_idx].transpose(-2, -1).contiguous()
|
||||
|
||||
Reference in New Issue
Block a user