From de4344a56e017658669c8d7f728f579a10879bca Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 15 Sep 2025 23:15:20 -0400 Subject: [PATCH] patch --- src/axolotl/loaders/patch_manager.py | 3 + src/axolotl/monkeypatch/moe_grouped.py | 96 ++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 src/axolotl/monkeypatch/moe_grouped.py diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 71933a9dd..eed35a027 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -12,6 +12,7 @@ import transformers from transformers import PretrainedConfig, PreTrainedModel from axolotl.integrations.base import PluginManager +from axolotl.monkeypatch.moe_grouped import apply_grouped_to_moe_blocks from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, patch_for_multipack, @@ -57,6 +58,8 @@ class PatchManager: self._apply_fsdp_patches() self._apply_adapter_patches() self._apply_model_specific_patches() + # Apply MoE grouped GEMM patches (cfg.moe_backend) + apply_grouped_to_moe_blocks(self.cfg) self._apply_fp8_patches() self._apply_flash_attention_peft_patches() self._apply_gradient_checkpointing_patches() diff --git a/src/axolotl/monkeypatch/moe_grouped.py b/src/axolotl/monkeypatch/moe_grouped.py new file mode 100644 index 000000000..24408d03a --- /dev/null +++ b/src/axolotl/monkeypatch/moe_grouped.py @@ -0,0 +1,96 @@ +import warnings + +import torch + +from axolotl.common.architectures import MOE_ARCH_BLOCK +from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name + + +def _patch_block_forward(block_cls, grouped_fn): + """Replace block_cls.forward with grouped_fn preserving signature.""" + setattr(block_cls, "forward", grouped_fn) + + +def apply_grouped_to_moe_blocks(cfg=None) -> None: + """ + Attempt to patch all known MoE block classes to use the torch_grouped backend + when cfg.moe_backend resolves to 'torch_grouped' and the op is available. + Falls back to original forwards otherwise. + """ + preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None + backend = get_moe_backend_name(preferred) + if backend != MOEBackend.TORCH_GROUPED: + return + try: + from axolotl.kernels.moe import torch_grouped as _tg + except Exception: + return + if not _tg.available(): + warnings.warn("torch_grouped requested but unavailable; skipping MoE patches") + return + + # Map of architecture key to (modeling module path, class name or list of class names) + model_mods = { + "mixtral": ( + "transformers.models.mixtral.modeling_mixtral", + MOE_ARCH_BLOCK.get("mixtral"), + ), + "qwen2_moe": ( + "transformers.models.qwen2_moe.modeling_qwen2_moe", + MOE_ARCH_BLOCK.get("qwen2_moe"), + ), + "qwen3_moe": ( + "transformers.models.qwen3_moe.modeling_qwen3_moe", + MOE_ARCH_BLOCK.get("qwen3_moe"), + ), + "jamba": ( + "transformers.models.jamba.modeling_jamba", + MOE_ARCH_BLOCK.get("jamba"), + ), + "deepseek_v2": ( + "transformers.models.deepseek_v2.modeling_deepseek_v2", + MOE_ARCH_BLOCK.get("deepseek_v2"), + ), + # Others may not follow standard paths; best-effort import + "dbrx": ("transformers.models.dbrx.modeling_dbrx", MOE_ARCH_BLOCK.get("dbrx")), + "jetmoe": ( + "transformers.models.jetmoe.modeling_jetmoe", + MOE_ARCH_BLOCK.get("jetmoe"), + ), + "gpt_oss": ( + "transformers.models.gpt_oss.modeling_gpt_oss", + MOE_ARCH_BLOCK.get("gpt_oss"), + ), + } + + def make_grouped_forward(orig_forward): + def _grouped_forward(self, hidden_states: torch.Tensor): + bsz, seqlen, hdim = hidden_states.shape + y, router_logits = _tg.moe_ffn_forward_grouped( + hidden_states, self.gate, self.experts, self.top_k + ) + if y is None: + return orig_forward(self, hidden_states) + return y, router_logits + + return _grouped_forward + + for key, (mod_path, cls_names) in model_mods.items(): + if not cls_names: + continue + try: + import importlib + + modeling = importlib.import_module(mod_path) + names = cls_names if isinstance(cls_names, list) else [cls_names] + for name in names: + if not hasattr(modeling, name): + continue + block_cls = getattr(modeling, name) + orig_forward = getattr(block_cls, "forward", None) + if orig_forward is None: + continue + _patch_block_forward(block_cls, make_grouped_forward(orig_forward)) + except Exception as e: + # Best effort; log and skip this entry + warnings.warn(f"Skipping MoE patch for {key}: {e}")