logs
This commit is contained in:
@@ -88,205 +88,189 @@ def _call_grouped_mm(
|
|||||||
def moe_ffn_forward_grouped(
|
def moe_ffn_forward_grouped(
|
||||||
hidden_states, gate_linear, experts_module, top_k: int
|
hidden_states, gate_linear, experts_module, top_k: int
|
||||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
"""
|
"""Attempt grouped GEMM fast path using PyTorch 2.8+."""
|
||||||
Attempt a grouped GEMM fast path using PyTorch 2.8+.
|
global LAST_ERROR
|
||||||
If unavailable or fails, returns (None, None) so caller can fallback.
|
LAST_ERROR = None
|
||||||
"""
|
bsz, seqlen, hdim = hidden_states.shape
|
||||||
try:
|
x = hidden_states.view(-1, hdim)
|
||||||
bsz, seqlen, hdim = hidden_states.shape
|
router_logits = gate_linear(x)
|
||||||
x = hidden_states.view(-1, hdim)
|
|
||||||
router_logits = gate_linear(x)
|
|
||||||
|
|
||||||
# topk routing in torch (keep simple to avoid dependency cycles)
|
# top-k routing executed in torch to avoid extra dependencies
|
||||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||||
topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(x.dtype)
|
topk_weight = (topk_weight / topk_weight.sum(dim=-1, keepdim=True)).to(x.dtype)
|
||||||
|
|
||||||
# Build per-expert input lists
|
flat_idx = topk_idx.view(-1)
|
||||||
flat_idx = topk_idx.view(-1)
|
x_rep = x.repeat_interleave(top_k, dim=0)
|
||||||
x_rep = x.repeat_interleave(top_k, dim=0)
|
|
||||||
|
|
||||||
# Cache stacked weights on experts (support Mixtral and Qwen-style layouts)
|
E = experts_module.num_experts
|
||||||
E = experts_module.num_experts
|
dev, dt = x.device, x.dtype
|
||||||
dev, dt = x.device, x.dtype
|
first = experts_module[0]
|
||||||
first = experts_module[0]
|
|
||||||
|
|
||||||
is_mixtral = _is_mixtral_layout(first)
|
is_mixtral = _is_mixtral_layout(first)
|
||||||
is_qwen2 = _is_qwen_layout(first)
|
is_qwen2 = _is_qwen_layout(first)
|
||||||
nested_attr: Optional[str] = None
|
nested_attr: Optional[str] = None
|
||||||
if not (is_mixtral or is_qwen2):
|
if not (is_mixtral or is_qwen2):
|
||||||
for candidate in ("mlp", "ffn"):
|
for candidate in ("mlp", "ffn"):
|
||||||
nested = getattr(first, candidate, None)
|
nested = getattr(first, candidate, None)
|
||||||
if nested is None:
|
if nested is None:
|
||||||
continue
|
continue
|
||||||
if _is_mixtral_layout(nested):
|
if _is_mixtral_layout(nested):
|
||||||
is_mixtral = True
|
is_mixtral = True
|
||||||
nested_attr = candidate
|
nested_attr = candidate
|
||||||
break
|
break
|
||||||
if _is_qwen_layout(nested):
|
if _is_qwen_layout(nested):
|
||||||
is_qwen2 = True
|
is_qwen2 = True
|
||||||
nested_attr = candidate
|
nested_attr = candidate
|
||||||
break
|
break
|
||||||
if not (is_mixtral or is_qwen2):
|
if not (is_mixtral or is_qwen2):
|
||||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"torch_grouped: unsupported expert layout; falling back to naive"
|
"torch_grouped: unsupported expert layout; falling back to naive"
|
||||||
)
|
|
||||||
experts_module._ax_grouped_logged_fail = True
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
def _resolve_expert(idx: int):
|
|
||||||
expert = experts_module[idx]
|
|
||||||
if nested_attr is None:
|
|
||||||
return expert
|
|
||||||
nested_mod = getattr(expert, nested_attr, None)
|
|
||||||
if nested_mod is None:
|
|
||||||
raise AttributeError(
|
|
||||||
f"expert {idx} missing nested module '{nested_attr}'"
|
|
||||||
)
|
|
||||||
return nested_mod
|
|
||||||
|
|
||||||
try:
|
|
||||||
if is_mixtral:
|
|
||||||
if (
|
|
||||||
not hasattr(experts_module, "_stacked_w1")
|
|
||||||
or experts_module._stacked_w1.device != dev
|
|
||||||
or experts_module._stacked_w1.dtype != dt
|
|
||||||
):
|
|
||||||
mods = [_resolve_expert(i) for i in range(E)]
|
|
||||||
w1 = [mod.w1.weight.t() for mod in mods]
|
|
||||||
w3 = [mod.w3.weight.t() for mod in mods]
|
|
||||||
w2 = [mod.w2.weight.t() for mod in mods]
|
|
||||||
experts_module._stacked_w1 = (
|
|
||||||
torch.stack(w1, dim=0)
|
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
experts_module._stacked_w3 = (
|
|
||||||
torch.stack(w3, dim=0)
|
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
experts_module._stacked_w2 = (
|
|
||||||
torch.stack(w2, dim=0)
|
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
experts_module._stacked_w13 = torch.cat(
|
|
||||||
[experts_module._stacked_w1, experts_module._stacked_w3], dim=-1
|
|
||||||
).contiguous()
|
|
||||||
W13 = experts_module._stacked_w13
|
|
||||||
W2 = experts_module._stacked_w2
|
|
||||||
else:
|
|
||||||
# Qwen-style MoE: either gate_up_proj (2I x H) or (up_proj + gate_proj), down_proj (H x I)
|
|
||||||
if (
|
|
||||||
not hasattr(experts_module, "_stacked_w13")
|
|
||||||
or experts_module._stacked_w13.device != dev
|
|
||||||
or experts_module._stacked_w13.dtype != dt
|
|
||||||
):
|
|
||||||
w13 = []
|
|
||||||
w2 = []
|
|
||||||
for i in range(E):
|
|
||||||
mod = _resolve_expert(i)
|
|
||||||
# prefer fused gate_up_proj if present
|
|
||||||
if hasattr(mod, "gate_up_proj"):
|
|
||||||
w13.append(mod.gate_up_proj.weight.t())
|
|
||||||
elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"):
|
|
||||||
# concatenate [up | gate] along N
|
|
||||||
w13.append(
|
|
||||||
torch.cat(
|
|
||||||
[mod.up_proj.weight.t(), mod.gate_proj.weight.t()],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
LAST_ERROR = "unrecognized Qwen MoE expert weight layout"
|
|
||||||
if not getattr(
|
|
||||||
experts_module, "_ax_grouped_logged_fail", False
|
|
||||||
):
|
|
||||||
_LOGGER.warning(
|
|
||||||
"torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive"
|
|
||||||
)
|
|
||||||
experts_module._ax_grouped_logged_fail = True
|
|
||||||
return None, None
|
|
||||||
w2.append(mod.down_proj.weight.t())
|
|
||||||
experts_module._stacked_w13 = (
|
|
||||||
torch.stack(w13, dim=0)
|
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
experts_module._stacked_w2 = (
|
|
||||||
torch.stack(w2, dim=0)
|
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
W13 = experts_module._stacked_w13
|
|
||||||
W2 = experts_module._stacked_w2
|
|
||||||
except AttributeError as err:
|
|
||||||
LAST_ERROR = str(err)
|
|
||||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
|
||||||
_LOGGER.warning(
|
|
||||||
"torch_grouped: expert weights missing expected attributes; falling back to naive"
|
|
||||||
)
|
|
||||||
experts_module._ax_grouped_logged_fail = True
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
# Grouped GEMM for up+gate
|
|
||||||
As: List[torch.Tensor] = []
|
|
||||||
Bs: List[torch.Tensor] = []
|
|
||||||
expert_slices = []
|
|
||||||
for i in range(E):
|
|
||||||
sel = flat_idx == i
|
|
||||||
if sel.any():
|
|
||||||
Xi = x_rep[sel]
|
|
||||||
As.append(Xi)
|
|
||||||
Bs.append(W13[i])
|
|
||||||
expert_slices.append((i, sel))
|
|
||||||
|
|
||||||
if not As:
|
|
||||||
# no tokens routed — edge case
|
|
||||||
out = torch.zeros_like(x)
|
|
||||||
return out.view(bsz, seqlen, hdim), router_logits
|
|
||||||
|
|
||||||
Y_list = _call_grouped_mm(As, Bs)
|
|
||||||
if Y_list is None:
|
|
||||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
|
||||||
_LOGGER.warning(
|
|
||||||
f"torch_grouped: grouped_mm up+gate failed; falling back to naive. Reason: {LAST_ERROR}"
|
|
||||||
)
|
|
||||||
experts_module._ax_grouped_logged_fail = True
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
# SwiGLU on each expert block and prepare for down projection
|
|
||||||
As2: List[torch.Tensor] = []
|
|
||||||
Bs2: List[torch.Tensor] = []
|
|
||||||
y_buf = torch.empty_like(x_rep)
|
|
||||||
|
|
||||||
# split Y into (I, I)
|
|
||||||
for Yi in Y_list:
|
|
||||||
I2 = Yi.shape[-1] // 2
|
|
||||||
Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]
|
|
||||||
As2.append(Yi_hidden)
|
|
||||||
Bs2.append(W2[i])
|
|
||||||
|
|
||||||
Y2_list = _call_grouped_mm(As2, Bs2)
|
|
||||||
if Y2_list is None:
|
|
||||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
|
||||||
_LOGGER.warning(
|
|
||||||
f"torch_grouped: grouped_mm down failed; falling back to naive. Reason: {LAST_ERROR}"
|
|
||||||
)
|
|
||||||
experts_module._ax_grouped_logged_fail = True
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
# Write back, apply per-token weighting, and reduce over top_k
|
|
||||||
for (_, sel), Out_i in zip(expert_slices, Y2_list, strict=False):
|
|
||||||
y_buf[sel] = Out_i
|
|
||||||
y = (y_buf.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
|
||||||
if not getattr(experts_module, "_ax_grouped_logged_ok", False):
|
|
||||||
_LOGGER.info(
|
|
||||||
f"torch_grouped: engaged grouped GEMM path (experts={E}, top_k={top_k})"
|
|
||||||
)
|
)
|
||||||
experts_module._ax_grouped_logged_ok = True
|
experts_module._ax_grouped_logged_fail = True
|
||||||
return y.view(bsz, seqlen, hdim), router_logits
|
LAST_ERROR = "unsupported expert layout"
|
||||||
except Exception:
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
def _resolve_expert(idx: int):
|
||||||
|
expert = experts_module[idx]
|
||||||
|
if nested_attr is None:
|
||||||
|
return expert
|
||||||
|
nested_mod = getattr(expert, nested_attr, None)
|
||||||
|
if nested_mod is None:
|
||||||
|
raise AttributeError(f"expert {idx} missing nested module '{nested_attr}'")
|
||||||
|
return nested_mod
|
||||||
|
|
||||||
|
try:
|
||||||
|
if is_mixtral:
|
||||||
|
if (
|
||||||
|
not hasattr(experts_module, "_stacked_w1")
|
||||||
|
or experts_module._stacked_w1.device != dev
|
||||||
|
or experts_module._stacked_w1.dtype != dt
|
||||||
|
):
|
||||||
|
mods = [_resolve_expert(i) for i in range(E)]
|
||||||
|
w1 = [mod.w1.weight.t() for mod in mods]
|
||||||
|
w3 = [mod.w3.weight.t() for mod in mods]
|
||||||
|
w2 = [mod.w2.weight.t() for mod in mods]
|
||||||
|
experts_module._stacked_w1 = (
|
||||||
|
torch.stack(w1, dim=0)
|
||||||
|
.to(device=dev, dtype=dt, non_blocking=True)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
experts_module._stacked_w3 = (
|
||||||
|
torch.stack(w3, dim=0)
|
||||||
|
.to(device=dev, dtype=dt, non_blocking=True)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
experts_module._stacked_w2 = (
|
||||||
|
torch.stack(w2, dim=0)
|
||||||
|
.to(device=dev, dtype=dt, non_blocking=True)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
experts_module._stacked_w13 = torch.cat(
|
||||||
|
[experts_module._stacked_w1, experts_module._stacked_w3], dim=-1
|
||||||
|
).contiguous()
|
||||||
|
W13 = experts_module._stacked_w13
|
||||||
|
W2 = experts_module._stacked_w2
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
not hasattr(experts_module, "_stacked_w13")
|
||||||
|
or experts_module._stacked_w13.device != dev
|
||||||
|
or experts_module._stacked_w13.dtype != dt
|
||||||
|
):
|
||||||
|
w13 = []
|
||||||
|
w2 = []
|
||||||
|
for i in range(E):
|
||||||
|
mod = _resolve_expert(i)
|
||||||
|
if hasattr(mod, "gate_up_proj"):
|
||||||
|
w13.append(mod.gate_up_proj.weight.t())
|
||||||
|
elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"):
|
||||||
|
w13.append(
|
||||||
|
torch.cat(
|
||||||
|
[mod.up_proj.weight.t(), mod.gate_proj.weight.t()],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LAST_ERROR = "unrecognized Qwen MoE expert weight layout"
|
||||||
|
if not getattr(
|
||||||
|
experts_module, "_ax_grouped_logged_fail", False
|
||||||
|
):
|
||||||
|
_LOGGER.warning(
|
||||||
|
"torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive"
|
||||||
|
)
|
||||||
|
experts_module._ax_grouped_logged_fail = True
|
||||||
|
return None, None
|
||||||
|
w2.append(mod.down_proj.weight.t())
|
||||||
|
experts_module._stacked_w13 = (
|
||||||
|
torch.stack(w13, dim=0)
|
||||||
|
.to(device=dev, dtype=dt, non_blocking=True)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
experts_module._stacked_w2 = (
|
||||||
|
torch.stack(w2, dim=0)
|
||||||
|
.to(device=dev, dtype=dt, non_blocking=True)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
W13 = experts_module._stacked_w13
|
||||||
|
W2 = experts_module._stacked_w2
|
||||||
|
except AttributeError as err:
|
||||||
|
LAST_ERROR = str(err)
|
||||||
|
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||||
|
_LOGGER.warning(
|
||||||
|
"torch_grouped: expert weights missing expected attributes; falling back to naive"
|
||||||
|
)
|
||||||
|
experts_module._ax_grouped_logged_fail = True
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
As: List[torch.Tensor] = []
|
||||||
|
Bs: List[torch.Tensor] = []
|
||||||
|
expert_slices: List[Tuple[int, torch.Tensor]] = []
|
||||||
|
for i in range(E):
|
||||||
|
sel = flat_idx == i
|
||||||
|
if sel.any():
|
||||||
|
Xi = x_rep[sel]
|
||||||
|
As.append(Xi)
|
||||||
|
Bs.append(W13[i])
|
||||||
|
expert_slices.append((i, sel))
|
||||||
|
|
||||||
|
if not As:
|
||||||
|
out = torch.zeros_like(x)
|
||||||
|
return out.view(bsz, seqlen, hdim), router_logits
|
||||||
|
|
||||||
|
Y_list = _call_grouped_mm(As, Bs)
|
||||||
|
if Y_list is None:
|
||||||
|
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||||
|
_LOGGER.warning(
|
||||||
|
f"torch_grouped: grouped_mm up+gate failed; falling back to naive. Reason: {LAST_ERROR}"
|
||||||
|
)
|
||||||
|
experts_module._ax_grouped_logged_fail = True
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
As2: List[torch.Tensor] = []
|
||||||
|
Bs2: List[torch.Tensor] = []
|
||||||
|
y_buf = torch.empty_like(x_rep)
|
||||||
|
for (i, _sel), Yi in zip(expert_slices, Y_list, strict=False):
|
||||||
|
I2 = Yi.shape[-1] // 2
|
||||||
|
Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]
|
||||||
|
As2.append(Yi_hidden)
|
||||||
|
Bs2.append(W2[i])
|
||||||
|
|
||||||
|
Y2_list = _call_grouped_mm(As2, Bs2)
|
||||||
|
if Y2_list is None:
|
||||||
|
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||||
|
_LOGGER.warning(
|
||||||
|
f"torch_grouped: grouped_mm down failed; falling back to naive. Reason: {LAST_ERROR}"
|
||||||
|
)
|
||||||
|
experts_module._ax_grouped_logged_fail = True
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
for (_i, sel), Out_i in zip(expert_slices, Y2_list, strict=False):
|
||||||
|
y_buf[sel] = Out_i
|
||||||
|
y = (y_buf.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||||
|
if not getattr(experts_module, "_ax_grouped_logged_ok", False):
|
||||||
|
_LOGGER.info(
|
||||||
|
f"torch_grouped: engaged grouped GEMM path (experts={E}, top_k={top_k})"
|
||||||
|
)
|
||||||
|
experts_module._ax_grouped_logged_ok = True
|
||||||
|
return y.view(bsz, seqlen, hdim), router_logits
|
||||||
|
|||||||
@@ -82,9 +82,18 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
|
|||||||
# One-time log per block instance indicating whether grouped engaged or fallback occurred
|
# One-time log per block instance indicating whether grouped engaged or fallback occurred
|
||||||
if not getattr(self, "_ax_grouped_wrapper_logged", False):
|
if not getattr(self, "_ax_grouped_wrapper_logged", False):
|
||||||
if y is None:
|
if y is None:
|
||||||
_LOG.warning(
|
reason = getattr(_tg, "LAST_ERROR", None)
|
||||||
f"Grouped wrapper active but fell back to naive for {self.__class__.__name__}"
|
if reason:
|
||||||
)
|
_LOG.warning(
|
||||||
|
"Grouped wrapper fell back to naive for %s (reason=%s)",
|
||||||
|
self.__class__.__name__,
|
||||||
|
reason,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_LOG.warning(
|
||||||
|
"Grouped wrapper active but fell back to naive for %s",
|
||||||
|
self.__class__.__name__,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
_LOG.info(
|
_LOG.info(
|
||||||
f"Grouped wrapper engaged for {self.__class__.__name__} (top_k={self.top_k})"
|
f"Grouped wrapper engaged for {self.__class__.__name__} (top_k={self.top_k})"
|
||||||
|
|||||||
Reference in New Issue
Block a user