TiledMLP support for FSDP2 (#2950)

* make TiledMLP work with FSDP

* cleanup/gc at start of train to prevent large VRAM spike

* chore: lint

* generic function for non-deepspeed training

* unify patch to fix imports

* update readme for ALST and add examples

* make deepspeed attribute on params check more robust

* update with new info from PR review
This commit is contained in:
Wing Lian
2025-07-25 07:15:03 -04:00
committed by GitHub
parent 460e0f9ed9
commit f7ea140838
13 changed files with 330 additions and 26 deletions

View File

@@ -25,6 +25,7 @@
## 🎉 Latest Updates
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!

9
examples/alst/README.md Normal file
View File

@@ -0,0 +1,9 @@
# Arctic Long Sequence Training (ALST)
Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization
techniques. It is a combination of:
- TiledMLP: Leverage tiling over the sequence dimension on MLP layers to reduce memory usage
- Tiled Loss: Using optimized loss functions like Liger-Kernel or Cut Cross Entropy to reduce memory usage
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).

View File

@@ -0,0 +1,53 @@
base_model: meta-llama/Llama-3.1-8B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
datasets:
- path: togethercomputer/Long-Data-Collections
type: completion
field: text
data_files:
- pretrain/rp_sub.jsonl.zst
- path: princeton-nlp/TextbookChapters
type: completion
field: chapter
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 500_000
min_sample_len: 200_000
sample_packing: true
tiled_mlp: true
sequence_parallel_degree: 8
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: auto
tf32: true
gradient_checkpointing: true
activation_offloading: legacy
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 100
saves_per_epoch: 1
evals_per_epoch: 2
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_all.json
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -0,0 +1,59 @@
base_model: meta-llama/Llama-3.1-8B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
datasets:
- path: togethercomputer/Long-Data-Collections
type: completion
field: text
data_files:
- pretrain/rp_sub.jsonl.zst
- path: princeton-nlp/TextbookChapters
type: completion
field: chapter
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 500_000
min_sample_len: 200_000
sample_packing: true
tiled_mlp: true
context_parallel_size: 8
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: auto
tf32: true
gradient_checkpointing: true
activation_offloading: legacy
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 100
saves_per_epoch: 1
evals_per_epoch: 2
weight_decay: 0.0
special_tokens:
pad_token: <|end_of_text|>
fsdp_version: 2
fsdp_config:
offload_params: false # offloading is currently not compatible with SP + torchao optimizer
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
reshard_after_forward: true
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -57,8 +57,12 @@ class LigerArgs(BaseModel):
@model_validator(mode="before")
@classmethod
def check_tiled_mlp_conflict(cls, data):
if data.get("liger_glu_activation") is True and data.get("tiled_mlp") is True:
if (
data.get("liger_glu_activation") is True
and data.get("tiled_mlp") is True
and not data.get("tiled_mlp_use_original_mlp")
):
raise ValueError(
"You cannot have both `liger_glu_activation` and `tiled_mlp` set."
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
)
return data

View File

@@ -162,6 +162,7 @@ class ModelLoader:
# Build the model
PLUGIN_MANAGER.pre_model_load(self.cfg)
self.patch_manager.apply_post_plugin_pre_model_load_patches()
skip_move_to_device = self._build_model()
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)

View File

@@ -66,6 +66,9 @@ class PatchManager:
self._apply_self_attention_lora_patch()
self._apply_gemma3_conditional_generation_forward_patch()
self._apply_sequence_parallel_patches()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
self._apply_tiled_mlp(self.cfg.model_config_type)
def apply_post_model_load_patches(self, model: PreTrainedModel):
@@ -272,7 +275,9 @@ class PatchManager:
def _apply_tiled_mlp(self, model_type: str):
if self.cfg.tiled_mlp:
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
from axolotl.monkeypatch.tiled_mlp import (
patch_tiled_mlp,
)
patch_tiled_mlp(
model_type,

View File

@@ -0,0 +1,11 @@
"""
TiledMLP monkey patches
"""
from .patch import (
patch_tiled_mlp,
)
__all__ = [
"patch_tiled_mlp",
]

View File

@@ -0,0 +1,153 @@
"""
TiledMLP support for DDP, FSDP, and single GPU
"""
import threading
from typing import List
import torch
class TiledMLP(torch.autograd.Function):
"""
TiledMLP implementation using gradient hooks
"""
@staticmethod
def forward(
ctx,
fn,
self,
x,
shards,
compute_params,
) -> torch.Tensor:
ctx.fn = fn
ctx.self = self
ctx.shards = shards
ctx.compute_params = [p for p in compute_params if p.requires_grad]
ctx.save_for_backward(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
with torch.no_grad():
output_shards = [fn(self, x_shard) for x_shard in x_shards]
output_unsharded = torch.cat(output_shards, dim=1)
return output_unsharded
@staticmethod
def backward(ctx, *grads) -> torch.Tensor:
fn = ctx.fn
(x,) = ctx.saved_tensors
self = ctx.self
shards = ctx.shards
compute_params = ctx.compute_params
x_requires_grad = x.requires_grad
x = x.detach()
x.requires_grad_(x_requires_grad)
incoming_grad = grads[0]
x_grad = torch.zeros_like(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
# Create a gradient accumulator for parameters
grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype)
shard_step = x_shards[0].numel()
for i, x_shard in enumerate(x_shards):
x_shard.requires_grad_(x_requires_grad)
shard_offset = i * shard_step
x_shard.grad = (
x_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
incoming_grad_shard = (
incoming_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
# Install hooks for this shard
is_last_shard = i + 1 == shards
grad_accumulator.install_hooks(is_last_shard)
with torch.enable_grad():
output = fn(self, x_shard)
torch.autograd.backward(output, incoming_grad_shard)
# Clean up hooks
grad_accumulator.cleanup()
del grad_accumulator
return (None, None, x_grad, None, None)
class GradientAccumulator:
"""
Manual gradient accumulator for TiledMLP with configurable precision
Accumulates in specified dtype and rescales the gradient at the end
"""
def __init__(
self,
params: List[torch.nn.Parameter],
total_shards: int,
dtype: torch.dtype | None = None,
):
self.params = params
self.total_shards = total_shards
self.grad_accumulation_dtype = dtype or torch.float32
self.accumulated_grads = {}
self.hooks = []
self.lock = threading.Lock()
self.gradient_scale = 1.0 / total_shards
# Initialize accumulated gradients in the specified dtype
for param in self.params:
if param.grad is not None:
self.accumulated_grads[param] = param.grad.to(
self.grad_accumulation_dtype
)
param.grad = None
else:
self.accumulated_grads[param] = torch.zeros_like(
param, dtype=self.grad_accumulation_dtype
)
def install_hooks(self, is_last_shard: bool):
"""Install gradient hooks that accumulate gradients in higher precision"""
def create_hook(param):
def hook(grad):
with self.lock:
grad_to_accum_dtype = grad.to(self.grad_accumulation_dtype)
scaled_grad = grad_to_accum_dtype * self.gradient_scale
if param in self.accumulated_grads:
self.accumulated_grads[param] += scaled_grad
else:
self.accumulated_grads[param] = scaled_grad.clone()
# Only assign the averaged gradient on the last shard
if is_last_shard:
param.grad = self.accumulated_grads[param].to(param.dtype)
return param.grad
return None
return hook
# Install hooks on all parameters
for param in self.params:
if param.requires_grad:
hook = param.register_hook(create_hook(param))
self.hooks.append(hook)
def cleanup(self):
"""Remove all installed hooks"""
for hook in self.hooks:
hook.remove()
self.hooks.clear()
del self.accumulated_grads

View File

@@ -12,8 +12,12 @@ from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
from deepspeed.runtime.sequence_parallel.ulysses_sp import (
TiledMLP as DeepSpeedTiledMLP,
)
from axolotl.monkeypatch.tiled_mlp.base import TiledMLP
try:
# Dynamically import the module and MLP class
@@ -36,6 +40,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
def tiled_mlp_forward(self, x):
# pylint: disable=protected-access
input_shape = x.shape
seqlen = input_shape[-2]
hidden = input_shape[-1]
@@ -48,14 +53,23 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
else:
num_shards = cfg_num_shards
if not self._compute_params: # pylint: disable=protected-access
self._compute_params = [ # pylint: disable=protected-access
p for p in self.parameters() if p.requires_grad
]
if not self._compute_params:
self._compute_params = [p for p in self.parameters() if p.requires_grad]
compute_params = self._compute_params # pylint: disable=protected-access
compute_params = self._compute_params
if not self._tiled_mlp_dist_impl:
if (
self._compute_params
and any(
hasattr(p, "ds_id") or hasattr(p, "param_idx_in_group")
for p in self._compute_params
)
) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
else:
self._tiled_mlp_dist_impl = TiledMLP
down_res = TiledMLP.apply(
down_res = self._tiled_mlp_dist_impl.apply(
mlp_forward,
self,
x,
@@ -66,6 +80,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
mlp_cls.forward = tiled_mlp_forward
mlp_cls._compute_params = [] # pylint: disable=protected-access
mlp_cls._tiled_mlp_dist_impl = None # pylint: disable=protected-access
LOG.info(
f"Successfully monkey-patched TiledMLP for model_type: {model_type}",
main_process_only=True,

View File

@@ -867,10 +867,16 @@ class GCCallback(TrainerCallback):
torch.cuda.empty_cache()
gc.collect()
def on_train_begin(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
self._gc()
def on_step_begin(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if self.next_gc_on_begin_step == state.global_step:
# pylint: disable=consider-using-in
if self.next_gc_on_begin_step == state.global_step or state.global_step == 0:
self._gc()
def on_step_end(

View File

@@ -597,7 +597,7 @@ class AxolotlInputConfig(
)
tiled_mlp_use_original_mlp: bool | None = Field(
default=None,
default=True,
json_schema_extra={
"description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama."
},

View File

@@ -512,19 +512,6 @@ class TrainingValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_tiled_mlp_deepspeed(cls, data):
capabilities = data.get("capabilities")
n_gpu = 0
if capabilities and capabilities.get("n_gpu", 0) >= 1:
n_gpu = capabilities.get("n_gpu", 0)
if data.get("tiled_mlp", False) and (n_gpu > 1 and not data.get("deepspeed")):
raise ValueError(
"tiled_mlp requires deepspeed ZeRO to be enabled for multi-gpu"
)
return data
class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration."""