refactor + fix
This commit is contained in:
@@ -32,6 +32,16 @@ LAST_ERROR: Optional[str] = None
|
||||
_LOGGER = logging.getLogger("axolotl.moe.grouped")
|
||||
|
||||
|
||||
def _is_mixtral_layout(mod: torch.nn.Module) -> bool:
|
||||
return all(hasattr(mod, attr) for attr in ("w1", "w3", "w2"))
|
||||
|
||||
|
||||
def _is_qwen_layout(mod: torch.nn.Module) -> bool:
|
||||
has_fused = hasattr(mod, "gate_up_proj")
|
||||
has_split = hasattr(mod, "up_proj") and hasattr(mod, "gate_proj")
|
||||
return (has_fused or has_split) and hasattr(mod, "down_proj")
|
||||
|
||||
|
||||
def _call_grouped_mm(
|
||||
As: List[torch.Tensor], Bs: List[torch.Tensor]
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
@@ -96,33 +106,27 @@ def moe_ffn_forward_grouped(
|
||||
flat_idx = topk_idx.view(-1)
|
||||
x_rep = x.repeat_interleave(top_k, dim=0)
|
||||
|
||||
# Cache stacked weights on experts (support Mixtral and Qwen2-MoE layouts)
|
||||
# Cache stacked weights on experts (support Mixtral and Qwen-style layouts)
|
||||
E = experts_module.num_experts
|
||||
dev, dt = x.device, x.dtype
|
||||
first = experts_module[0]
|
||||
is_mixtral = (
|
||||
hasattr(first, "w1") and hasattr(first, "w3") and hasattr(first, "w2")
|
||||
)
|
||||
is_qwen2 = (
|
||||
hasattr(first, "gate_up_proj")
|
||||
or hasattr(first, "gate_proj")
|
||||
or hasattr(first, "up_proj")
|
||||
) and hasattr(first, "down_proj")
|
||||
# try nested mlp/ffn module
|
||||
nested = None
|
||||
|
||||
is_mixtral = _is_mixtral_layout(first)
|
||||
is_qwen2 = _is_qwen_layout(first)
|
||||
nested_attr: Optional[str] = None
|
||||
if not (is_mixtral or is_qwen2):
|
||||
nested = getattr(first, "mlp", None) or getattr(first, "ffn", None)
|
||||
if nested is not None:
|
||||
is_mixtral = (
|
||||
hasattr(nested, "w1")
|
||||
and hasattr(nested, "w3")
|
||||
and hasattr(nested, "w2")
|
||||
)
|
||||
is_qwen2 = (
|
||||
hasattr(nested, "gate_up_proj")
|
||||
or hasattr(nested, "gate_proj")
|
||||
or hasattr(nested, "up_proj")
|
||||
) and hasattr(nested, "down_proj")
|
||||
for candidate in ("mlp", "ffn"):
|
||||
nested = getattr(first, candidate, None)
|
||||
if nested is None:
|
||||
continue
|
||||
if _is_mixtral_layout(nested):
|
||||
is_mixtral = True
|
||||
nested_attr = candidate
|
||||
break
|
||||
if _is_qwen_layout(nested):
|
||||
is_qwen2 = True
|
||||
nested_attr = candidate
|
||||
break
|
||||
if not (is_mixtral or is_qwen2):
|
||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||
_LOGGER.warning(
|
||||
@@ -131,81 +135,101 @@ def moe_ffn_forward_grouped(
|
||||
experts_module._ax_grouped_logged_fail = True
|
||||
return None, None
|
||||
|
||||
if is_mixtral:
|
||||
if (
|
||||
not hasattr(experts_module, "_stacked_w1")
|
||||
or experts_module._stacked_w1.device != dev
|
||||
or experts_module._stacked_w1.dtype != dt
|
||||
):
|
||||
w1 = [experts_module[i].w1.weight.t() for i in range(E)]
|
||||
w3 = [experts_module[i].w3.weight.t() for i in range(E)]
|
||||
w2 = [experts_module[i].w2.weight.t() for i in range(E)]
|
||||
experts_module._stacked_w1 = (
|
||||
torch.stack(w1, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
def _resolve_expert(idx: int):
|
||||
expert = experts_module[idx]
|
||||
if nested_attr is None:
|
||||
return expert
|
||||
nested_mod = getattr(expert, nested_attr, None)
|
||||
if nested_mod is None:
|
||||
raise AttributeError(
|
||||
f"expert {idx} missing nested module '{nested_attr}'"
|
||||
)
|
||||
experts_module._stacked_w3 = (
|
||||
torch.stack(w3, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w2 = (
|
||||
torch.stack(w2, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w13 = torch.cat(
|
||||
[experts_module._stacked_w1, experts_module._stacked_w3], dim=-1
|
||||
).contiguous()
|
||||
W13 = experts_module._stacked_w13
|
||||
W2 = experts_module._stacked_w2
|
||||
else:
|
||||
# Qwen2/3 MoE style: either gate_up_proj (2I x H) or (up_proj + gate_proj), down_proj (H x I)
|
||||
if (
|
||||
not hasattr(experts_module, "_stacked_w13")
|
||||
or experts_module._stacked_w13.device != dev
|
||||
or experts_module._stacked_w13.dtype != dt
|
||||
):
|
||||
w13 = []
|
||||
w2 = []
|
||||
for i in range(E):
|
||||
exp = experts_module[i]
|
||||
mod = nested if nested is not None else exp
|
||||
# prefer fused gate_up_proj if present
|
||||
if hasattr(mod, "gate_up_proj"):
|
||||
w13.append(mod.gate_up_proj.weight.t())
|
||||
elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"):
|
||||
# concatenate [up | gate] along N
|
||||
w13.append(
|
||||
torch.cat(
|
||||
[mod.up_proj.weight.t(), mod.gate_proj.weight.t()],
|
||||
dim=-1,
|
||||
return nested_mod
|
||||
|
||||
try:
|
||||
if is_mixtral:
|
||||
if (
|
||||
not hasattr(experts_module, "_stacked_w1")
|
||||
or experts_module._stacked_w1.device != dev
|
||||
or experts_module._stacked_w1.dtype != dt
|
||||
):
|
||||
mods = [_resolve_expert(i) for i in range(E)]
|
||||
w1 = [mod.w1.weight.t() for mod in mods]
|
||||
w3 = [mod.w3.weight.t() for mod in mods]
|
||||
w2 = [mod.w2.weight.t() for mod in mods]
|
||||
experts_module._stacked_w1 = (
|
||||
torch.stack(w1, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w3 = (
|
||||
torch.stack(w3, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w2 = (
|
||||
torch.stack(w2, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w13 = torch.cat(
|
||||
[experts_module._stacked_w1, experts_module._stacked_w3], dim=-1
|
||||
).contiguous()
|
||||
W13 = experts_module._stacked_w13
|
||||
W2 = experts_module._stacked_w2
|
||||
else:
|
||||
# Qwen-style MoE: either gate_up_proj (2I x H) or (up_proj + gate_proj), down_proj (H x I)
|
||||
if (
|
||||
not hasattr(experts_module, "_stacked_w13")
|
||||
or experts_module._stacked_w13.device != dev
|
||||
or experts_module._stacked_w13.dtype != dt
|
||||
):
|
||||
w13 = []
|
||||
w2 = []
|
||||
for i in range(E):
|
||||
mod = _resolve_expert(i)
|
||||
# prefer fused gate_up_proj if present
|
||||
if hasattr(mod, "gate_up_proj"):
|
||||
w13.append(mod.gate_up_proj.weight.t())
|
||||
elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"):
|
||||
# concatenate [up | gate] along N
|
||||
w13.append(
|
||||
torch.cat(
|
||||
[mod.up_proj.weight.t(), mod.gate_proj.weight.t()],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
LAST_ERROR = "unrecognized Qwen MoE expert weight layout"
|
||||
if not getattr(
|
||||
experts_module, "_ax_grouped_logged_fail", False
|
||||
):
|
||||
_LOGGER.warning(
|
||||
"torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive"
|
||||
)
|
||||
experts_module._ax_grouped_logged_fail = True
|
||||
return None, None
|
||||
w2.append((mod.down_proj.weight.t()))
|
||||
experts_module._stacked_w13 = (
|
||||
torch.stack(w13, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
else:
|
||||
LAST_ERROR = "unrecognized Qwen MoE expert weight layout"
|
||||
if not getattr(
|
||||
experts_module, "_ax_grouped_logged_fail", False
|
||||
):
|
||||
_LOGGER.warning(
|
||||
"torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive"
|
||||
)
|
||||
experts_module._ax_grouped_logged_fail = True
|
||||
return None, None
|
||||
w2.append(mod.down_proj.weight.t())
|
||||
experts_module._stacked_w13 = (
|
||||
torch.stack(w13, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
experts_module._stacked_w2 = (
|
||||
torch.stack(w2, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
W13 = experts_module._stacked_w13
|
||||
W2 = experts_module._stacked_w2
|
||||
except AttributeError as err:
|
||||
LAST_ERROR = str(err)
|
||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||
_LOGGER.warning(
|
||||
"torch_grouped: expert weights missing expected attributes; falling back to naive"
|
||||
)
|
||||
experts_module._stacked_w2 = (
|
||||
torch.stack(w2, dim=0)
|
||||
.to(device=dev, dtype=dt, non_blocking=True)
|
||||
.contiguous()
|
||||
)
|
||||
W13 = experts_module._stacked_w13
|
||||
W2 = experts_module._stacked_w2
|
||||
experts_module._ax_grouped_logged_fail = True
|
||||
return None, None
|
||||
|
||||
# Grouped GEMM for up+gate
|
||||
As: List[torch.Tensor] = []
|
||||
@@ -237,8 +261,9 @@ def moe_ffn_forward_grouped(
|
||||
As2: List[torch.Tensor] = []
|
||||
Bs2: List[torch.Tensor] = []
|
||||
y_buf = torch.empty_like(x_rep)
|
||||
|
||||
# split Y into (I, I)
|
||||
for (i, sel), Yi in zip(expert_slices, Y_list):
|
||||
for Yi in Y_list:
|
||||
I2 = Yi.shape[-1] // 2
|
||||
Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]
|
||||
As2.append(Yi_hidden)
|
||||
@@ -254,7 +279,7 @@ def moe_ffn_forward_grouped(
|
||||
return None, None
|
||||
|
||||
# Write back, apply per-token weighting, and reduce over top_k
|
||||
for (i, sel), Out_i in zip(expert_slices, Y2_list):
|
||||
for (_, sel), Out_i in zip(expert_slices, Y2_list, strict=False):
|
||||
y_buf[sel] = Out_i
|
||||
y = (y_buf.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
if not getattr(experts_module, "_ax_grouped_logged_ok", False):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
import warnings
|
||||
from functools import wraps
|
||||
|
||||
import torch
|
||||
|
||||
@@ -11,7 +11,7 @@ _LOG = logging.getLogger("axolotl.moe.patch")
|
||||
|
||||
def _patch_block_forward(block_cls, grouped_fn):
|
||||
"""Replace block_cls.forward with grouped_fn preserving signature."""
|
||||
setattr(block_cls, "forward", grouped_fn)
|
||||
block_cls.forward = grouped_fn
|
||||
|
||||
|
||||
def apply_grouped_to_moe_blocks(cfg=None) -> None:
|
||||
@@ -73,7 +73,8 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
|
||||
}
|
||||
|
||||
def make_grouped_forward(orig_forward):
|
||||
def _grouped_forward(self, hidden_states: torch.Tensor):
|
||||
@wraps(orig_forward)
|
||||
def _grouped_forward(self, hidden_states: torch.Tensor, *args, **kwargs):
|
||||
bsz, seqlen, hdim = hidden_states.shape
|
||||
y, router_logits = _tg.moe_ffn_forward_grouped(
|
||||
hidden_states, self.gate, self.experts, self.top_k
|
||||
@@ -90,7 +91,7 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
|
||||
)
|
||||
self._ax_grouped_wrapper_logged = True
|
||||
if y is None:
|
||||
return orig_forward(self, hidden_states)
|
||||
return orig_forward(self, hidden_states, *args, **kwargs)
|
||||
return y, router_logits
|
||||
|
||||
return _grouped_forward
|
||||
|
||||
Reference in New Issue
Block a user