logs, qwen2 support
This commit is contained in:
@@ -5,6 +5,7 @@ This is a cautious first pass that probes available ops and only runs when suppo
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -28,6 +29,7 @@ def available() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
LAST_ERROR: Optional[str] = None
|
LAST_ERROR: Optional[str] = None
|
||||||
|
_LOGGER = logging.getLogger("axolotl.moe.grouped")
|
||||||
|
|
||||||
|
|
||||||
def _call_grouped_mm(
|
def _call_grouped_mm(
|
||||||
@@ -94,37 +96,72 @@ def moe_ffn_forward_grouped(
|
|||||||
flat_idx = topk_idx.view(-1)
|
flat_idx = topk_idx.view(-1)
|
||||||
x_rep = x.repeat_interleave(top_k, dim=0)
|
x_rep = x.repeat_interleave(top_k, dim=0)
|
||||||
|
|
||||||
# Cache stacked weights on experts
|
# Cache stacked weights on experts (support Mixtral and Qwen2-MoE layouts)
|
||||||
E = experts_module.num_experts
|
E = experts_module.num_experts
|
||||||
dev, dt = x.device, x.dtype
|
dev, dt = x.device, x.dtype
|
||||||
if (
|
first = experts_module[0]
|
||||||
not hasattr(experts_module, "_stacked_w1")
|
is_mixtral = (
|
||||||
or experts_module._stacked_w1.device != dev
|
hasattr(first, "w1") and hasattr(first, "w3") and hasattr(first, "w2")
|
||||||
or experts_module._stacked_w1.dtype != dt
|
)
|
||||||
):
|
is_qwen2 = hasattr(first, "gate_up_proj") and hasattr(first, "down_proj")
|
||||||
w1 = [experts_module[i].w1.weight.t() for i in range(E)]
|
if not (is_mixtral or is_qwen2):
|
||||||
w3 = [experts_module[i].w3.weight.t() for i in range(E)]
|
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||||
w2 = [experts_module[i].w2.weight.t() for i in range(E)]
|
_LOGGER.warning(
|
||||||
experts_module._stacked_w1 = (
|
"torch_grouped: unsupported expert layout; falling back to naive"
|
||||||
torch.stack(w1, dim=0)
|
)
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
experts_module._ax_grouped_logged_fail = True
|
||||||
.contiguous()
|
return None, None
|
||||||
)
|
|
||||||
experts_module._stacked_w3 = (
|
if is_mixtral:
|
||||||
torch.stack(w3, dim=0)
|
if (
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
not hasattr(experts_module, "_stacked_w1")
|
||||||
.contiguous()
|
or experts_module._stacked_w1.device != dev
|
||||||
)
|
or experts_module._stacked_w1.dtype != dt
|
||||||
experts_module._stacked_w2 = (
|
):
|
||||||
torch.stack(w2, dim=0)
|
w1 = [experts_module[i].w1.weight.t() for i in range(E)]
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
w3 = [experts_module[i].w3.weight.t() for i in range(E)]
|
||||||
.contiguous()
|
w2 = [experts_module[i].w2.weight.t() for i in range(E)]
|
||||||
)
|
experts_module._stacked_w1 = (
|
||||||
experts_module._stacked_w13 = torch.cat(
|
torch.stack(w1, dim=0)
|
||||||
[experts_module._stacked_w1, experts_module._stacked_w3], dim=-1
|
.to(device=dev, dtype=dt, non_blocking=True)
|
||||||
).contiguous()
|
.contiguous()
|
||||||
W13 = experts_module._stacked_w13
|
)
|
||||||
W2 = experts_module._stacked_w2
|
experts_module._stacked_w3 = (
|
||||||
|
torch.stack(w3, dim=0)
|
||||||
|
.to(device=dev, dtype=dt, non_blocking=True)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
experts_module._stacked_w2 = (
|
||||||
|
torch.stack(w2, dim=0)
|
||||||
|
.to(device=dev, dtype=dt, non_blocking=True)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
experts_module._stacked_w13 = torch.cat(
|
||||||
|
[experts_module._stacked_w1, experts_module._stacked_w3], dim=-1
|
||||||
|
).contiguous()
|
||||||
|
W13 = experts_module._stacked_w13
|
||||||
|
W2 = experts_module._stacked_w2
|
||||||
|
else:
|
||||||
|
# Qwen2-MoE style: gate_up_proj (2I x H), down_proj (H x I)
|
||||||
|
if (
|
||||||
|
not hasattr(experts_module, "_stacked_w13")
|
||||||
|
or experts_module._stacked_w13.device != dev
|
||||||
|
or experts_module._stacked_w13.dtype != dt
|
||||||
|
):
|
||||||
|
w13 = [experts_module[i].gate_up_proj.weight.t() for i in range(E)]
|
||||||
|
w2 = [experts_module[i].down_proj.weight.t() for i in range(E)]
|
||||||
|
experts_module._stacked_w13 = (
|
||||||
|
torch.stack(w13, dim=0)
|
||||||
|
.to(device=dev, dtype=dt, non_blocking=True)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
experts_module._stacked_w2 = (
|
||||||
|
torch.stack(w2, dim=0)
|
||||||
|
.to(device=dev, dtype=dt, non_blocking=True)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
W13 = experts_module._stacked_w13
|
||||||
|
W2 = experts_module._stacked_w2
|
||||||
|
|
||||||
# Grouped GEMM for up+gate
|
# Grouped GEMM for up+gate
|
||||||
As: List[torch.Tensor] = []
|
As: List[torch.Tensor] = []
|
||||||
@@ -145,6 +182,11 @@ def moe_ffn_forward_grouped(
|
|||||||
|
|
||||||
Y_list = _call_grouped_mm(As, Bs)
|
Y_list = _call_grouped_mm(As, Bs)
|
||||||
if Y_list is None:
|
if Y_list is None:
|
||||||
|
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||||
|
_LOGGER.warning(
|
||||||
|
f"torch_grouped: grouped_mm up+gate failed; falling back to naive. Reason: {LAST_ERROR}"
|
||||||
|
)
|
||||||
|
experts_module._ax_grouped_logged_fail = True
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# SwiGLU on each expert block and prepare for down projection
|
# SwiGLU on each expert block and prepare for down projection
|
||||||
@@ -160,12 +202,22 @@ def moe_ffn_forward_grouped(
|
|||||||
|
|
||||||
Y2_list = _call_grouped_mm(As2, Bs2)
|
Y2_list = _call_grouped_mm(As2, Bs2)
|
||||||
if Y2_list is None:
|
if Y2_list is None:
|
||||||
|
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
||||||
|
_LOGGER.warning(
|
||||||
|
f"torch_grouped: grouped_mm down failed; falling back to naive. Reason: {LAST_ERROR}"
|
||||||
|
)
|
||||||
|
experts_module._ax_grouped_logged_fail = True
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Write back, apply per-token weighting, and reduce over top_k
|
# Write back, apply per-token weighting, and reduce over top_k
|
||||||
for (i, sel), Out_i in zip(expert_slices, Y2_list):
|
for (i, sel), Out_i in zip(expert_slices, Y2_list):
|
||||||
y_buf[sel] = Out_i
|
y_buf[sel] = Out_i
|
||||||
y = (y_buf.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
y = (y_buf.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||||
|
if not getattr(experts_module, "_ax_grouped_logged_ok", False):
|
||||||
|
_LOGGER.info(
|
||||||
|
f"torch_grouped: engaged grouped GEMM path (experts={E}, top_k={top_k})"
|
||||||
|
)
|
||||||
|
experts_module._ax_grouped_logged_ok = True
|
||||||
return y.view(bsz, seqlen, hdim), router_logits
|
return y.view(bsz, seqlen, hdim), router_logits
|
||||||
except Exception:
|
except Exception:
|
||||||
return None, None
|
return None, None
|
||||||
|
|||||||
Reference in New Issue
Block a user