refactor
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -33,22 +34,190 @@ def _iter_expert_impls(experts_module) -> List[torch.nn.Module]:
|
|||||||
return impls
|
return impls
|
||||||
|
|
||||||
|
|
||||||
def _stack_weights(
|
@dataclass
|
||||||
experts_module,
|
class _GroupedWeightStorage:
|
||||||
names: Tuple[str, ...],
|
pattern: str
|
||||||
*,
|
gate: torch.Tensor
|
||||||
dtype: torch.dtype,
|
up: torch.Tensor
|
||||||
device: torch.device,
|
down: torch.Tensor
|
||||||
) -> torch.Tensor:
|
dtype: torch.dtype
|
||||||
tensors: List[torch.Tensor] = []
|
device: torch.device
|
||||||
for mod in _iter_expert_impls(experts_module):
|
base_gate: Optional[torch.Tensor] = None
|
||||||
parts = [getattr(mod, name).weight.t() for name in names]
|
|
||||||
tensors.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1))
|
|
||||||
|
|
||||||
return (
|
|
||||||
torch.stack(tensors, dim=0)
|
def _ensure_grouped_weights(
|
||||||
.to(device=device, dtype=dtype, non_blocking=True)
|
experts_module, expert_impls: List[torch.nn.Module], sample_mod: torch.nn.Module
|
||||||
.contiguous()
|
) -> _GroupedWeightStorage:
|
||||||
|
storage: Optional[_GroupedWeightStorage] = getattr(
|
||||||
|
experts_module, "_ax_grouped_storage", None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _store(new_storage: _GroupedWeightStorage) -> _GroupedWeightStorage:
|
||||||
|
experts_module._ax_grouped_storage = new_storage
|
||||||
|
return new_storage
|
||||||
|
|
||||||
|
# Identify expert parameter layout
|
||||||
|
if (
|
||||||
|
hasattr(sample_mod, "w1")
|
||||||
|
and hasattr(sample_mod, "w3")
|
||||||
|
and hasattr(sample_mod, "w2")
|
||||||
|
):
|
||||||
|
pattern = "swi_glu"
|
||||||
|
if (
|
||||||
|
storage is not None
|
||||||
|
and storage.pattern == pattern
|
||||||
|
and storage.dtype == sample_mod.w1.weight.dtype
|
||||||
|
and storage.device == sample_mod.w1.weight.device
|
||||||
|
):
|
||||||
|
return storage
|
||||||
|
|
||||||
|
num_experts = len(expert_impls)
|
||||||
|
w1_shape = sample_mod.w1.weight.shape
|
||||||
|
w3_shape = sample_mod.w3.weight.shape
|
||||||
|
w2_shape = sample_mod.w2.weight.shape
|
||||||
|
gate = torch.empty(
|
||||||
|
(num_experts, *w1_shape),
|
||||||
|
device=sample_mod.w1.weight.device,
|
||||||
|
dtype=sample_mod.w1.weight.dtype,
|
||||||
|
)
|
||||||
|
up = torch.empty(
|
||||||
|
(num_experts, *w3_shape),
|
||||||
|
device=sample_mod.w3.weight.device,
|
||||||
|
dtype=sample_mod.w3.weight.dtype,
|
||||||
|
)
|
||||||
|
down = torch.empty(
|
||||||
|
(num_experts, *w2_shape),
|
||||||
|
device=sample_mod.w2.weight.device,
|
||||||
|
dtype=sample_mod.w2.weight.dtype,
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
for idx, mod in enumerate(expert_impls):
|
||||||
|
gate[idx].copy_(mod.w1.weight.detach())
|
||||||
|
up[idx].copy_(mod.w3.weight.detach())
|
||||||
|
down[idx].copy_(mod.w2.weight.detach())
|
||||||
|
mod.w1.weight.detach_()
|
||||||
|
mod.w1.weight.set_(gate[idx])
|
||||||
|
mod.w3.weight.detach_()
|
||||||
|
mod.w3.weight.set_(up[idx])
|
||||||
|
mod.w2.weight.detach_()
|
||||||
|
mod.w2.weight.set_(down[idx])
|
||||||
|
|
||||||
|
return _store(
|
||||||
|
_GroupedWeightStorage(
|
||||||
|
pattern=pattern,
|
||||||
|
gate=gate,
|
||||||
|
up=up,
|
||||||
|
down=down,
|
||||||
|
dtype=gate.dtype,
|
||||||
|
device=gate.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(sample_mod, "gate_up_proj") and hasattr(sample_mod, "down_proj"):
|
||||||
|
pattern = "fused_gate_up"
|
||||||
|
gate_weight = sample_mod.gate_up_proj.weight
|
||||||
|
down_weight = sample_mod.down_proj.weight
|
||||||
|
if (
|
||||||
|
storage is not None
|
||||||
|
and storage.pattern == pattern
|
||||||
|
and storage.dtype == gate_weight.dtype
|
||||||
|
and storage.device == gate_weight.device
|
||||||
|
):
|
||||||
|
return storage
|
||||||
|
|
||||||
|
num_experts = len(expert_impls)
|
||||||
|
gate_full = torch.empty(
|
||||||
|
(num_experts, *gate_weight.shape),
|
||||||
|
device=gate_weight.device,
|
||||||
|
dtype=gate_weight.dtype,
|
||||||
|
)
|
||||||
|
down = torch.empty(
|
||||||
|
(num_experts, *down_weight.shape),
|
||||||
|
device=down_weight.device,
|
||||||
|
dtype=down_weight.dtype,
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
for idx, mod in enumerate(expert_impls):
|
||||||
|
gate_full[idx].copy_(mod.gate_up_proj.weight.detach())
|
||||||
|
down[idx].copy_(mod.down_proj.weight.detach())
|
||||||
|
mod.gate_up_proj.weight.detach_()
|
||||||
|
mod.gate_up_proj.weight.set_(gate_full[idx])
|
||||||
|
mod.down_proj.weight.detach_()
|
||||||
|
mod.down_proj.weight.set_(down[idx])
|
||||||
|
|
||||||
|
inter = gate_weight.shape[0] // 2
|
||||||
|
gate = gate_full[:, :inter]
|
||||||
|
up = gate_full[:, inter:]
|
||||||
|
return _store(
|
||||||
|
_GroupedWeightStorage(
|
||||||
|
pattern=pattern,
|
||||||
|
gate=gate,
|
||||||
|
up=up,
|
||||||
|
down=down,
|
||||||
|
dtype=gate.dtype,
|
||||||
|
device=gate.device,
|
||||||
|
base_gate=gate_full,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(sample_mod, "up_proj")
|
||||||
|
and hasattr(sample_mod, "gate_proj")
|
||||||
|
and hasattr(sample_mod, "down_proj")
|
||||||
|
):
|
||||||
|
pattern = "dual_proj"
|
||||||
|
if (
|
||||||
|
storage is not None
|
||||||
|
and storage.pattern == pattern
|
||||||
|
and storage.dtype == sample_mod.up_proj.weight.dtype
|
||||||
|
and storage.device == sample_mod.up_proj.weight.device
|
||||||
|
):
|
||||||
|
return storage
|
||||||
|
|
||||||
|
num_experts = len(expert_impls)
|
||||||
|
up_weight = sample_mod.up_proj.weight
|
||||||
|
gate_weight = sample_mod.gate_proj.weight
|
||||||
|
down_weight = sample_mod.down_proj.weight
|
||||||
|
up = torch.empty(
|
||||||
|
(num_experts, *up_weight.shape),
|
||||||
|
device=up_weight.device,
|
||||||
|
dtype=up_weight.dtype,
|
||||||
|
)
|
||||||
|
gate = torch.empty(
|
||||||
|
(num_experts, *gate_weight.shape),
|
||||||
|
device=gate_weight.device,
|
||||||
|
dtype=gate_weight.dtype,
|
||||||
|
)
|
||||||
|
down = torch.empty(
|
||||||
|
(num_experts, *down_weight.shape),
|
||||||
|
device=down_weight.device,
|
||||||
|
dtype=down_weight.dtype,
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
for idx, mod in enumerate(expert_impls):
|
||||||
|
up[idx].copy_(mod.up_proj.weight.detach())
|
||||||
|
gate[idx].copy_(mod.gate_proj.weight.detach())
|
||||||
|
down[idx].copy_(mod.down_proj.weight.detach())
|
||||||
|
mod.up_proj.weight.detach_()
|
||||||
|
mod.up_proj.weight.set_(up[idx])
|
||||||
|
mod.gate_proj.weight.detach_()
|
||||||
|
mod.gate_proj.weight.set_(gate[idx])
|
||||||
|
mod.down_proj.weight.detach_()
|
||||||
|
mod.down_proj.weight.set_(down[idx])
|
||||||
|
|
||||||
|
return _store(
|
||||||
|
_GroupedWeightStorage(
|
||||||
|
pattern=pattern,
|
||||||
|
gate=gate,
|
||||||
|
up=up,
|
||||||
|
down=down,
|
||||||
|
dtype=gate.dtype,
|
||||||
|
device=gate.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
"torch_grouped: unsupported expert module layout for grouped weights"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -75,31 +244,12 @@ def moe_ffn_forward_grouped(
|
|||||||
)
|
)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
for suffix in ("w13", "w2"):
|
|
||||||
attr = f"_ax_grouped_{suffix}"
|
|
||||||
if hasattr(experts_module, attr):
|
|
||||||
delattr(experts_module, attr)
|
|
||||||
|
|
||||||
expert_impls = _iter_expert_impls(experts_module)
|
expert_impls = _iter_expert_impls(experts_module)
|
||||||
sample_mod = expert_impls[0]
|
sample_mod = expert_impls[0]
|
||||||
if (
|
storage = _ensure_grouped_weights(experts_module, expert_impls, sample_mod)
|
||||||
hasattr(sample_mod, "w1")
|
w_gate = storage.gate
|
||||||
and hasattr(sample_mod, "w3")
|
w_up = storage.up
|
||||||
and hasattr(sample_mod, "w2")
|
w2 = storage.down
|
||||||
):
|
|
||||||
w13 = _stack_weights(
|
|
||||||
experts_module, ("w1", "w3"), dtype=expert_dtype, device=device
|
|
||||||
)
|
|
||||||
w2 = _stack_weights(experts_module, ("w2",), dtype=expert_dtype, device=device)
|
|
||||||
else:
|
|
||||||
if hasattr(sample_mod, "gate_up_proj"):
|
|
||||||
names13: Tuple[str, ...] = ("gate_up_proj",)
|
|
||||||
else:
|
|
||||||
names13 = ("up_proj", "gate_proj")
|
|
||||||
w13 = _stack_weights(experts_module, names13, dtype=expert_dtype, device=device)
|
|
||||||
w2 = _stack_weights(
|
|
||||||
experts_module, ("down_proj",), dtype=expert_dtype, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
|
x_flat = hidden_states.view(tokens, hdim).to(expert_dtype)
|
||||||
router_logits = gate_linear(x_flat.to(routing_dtype))
|
router_logits = gate_linear(x_flat.to(routing_dtype))
|
||||||
@@ -139,10 +289,6 @@ def moe_ffn_forward_grouped(
|
|||||||
zero = torch.zeros_like(x_flat)
|
zero = torch.zeros_like(x_flat)
|
||||||
return zero.view(bsz, seqlen, hdim), router_logits
|
return zero.view(bsz, seqlen, hdim), router_logits
|
||||||
|
|
||||||
mid = w13.shape[-1] // 2
|
|
||||||
w_gate = w13[..., :mid]
|
|
||||||
w_up = w13[..., mid:]
|
|
||||||
|
|
||||||
w_gate_t = w_gate[active_idx].transpose(-2, -1).contiguous()
|
w_gate_t = w_gate[active_idx].transpose(-2, -1).contiguous()
|
||||||
w_up_t = w_up[active_idx].transpose(-2, -1).contiguous()
|
w_up_t = w_up[active_idx].transpose(-2, -1).contiguous()
|
||||||
w2_t = w2[active_idx].transpose(-2, -1).contiguous()
|
w2_t = w2[active_idx].transpose(-2, -1).contiguous()
|
||||||
|
|||||||
Reference in New Issue
Block a user