refactor + fix

This commit is contained in:
Dan Saunders
2025-09-17 14:01:39 -04:00
parent 7289e0cb55
commit c774dd0409
4 changed files with 356 additions and 101 deletions

View File

@@ -285,6 +285,7 @@ website:
- docs/custom_integrations.qmd - docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd - docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd - docs/gradient_checkpointing.qmd
- docs/moe_backends.md
- docs/nd_parallelism.qmd - docs/nd_parallelism.qmd
- section: "Troubleshooting" - section: "Troubleshooting"

View File

@@ -32,6 +32,16 @@ LAST_ERROR: Optional[str] = None
_LOGGER = logging.getLogger("axolotl.moe.grouped") _LOGGER = logging.getLogger("axolotl.moe.grouped")
def _is_mixtral_layout(mod: torch.nn.Module) -> bool:
return all(hasattr(mod, attr) for attr in ("w1", "w3", "w2"))
def _is_qwen_layout(mod: torch.nn.Module) -> bool:
has_fused = hasattr(mod, "gate_up_proj")
has_split = hasattr(mod, "up_proj") and hasattr(mod, "gate_proj")
return (has_fused or has_split) and hasattr(mod, "down_proj")
def _call_grouped_mm( def _call_grouped_mm(
As: List[torch.Tensor], Bs: List[torch.Tensor] As: List[torch.Tensor], Bs: List[torch.Tensor]
) -> Optional[List[torch.Tensor]]: ) -> Optional[List[torch.Tensor]]:
@@ -96,33 +106,27 @@ def moe_ffn_forward_grouped(
flat_idx = topk_idx.view(-1) flat_idx = topk_idx.view(-1)
x_rep = x.repeat_interleave(top_k, dim=0) x_rep = x.repeat_interleave(top_k, dim=0)
# Cache stacked weights on experts (support Mixtral and Qwen2-MoE layouts) # Cache stacked weights on experts (support Mixtral and Qwen-style layouts)
E = experts_module.num_experts E = experts_module.num_experts
dev, dt = x.device, x.dtype dev, dt = x.device, x.dtype
first = experts_module[0] first = experts_module[0]
is_mixtral = (
hasattr(first, "w1") and hasattr(first, "w3") and hasattr(first, "w2") is_mixtral = _is_mixtral_layout(first)
) is_qwen2 = _is_qwen_layout(first)
is_qwen2 = ( nested_attr: Optional[str] = None
hasattr(first, "gate_up_proj")
or hasattr(first, "gate_proj")
or hasattr(first, "up_proj")
) and hasattr(first, "down_proj")
# try nested mlp/ffn module
nested = None
if not (is_mixtral or is_qwen2): if not (is_mixtral or is_qwen2):
nested = getattr(first, "mlp", None) or getattr(first, "ffn", None) for candidate in ("mlp", "ffn"):
if nested is not None: nested = getattr(first, candidate, None)
is_mixtral = ( if nested is None:
hasattr(nested, "w1") continue
and hasattr(nested, "w3") if _is_mixtral_layout(nested):
and hasattr(nested, "w2") is_mixtral = True
) nested_attr = candidate
is_qwen2 = ( break
hasattr(nested, "gate_up_proj") if _is_qwen_layout(nested):
or hasattr(nested, "gate_proj") is_qwen2 = True
or hasattr(nested, "up_proj") nested_attr = candidate
) and hasattr(nested, "down_proj") break
if not (is_mixtral or is_qwen2): if not (is_mixtral or is_qwen2):
if not getattr(experts_module, "_ax_grouped_logged_fail", False): if not getattr(experts_module, "_ax_grouped_logged_fail", False):
_LOGGER.warning( _LOGGER.warning(
@@ -131,81 +135,101 @@ def moe_ffn_forward_grouped(
experts_module._ax_grouped_logged_fail = True experts_module._ax_grouped_logged_fail = True
return None, None return None, None
if is_mixtral: def _resolve_expert(idx: int):
if ( expert = experts_module[idx]
not hasattr(experts_module, "_stacked_w1") if nested_attr is None:
or experts_module._stacked_w1.device != dev return expert
or experts_module._stacked_w1.dtype != dt nested_mod = getattr(expert, nested_attr, None)
): if nested_mod is None:
w1 = [experts_module[i].w1.weight.t() for i in range(E)] raise AttributeError(
w3 = [experts_module[i].w3.weight.t() for i in range(E)] f"expert {idx} missing nested module '{nested_attr}'"
w2 = [experts_module[i].w2.weight.t() for i in range(E)]
experts_module._stacked_w1 = (
torch.stack(w1, dim=0)
.to(device=dev, dtype=dt, non_blocking=True)
.contiguous()
) )
experts_module._stacked_w3 = ( return nested_mod
torch.stack(w3, dim=0)
.to(device=dev, dtype=dt, non_blocking=True) try:
.contiguous() if is_mixtral:
) if (
experts_module._stacked_w2 = ( not hasattr(experts_module, "_stacked_w1")
torch.stack(w2, dim=0) or experts_module._stacked_w1.device != dev
.to(device=dev, dtype=dt, non_blocking=True) or experts_module._stacked_w1.dtype != dt
.contiguous() ):
) mods = [_resolve_expert(i) for i in range(E)]
experts_module._stacked_w13 = torch.cat( w1 = [mod.w1.weight.t() for mod in mods]
[experts_module._stacked_w1, experts_module._stacked_w3], dim=-1 w3 = [mod.w3.weight.t() for mod in mods]
).contiguous() w2 = [mod.w2.weight.t() for mod in mods]
W13 = experts_module._stacked_w13 experts_module._stacked_w1 = (
W2 = experts_module._stacked_w2 torch.stack(w1, dim=0)
else: .to(device=dev, dtype=dt, non_blocking=True)
# Qwen2/3 MoE style: either gate_up_proj (2I x H) or (up_proj + gate_proj), down_proj (H x I) .contiguous()
if ( )
not hasattr(experts_module, "_stacked_w13") experts_module._stacked_w3 = (
or experts_module._stacked_w13.device != dev torch.stack(w3, dim=0)
or experts_module._stacked_w13.dtype != dt .to(device=dev, dtype=dt, non_blocking=True)
): .contiguous()
w13 = [] )
w2 = [] experts_module._stacked_w2 = (
for i in range(E): torch.stack(w2, dim=0)
exp = experts_module[i] .to(device=dev, dtype=dt, non_blocking=True)
mod = nested if nested is not None else exp .contiguous()
# prefer fused gate_up_proj if present )
if hasattr(mod, "gate_up_proj"): experts_module._stacked_w13 = torch.cat(
w13.append(mod.gate_up_proj.weight.t()) [experts_module._stacked_w1, experts_module._stacked_w3], dim=-1
elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"): ).contiguous()
# concatenate [up | gate] along N W13 = experts_module._stacked_w13
w13.append( W2 = experts_module._stacked_w2
torch.cat( else:
[mod.up_proj.weight.t(), mod.gate_proj.weight.t()], # Qwen-style MoE: either gate_up_proj (2I x H) or (up_proj + gate_proj), down_proj (H x I)
dim=-1, if (
not hasattr(experts_module, "_stacked_w13")
or experts_module._stacked_w13.device != dev
or experts_module._stacked_w13.dtype != dt
):
w13 = []
w2 = []
for i in range(E):
mod = _resolve_expert(i)
# prefer fused gate_up_proj if present
if hasattr(mod, "gate_up_proj"):
w13.append(mod.gate_up_proj.weight.t())
elif hasattr(mod, "up_proj") and hasattr(mod, "gate_proj"):
# concatenate [up | gate] along N
w13.append(
torch.cat(
[mod.up_proj.weight.t(), mod.gate_proj.weight.t()],
dim=-1,
)
) )
) else:
else: LAST_ERROR = "unrecognized Qwen MoE expert weight layout"
LAST_ERROR = "unrecognized Qwen MoE expert weight layout" if not getattr(
if not getattr( experts_module, "_ax_grouped_logged_fail", False
experts_module, "_ax_grouped_logged_fail", False ):
): _LOGGER.warning(
_LOGGER.warning( "torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive"
"torch_grouped: could not resolve Qwen MoE expert weights; fallback to naive" )
) experts_module._ax_grouped_logged_fail = True
experts_module._ax_grouped_logged_fail = True return None, None
return None, None w2.append(mod.down_proj.weight.t())
w2.append((mod.down_proj.weight.t())) experts_module._stacked_w13 = (
experts_module._stacked_w13 = ( torch.stack(w13, dim=0)
torch.stack(w13, dim=0) .to(device=dev, dtype=dt, non_blocking=True)
.to(device=dev, dtype=dt, non_blocking=True) .contiguous()
.contiguous() )
experts_module._stacked_w2 = (
torch.stack(w2, dim=0)
.to(device=dev, dtype=dt, non_blocking=True)
.contiguous()
)
W13 = experts_module._stacked_w13
W2 = experts_module._stacked_w2
except AttributeError as err:
LAST_ERROR = str(err)
if not getattr(experts_module, "_ax_grouped_logged_fail", False):
_LOGGER.warning(
"torch_grouped: expert weights missing expected attributes; falling back to naive"
) )
experts_module._stacked_w2 = ( experts_module._ax_grouped_logged_fail = True
torch.stack(w2, dim=0) return None, None
.to(device=dev, dtype=dt, non_blocking=True)
.contiguous()
)
W13 = experts_module._stacked_w13
W2 = experts_module._stacked_w2
# Grouped GEMM for up+gate # Grouped GEMM for up+gate
As: List[torch.Tensor] = [] As: List[torch.Tensor] = []
@@ -237,8 +261,9 @@ def moe_ffn_forward_grouped(
As2: List[torch.Tensor] = [] As2: List[torch.Tensor] = []
Bs2: List[torch.Tensor] = [] Bs2: List[torch.Tensor] = []
y_buf = torch.empty_like(x_rep) y_buf = torch.empty_like(x_rep)
# split Y into (I, I) # split Y into (I, I)
for (i, sel), Yi in zip(expert_slices, Y_list): for Yi in Y_list:
I2 = Yi.shape[-1] // 2 I2 = Yi.shape[-1] // 2
Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:] Yi_hidden = F.silu(Yi[:, :I2]) * Yi[:, I2:]
As2.append(Yi_hidden) As2.append(Yi_hidden)
@@ -254,7 +279,7 @@ def moe_ffn_forward_grouped(
return None, None return None, None
# Write back, apply per-token weighting, and reduce over top_k # Write back, apply per-token weighting, and reduce over top_k
for (i, sel), Out_i in zip(expert_slices, Y2_list): for (_, sel), Out_i in zip(expert_slices, Y2_list, strict=False):
y_buf[sel] = Out_i y_buf[sel] = Out_i
y = (y_buf.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) y = (y_buf.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
if not getattr(experts_module, "_ax_grouped_logged_ok", False): if not getattr(experts_module, "_ax_grouped_logged_ok", False):

View File

@@ -1,5 +1,5 @@
import logging import logging
import warnings from functools import wraps
import torch import torch
@@ -11,7 +11,7 @@ _LOG = logging.getLogger("axolotl.moe.patch")
def _patch_block_forward(block_cls, grouped_fn): def _patch_block_forward(block_cls, grouped_fn):
"""Replace block_cls.forward with grouped_fn preserving signature.""" """Replace block_cls.forward with grouped_fn preserving signature."""
setattr(block_cls, "forward", grouped_fn) block_cls.forward = grouped_fn
def apply_grouped_to_moe_blocks(cfg=None) -> None: def apply_grouped_to_moe_blocks(cfg=None) -> None:
@@ -73,7 +73,8 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
} }
def make_grouped_forward(orig_forward): def make_grouped_forward(orig_forward):
def _grouped_forward(self, hidden_states: torch.Tensor): @wraps(orig_forward)
def _grouped_forward(self, hidden_states: torch.Tensor, *args, **kwargs):
bsz, seqlen, hdim = hidden_states.shape bsz, seqlen, hdim = hidden_states.shape
y, router_logits = _tg.moe_ffn_forward_grouped( y, router_logits = _tg.moe_ffn_forward_grouped(
hidden_states, self.gate, self.experts, self.top_k hidden_states, self.gate, self.experts, self.top_k
@@ -90,7 +91,7 @@ def apply_grouped_to_moe_blocks(cfg=None) -> None:
) )
self._ax_grouped_wrapper_logged = True self._ax_grouped_wrapper_logged = True
if y is None: if y is None:
return orig_forward(self, hidden_states) return orig_forward(self, hidden_states, *args, **kwargs)
return y, router_logits return y, router_logits
return _grouped_forward return _grouped_forward

View File

@@ -0,0 +1,228 @@
import sys
import types
import torch
import torch.nn as nn
from axolotl.kernels.moe import (
backends as moe_backends,
torch_grouped as torch_grouped_module,
)
from axolotl.monkeypatch import moe_grouped
class DummyExperts(nn.Module):
def __init__(self, layers):
super().__init__()
self.layers = nn.ModuleList(layers)
self.num_experts = len(layers)
def __getitem__(self, idx):
return self.layers[idx]
class DummyQwenMLP(nn.Module):
def __init__(self, idx: int, hidden: int, intermediate: int):
super().__init__()
self.gate_up_proj = nn.Linear(hidden, 2 * intermediate, bias=False)
self.down_proj = nn.Linear(intermediate, hidden, bias=False)
nn.init.constant_(self.gate_up_proj.weight, float(idx + 1))
nn.init.constant_(self.down_proj.weight, float((idx + 1) * 10))
class DummyQwenExpert(nn.Module):
def __init__(self, idx: int, hidden: int, intermediate: int):
super().__init__()
self.mlp = DummyQwenMLP(idx, hidden, intermediate)
def _make_transformers_stub(monkeypatch, block_cls):
# ensure we start from the original forward for each test
if block_cls is DummyMixtralBlock:
DummyMixtralBlock.forward = _DUMMY_MIXTRAL_ORIG_FORWARD
transformers_mod = types.ModuleType("transformers")
models_mod = types.ModuleType("transformers.models")
mixtral_mod = types.ModuleType("transformers.models.mixtral")
modeling_mixtral = types.ModuleType("transformers.models.mixtral.modeling_mixtral")
modeling_mixtral.MixtralSparseMoeBlock = block_cls
transformers_mod.models = models_mod
models_mod.mixtral = mixtral_mod
mixtral_mod.modeling_mixtral = modeling_mixtral
monkeypatch.setitem(sys.modules, "transformers", transformers_mod)
monkeypatch.setitem(sys.modules, "transformers.models", models_mod)
monkeypatch.setitem(sys.modules, "transformers.models.mixtral", mixtral_mod)
monkeypatch.setitem(
sys.modules,
"transformers.models.mixtral.modeling_mixtral",
modeling_mixtral,
)
def test_grouped_uses_per_expert_nested_modules(monkeypatch):
hidden = 4
intermediate = 2
num_experts = 2
experts = DummyExperts(
[DummyQwenExpert(i, hidden, intermediate) for i in range(num_experts)]
)
gate = nn.Linear(hidden, num_experts, bias=False)
nn.init.zeros_(gate.weight)
captured = []
def fake_grouped_mm(As, Bs):
captured.append([b.detach().clone() for b in Bs])
return [
torch.zeros(a.shape[0], b.shape[-1], device=a.device, dtype=a.dtype)
for a, b in zip(As, Bs, strict=False)
]
monkeypatch.setattr(torch_grouped_module, "_call_grouped_mm", fake_grouped_mm)
hidden_states = torch.randn(1, 2, hidden)
y, router_logits = torch_grouped_module.moe_ffn_forward_grouped(
hidden_states, gate, experts, top_k=2
)
assert y is not None
assert router_logits is not None
assert captured, "Grouped GEMM path should have been invoked"
first_call = captured[0]
expected0 = experts[0].mlp.gate_up_proj.weight.t()
expected1 = experts[1].mlp.gate_up_proj.weight.t()
assert torch.equal(first_call[0], expected0)
assert torch.equal(first_call[1], expected1)
assert not torch.equal(first_call[0], first_call[1])
class _DummyCfg:
moe_backend = "torch_grouped"
class DummyMixtralBlock(nn.Module):
def __init__(self):
super().__init__()
self.top_k = 1
self.gate = lambda x: x
self.experts = object()
self._calls = []
def forward(self, hidden_states: torch.Tensor, attention_mask=None):
self._calls.append((hidden_states, attention_mask))
tokens = hidden_states.shape[0] * hidden_states.shape[1]
router = torch.ones(
tokens, 2, device=hidden_states.device, dtype=hidden_states.dtype
)
return hidden_states + 5, router
_DUMMY_MIXTRAL_ORIG_FORWARD = DummyMixtralBlock.forward
def test_apply_grouped_forward_handles_args(monkeypatch):
_make_transformers_stub(monkeypatch, DummyMixtralBlock)
import axolotl.common.architectures as arch
original_map = arch.MOE_ARCH_BLOCK.copy()
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, "mixtral", "MixtralSparseMoeBlock")
for key in list(original_map.keys()):
if key != "mixtral":
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, key, None)
monkeypatch.setattr(
moe_grouped,
"get_moe_backend_name",
lambda preferred=None: moe_backends.MOEBackend.TORCH_GROUPED,
)
results = {}
def fake_grouped_forward(hidden_states, gate, experts, top_k):
results["called"] = True
router = torch.zeros(
hidden_states.shape[0] * hidden_states.shape[1],
2,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
return hidden_states + 1, router
monkeypatch.setattr(torch_grouped_module, "available", lambda: True)
monkeypatch.setattr(
torch_grouped_module,
"moe_ffn_forward_grouped",
fake_grouped_forward,
)
cfg = _DummyCfg()
moe_grouped.apply_grouped_to_moe_blocks(cfg)
block = DummyMixtralBlock()
hidden_states = torch.ones(1, 2, 3)
mask = torch.zeros(1, 2)
out, router = block.forward(hidden_states, attention_mask=mask)
assert results.get("called") is True
assert torch.equal(out, hidden_states + 1)
assert router.shape[0] == hidden_states.shape[0] * hidden_states.shape[1]
def test_apply_grouped_forward_fallback(monkeypatch):
_make_transformers_stub(monkeypatch, DummyMixtralBlock)
import axolotl.common.architectures as arch
original_map = arch.MOE_ARCH_BLOCK.copy()
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, "mixtral", "MixtralSparseMoeBlock")
for key in list(original_map.keys()):
if key != "mixtral":
monkeypatch.setitem(arch.MOE_ARCH_BLOCK, key, None)
monkeypatch.setattr(
moe_grouped,
"get_moe_backend_name",
lambda preferred=None: moe_backends.MOEBackend.TORCH_GROUPED,
)
monkeypatch.setattr(torch_grouped_module, "available", lambda: True)
monkeypatch.setattr(
torch_grouped_module,
"moe_ffn_forward_grouped",
lambda *args, **kwargs: (None, None),
)
cfg = _DummyCfg()
moe_grouped.apply_grouped_to_moe_blocks(cfg)
block = DummyMixtralBlock()
hidden_states = torch.ones(1, 2, 3)
mask = torch.zeros(1, 2)
out, router = block.forward(hidden_states, attention_mask=mask)
assert torch.equal(out, hidden_states + 5)
assert router.shape[0] == hidden_states.shape[0] * hidden_states.shape[1]
assert block._calls, "Original forward should have been invoked"
call_hidden, call_mask = block._calls[-1]
assert torch.equal(call_hidden, hidden_states)
assert torch.equal(call_mask, mask)
def test_get_moe_backend_name_prefers_probe(monkeypatch):
monkeypatch.setattr(moe_backends, "_probe_torch_grouped", lambda: True)
assert moe_backends.get_moe_backend_name() == moe_backends.MOEBackend.TORCH_GROUPED
def test_get_moe_backend_name_falls_back(monkeypatch):
warnings_captured = []
def fake_warn(msg):
warnings_captured.append(msg)
monkeypatch.setattr(moe_backends, "_probe_torch_grouped", lambda: False)
monkeypatch.setattr(moe_backends.warnings, "warn", fake_warn)
backend = moe_backends.get_moe_backend_name("torch_grouped")
assert backend == moe_backends.MOEBackend.NAIVE
assert warnings_captured, "Expected warning when torch_grouped unavailable"