Files
axolotl/src/axolotl/integrations/kernels/plugin.py
Wing Lian 3738978394 Add support for batched_mm, grouped_mm and scattermoe for MoE models (#3377)
* kernels plugin for moe for v5

* add support for native batched_mm or grouped_mm
2026-01-29 14:25:47 -05:00

62 lines
2.1 KiB
Python

from kernels import (
LayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
class KernelsPlugin(BasePlugin):
def get_input_args(self):
return "axolotl.integrations.kernels.KernelsArgs"
def pre_model_load(self, cfg):
if cfg.use_scattermoe:
self._register_kernels()
self._kernelize_model(cfg.model_config_type)
def _register_kernels(self):
register_kernel_mapping(
{
"HFScatterMoEParallelExperts": {
"cuda": {
Mode.TRAINING: LayerRepository(
repo_id="axolotl-ai-co/scattermoe",
layer_name="HFScatterMoEGatedMLP",
),
Mode.INFERENCE: LayerRepository(
repo_id="axolotl-ai-co/scattermoe",
layer_name="HFScatterMoEGatedMLP",
),
},
}
}
)
def _kernelize_model(self, model_type: str):
if model_type == "olmoe":
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
replace_kernel_forward_from_hub(
OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts"
)
else:
try:
model_moe_cls = get_model_moe_block(model_type)
replace_kernel_forward_from_hub(
model_moe_cls, "HFScatterMoEParallelExperts"
)
except Exception as err:
raise ValueError(f"Unsupported model type: {model_type}") from err
def get_model_moe_block(model_type: str):
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}SparseMoeBlock"])
model_cls = getattr(module, f"{model_cls_prefix}SparseMoeBlock")
return model_cls