TiledMLP support (#2865)
This commit is contained in:
@@ -66,6 +66,7 @@ class PatchManager:
|
||||
self._apply_self_attention_lora_patch()
|
||||
self._apply_gemma3_conditional_generation_forward_patch()
|
||||
self._apply_sequence_parallel_patches()
|
||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
"""Apply patches that require the model instance."""
|
||||
@@ -257,6 +258,12 @@ class PatchManager:
|
||||
patch_prepare_data_loader()
|
||||
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
||||
|
||||
def _apply_tiled_mlp(self, model_type: str):
|
||||
if self.cfg.tiled_mlp:
|
||||
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
|
||||
|
||||
patch_tiled_mlp(model_type, cfg_num_shards=self.cfg.tiled_mlp_num_shards)
|
||||
|
||||
def _patch_attention(self):
|
||||
"""Apply attention-specific patches based on model type."""
|
||||
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
||||
|
||||
64
src/axolotl/monkeypatch/tiled_mlp.py
Normal file
64
src/axolotl/monkeypatch/tiled_mlp.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Monkeypatch for Tiled MLP implementation"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
||||
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
|
||||
|
||||
try:
|
||||
# Dynamically import the module and MLP class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
model_cls_prefix = "".join(
|
||||
[part.capitalize() for part in model_type.split("_")]
|
||||
)
|
||||
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
||||
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
|
||||
|
||||
if use_original_mlp:
|
||||
mlp_forward = mlp_cls.forward
|
||||
else:
|
||||
|
||||
def generic_mlp_forward(self_, hs):
|
||||
return self_.down_proj(
|
||||
self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs)
|
||||
)
|
||||
|
||||
mlp_forward = torch.compile(generic_mlp_forward)
|
||||
|
||||
def tiled_mlp_forward(self, x):
|
||||
input_shape = x.shape
|
||||
seqlen = input_shape[-2]
|
||||
hidden = input_shape[-1]
|
||||
if cfg_num_shards is None:
|
||||
num_shards = math.ceil(seqlen / hidden)
|
||||
num_shards_tensor = torch.tensor(num_shards, device=x.device)
|
||||
dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
|
||||
num_shards = num_shards_tensor.item()
|
||||
else:
|
||||
num_shards = cfg_num_shards
|
||||
|
||||
compute_params = [
|
||||
self.down_proj.weight,
|
||||
self.gate_proj.weight,
|
||||
self.up_proj.weight,
|
||||
]
|
||||
|
||||
down_res = TiledMLP.apply(
|
||||
mlp_forward,
|
||||
self,
|
||||
x,
|
||||
num_shards,
|
||||
compute_params,
|
||||
)
|
||||
return down_res
|
||||
|
||||
mlp_cls.forward = tiled_mlp_forward
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise RuntimeError(
|
||||
f"Could not import MLP class for model_type: {model_type}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
@@ -549,6 +549,20 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
tiled_mlp: bool | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Whether to use ALST tiled mlp for memory efficient long context"
|
||||
},
|
||||
)
|
||||
|
||||
tiled_mlp_num_shards: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "Number of shards to use for ALST tiled mlp. If unset, it will be set based on seqlen/hidden_size"
|
||||
},
|
||||
)
|
||||
|
||||
llama4_linearized_experts: bool | None = None
|
||||
|
||||
deepspeed: str | dict[str, Any] | None = Field(
|
||||
|
||||
@@ -476,6 +476,13 @@ class TrainingValidationMixin:
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tiled_mlp_deepspeed(cls, data):
|
||||
if data.get("tiled_mlp", False) and not data.get("deepspeed"):
|
||||
raise ValueError("tiled_mlp requires deepspeed ZeRO to be enabled")
|
||||
return data
|
||||
|
||||
|
||||
class LoRAValidationMixin:
|
||||
"""Validation methods related to LoRA/QLoRA configuration."""
|
||||
|
||||
Reference in New Issue
Block a user