patch
This commit is contained in:
@@ -12,6 +12,7 @@ import transformers
|
|||||||
from transformers import PretrainedConfig, PreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
|
from axolotl.monkeypatch.moe_grouped import apply_grouped_to_moe_blocks
|
||||||
from axolotl.monkeypatch.multipack import (
|
from axolotl.monkeypatch.multipack import (
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||||
patch_for_multipack,
|
patch_for_multipack,
|
||||||
@@ -57,6 +58,8 @@ class PatchManager:
|
|||||||
self._apply_fsdp_patches()
|
self._apply_fsdp_patches()
|
||||||
self._apply_adapter_patches()
|
self._apply_adapter_patches()
|
||||||
self._apply_model_specific_patches()
|
self._apply_model_specific_patches()
|
||||||
|
# Apply MoE grouped GEMM patches (cfg.moe_backend)
|
||||||
|
apply_grouped_to_moe_blocks(self.cfg)
|
||||||
self._apply_fp8_patches()
|
self._apply_fp8_patches()
|
||||||
self._apply_flash_attention_peft_patches()
|
self._apply_flash_attention_peft_patches()
|
||||||
self._apply_gradient_checkpointing_patches()
|
self._apply_gradient_checkpointing_patches()
|
||||||
|
|||||||
96
src/axolotl/monkeypatch/moe_grouped.py
Normal file
96
src/axolotl/monkeypatch/moe_grouped.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||||
|
from axolotl.kernels.moe.backends import MOEBackend, get_moe_backend_name
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_block_forward(block_cls, grouped_fn):
|
||||||
|
"""Replace block_cls.forward with grouped_fn preserving signature."""
|
||||||
|
setattr(block_cls, "forward", grouped_fn)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_grouped_to_moe_blocks(cfg=None) -> None:
|
||||||
|
"""
|
||||||
|
Attempt to patch all known MoE block classes to use the torch_grouped backend
|
||||||
|
when cfg.moe_backend resolves to 'torch_grouped' and the op is available.
|
||||||
|
Falls back to original forwards otherwise.
|
||||||
|
"""
|
||||||
|
preferred = getattr(cfg, "moe_backend", None) if cfg is not None else None
|
||||||
|
backend = get_moe_backend_name(preferred)
|
||||||
|
if backend != MOEBackend.TORCH_GROUPED:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
from axolotl.kernels.moe import torch_grouped as _tg
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
if not _tg.available():
|
||||||
|
warnings.warn("torch_grouped requested but unavailable; skipping MoE patches")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Map of architecture key to (modeling module path, class name or list of class names)
|
||||||
|
model_mods = {
|
||||||
|
"mixtral": (
|
||||||
|
"transformers.models.mixtral.modeling_mixtral",
|
||||||
|
MOE_ARCH_BLOCK.get("mixtral"),
|
||||||
|
),
|
||||||
|
"qwen2_moe": (
|
||||||
|
"transformers.models.qwen2_moe.modeling_qwen2_moe",
|
||||||
|
MOE_ARCH_BLOCK.get("qwen2_moe"),
|
||||||
|
),
|
||||||
|
"qwen3_moe": (
|
||||||
|
"transformers.models.qwen3_moe.modeling_qwen3_moe",
|
||||||
|
MOE_ARCH_BLOCK.get("qwen3_moe"),
|
||||||
|
),
|
||||||
|
"jamba": (
|
||||||
|
"transformers.models.jamba.modeling_jamba",
|
||||||
|
MOE_ARCH_BLOCK.get("jamba"),
|
||||||
|
),
|
||||||
|
"deepseek_v2": (
|
||||||
|
"transformers.models.deepseek_v2.modeling_deepseek_v2",
|
||||||
|
MOE_ARCH_BLOCK.get("deepseek_v2"),
|
||||||
|
),
|
||||||
|
# Others may not follow standard paths; best-effort import
|
||||||
|
"dbrx": ("transformers.models.dbrx.modeling_dbrx", MOE_ARCH_BLOCK.get("dbrx")),
|
||||||
|
"jetmoe": (
|
||||||
|
"transformers.models.jetmoe.modeling_jetmoe",
|
||||||
|
MOE_ARCH_BLOCK.get("jetmoe"),
|
||||||
|
),
|
||||||
|
"gpt_oss": (
|
||||||
|
"transformers.models.gpt_oss.modeling_gpt_oss",
|
||||||
|
MOE_ARCH_BLOCK.get("gpt_oss"),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def make_grouped_forward(orig_forward):
|
||||||
|
def _grouped_forward(self, hidden_states: torch.Tensor):
|
||||||
|
bsz, seqlen, hdim = hidden_states.shape
|
||||||
|
y, router_logits = _tg.moe_ffn_forward_grouped(
|
||||||
|
hidden_states, self.gate, self.experts, self.top_k
|
||||||
|
)
|
||||||
|
if y is None:
|
||||||
|
return orig_forward(self, hidden_states)
|
||||||
|
return y, router_logits
|
||||||
|
|
||||||
|
return _grouped_forward
|
||||||
|
|
||||||
|
for key, (mod_path, cls_names) in model_mods.items():
|
||||||
|
if not cls_names:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
modeling = importlib.import_module(mod_path)
|
||||||
|
names = cls_names if isinstance(cls_names, list) else [cls_names]
|
||||||
|
for name in names:
|
||||||
|
if not hasattr(modeling, name):
|
||||||
|
continue
|
||||||
|
block_cls = getattr(modeling, name)
|
||||||
|
orig_forward = getattr(block_cls, "forward", None)
|
||||||
|
if orig_forward is None:
|
||||||
|
continue
|
||||||
|
_patch_block_forward(block_cls, make_grouped_forward(orig_forward))
|
||||||
|
except Exception as e:
|
||||||
|
# Best effort; log and skip this entry
|
||||||
|
warnings.warn(f"Skipping MoE patch for {key}: {e}")
|
||||||
Reference in New Issue
Block a user