TiledMLP support (#2865)
This commit is contained in:
@@ -66,6 +66,7 @@ class PatchManager:
|
|||||||
self._apply_self_attention_lora_patch()
|
self._apply_self_attention_lora_patch()
|
||||||
self._apply_gemma3_conditional_generation_forward_patch()
|
self._apply_gemma3_conditional_generation_forward_patch()
|
||||||
self._apply_sequence_parallel_patches()
|
self._apply_sequence_parallel_patches()
|
||||||
|
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
@@ -257,6 +258,12 @@ class PatchManager:
|
|||||||
patch_prepare_data_loader()
|
patch_prepare_data_loader()
|
||||||
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
||||||
|
|
||||||
|
def _apply_tiled_mlp(self, model_type: str):
|
||||||
|
if self.cfg.tiled_mlp:
|
||||||
|
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
|
||||||
|
|
||||||
|
patch_tiled_mlp(model_type, cfg_num_shards=self.cfg.tiled_mlp_num_shards)
|
||||||
|
|
||||||
def _patch_attention(self):
|
def _patch_attention(self):
|
||||||
"""Apply attention-specific patches based on model type."""
|
"""Apply attention-specific patches based on model type."""
|
||||||
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
||||||
|
|||||||
64
src/axolotl/monkeypatch/tiled_mlp.py
Normal file
64
src/axolotl/monkeypatch/tiled_mlp.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""Monkeypatch for Tiled MLP implementation"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
||||||
|
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Dynamically import the module and MLP class
|
||||||
|
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||||
|
model_cls_prefix = "".join(
|
||||||
|
[part.capitalize() for part in model_type.split("_")]
|
||||||
|
)
|
||||||
|
module = __import__(module_path, fromlist=[f"{model_cls_prefix}MLP"])
|
||||||
|
mlp_cls = getattr(module, f"{model_cls_prefix}MLP")
|
||||||
|
|
||||||
|
if use_original_mlp:
|
||||||
|
mlp_forward = mlp_cls.forward
|
||||||
|
else:
|
||||||
|
|
||||||
|
def generic_mlp_forward(self_, hs):
|
||||||
|
return self_.down_proj(
|
||||||
|
self_.act_fn(self_.gate_proj(hs)) * self_.up_proj(hs)
|
||||||
|
)
|
||||||
|
|
||||||
|
mlp_forward = torch.compile(generic_mlp_forward)
|
||||||
|
|
||||||
|
def tiled_mlp_forward(self, x):
|
||||||
|
input_shape = x.shape
|
||||||
|
seqlen = input_shape[-2]
|
||||||
|
hidden = input_shape[-1]
|
||||||
|
if cfg_num_shards is None:
|
||||||
|
num_shards = math.ceil(seqlen / hidden)
|
||||||
|
num_shards_tensor = torch.tensor(num_shards, device=x.device)
|
||||||
|
dist.all_reduce(num_shards_tensor, op=dist.ReduceOp.MAX)
|
||||||
|
num_shards = num_shards_tensor.item()
|
||||||
|
else:
|
||||||
|
num_shards = cfg_num_shards
|
||||||
|
|
||||||
|
compute_params = [
|
||||||
|
self.down_proj.weight,
|
||||||
|
self.gate_proj.weight,
|
||||||
|
self.up_proj.weight,
|
||||||
|
]
|
||||||
|
|
||||||
|
down_res = TiledMLP.apply(
|
||||||
|
mlp_forward,
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
num_shards,
|
||||||
|
compute_params,
|
||||||
|
)
|
||||||
|
return down_res
|
||||||
|
|
||||||
|
mlp_cls.forward = tiled_mlp_forward
|
||||||
|
except (ImportError, AttributeError) as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not import MLP class for model_type: {model_type}. "
|
||||||
|
f"Error: {str(e)}"
|
||||||
|
) from e
|
||||||
@@ -549,6 +549,20 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tiled_mlp: bool | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Whether to use ALST tiled mlp for memory efficient long context"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
tiled_mlp_num_shards: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Number of shards to use for ALST tiled mlp. If unset, it will be set based on seqlen/hidden_size"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
llama4_linearized_experts: bool | None = None
|
llama4_linearized_experts: bool | None = None
|
||||||
|
|
||||||
deepspeed: str | dict[str, Any] | None = Field(
|
deepspeed: str | dict[str, Any] | None = Field(
|
||||||
|
|||||||
@@ -476,6 +476,13 @@ class TrainingValidationMixin:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_tiled_mlp_deepspeed(cls, data):
|
||||||
|
if data.get("tiled_mlp", False) and not data.get("deepspeed"):
|
||||||
|
raise ValueError("tiled_mlp requires deepspeed ZeRO to be enabled")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class LoRAValidationMixin:
|
class LoRAValidationMixin:
|
||||||
"""Validation methods related to LoRA/QLoRA configuration."""
|
"""Validation methods related to LoRA/QLoRA configuration."""
|
||||||
|
|||||||
Reference in New Issue
Block a user