TiledMLP support for FSDP2 (#2950)
* make TiledMLP work with FSDP * cleanup/gc at start of train to prevent large VRAM spike * chore: lint * generic function for non-deepspeed training * unify patch to fix imports * update readme for ALST and add examples * make deepspeed attribute on params check more robust * update with new info from PR review
This commit is contained in:
@@ -25,6 +25,7 @@
|
||||
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
|
||||
9
examples/alst/README.md
Normal file
9
examples/alst/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Arctic Long Sequence Training (ALST)
|
||||
|
||||
Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization
|
||||
techniques. It is a combination of:
|
||||
- TiledMLP: Leverage tiling over the sequence dimension on MLP layers to reduce memory usage
|
||||
- Tiled Loss: Using optimized loss functions like Liger-Kernel or Cut Cross Entropy to reduce memory usage
|
||||
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
|
||||
|
||||
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).
|
||||
53
examples/alst/llama3-8b-deepspeed-alst.yaml
Normal file
53
examples/alst/llama3-8b-deepspeed-alst.yaml
Normal file
@@ -0,0 +1,53 @@
|
||||
base_model: meta-llama/Llama-3.1-8B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
datasets:
|
||||
- path: togethercomputer/Long-Data-Collections
|
||||
type: completion
|
||||
field: text
|
||||
data_files:
|
||||
- pretrain/rp_sub.jsonl.zst
|
||||
- path: princeton-nlp/TextbookChapters
|
||||
type: completion
|
||||
field: chapter
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 500_000
|
||||
min_sample_len: 200_000
|
||||
sample_packing: true
|
||||
|
||||
tiled_mlp: true
|
||||
sequence_parallel_degree: 8
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: legacy
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 100
|
||||
saves_per_epoch: 1
|
||||
evals_per_epoch: 2
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
|
||||
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_all.json
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
59
examples/alst/llama3-8b-fsdp2-alst.yaml
Normal file
59
examples/alst/llama3-8b-fsdp2-alst.yaml
Normal file
@@ -0,0 +1,59 @@
|
||||
base_model: meta-llama/Llama-3.1-8B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
datasets:
|
||||
- path: togethercomputer/Long-Data-Collections
|
||||
type: completion
|
||||
field: text
|
||||
data_files:
|
||||
- pretrain/rp_sub.jsonl.zst
|
||||
- path: princeton-nlp/TextbookChapters
|
||||
type: completion
|
||||
field: chapter
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 500_000
|
||||
min_sample_len: 200_000
|
||||
sample_packing: true
|
||||
|
||||
tiled_mlp: true
|
||||
context_parallel_size: 8
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: legacy
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 100
|
||||
saves_per_epoch: 1
|
||||
evals_per_epoch: 2
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false # offloading is currently not compatible with SP + torchao optimizer
|
||||
state_dict_type: SHARDED_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
reshard_after_forward: true
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -57,8 +57,12 @@ class LigerArgs(BaseModel):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tiled_mlp_conflict(cls, data):
|
||||
if data.get("liger_glu_activation") is True and data.get("tiled_mlp") is True:
|
||||
if (
|
||||
data.get("liger_glu_activation") is True
|
||||
and data.get("tiled_mlp") is True
|
||||
and not data.get("tiled_mlp_use_original_mlp")
|
||||
):
|
||||
raise ValueError(
|
||||
"You cannot have both `liger_glu_activation` and `tiled_mlp` set."
|
||||
"You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`."
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -162,6 +162,7 @@ class ModelLoader:
|
||||
|
||||
# Build the model
|
||||
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||
self.patch_manager.apply_post_plugin_pre_model_load_patches()
|
||||
skip_move_to_device = self._build_model()
|
||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||
|
||||
|
||||
@@ -66,6 +66,9 @@ class PatchManager:
|
||||
self._apply_self_attention_lora_patch()
|
||||
self._apply_gemma3_conditional_generation_forward_patch()
|
||||
self._apply_sequence_parallel_patches()
|
||||
|
||||
def apply_post_plugin_pre_model_load_patches(self):
|
||||
"""Apply post plugin-pre_model_load load patches based on config."""
|
||||
self._apply_tiled_mlp(self.cfg.model_config_type)
|
||||
|
||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||
@@ -272,7 +275,9 @@ class PatchManager:
|
||||
|
||||
def _apply_tiled_mlp(self, model_type: str):
|
||||
if self.cfg.tiled_mlp:
|
||||
from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp
|
||||
from axolotl.monkeypatch.tiled_mlp import (
|
||||
patch_tiled_mlp,
|
||||
)
|
||||
|
||||
patch_tiled_mlp(
|
||||
model_type,
|
||||
|
||||
11
src/axolotl/monkeypatch/tiled_mlp/__init__.py
Normal file
11
src/axolotl/monkeypatch/tiled_mlp/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
TiledMLP monkey patches
|
||||
"""
|
||||
|
||||
from .patch import (
|
||||
patch_tiled_mlp,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"patch_tiled_mlp",
|
||||
]
|
||||
153
src/axolotl/monkeypatch/tiled_mlp/base.py
Normal file
153
src/axolotl/monkeypatch/tiled_mlp/base.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
TiledMLP support for DDP, FSDP, and single GPU
|
||||
"""
|
||||
|
||||
import threading
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TiledMLP(torch.autograd.Function):
|
||||
"""
|
||||
TiledMLP implementation using gradient hooks
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
fn,
|
||||
self,
|
||||
x,
|
||||
shards,
|
||||
compute_params,
|
||||
) -> torch.Tensor:
|
||||
ctx.fn = fn
|
||||
ctx.self = self
|
||||
ctx.shards = shards
|
||||
ctx.compute_params = [p for p in compute_params if p.requires_grad]
|
||||
ctx.save_for_backward(x)
|
||||
|
||||
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
|
||||
with torch.no_grad():
|
||||
output_shards = [fn(self, x_shard) for x_shard in x_shards]
|
||||
output_unsharded = torch.cat(output_shards, dim=1)
|
||||
|
||||
return output_unsharded
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grads) -> torch.Tensor:
|
||||
fn = ctx.fn
|
||||
(x,) = ctx.saved_tensors
|
||||
self = ctx.self
|
||||
shards = ctx.shards
|
||||
compute_params = ctx.compute_params
|
||||
|
||||
x_requires_grad = x.requires_grad
|
||||
x = x.detach()
|
||||
x.requires_grad_(x_requires_grad)
|
||||
|
||||
incoming_grad = grads[0]
|
||||
x_grad = torch.zeros_like(x)
|
||||
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
|
||||
|
||||
# Create a gradient accumulator for parameters
|
||||
grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype)
|
||||
|
||||
shard_step = x_shards[0].numel()
|
||||
for i, x_shard in enumerate(x_shards):
|
||||
x_shard.requires_grad_(x_requires_grad)
|
||||
|
||||
shard_offset = i * shard_step
|
||||
x_shard.grad = (
|
||||
x_grad.view(-1)
|
||||
.narrow(0, shard_offset, x_shard.numel())
|
||||
.view_as(x_shard)
|
||||
)
|
||||
incoming_grad_shard = (
|
||||
incoming_grad.view(-1)
|
||||
.narrow(0, shard_offset, x_shard.numel())
|
||||
.view_as(x_shard)
|
||||
)
|
||||
|
||||
# Install hooks for this shard
|
||||
is_last_shard = i + 1 == shards
|
||||
grad_accumulator.install_hooks(is_last_shard)
|
||||
|
||||
with torch.enable_grad():
|
||||
output = fn(self, x_shard)
|
||||
torch.autograd.backward(output, incoming_grad_shard)
|
||||
|
||||
# Clean up hooks
|
||||
grad_accumulator.cleanup()
|
||||
del grad_accumulator
|
||||
|
||||
return (None, None, x_grad, None, None)
|
||||
|
||||
|
||||
class GradientAccumulator:
|
||||
"""
|
||||
Manual gradient accumulator for TiledMLP with configurable precision
|
||||
Accumulates in specified dtype and rescales the gradient at the end
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: List[torch.nn.Parameter],
|
||||
total_shards: int,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
self.params = params
|
||||
self.total_shards = total_shards
|
||||
self.grad_accumulation_dtype = dtype or torch.float32
|
||||
self.accumulated_grads = {}
|
||||
self.hooks = []
|
||||
self.lock = threading.Lock()
|
||||
self.gradient_scale = 1.0 / total_shards
|
||||
|
||||
# Initialize accumulated gradients in the specified dtype
|
||||
for param in self.params:
|
||||
if param.grad is not None:
|
||||
self.accumulated_grads[param] = param.grad.to(
|
||||
self.grad_accumulation_dtype
|
||||
)
|
||||
param.grad = None
|
||||
else:
|
||||
self.accumulated_grads[param] = torch.zeros_like(
|
||||
param, dtype=self.grad_accumulation_dtype
|
||||
)
|
||||
|
||||
def install_hooks(self, is_last_shard: bool):
|
||||
"""Install gradient hooks that accumulate gradients in higher precision"""
|
||||
|
||||
def create_hook(param):
|
||||
def hook(grad):
|
||||
with self.lock:
|
||||
grad_to_accum_dtype = grad.to(self.grad_accumulation_dtype)
|
||||
scaled_grad = grad_to_accum_dtype * self.gradient_scale
|
||||
|
||||
if param in self.accumulated_grads:
|
||||
self.accumulated_grads[param] += scaled_grad
|
||||
else:
|
||||
self.accumulated_grads[param] = scaled_grad.clone()
|
||||
|
||||
# Only assign the averaged gradient on the last shard
|
||||
if is_last_shard:
|
||||
param.grad = self.accumulated_grads[param].to(param.dtype)
|
||||
return param.grad
|
||||
return None
|
||||
|
||||
return hook
|
||||
|
||||
# Install hooks on all parameters
|
||||
for param in self.params:
|
||||
if param.requires_grad:
|
||||
hook = param.register_hook(create_hook(param))
|
||||
self.hooks.append(hook)
|
||||
|
||||
def cleanup(self):
|
||||
"""Remove all installed hooks"""
|
||||
for hook in self.hooks:
|
||||
hook.remove()
|
||||
self.hooks.clear()
|
||||
del self.accumulated_grads
|
||||
@@ -12,8 +12,12 @@ from axolotl.utils.logging import get_logger
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
||||
from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP
|
||||
def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
|
||||
from deepspeed.runtime.sequence_parallel.ulysses_sp import (
|
||||
TiledMLP as DeepSpeedTiledMLP,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.tiled_mlp.base import TiledMLP
|
||||
|
||||
try:
|
||||
# Dynamically import the module and MLP class
|
||||
@@ -36,6 +40,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
||||
is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1
|
||||
|
||||
def tiled_mlp_forward(self, x):
|
||||
# pylint: disable=protected-access
|
||||
input_shape = x.shape
|
||||
seqlen = input_shape[-2]
|
||||
hidden = input_shape[-1]
|
||||
@@ -48,14 +53,23 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
||||
else:
|
||||
num_shards = cfg_num_shards
|
||||
|
||||
if not self._compute_params: # pylint: disable=protected-access
|
||||
self._compute_params = [ # pylint: disable=protected-access
|
||||
p for p in self.parameters() if p.requires_grad
|
||||
]
|
||||
if not self._compute_params:
|
||||
self._compute_params = [p for p in self.parameters() if p.requires_grad]
|
||||
|
||||
compute_params = self._compute_params # pylint: disable=protected-access
|
||||
compute_params = self._compute_params
|
||||
if not self._tiled_mlp_dist_impl:
|
||||
if (
|
||||
self._compute_params
|
||||
and any(
|
||||
hasattr(p, "ds_id") or hasattr(p, "param_idx_in_group")
|
||||
for p in self._compute_params
|
||||
)
|
||||
) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
|
||||
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
|
||||
else:
|
||||
self._tiled_mlp_dist_impl = TiledMLP
|
||||
|
||||
down_res = TiledMLP.apply(
|
||||
down_res = self._tiled_mlp_dist_impl.apply(
|
||||
mlp_forward,
|
||||
self,
|
||||
x,
|
||||
@@ -66,6 +80,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None):
|
||||
|
||||
mlp_cls.forward = tiled_mlp_forward
|
||||
mlp_cls._compute_params = [] # pylint: disable=protected-access
|
||||
mlp_cls._tiled_mlp_dist_impl = None # pylint: disable=protected-access
|
||||
LOG.info(
|
||||
f"Successfully monkey-patched TiledMLP for model_type: {model_type}",
|
||||
main_process_only=True,
|
||||
@@ -867,10 +867,16 @@ class GCCallback(TrainerCallback):
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def on_train_begin(
|
||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||
):
|
||||
self._gc()
|
||||
|
||||
def on_step_begin(
|
||||
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||
):
|
||||
if self.next_gc_on_begin_step == state.global_step:
|
||||
# pylint: disable=consider-using-in
|
||||
if self.next_gc_on_begin_step == state.global_step or state.global_step == 0:
|
||||
self._gc()
|
||||
|
||||
def on_step_end(
|
||||
|
||||
@@ -597,7 +597,7 @@ class AxolotlInputConfig(
|
||||
)
|
||||
|
||||
tiled_mlp_use_original_mlp: bool | None = Field(
|
||||
default=None,
|
||||
default=True,
|
||||
json_schema_extra={
|
||||
"description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama."
|
||||
},
|
||||
|
||||
@@ -512,19 +512,6 @@ class TrainingValidationMixin:
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_tiled_mlp_deepspeed(cls, data):
|
||||
capabilities = data.get("capabilities")
|
||||
n_gpu = 0
|
||||
if capabilities and capabilities.get("n_gpu", 0) >= 1:
|
||||
n_gpu = capabilities.get("n_gpu", 0)
|
||||
if data.get("tiled_mlp", False) and (n_gpu > 1 and not data.get("deepspeed")):
|
||||
raise ValueError(
|
||||
"tiled_mlp requires deepspeed ZeRO to be enabled for multi-gpu"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class LoRAValidationMixin:
|
||||
"""Validation methods related to LoRA/QLoRA configuration."""
|
||||
|
||||
Reference in New Issue
Block a user