From fd87eed501a3d17f74df1efa570302cc7eb981bb Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 17 Sep 2025 16:42:35 -0400 Subject: [PATCH] minify --- scripts/bench_moe.py | 11 +- src/axolotl/kernels/moe/torch_grouped.py | 356 ++++++----------------- src/axolotl/monkeypatch/moe_grouped.py | 16 +- src/axolotl/utils/schemas/config.py | 1 - tests/monkeypatch/test_moe_grouped.py | 4 +- 5 files changed, 91 insertions(+), 297 deletions(-) diff --git a/scripts/bench_moe.py b/scripts/bench_moe.py index 287a9cbe3..49f54a23b 100644 --- a/scripts/bench_moe.py +++ b/scripts/bench_moe.py @@ -54,7 +54,7 @@ def forward_naive( def bench(fn, *args, iters=50, warmup=10, sync=True): # warmup for _ in range(warmup): - out = fn(*args) + fn(*args) if sync and torch.cuda.is_available(): torch.cuda.synchronize() # measure @@ -63,7 +63,7 @@ def bench(fn, *args, iters=50, warmup=10, sync=True): if sync and torch.cuda.is_available(): torch.cuda.synchronize() t0 = time.perf_counter() - out = fn(*args) + fn(*args) if sync and torch.cuda.is_available(): torch.cuda.synchronize() dt = (time.perf_counter() - t0) * 1000.0 @@ -185,12 +185,7 @@ def main(): f"torch_grouped_check: max_abs={max_abs:.3e} mean_abs={mean_abs:.3e} rel_l2={rel_l2:.3e}" ) else: - try: - from axolotl.kernels.moe.torch_grouped import LAST_ERROR as _TG_ERR - except Exception: - _TG_ERR = None - reason = f" (reason: {_TG_ERR})" if _TG_ERR else "" - print(f"torch_grouped\tN/A (op not callable){reason}") + print("torch_grouped\tN/A (op not callable)") else: print("torch_grouped\tN/A (unavailable)") diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 2e947c83f..f712eca37 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -1,12 +1,8 @@ -""" -PyTorch 2.8+ grouped GEMM MoE path (cuBLASLt-backed). -This is a cautious first pass that probes available ops and only runs when supported. -""" +"""Minimal grouped GEMM fast path for MoE experts using PyTorch _grouped_mm.""" from __future__ import annotations -import logging -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple import torch import torch.nn.functional as F @@ -14,316 +10,128 @@ import torch.nn.functional as F def available() -> bool: try: - ver = tuple(int(x) for x in torch.__version__.split("+")[0].split(".")[:2]) - if ver < (2, 8): + major, minor = map(int, torch.__version__.split("+")[0].split(".")[:2]) + if (major, minor) < (2, 8): return False - # Require Hopper+ (SM90) per torch error message and check op presence if not torch.cuda.is_available(): return False - major, minor = torch.cuda.get_device_capability() - if major < 9: + sm, _ = torch.cuda.get_device_capability() + if sm < 9: return False return hasattr(torch.ops, "_grouped_mm") except Exception: return False -LAST_ERROR: Optional[str] = None -_LOGGER = logging.getLogger("axolotl.moe.grouped") - - -def _is_mixtral_layout(mod: torch.nn.Module) -> bool: - return all(hasattr(mod, attr) for attr in ("w1", "w3", "w2")) - - -def _is_qwen_layout(mod: torch.nn.Module) -> bool: - has_fused = hasattr(mod, "gate_up_proj") - has_split = hasattr(mod, "up_proj") and hasattr(mod, "gate_proj") - return (has_fused or has_split) and hasattr(mod, "down_proj") - - -def _num_experts(module: torch.nn.Module) -> int: - """Return expert count, supporting ModuleList-style inputs.""" - count = getattr(module, "num_experts", None) - if count is not None: - return int(count() if callable(count) else count) - try: - return len(module) # type: ignore[arg-type] - except TypeError as exc: # pragma: no cover - defensive - raise AttributeError("experts module missing num_experts/len support") from exc +def _stack_weights( + experts: Sequence[torch.nn.Module], names: Tuple[str, ...] +) -> torch.Tensor: + stacked: List[torch.Tensor] = [] + for expert in experts: + mod = getattr(expert, "mlp", getattr(expert, "ffn", expert)) + parts = [getattr(mod, name).weight.t() for name in names] + stacked.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1)) + return torch.stack(stacked, dim=0) def _call_grouped_mm( - As: List[torch.Tensor], Bs: List[torch.Tensor] + As: List[torch.Tensor], Bs: List[torch.Tensor], dtype: torch.dtype ) -> Optional[List[torch.Tensor]]: - """ - Call grouped mm with packed representation - - - A_cat: concat As along rows -> [sum_i Mi, K] - - B_stk: stack Bs per group -> [G, K, N] - - offs: lengths per group Mi -> [G] int32 - Returns list of per-group outputs split from concatenated result. - """ - global LAST_ERROR - # Ensure 2D contiguous inputs - As2 = [a.contiguous().view(a.shape[0], a.shape[1]) for a in As] - Bs2 = [b.contiguous().view(b.shape[0], b.shape[1]) for b in Bs] - - if not As2: + if not As: return [] + + As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As] + Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs] device = As2[0].device - A_cat = torch.cat(As2, dim=0).to(torch.bfloat16) - B_stk = torch.stack(Bs2, dim=0).to(torch.bfloat16) offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32) - Y_cat = torch._grouped_mm(A_cat, B_stk, offs) # type: ignore[attr-defined] + Y_cat = torch.ops.aten._grouped_mm( + torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offs + ) outs: List[torch.Tensor] = [] start = 0 for m in offs.tolist(): - outs.append(Y_cat[start : start + m, :]) + outs.append(Y_cat[start : start + m]) start += m return outs def moe_ffn_forward_grouped( - hidden_states, gate_linear, experts_module, top_k: int + hidden_states: torch.Tensor, + gate_linear: torch.nn.Linear, + experts_module, + top_k: int, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - """Attempt grouped GEMM fast path using PyTorch 2.8+.""" - global LAST_ERROR - LAST_ERROR = None + if not available(): + return None, None + bsz, seqlen, hdim = hidden_states.shape + tokens = bsz * seqlen + device = hidden_states.device + routing_dtype = gate_linear.weight.dtype - use_mixed_router = ( - hidden_states.device.type == "cuda" and routing_dtype == torch.float32 - ) + expert_dtype = hidden_states.dtype + x_flat = hidden_states.view(tokens, hdim) + router_logits = gate_linear(x_flat.to(routing_dtype)) - x_base = hidden_states.view(-1, hdim) - if use_mixed_router: - x_router = x_base.to(dtype=routing_dtype) - else: - x_router = x_base - if x_router.dtype != routing_dtype: - x_router = x_router.to(dtype=routing_dtype) - - router_logits = gate_linear(x_router) - if router_logits.dtype != routing_dtype: - router_logits = router_logits.to(dtype=routing_dtype) - - x = x_base - - # top-k routing executed in torch to avoid extra dependencies routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) - E = _num_experts(experts_module) - dev = hidden_states.device - first = experts_module[0] - - is_mixtral = _is_mixtral_layout(first) - is_qwen2 = _is_qwen_layout(first) - nested_attr: Optional[str] = None - if not (is_mixtral or is_qwen2): - for candidate in ("mlp", "ffn"): - nested = getattr(first, candidate, None) - if nested is None: - continue - if _is_mixtral_layout(nested): - is_mixtral = True - nested_attr = candidate - break - if _is_qwen_layout(nested): - is_qwen2 = True - nested_attr = candidate - break - if not (is_mixtral or is_qwen2): - if not getattr(experts_module, "_ax_grouped_logged_fail", False): - _LOGGER.warning( - "torch_grouped: unsupported expert layout; falling back to naive" - ) - experts_module._ax_grouped_logged_fail = True - LAST_ERROR = "unsupported expert layout" - return None, None - - def _resolve_expert(idx: int) -> torch.nn.Module: - expert = experts_module[idx] - if nested_attr is None: - return expert - nested_mod = getattr(expert, nested_attr, None) - if nested_mod is None: - raise AttributeError(f"expert {idx} missing nested module '{nested_attr}'") - return nested_mod - - if is_mixtral: - dt: torch.dtype = first.w1.weight.dtype # type: ignore[assignment] - if ( - not hasattr(experts_module, "_stacked_w1") - or experts_module._stacked_w1.device != dev - or experts_module._stacked_w1.dtype != dt - ): - mods = [_resolve_expert(i) for i in range(E)] - w1 = [mod.w1.weight.t() for mod in mods] - w3 = [mod.w3.weight.t() for mod in mods] - w2 = [mod.w2.weight.t() for mod in mods] - experts_module._stacked_w1 = ( - torch.stack(w1, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w3 = ( - torch.stack(w3, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w2 = ( - torch.stack(w2, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w13 = torch.cat( - [experts_module._stacked_w1, experts_module._stacked_w3], dim=-1 - ).contiguous() - W13 = experts_module._stacked_w13 - W2 = experts_module._stacked_w2 + experts = list(experts_module) + sample = getattr(experts[0], "mlp", getattr(experts[0], "ffn", experts[0])) + if hasattr(sample, "w1") and hasattr(sample, "w3") and hasattr(sample, "w2"): + w13 = _stack_weights(experts, ("w1", "w3")).to( + device=device, dtype=expert_dtype + ) + w2 = _stack_weights(experts, ("w2",)).to(device=device, dtype=expert_dtype) else: - sample_mod = _resolve_expert(0) - if hasattr(sample_mod, "gate_up_proj"): - dt = sample_mod.gate_up_proj.weight.dtype # type: ignore[assignment] - elif hasattr(sample_mod, "up_proj"): - dt = sample_mod.up_proj.weight.dtype # type: ignore[assignment] - else: - dt = sample_mod.down_proj.weight.dtype # type: ignore[assignment] - if ( - not hasattr(experts_module, "_stacked_w13") - or experts_module._stacked_w13.device != dev - or experts_module._stacked_w13.dtype != dt - ): - w13 = [] - w2 = [] - for i in range(E): - mod = _resolve_expert(i) - if hasattr(mod, "gate_up_proj"): - w13.append(mod.gate_up_proj.weight.t()) - elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"): - w13.append( - torch.cat( - [mod.up_proj.weight.t(), mod.gate_proj.weight.t()], - dim=-1, - ) - ) - else: - LAST_ERROR = "unrecognized Qwen MoE expert weight layout" - if not getattr(experts_module, "_ax_grouped_logged_fail", False): - _LOGGER.warning( - "torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive" - ) - experts_module._ax_grouped_logged_fail = True - return None, None - w2.append(mod.down_proj.weight.t()) - experts_module._stacked_w13 = ( - torch.stack(w13, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w2 = ( - torch.stack(w2, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - W13 = experts_module._stacked_w13 - W2 = experts_module._stacked_w2 + names13 = ( + ("gate_up_proj",) + if hasattr(sample, "gate_up_proj") + else ("up_proj", "gate_proj") + ) + w13 = _stack_weights(experts, names13).to(device=device, dtype=expert_dtype) + w2 = _stack_weights(experts, ("down_proj",)).to( + device=device, dtype=expert_dtype + ) - dt = W13.dtype - if router_logits.dtype != dt: - router_logits = router_logits.to(dtype=dt) - if x.dtype != dt: - x = x.to(dtype=dt) flat_idx = topk_idx.view(-1) - if topk_weight.dtype != dt: - topk_weight = topk_weight.to(dtype=dt) - x_rep = x.repeat_interleave(top_k, dim=0) - if x_rep.dtype != dt: - x_rep = x_rep.to(dtype=dt) + x_rep = x_flat.to(expert_dtype).repeat_interleave(top_k, dim=0) - As: List[torch.Tensor] = [] - Bs: List[torch.Tensor] = [] - expert_slices: List[Tuple[int, torch.Tensor]] = [] - for i in range(E): + as_list: List[torch.Tensor] = [] + bs_list: List[torch.Tensor] = [] + slices: List[Tuple[int, torch.Tensor]] = [] + for i in range(len(experts)): sel = flat_idx == i if sel.any(): - Xi = x_rep[sel].contiguous() - As.append(Xi) - Bs.append(W13[i].reshape(hdim, -1).contiguous()) - expert_slices.append((i, sel)) + as_list.append(x_rep[sel]) + bs_list.append(w13[i]) + slices.append((i, sel)) - if not As: - out = torch.zeros_like(x) - return out.view(bsz, seqlen, hdim), router_logits + if not as_list: + return torch.zeros_like(x_flat).view(bsz, seqlen, hdim), router_logits - def _run_grouped_mm( - a_tensors: List[torch.Tensor], - b_tensors: List[torch.Tensor], - target_dtype: torch.dtype, - ) -> Optional[List[torch.Tensor]]: - global LAST_ERROR - if target_dtype != dt: - a_tensors = [t.to(target_dtype).contiguous() for t in a_tensors] - b_tensors = [t.to(target_dtype).contiguous() for t in b_tensors] - outputs = _call_grouped_mm(a_tensors, b_tensors) - if outputs is not None and target_dtype != dt: - outputs = [t.to(dt).contiguous() for t in outputs] - return outputs - - def _try_grouped_mm( - a_tensors: List[torch.Tensor], b_tensors: List[torch.Tensor] - ) -> Tuple[Optional[List[torch.Tensor]], bool]: - global LAST_ERROR - result = _run_grouped_mm(a_tensors, b_tensors, target_dtype=dt) - cast_used_local = False - if result is None and dt == torch.bfloat16: - result = _run_grouped_mm(a_tensors, b_tensors, target_dtype=torch.float16) - if result is not None: - cast_used_local = True - LAST_ERROR = None - if not getattr(experts_module, "_ax_grouped_logged_cast", False): - _LOGGER.info( - "torch_grouped: grouped_mm casting bfloat16 operands to float16" - ) - experts_module._ax_grouped_logged_cast = True - return result, cast_used_local - - Y_list, _cast_used_up = _try_grouped_mm(As, Bs) - if Y_list is None: - if not getattr(experts_module, "_ax_grouped_logged_fail", False): - _LOGGER.warning( - f"torch_grouped: grouped_mm up+gate failed; falling back to naive. Reason: {LAST_ERROR}" - ) - experts_module._ax_grouped_logged_fail = True + up_out = _call_grouped_mm(as_list, bs_list, expert_dtype) + if up_out is None: return None, None - As2: List[torch.Tensor] = [] - Bs2: List[torch.Tensor] = [] - y_buf = torch.empty_like(x_rep, dtype=dt) - for (i, _sel), Yi in zip(expert_slices, Y_list, strict=False): - I2 = Yi.shape[-1] // 2 - Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:] - As2.append(Yi_hidden) - Bs2.append(W2[i].reshape(I2, hdim).contiguous()) + down_inputs: List[torch.Tensor] = [] + down_weights: List[torch.Tensor] = [] + buf = torch.empty_like(x_rep) + for (i, _sel), Yi in zip(slices, up_out, strict=False): + mid = Yi.shape[-1] // 2 + hidden = F.silu(Yi[:, :mid]) * Yi[:, mid:] + down_inputs.append(hidden) + down_weights.append(w2[i]) - Y2_list, _cast_used_down = _try_grouped_mm(As2, Bs2) - if Y2_list is None: - if not getattr(experts_module, "_ax_grouped_logged_fail", False): - _LOGGER.warning( - f"torch_grouped: grouped_mm down failed; falling back to naive. Reason: {LAST_ERROR}" - ) - experts_module._ax_grouped_logged_fail = True + down_out = _call_grouped_mm(down_inputs, down_weights, expert_dtype) + if down_out is None: return None, None - for (_i, sel), Out_i in zip(expert_slices, Y2_list, strict=False): - y_buf[sel] = Out_i - y = (y_buf.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) - if not getattr(experts_module, "_ax_grouped_logged_ok", False): - _LOGGER.info( - f"torch_grouped: engaged grouped GEMM path (experts={E}, top_k={top_k})" - ) - experts_module._ax_grouped_logged_ok = True - return y.view(bsz, seqlen, hdim), router_logits + for (_i, sel), tensor in zip(slices, down_out, strict=False): + buf[sel] = tensor + + combined = ( + buf.view(tokens, top_k, -1) * topk_weight.to(expert_dtype).unsqueeze(-1) + ).sum(dim=1) + return combined.view(bsz, seqlen, hdim), router_logits diff --git a/src/axolotl/monkeypatch/moe_grouped.py b/src/axolotl/monkeypatch/moe_grouped.py index ab5bfba27..0d2d67c97 100644 --- a/src/axolotl/monkeypatch/moe_grouped.py +++ b/src/axolotl/monkeypatch/moe_grouped.py @@ -82,18 +82,10 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None: # One-time log per block instance indicating whether grouped engaged or fallback occurred if not getattr(self, "_ax_grouped_wrapper_logged", False): if y is None: - reason = getattr(_tg, "LAST_ERROR", None) - if reason: - _LOG.warning( - "Grouped wrapper fell back to naive for %s (reason=%s)", - self.__class__.__name__, - reason, - ) - else: - _LOG.warning( - "Grouped wrapper active but fell back to naive for %s", - self.__class__.__name__, - ) + _LOG.warning( + "Grouped wrapper active but fell back to naive for %s", + self.__class__.__name__, + ) else: _LOG.info( f"Grouped wrapper engaged for {self.__class__.__name__} (top_k={self.top_k})" diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index a5cdacf9e..1f4043ced 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -9,7 +9,6 @@ from pydantic import ( Field, StringConstraints, field_serializer, - field_validator, model_validator, ) diff --git a/tests/monkeypatch/test_moe_grouped.py b/tests/monkeypatch/test_moe_grouped.py index f668f2886..5e1cea925 100644 --- a/tests/monkeypatch/test_moe_grouped.py +++ b/tests/monkeypatch/test_moe_grouped.py @@ -75,7 +75,7 @@ def test_grouped_uses_per_expert_nested_modules(monkeypatch): captured = [] - def fake_grouped_mm(As, Bs): + def fake_grouped_mm(As, Bs, dtype): captured.append([b.detach().clone() for b in Bs]) return [ torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype) @@ -111,7 +111,7 @@ def test_grouped_accepts_module_list_experts(monkeypatch): calls = {"count": 0} - def fake_grouped_mm(As, Bs): + def fake_grouped_mm(As, Bs, dtype): calls["count"] += 1 return [ torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)