From 9c0d7ee761574d15970d5cc0c94096b17c30be0f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 7 Jul 2025 15:23:49 -0400 Subject: [PATCH] TiledMLP support (#2865) --- src/axolotl/loaders/patch_manager.py | 7 +++ src/axolotl/monkeypatch/tiled_mlp.py | 64 +++++++++++++++++++++++++ src/axolotl/utils/schemas/config.py | 14 ++++++ src/axolotl/utils/schemas/validation.py | 7 +++ 4 files changed, 92 insertions(+) create mode 100644 src/axolotl/monkeypatch/tiled_mlp.py diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 221a5fce8..48ee78cbc 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -66,6 +66,7 @@ class PatchManager: self._apply_self_attention_lora_patch() self._apply_gemma3_conditional_generation_forward_patch() self._apply_sequence_parallel_patches() + self._apply_tiled_mlp(self.cfg.model_config_type) def apply_post_model_load_patches(self, model: PreTrainedModel): """Apply patches that require the model instance.""" @@ -257,6 +258,12 @@ class PatchManager: patch_prepare_data_loader() patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp) + def _apply_tiled_mlp(self, model_type: str): + if self.cfg.tiled_mlp: + from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp + + patch_tiled_mlp(model_type, cfg_num_shards=self.cfg.tiled_mlp_num_shards) + def _patch_attention(self): """Apply attention-specific patches based on model type.""" if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")): diff --git a/src/axolotl/monkeypatch/tiled_mlp.py b/src/axolotl/monkeypatch/tiled_mlp.py new file mode 100644 index 000000000..4862ae78c --- /dev/null +++ b/src/axolotl/monkeypatch/tiled_mlp.py @@ -0,0 +1,64 @@ +"""Monkeypatch for Tiled MLP implementation""" + +import math + +import torch +import torch.distributed as dist + + +def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None): + from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP + + try: + # Dynamically import the module and MLP class + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + model_cls_prefix = "".join( + [part.capitalize() for part in model_type.split("_")] + ) + module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"]) + mlp_cls = getattr(module, f"{model_cls_prefix}MLP") + + if use_original_mlp: + mlp_forward = mlp_cls.forward + else: + + def generic_mlp_forward(self_, hs): + return self_.down_proj( + self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs) + ) + + mlp_forward = torch.compile(generic_mlp_forward) + + def tiled_mlp_forward(self, x): + input_shape = x.shape + seqlen = input_shape[-2] + hidden = input_shape[-1] + if cfg_num_shards is None: + num_shards = math.ceil(seqlen / hidden) + num_shards_tensor = torch.tensor(num_shards, device=x.device) + dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX) + num_shards = num_shards_tensor.item() + else: + num_shards = cfg_num_shards + + compute_params = [ + self.down_proj.weight, + self.gate_proj.weight, + self.up_proj.weight, + ] + + down_res = TiledMLP.apply( + mlp_forward, + self, + x, + num_shards, + compute_params, + ) + return down_res + + mlp_cls.forward = tiled_mlp_forward + except (ImportError, AttributeError) as e: + raise RuntimeError( + f"Could not import MLP class for model_type: {model_type}. " + f"Error: {str(e)}" + ) from e diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index d39b70219..94df7cde8 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -549,6 +549,20 @@ class AxolotlInputConfig( }, ) + tiled_mlp: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use ALST tiled mlp for memory efficient long context" + }, + ) + + tiled_mlp_num_shards: int | None = Field( + default=None, + json_schema_extra={ + "description": "Number of shards to use for ALST tiled mlp. If unset, it will be set based on seqlen/hidden_size" + }, + ) + llama4_linearized_experts: bool | None = None deepspeed: str | dict[str, Any] | None = Field( diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 3a0c9cc9f..af1341cda 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -476,6 +476,13 @@ class TrainingValidationMixin: return data + @model_validator(mode="before") + @classmethod + def check_tiled_mlp_deepspeed(cls, data): + if data.get("tiled_mlp", False) and not data.get("deepspeed"): + raise ValueError("tiled_mlp requires deepspeed ZeRO to be enabled") + return data + class LoRAValidationMixin: """Validation methods related to LoRA/QLoRA configuration."""