diff --git a/src/axolotl/integrations/kernels/__init__.py b/src/axolotl/integrations/kernels/__init__.py new file mode 100644 index 000000000..d87942435 --- /dev/null +++ b/src/axolotl/integrations/kernels/__init__.py @@ -0,0 +1,7 @@ +from .args import KernelsArgs +from .plugin import KernelsPlugin + +__all__ = [ + "KernelsArgs", + "KernelsPlugin", +] diff --git a/src/axolotl/integrations/kernels/args.py b/src/axolotl/integrations/kernels/args.py new file mode 100644 index 000000000..66d6b6d53 --- /dev/null +++ b/src/axolotl/integrations/kernels/args.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel, model_validator + +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + + +class KernelsArgs(BaseModel): + use_scattermoe: bool | None = True + + @model_validator(mode="before") + @classmethod + def check_use_kernels(cls, data): + if data.get("use_kernels") is not True: + LOG.warning( + "`use_kernels` must be set to True to use this. Automatically setting it to True." + ) + data["use_kernels"] = True + + return data + + @model_validator(mode="before") + @classmethod + def check_experts_implementation(cls, data): + experts_implementation = data.get("experts_implementation") + if experts_implementation is None: + # transformers may default to batched_mm when unset + data["experts_implementation"] = "eager" + elif experts_implementation != "eager": + LOG.warning( + "`experts_implementation` must be set to 'eager' to use this. Automatically setting it to 'eager'." + ) + data["experts_implementation"] = "eager" + + return data diff --git a/src/axolotl/integrations/kernels/plugin.py b/src/axolotl/integrations/kernels/plugin.py new file mode 100644 index 000000000..c7fb79ff6 --- /dev/null +++ b/src/axolotl/integrations/kernels/plugin.py @@ -0,0 +1,61 @@ +from kernels import ( + LayerRepository, + Mode, + register_kernel_mapping, + replace_kernel_forward_from_hub, +) + +from axolotl.integrations.base import BasePlugin +from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix + + +class KernelsPlugin(BasePlugin): + def get_input_args(self): + return "axolotl.integrations.kernels.KernelsArgs" + + def pre_model_load(self, cfg): + if cfg.use_scattermoe: + self._register_kernels() + self._kernelize_model(cfg.model_config_type) + + def _register_kernels(self): + register_kernel_mapping( + { + "HFScatterMoEParallelExperts": { + "cuda": { + Mode.TRAINING: LayerRepository( + repo_id="axolotl-ai-co/scattermoe", + layer_name="HFScatterMoEGatedMLP", + ), + Mode.INFERENCE: LayerRepository( + repo_id="axolotl-ai-co/scattermoe", + layer_name="HFScatterMoEGatedMLP", + ), + }, + } + } + ) + + def _kernelize_model(self, model_type: str): + if model_type == "olmoe": + from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock + + replace_kernel_forward_from_hub( + OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts" + ) + else: + try: + model_moe_cls = get_model_moe_block(model_type) + replace_kernel_forward_from_hub( + model_moe_cls, "HFScatterMoEParallelExperts" + ) + except Exception as err: + raise ValueError(f"Unsupported model type: {model_type}") from err + + +def get_model_moe_block(model_type: str): + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type) + module = __import__(module_path, fromlist=[f"{model_cls_prefix}SparseMoeBlock"]) + model_cls = getattr(module, f"{model_cls_prefix}SparseMoeBlock") + return model_cls diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 0133148eb..75684c1ae 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -225,6 +225,7 @@ class ModelLoader: ): self.model = self.model.merge_and_unload() + self._configure_experts_implementation() self._apply_activation_checkpointing() self._resize_token_embeddings() self._adjust_model_config() @@ -232,6 +233,10 @@ class ModelLoader: self._configure_qat() log_gpu_memory_usage(LOG, "Memory usage after model load", 0) + def _configure_experts_implementation(self): + if self.cfg.experts_implementation is not None: + self.model.set_experts_implementation(self.cfg.experts_implementation) + def _apply_activation_checkpointing(self): if self.cfg.activation_offloading is True: from axolotl.core.trainers.mixins.activation_checkpointing import ( diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 3621c0d89..653773273 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -619,6 +619,13 @@ class AxolotlInputConfig( }, ) + experts_implementation: str | None = Field( + default=None, + json_schema_extra={ + "description": "Which experts implementation to use for MoE models," + }, + ) + scaling_softmax: bool | None = Field( default=None, json_schema_extra={