* [gemma4] fix VRAM leak in hybrid FA2+SDPA path under activation checkpointing
Route shared_kv_states through a thread-local side channel instead of the
decoder-layer kwargs so the checkpoint partial never references the dict.
HF's Gemma4TextModel.forward passes shared_kv_states (a mutable dict used
for cross-layer K/V sharing) as a kwarg to every decoder_layer call.
GradientCheckpointingLayer.__call__ then forms
partial(super().__call__, **kwargs), and whichever checkpoint runs
(axolotl's CPU_Offloaded_Gradient_Checkpointer or torch's stock
checkpoint) captures that partial. The partial holds a reference to the
dict, which holds the K/V tensors produced by store_full_length_kv
layers. Those tensors stay pinned for the full duration of backward, and
delayed ref-cycle cleanup in torch's caching allocator under FSDP2 +
activation checkpointing bleeds the residual across steps.
Observed symptom: VRAM climbs ~0.47 GiB/step from a 42 GiB baseline,
OOMs around step 73 (~94 GiB peak) on Gemma-4 31B multimodal with
gemma4_hybrid_attn_impl: true. Independent of seq len / image size.
All-flex-attention path is flat but ~22x slower.
Violated invariant: anything crossing an activation-checkpoint boundary
must be a tensor (refcounted by autograd) or plain Python data -- never
a mutable container holding tensor references.
Fix (all in src/axolotl/monkeypatch/models/gemma4/fused_attn.py):
* threading.local() store with _get/_set_shared_kv_states helpers
* _patch_decoder_layer_call(): monkeypatches
Gemma4TextDecoderLayer.__call__ to pop shared_kv_states from kwargs
and stash it in TLS before delegating to GradientCheckpointingLayer.
The partial formed downstream no longer references the dict.
* fused_forward reads TLS first, falls back to kwarg for callers that
bypass the patched __call__ (e.g. direct attention invocation).
* wired into patch_gemma4_fused_attn; idempotent via a sentinel.
TLS is overwritten on each new step's first decoder-layer call, so the
previous step's dict is released promptly. No changes to hybrid dispatch,
FSDP wrap policy, or any config behaviour. Works for hybrid, flex, and
eager paths.
Introduced by PR #3598 (commit b8358aa5).
* Coderabbit comment: gemma4: clear TLS unconditionally in decoder-layer patched __call__
Overwrite the thread-local shared_kv_states store on every invocation
(including with None) instead of only when the kwarg is present.
The previous conditional write left stale dicts in TLS on any path that
reaches Gemma4TextDecoderLayer.__call__ without a shared_kv_states
kwarg — e.g. generation, eval hooks, or future HF refactors that make
the kwarg optional. fused_forward would then silently consume a prior
step's K/V dict instead of falling back to its own kwarg path.
Unconditional write makes the invariant in the surrounding comment
("TLS is overwritten on each new step's first decoder-layer call, so
the previous step's dict is released promptly") actually hold.
No behavior change for the training happy path, which always passes
the kwarg. Addresses CodeRabbit review on PR #3611
* fix: swap threading.local() for module-level store so autograd worker threads see shared_kv_states during backward recompute
Previous commits fixed memory leak on 31B but caused type error with MOE Gemma4 variants - this fixes that:
PR 3611's TLS variant only works when recompute runs on the same thread
that set TLS during forward. PyTorch's C++ autograd engine
(_engine_run_backward) spawns per-device worker threads to dispatch
backward, and HF-Trainer gradient_checkpointing (stock
torch.utils.checkpoint, non-reentrant / saved-tensor-hooks) fires
unpack_hook -> recompute_fn on those worker threads. TLS set on the main
thread during forward is invisible there, so _get_shared_kv_states()
returns None and the consumer-layer lookup crashes with
"'NoneType' object is not subscriptable" at
fused_attn.py:97 (shared_kv_states[self.kv_shared_layer_index]).
A plain module-level dict is visible to all threads in the process.
Lifecycle is identical: the slot is overwritten each forward, releasing
the previous step's dict and allowing its K/V tensors to be GC'd, so
the original VRAM-leak fix still holds under FSDP2 AC too.
* scope gemma4 shared_kv_states side channel to checkpointed training
Update PR #3611 with gate for checkpointed training to avoid regressions across async flows.
Added unit tests for kwargs pop, store-clear regression, and flag gating. Condensed verbose comments
* add gemma4 cross-thread visibility test for shared_kv_states store
Additional regression test for MoE gemma4 variants - asserts the module-level store is readable from threads other than the one that set it in response to previously observed 'NoneType' error
* fix logger
---------
Co-authored-by: Wing Lian <wing@axolotl.ai>
A Free and Open Source LLM Fine-tuning Framework
🎉 Latest Updates
- 2026/03:
- New model support has been added in Axolotl for Mistral Small 4, Qwen3.5, Qwen3.5 MoE, GLM-4.7-Flash, GLM-4.6V, and GLM-4.5-Air.
- MoE expert quantization support (via
quantize_moe_experts: true) greatly reduces VRAM when training MoE models (FSDP2 compat).
- 2026/02:
- ScatterMoE LoRA support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.
- Axolotl now has support for SageAttention and GDPO (Generalized DPO).
- 2026/01:
- New integration for EAFT (Entropy-Aware Focal Training), weights loss by entropy of the top-k logit distribution, and Scalable Softmax, improves long context in attention.
- 2025/12:
- Axolotl now includes support for Kimi-Linear, Plano-Orchestrator, MiMo, InternVL 3.5, Olmo3, Trinity, and Ministral3.
- Distributed Muon Optimizer support has been added for FSDP2 pretraining.
- 2025/10: New model support has been added in Axolotl for: Qwen3 Next, Qwen2.5-vl, Qwen3-vl, Qwen3, Qwen3MoE, Granite 4, HunYuan, Magistral 2509, Apertus, and Seed-OSS.
Expand older updates
- 2025/09: Axolotl now has text diffusion training. Read more here.
- 2025/08: QAT has been updated to include NVFP4 support. See PR.
- 2025/07:
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the blog post for more info.
- Axolotl adds more models: GPT-OSS, Gemma 3n, Liquid Foundation Model 2 (LFM2), and Arcee Foundation Models (AFM).
- FP8 finetuning with fp8 gather op is now possible in Axolotl via
torchao. Get started here! - Voxtral, Magistral 1.1, and Devstral with mistral-common tokenizer support has been integrated in Axolotl!
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See examples for using ALST with Axolotl!
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See docs to start training your own Magistral models with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the docs to learn more!
- 2025/04: Llama 4 support has been added in Axolotl. See docs to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the blog and docs to learn how to scale your context length when fine-tuning.
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the docs to fine-tune your own!
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the docs to give it a try.
- 2025/02: Axolotl has added GRPO support. Dive into our blog and GRPO example and have some fun!
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See docs.
✨ Overview
Axolotl is a free and open-source tool designed to streamline post-training and fine-tuning for the latest large language models (LLMs).
Features:
- Multiple Model Support: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
- Multimodal Training: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
- Training Methods: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
- Easy Configuration: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
- Performance Optimizations: Multipacking, Flash Attention 2/3/4, Xformers, Flex Attention, SageAttention, Liger Kernel, Cut Cross Entropy, ScatterMoE, Sequence Parallelism (SP), LoRA optimizations, Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!
- Flexible Dataset Handling: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- Cloud Ready: We ship Docker images and also PyPI packages for use on cloud platforms and local hardware.
🚀 Quick Start - LLM Fine-tuning in Minutes
Requirements:
- NVIDIA GPU (Ampere or newer for
bf16and Flash Attention) or AMD GPU - Python >=3.11 (3.12 recommended)
- PyTorch ≥2.9.1
Google Colab
Installation
# install uv if you don't already have it installed (restart shell after)
curl -LsSf https://astral.sh/uv/install.sh | sh
# change depending on system
export UV_TORCH_BACKEND=cu128
# create a new virtual environment
uv venv --python 3.12
source .venv/bin/activate
uv pip install torch==2.10.0 torchvision
uv pip install --no-build-isolation axolotl[deepspeed]
# Download example axolotl configs, deepspeed configs
axolotl fetch examples
axolotl fetch deepspeed_configs # OPTIONAL
Using Docker
Installing with Docker can be less error prone than installing in your own environment.
docker run --gpus '"all"' --ipc=host --rm -it axolotlai/axolotl:main-latest
Other installation approaches are described here.
Cloud Providers
Your First Fine-tune
# Fetch axolotl examples
axolotl fetch examples
# Or, specify a custom path
axolotl fetch examples --dest path/to/folder
# Train a model using LoRA
axolotl train examples/llama-3/lora-1b.yml
That's it! Check out our Getting Started Guide for a more detailed walkthrough.
📚 Documentation
- Installation Options - Detailed setup instructions for different environments
- Configuration Guide - Full configuration options and examples
- Dataset Loading - Loading datasets from various sources
- Dataset Guide - Supported formats and how to use them
- Multi-GPU Training
- Multi-Node Training
- Multipacking
- API Reference - Auto-generated code documentation
- FAQ - Frequently asked questions
AI Agent Support
Axolotl ships with built-in documentation optimized for AI coding agents (Claude Code, Cursor, Copilot, etc.). These docs are bundled with the pip package — no repo clone needed.
# Show overview and available training methods
axolotl agent-docs
# Topic-specific references
axolotl agent-docs sft # supervised fine-tuning
axolotl agent-docs grpo # GRPO online RL
axolotl agent-docs preference_tuning # DPO, KTO, ORPO, SimPO
axolotl agent-docs reward_modelling # outcome and process reward models
axolotl agent-docs pretraining # continual pretraining
axolotl agent-docs --list # list all topics
# Dump config schema for programmatic use
axolotl config-schema
axolotl config-schema --field adapter
If you're working with the source repo, agent docs are also available at docs/agents/ and the project overview is in AGENTS.md.
🤝 Getting Help
- Join our Discord community for support
- Check out our Examples directory
- Read our Debugging Guide
- Need dedicated support? Please contact ✉️wing@axolotl.ai for options
🌟 Contributing
Contributions are welcome! Please see our Contributing Guide for details.
📈 Telemetry
Axolotl has opt-out telemetry that helps us understand how the project is being used and prioritize improvements. We collect basic system information, model types, and error rates—never personal data or file paths. Telemetry is enabled by default. To disable it, set AXOLOTL_DO_NOT_TRACK=1. For more details, see our telemetry documentation.
❤️ Sponsors
Interested in sponsoring? Contact us at wing@axolotl.ai
📝 Citing Axolotl
If you use Axolotl in your research or projects, please cite it as follows:
@software{axolotl,
title = {Axolotl: Open Source LLM Post-Training},
author = {{Axolotl maintainers and contributors}},
url = {https://github.com/axolotl-ai-cloud/axolotl},
license = {Apache-2.0},
year = {2023}
}
📜 License
This project is licensed under the Apache 2.0 License - see the LICENSE file for details.