diff --git a/src/axolotl/kernels/moe/backends.py b/src/axolotl/kernels/moe/backends.py index 49f365983..dee99aa84 100644 --- a/src/axolotl/kernels/moe/backends.py +++ b/src/axolotl/kernels/moe/backends.py @@ -29,7 +29,9 @@ def get_moe_backend_name(preferred: str | None = None) -> MOEBackend: try: selected = MOEBackend(choice) except ValueError: - warnings.warn(f"Unknown moe backend '{choice}', falling back to auto") + warnings.warn( + f"Unknown moe backend '{choice}', falling back to auto", stacklevel=2 + ) selected = MOEBackend.AUTO if selected == MOEBackend.AUTO: @@ -38,7 +40,8 @@ def get_moe_backend_name(preferred: str | None = None) -> MOEBackend: return MOEBackend.NAIVE if selected == MOEBackend.TORCH_GROUPED and not _probe_torch_grouped(): warnings.warn( - "torch_grouped requested but torch>=2.8 not detected; falling back to naive" + "torch_grouped requested but torch>=2.8 not detected; falling back to naive", + stacklevel=2, ) return MOEBackend.NAIVE return selected diff --git a/src/axolotl/kernels/moe/torch_grouped.py b/src/axolotl/kernels/moe/torch_grouped.py index f712eca37..aea1063b7 100644 --- a/src/axolotl/kernels/moe/torch_grouped.py +++ b/src/axolotl/kernels/moe/torch_grouped.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import List, Optional, Sequence, Tuple +from typing import List, Optional, Tuple import torch import torch.nn.functional as F @@ -24,14 +24,31 @@ def available() -> bool: def _stack_weights( - experts: Sequence[torch.nn.Module], names: Tuple[str, ...] + experts_module, + names: Tuple[str, ...], + *, + key: str, + dtype: torch.dtype, + device: torch.device, ) -> torch.Tensor: - stacked: List[torch.Tensor] = [] - for expert in experts: - mod = getattr(expert, "mlp", getattr(expert, "ffn", expert)) + attr = f"_ax_grouped_{key}" + cached = getattr(experts_module, attr, None) + if cached is not None and cached.dtype == dtype and cached.device == device: + return cached + + tensors: List[torch.Tensor] = [] + for exp in experts_module: + mod = getattr(exp, "mlp", getattr(exp, "ffn", exp)) parts = [getattr(mod, name).weight.t() for name in names] - stacked.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1)) - return torch.stack(stacked, dim=0) + tensors.append(parts[0] if len(parts) == 1 else torch.cat(parts, dim=-1)) + + stacked = ( + torch.stack(tensors, dim=0) + .to(device=device, dtype=dtype, non_blocking=True) + .contiguous() + ) + setattr(experts_module, attr, stacked) + return stacked def _call_grouped_mm( @@ -40,19 +57,22 @@ def _call_grouped_mm( if not As: return [] - As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As] - Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs] - device = As2[0].device - offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32) - Y_cat = torch.ops.aten._grouped_mm( - torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offs - ) - outs: List[torch.Tensor] = [] - start = 0 - for m in offs.tolist(): - outs.append(Y_cat[start : start + m]) - start += m - return outs + try: + As2 = [a.to(dtype).contiguous().view(a.shape[0], a.shape[1]) for a in As] + Bs2 = [b.to(dtype).contiguous().view(b.shape[0], b.shape[1]) for b in Bs] + device = As2[0].device + offs = torch.tensor([a.shape[0] for a in As2], device=device, dtype=torch.int32) + Y_cat = torch.ops.aten._grouped_mm( + torch.cat(As2, dim=0), torch.stack(Bs2, dim=0), offs + ) + outs: List[torch.Tensor] = [] + start = 0 + for m in offs.tolist(): + outs.append(Y_cat[start : start + m]) + start += m + return outs + except RuntimeError: + return None def moe_ffn_forward_grouped( @@ -77,22 +97,27 @@ def moe_ffn_forward_grouped( topk_weight, topk_idx = torch.topk(routing_weights, top_k, dim=-1, sorted=False) topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True) - experts = list(experts_module) - sample = getattr(experts[0], "mlp", getattr(experts[0], "ffn", experts[0])) + sample = getattr( + experts_module[0], "mlp", getattr(experts_module[0], "ffn", experts_module[0]) + ) if hasattr(sample, "w1") and hasattr(sample, "w3") and hasattr(sample, "w2"): - w13 = _stack_weights(experts, ("w1", "w3")).to( - device=device, dtype=expert_dtype + w13 = _stack_weights( + experts_module, ("w1", "w3"), key="w13", dtype=expert_dtype, device=device + ) + w2 = _stack_weights( + experts_module, ("w2",), key="w2", dtype=expert_dtype, device=device ) - w2 = _stack_weights(experts, ("w2",)).to(device=device, dtype=expert_dtype) else: names13 = ( ("gate_up_proj",) if hasattr(sample, "gate_up_proj") else ("up_proj", "gate_proj") ) - w13 = _stack_weights(experts, names13).to(device=device, dtype=expert_dtype) - w2 = _stack_weights(experts, ("down_proj",)).to( - device=device, dtype=expert_dtype + w13 = _stack_weights( + experts_module, names13, key="w13", dtype=expert_dtype, device=device + ) + w2 = _stack_weights( + experts_module, ("down_proj",), key="w2", dtype=expert_dtype, device=device ) flat_idx = topk_idx.view(-1) @@ -101,7 +126,7 @@ def moe_ffn_forward_grouped( as_list: List[torch.Tensor] = [] bs_list: List[torch.Tensor] = [] slices: List[Tuple[int, torch.Tensor]] = [] - for i in range(len(experts)): + for i, _ in enumerate(experts_module): sel = flat_idx == i if sel.any(): as_list.append(x_rep[sel]) diff --git a/src/axolotl/monkeypatch/mixtral/__init__.py b/src/axolotl/monkeypatch/mixtral/__init__.py index 2fb77869e..e2ac6c5d1 100644 --- a/src/axolotl/monkeypatch/mixtral/__init__.py +++ b/src/axolotl/monkeypatch/mixtral/__init__.py @@ -33,7 +33,8 @@ def patch_mixtral_moe_forward_zero3(cfg=None) -> None: and not _moe_backends._probe_torch_grouped() ): warnings.warn( - "torch_grouped selected but not available; falling back to naive" + "torch_grouped selected but not available; falling back to naive", + stacklevel=2, ) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)