TiledMLP support (#2865)

This commit is contained in:
Wing Lian
2025-07-07 15:23:49 -04:00
committed by GitHub
parent 22d4a838dc
commit 9c0d7ee761
4 changed files with 92 additions and 0 deletions

View File

@@ -66,6 +66,7 @@ class PatchManager:
self._apply_self_attention_lora_patch()
self._apply_gemma3_conditional_generation_forward_patch()
self._apply_sequence_parallel_patches()
self._apply_tiled_mlp(self.cfg.model_config_type)
def apply_post_model_load_patches(self, model: PreTrainedModel):
"""Apply patches that require the model instance."""
@@ -257,6 +258,12 @@ class PatchManager:
patch_prepare_data_loader()
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
patch_tiled_mlp(model_type, cfg_num_shards=self.cfg.tiled_mlp_num_shards)
def _patch_attention(self):
"""Apply attention-specific patches based on model type."""
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):

View File

@@ -0,0 +1,64 @@
"""Monkeypatch for Tiled MLP implementation"""
import math
import torch
import torch.distributed as dist
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
try:
# Dynamically import the module and MLP class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
if use_original_mlp:
mlp_forward = mlp_cls.forward
else:
def generic_mlp_forward(self_, hs):
return self_.down_proj(
self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs)
)
mlp_forward = torch.compile(generic_mlp_forward)
def tiled_mlp_forward(self, x):
input_shape = x.shape
seqlen = input_shape[-2]
hidden = input_shape[-1]
if cfg_num_shards is None:
num_shards = math.ceil(seqlen / hidden)
num_shards_tensor = torch.tensor(num_shards, device=x.device)
dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
num_shards = num_shards_tensor.item()
else:
num_shards = cfg_num_shards
compute_params = [
self.down_proj.weight,
self.gate_proj.weight,
self.up_proj.weight,
]
down_res = TiledMLP.apply(
mlp_forward,
self,
x,
num_shards,
compute_params,
)
return down_res
mlp_cls.forward = tiled_mlp_forward
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Could not import MLP class for model_type: {model_type}. "
f"Error: {str(e)}"
) from e

View File

@@ -549,6 +549,20 @@ class AxolotlInputConfig(
},
)
tiled_mlp: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use ALST tiled mlp for memory efficient long context"
},
)
tiled_mlp_num_shards: int | None = Field(
default=None,
json_schema_extra={
"description": "Number of shards to use for ALST tiled mlp. If unset, it will be set based on seqlen/hidden_size"
},
)
llama4_linearized_experts: bool | None = None
deepspeed: str | dict[str, Any] | None = Field(

View File

@@ -476,6 +476,13 @@ class TrainingValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_tiled_mlp_deepspeed(cls, data):
if data.get("tiled_mlp", False) and not data.get("deepspeed"):
raise ValueError("tiled_mlp requires deepspeed ZeRO to be enabled")
return data
class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration."""