Compare commits

..

4 Commits

Author SHA1 Message Date
Wing Lian
08aa74e418 fix llama modeling 2025-07-30 11:37:58 -04:00
Wing Lian
dfa14f87ab fix residuals and add llama support 2025-07-30 10:22:38 -04:00
Wing Lian
fbe1b504da add custom modeling for gemma3 using liger fused add rms 2025-07-30 08:21:03 -04:00
Wing Lian
5b8370969c actually call the register method on plugins 2025-07-30 08:05:25 -04:00
414 changed files with 15534 additions and 18189 deletions

View File

@@ -1,3 +1,3 @@
[bandit]
exclude = tests
skips = B101,B615,B102,B110
skips = B101,B615

View File

@@ -12,6 +12,5 @@ reviews:
auto_review:
enabled: true
drafts: false
auto_incremental_review: false
chat:
auto_reply: true

5
.flake8 Normal file
View File

@@ -0,0 +1,5 @@
[flake8]
max-line-length = 88
select = C,E,F,W,B,B950
extend-ignore = E203, E501, W503

View File

@@ -57,13 +57,6 @@ We welcome ideas for improvements and new features. To suggest an enhancement, o
5. Push your branch to your fork on GitHub.
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
#### Skipping CI Checks
You can skip certain CI checks by including specific keywords in your commit messages:
- `[skip ci]` or `skip ci` - Skips all CI checks for that commit
- `[skip-e2e]` or `skip-e2e` - Skips only end-to-end tests while running other CI checks. You may also include this in the title of your PR to disable end-to-end tests for the entire PR.
## Style Guidelines
### Code Style

View File

@@ -54,7 +54,7 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.8.1
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
@@ -64,16 +64,9 @@ jobs:
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.8.0
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
# python_version: "3.11"
# pytorch: nightly
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base-nightly"
dockerfile: "Dockerfile-base-nightly"
# # "next" is for release candidates of pytorch
# - cuda: "128"
# cuda_version: 12.8.1
@@ -129,13 +122,6 @@ jobs:
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -143,13 +129,6 @@ jobs:
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -24,22 +24,16 @@ jobs:
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
pytorch: 2.7.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -103,23 +97,12 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -167,24 +150,6 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -36,15 +36,15 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
pytorch: 2.7.1
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]

View File

@@ -53,7 +53,7 @@ jobs:
- name: Netlify Publish
uses: nwtgck/actions-netlify@v3.0
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
if: ${{ secrets.NETLIFY_AUTH_TOKEN != '' }}
id: netlify
with:
publish-dir: './_site'
@@ -68,7 +68,7 @@ jobs:
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
- name: Update PR with preview link
if: ${{ steps.netlify.outcome == 'success' }}
if: ${{ steps.netlify.outcome == 'success' && secrets.NETLIFY_AUTH_TOKEN != '' }}
uses: marocchino/sticky-pull-request-comment@v2
with:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -55,7 +55,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
timeout-minutes: 20
steps:
@@ -105,8 +105,7 @@ jobs:
- name: Run tests
run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
@@ -130,7 +129,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
timeout-minutes: 20
steps:
@@ -180,52 +179,21 @@ jobs:
- name: Run tests
run: |
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest -v --durations=10 tests/patched/
pytest -v --durations=10 tests/cli/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
gate-skip-e2e:
needs: [pre-commit, pytest, pytest-sdist]
runs-on: ubuntu-latest
outputs:
skip: ${{ steps.compute.outputs.skip }}
steps:
- uses: actions/github-script@v7
id: compute
with:
script: |
const token = /\[skip-e2e\]/i;
let msg = '';
if (context.eventName === 'push') {
msg = context.payload.head_commit?.message || '';
} else if (context.eventName === 'pull_request') {
const { owner, repo } = context.repo;
const prNumber = context.payload.pull_request.number;
const commits = await github.paginate(
github.rest.pulls.listCommits,
{ owner, repo, pull_number: prNumber, per_page: 100 }
);
msg = commits.at(-1)?.commit?.message || '';
}
const title = context.payload.pull_request?.title || '';
const body = context.payload.pull_request?.body || '';
const skip = token.test(msg) || token.test(title) || token.test(body);
core.setOutput('skip', String(skip));
docker-e2e-tests-1st:
# Run this job first as a gate for running the remainder of the test matrix
if: >
github.repository_owner == 'axolotl-ai-cloud' &&
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
needs.gate-skip-e2e.outputs.skip != 'true'
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
needs: [pre-commit, pytest, pytest-sdist]
strategy:
fail-fast: false
@@ -240,7 +208,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
@@ -271,16 +239,13 @@ jobs:
modal run cicd.e2e_tests
docker-e2e-tests:
if: >
github.repository_owner == 'axolotl-ai-cloud' &&
(github.event_name != 'pull_request' || !github.event.pull_request.draft) &&
needs.gate-skip-e2e.outputs.skip != 'true'
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && !github.event.pull_request.draft }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
# Only run the remainder of the matrix if the first e2e check passed;
# this is to save on wasted compute costs for known failures that get caught in the first run
needs: [pre-commit, pytest, gate-skip-e2e, docker-e2e-tests-1st]
needs: [pre-commit, pytest, docker-e2e-tests-1st]
strategy:
fail-fast: false
@@ -298,13 +263,6 @@ jobs:
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
num_gpus: 1
gpu_type: "B200"
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -325,7 +283,6 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
@@ -342,10 +299,10 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
steps:

4
.isort.cfg Normal file
View File

@@ -0,0 +1,4 @@
[settings]
profile=black
known_third_party=wandb,comet_ml
known_local_folder=src,tests

View File

@@ -3,21 +3,31 @@ default_language_version:
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
rev: v5.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.12
- repo: https://github.com/psf/black
rev: 25.1.0
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- id: black
- repo: https://github.com/pycqa/isort
rev: 6.0.1
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 7.3.0
hooks:
- id: flake8
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.7
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.17.1
rev: v1.17.0
hooks:
- id: mypy
additional_dependencies:

15
.pylintrc Normal file
View File

@@ -0,0 +1,15 @@
[MASTER]
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
[TYPECHECK]
# List of members which are set dynamically and missed by Pylint inference
# system, and so shouldn't trigger E1101 when accessed.
generated-members=numpy.*, torch.*
[pylint.messages_control]
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-positional-arguments, possibly-used-before-assignment

View File

@@ -185,6 +185,7 @@ datasets:
| `flash_attention` | `false` | Use flash attention |
| `flash_attn_cross_entropy` | `false` | Flash attention cross entropy |
| `flash_attn_rms_norm` | `false` | Flash attention RMS norm |
| `flash_attn_fuse_qkv` | `false` | Fuse QKV operations |
| `flash_attn_fuse_mlp` | `false` | Fuse MLP operations |
| `sdp_attention` | `false` | Use scaled dot product |
| `s2_attention` | `false` | Use shifted sparse attention |

View File

@@ -296,6 +296,7 @@
# flash_attention:
# flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
# flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
# flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
# flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# # Whether to use scaled-dot-product attention
# # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
@@ -540,6 +541,7 @@ xformers_attention: ${XFORMERS_ATTENTION}
flash_attention: ${FLASH_ATTENTION}
flash_attn_cross_entropy: ${FLASH_ATTN_CROSS_ENTROPY}
flash_attn_rms_norm: ${FLASH_ATTN_RMS_NORM}
flash_attn_fuse_qkv: ${FLASH_ATTN_FUSE_QKV}
flash_attn_fuse_mlp: ${FLASH_ATTN_FUSE_MLP}
sdp_attention: ${SDP_ATTENTION}
s2_attention: ${S2_ATTENTION}

View File

@@ -1,10 +0,0 @@
cff-version: 1.2.0
type: software
title: "Axolotl: Post-Training for AI Models"
message: "If you use this software, please cite it as below."
authors:
- name: "Axolotl maintainers and contributors"
repository-code: "https://github.com/axolotl-ai-cloud/axolotl"
url: "https://axolotl.ai/"
license: Apache-2.0
date-released: "2023-05-30"

View File

@@ -17,7 +17,6 @@
<br/>
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
<a href="https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google-colab" style="height: 20px;"></a>
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
@@ -26,28 +25,17 @@
## 🎉 Latest Updates
- 2025/07:
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
<details>
<summary>Expand older updates</summary>
- 2025/07: Voxtral with mistral-common tokenizer support has been integrated in Axolotl. Read the [docs](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral)!
- 2025/07: TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
</details>
## ✨ Overview
Axolotl is a tool designed to streamline post-training for various AI models.
@@ -71,10 +59,6 @@ Features:
- Python 3.11
- PyTorch ≥2.6.0
### Google Colab
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
### Installation
#### Using pip
@@ -154,20 +138,6 @@ Contributions are welcome! Please see our [Contributing Guide](https://github.co
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
## 📝 Citing Axolotl
If you use Axolotl in your research or projects, please cite it as follows:
```bibtex
@software{axolotl,
title = {Axolotl: Post-Training for AI Models},
author = {{Axolotl maintainers and contributors}},
url = {https://github.com/axolotl-ai-cloud/axolotl},
license = {Apache-2.0},
year = {2023}
}
```
## 📜 License
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.

10
TODO.md Normal file
View File

@@ -0,0 +1,10 @@
# todo list
- [] Validation of parameters for combinations that won't work
## things that are known not to work
- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
- adamw_bnb_8bit doesn't play well with FSDP offload

View File

@@ -35,30 +35,25 @@ quartodoc:
- cli.train
- cli.evaluate
- cli.args
- cli.art
- cli.checks
- cli.config
- cli.delinearize_llama4
- cli.inference
- cli.merge_lora
- cli.merge_sharded_fsdp_weights
- cli.preprocess
- cli.quantize
- cli.sweeps
- cli.utils
- cli.vllm_serve
- cli.cloud.base
- cli.cloud.modal_
- cli.utils
- cli.utils.args
- cli.utils.fetch
- cli.utils.load
- cli.utils.sweeps
- cli.utils.train
- cli.quantize
- title: Trainers
desc: Training implementations
contents:
- core.trainers.base
- core.trainers.trl
- core.trainers.mamba
- core.trainers.relora
- core.trainers.dpo.trainer
- core.trainers.grpo.trainer
- core.trainers.grpo.sampler
@@ -153,7 +148,7 @@ quartodoc:
- utils.distributed
- utils.dict
- utils.optimizers.adopt
- utils.data.streaming
- utils.data.pretraining
- utils.data.sft
- utils.quantization
- title: Schemas
@@ -272,10 +267,9 @@ website:
contents:
- docs/batch_vs_grad.qmd
- docs/dataset_preprocessing.qmd
- docs/streaming.qmd
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/optimizers.qmd
- docs/gradient_accumulation.qmd
- section: "Advanced Features"
contents:
@@ -285,7 +279,6 @@ website:
- docs/custom_integrations.qmd
- docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- docs/nd_parallelism.qmd
- section: "Troubleshooting"
contents:

View File

@@ -2,6 +2,8 @@
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
@@ -61,7 +63,7 @@ def run_cmd(cmd: str, run_folder: str):
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code)
exit(exit_code) # pylint: disable=consider-using-sys-exit
@app.function(

View File

@@ -2,7 +2,7 @@
set -e
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v --durations=10 -n2 \
pytest -v -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \

View File

@@ -1,5 +1,7 @@
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
@@ -57,16 +59,12 @@ VOLUME_CONFIG = {
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_TYPE = os.environ.get("GPU_TYPE", "L40S")
GPU_CONFIG = f"{GPU_TYPE}:{N_GPUS}"
GPU_CONFIG = f"L40S:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
sp_env = os.environ.copy()
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit(exit_code)
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit

View File

@@ -12,7 +12,7 @@ coverage:
default:
# basic
target: auto
threshold: 1%
threshold: 0%
base: auto
# advanced
branches: null
@@ -27,7 +27,7 @@ coverage:
default:
# basic
target: auto
threshold: 1%
threshold: 0%
base: auto
# advanced
branches: null

View File

@@ -16,10 +16,7 @@ ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y --no-install-recommends \
wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
ibverbs-providers ibverbs-utils infiniband-diags \
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config \
&& rm -rf /var/cache/apt/archives \
&& rm -rf /var/lib/apt/lists/* \
&& wget \
@@ -37,7 +34,7 @@ WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge

View File

@@ -15,7 +15,7 @@ COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/* && \
mkdir -p ~/.ssh && \

View File

@@ -23,20 +23,6 @@ axolotl <command> [config.yml] [options]
The config file can be local or a URL to a raw YAML file.
### Launcher Arguments
For commands that support multi-GPU (`train`, `evaluate`, ...), you can pass launcher-specific arguments using the `--` separator:
```bash
# Pass torchrun arguments
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
# Pass accelerate arguments
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml --num_processes=4
```
Arguments after `--` are passed directly to the launcher (torchrun, accelerate launch, etc.).
## Command Reference
### fetch
@@ -94,11 +80,7 @@ axolotl train config.yml \
--num-epochs 3
# Training without accelerate
axolotl train config.yml --launcher python
# Pass launcher-specific arguments using -- separator
axolotl train config.yml --launcher torchrun -- --nproc_per_node=2 --nnodes=1
axolotl train config.yml --launcher accelerate -- --config_file=accelerate_config.yml
axolotl train config.yml --no-accelerate
# Resume training from checkpoint
axolotl train config.yml --resume-from-checkpoint path/to/checkpoint
@@ -193,9 +175,6 @@ Evaluates a model's performance (loss etc) on the train and eval datasets.
```bash
# Basic evaluation
axolotl evaluate config.yml
# Evaluation with launcher arguments
axolotl evaluate config.yml --launcher torchrun -- --nproc_per_node=2
```
### lm-eval
@@ -308,6 +287,9 @@ axolotl preprocess config.yml --cloud cloud_config.yml
# Train on cloud
axolotl train config.yml --cloud cloud_config.yml
# Train without accelerate on cloud
axolotl train config.yml --cloud cloud_config.yml --no-accelerate
# Run lm-eval on cloud
axolotl lm-eval config.yml --cloud cloud_config.yml
```

View File

@@ -212,11 +212,10 @@ Instead of passing `tools` via the system prompt, an alternative method would be
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
:::
Example config for Llama4:
```yaml
chat_template: llama4
datasets:
- path: Nanobit/text-tools-2k-test
- path: ...
type: chat_template
# field_tools: tools # default is `tools`
```

View File

@@ -134,7 +134,7 @@ For providers supporting Docker:
### Google Colab {#sec-colab}
[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
Use our [example notebook](../examples/colab-notebooks/colab-axolotl-example.ipynb).
## Platform-Specific Instructions {#sec-platform-specific}

View File

@@ -63,6 +63,15 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
:::
::: {.callout-tip}
Using ZeRO Stage 3 with Single-GPU training
ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`
:::
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
::: {.callout-note}

View File

@@ -69,19 +69,11 @@ export NCCL_BUFFSIZE=2097152
Run the following on each node:
### Option 1: New Axolotl CLI with launcher args (Recommended)
```bash
axolotl train config.yaml --launcher torchrun -- --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port"
```
### Option 2: Direct torchrun (Legacy)
```bash
torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml
```
Please make sure to substitute the placeholder variables:
Please make sure to substitute the placeholder variables.
- `num_nodes`: Number of nodes (containing GPUs)
- `gpu_per_node`: Number of gpus per node
@@ -89,6 +81,8 @@ Please make sure to substitute the placeholder variables:
- `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400)
- `rdzv_id`: A unique job ID that is used by the job across nodes.
The new CLI approach (Option 1) is recommended as it provides consistent argument handling and works seamlessly with other Axolotl CLI features.
::: {.callout-note}
You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood
:::
More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html)

View File

@@ -13,13 +13,10 @@ format:
- [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Voxtral](#sec-voxtral)
- [Gemma-3](#sec-gemma-3)
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
## Usage
@@ -34,7 +31,7 @@ skip_prepare_dataset: true
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
sample_packing: false # not yet supported with multimodal
chat_template: # see in next section if specified
chat_template: # see in next section
# example dataset
datasets:
@@ -100,16 +97,6 @@ base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
chat_template: mistral_v7_tekken
```
### Voxtral {#sec-voxtral}
::: {.callout-tip}
Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral_common[audio]==1.8.3'`
:::
```yaml
base_model: mistralai/Voxtral-Mini-3B-2507
```
### Gemma-3 {#sec-gemma-3}
::: {.callout-tip}
@@ -156,26 +143,6 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip}
Please make sure to install `num2words` via `pip3 install num2words==0.5.14`
:::
```yaml
base_model: HuggingFaceTB/SmolVLM2-500M-Video-Instruct
```
### LFM2-VL {#sec-lfm2-vl}
::: {.callout-warning}
Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
:::
```yaml
base_model: LiquidAI/LFM2-VL-450M
```
## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
@@ -214,20 +181,6 @@ You may need to install `librosa` via `pip3 install librosa==0.11.0`.
:::
### Video
::: {.callout-warning}
This is not well tested at the moment. We welcome contributors!
:::
For video loading, you can use the following keys within `content` alongside `"type": "video"`:
- `"path": "/path/to/video.mp4"`
- `"url": "https://example.com/video.mp4"`
- `"video": np.ndarray | list[PIL.Image.Image] | torch.Tensor` (or list of the aforementioned)
### Example
Here is an example of a multi-modal dataset:

View File

@@ -1,108 +0,0 @@
---
title: "N-D Parallelism (Beta)"
---
Axolotl enables training models at scale by composing different parallelism techniques. This is essential when:
- A model's weights are too large to fit on a single GPU's memory.
- A model's activations, especially with very long contexts, are too large for a single GPU.
- You want to accelerate training by using multiple GPUs or nodes.
or combinations of the above!
## Core Concepts
Parallelism strategies can be combined. The key is understanding how each one divides the workload. PyTorch's `DeviceMesh` is the modern way to manage these combinations, creating a logical grid of your GPUs and assigning different parallel strategies to different dimensions of the grid.
### Data Parallelism {#sec-dp}
Data Parallelism focuses on splitting the global data batch across GPUs.
- Distributed Data Parallel (DDP): The classic approach. The full model is replicated on every GPU. Each GPU processes a different slice of the data batch. Gradients are then averaged across all GPUs after the backward pass to keep the models synchronized. This can substantially improve data throughput compared to single-device training, but requires that each GPU is able to hold the entire model, its gradients, and optimizer states.
- [Fully Sharded Data Parallel (FSDP)](multi-gpu.qmd#fully-sharded-data-parallel-(fsdp)): A highly memory-efficient form of data parallelism (inspired by DeepSpeed's ZeRO). Instead of replicating the model, FSDP shards the model's *parameters, gradients, and optimizer states* across the GPUs in the data-parallel group. During computation, each GPU receives the specific parameters it needs via an `all_gather` operation just before they are used, and they can be discarded immediately after (`reshard-after-forward`).
- FSDP maps to ZeRO stages:
- ZeRO-2 (`reshard_after_forward=False`): Shards gradients and optimizer states. Model weights are replicated on each GPU.
- ZeRO-3 (`reshard_after_forward=True`): Shards gradients, optimizer states, AND model parameters. This provides the most memory savings at the cost of more communication (re-gathering parameters for both forward and backward passes).
### [Experimental] Tensor Parallelism (TP) {#sec-tp}
Also known as "horizontal model parallelism," as described in the [Megatron-LM paper](https://arxiv.org/pdf/1909.08053.pdf). Instead of splitting the batch, TP splits the model's layers themselves across GPUs.
- How it works: For a linear layer `Y = XA`, the weight matrix `A` is split column-wise (`A = [A_1, A_2]`). The computation becomes `Y_1 = XA_1` and `Y_2 = XA_2`, which can happen in parallel on different GPUs. The final output `Y` is simply the concatenation of `Y_1` and `Y_2`. Check [this comment](https://github.com/huggingface/transformers/issues/10321#issuecomment-783543530) for more detailed info.
- Requirement: TP involves frequent, small communications within a forward/backward pass. It requires a very fast interconnect between GPUs (e.g., NVLink) and is typically not recommended across different nodes.
### Context Parallelism (CP) {#sec-cp}
Context Parallelism, also called [Sequence Parallelism](sequence_parallelism.qmd), addresses the memory bottleneck from long sequences. The input sequence itself is split along the sequence length dimension and distributed across GPUs.
- How it works: If you have a sequence of 8192 tokens and a `context_parallel_size` of 4, each GPU will only handle a chunk of 2048 tokens.
- The Challenge: Attention is not local; every token needs to "attend to" every other token. Splitting the sequence breaks this.
- The Solution (`ring-flash-attention`): An efficient communication protocol is used. To compute attention for its local sequence chunk, each GPU passes its Key-Value (KV) cache to its neighbor in a "ring." After `N-1` steps, every GPU has seen the KV-cache from all other GPUs, allowing it to compute the correct attention values for its chunk. This is implemented using the highly optimized `flash-attention` kernel at each step.
### Hybrid Sharding Data Parallel (HSDP) {#sec-hsdp}
HSDP is a 2D strategy that intelligently combines FSDP and DDP, typically for multi-node training.
- Intra-Node (within a machine): Use FSDP. This is efficient because GPUs on the same node have fast interconnects (NVLink), making the `all_gather` operations for sharded parameters fast.
- Inter-Node (across machines): Use DDP. The gradient synchronization between nodes is less frequent than FSDP's parameter gathering, making it a better fit for the slower node-to-node network (e.g., Ethernet/Infiniband).
- Example: With 2 nodes of 8 GPUs each (16 total), you could have `dp_shard_size=8` (FSDP within each node) and `dp_replicate_size=2` (DDP across the two nodes).
## Usage
```yaml
# FSDP config. See https://docs.axolotl.ai/docs/multi-gpu.html#sec-fsdp
fsdp_version: 2
fsdp_config:
# ...
# The number of GPUs to shard the model parameters across (FSDP dimension).
dp_shard_size: 4
# The number of times to replicate the sharded model (DDP dimension).
dp_replicate_size: 2
# Number of GPUs for Tensor Parallelism.
tensor_parallel_size: 1 # (default is 1, no TP)
# Number of GPUs for Context/Sequence Parallelism.
context_parallel_size: 1 # (default is 1, no CP)
```
Note: We recommend FSDP. DeepSpeed is only compatible with `tensor_parallel_size`.
## Examples
::: {.callout-tip}
See our example configs [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/distributed-parallel).
:::
1. HSDP on 2 nodes with 4 GPUs each (8 GPUs total):
- You want FSDP within each node and DDP across nodes.
- Set `dp_shard_size: 4` and `dp_replicate_size: 2`.
2. FSDP + TP on a single 8-GPU node:
- You want to split the model across 4 GPUs using FSDP, and further split each layer across 2 GPUs with TP.
- Set `dp_shard_size: 4` and `tensor_parallel_size: 2`.
3. FSDP + CP on a single 8-GPU node for long context:
- You want to shard the model across all 8 GPUs and also split the sequence length across all 8 GPUs.
- Set `dp_shard_size: 8` and `context_parallel_size: 8`. Note: this means the data parallel group and context parallel group are the same. A more common setup might be to shard across a smaller group.
## Support Matrix
This matrix describes how different parallelism methods can be combined in Axolotl.
| Combination | `dp_replicate_size` | `dp_shard_size` | `tp_size` | `cp_size` | Status & Notes |
| --- | :---: | :---: |:---:|:---:|---|
| **FSDP** (ZeRO-3) | 1 | >1 | 1 | 1 | ✅ Fully supported. Shards model across all GPUs. |
| **HSDP** | >1 | >1 | 1 | 1 | ✅ Fully supported. FSDP intra-node, DDP inter-node. |
| **FSDP + TP** | 1 | >1 | >1 | 1 | ✅ **2D Parallelism**. Shards the model across a `dp_shard` group, and TP-splits layers within the `tp` group. |
| **HSDP + TP** | >1 | >1 | >1 | 1 | ✅ **3D Parallelism**. A powerful but complex combination. |
| **FSDP + CP** | 1 | >1 | 1 | >1 | ✅ **2D Parallelism**. Combines FSDP with context parallelism. |
| **FSDP + TP + CP**| 1 | >1 | >1| >1| ✅ **3D Parallelism**. Another advanced combination. |
| DDP + TP/CP | >1 | 1 | >1 | >1 | ❌ **Not Supported**. The `ParallelismConfig` explicitly prevents this, as composing pure DDP with TP or CP is currently not supported. You should use FSDP + TP/CP instead (`dp_shard_size > 1`). |
| Just TP / CP | 1 | 1 | >1 | >1 | ✅ Supported. Useful for inference or when the model fits on one GPU but context is too long. |
- `tp_size` refers to `tensor_parallel_size`
- `cp_size` refers to `context_parallel_size`

View File

@@ -1,129 +0,0 @@
---
title: Optimizers
description: Configuring optimizers
---
## Overview
Axolotl supports all optimizers supported by [transformers OptimizerNames](https://github.com/huggingface/transformers/blob/51f94ea06d19a6308c61bbb4dc97c40aabd12bad/src/transformers/training_args.py#L142-L187)
Here is a list of optimizers supported by transformers as of `v4.54.0`:
- `adamw_torch`
- `adamw_torch_fused`
- `adamw_torch_xla`
- `adamw_torch_npu_fused`
- `adamw_apex_fused`
- `adafactor`
- `adamw_anyprecision`
- `adamw_torch_4bit`
- `adamw_torch_8bit`
- `ademamix`
- `sgd`
- `adagrad`
- `adamw_bnb_8bit`
- `adamw_8bit` # alias for adamw_bnb_8bit
- `ademamix_8bit`
- `lion_8bit`
- `lion_32bit`
- `paged_adamw_32bit`
- `paged_adamw_8bit`
- `paged_ademamix_32bit`
- `paged_ademamix_8bit`
- `paged_lion_32bit`
- `paged_lion_8bit`
- `rmsprop`
- `rmsprop_bnb`
- `rmsprop_bnb_8bit`
- `rmsprop_bnb_32bit`
- `galore_adamw`
- `galore_adamw_8bit`
- `galore_adafactor`
- `galore_adamw_layerwise`
- `galore_adamw_8bit_layerwise`
- `galore_adafactor_layerwise`
- `lomo`
- `adalomo`
- `grokadamw`
- `schedule_free_radam`
- `schedule_free_adamw`
- `schedule_free_sgd`
- `apollo_adamw`
- `apollo_adamw_layerwise`
- `stable_adamw`
## Custom Optimizers
Enable custom optimizers by passing a string to the `optimizer` argument. Each optimizer will receive beta and epsilon args, however, some may accept additional args which are detailed below.
### optimi_adamw
```yaml
optimizer: optimi_adamw
```
### ao_adamw_4bit
Deprecated: Please use `adamw_torch_4bit`.
### ao_adamw_8bit
Deprecated: Please use `adamw_torch_8bit`.
### ao_adamw_fp8
```yaml
optimizer: ao_adamw_fp8
```
### adopt_adamw
GitHub: [https://github.com/iShohei220/adopt](https://github.com/iShohei220/adopt)
Paper: [https://arxiv.org/abs/2411.02853](https://arxiv.org/abs/2411.02853)
```yaml
optimizer: adopt_adamw
```
### came_pytorch
GitHub: [https://github.com/yangluo7/CAME/tree/master](https://github.com/yangluo7/CAME/tree/master)
Paper: [https://arxiv.org/abs/2307.02047](https://arxiv.org/abs/2307.02047)
```yaml
optimizer: came_pytorch
# optional args (defaults below)
adam_beta1: 0.9
adam_beta2: 0.999
adam_beta3: 0.9999
adam_epsilon: 1e-30
adam_epsilon2: 1e-16
```
### muon
Blog: [https://kellerjordan.github.io/posts/muon/](https://kellerjordan.github.io/posts/muon/)
Paper: [https://arxiv.org/abs/2502.16982v1](https://arxiv.org/abs/2502.16982v1)
```yaml
optimizer: muon
```
### dion
Microsoft's Dion (DIstributed OrthoNormalization) optimizer is a scalable and communication-efficient
orthonormalizing optimizer that uses low-rank approximations to reduce gradient communication.
GitHub: [https://github.com/microsoft/dion](https://github.com/microsoft/dion)
Paper: [https://arxiv.org/pdf/2504.05295](https://arxiv.org/pdf/2504.05295)
Note: Implementation written for PyTorch 2.7+ for DTensor
```yaml
optimizer: dion
dion_lr: 0.01
dion_momentum: 0.95
lr: 0.00001 # learning rate for embeddings and parameters that fallback to AdamW
```

View File

@@ -11,7 +11,6 @@ We support the reward modelling techniques supported by `trl`.
### (Outcome) Reward Models
Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).
For improved training stability, you can use the `center_rewards_coefficient` parameter to encourage mean-zero reward outputs ([see TRL docs](https://huggingface.co/docs/trl/v0.10.1/en/reward_trainer#centering-rewards)).
```yaml
base_model: google/gemma-2-2b

View File

@@ -47,6 +47,7 @@ class QuartoGenerator:
"""Check if a type is a Pydantic BaseModel."""
return inspect.isclass(type_obj) and issubclass(type_obj, BaseModel)
# pylint: disable=too-many-return-statements
def _extract_nested_type(self, field_type) -> Any:
"""Extract the actual type from complex type annotations."""
# Handle Annotated types (Python 3.9+)
@@ -123,6 +124,7 @@ class QuartoGenerator:
return field_type
# pylint: disable=too-many-return-statements
def _extract_all_pydantic_models_from_type(
self, field_type
) -> list[type[BaseModel]]:
@@ -316,6 +318,7 @@ class QuartoGenerator:
return all_groups
# pylint: disable=too-many-return-statements
def _extract_field_groups_from_source(
self, model_class: type[BaseModel]
) -> list[dict]:
@@ -500,7 +503,7 @@ class QuartoGenerator:
nested_schema = nested_model.model_json_schema()
nested_properties = nested_schema.get("properties", {})
nested_required = nested_schema.get("required", [])
except Exception:
except Exception: # pylint: disable=broad-exception-caught
# Fallback: use model fields directly
nested_properties = {}
nested_required = []
@@ -604,7 +607,7 @@ class QuartoGenerator:
schema = model_class.model_json_schema()
properties = schema.get("properties", {})
required = schema.get("required", [])
except Exception as e:
except Exception as e: # pylint: disable=broad-exception-caught
print(
f"Warning: Could not generate JSON schema ({e}). Using model fields instead."
)

View File

@@ -22,7 +22,7 @@ To enable sequence parallelism, add the following to your configuration file:
```yaml
# Set to a divisor (> 1) of the number of GPUs available
context_parallel_size: 4 # Split sequences across 4 GPUs
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -30,7 +30,7 @@ heads_k_stride: 1
ring_attn_func:
```
The `context_parallel_size` should be a divisor of the total number of GPUs. For example:
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
- With 8 GPUs, valid values would be 2, 4, or 8
- With 4 GPUs, valid values would be 2 or 4
@@ -66,7 +66,7 @@ sequence_len: 8192
...
context_parallel_size: 4 # Split each sequence into 4 parts, one per GPU
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
heads_k_stride: 1
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
@@ -89,12 +89,12 @@ Sequence parallelism is compatible with Axolotl's sample packing functionality.
## Effect on Batch Size
When using sequence parallelism, your effective global batch size is **divided** by the `context_parallel_size`. This happens because:
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
- Each group of `context_parallel_size` GPUs works on the same batch (just different parts of each sequence)
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
- The number of batches processed per step decreases
For example:
- With 8 GPUs and no sequence parallelism: 8 different batches processed per step
- With 8 GPUs and `context_parallel_size=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
- If your per-GPU `micro_batch_size` is 2, the global batch size decreases from 16 to 4

View File

@@ -1,120 +0,0 @@
---
title: Streaming Datasets
description: How to use streaming mode for large-scale datasets and memory-efficient training
order: 10
---
Streaming enables memory-efficient training with large datasets by loading data
incrementally rather than loading the entire dataset into memory at once.
Use streaming when:
- Your dataset is too large to fit in memory (e.g. when you're doing pretraining with massive text corpora)
- You want to start training immediately without preprocessing the entire dataset
Streaming works with both remote and locally stored datasets!
::: {.callout-note}
Streaming currently only supports a single dataset. Multi-dataset support will be added soon.
:::
## Configuration
### Basic Streaming
Enable streaming mode by setting the `streaming` flag:
```yaml
streaming: true
```
### Pretraining with Streaming
For pretraining tasks, streaming is automatically enabled when using `pretraining_dataset`:
```yaml
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
type: pretrain
text_column: text
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
### SFT with Streaming
For supervised fine-tuning with streaming:
```yaml
streaming: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
## Configuration Options
### `streaming_multipack_buffer_size`
Controls the buffer size for multipack streaming (default: 10,000). This determines how
many samples are buffered before packing. Larger buffers can improve packing efficiency
but use more memory.
### `shuffle_merged_datasets`
When enabled, shuffles the streaming dataset using the buffer. This requires additional
memory for the shuffle buffer.
## Sample Packing with Streaming
Sample packing is supported for streaming datasets. When enabled, multiple samples are
packed into a single sequence to maximize GPU utilization:
```yaml
sample_packing: true
streaming_multipack_buffer_size: 10000
# For SFT: attention is automatically isolated between packed samples
# For pretraining: control with pretrain_multipack_attn
pretrain_multipack_attn: true # prevent cross-attention between packed samples
```
For more information, see our [documentation](multipack.qmd) on multipacking.
## Important Considerations
### Memory Usage
While streaming reduces memory usage compared to loading entire datasets, you still need
to consider:
- You can control the memory usage by adjusting `streaming_multipack_buffer_size`
- Sample packing requires buffering multiple samples
- Shuffling requires additional memory for the shuffle buffer
### Performance
- Streaming may have slightly higher latency compared to preprocessed datasets, as samples are processed on-the-fly
- Network speed and disk read speed are important when streaming from remote sources or a local dataset, respectively
- Consider using `axolotl preprocess` for smaller or more frequently used datasets
### Evaluation Datasets
Evaluation datasets are not streamed to ensure consistent evaluation metrics. They're
loaded normally even when training uses streaming.
## Examples
See the `examples/streaming/` directory for complete configuration examples:
- `pretrain.yaml`: Pretraining with streaming dataset
- `sft.yaml`: Supervised fine-tuning with streaming

View File

@@ -1,58 +0,0 @@
# Finetune Liquid Foundation Models 2 (LFM2) with Axolotl
[Liquid Foundation Models 2 (LFM2)](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) are a family of small, open-weight models from [Liquid AI](https://www.liquid.ai/) focused on quality, speed, and memory efficiency. Liquid AI released text-only [LFM2](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) and text+vision [LFM2-VL](https://huggingface.co/collections/LiquidAI/lfm2-vl-68963bbc84a610f7638d5ffa) models.
LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-range convolutions, and grouped query attention, enabling fast training and inference.
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Run one of the finetuning examples below.
**LFM2**
```bash
# FFT SFT (1x48GB @ 25GiB)
axolotl train examples/LiquidAI/lfm2-350m-fft.yaml
```
**LFM2-VL**
```bash
# LoRA SFT (1x48GB @ 2.7GiB)
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
```
### TIPS
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
```bash
pip uninstall -y causal-conv1d
```
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
- **Dataset Formats**:
- For LFM2 models, the dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- For LFM2-VL models, Axolotl follows the multi-content Messages format. See our [Multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format) for details.
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,58 +0,0 @@
base_model: LiquidAI/LFM2-VL-450M
trust_remote_code: true
model_type: AutoModelForImageTextToText
processor_type: AutoProcessor
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -20,7 +20,7 @@ min_sample_len: 200_000
sample_packing: true
tiled_mlp: true
context_parallel_size: 8
sequence_parallel_degree: 8
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin

View File

@@ -1,53 +0,0 @@
# Finetune ArceeAI's AFM with Axolotl
[Arcee Foundation Models (AFM)](https://huggingface.co/collections/arcee-ai/afm-45b-68823397c351603014963473) are a family of 4.5B parameter open weight models trained by Arcee.ai.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the AFM model.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as AFM is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Run the finetuning example:
```bash
axolotl train examples/arcee/afm-4.5b-qlora.yaml
```
This config uses about 7.8GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### TIPS
- For inference, the official Arcee.ai team recommends `top_p: 0.95`, `temperature: 0.5`, `top_k: 50`, and `repeat_penalty: 1.1`.
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [AFM Blog](https://docs.arcee.ai/arcee-foundation-models/introduction-to-arcee-foundation-models)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,64 +0,0 @@
base_model: arcee-ai/AFM-4.5B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -47,6 +47,7 @@ logging_steps: 1
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_ratio: 0.1

View File

@@ -1,10 +0,0 @@
provider: baseten
project_name:
secrets:
- HF_TOKEN
- WANDB_API_KEY
gpu: h100
gpu_count: 8
node_count: 1

File diff suppressed because it is too large Load Diff

View File

@@ -10,14 +10,17 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Devstral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from pip:
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Ensure you have Pytorch installed (Pytorch 2.6.0+)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Run the finetuning example:

View File

@@ -1,52 +0,0 @@
# ND Parallelism Examples
This directory contains example configurations for training models using ND Parallelism in Axolotl. These examples demonstrate how to compose different parallelism strategies (FSDP, TP, CP, HSDP) for efficient multi-GPU training.
## Quick Start
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Run the command below:
```bash
# Train Qwen3 8B with FSDP + TP + CP on a single 8-GPU node
axolotl train examples/distributed-parallel/qwen3-8b-fsdp-tp-cp.yaml
# Train Llama 3.1 8B with HSDP + TP on 2 nodes (16 GPUs total)
axolotl train examples/distributed-parallel/llama-3_1-8b-hsdp-tp.yaml
```
## Example Configurations
### Single Node (8 GPUs)
**Qwen3 8B with FSDP + TP + CP** ([qwen3-8b-fsdp-tp-cp.yaml](./qwen3-8b-fsdp-tp-cp.yaml))
- Uses all 3 parallelism dimensions on a single node
- Ideal for: when model weights, activations, and/or context are too large to fit on single GPU
```yaml
dp_shard_size: 2 # FSDP across 2 GPUs
tensor_parallel_size: 2 # TP across 2 GPUs
context_parallel_size: 2 # CP across 2 GPUs
# Total: 2 × 2 × 2 = 8 GPUs
```
### Multi-Node
**Llama 3.1 8B with HSDP + TP** ([llama-3_1-8b-hsdp-tp.yaml](./llama-3_1-8b-hsdp-tp.yaml))
- FSDP & TP within nodes, DDP across nodes to minimize inter-node communication
- Ideal for: Scaling to multiple nodes while maintaining training efficiency
```yaml
dp_shard_size: 4 # FSDP within each 4-GPU group
tensor_parallel_size: 2 # TP within each node
dp_replicate_size: 2 # DDP across 2 groups
# Total: (4 × 2) × 2 = 16 GPUs (2 nodes)
```
## Learn More
- [ND Parallelism Documentation](https://docs.axolotl.ai/docs/nd_parallelism.html)
- [Blog: Accelerate ND-Parallel Guide](https://huggingface.co/blog/accelerate-nd-parallel)
- [Multi-GPU Training Guide](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,47 +0,0 @@
base_model: meta-llama/Llama-3.1-8B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
dp_shard_size: 4
dp_replicate_size: 2
tensor_parallel_size: 2
# context_parallel_size: 2
dataset_prepared_path: last_run_prepared
special_tokens:
pad_token: <|end_of_text|>
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
reshard_after_forward: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/ndp-out/
sequence_len: 2048
sample_packing: true
flash_attention: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch_fused
lr_scheduler: constant_with_warmup
learning_rate: 2e-6
bf16: true
tf32: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1

View File

@@ -1,46 +0,0 @@
base_model: Qwen/Qwen3-8B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
dp_shard_size: 2
# dp_replicate_size: 1
context_parallel_size: 2
tensor_parallel_size: 2
dataset_prepared_path: last_run_prepared
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3DecoderLayer
reshard_after_forward: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
output_dir: ./outputs/ndp-out/
sequence_len: 8192
sample_packing: true
flash_attention: true
gradient_accumulation_steps: 1
micro_batch_size: 1 # must be 1 when using context parallel
num_epochs: 2
optimizer: adamw_torch_fused
lr_scheduler: constant_with_warmup
learning_rate: 2e-6
bf16: true
tf32: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
special_tokens:

View File

@@ -1,68 +0,0 @@
base_model: google/gemma-3-270m-it
# optionally might have model_type or tokenizer_type
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
load_in_8bit: false
load_in_4bit: true
# huggingface repo
chat_template: gemma3
eot_tokens:
- <end_of_turn>
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -4,14 +4,17 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Gemma3n is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from pip:
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
# Ensure you have Pytorch installed (Pytorch 2.6.0 min recommended)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. In addition to Axolotl's requirements, Gemma-3n requires:

View File

@@ -1,125 +0,0 @@
# Finetune OpenAI's GPT-OSS with Axolotl
[GPT-OSS](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4) are a family of open-weight MoE models trained by OpenAI, released in August 2025. There are two variants: 20B and 120B.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
```bash
# LoRA SFT linear layers (1x48GB @ ~44GiB)
axolotl train examples/gpt-oss/gpt-oss-20b-sft-lora-singlegpu.yaml
# FFT SFT with offloading (2x24GB @ ~21GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2-offload.yaml
# FFT SFT (8x48GB @ ~36GiB/GPU or 4x80GB @ ~46GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-20b-fft-fsdp2.yaml
```
Note: Memory usage taken from `device_mem_reserved(gib)` from logs.
### Training 120B
On 8xH100s, make sure you have ~3TB of free disk space. With each checkpoint clocking in at ~720GB, along with the base
model, and final model output, you may need at least 3TB of free disk space to keep at least 2 checkpoints.
```bash
# FFT SFT with offloading (8x80GB @ ~49GiB/GPU)
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
```
To simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with [Baseten](https://baseten.co) to showcase multi-node
training of the 120B model using Baseten Truss. You can read more about this recipe on
[Baseten's blog](https://www.baseten.co/blog/how-to-fine-tune-gpt-oss-120b-with-baseten-and-axolotl/). The recipe can
be found on their
[GitHub](https://github.com/basetenlabs/ml-cookbook/tree/main/examples/oss-gpt-120b-axolotl/training).
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
```bash
sed -i 's/FSDPGptOssForCausalLM/GptOssForCausalLM/g' ./outputs/gpt-oss-out/config.json
```
When using SHARDED_STATE_DICT with FSDP, the final checkpoint should automatically merge the sharded weights to your
configured `output_dir`. However, if that step fails due to a disk space error, you can take an additional step to
merge the sharded weights. This step will automatically determine the last checkpoint directory and merge the sharded
weights to `{output_dir}/merged`.
```bash
axolotl merge-sharded-fsdp-weights examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
```
### Inferencing your fine-tuned model
#### vLLM
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
for more information about using a special vllm-openai docker image for inferencing with vLLM.
Optionally, vLLM can be installed from nightly:
```bash
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
```
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
```bash
vllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888 --tensor-parallel-size 8
```
#### SGLang
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:
```bash
python3 -m sglang.launch_server --model ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-120b --host 0.0.0.0 --port 8888 --tp 8
```
### Tool use
GPT-OSS has a comprehensive tool understanding. Axolotl supports tool calling datasets for Supervised Fine-tuning.
Here is an example dataset config:
```yaml
datasets:
- path: Nanobit/text-tools-2k-test
type: chat_template
```
See [Nanobit/text-tools-2k-test](https://huggingface.co/datasets/Nanobit/text-tools-2k-test) for the sample dataset.
Refer to [our docs](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#using-tool-use) for more info.
### TIPS
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [GPT-OSS Blog](https://openai.com/index/introducing-gpt-oss/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,68 +0,0 @@
# the original mxfp4 quantized model is not supported with FSDP cpu_ram_efficient_loading
# FSDP cpu_ram_efficient_loading is used to reduce the initial CPU memory usage when loading the model
base_model: axolotl-ai-co/gpt-oss-120b-dequantized
use_kernels: false
dp_shard_size: 16 # requires 2x8xH100 nodes
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
save_total_limit: 2 # the 120B model can use up to 720GB of disk space per checkpoint, so let's only keep the last 2
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
fsdp_version: 2
fsdp_config:
offload_params: true
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true
cpu_ram_efficient_loading: true

View File

@@ -1,58 +0,0 @@
base_model: openai/gpt-oss-20b
use_kernels: false
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
# choose the zero3 configuration that best fits your system capabilities
deepspeed: deepspeed_configs/zero3_bf16.json

View File

@@ -1,68 +0,0 @@
base_model: openai/gpt-oss-20b
use_kernels: true
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_fused # 8bit optimizers do not work with FSDP2 offload
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
fsdp_version: 2
fsdp_config:
offload_params: true
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true
# cpu_ram_efficient_loading: true
# cpu_ram_efficient_loading cannot be used with MXFP4 model quantization.
# It can only be used with a dequantized model like `axolotl-ai-co/gpt-oss-120b-dequantized`

View File

@@ -1,64 +0,0 @@
base_model: openai/gpt-oss-20b
use_kernels: false
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: ./outputs/last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-5
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.03
special_tokens:
eot_tokens:
- "<|end|>"
fsdp_version: 2
fsdp_config:
offload_params: false
state_dict_type: SHARDED_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: GptOssDecoderLayer
reshard_after_forward: true
# cpu_ram_efficient_loading: true

View File

@@ -1,67 +0,0 @@
base_model: openai/gpt-oss-20b
use_kernels: true
model_quantization_config: Mxfp4Config
model_quantization_config_kwargs:
dequantize: true
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by not putting model to GPU before sharding
datasets:
- path: HuggingFaceH4/Multilingual-Thinking
type: chat_template
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
sequence_len: 4096
sample_packing: true
adapter: lora
lora_r: 8
lora_alpha: 16
lora_dropout: 0.0 # dropout not supported when using LoRA over expert parameters
lora_target_linear: true
# TODO: not supported for now, see peft#2710
#lora_target_parameters: # target the experts in the last two layers
# - "22._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "22._checkpoint_wrapped_module.mlp.experts.down_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.gate_up_proj"
# - "23._checkpoint_wrapped_module.mlp.experts.down_proj"
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: constant_with_warmup
learning_rate: 2e-4
bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
gradient_checkpointing: true
activation_offloading: true
logging_steps: 1
saves_per_epoch: 1
warmup_ratio: 0.1
special_tokens:
eot_tokens:
- "<|end|>"

7
examples/lfm2/README.md Normal file
View File

@@ -0,0 +1,7 @@
# Liquid Foundation Models 2
LFM2 support in transformers exists in the main branch, but is not yet included in the transformers release.
```bash
pip install --upgrade --no-deps --force-reinstall git+https://github.com/huggingface/transformers.git
```

View File

@@ -2,6 +2,7 @@ base_model: LiquidAI/LFM2-350M
chunked_cross_entropy: true
chat_template: tokenizer_default
eot_tokens:
- "<|im_end|>"
datasets:

View File

@@ -45,6 +45,7 @@ logging_steps: 1
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_ratio: 0.1

View File

@@ -49,6 +49,7 @@ logging_steps: 1
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_ratio: 0.1

View File

@@ -25,12 +25,9 @@ lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
relora: true
relora_prune_ratio: 0.9
relora_steps: 150
relora_warmup_ratio: 0.1
relora_cpu_offload: false
jagged_restart_steps: 150
jagged_restart_warmup_steps: 10
jagged_restart_anneal_steps: false
wandb_project:
wandb_entity:

View File

@@ -8,14 +8,17 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from pip:
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Run the finetuning example:

View File

@@ -27,6 +27,7 @@ sequence_len: 2048
sample_packing: true
eval_sample_packing: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -26,6 +26,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -26,6 +26,7 @@ lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05

View File

@@ -1,44 +0,0 @@
base_model: Skywork/Skywork-Reward-V2-Qwen3-8B
model_type: AutoModelForSequenceClassification
num_labels: 1
reward_model: true
center_rewards_coefficient: 0.01 # Incentivize mean-zero rewards for improved stability
chat_template: qwen3
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
deepspeed: deepspeed_configs/zero1.json
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
eval_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: linear
learning_rate: 0.00002
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
warmup_ratio: 0.1
logging_steps: 1
weight_decay: 0.01

View File

@@ -1,66 +0,0 @@
# SLURM Multi-Node Training
This directory contains an example SLURM script for running Axolotl training jobs across multiple nodes in a SLURM cluster.
## Prerequisites
- Access to a SLURM cluster with GPU nodes
- Axolotl installed on all nodes (see [installation docs](https://docs.axolotl.ai/docs/installation.html))
## Usage
### Standard SLURM Clusters
1. Copy [`axolotl.slurm`](./axolotl.slurm) to your working directory.
2. Place your Axolotl config file (`train.yaml`) in the same directory.
3. Set the appropriate environment variables for the job:
```bash
export HF_TOKEN="your-huggingface-token"
# metric tracking
# export WANDB_API_KEY="your-wandb-api-key"
# ...
```
4. Submit the job:
```bash
sbatch --export=ALL,NUM_NODES=2,NUM_TRAINERS=8,PRIMARY_ADDR=<master-node>,PRIMARY_PORT=29400 axolotl.slurm
```
Where:
- `NUM_NODES`: Number of nodes to use
- `NUM_TRAINERS`: GPUs per node (typically 8)
- `PRIMARY_ADDR`: Hostname/IP of the master node
- `PRIMARY_PORT`: Port for distributed training (default: 29400)
5. (Optional) Run other slurm commands:
```bash
# check job info
scontrol show job axolotl-cli
# check job queue
squeue
# check cluster status
sinfo
```
### RunPod Instant Clusters
Axolotl works with RunPod Instant Clusters. This feature provides managed SLURM clusters with zero configuration.
1. **Deploy a SLURM Cluster**:
- Go to [RunPod Instant Clusters](https://console.runpod.io/cluster)
- Click "Create a Cluster"
- Choose your GPU type, node count, and region
- Choose an [Axolotl cloud docker image](https://docs.axolotl.ai/docs/docker.html#cloud)
- Deploy the cluster
2. **Connect to the Controller Node**: Find the controller node in the RunPod console and connect via SSH
3. **Follow the instructions in [Standard SLURM Clusters](#standard-slurm-clusters)**
## Additional Resources
- [Axolotl Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [SLURM Documentation](https://slurm.schedmd.com/documentation.html)
- [RunPod SLURM Clusters Guide](https://docs.runpod.io/instant-clusters/slurm-clusters)

View File

@@ -1,20 +0,0 @@
#!/bin/bash
# Prior to running this script, export your HF_TOKEN and WANDB_API_KEY to your environment; i.e.
# export HF_TOKEN="..."
# export WANDB_API_KEY="..."
#
# ---------- SBATCH commands ---------- #
#SBATCH --job-name=axolotl-slurm-multinode
#SBATCH --ntasks-per-node=1
#SBATCH --nodes=$NUM_NODES
#SBATCH --gpus-per-task=8
#SBATCH --cpus-per-task=128
export TORCH_DIST_INIT_BARRIER=0
srun axolotl preprocess train.yaml
srun axolotl train train.yaml --launcher torchrun -- \
--nproc_per_node=$NUM_TRAINERS --nnodes=$NUM_NODES \
--rdzv_id axolotl-cli --rdzv_backend c10d --rdzv_endpoint "${PRIMARY_ADDR}:${PRIMARY_PORT}" --rdzv-conf="join_timeout=1800"

View File

@@ -1,49 +0,0 @@
# Finetune SmolVLM2 with Axolotl
[SmolVLM2](https://huggingface.co/collections/HuggingFaceTB/smolvlm2-smallest-video-lm-ever-67ab6b5e84bf8aaa60cb17c7) are a family of lightweight, open-source multimodal models from HuggingFace designed to analyze and understand video, image, and text content.
These models are built for efficiency, making them well-suited for on-device applications where computational resources are limited. Models are available in multiple sizes, including 2.2B, 500M, and 256M.
This guide shows how to fine-tune SmolVLM2 models with Axolotl.
## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
Here is an example of how to install from pip:
```bash
# Ensure you have a compatible version of Pytorch installed
pip3 install packaging setuptools wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install an extra dependency:
```bash
pip3 install num2words==0.5.14
```
3. Run the finetuning example:
```bash
# LoRA SFT (1x48GB @ 6.8GiB)
axolotl train examples/smolvlm2/smolvlm2-2B-lora.yaml
```
## TIPS
- **Dataset Format**: For video finetuning, your dataset must be compatible with the multi-content Messages format. For more details, see our documentation on [Multimodal Formats](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
- **Dataset Loading**: Read more on how to prepare and load your own datasets in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [SmolVLM2 Blog](https://huggingface.co/blog/smolvlm2)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,56 +0,0 @@
base_model: HuggingFaceTB/SmolVLM2-2.2B-Instruct
trust_remote_code: true
processor_type: AutoProcessor
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.text_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,50 +0,0 @@
# Streaming Dataset Examples
This directory contains example configurations for using Axolotl's streaming dataset
functionality, which enables memory-efficient training with large datasets.
## Examples
Run the following examples with e.g. `axolotl train examples/streaming/sft.yaml`; no
`axolotl preprocess` required!
### Pretraining (`pretrain.yaml`)
Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset
with SmolLM2-135M.
- Uses `pretraining_dataset` configuration for automatic streaming
- Multipack attention control to prevent cross-attention between packed sequences
- Buffer size configuration for memory management
### SFT (`sft.yaml`)
Shows how to use streaming for supervised fine-tuning with the Alpaca dataset.
- Explicit `streaming: true` flag for SFT datasets
- Memory-efficient training on instruction datasets
- Evaluation datasets are currently not streamed
## Key Configuration Options
### `streaming`
- Enables streaming mode for standard datasets
- Automatically enabled for `pretraining_dataset`
### `streaming_multipack_buffer_size`
- Controls buffer size for sample packing (default: 10,000)
- Larger values improve packing efficiency but use more memory
- Adjust based on available memory
### `shuffle_merged_datasets`
- Enables shuffling of streaming datasets
- Requires additional memory for shuffle buffer
### `sample_packing`
- Packs multiple samples into single sequences
- Minimize per-step padding tokens
## Performance Tips
- Download small / frequently-used datasets locally for better performance
- Larger buffer sizes improve packing efficiency

View File

@@ -1,57 +0,0 @@
base_model: HuggingFaceTB/SmolLM2-135M
# Streaming pretraining configuration
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
name: sample-10BT
type: pretrain
text_column: text
split: train
# Streaming-specific settings
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true
# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-pretrain-streaming
# Sequence and packing settings
sequence_len: 1024
sample_packing: true
pretrain_multipack_attn: true # Prevent cross-attention between packed sequences
flash_attention: true
# Batch size settings
gradient_accumulation_steps: 8
micro_batch_size: 1
# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 5e-4
warmup_ratio: 0.1
weight_decay: 0.01
# Precision and performance
bf16: auto
tf32: true
# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 250
save_total_limit: 3
# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# Special tokens
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,55 +0,0 @@
base_model: HuggingFaceTB/SmolLM2-135M
# Dataset configuration
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Streaming-specific settings
streaming: true
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true
# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-sft-streaming
# Sequence and packing settings
sequence_len: 1024
sample_packing: true
flash_attention: true
# Batch size settings
gradient_accumulation_steps: 4
micro_batch_size: 1
# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-4
warmup_ratio: 0.1
weight_decay: 0.0
# Precision and performance
bf16: auto
tf32: true
# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 100
save_total_limit: 3
# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# Special tokens
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -6,14 +6,17 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Voxtral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from pip:
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
pip3 install --no-build-isolation -e '.[flash-attn]'
```
2. Please install the below.

View File

@@ -26,34 +26,3 @@ include-package-data = true
[tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
[tool.ruff]
line-length = 88
target-version = "py310"
[tool.ruff.lint]
select = ["E", "F", "W", "C90", "B"]
ignore = [
"E203", # Whitespace before ':'
"E501", # Line too long
"C901", # Too complex
"B019", # Use of functools.cache on methods
"E722", # Bare except
"F821", # Undefined name (for dynamic exec)
]
[tool.ruff.lint.isort]
known-third-party = ["wandb", "comet_ml"]
known-local-folder = ["src", "tests"]
# Black-compatible isort settings
force-single-line = false
combine-as-imports = true
split-on-trailing-comma = true
[tool.ruff.format]
# Use black's formatting style exactly
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false

View File

@@ -1,32 +1,30 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.47.0
bitsandbytes==0.46.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.6.1
liger-kernel==0.6.0
# END section
packaging==23.2
huggingface_hub>=0.33.0
peft>=0.17.0
transformers==4.56.1
peft==0.16.0
transformers==4.54.0
tokenizers>=0.21.1
accelerate==1.10.0
accelerate==1.9.0
datasets==4.0.0
deepspeed>=0.17.0
trl==0.21.0
trl==0.19.1
hf_xet==1.1.5
kernels==0.9.0
trackio
optimum==1.16.2
hf_transfer
sentencepiece
gradio==5.41.1
gradio==5.23.3
modal==1.0.2
pydantic==2.10.6
@@ -68,6 +66,6 @@ torchao==0.12.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.5
axolotl-contribs-mit==0.0.3
mistral-common==1.8.3

View File

@@ -27,7 +27,7 @@ def parse_dataset(dataset=None, split="train"):
break
if not field_messages:
raise ValueError(
f"No conversation field found in dataset: {', '.join(feature_keys)}"
f'No conversation field found in dataset: {", ".join(feature_keys)}'
)
ds_cfg["field_messages"] = field_messages
@@ -40,7 +40,7 @@ def parse_dataset(dataset=None, split="train"):
break
if not message_property_mappings["role"]:
raise ValueError(
f"No role field found in messages: {', '.join(message_fields)}"
f'No role field found in messages: {", ".join(message_fields)}'
)
for key in ["content", "text", "value"]:
@@ -49,7 +49,7 @@ def parse_dataset(dataset=None, split="train"):
break
if not message_property_mappings["content"]:
raise ValueError(
f"No content field found in messages: {', '.join(message_fields)}"
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_property_mappings"] = message_property_mappings

View File

@@ -44,13 +44,8 @@ add_keys_to_authorized() {
chmod 700 -R ~/.ssh
}
# Set SSH port
if [ ! -z "$SSH_PORT" ]; then
sed -i "s/#Port 22/Port $SSH_PORT/" /etc/ssh/sshd_config
fi
if [[ $PUBLIC_KEY ]]; then
# runpod, prime intellect
# runpod
add_keys_to_authorized "$PUBLIC_KEY"
# Start the SSH service in the background
service ssh start
@@ -81,13 +76,5 @@ if [ ! -L "/workspace/axolotl/outputs" ]; then
ln -sf /workspace/data/axolotl-artifacts /workspace/axolotl/outputs
fi
# start the runpod slurm init
SLURM_INIT="${SLURM_INIT:-/slurm-init.sh}"
if [[ -f "$SLURM_INIT" ]]; then
echo "[entrypoint] running $SLURM_INIT..."
bash "$SLURM_INIT"
fi
# Execute the passed arguments (CMD)
exec "$@"

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@010c3ac3f1e725098961832830303eeb4142dd88"'
)

View File

@@ -1,10 +1,11 @@
# noqa
# pylint: skip-file
import sys
try:
import torch
except ImportError as error:
raise ImportError("Install torch via `pip install torch`") from error
except ImportError:
raise ImportError("Install torch via `pip install torch`")
from packaging.version import Version as V
use_uv = "--uv" in sys.argv[1:]

View File

@@ -64,9 +64,7 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 8):
pass
elif (major, minor) >= (2, 7):
if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.30")
@@ -74,13 +72,12 @@ def parse_requirements(extras_require_map):
extras_require_map.pop("vllm")
else:
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm>=0.10.0"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map.pop("vllm")
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
@@ -120,14 +117,14 @@ def get_package_version():
extras_require = {
"flash-attn": ["flash-attn==2.8.3"],
"flash-attn": ["flash-attn==2.8.2"],
"ring-flash-attn": [
"flash-attn==2.8.3",
"flash-attn==2.8.2",
"ring-flash-attn>=0.1.7",
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.17.5",
"deepspeed==0.17.2",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -4,4 +4,4 @@ import pkgutil
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
__version__ = "0.13.0.dev"
__version__ = "0.12.0.dev"

View File

@@ -14,13 +14,9 @@ class PreprocessCliArgs:
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=False,
default=None,
metadata={
"help": (
"Deprecated in v0.13.0, will be removed in v0.14.0. For streaming "
"datasets, use 'axolotl train' and set 'streaming: true' in your YAML "
"config, or pass --streaming instead in the CLI."
)
"help": "Use IterableDataset for streaming processing of large datasets"
},
)
@@ -34,6 +30,8 @@ class TrainerCliArgs:
debug_num_examples: int = field(default=0)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
main_process_port: Optional[int] = field(default=None)
num_processes: Optional[int] = field(default=None)
@dataclass
@@ -44,12 +42,6 @@ class VllmServeCliArgs:
default=None,
metadata={"help": "Number of tensor parallel workers to use."},
)
data_parallel_size: Optional[int] = field(
default=None,
metadata={
"help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference."
},
)
host: Optional[str] = field(
default=None, # nosec B104
metadata={"help": "Host address to run the server on."},

View File

@@ -22,7 +22,7 @@ HAS_PRINTED_LOGO = False
def print_axolotl_text_art():
"""Prints axolotl ASCII art."""
global HAS_PRINTED_LOGO
global HAS_PRINTED_LOGO # pylint: disable=global-statement
if HAS_PRINTED_LOGO:
return
if is_main_process():

View File

@@ -3,17 +3,15 @@ launch axolotl in supported cloud platforms
"""
from pathlib import Path
from typing import Literal
from typing import Union
import yaml
from axolotl.cli.cloud.base import Cloud
from axolotl.cli.cloud.baseten import BasetenCloud
from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
def load_cloud_cfg(cloud_config: Path | str) -> DictDefault:
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
"""Load and validate cloud configuration."""
# Load cloud configuration.
with open(cloud_config, encoding="utf-8") as file:
@@ -22,8 +20,8 @@ def load_cloud_cfg(cloud_config: Path | str) -> DictDefault:
def do_cli_preprocess(
cloud_config: Path | str,
config: Path | str,
cloud_config: Union[Path, str],
config: Union[Path, str],
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
@@ -33,39 +31,25 @@ def do_cli_preprocess(
def do_cli_train(
cloud_config: Path | str,
config: Path | str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
cloud_config: Union[Path, str],
config: Union[Path, str],
accelerate: bool = True,
cwd=None,
**kwargs,
) -> None:
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
provider = cloud_cfg.provider or "modal"
cloud: Cloud | None
if provider == "modal":
cloud = ModalCloud(cloud_cfg)
elif provider == "baseten":
cloud = BasetenCloud(cloud_cfg.to_dict())
else:
raise ValueError(f"Unsupported cloud provider: {provider}")
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
local_dirs = {}
if cwd and not Path(cwd).joinpath("src", "axolotl").exists():
local_dirs = {"/workspace/mounts": cwd}
cloud.train(
config_yaml,
launcher=launcher,
launcher_args=launcher_args,
local_dirs=local_dirs,
**kwargs,
)
cloud.train(config_yaml, accelerate=accelerate, local_dirs=local_dirs, **kwargs)
def do_cli_lm_eval(
cloud_config: Path | str,
config: Path | str,
cloud_config: Union[Path, str],
config: Union[Path, str],
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)

View File

@@ -3,7 +3,6 @@ base class for cloud platforms from cli
"""
from abc import ABC, abstractmethod
from typing import Literal
class Cloud(ABC):
@@ -16,12 +15,5 @@ class Cloud(ABC):
pass
@abstractmethod
def train(
self,
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None,
**kwargs,
):
def train(self, config_yaml: str, accelerate: bool = True) -> str:
pass

View File

@@ -1,48 +0,0 @@
"""Baseten Cloud CLI"""
import shutil
import subprocess # nosec B404
import tempfile
from os.path import dirname
from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
class BasetenCloud(Cloud):
"""Baseten Cloud Axolotl CLI"""
def __init__(self, config: dict):
self.config = config
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
raise NotImplementedError(
"Separate preprocess function for Baseten is not "
"implemented and will happen during hte train step."
)
def train(
self,
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None, # pylint: disable=unused-argument
**kwargs,
):
with tempfile.TemporaryDirectory() as tmp_dir:
config = self.config.copy()
config["launcher"] = launcher
config["launcher_args"] = launcher_args
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
yaml.dump(config, cloud_fout)
with open(tmp_dir + "/train.yaml", "w", encoding="utf-8") as config_fout:
config_fout.write(config_yaml)
shutil.copyfile(dirname(__file__) + "/template/run.sh", tmp_dir + "/run.sh")
shutil.copyfile(
dirname(__file__) + "/template/train_sft.py", tmp_dir + "/train_sft.py"
)
subprocess.run( # nosec B603 B607
["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False
)

View File

@@ -1,9 +0,0 @@
#!/bin/bash
set -eux
export NCCL_SOCKET_IFNAME="^docker0,lo"
export NCCL_IB_DISABLE=0
export NCCL_TIMEOUT=1800000
axolotl preprocess train.yaml
axolotl train train.yaml --launcher ${AXOLOTL_LAUNCHER} ${AXOLOTL_LAUNCHER_ARGS}

View File

@@ -1,71 +0,0 @@
"""
Baseten Training Script for Axolotl
"""
# pylint: skip-file
import yaml
from truss.base import truss_config
# Import necessary classes from the Baseten Training SDK
from truss_train import definitions
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
gpu = cloud_config.get("gpu", "h100")
gpu_count = int(cloud_config.get("gpu_count", 1))
node_count = int(cloud_config.get("node_count", 1))
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
secrets = cloud_config.get("secrets", [])
launcher = cloud_config.get("launcher", "accelerate")
launcher_args = cloud_config.get("launcher_args", [])
script_name = "run.sh"
launcher_args_str = ""
if launcher_args:
launcher_args_str = "-- " + " ".join(launcher_args)
# 1. Define a base image for your training job
# must use torch 2.7.0 for vllm
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
# 2. Define the Runtime Environment for the Training Job
# This includes start commands and environment variables.a
# Secrets from the baseten workspace like API keys are referenced using
# `SecretReference`.
env_vars = {
"AXOLOTL_LAUNCHER": launcher,
"AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
}
for secret_name in secrets:
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
training_runtime = definitions.Runtime(
start_commands=[ # Example: list of commands to run your training script
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
],
environment_variables=env_vars,
)
# 3. Define the Compute Resources for the Training Job
training_compute = definitions.Compute(
node_count=node_count,
accelerator=truss_config.AcceleratorSpec(
accelerator=truss_config.Accelerator.H100,
count=gpu_count,
),
)
# 4. Define the Training Job
# This brings together the image, compute, and runtime configurations.
my_training_job = definitions.TrainingJob(
image=definitions.Image(base_image=BASE_IMAGE),
compute=training_compute,
runtime=training_runtime,
)
# This config will be pushed using the Truss CLI.
# The association of the job to the project happens at the time of push.
first_project_with_job = definitions.TrainingProject(
name=project_name, job=my_training_job
)

View File

@@ -8,7 +8,7 @@ import os
import subprocess # nosec B404
from pathlib import Path
from random import randint
from typing import Literal
from typing import Optional
import modal
@@ -41,7 +41,7 @@ def run_cmd(cmd: str, run_folder: str, volumes=None):
if exit_code := subprocess.call( # nosec B603
cmd.split(), cwd=run_folder, env=new_env
):
exit(exit_code)
exit(exit_code) # pylint: disable=consider-using-sys-exit
# Commit writes to volume.
if volumes:
@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res
def get_image(self):
docker_tag = "main-py3.11-cu126-2.7.1"
docker_tag = "main-py3.11-cu124-2.6.0"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"
@@ -130,6 +130,7 @@ class ModalCloud(Cloud):
res = []
if self.config.secrets:
for key in self.config.get("secrets", []):
# pylint: disable=duplicate-code
if isinstance(key, str):
if val := os.environ.get(key, ""):
res.append(modal.Secret.from_dict({key: val}))
@@ -176,8 +177,8 @@ class ModalCloud(Cloud):
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
*args,
volumes={k: v[0] for k, v in self.volumes.items()},
*args,
**kwargs,
)
@@ -186,7 +187,7 @@ class ModalCloud(Cloud):
return int(self.config.timeout)
return 60 * 60 * 24 # 24 hours
def get_train_gpu(self):
def get_train_gpu(self): # pylint: disable=too-many-return-statements
count = self.config.gpu_count or 1
family = self.config.gpu.lower() or "l40s"
@@ -199,7 +200,7 @@ class ModalCloud(Cloud):
if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count)
if family == "h100":
return f"H100:{count}"
return modal.gpu.H100(count=count)
if family == "t4":
return modal.gpu.T4(count=count)
if family == "l4":
@@ -229,9 +230,8 @@ class ModalCloud(Cloud):
def train(
self,
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None,
accelerate: bool = True,
local_dirs: Optional[dict[str, str]] = None,
**kwargs,
):
modal_fn = self.get_train_env(local_dirs)(_train)
@@ -239,8 +239,7 @@ class ModalCloud(Cloud):
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
launcher=launcher,
launcher_args=launcher_args,
accelerate=accelerate,
volumes={k: v[0] for k, v in self.volumes.items()},
**kwargs,
)
@@ -271,35 +270,20 @@ def _preprocess(config_yaml: str, volumes=None):
)
def _train(
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
volumes=None,
**kwargs,
):
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/mounts"
launcher_args = launcher_args or []
# Build the base command
if launcher == "accelerate":
launcher_arg = "--launcher accelerate"
elif launcher == "torchrun":
launcher_arg = "--launcher torchrun"
if accelerate:
accelerate_args = "--accelerate"
else:
launcher_arg = "--launcher python"
# Build launcher args string
launcher_args_str = ""
if launcher_args:
launcher_args_str = "-- " + " ".join(launcher_args)
accelerate_args = "--no-accelerate"
num_processes_args = ""
if num_processes := kwargs.pop("num_processes", None):
num_processes_args = f"--num-processes {num_processes}"
run_cmd(
f"axolotl train {launcher_arg} /workspace/mounts/config.yaml {launcher_args_str}".strip(),
f"axolotl train {accelerate_args} {num_processes_args} /workspace/mounts/config.yaml",
run_folder,
volumes,
)

View File

@@ -153,14 +153,15 @@ def prepare_plugins(cfg: DictDefault):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
for plugin in plugin_manager.plugins.values():
plugin.register(cfg)
def plugin_set_cfg(cfg: DictDefault):
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
plugin_manager.cfg = cfg
# now that we have the finalized cfg, register the plugins individually
for plugin in plugin_manager.plugins.values():
plugin.register(cfg)
def load_cfg(
@@ -199,18 +200,19 @@ def load_cfg(
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
for key, value in kwargs.items():
# If not strict, allow writing to cfg even if it's not in the yml already
if key in cfg_keys or not cfg.strict:
if isinstance(cfg[key], bool):
cfg[key] = bool(value)
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[key] = value
cfg[k] = kwargs[k]
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except:
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
prepare_plugins(cfg)

View File

@@ -9,6 +9,7 @@ from typing import Generator, Union
import fire
import torch
from accelerate import init_empty_weights
from dotenv import load_dotenv
from transformers import AutoProcessor
@@ -151,4 +152,5 @@ def do_cli(model: Union[Path, str], output: Union[Path, str]) -> None:
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -5,6 +5,7 @@ from pathlib import Path
from typing import Union
import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs
@@ -12,6 +13,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate
from axolotl.utils import patch_optimized_env
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -28,7 +30,10 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
# pylint: disable=duplicate-code
check_accelerate_default_config()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
@@ -49,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
@@ -59,4 +64,5 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -9,12 +9,16 @@ from typing import Union
import fire
import torch
import transformers
from dotenv import load_dotenv
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -32,7 +36,7 @@ def get_multi_line_input() -> str:
instruction = ""
for line in sys.stdin:
instruction += line
instruction += line # pylint: disable=consider-using-join
return instruction
@@ -61,9 +65,7 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template_from_config(
cfg, ds_cfg=None, tokenizer=tokenizer
)
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
@@ -158,13 +160,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template_from_config(
cfg, ds_cfg=None, tokenizer=tokenizer
)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -172,6 +168,7 @@ def do_inference_gradio(
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
@@ -256,7 +253,7 @@ def do_cli(
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, inference=True, rl=None, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser(InferenceCliArgs)
@@ -271,4 +268,5 @@ def do_cli(
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -1,10 +1,15 @@
"""Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name
import os
import subprocess # nosec B404
from typing import Literal, Optional
import tempfile
from pathlib import Path
from typing import Optional
import click
import yaml
from dotenv import load_dotenv
import axolotl
@@ -16,14 +21,13 @@ from axolotl.cli.args import (
VllmServeCliArgs,
)
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.sweeps import generate_sweep_configs
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
build_command,
fetch_from_github,
filter_none_kwargs,
generate_config_files,
launch_training,
)
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import patch_optimized_env
@@ -32,19 +36,12 @@ from axolotl.utils.schemas.config import AxolotlInputConfig
LOG = get_logger(__name__)
LAUNCHER_COMMAND_MAPPING = {
"accelerate": ["accelerate", "launch"],
"torchrun": ["torchrun"],
}
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli():
"""Axolotl CLI - Train and fine-tune large language models"""
print_axolotl_text_art()
load_dotenv()
patch_optimized_env()
@cli.command()
@@ -53,7 +50,7 @@ def cli():
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
"""
Preprocess datasets before training.
@@ -63,6 +60,7 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
patch_optimized_env()
if cloud:
from axolotl.cli.cloud import do_cli_preprocess
@@ -74,15 +72,12 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
do_cli(config=config, **kwargs)
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for multi-GPU training",
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
)
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@click.option(
@@ -93,82 +88,126 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@click.pass_context
def train(
ctx: click.Context,
config: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
cloud: str | None = None,
sweep: str | None = None,
accelerate: bool,
cloud: Optional[str] = None,
sweep: Optional[str] = None,
**kwargs,
):
) -> None:
"""
Train or fine-tune a model.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
launcher: Launcher to use for multi-GPU training ("accelerate", "torchrun", or "python").
accelerate: Whether to use `accelerate` launcher.
cloud: Path to a cloud accelerator configuration file
sweep: Path to YAML config for sweeping hyperparameters.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
# Enable expandable segments for cuda allocation to improve VRAM usage
patch_optimized_env()
# Handle Ray launcher override
_launcher = None if kwargs.get("use_ray") else launcher
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
if sweep:
# load the sweep configuration yaml file
with open(sweep, "r", encoding="utf-8") as fin:
sweep_config: dict[str, list] = yaml.safe_load(fin)
with open(config, "r", encoding="utf-8") as fin:
base_config: dict[str, list] = yaml.safe_load(fin)
# Process each configuration
for cfg_file, is_group in generate_config_files(config, sweep):
# generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
def iter_configs():
for perm in permutations:
# open temp directory for temporary configurations
with tempfile.TemporaryDirectory() as temp_dir:
with open(
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
) as fout:
yaml.dump(perm, fout)
yield str(Path(temp_dir) / "config.yaml")
else:
def iter_configs():
yield config
for cfg_file in iter_configs():
# handle errors from subprocess so we can continue rest of sweeps
try:
use_exec = is_group is not True
launch_training(cfg_file, _launcher, cloud, kwargs, launcher_args, use_exec)
if accelerate:
if cloud:
from axolotl.cli.cloud import do_cli_train
cwd = os.getcwd()
do_cli_train(
cloud_config=cloud,
config=config,
accelerate=True,
cwd=cwd,
**kwargs,
)
else:
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num_processes")
accelerate_args.append(str(num_processes))
base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if cfg_file:
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
from axolotl.cli.cloud import do_cli_train
do_cli_train(
cloud_config=cloud, config=config, accelerate=False, **kwargs
)
else:
from axolotl.cli.train import do_cli
do_cli(config=cfg_file, **kwargs)
except subprocess.CalledProcessError as exc:
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
raise exc
finally:
# Only delete temp files, not the original config
if cfg_file != config:
os.unlink(cfg_file)
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for multi-GPU evaluation",
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
)
@add_options_from_dataclass(EvaluateCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@click.pass_context
def evaluate(ctx: click.Context, config: str, launcher: str, **kwargs):
def evaluate(config: str, accelerate: bool, **kwargs) -> None:
"""
Evaluate a model.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
launcher: Launcher to use for multi-GPU evaluation ("accelerate", "torchrun", or "python").
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if launcher in LAUNCHER_COMMAND_MAPPING:
base_cmd = (
LAUNCHER_COMMAND_MAPPING[launcher]
+ launcher_args
+ ["-m", "axolotl.cli.evaluate"]
)
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
@@ -179,42 +218,30 @@ def evaluate(ctx: click.Context, config: str, launcher: str, **kwargs):
do_cli(config=config, **kwargs)
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for multi-GPU inference",
"--accelerate/--no-accelerate",
default=False,
help="Use accelerate launch for multi-GPU inference",
)
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@click.pass_context
def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kwargs):
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
"""
Run inference with a trained model.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
launcher: Launcher to use for multi-GPU inference ("accelerate", "torchrun", or "python").
accelerate: Whether to use `accelerate` launcher.
gradio: Whether to use Gradio browser interface or command line for inference.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if launcher in LAUNCHER_COMMAND_MAPPING:
base_cmd = (
LAUNCHER_COMMAND_MAPPING[launcher]
+ launcher_args
+ ["-m", "axolotl.cli.inference"]
)
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
if config:
base_cmd.append(config)
if gradio:
@@ -227,42 +254,33 @@ def inference(ctx: click.Context, config: str, launcher: str, gradio: bool, **kw
do_cli(config=config, gradio=gradio, **kwargs)
@cli.command(
context_settings={"ignore_unknown_options": True, "allow_extra_args": True}
)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--launcher",
type=click.Choice(["accelerate", "torchrun", "python"]),
default="accelerate",
help="Launcher to use for weight merging",
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for weight merging",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
@click.pass_context
def merge_sharded_fsdp_weights(
ctx: click.Context, config: str, launcher: str, **kwargs
):
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
"""
Merge sharded FSDP model weights.
Args:
ctx: Click context for extra args.
config: Path to `axolotl` config YAML file.
launcher: Launcher to use for weight merging ("accelerate", "torchrun", or "python").
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Extract launcher args from extra args (after --)
launcher_args = ctx.args if ctx.args else []
if launcher in LAUNCHER_COMMAND_MAPPING:
base_cmd = (
LAUNCHER_COMMAND_MAPPING[launcher]
+ launcher_args
+ ["-m", "axolotl.cli.merge_sharded_fsdp_weights"]
)
if accelerate:
base_cmd = [
"accelerate",
"launch",
"-m",
"axolotl.cli.merge_sharded_fsdp_weights",
]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
@@ -278,7 +296,7 @@ def merge_sharded_fsdp_weights(
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def merge_lora(config: str, **kwargs):
def merge_lora(config: str, **kwargs) -> None:
"""
Merge trained LoRA adapters into a base model.
@@ -295,7 +313,7 @@ def merge_lora(config: str, **kwargs):
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]):
def fetch(directory: str, dest: Optional[str]) -> None:
"""
Fetch example configs or other resources.
@@ -333,7 +351,7 @@ def quantize(config: str, **cli_args: QuantizeCliArgs):
@cli.command()
@click.argument("model", type=click.Path(exists=True, path_type=str))
@click.argument("output", type=click.Path(exists=False, path_type=str))
def delinearize_llama4(model: str, output: str):
def delinearize_llama4(model: str, output: str) -> None:
from axolotl.cli.delinearize_llama4 import do_cli as do_delinearize_llama4
do_delinearize_llama4(model, output)
@@ -347,4 +365,5 @@ def main():
if __name__ == "__main__":
load_dotenv()
main()

View File

@@ -4,6 +4,7 @@ from pathlib import Path
from typing import Union
import fire
from dotenv import load_dotenv
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
@@ -43,10 +44,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
if processor:
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
@@ -72,7 +70,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
load_in_8bit=False,
load_in_4bit=False,
flash_attention=False,
context_parallel_size=None,
sequence_parallel_degree=None,
deepspeed=None,
fsdp=None,
fsdp_config=None,
@@ -90,4 +88,5 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -10,7 +10,6 @@ import fire
import torch
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
from accelerate import PartialState
from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
@@ -18,13 +17,13 @@ from accelerate.utils import (
WEIGHTS_NAME,
is_torch_version,
)
from dotenv import load_dotenv
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.config import load_cfg
from axolotl.utils.logging import get_logger
from axolotl.utils.train import determine_last_checkpoint
LOG = get_logger(__name__)
@@ -32,7 +31,7 @@ LOG = get_logger(__name__)
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
"""A custom planner to cast tensors to bfloat16 on the fly during loading."""
def commit_tensor(self, read_item, tensor):
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
tensor.copy_(tensor.to(torch.bfloat16))
@@ -59,10 +58,10 @@ def _distributed_checkpoint_to_merged_weights(
state_dict: Dict = {}
save_path_ = Path(save_path)
save_path_.mkdir(exist_ok=True)
dist_cp_format_utils._load_state_dict(
dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access
state_dict,
storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
planner=BFloat16CastPlanner(),
planner=BFloat16CastPlanner(), # pylint: disable=protected-access
no_dist=True,
)
@@ -145,6 +144,7 @@ def merge_fsdp_weights(
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
"""
checkpoint_dir_ = Path(checkpoint_dir)
from accelerate.state import PartialState
if not is_torch_version(">=", "2.3.0"):
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
@@ -181,6 +181,7 @@ def merge_fsdp_weights(
if remove_checkpoint_dir:
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
shutil.rmtree(checkpoint_dir_)
state.wait_for_everyone()
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
@@ -191,37 +192,17 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
if not fsdp_dir.exists():
checkpoint_dir = determine_last_checkpoint(parsed_cfg, update=False)
if checkpoint_dir:
fsdp_dir = Path(checkpoint_dir) / "pytorch_model_fsdp_0"
if not fsdp_dir.exists():
raise ValueError(
f"Could not find FSDP checkpoint `pytorch_model_fsdp_0` in {checkpoint_dir}"
)
output_path = str(Path(parsed_cfg.output_dir) / "merged")
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=output_path,
output_path=str(Path(parsed_cfg.output_dir) / "merged"),
safe_serialization=True,
)
state = PartialState()
state.wait_for_everyone()
LOG.info(
f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
main_process_only=True,
)
LOG.info(
"Merged weights are only the safetensors and doesn't include the model configuration "
f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
main_process_only=True,
)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -9,6 +9,7 @@ import fire
import transformers
from accelerate import init_empty_weights
from colorama import Fore
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM
from axolotl.cli.args import PreprocessCliArgs
@@ -35,20 +36,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
check_accelerate_default_config()
check_user_token()
if cli_args.iterable:
LOG.error(
"The --iterable CLI argument for 'axolotl preprocess' is no longer "
"supported. For training, set 'streaming: true' in your YAML config or "
"pass '--streaming' in your 'axolotl train' command for on-the-fly "
"preprocessing."
)
return
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
if cfg.get(key):
LOG.error(
f"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl "
"train' CLI directly instead."
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
)
return
@@ -83,7 +74,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True
)
except Exception: # nosec B110
except Exception as exc: # pylint: disable=broad-exception-caught,unused-variable # nosec B110 # noqa F841
pass
# fmt: on
@@ -105,10 +96,9 @@ def do_cli(
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
is_preprocess = kwargs.pop("is_preprocess", True)
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
@@ -119,4 +109,5 @@ def do_cli(
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -84,6 +84,5 @@ def do_quantize(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")

View File

@@ -3,12 +3,11 @@
import random
from copy import deepcopy
from itertools import product
from typing import Any
def generate_sweep_configs(
base_config: dict[str, list], sweeps_config: dict[str, list]
) -> list[dict[str, Any]]:
) -> list[dict[str, list]]:
"""
Recursively generates all possible configurations by applying sweeps to the base config.
@@ -49,10 +48,7 @@ def generate_sweep_configs(
new_config = {}
# new_config = deepcopy(base_config)
# Combine regular parameters with paired parameters
full_combo = {
**dict(zip(param_names, reg_combo, strict=False)),
**paired_set,
}
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
for param_name, param_value in full_combo.items():
new_config[param_name] = param_value
print(new_config)
@@ -61,7 +57,7 @@ def generate_sweep_configs(
# If no paired values, just use regular combinations
# new_config = deepcopy(base_config)
new_config = {}
for param_name, param_value in zip(param_names, reg_combo, strict=False):
for param_name, param_value in zip(param_names, reg_combo):
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)

Some files were not shown because too many files have changed in this diff Show More