more logs
This commit is contained in:
@@ -103,7 +103,26 @@ def moe_ffn_forward_grouped(
|
|||||||
is_mixtral = (
|
is_mixtral = (
|
||||||
hasattr(first, "w1") and hasattr(first, "w3") and hasattr(first, "w2")
|
hasattr(first, "w1") and hasattr(first, "w3") and hasattr(first, "w2")
|
||||||
)
|
)
|
||||||
is_qwen2 = hasattr(first, "gate_up_proj") and hasattr(first, "down_proj")
|
is_qwen2 = (
|
||||||
|
hasattr(first, "gate_up_proj")
|
||||||
|
or hasattr(first, "gate_proj")
|
||||||
|
or hasattr(first, "up_proj")
|
||||||
|
) and hasattr(first, "down_proj")
|
||||||
|
# try nested mlp/ffn module
|
||||||
|
nested = None
|
||||||
|
if not (is_mixtral or is_qwen2):
|
||||||
|
nested = getattr(first, "mlp", None) or getattr(first, "ffn", None)
|
||||||
|
if nested is not None:
|
||||||
|
is_mixtral = (
|
||||||
|
hasattr(nested, "w1")
|
||||||
|
and hasattr(nested, "w3")
|
||||||
|
and hasattr(nested, "w2")
|
||||||
|
)
|
||||||
|
is_qwen2 = (
|
||||||
|
hasattr(nested, "gate_up_proj")
|
||||||
|
or hasattr(nested, "gate_proj")
|
||||||
|
or hasattr(nested, "up_proj")
|
||||||
|
) and hasattr(nested, "down_proj")
|
||||||
if not (is_mixtral or is_qwen2):
|
if not (is_mixtral or is_qwen2):
|
||||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
@@ -142,14 +161,39 @@ def moe_ffn_forward_grouped(
|
|||||||
W13 = experts_module._stacked_w13
|
W13 = experts_module._stacked_w13
|
||||||
W2 = experts_module._stacked_w2
|
W2 = experts_module._stacked_w2
|
||||||
else:
|
else:
|
||||||
# Qwen2-MoE style: gate_up_proj (2I x H), down_proj (H x I)
|
# Qwen2/3 MoE style: either gate_up_proj (2I x H) or (up_proj + gate_proj), down_proj (H x I)
|
||||||
if (
|
if (
|
||||||
not hasattr(experts_module, "_stacked_w13")
|
not hasattr(experts_module, "_stacked_w13")
|
||||||
or experts_module._stacked_w13.device != dev
|
or experts_module._stacked_w13.device != dev
|
||||||
or experts_module._stacked_w13.dtype != dt
|
or experts_module._stacked_w13.dtype != dt
|
||||||
):
|
):
|
||||||
w13 = [experts_module[i].gate_up_proj.weight.t() for i in range(E)]
|
w13 = []
|
||||||
w2 = [experts_module[i].down_proj.weight.t() for i in range(E)]
|
w2 = []
|
||||||
|
for i in range(E):
|
||||||
|
exp = experts_module[i]
|
||||||
|
mod = nested if nested is not None else exp
|
||||||
|
# prefer fused gate_up_proj if present
|
||||||
|
if hasattr(mod, "gate_up_proj"):
|
||||||
|
w13.append(mod.gate_up_proj.weight.t())
|
||||||
|
elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"):
|
||||||
|
# concatenate [up | gate] along N
|
||||||
|
w13.append(
|
||||||
|
torch.cat(
|
||||||
|
[mod.up_proj.weight.t(), mod.gate_proj.weight.t()],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LAST_ERROR = "unrecognized Qwen MoE expert weight layout"
|
||||||
|
if not getattr(
|
||||||
|
experts_module, "_ax_grouped_logged_fail", False
|
||||||
|
):
|
||||||
|
_LOGGER.warning(
|
||||||
|
"torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive"
|
||||||
|
)
|
||||||
|
experts_module._ax_grouped_logged_fail = True
|
||||||
|
return None, None
|
||||||
|
w2.append((mod.down_proj.weight.t()))
|
||||||
experts_module._stacked_w13 = (
|
experts_module._stacked_w13 = (
|
||||||
torch.stack(w13, dim=0)
|
torch.stack(w13, dim=0)
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
.to(device=dev, dtype=dt, non_blocking=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user