From f7ea140838e720cc23c6d71c4e578314e7daf52a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 25 Jul 2025 07:15:03 -0400 Subject: [PATCH] TiledMLP support for FSDP2 (#2950) * make TiledMLP work with FSDP * cleanup/gc at start of train to prevent large VRAM spike * chore: lint * generic function for non-deepspeed training * unify patch to fix imports * update readme for ALST and add examples * make deepspeed attribute on params check more robust * update with new info from PR review --- README.md | 1 + examples/alst/README.md | 9 ++ examples/alst/llama3-8b-deepspeed-alst.yaml | 53 ++++++ examples/alst/llama3-8b-fsdp2-alst.yaml | 59 +++++++ src/axolotl/integrations/liger/args.py | 8 +- src/axolotl/loaders/model.py | 1 + src/axolotl/loaders/patch_manager.py | 7 +- src/axolotl/monkeypatch/tiled_mlp/__init__.py | 11 ++ src/axolotl/monkeypatch/tiled_mlp/base.py | 153 ++++++++++++++++++ .../{tiled_mlp.py => tiled_mlp/patch.py} | 31 +++- src/axolotl/utils/callbacks/__init__.py | 8 +- src/axolotl/utils/schemas/config.py | 2 +- src/axolotl/utils/schemas/validation.py | 13 -- 13 files changed, 330 insertions(+), 26 deletions(-) create mode 100644 examples/alst/README.md create mode 100644 examples/alst/llama3-8b-deepspeed-alst.yaml create mode 100644 examples/alst/llama3-8b-fsdp2-alst.yaml create mode 100644 src/axolotl/monkeypatch/tiled_mlp/__init__.py create mode 100644 src/axolotl/monkeypatch/tiled_mlp/base.py rename src/axolotl/monkeypatch/{tiled_mlp.py => tiled_mlp/patch.py} (66%) diff --git a/README.md b/README.md index 406781039..b31703e2b 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ ## 🎉 Latest Updates +- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl! - 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl! - 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more! - 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version! diff --git a/examples/alst/README.md b/examples/alst/README.md new file mode 100644 index 000000000..7f194d299 --- /dev/null +++ b/examples/alst/README.md @@ -0,0 +1,9 @@ +# Arctic Long Sequence Training (ALST) + +Artic Long Sequence Training (ALST) is a technique for training long context models using a variety of optimization +techniques. It is a combination of: +- TiledMLP: Leverage tiling over the sequence dimension on MLP layers to reduce memory usage +- Tiled Loss: Using optimized loss functions like Liger-Kernel or Cut Cross Entropy to reduce memory usage +- Activation Offloading: Offload activations to CPU RAM to reduce memory usage + +For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996). diff --git a/examples/alst/llama3-8b-deepspeed-alst.yaml b/examples/alst/llama3-8b-deepspeed-alst.yaml new file mode 100644 index 000000000..dc82fa3be --- /dev/null +++ b/examples/alst/llama3-8b-deepspeed-alst.yaml @@ -0,0 +1,53 @@ +base_model: meta-llama/Llama-3.1-8B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +datasets: + - path: togethercomputer/Long-Data-Collections + type: completion + field: text + data_files: + - pretrain/rp_sub.jsonl.zst + - path: princeton-nlp/TextbookChapters + type: completion + field: chapter +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 500_000 +min_sample_len: 200_000 +sample_packing: true + +tiled_mlp: true +sequence_parallel_degree: 8 +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_8bit +lr_scheduler: cosine +learning_rate: 2e-5 + +bf16: auto +tf32: true + +gradient_checkpointing: true +activation_offloading: legacy + +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_steps: 100 +saves_per_epoch: 1 +evals_per_epoch: 2 +weight_decay: 0.0 +special_tokens: + pad_token: <|end_of_text|> + +deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_all.json + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/examples/alst/llama3-8b-fsdp2-alst.yaml b/examples/alst/llama3-8b-fsdp2-alst.yaml new file mode 100644 index 000000000..c8a978264 --- /dev/null +++ b/examples/alst/llama3-8b-fsdp2-alst.yaml @@ -0,0 +1,59 @@ +base_model: meta-llama/Llama-3.1-8B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +datasets: + - path: togethercomputer/Long-Data-Collections + type: completion + field: text + data_files: + - pretrain/rp_sub.jsonl.zst + - path: princeton-nlp/TextbookChapters + type: completion + field: chapter +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +sequence_len: 500_000 +min_sample_len: 200_000 +sample_packing: true + +tiled_mlp: true +context_parallel_size: 8 +plugins: + - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin + +gradient_accumulation_steps: 1 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_8bit +lr_scheduler: cosine +learning_rate: 2e-5 + +bf16: auto +tf32: true + +gradient_checkpointing: true +activation_offloading: legacy + +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_steps: 100 +saves_per_epoch: 1 +evals_per_epoch: 2 +weight_decay: 0.0 +special_tokens: + pad_token: <|end_of_text|> + +fsdp_version: 2 +fsdp_config: + offload_params: false # offloading is currently not compatible with SP + torchao optimizer + state_dict_type: SHARDED_STATE_DICT + auto_wrap_policy: TRANSFORMER_BASED_WRAP + transformer_layer_cls_to_wrap: LlamaDecoderLayer + reshard_after_forward: true + +# save_first_step: true # uncomment this to validate checkpoint saving works with your config diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py index 94ba83dd5..0460bdbf5 100644 --- a/src/axolotl/integrations/liger/args.py +++ b/src/axolotl/integrations/liger/args.py @@ -57,8 +57,12 @@ class LigerArgs(BaseModel): @model_validator(mode="before") @classmethod def check_tiled_mlp_conflict(cls, data): - if data.get("liger_glu_activation") is True and data.get("tiled_mlp") is True: + if ( + data.get("liger_glu_activation") is True + and data.get("tiled_mlp") is True + and not data.get("tiled_mlp_use_original_mlp") + ): raise ValueError( - "You cannot have both `liger_glu_activation` and `tiled_mlp` set." + "You cannot have both `liger_glu_activation` and `tiled_mlp` set without `tiled_mlp_use_original_mlp: true`." ) return data diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 1ce98ef31..4fc005457 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -162,6 +162,7 @@ class ModelLoader: # Build the model PLUGIN_MANAGER.pre_model_load(self.cfg) + self.patch_manager.apply_post_plugin_pre_model_load_patches() skip_move_to_device = self._build_model() PLUGIN_MANAGER.post_model_build(self.cfg, self.model) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index 533bd0f7a..f1bb3ae67 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -66,6 +66,9 @@ class PatchManager: self._apply_self_attention_lora_patch() self._apply_gemma3_conditional_generation_forward_patch() self._apply_sequence_parallel_patches() + + def apply_post_plugin_pre_model_load_patches(self): + """Apply post plugin-pre_model_load load patches based on config.""" self._apply_tiled_mlp(self.cfg.model_config_type) def apply_post_model_load_patches(self, model: PreTrainedModel): @@ -272,7 +275,9 @@ class PatchManager: def _apply_tiled_mlp(self, model_type: str): if self.cfg.tiled_mlp: - from axolotl.monkeypatch.tiled_mlp import patch_tiled_mlp + from axolotl.monkeypatch.tiled_mlp import ( + patch_tiled_mlp, + ) patch_tiled_mlp( model_type, diff --git a/src/axolotl/monkeypatch/tiled_mlp/__init__.py b/src/axolotl/monkeypatch/tiled_mlp/__init__.py new file mode 100644 index 000000000..4ea154991 --- /dev/null +++ b/src/axolotl/monkeypatch/tiled_mlp/__init__.py @@ -0,0 +1,11 @@ +""" +TiledMLP monkey patches +""" + +from .patch import ( + patch_tiled_mlp, +) + +__all__ = [ + "patch_tiled_mlp", +] diff --git a/src/axolotl/monkeypatch/tiled_mlp/base.py b/src/axolotl/monkeypatch/tiled_mlp/base.py new file mode 100644 index 000000000..3b7326bdb --- /dev/null +++ b/src/axolotl/monkeypatch/tiled_mlp/base.py @@ -0,0 +1,153 @@ +""" +TiledMLP support for DDP, FSDP, and single GPU +""" + +import threading +from typing import List + +import torch + + +class TiledMLP(torch.autograd.Function): + """ + TiledMLP implementation using gradient hooks + """ + + @staticmethod + def forward( + ctx, + fn, + self, + x, + shards, + compute_params, + ) -> torch.Tensor: + ctx.fn = fn + ctx.self = self + ctx.shards = shards + ctx.compute_params = [p for p in compute_params if p.requires_grad] + ctx.save_for_backward(x) + + x_shards = list(torch.chunk(x, chunks=shards, dim=1)) + with torch.no_grad(): + output_shards = [fn(self, x_shard) for x_shard in x_shards] + output_unsharded = torch.cat(output_shards, dim=1) + + return output_unsharded + + @staticmethod + def backward(ctx, *grads) -> torch.Tensor: + fn = ctx.fn + (x,) = ctx.saved_tensors + self = ctx.self + shards = ctx.shards + compute_params = ctx.compute_params + + x_requires_grad = x.requires_grad + x = x.detach() + x.requires_grad_(x_requires_grad) + + incoming_grad = grads[0] + x_grad = torch.zeros_like(x) + x_shards = list(torch.chunk(x, chunks=shards, dim=1)) + + # Create a gradient accumulator for parameters + grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype) + + shard_step = x_shards[0].numel() + for i, x_shard in enumerate(x_shards): + x_shard.requires_grad_(x_requires_grad) + + shard_offset = i * shard_step + x_shard.grad = ( + x_grad.view(-1) + .narrow(0, shard_offset, x_shard.numel()) + .view_as(x_shard) + ) + incoming_grad_shard = ( + incoming_grad.view(-1) + .narrow(0, shard_offset, x_shard.numel()) + .view_as(x_shard) + ) + + # Install hooks for this shard + is_last_shard = i + 1 == shards + grad_accumulator.install_hooks(is_last_shard) + + with torch.enable_grad(): + output = fn(self, x_shard) + torch.autograd.backward(output, incoming_grad_shard) + + # Clean up hooks + grad_accumulator.cleanup() + del grad_accumulator + + return (None, None, x_grad, None, None) + + +class GradientAccumulator: + """ + Manual gradient accumulator for TiledMLP with configurable precision + Accumulates in specified dtype and rescales the gradient at the end + """ + + def __init__( + self, + params: List[torch.nn.Parameter], + total_shards: int, + dtype: torch.dtype | None = None, + ): + self.params = params + self.total_shards = total_shards + self.grad_accumulation_dtype = dtype or torch.float32 + self.accumulated_grads = {} + self.hooks = [] + self.lock = threading.Lock() + self.gradient_scale = 1.0 / total_shards + + # Initialize accumulated gradients in the specified dtype + for param in self.params: + if param.grad is not None: + self.accumulated_grads[param] = param.grad.to( + self.grad_accumulation_dtype + ) + param.grad = None + else: + self.accumulated_grads[param] = torch.zeros_like( + param, dtype=self.grad_accumulation_dtype + ) + + def install_hooks(self, is_last_shard: bool): + """Install gradient hooks that accumulate gradients in higher precision""" + + def create_hook(param): + def hook(grad): + with self.lock: + grad_to_accum_dtype = grad.to(self.grad_accumulation_dtype) + scaled_grad = grad_to_accum_dtype * self.gradient_scale + + if param in self.accumulated_grads: + self.accumulated_grads[param] += scaled_grad + else: + self.accumulated_grads[param] = scaled_grad.clone() + + # Only assign the averaged gradient on the last shard + if is_last_shard: + param.grad = self.accumulated_grads[param].to(param.dtype) + return param.grad + return None + + return hook + + # Install hooks on all parameters + for param in self.params: + if param.requires_grad: + hook = param.register_hook(create_hook(param)) + self.hooks.append(hook) + + def cleanup(self): + """Remove all installed hooks""" + for hook in self.hooks: + hook.remove() + self.hooks.clear() + del self.accumulated_grads diff --git a/src/axolotl/monkeypatch/tiled_mlp.py b/src/axolotl/monkeypatch/tiled_mlp/patch.py similarity index 66% rename from src/axolotl/monkeypatch/tiled_mlp.py rename to src/axolotl/monkeypatch/tiled_mlp/patch.py index 02bb3a8cb..419c73104 100644 --- a/src/axolotl/monkeypatch/tiled_mlp.py +++ b/src/axolotl/monkeypatch/tiled_mlp/patch.py @@ -12,8 +12,12 @@ from axolotl.utils.logging import get_logger LOG = get_logger(__name__) -def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None): - from deepspeed.runtime.sequence_parallel.ulysses_sp import TiledMLP +def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None): + from deepspeed.runtime.sequence_parallel.ulysses_sp import ( + TiledMLP as DeepSpeedTiledMLP, + ) + + from axolotl.monkeypatch.tiled_mlp.base import TiledMLP try: # Dynamically import the module and MLP class @@ -36,6 +40,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None): is_distributed = int(os.environ.get("WORLD_SIZE", 1)) > 1 def tiled_mlp_forward(self, x): + # pylint: disable=protected-access input_shape = x.shape seqlen = input_shape[-2] hidden = input_shape[-1] @@ -48,14 +53,23 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None): else: num_shards = cfg_num_shards - if not self._compute_params: # pylint: disable=protected-access - self._compute_params = [ # pylint: disable=protected-access - p for p in self.parameters() if p.requires_grad - ] + if not self._compute_params: + self._compute_params = [p for p in self.parameters() if p.requires_grad] - compute_params = self._compute_params # pylint: disable=protected-access + compute_params = self._compute_params + if not self._tiled_mlp_dist_impl: + if ( + self._compute_params + and any( + hasattr(p, "ds_id") or hasattr(p, "param_idx_in_group") + for p in self._compute_params + ) + ) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true": + self._tiled_mlp_dist_impl = DeepSpeedTiledMLP + else: + self._tiled_mlp_dist_impl = TiledMLP - down_res = TiledMLP.apply( + down_res = self._tiled_mlp_dist_impl.apply( mlp_forward, self, x, @@ -66,6 +80,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=False, cfg_num_shards=None): mlp_cls.forward = tiled_mlp_forward mlp_cls._compute_params = [] # pylint: disable=protected-access + mlp_cls._tiled_mlp_dist_impl = None # pylint: disable=protected-access LOG.info( f"Successfully monkey-patched TiledMLP for model_type: {model_type}", main_process_only=True, diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 2d031aa03..c64d8d351 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -867,10 +867,16 @@ class GCCallback(TrainerCallback): torch.cuda.empty_cache() gc.collect() + def on_train_begin( + self, args, state, control, **kwargs # pylint: disable=unused-argument + ): + self._gc() + def on_step_begin( self, args, state, control, **kwargs # pylint: disable=unused-argument ): - if self.next_gc_on_begin_step == state.global_step: + # pylint: disable=consider-using-in + if self.next_gc_on_begin_step == state.global_step or state.global_step == 0: self._gc() def on_step_end( diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 96b694043..a0e0b9604 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -597,7 +597,7 @@ class AxolotlInputConfig( ) tiled_mlp_use_original_mlp: bool | None = Field( - default=None, + default=True, json_schema_extra={ "description": "Whether to use original mlp for ALST tiled mlp. Otherwise uses a generic MLP based on llama." }, diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index cfa759cad..9ca33f456 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -512,19 +512,6 @@ class TrainingValidationMixin: return data - @model_validator(mode="before") - @classmethod - def check_tiled_mlp_deepspeed(cls, data): - capabilities = data.get("capabilities") - n_gpu = 0 - if capabilities and capabilities.get("n_gpu", 0) >= 1: - n_gpu = capabilities.get("n_gpu", 0) - if data.get("tiled_mlp", False) and (n_gpu > 1 and not data.get("deepspeed")): - raise ValueError( - "tiled_mlp requires deepspeed ZeRO to be enabled for multi-gpu" - ) - return data - class LoRAValidationMixin: """Validation methods related to LoRA/QLoRA configuration."""