TiledMLP support for FSDP2 (#2950)

* make TiledMLP work with FSDP

* cleanup/gc at start of train to prevent large VRAM spike

* chore: lint

* generic function for non-deepspeed training

* unify patch to fix imports

* update readme for ALST and add examples

* make deepspeed attribute on params check more robust

* update with new info from PR review
This commit is contained in:
Wing Lian
2025-07-25 07:15:03 -04:00
committed by GitHub
parent 460e0f9ed9
commit f7ea140838
13 changed files with 330 additions and 26 deletions

View File

@@ -57,8 +57,12 @@ class LigerArgs(BaseModel):
@model_validator(mode="before")
@classmethod
def check_tiled_mlp_conflict(cls, data):
if data.get("liger_glu_activation") is True and data.get("tiled_mlp") is True:
if (
data.get("liger_glu_activation") is True
and data.get("tiled_mlp") is True
and not data.get("tiled_mlp_use_original_mlp")
):
raise ValueError(
"You cannot have both `liger_glu_activation` and `tiled_mlp` set."
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
)
return data

View File

@@ -162,6 +162,7 @@ class ModelLoader:
# Build the model
PLUGIN_MANAGER.pre_model_load(self.cfg)
self.patch_manager.apply_post_plugin_pre_model_load_patches()
skip_move_to_device = self._build_model()
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)

View File

@@ -66,6 +66,9 @@ class PatchManager:
self._apply_self_attention_lora_patch()
self._apply_gemma3_conditional_generation_forward_patch()
self._apply_sequence_parallel_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type)
def apply_post_model_load_patches(self, model: PreTrainedModel):
@@ -272,7 +275,9 @@ class PatchManager:
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
from axolotl.monkeypatch.tiled_mlp import (
patch_tiled_mlp,
)
patch_tiled_mlp(
model_type,

View File

@@ -0,0 +1,11 @@
"""
TiledMLP monkey patches
"""
from .patch import (
patch_tiled_mlp,
)
__all__ = [
"patch_tiled_mlp",
]

View File

@@ -0,0 +1,153 @@
"""
TiledMLP support for DDP, FSDP, and single GPU
"""
import threading
from typing import List
import torch
class TiledMLP(torch.autograd.Function):
"""
TiledMLP implementation using gradient hooks
"""
@staticmethod
def forward(
ctx,
fn,
self,
x,
shards,
compute_params,
) -> torch.Tensor:
ctx.fn = fn
ctx.self = self
ctx.shards = shards
ctx.compute_params = [p for p in compute_params if p.requires_grad]
ctx.save_for_backward(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
with torch.no_grad():
output_shards = [fn(self, x_shard) for x_shard in x_shards]
output_unsharded = torch.cat(output_shards, dim=1)
return output_unsharded
@staticmethod
def backward(ctx, *grads) -> torch.Tensor:
fn = ctx.fn
(x,) = ctx.saved_tensors
self = ctx.self
shards = ctx.shards
compute_params = ctx.compute_params
x_requires_grad = x.requires_grad
x = x.detach()
x.requires_grad_(x_requires_grad)
incoming_grad = grads[0]
x_grad = torch.zeros_like(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
# Create a gradient accumulator for parameters
grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype)
shard_step = x_shards[0].numel()
for i, x_shard in enumerate(x_shards):
x_shard.requires_grad_(x_requires_grad)
shard_offset = i * shard_step
x_shard.grad = (
x_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
incoming_grad_shard = (
incoming_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
# Install hooks for this shard
is_last_shard = i + 1 == shards
grad_accumulator.install_hooks(is_last_shard)
with torch.enable_grad():
output = fn(self, x_shard)
torch.autograd.backward(output, incoming_grad_shard)
# Clean up hooks
grad_accumulator.cleanup()
del grad_accumulator
return (None, None, x_grad, None, None)
class GradientAccumulator:
"""
Manual gradient accumulator for TiledMLP with configurable precision
Accumulates in specified dtype and rescales the gradient at the end
"""
def __init__(
self,
params: List[torch.nn.Parameter],
total_shards: int,
dtype: torch.dtype | None = None,
):
self.params = params
self.total_shards = total_shards
self.grad_accumulation_dtype = dtype or torch.float32
self.accumulated_grads = {}
self.hooks = []
self.lock = threading.Lock()
self.gradient_scale = 1.0 / total_shards
# Initialize accumulated gradients in the specified dtype
for param in self.params:
if param.grad is not None:
self.accumulated_grads[param] = param.grad.to(
self.grad_accumulation_dtype
)
param.grad = None
else:
self.accumulated_grads[param] = torch.zeros_like(
param, dtype=self.grad_accumulation_dtype
)
def install_hooks(self, is_last_shard: bool):
"""Install gradient hooks that accumulate gradients in higher precision"""
def create_hook(param):
def hook(grad):
with self.lock:
grad_to_accum_dtype = grad.to(self.grad_accumulation_dtype)
scaled_grad = grad_to_accum_dtype * self.gradient_scale
if param in self.accumulated_grads:
self.accumulated_grads[param] += scaled_grad
else:
self.accumulated_grads[param] = scaled_grad.clone()
# Only assign the averaged gradient on the last shard
if is_last_shard:
param.grad = self.accumulated_grads[param].to(param.dtype)
return param.grad
return None
return hook
# Install hooks on all parameters
for param in self.params:
if param.requires_grad:
hook = param.register_hook(create_hook(param))
self.hooks.append(hook)
def cleanup(self):
"""Remove all installed hooks"""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
del self.accumulated_grads

View File

@@ -12,8 +12,12 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
from deepspeed.runtime.sequence_parallel.ulysses_sp import (
TiledMLP as DeepSpeedTiledMLP,
)
from axolotl.monkeypatch.tiled_mlp.base import TiledMLP
try:
# Dynamically import the module and MLP class
@@ -36,6 +40,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
def tiled_mlp_forward(self, x):
# pylint: disable=protected-access
input_shape = x.shape
seqlen = input_shape[-2]
hidden = input_shape[-1]
@@ -48,14 +53,23 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
else:
num_shards = cfg_num_shards
if not self._compute_params: # pylint: disable=protected-access
self._compute_params = [ # pylint: disable=protected-access
p for p in self.parameters() if p.requires_grad
]
if not self._compute_params:
self._compute_params = [p for p in self.parameters() if p.requires_grad]
compute_params = self._compute_params # pylint: disable=protected-access
compute_params = self._compute_params
if not self._tiled_mlp_dist_impl:
if (
self._compute_params
and any(
hasattr(p, "ds_id") or hasattr(p, "param_idx_in_group")
for p in self._compute_params
)
) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
else:
self._tiled_mlp_dist_impl = TiledMLP
down_res = TiledMLP.apply(
down_res = self._tiled_mlp_dist_impl.apply(
mlp_forward,
self,
x,
@@ -66,6 +80,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
mlp_cls.forward = tiled_mlp_forward
mlp_cls._compute_params = [] # pylint: disable=protected-access
mlp_cls._tiled_mlp_dist_impl = None # pylint: disable=protected-access
LOG.info(
f"Successfully monkey-patched TiledMLP for model_type: {model_type}",
main_process_only=True,

View File

@@ -867,10 +867,16 @@ class GCCallback(TrainerCallback):
torch.cuda.empty_cache()
gc.collect()
def on_train_begin(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
self._gc()
def on_step_begin(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if self.next_gc_on_begin_step == state.global_step:
# pylint: disable=consider-using-in
if self.next_gc_on_begin_step == state.global_step or state.global_step == 0:
self._gc()
def on_step_end(

View File

@@ -597,7 +597,7 @@ class AxolotlInputConfig(
)
tiled_mlp_use_original_mlp: bool | None = Field(
default=None,
default=True,
json_schema_extra={
"description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama."
},

View File

@@ -512,19 +512,6 @@ class TrainingValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_tiled_mlp_deepspeed(cls, data):
capabilities = data.get("capabilities")
n_gpu = 0
if capabilities and capabilities.get("n_gpu", 0) >= 1:
n_gpu = capabilities.get("n_gpu", 0)
if data.get("tiled_mlp", False) and (n_gpu > 1 and not data.get("deepspeed")):
raise ValueError(
"tiled_mlp requires deepspeed ZeRO to be enabled for multi-gpu"
)
return data
class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration."""