minify
This commit is contained in:
@@ -54,7 +54,7 @@ def forward_naive(
|
|||||||
def bench(fn, *args, iters=50, warmup=10, sync=True):
|
def bench(fn, *args, iters=50, warmup=10, sync=True):
|
||||||
# warmup
|
# warmup
|
||||||
for _ in range(warmup):
|
for _ in range(warmup):
|
||||||
out = fn(*args)
|
fn(*args)
|
||||||
if sync and torch.cuda.is_available():
|
if sync and torch.cuda.is_available():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
# measure
|
# measure
|
||||||
@@ -63,7 +63,7 @@ def bench(fn, *args, iters=50, warmup=10, sync=True):
|
|||||||
if sync and torch.cuda.is_available():
|
if sync and torch.cuda.is_available():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
out = fn(*args)
|
fn(*args)
|
||||||
if sync and torch.cuda.is_available():
|
if sync and torch.cuda.is_available():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
dt = (time.perf_counter() - t0) * 1000.0
|
dt = (time.perf_counter() - t0) * 1000.0
|
||||||
@@ -185,12 +185,7 @@ def main():
|
|||||||
f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
|
f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
print("torch_grouped\tN/A (op not callable)")
|
||||||
from axolotl.kernels.moe.torch_grouped import LAST_ERROR as _TG_ERR
|
|
||||||
except Exception:
|
|
||||||
_TG_ERR = None
|
|
||||||
reason = f" (reason: {_TG_ERR})" if _TG_ERR else ""
|
|
||||||
print(f"torch_grouped\tN/A (op not callable){reason}")
|
|
||||||
else:
|
else:
|
||||||
print("torch_grouped\tN/A (unavailable)")
|
print("torch_grouped\tN/A (unavailable)")
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,8 @@
|
|||||||
"""
|
"""Minimal grouped GEMM fast path for MoE experts using PyTorch _grouped_mm."""
|
||||||
PyTorch 2.8+ grouped GEMM MoE path (cuBLASLt-backed).
|
|
||||||
This is a cautious first pass that probes available ops and only runs when supported.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
from typing import List, Optional, Sequence, Tuple
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -14,316 +10,128 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
def available() -> bool:
|
def available() -> bool:
|
||||||
try:
|
try:
|
||||||
ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2])
|
major, minor = map(int, torch.__version__.split("+")[0].split(".")[:2])
|
||||||
if ver < (2, 8):
|
if (major, minor) < (2, 8):
|
||||||
return False
|
return False
|
||||||
# Require Hopper+ (SM90) per torch error message and check op presence
|
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
return False
|
return False
|
||||||
major, minor = torch.cuda.get_device_capability()
|
sm, _ = torch.cuda.get_device_capability()
|
||||||
if major < 9:
|
if sm < 9:
|
||||||
return False
|
return False
|
||||||
return hasattr(torch.ops, "_grouped_mm")
|
return hasattr(torch.ops, "_grouped_mm")
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
LAST_ERROR: Optional[str] = None
|
def _stack_weights(
|
||||||
_LOGGER = logging.getLogger("axolotl.moe.grouped")
|
experts: Sequence[torch.nn.Module], names: Tuple[str, ...]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
stacked: List[torch.Tensor] = []
|
||||||
def _is_mixtral_layout(mod: torch.nn.Module) -> bool:
|
for expert in experts:
|
||||||
return all(hasattr(mod, attr) for attr in ("w1", "w3", "w2"))
|
mod = getattr(expert, "mlp", getattr(expert, "ffn", expert))
|
||||||
|
parts = [getattr(mod, name).weight.t() for name in names]
|
||||||
|
stacked.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1))
|
||||||
def _is_qwen_layout(mod: torch.nn.Module) -> bool:
|
return torch.stack(stacked, dim=0)
|
||||||
has_fused = hasattr(mod, "gate_up_proj")
|
|
||||||
has_split = hasattr(mod, "up_proj") and hasattr(mod, "gate_proj")
|
|
||||||
return (has_fused or has_split) and hasattr(mod, "down_proj")
|
|
||||||
|
|
||||||
|
|
||||||
def _num_experts(module: torch.nn.Module) -> int:
|
|
||||||
"""Return expert count, supporting ModuleList-style inputs."""
|
|
||||||
count = getattr(module, "num_experts", None)
|
|
||||||
if count is not None:
|
|
||||||
return int(count() if callable(count) else count)
|
|
||||||
try:
|
|
||||||
return len(module) # type: ignore[arg-type]
|
|
||||||
except TypeError as exc: # pragma: no cover - defensive
|
|
||||||
raise AttributeError("experts module missing num_experts/len support") from exc
|
|
||||||
|
|
||||||
|
|
||||||
def _call_grouped_mm(
|
def _call_grouped_mm(
|
||||||
As: List[torch.Tensor], Bs: List[torch.Tensor]
|
As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype
|
||||||
) -> Optional[List[torch.Tensor]]:
|
) -> Optional[List[torch.Tensor]]:
|
||||||
"""
|
if not As:
|
||||||
Call grouped mm with packed representation
|
|
||||||
|
|
||||||
- A_cat: concat As along rows -> [sum_i Mi, K]
|
|
||||||
- B_stk: stack Bs per group -> [G, K, N]
|
|
||||||
- offs: lengths per group Mi -> [G] int32
|
|
||||||
Returns list of per-group outputs split from concatenated result.
|
|
||||||
"""
|
|
||||||
global LAST_ERROR
|
|
||||||
# Ensure 2D contiguous inputs
|
|
||||||
As2 = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
|
||||||
Bs2 = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
|
||||||
|
|
||||||
if not As2:
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As]
|
||||||
|
Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs]
|
||||||
device = As2[0].device
|
device = As2[0].device
|
||||||
A_cat = torch.cat(As2, dim=0).to(torch.bfloat16)
|
|
||||||
B_stk = torch.stack(Bs2, dim=0).to(torch.bfloat16)
|
|
||||||
offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
|
offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32)
|
||||||
Y_cat = torch._grouped_mm(A_cat, B_stk, offs) # type: ignore[attr-defined]
|
Y_cat = torch.ops.aten._grouped_mm(
|
||||||
|
torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offs
|
||||||
|
)
|
||||||
outs: List[torch.Tensor] = []
|
outs: List[torch.Tensor] = []
|
||||||
start = 0
|
start = 0
|
||||||
for m in offs.tolist():
|
for m in offs.tolist():
|
||||||
outs.append(Y_cat[start : start + m, :])
|
outs.append(Y_cat[start : start + m])
|
||||||
start += m
|
start += m
|
||||||
return outs
|
return outs
|
||||||
|
|
||||||
|
|
||||||
def moe_ffn_forward_grouped(
|
def moe_ffn_forward_grouped(
|
||||||
hidden_states, gate_linear, experts_module, top_k: int
|
hidden_states: torch.Tensor,
|
||||||
|
gate_linear: torch.nn.Linear,
|
||||||
|
experts_module,
|
||||||
|
top_k: int,
|
||||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
"""Attempt grouped GEMM fast path using PyTorch 2.8+."""
|
if not available():
|
||||||
global LAST_ERROR
|
return None, None
|
||||||
LAST_ERROR = None
|
|
||||||
bsz, seqlen, hdim = hidden_states.shape
|
bsz, seqlen, hdim = hidden_states.shape
|
||||||
|
tokens = bsz * seqlen
|
||||||
|
device = hidden_states.device
|
||||||
|
|
||||||
routing_dtype = gate_linear.weight.dtype
|
routing_dtype = gate_linear.weight.dtype
|
||||||
use_mixed_router = (
|
expert_dtype = hidden_states.dtype
|
||||||
hidden_states.device.type == "cuda" and routing_dtype == torch.float32
|
x_flat = hidden_states.view(tokens, hdim)
|
||||||
)
|
router_logits = gate_linear(x_flat.to(routing_dtype))
|
||||||
|
|
||||||
x_base = hidden_states.view(-1, hdim)
|
|
||||||
if use_mixed_router:
|
|
||||||
x_router = x_base.to(dtype=routing_dtype)
|
|
||||||
else:
|
|
||||||
x_router = x_base
|
|
||||||
if x_router.dtype != routing_dtype:
|
|
||||||
x_router = x_router.to(dtype=routing_dtype)
|
|
||||||
|
|
||||||
router_logits = gate_linear(x_router)
|
|
||||||
if router_logits.dtype != routing_dtype:
|
|
||||||
router_logits = router_logits.to(dtype=routing_dtype)
|
|
||||||
|
|
||||||
x = x_base
|
|
||||||
|
|
||||||
# top-k routing executed in torch to avoid extra dependencies
|
|
||||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||||
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False)
|
||||||
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
E = _num_experts(experts_module)
|
experts = list(experts_module)
|
||||||
dev = hidden_states.device
|
sample = getattr(experts[0], "mlp", getattr(experts[0], "ffn", experts[0]))
|
||||||
first = experts_module[0]
|
if hasattr(sample, "w1") and hasattr(sample, "w3") and hasattr(sample, "w2"):
|
||||||
|
w13 = _stack_weights(experts, ("w1", "w3")).to(
|
||||||
is_mixtral = _is_mixtral_layout(first)
|
device=device, dtype=expert_dtype
|
||||||
is_qwen2 = _is_qwen_layout(first)
|
)
|
||||||
nested_attr: Optional[str] = None
|
w2 = _stack_weights(experts, ("w2",)).to(device=device, dtype=expert_dtype)
|
||||||
if not (is_mixtral or is_qwen2):
|
|
||||||
for candidate in ("mlp", "ffn"):
|
|
||||||
nested = getattr(first, candidate, None)
|
|
||||||
if nested is None:
|
|
||||||
continue
|
|
||||||
if _is_mixtral_layout(nested):
|
|
||||||
is_mixtral = True
|
|
||||||
nested_attr = candidate
|
|
||||||
break
|
|
||||||
if _is_qwen_layout(nested):
|
|
||||||
is_qwen2 = True
|
|
||||||
nested_attr = candidate
|
|
||||||
break
|
|
||||||
if not (is_mixtral or is_qwen2):
|
|
||||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
|
||||||
_LOGGER.warning(
|
|
||||||
"torch_grouped: unsupported expert layout; falling back to naive"
|
|
||||||
)
|
|
||||||
experts_module._ax_grouped_logged_fail = True
|
|
||||||
LAST_ERROR = "unsupported expert layout"
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
def _resolve_expert(idx: int) -> torch.nn.Module:
|
|
||||||
expert = experts_module[idx]
|
|
||||||
if nested_attr is None:
|
|
||||||
return expert
|
|
||||||
nested_mod = getattr(expert, nested_attr, None)
|
|
||||||
if nested_mod is None:
|
|
||||||
raise AttributeError(f"expert {idx} missing nested module '{nested_attr}'")
|
|
||||||
return nested_mod
|
|
||||||
|
|
||||||
if is_mixtral:
|
|
||||||
dt: torch.dtype = first.w1.weight.dtype # type: ignore[assignment]
|
|
||||||
if (
|
|
||||||
not hasattr(experts_module, "_stacked_w1")
|
|
||||||
or experts_module._stacked_w1.device != dev
|
|
||||||
or experts_module._stacked_w1.dtype != dt
|
|
||||||
):
|
|
||||||
mods = [_resolve_expert(i) for i in range(E)]
|
|
||||||
w1 = [mod.w1.weight.t() for mod in mods]
|
|
||||||
w3 = [mod.w3.weight.t() for mod in mods]
|
|
||||||
w2 = [mod.w2.weight.t() for mod in mods]
|
|
||||||
experts_module._stacked_w1 = (
|
|
||||||
torch.stack(w1, dim=0)
|
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
experts_module._stacked_w3 = (
|
|
||||||
torch.stack(w3, dim=0)
|
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
experts_module._stacked_w2 = (
|
|
||||||
torch.stack(w2, dim=0)
|
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
experts_module._stacked_w13 = torch.cat(
|
|
||||||
[experts_module._stacked_w1, experts_module._stacked_w3], dim=-1
|
|
||||||
).contiguous()
|
|
||||||
W13 = experts_module._stacked_w13
|
|
||||||
W2 = experts_module._stacked_w2
|
|
||||||
else:
|
else:
|
||||||
sample_mod = _resolve_expert(0)
|
names13 = (
|
||||||
if hasattr(sample_mod, "gate_up_proj"):
|
("gate_up_proj",)
|
||||||
dt = sample_mod.gate_up_proj.weight.dtype # type: ignore[assignment]
|
if hasattr(sample, "gate_up_proj")
|
||||||
elif hasattr(sample_mod, "up_proj"):
|
else ("up_proj", "gate_proj")
|
||||||
dt = sample_mod.up_proj.weight.dtype # type: ignore[assignment]
|
)
|
||||||
else:
|
w13 = _stack_weights(experts, names13).to(device=device, dtype=expert_dtype)
|
||||||
dt = sample_mod.down_proj.weight.dtype # type: ignore[assignment]
|
w2 = _stack_weights(experts, ("down_proj",)).to(
|
||||||
if (
|
device=device, dtype=expert_dtype
|
||||||
not hasattr(experts_module, "_stacked_w13")
|
)
|
||||||
or experts_module._stacked_w13.device != dev
|
|
||||||
or experts_module._stacked_w13.dtype != dt
|
|
||||||
):
|
|
||||||
w13 = []
|
|
||||||
w2 = []
|
|
||||||
for i in range(E):
|
|
||||||
mod = _resolve_expert(i)
|
|
||||||
if hasattr(mod, "gate_up_proj"):
|
|
||||||
w13.append(mod.gate_up_proj.weight.t())
|
|
||||||
elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"):
|
|
||||||
w13.append(
|
|
||||||
torch.cat(
|
|
||||||
[mod.up_proj.weight.t(), mod.gate_proj.weight.t()],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
LAST_ERROR = "unrecognized Qwen MoE expert weight layout"
|
|
||||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
|
||||||
_LOGGER.warning(
|
|
||||||
"torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive"
|
|
||||||
)
|
|
||||||
experts_module._ax_grouped_logged_fail = True
|
|
||||||
return None, None
|
|
||||||
w2.append(mod.down_proj.weight.t())
|
|
||||||
experts_module._stacked_w13 = (
|
|
||||||
torch.stack(w13, dim=0)
|
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
experts_module._stacked_w2 = (
|
|
||||||
torch.stack(w2, dim=0)
|
|
||||||
.to(device=dev, dtype=dt, non_blocking=True)
|
|
||||||
.contiguous()
|
|
||||||
)
|
|
||||||
W13 = experts_module._stacked_w13
|
|
||||||
W2 = experts_module._stacked_w2
|
|
||||||
|
|
||||||
dt = W13.dtype
|
|
||||||
if router_logits.dtype != dt:
|
|
||||||
router_logits = router_logits.to(dtype=dt)
|
|
||||||
if x.dtype != dt:
|
|
||||||
x = x.to(dtype=dt)
|
|
||||||
flat_idx = topk_idx.view(-1)
|
flat_idx = topk_idx.view(-1)
|
||||||
if topk_weight.dtype != dt:
|
x_rep = x_flat.to(expert_dtype).repeat_interleave(top_k, dim=0)
|
||||||
topk_weight = topk_weight.to(dtype=dt)
|
|
||||||
x_rep = x.repeat_interleave(top_k, dim=0)
|
|
||||||
if x_rep.dtype != dt:
|
|
||||||
x_rep = x_rep.to(dtype=dt)
|
|
||||||
|
|
||||||
As: List[torch.Tensor] = []
|
as_list: List[torch.Tensor] = []
|
||||||
Bs: List[torch.Tensor] = []
|
bs_list: List[torch.Tensor] = []
|
||||||
expert_slices: List[Tuple[int, torch.Tensor]] = []
|
slices: List[Tuple[int, torch.Tensor]] = []
|
||||||
for i in range(E):
|
for i in range(len(experts)):
|
||||||
sel = flat_idx == i
|
sel = flat_idx == i
|
||||||
if sel.any():
|
if sel.any():
|
||||||
Xi = x_rep[sel].contiguous()
|
as_list.append(x_rep[sel])
|
||||||
As.append(Xi)
|
bs_list.append(w13[i])
|
||||||
Bs.append(W13[i].reshape(hdim, -1).contiguous())
|
slices.append((i, sel))
|
||||||
expert_slices.append((i, sel))
|
|
||||||
|
|
||||||
if not As:
|
if not as_list:
|
||||||
out = torch.zeros_like(x)
|
return torch.zeros_like(x_flat).view(bsz, seqlen, hdim), router_logits
|
||||||
return out.view(bsz, seqlen, hdim), router_logits
|
|
||||||
|
|
||||||
def _run_grouped_mm(
|
up_out = _call_grouped_mm(as_list, bs_list, expert_dtype)
|
||||||
a_tensors: List[torch.Tensor],
|
if up_out is None:
|
||||||
b_tensors: List[torch.Tensor],
|
|
||||||
target_dtype: torch.dtype,
|
|
||||||
) -> Optional[List[torch.Tensor]]:
|
|
||||||
global LAST_ERROR
|
|
||||||
if target_dtype != dt:
|
|
||||||
a_tensors = [t.to(target_dtype).contiguous() for t in a_tensors]
|
|
||||||
b_tensors = [t.to(target_dtype).contiguous() for t in b_tensors]
|
|
||||||
outputs = _call_grouped_mm(a_tensors, b_tensors)
|
|
||||||
if outputs is not None and target_dtype != dt:
|
|
||||||
outputs = [t.to(dt).contiguous() for t in outputs]
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
def _try_grouped_mm(
|
|
||||||
a_tensors: List[torch.Tensor], b_tensors: List[torch.Tensor]
|
|
||||||
) -> Tuple[Optional[List[torch.Tensor]], bool]:
|
|
||||||
global LAST_ERROR
|
|
||||||
result = _run_grouped_mm(a_tensors, b_tensors, target_dtype=dt)
|
|
||||||
cast_used_local = False
|
|
||||||
if result is None and dt == torch.bfloat16:
|
|
||||||
result = _run_grouped_mm(a_tensors, b_tensors, target_dtype=torch.float16)
|
|
||||||
if result is not None:
|
|
||||||
cast_used_local = True
|
|
||||||
LAST_ERROR = None
|
|
||||||
if not getattr(experts_module, "_ax_grouped_logged_cast", False):
|
|
||||||
_LOGGER.info(
|
|
||||||
"torch_grouped: grouped_mm casting bfloat16 operands to float16"
|
|
||||||
)
|
|
||||||
experts_module._ax_grouped_logged_cast = True
|
|
||||||
return result, cast_used_local
|
|
||||||
|
|
||||||
Y_list, _cast_used_up = _try_grouped_mm(As, Bs)
|
|
||||||
if Y_list is None:
|
|
||||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
|
||||||
_LOGGER.warning(
|
|
||||||
f"torch_grouped: grouped_mm up+gate failed; falling back to naive. Reason: {LAST_ERROR}"
|
|
||||||
)
|
|
||||||
experts_module._ax_grouped_logged_fail = True
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
As2: List[torch.Tensor] = []
|
down_inputs: List[torch.Tensor] = []
|
||||||
Bs2: List[torch.Tensor] = []
|
down_weights: List[torch.Tensor] = []
|
||||||
y_buf = torch.empty_like(x_rep, dtype=dt)
|
buf = torch.empty_like(x_rep)
|
||||||
for (i, _sel), Yi in zip(expert_slices, Y_list, strict=False):
|
for (i, _sel), Yi in zip(slices, up_out, strict=False):
|
||||||
I2 = Yi.shape[-1] // 2
|
mid = Yi.shape[-1] // 2
|
||||||
Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]
|
hidden = F.silu(Yi[:, :mid]) * Yi[:, mid:]
|
||||||
As2.append(Yi_hidden)
|
down_inputs.append(hidden)
|
||||||
Bs2.append(W2[i].reshape(I2, hdim).contiguous())
|
down_weights.append(w2[i])
|
||||||
|
|
||||||
Y2_list, _cast_used_down = _try_grouped_mm(As2, Bs2)
|
down_out = _call_grouped_mm(down_inputs, down_weights, expert_dtype)
|
||||||
if Y2_list is None:
|
if down_out is None:
|
||||||
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
|
|
||||||
_LOGGER.warning(
|
|
||||||
f"torch_grouped: grouped_mm down failed; falling back to naive. Reason: {LAST_ERROR}"
|
|
||||||
)
|
|
||||||
experts_module._ax_grouped_logged_fail = True
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
for (_i, sel), Out_i in zip(expert_slices, Y2_list, strict=False):
|
for (_i, sel), tensor in zip(slices, down_out, strict=False):
|
||||||
y_buf[sel] = Out_i
|
buf[sel] = tensor
|
||||||
y = (y_buf.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
|
||||||
if not getattr(experts_module, "_ax_grouped_logged_ok", False):
|
combined = (
|
||||||
_LOGGER.info(
|
buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1)
|
||||||
f"torch_grouped: engaged grouped GEMM path (experts={E}, top_k={top_k})"
|
).sum(dim=1)
|
||||||
)
|
return combined.view(bsz, seqlen, hdim), router_logits
|
||||||
experts_module._ax_grouped_logged_ok = True
|
|
||||||
return y.view(bsz, seqlen, hdim), router_logits
|
|
||||||
|
|||||||
@@ -82,18 +82,10 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
|
|||||||
# One-time log per block instance indicating whether grouped engaged or fallback occurred
|
# One-time log per block instance indicating whether grouped engaged or fallback occurred
|
||||||
if not getattr(self, "_ax_grouped_wrapper_logged", False):
|
if not getattr(self, "_ax_grouped_wrapper_logged", False):
|
||||||
if y is None:
|
if y is None:
|
||||||
reason = getattr(_tg, "LAST_ERROR", None)
|
_LOG.warning(
|
||||||
if reason:
|
"Grouped wrapper active but fell back to naive for %s",
|
||||||
_LOG.warning(
|
self.__class__.__name__,
|
||||||
"Grouped wrapper fell back to naive for %s (reason=%s)",
|
)
|
||||||
self.__class__.__name__,
|
|
||||||
reason,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_LOG.warning(
|
|
||||||
"Grouped wrapper active but fell back to naive for %s",
|
|
||||||
self.__class__.__name__,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
_LOG.info(
|
_LOG.info(
|
||||||
f"Grouped wrapper engaged for {self.__class__.__name__} (top_k={self.top_k})"
|
f"Grouped wrapper engaged for {self.__class__.__name__} (top_k={self.top_k})"
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from pydantic import (
|
|||||||
Field,
|
Field,
|
||||||
StringConstraints,
|
StringConstraints,
|
||||||
field_serializer,
|
field_serializer,
|
||||||
field_validator,
|
|
||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ def test_grouped_uses_per_expert_nested_modules(monkeypatch):
|
|||||||
|
|
||||||
captured = []
|
captured = []
|
||||||
|
|
||||||
def fake_grouped_mm(As, Bs):
|
def fake_grouped_mm(As, Bs, dtype):
|
||||||
captured.append([b.detach().clone() for b in Bs])
|
captured.append([b.detach().clone() for b in Bs])
|
||||||
return [
|
return [
|
||||||
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
|
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
|
||||||
@@ -111,7 +111,7 @@ def test_grouped_accepts_module_list_experts(monkeypatch):
|
|||||||
|
|
||||||
calls = {"count": 0}
|
calls = {"count": 0}
|
||||||
|
|
||||||
def fake_grouped_mm(As, Bs):
|
def fake_grouped_mm(As, Bs, dtype):
|
||||||
calls["count"] += 1
|
calls["count"] += 1
|
||||||
return [
|
return [
|
||||||
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
|
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
|
||||||
|
|||||||
Reference in New Issue
Block a user