diff --git a/_quarto.yml b/_quarto.yml index 3ffb0e627..d50db8845 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -285,6 +285,7 @@ website: - docs/custom_integrations.qmd - docs/sequence_parallelism.qmd - docs/gradient_checkpointing.qmd + - docs/moe_backends.md - docs/nd_parallelism.qmd - section: "Troubleshooting" diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index 7b356f834..38c0bb2b7 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -32,6 +32,16 @@ LAST_ERROR: Optional[str] = None _LOGGER = logging.getLogger("axolotl.moe.grouped") +def _is_mixtral_layout(mod: torch.nn.Module) -> bool: + return all(hasattr(mod, attr) for attr in ("w1", "w3", "w2")) + + +def _is_qwen_layout(mod: torch.nn.Module) -> bool: + has_fused = hasattr(mod, "gate_up_proj") + has_split = hasattr(mod, "up_proj") and hasattr(mod, "gate_proj") + return (has_fused or has_split) and hasattr(mod, "down_proj") + + def _call_grouped_mm( As: List[torch.Tensor], Bs: List[torch.Tensor] ) -> Optional[List[torch.Tensor]]: @@ -96,33 +106,27 @@ def moe_ffn_forward_grouped( flat_idx = topk_idx.view(-1) x_rep = x.repeat_interleave(top_k, dim=0) - # Cache stacked weights on experts (support Mixtral and Qwen2-MoE layouts) + # Cache stacked weights on experts (support Mixtral and Qwen-style layouts) E = experts_module.num_experts dev, dt = x.device, x.dtype first = experts_module[0] - is_mixtral = ( - hasattr(first, "w1") and hasattr(first, "w3") and hasattr(first, "w2") - ) - is_qwen2 = ( - hasattr(first, "gate_up_proj") - or hasattr(first, "gate_proj") - or hasattr(first, "up_proj") - ) and hasattr(first, "down_proj") - # try nested mlp/ffn module - nested = None + + is_mixtral = _is_mixtral_layout(first) + is_qwen2 = _is_qwen_layout(first) + nested_attr: Optional[str] = None if not (is_mixtral or is_qwen2): - nested = getattr(first, "mlp", None) or getattr(first, "ffn", None) - if nested is not None: - is_mixtral = ( - hasattr(nested, "w1") - and hasattr(nested, "w3") - and hasattr(nested, "w2") - ) - is_qwen2 = ( - hasattr(nested, "gate_up_proj") - or hasattr(nested, "gate_proj") - or hasattr(nested, "up_proj") - ) and hasattr(nested, "down_proj") + for candidate in ("mlp", "ffn"): + nested = getattr(first, candidate, None) + if nested is None: + continue + if _is_mixtral_layout(nested): + is_mixtral = True + nested_attr = candidate + break + if _is_qwen_layout(nested): + is_qwen2 = True + nested_attr = candidate + break if not (is_mixtral or is_qwen2): if not getattr(experts_module, "_ax_grouped_logged_fail", False): _LOGGER.warning( @@ -131,81 +135,101 @@ def moe_ffn_forward_grouped( experts_module._ax_grouped_logged_fail = True return None, None - if is_mixtral: - if ( - not hasattr(experts_module, "_stacked_w1") - or experts_module._stacked_w1.device != dev - or experts_module._stacked_w1.dtype != dt - ): - w1 = [experts_module[i].w1.weight.t() for i in range(E)] - w3 = [experts_module[i].w3.weight.t() for i in range(E)] - w2 = [experts_module[i].w2.weight.t() for i in range(E)] - experts_module._stacked_w1 = ( - torch.stack(w1, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() + def _resolve_expert(idx: int): + expert = experts_module[idx] + if nested_attr is None: + return expert + nested_mod = getattr(expert, nested_attr, None) + if nested_mod is None: + raise AttributeError( + f"expert {idx} missing nested module '{nested_attr}'" ) - experts_module._stacked_w3 = ( - torch.stack(w3, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w2 = ( - torch.stack(w2, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - experts_module._stacked_w13 = torch.cat( - [experts_module._stacked_w1, experts_module._stacked_w3], dim=-1 - ).contiguous() - W13 = experts_module._stacked_w13 - W2 = experts_module._stacked_w2 - else: - # Qwen2/3 MoE style: either gate_up_proj (2I x H) or (up_proj + gate_proj), down_proj (H x I) - if ( - not hasattr(experts_module, "_stacked_w13") - or experts_module._stacked_w13.device != dev - or experts_module._stacked_w13.dtype != dt - ): - w13 = [] - w2 = [] - for i in range(E): - exp = experts_module[i] - mod = nested if nested is not None else exp - # prefer fused gate_up_proj if present - if hasattr(mod, "gate_up_proj"): - w13.append(mod.gate_up_proj.weight.t()) - elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"): - # concatenate [up | gate] along N - w13.append( - torch.cat( - [mod.up_proj.weight.t(), mod.gate_proj.weight.t()], - dim=-1, + return nested_mod + + try: + if is_mixtral: + if ( + not hasattr(experts_module, "_stacked_w1") + or experts_module._stacked_w1.device != dev + or experts_module._stacked_w1.dtype != dt + ): + mods = [_resolve_expert(i) for i in range(E)] + w1 = [mod.w1.weight.t() for mod in mods] + w3 = [mod.w3.weight.t() for mod in mods] + w2 = [mod.w2.weight.t() for mod in mods] + experts_module._stacked_w1 = ( + torch.stack(w1, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w3 = ( + torch.stack(w3, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w2 = ( + torch.stack(w2, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w13 = torch.cat( + [experts_module._stacked_w1, experts_module._stacked_w3], dim=-1 + ).contiguous() + W13 = experts_module._stacked_w13 + W2 = experts_module._stacked_w2 + else: + # Qwen-style MoE: either gate_up_proj (2I x H) or (up_proj + gate_proj), down_proj (H x I) + if ( + not hasattr(experts_module, "_stacked_w13") + or experts_module._stacked_w13.device != dev + or experts_module._stacked_w13.dtype != dt + ): + w13 = [] + w2 = [] + for i in range(E): + mod = _resolve_expert(i) + # prefer fused gate_up_proj if present + if hasattr(mod, "gate_up_proj"): + w13.append(mod.gate_up_proj.weight.t()) + elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"): + # concatenate [up | gate] along N + w13.append( + torch.cat( + [mod.up_proj.weight.t(), mod.gate_proj.weight.t()], + dim=-1, + ) ) - ) - else: - LAST_ERROR = "unrecognized Qwen MoE expert weight layout" - if not getattr( - experts_module, "_ax_grouped_logged_fail", False - ): - _LOGGER.warning( - "torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive" - ) - experts_module._ax_grouped_logged_fail = True - return None, None - w2.append((mod.down_proj.weight.t())) - experts_module._stacked_w13 = ( - torch.stack(w13, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() + else: + LAST_ERROR = "unrecognized Qwen MoE expert weight layout" + if not getattr( + experts_module, "_ax_grouped_logged_fail", False + ): + _LOGGER.warning( + "torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive" + ) + experts_module._ax_grouped_logged_fail = True + return None, None + w2.append(mod.down_proj.weight.t()) + experts_module._stacked_w13 = ( + torch.stack(w13, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + experts_module._stacked_w2 = ( + torch.stack(w2, dim=0) + .to(device=dev, dtype=dt, non_blocking=True) + .contiguous() + ) + W13 = experts_module._stacked_w13 + W2 = experts_module._stacked_w2 + except AttributeError as err: + LAST_ERROR = str(err) + if not getattr(experts_module, "_ax_grouped_logged_fail", False): + _LOGGER.warning( + "torch_grouped: expert weights missing expected attributes; falling back to naive" ) - experts_module._stacked_w2 = ( - torch.stack(w2, dim=0) - .to(device=dev, dtype=dt, non_blocking=True) - .contiguous() - ) - W13 = experts_module._stacked_w13 - W2 = experts_module._stacked_w2 + experts_module._ax_grouped_logged_fail = True + return None, None # Grouped GEMM for up+gate As: List[torch.Tensor] = [] @@ -237,8 +261,9 @@ def moe_ffn_forward_grouped( As2: List[torch.Tensor] = [] Bs2: List[torch.Tensor] = [] y_buf = torch.empty_like(x_rep) + # split Y into (I, I) - for (i, sel), Yi in zip(expert_slices, Y_list): + for Yi in Y_list: I2 = Yi.shape[-1] // 2 Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:] As2.append(Yi_hidden) @@ -254,7 +279,7 @@ def moe_ffn_forward_grouped( return None, None # Write back, apply per-token weighting, and reduce over top_k - for (i, sel), Out_i in zip(expert_slices, Y2_list): + for (_, sel), Out_i in zip(expert_slices, Y2_list, strict=False): y_buf[sel] = Out_i y = (y_buf.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) if not getattr(experts_module, "_ax_grouped_logged_ok", False): diff --git a/src/axolotl/monkeypatch/moe_grouped.py b/src/axolotl/monkeypatch/moe_grouped.py index 36b975d04..01ac6e101 100644 --- a/src/axolotl/monkeypatch/moe_grouped.py +++ b/src/axolotl/monkeypatch/moe_grouped.py @@ -1,5 +1,5 @@ import logging -import warnings +from functools import wraps import torch @@ -11,7 +11,7 @@ _LOG = logging.getLogger("axolotl.moe.patch") def _patch_block_forward(block_cls, grouped_fn): """Replace block_cls.forward with grouped_fn preserving signature.""" - setattr(block_cls, "forward", grouped_fn) + block_cls.forward = grouped_fn def apply_grouped_to_moe_blocks(cfg=None) -> None: @@ -73,7 +73,8 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None: } def make_grouped_forward(orig_forward): - def _grouped_forward(self, hidden_states: torch.Tensor): + @wraps(orig_forward) + def _grouped_forward(self, hidden_states: torch.Tensor, *args, **kwargs): bsz, seqlen, hdim = hidden_states.shape y, router_logits = _tg.moe_ffn_forward_grouped( hidden_states, self.gate, self.experts, self.top_k @@ -90,7 +91,7 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None: ) self._ax_grouped_wrapper_logged = True if y is None: - return orig_forward(self, hidden_states) + return orig_forward(self, hidden_states, *args, **kwargs) return y, router_logits return _grouped_forward diff --git a/tests/monkeypatch/test_moe_grouped.py b/tests/monkeypatch/test_moe_grouped.py new file mode 100644 index 000000000..9b409e344 --- /dev/null +++ b/tests/monkeypatch/test_moe_grouped.py @@ -0,0 +1,228 @@ +import sys +import types + +import torch +import torch.nn as nn + +from axolotl.kernels.moe import ( + backends as moe_backends, + torch_grouped as torch_grouped_module, +) +from axolotl.monkeypatch import moe_grouped + + +class DummyExperts(nn.Module): + def __init__(self, layers): + super().__init__() + self.layers = nn.ModuleList(layers) + self.num_experts = len(layers) + + def __getitem__(self, idx): + return self.layers[idx] + + +class DummyQwenMLP(nn.Module): + def __init__(self, idx: int, hidden: int, intermediate: int): + super().__init__() + self.gate_up_proj = nn.Linear(hidden, 2 * intermediate, bias=False) + self.down_proj = nn.Linear(intermediate, hidden, bias=False) + nn.init.constant_(self.gate_up_proj.weight, float(idx + 1)) + nn.init.constant_(self.down_proj.weight, float((idx + 1) * 10)) + + +class DummyQwenExpert(nn.Module): + def __init__(self, idx: int, hidden: int, intermediate: int): + super().__init__() + self.mlp = DummyQwenMLP(idx, hidden, intermediate) + + +def _make_transformers_stub(monkeypatch, block_cls): + # ensure we start from the original forward for each test + if block_cls is DummyMixtralBlock: + DummyMixtralBlock.forward = _DUMMY_MIXTRAL_ORIG_FORWARD + + transformers_mod = types.ModuleType("transformers") + models_mod = types.ModuleType("transformers.models") + mixtral_mod = types.ModuleType("transformers.models.mixtral") + modeling_mixtral = types.ModuleType("transformers.models.mixtral.modeling_mixtral") + modeling_mixtral.MixtralSparseMoeBlock = block_cls + + transformers_mod.models = models_mod + models_mod.mixtral = mixtral_mod + mixtral_mod.modeling_mixtral = modeling_mixtral + + monkeypatch.setitem(sys.modules, "transformers", transformers_mod) + monkeypatch.setitem(sys.modules, "transformers.models", models_mod) + monkeypatch.setitem(sys.modules, "transformers.models.mixtral", mixtral_mod) + monkeypatch.setitem( + sys.modules, + "transformers.models.mixtral.modeling_mixtral", + modeling_mixtral, + ) + + +def test_grouped_uses_per_expert_nested_modules(monkeypatch): + hidden = 4 + intermediate = 2 + num_experts = 2 + + experts = DummyExperts( + [DummyQwenExpert(i, hidden, intermediate) for i in range(num_experts)] + ) + + gate = nn.Linear(hidden, num_experts, bias=False) + nn.init.zeros_(gate.weight) + + captured = [] + + def fake_grouped_mm(As, Bs): + captured.append([b.detach().clone() for b in Bs]) + return [ + torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype) + for a, b in zip(As, Bs, strict=False) + ] + + monkeypatch.setattr(torch_grouped_module, "_call_grouped_mm", fake_grouped_mm) + + hidden_states = torch.randn(1, 2, hidden) + y, router_logits = torch_grouped_module.moe_ffn_forward_grouped( + hidden_states, gate, experts, top_k=2 + ) + + assert y is not None + assert router_logits is not None + assert captured, "Grouped GEMM path should have been invoked" + first_call = captured[0] + expected0 = experts[0].mlp.gate_up_proj.weight.t() + expected1 = experts[1].mlp.gate_up_proj.weight.t() + assert torch.equal(first_call[0], expected0) + assert torch.equal(first_call[1], expected1) + assert not torch.equal(first_call[0], first_call[1]) + + +class _DummyCfg: + moe_backend = "torch_grouped" + + +class DummyMixtralBlock(nn.Module): + def __init__(self): + super().__init__() + self.top_k = 1 + self.gate = lambda x: x + self.experts = object() + self._calls = [] + + def forward(self, hidden_states: torch.Tensor, attention_mask=None): + self._calls.append((hidden_states, attention_mask)) + tokens = hidden_states.shape[0] * hidden_states.shape[1] + router = torch.ones( + tokens, 2, device=hidden_states.device, dtype=hidden_states.dtype + ) + return hidden_states + 5, router + + +_DUMMY_MIXTRAL_ORIG_FORWARD = DummyMixtralBlock.forward + + +def test_apply_grouped_forward_handles_args(monkeypatch): + _make_transformers_stub(monkeypatch, DummyMixtralBlock) + import axolotl.common.architectures as arch + + original_map = arch.MOE_ARCH_BLOCK.copy() + monkeypatch.setitem(arch.MOE_ARCH_BLOCK, "mixtral", "MixtralSparseMoeBlock") + for key in list(original_map.keys()): + if key != "mixtral": + monkeypatch.setitem(arch.MOE_ARCH_BLOCK, key, None) + + monkeypatch.setattr( + moe_grouped, + "get_moe_backend_name", + lambda preferred=None: moe_backends.MOEBackend.TORCH_GROUPED, + ) + + results = {} + + def fake_grouped_forward(hidden_states, gate, experts, top_k): + results["called"] = True + router = torch.zeros( + hidden_states.shape[0] * hidden_states.shape[1], + 2, + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + return hidden_states + 1, router + + monkeypatch.setattr(torch_grouped_module, "available", lambda: True) + monkeypatch.setattr( + torch_grouped_module, + "moe_ffn_forward_grouped", + fake_grouped_forward, + ) + + cfg = _DummyCfg() + moe_grouped.apply_grouped_to_moe_blocks(cfg) + + block = DummyMixtralBlock() + hidden_states = torch.ones(1, 2, 3) + mask = torch.zeros(1, 2) + out, router = block.forward(hidden_states, attention_mask=mask) + + assert results.get("called") is True + assert torch.equal(out, hidden_states + 1) + assert router.shape[0] == hidden_states.shape[0] * hidden_states.shape[1] + + +def test_apply_grouped_forward_fallback(monkeypatch): + _make_transformers_stub(monkeypatch, DummyMixtralBlock) + import axolotl.common.architectures as arch + + original_map = arch.MOE_ARCH_BLOCK.copy() + monkeypatch.setitem(arch.MOE_ARCH_BLOCK, "mixtral", "MixtralSparseMoeBlock") + for key in list(original_map.keys()): + if key != "mixtral": + monkeypatch.setitem(arch.MOE_ARCH_BLOCK, key, None) + + monkeypatch.setattr( + moe_grouped, + "get_moe_backend_name", + lambda preferred=None: moe_backends.MOEBackend.TORCH_GROUPED, + ) + monkeypatch.setattr(torch_grouped_module, "available", lambda: True) + monkeypatch.setattr( + torch_grouped_module, + "moe_ffn_forward_grouped", + lambda *args, **kwargs: (None, None), + ) + + cfg = _DummyCfg() + moe_grouped.apply_grouped_to_moe_blocks(cfg) + + block = DummyMixtralBlock() + hidden_states = torch.ones(1, 2, 3) + mask = torch.zeros(1, 2) + out, router = block.forward(hidden_states, attention_mask=mask) + + assert torch.equal(out, hidden_states + 5) + assert router.shape[0] == hidden_states.shape[0] * hidden_states.shape[1] + assert block._calls, "Original forward should have been invoked" + call_hidden, call_mask = block._calls[-1] + assert torch.equal(call_hidden, hidden_states) + assert torch.equal(call_mask, mask) + + +def test_get_moe_backend_name_prefers_probe(monkeypatch): + monkeypatch.setattr(moe_backends, "_probe_torch_grouped", lambda: True) + assert moe_backends.get_moe_backend_name() == moe_backends.MOEBackend.TORCH_GROUPED + + +def test_get_moe_backend_name_falls_back(monkeypatch): + warnings_captured = [] + + def fake_warn(msg): + warnings_captured.append(msg) + + monkeypatch.setattr(moe_backends, "_probe_torch_grouped", lambda: False) + monkeypatch.setattr(moe_backends.warnings, "warn", fake_warn) + backend = moe_backends.get_moe_backend_name("torch_grouped") + assert backend == moe_backends.MOEBackend.NAIVE + assert warnings_captured, "Expected warning when torch_grouped unavailable"