Compare commits

..

13 Commits

Author SHA1 Message Date
Dan Saunders
64f349b7bb diffusion alt: custom loss impl 2025-08-18 20:50:34 +00:00
Dan Saunders
260ebe4c93 diffusion alt: custom loss impl 2025-08-18 20:50:20 +00:00
Dan Saunders
63d2280999 nits 2025-08-18 19:17:24 +00:00
Dan Saunders
b210db2d15 fixes 2025-08-18 19:09:09 +00:00
Dan Saunders
556a69118f sample generation, tests fixes 2025-08-18 18:25:04 +00:00
Dan Saunders
8569675b26 Merge branch 'main' into diffusion 2025-08-18 10:07:55 -04:00
Dan Saunders
077b5a4358 cleanup; tests draft 2025-08-16 02:44:44 +00:00
Dan Saunders
234b7b3126 nits 2025-08-16 00:14:44 +00:00
Dan Saunders
e19be0c2d9 add back in reinit_weights (clobbered?); masking / pretrain fixes 2025-08-15 02:21:25 +00:00
Dan Saunders
479a454ae3 fixes + improvements 2025-08-14 16:11:37 -04:00
Dan Saunders
0a9341acde nits 2025-08-14 01:53:24 -04:00
Dan Saunders
d8b63804bc cleanup 2025-08-14 01:51:13 -04:00
Dan Saunders
3156c605d4 diffusion training plugin 2025-08-14 01:48:22 -04:00
782 changed files with 15576 additions and 88390 deletions

View File

@@ -1,3 +1,3 @@
[bandit]
exclude = tests
skips = B101,B615,B102,B110
skips = B101,B615

View File

@@ -12,6 +12,5 @@ reviews:
auto_review:
enabled: true
drafts: false
auto_incremental_review: false
chat:
auto_reply: true

5
.flake8 Normal file
View File

@@ -0,0 +1,5 @@
[flake8]
max-line-length = 88
select = C,E,F,W,B,B950
extend-ignore = E203, E501, W503

View File

@@ -68,12 +68,7 @@ You can skip certain CI checks by including specific keywords in your commit mes
### Code Style
axolotl uses [Ruff](https://docs.astral.sh/ruff/) as its code style guide. Please ensure that your code follows these guidelines.
Use the pre-commit linter to ensure that your code is formatted consistently.
```bash
pre-commit run --all-files
```
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
### Commit Messages
@@ -83,6 +78,6 @@ Write clear and concise commit messages that briefly describe the changes made i
- [GitHub Help](https://help.github.com/)
- [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests)
- [Ruff](https://docs.astral.sh/ruff/)
- [{codestyle}]({URLofCodestyle})
Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together!

6
.github/FUNDING.yml vendored
View File

@@ -1,13 +1,13 @@
# These are supported funding model platforms
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
github: [winglian, OpenAccess-AI-Collective] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
ko_fi: axolotl_ai # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
custom: ['https://quickchart.io/qr?text=bitcoin%3Abc1qxlgwlqwfea5s2cxm42xqsfmwjct0rj8w8ea5np&size=480&centerImageUrl=https%3A%2F%2Fupload.wikimedia.org%2Fwikipedia%2Fcommons%2Fthumb%2F4%2F46%2FBitcoin.svg%2F64px-Bitcoin.svg.png'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

View File

@@ -15,11 +15,6 @@
<!--- Include details of your testing environment, tests ran to see how -->
<!--- your change affects other areas of the code, etc. -->
## AI Usage Disclaimer
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
## Screenshots (if appropriate)
## Types of changes

View File

@@ -15,85 +15,58 @@ on:
- '.github/workflows/base.yml'
workflow_dispatch:
permissions:
contents: read
jobs:
build-base:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
timeout-minutes: 480
# this job needs to be run on self-hosted GPU runners...
runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy:
fail-fast: false
matrix:
include:
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
@@ -117,21 +90,20 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-base
axolotlai/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v3
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
uses: docker/build-push-action@v4
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
@@ -146,84 +118,38 @@ jobs:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
timeout-minutes: 480
runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy:
fail-fast: false
matrix:
include:
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.7.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.0
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-uv-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.11"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -234,19 +160,17 @@ jobs:
images: |
axolotlai/axolotl-base-uv
- name: Login to Docker Hub
uses: docker/login-action@v3
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
uses: docker/build-push-action@v4
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -12,9 +12,6 @@ jobs:
build-deploy:
runs-on: ubuntu-latest
steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository
uses: actions/checkout@v4
- name: Set up Quarto

View File

@@ -13,9 +13,6 @@ on:
- ".pre-commit-config.yaml"
workflow_dispatch:
permissions:
contents: read
jobs:
pre-commit:
name: pre-commit

View File

@@ -8,9 +8,6 @@ on:
- "v*"
workflow_dispatch:
permissions:
contents: read
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
@@ -18,43 +15,27 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.9.0
pytorch: 2.6.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.7.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.7.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -64,6 +45,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl
axolotlai/axolotl
tags: |
type=ref,event=branch
@@ -80,7 +62,6 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
@@ -95,128 +76,40 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-uv:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
- name: Build and export to Docker
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
file: ./docker/Dockerfile-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.9.0
pytorch: 2.6.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.7.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.7.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -226,6 +119,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-cloud
axolotlai/axolotl-cloud
tags: |
type=ref,event=branch
@@ -241,7 +135,6 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
@@ -252,100 +145,30 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-uv:
needs: build-axolotl-uv
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-cloud-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-no-tmux:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.6.0
axolotl_extras:
is_latest: true
- cuda: 130
cuda_version: 13.0.0
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.7.1
axolotl_extras:
is_latest:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -355,6 +178,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-cloud-term
axolotlai/axolotl-cloud-term
tags: |
type=ref,event=branch
@@ -370,7 +194,6 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
platforms: linux/amd64,linux/arm64
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}

View File

@@ -8,7 +8,6 @@ on:
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'scripts/cutcrossentropy_install.py'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
- 'src/axolotl/utils/distributed.py'
workflow_dispatch:
@@ -20,12 +19,6 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions:
contents: read
env:
MODAL_IMAGE_BUILDER_VERSION: "2025.06"
jobs:
test-axolotl-multigpu:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
@@ -33,27 +26,27 @@ jobs:
fail-fast: false
matrix:
include:
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras: "fbgemm-gpu"
# num_gpus: 2
# dockerfile: "Dockerfile-uv.jinja"
- cuda: 130
cuda_version: 13.0.0
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.6.0
axolotl_extras:
# axolotl_extras: fbgemm-gpu
num_gpus: 2
- cuda: 128
cuda_version: 12.8.1
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.10.0
axolotl_extras: "fbgemm-gpu"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
dockerfile: "Dockerfile-uv.jinja"
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -66,7 +59,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.3.0.post1 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -75,9 +68,8 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run -m cicd.multigpu
modal run cicd.multigpu

View File

@@ -5,9 +5,6 @@ on:
schedule:
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
permissions:
contents: read
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
@@ -15,10 +12,15 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -29,6 +31,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl
axolotlai/axolotl
tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
@@ -62,10 +65,15 @@ jobs:
strategy:
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.6.0
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -76,6 +84,7 @@ jobs:
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-cloud
axolotlai/axolotl-cloud
tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}

View File

@@ -2,11 +2,9 @@ name: Pre-commit auto-update
on:
schedule:
- cron: '0 0 1 * *' # Run monthly
- cron: '0 0 * * 0' # Run weekly
workflow_dispatch: # Manual kickoff
permissions: {}
jobs:
auto-update:
runs-on: ubuntu-latest

View File

@@ -11,21 +11,22 @@ on:
- '_quarto.yml'
- docs/scripts/generate_config_docs.py
- src/axolotl/utils/schemas/**.py
- .github/workflows/preview-docs.yml
permissions:
contents: read
checks: write
contents: write
deployments: write
issues: write
discussions: write
pages: write
pull-requests: write
statuses: write
jobs:
preview:
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository
uses: actions/checkout@v4
with:

View File

@@ -3,11 +3,9 @@ name: publish pypi
on:
push:
tags:
- "v*"
- 'v*'
workflow_dispatch:
permissions: {}
jobs:
setup_release:
name: Create Release
@@ -30,8 +28,7 @@ jobs:
name: pypi
url: https://pypi.org/p/axolotl
permissions:
contents: read
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
steps:
- name: Check out repository code
uses: actions/checkout@v4
@@ -43,17 +40,17 @@ jobs:
- name: Install dependencies
run: |
pip3 install wheel packaging==26.0
pip3 install wheel packaging==23.2
pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Extract tag name
id: tag
run: echo "TAG_NAME=$(echo $GITHUB_REF | cut -d / -f 3)" >> "$GITHUB_OUTPUT"
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
- name: Update version in VERSION file
- name: Update version in setup.py
run: |
echo "${{ steps.tag.outputs.TAG_NAME }}" | sed 's/^v//' > VERSION
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
- name: Build a source dist
run: |

View File

@@ -3,13 +3,6 @@ on:
workflow_dispatch:
schedule:
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '.github/workflows/tests-nightly.yml'
permissions:
contents: read
jobs:
pre-commit:
@@ -25,26 +18,15 @@ jobs:
env:
SKIP: no-commit-to-branch
prime-cdn-s3-cache:
name: Prefetch S3 once to prime the CDN cache
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
timeout-minutes: 10
steps:
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
curl -v -H "Range: bytes=0-1023" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
pytest:
name: PyTest
runs-on: ubuntu-latest
needs: [prime-cdn-s3-cache]
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.9.1", "2.10.0"]
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
@@ -55,7 +37,7 @@ jobs:
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python
uses: actions/setup-python@v5
@@ -66,7 +48,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
@@ -117,26 +99,19 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.10.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
nightly_build: "true"
steps:
- name: Checkout
@@ -148,7 +123,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.3.0.post1 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -157,11 +132,9 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.e2e_tests
docker-e2e-multigpu-tests:
@@ -175,10 +148,10 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.7.1
num_gpus: 2
axolotl_extras:
nightly_build: "true"
@@ -192,7 +165,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.3.0.post1 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -202,8 +175,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.multigpu

View File

@@ -28,9 +28,6 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions:
contents: read
env:
TRANSFORMERS_IS_CI: "yes"
@@ -49,46 +46,27 @@ jobs:
env:
SKIP: no-commit-to-branch
prime-cdn-s3-cache:
name: Prefetch S3 once to prime the CDN cache
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
timeout-minutes: 10
steps:
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
curl -v -H "Range: bytes=0-1023" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
pytest:
name: PyTest
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
needs: [prime-cdn-s3-cache]
# needs: [preload-cache]
strategy:
fail-fast: false
matrix:
python_version: ["3.12", "3.14"]
pytorch_version: ["2.9.1", "2.10.0"]
exclude:
- python_version: "3.14"
pytorch_version: "2.9.1"
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
timeout-minutes: 20
steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p ~/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python
uses: actions/setup-python@v5
@@ -99,24 +77,20 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-cache-dir --no-build-isolation -U -e .
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
@@ -127,25 +101,15 @@ jobs:
- name: Pre-Download dataset fixture
run: |
hf download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Show HF cache
run: hf cache ls
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
df -h
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
df -h
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
df -h
pytest -v --durations=10 tests/patched/ --cov=axolotl --cov-append --cov-report=xml
df -h
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Show HF cache
run: hf cache ls
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
@@ -154,35 +118,30 @@ jobs:
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
needs: [prime-cdn-s3-cache]
strategy:
fail-fast: false
matrix:
python_version: ["3.12", "3.14"]
pytorch_version: ["2.9.1", "2.10.0"]
exclude:
- python_version: "3.14"
pytorch_version: "2.9.1"
timeout-minutes: 30
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
timeout-minutes: 20
steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
mkdir -p ~/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python
uses: actions/setup-python@v5
@@ -193,25 +152,21 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
- name: Install PyTorch
run: |
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
pip3 install torch==${{ matrix.pytorch_version }} torchvision
- name: Install dependencies
run: |
pip3 show torch
python -m build --no-isolation --sdist
pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz
pip3 install --no-build-isolation dist/axolotl*.tar.gz
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
@@ -221,19 +176,20 @@ jobs:
axolotl --help
- name: Show HF cache
run: hf cache ls
run: huggingface-cli scan-cache
- name: Run tests
run: |
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/
- name: Show HF cache
run: hf cache ls
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
gate-skip-e2e:
needs: [pre-commit]
needs: [pre-commit, pytest, pytest-sdist]
runs-on: ubuntu-latest
outputs:
skip: ${{ steps.compute.outputs.skip }}
@@ -269,16 +225,22 @@ jobs:
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
needs: [pre-commit, pytest]
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
strategy:
fail-fast: false
matrix:
include:
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.9.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
@@ -292,7 +254,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.3.0.post1 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -302,10 +264,9 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.e2e_tests
@@ -325,22 +286,16 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.10.0
num_gpus: 1
axolotl_extras:
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
steps:
@@ -353,7 +308,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.3.0.post1 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -363,11 +318,9 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.e2e_tests
@@ -381,10 +334,10 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
steps:
@@ -397,7 +350,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==1.3.0.post1 jinja2
pip install modal==1.0.2 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -407,6 +360,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.cleanup

3
.gitignore vendored
View File

@@ -190,6 +190,3 @@ out/
# vim
*.swp
# scm auto-versioning
src/axolotl/_version.py

4
.isort.cfg Normal file
View File

@@ -0,0 +1,4 @@
[settings]
profile=black
known_third_party=wandb,comet_ml
known_local_folder=src,tests

View File

@@ -10,14 +10,24 @@ repos:
- id: trailing-whitespace
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.8
- repo: https://github.com/psf/black
rev: 25.1.0
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- id: black
- repo: https://github.com/pycqa/isort
rev: 6.0.1
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 7.3.0
hooks:
- id: flake8
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.8
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.19.1
rev: v1.17.1
hooks:
- id: mypy
additional_dependencies:
@@ -26,7 +36,7 @@ repos:
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.9.4
rev: 1.8.6
hooks:
- id: bandit
args: [

15
.pylintrc Normal file
View File

@@ -0,0 +1,15 @@
[MASTER]
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
[TYPECHECK]
# List of members which are set dynamically and missed by Pylint inference
# system, and so shouldn't trigger E1101 when accessed.
generated-members=numpy.*, torch.*
[pylint.messages_control]
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-positional-arguments, possibly-used-before-assignment

View File

@@ -10,7 +10,6 @@ ARG BASE_VOLUME="/runpod-volume"
ENV BASE_VOLUME=$BASE_VOLUME
ENV HF_DATASETS_CACHE="${BASE_VOLUME}/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
ENV HF_HUB_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
COPY .runpod/src /src

View File

@@ -123,7 +123,7 @@ datasets:
| --------------------------------- | -------------------------- | ----------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_num_proc` | `4` | Number of preprocessing processes |
| `dataset_processes` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |

View File

@@ -39,6 +39,7 @@
# type: # linear | dynamic
# factor: # float
# # Whether you are training a 4-bit GPTQ quantized model
# gptq: true
# gptq_groupsize: 128 # group size
@@ -106,7 +107,7 @@
# push_dataset_to_hub: # repo path
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# # if not set.
# dataset_num_proc: # defaults to os.cpu_count() if not set
# dataset_processes: # defaults to os.cpu_count() if not set
# # push checkpoints to hub
# hub_model_id: # repo path to push finetuned model
# # how to push checkpoints to hub
@@ -223,6 +224,9 @@
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
# # Save model as safetensors (require safetensors package)
# save_safetensors:
# # Whether to mask out or include the human's prompt from the training labels
# train_on_inputs: false
# # Group similarly sized data to minimize padding.
@@ -348,6 +352,8 @@
# # Allow overwrite yml config using from cli
# strict:
base_model: ${BASE_MODEL}
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
base_model_config: ${BASE_MODEL_CONFIG}
@@ -406,7 +412,7 @@ chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
dataset_prepared_path: ${DATASET_PREPARED_PATH}
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
dataset_num_proc: ${DATASET_NUM_PROC}
dataset_processes: ${DATASET_PROCESSES}
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
hub_model_id: ${HUB_MODEL_ID}
hub_strategy: ${HUB_STRATEGY}
@@ -506,6 +512,7 @@ profiler_steps: ${PROFILER_STEPS}
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
save_safetensors: ${SAVE_SAFETENSORS}
train_on_inputs: ${TRAIN_ON_INPUTS}
group_by_length: ${GROUP_BY_LENGTH}
gradient_checkpointing: ${GRADIENT_CHECKPOINTING}

View File

@@ -1,94 +0,0 @@
# Axolotl
Fine-tuning framework for LLMs. Config-driven: every training run is defined by a single YAML file.
## Tech Stack
Python, PyTorch, HuggingFace Transformers, TRL, PEFT (LoRA/QLoRA), DeepSpeed, FSDP, vLLM (for GRPO generation).
## Commands
```bash
axolotl train config.yaml # Train (single or multi-GPU, auto-detected)
axolotl preprocess config.yaml # Tokenize dataset and validate config
axolotl preprocess config.yaml --debug # Inspect tokenized samples and label masking
axolotl inference config.yaml # Interactive inference
axolotl merge-lora config.yaml # Merge LoRA adapter into base model
axolotl vllm-serve config.yaml # Start vLLM server for GRPO/EBFT training
axolotl fetch examples # Download example configs
```
## Training Methods
| Method | Config Key | When to Use |
|--------|-----------|-------------|
| SFT | *(default)* | Input-output pairs, instruction tuning |
| DPO/IPO | `rl: dpo` / `rl: ipo` | Paired preference data (chosen vs rejected) |
| KTO | `rl: kto` | Unpaired binary preference labels |
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |
| EBFT | `rl: ebft` | Feature-matching rewards from internal representations |
Agent-specific references:
- [docs/agents/sft.md](docs/agents/sft.md) — supervised fine-tuning
- [docs/agents/preference_tuning.md](docs/agents/preference_tuning.md) — DPO, IPO, KTO, ORPO, SimPO
- [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions
- [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models
- [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining
## Config Pattern
All training is config-driven. A YAML file specifies model, adapter, dataset(s), and hyperparameters:
```yaml
base_model: meta-llama/Llama-3.1-8B-Instruct
adapter: lora # or qlora, or omit for full fine-tune
datasets:
- path: my_dataset
type: chat_template # prompt strategy (see docs/dataset-formats/)
output_dir: ./outputs/lora-out
```
Config schema: `src/axolotl/utils/schemas/config.py` (AxolotlInputConfig).
## Project Structure
```
src/axolotl/
cli/ # CLI entry points (train, preprocess, inference, merge_lora, vllm_serve)
core/
builders/ # TrainerBuilder classes (causal.py for SFT, rl.py for RLHF)
trainers/ # Trainer classes, mixins (optimizer, scheduler, packing)
dpo/ # DPO trainer and config
grpo/ # GRPO trainer and sampler
loaders/ # Model, tokenizer, adapter, processor loading
prompt_strategies/ # Dataset format handlers (chat_template, alpaca, dpo/, kto/, orpo/)
utils/schemas/ # Pydantic config schemas (config, model, training, peft, trl, fsdp)
integrations/ # Plugins (liger, cut_cross_entropy, swanlab, nemo_gym)
monkeypatch/ # Runtime patches for HF transformers
examples/ # Example YAML configs by model (llama-3/, qwen2/, mistral/, ebft/)
deepspeed_configs/ # DeepSpeed JSON configs (zero2, zero3)
docs/ # Quarto documentation site
```
## Code Conventions
- Config-driven: features are toggled via YAML, not code changes
- Prompt strategies: `src/axolotl/prompt_strategies/` — each `type:` value maps to a function
- Plugin system: `plugins:` list in config loads integration modules
- Trainer mixins: `core/trainers/mixins/` for composable trainer behaviors
- Schemas: all config validation via Pydantic in `utils/schemas/`
## Key Documentation
- [Getting Started](docs/getting-started.qmd) — quickstart tutorial
- [Choosing a Method](docs/choosing_method.qmd) — SFT vs DPO vs GRPO decision guide
- [Config Reference](docs/config-reference.qmd) — all config options
- [Dataset Formats](docs/dataset-formats/) — chat_template, alpaca, input_output, completion
- [RLHF](docs/rlhf.qmd) — DPO, KTO, ORPO, GRPO, EBFT configs and dataset formats
- [GRPO Deep Dive](docs/grpo.qmd) — async training, custom rewards, scaling
- [vLLM Serving](docs/vllm_serving.qmd) — vLLM setup for GRPO/EBFT
- [Multi-GPU](docs/multi-gpu.qmd) — FSDP and DeepSpeed
- [Training Stability](docs/training_stability.qmd) — debugging loss, NaN, OOM
- [Debugging](docs/debugging.qmd) — VSCode setup, Docker debugging

View File

@@ -1,6 +1,6 @@
cff-version: 1.2.0
type: software
title: "Axolotl: Open Source LLM Post-Training"
title: "Axolotl: Post-Training for AI Models"
message: "If you use this software, please cite it as below."
authors:
- name: "Axolotl maintainers and contributors"

View File

@@ -5,9 +5,6 @@
<img alt="Axolotl" src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/887513285d98132142bf5db2a74eb5e0928787f1/image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
</picture>
</p>
<p align="center">
<strong>A Free and Open Source LLM Fine-tuning Framework</strong><br>
</p>
<p align="center">
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
@@ -20,7 +17,6 @@
<br/>
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
<a href="https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google-colab" style="height: 20px;"></a>
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
@@ -29,35 +25,21 @@
## 🎉 Latest Updates
- 2026/03:
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
- 2026/02:
- [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.
- Axolotl now has support for [SageAttention](https://github.com/axolotl-ai-cloud/axolotl/pull/2823) and [GDPO](https://github.com/axolotl-ai-cloud/axolotl/pull/3353) (Generalized DPO).
- 2026/01:
- New integration for [EAFT](https://github.com/axolotl-ai-cloud/axolotl/pull/3366) (Entropy-Aware Focal Training), weights loss by entropy of the top-k logit distribution, and [Scalable Softmax](https://github.com/axolotl-ai-cloud/axolotl/pull/3338), improves long context in attention.
- 2025/12:
- Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
- [Distributed Muon Optimizer](https://github.com/axolotl-ai-cloud/axolotl/pull/3264) support has been added for FSDP2 pretraining.
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).
- 2025/07:
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
- Axolotl adds more models: [GPT-OSS](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gpt-oss), [Gemma 3n](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma3n), [Liquid Foundation Model 2 (LFM2)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/lfm2), and [Arcee Foundation Models (AFM)](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/afm).
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
- [Voxtral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/voxtral), [Magistral 1.1](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral), and [Devstral](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/devstral) with mistral-common tokenizer support has been integrated in Axolotl!
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
<details>
<summary>Expand older updates</summary>
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
- 2025/07:
- ND Parallelism support has been added into Axolotl. Compose Context Parallelism (CP), Tensor Parallelism (TP), and Fully Sharded Data Parallelism (FSDP) within a single node and across multiple nodes. Check out the [blog post](https://huggingface.co/blog/accelerate-nd-parallel) for more info.
- Axolotl adds more models: [GPT-OSS](https://docs.axolotl.ai/docs/models/gpt-oss.html), [Gemma 3n](https://docs.axolotl.ai/docs/models/gemma3n.html), [Liquid Foundation Model 2 (LFM2)](https://docs.axolotl.ai/docs/models/LiquidAI.html), and [Arcee Foundation Models (AFM)](https://docs.axolotl.ai/docs/models/arcee.html).
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
- [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
@@ -67,38 +49,33 @@
## ✨ Overview
Axolotl is a free and open-source tool designed to streamline post-training and fine-tuning for the latest large language models (LLMs).
Axolotl is a tool designed to streamline post-training for various AI models.
Features:
- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention 2/3/4](https://docs.axolotl.ai/docs/attention.html#flash-attention), [Xformers](https://docs.axolotl.ai/docs/attention.html#xformers), [Flex Attention](https://docs.axolotl.ai/docs/attention.html#flex-attention), [SageAttention](https://docs.axolotl.ai/docs/attention.html#sageattention), [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels), [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
## 🚀 Quick Start - LLM Fine-tuning in Minutes
## 🚀 Quick Start
**Requirements**:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.11
- PyTorch ≥2.9.1
### Google Colab
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
- PyTorch ≥2.6.0
### Installation
#### Using pip
```bash
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs
@@ -168,13 +145,6 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
## 📈 Telemetry
Axolotl has opt-out telemetry that helps us understand how the project is being used
and prioritize improvements. We collect basic system information, model types, and
error rates—never personal data or file paths. Telemetry is enabled by default. To
disable it, set AXOLOTL_DO_NOT_TRACK=1. For more details, see our [telemetry documentation](https://docs.axolotl.ai/docs/telemetry.html).
## ❤️ Sponsors
Interested in sponsoring? Contact us at [wing@axolotl.ai](mailto:wing@axolotl.ai)
@@ -185,7 +155,7 @@ If you use Axolotl in your research or projects, please cite it as follows:
```bibtex
@software{axolotl,
title = {Axolotl: Open Source LLM Post-Training},
title = {Axolotl: Post-Training for AI Models},
author = {{Axolotl maintainers and contributors}},
url = {https://github.com/axolotl-ai-cloud/axolotl},
license = {Apache-2.0},

View File

@@ -1 +0,0 @@
0.16.0.dev0

View File

@@ -1,8 +1,6 @@
project:
type: website
pre-render:
- docs/scripts/generate_config_docs.py
- docs/scripts/generate_examples_docs.py
pre-render: docs/scripts/generate_config_docs.py
quartodoc:
dir: docs/api
@@ -128,9 +126,11 @@ quartodoc:
- monkeypatch.mistral_attn_hijack_flash
- monkeypatch.multipack
- monkeypatch.relora
- monkeypatch.llama_expand_mask
- monkeypatch.lora_kernels
- monkeypatch.utils
- monkeypatch.btlm_attn_hijack_flash
- monkeypatch.llama_patch_multipack
- monkeypatch.stablelm_attn_hijack_flash
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils
@@ -153,7 +153,7 @@ quartodoc:
- utils.distributed
- utils.dict
- utils.optimizers.adopt
- utils.data.streaming
- utils.data.pretraining
- utils.data.sft
- utils.quantization
- title: Schemas
@@ -238,51 +238,9 @@ website:
- section: "Getting Started"
contents:
- docs/getting-started.qmd
- docs/choosing_method.qmd
- docs/installation.qmd
- docs/inference.qmd
- section: "Model Guides"
contents:
- docs/models/kimi-linear.qmd
- docs/models/plano.qmd
- docs/models/mimo.qmd
- docs/models/internvl3_5.qmd
- docs/models/olmo3.qmd
- docs/models/trinity.qmd
- docs/models/arcee.qmd
- section: "Ministral3"
contents:
- docs/models/ministral3.qmd
- docs/models/ministral3/think.qmd
- docs/models/ministral3/vision.qmd
- section: "Magistral"
contents:
- docs/models/magistral.qmd
- docs/models/magistral/think.qmd
- docs/models/magistral/vision.qmd
- docs/models/ministral.qmd
- docs/models/mistral-small.qmd
- docs/models/voxtral.qmd
- docs/models/devstral.qmd
- docs/models/mistral.qmd
- docs/models/llama-4.qmd
- docs/models/llama-2.qmd
- docs/models/qwen3-next.qmd
- docs/models/qwen3.qmd
- docs/models/gemma3n.qmd
- docs/models/apertus.qmd
- docs/models/gpt-oss.qmd
- docs/models/seed-oss.qmd
- docs/models/phi.qmd
- docs/models/smolvlm2.qmd
- docs/models/granite4.qmd
- docs/models/LiquidAI.qmd
- docs/models/hunyuan.qmd
- docs/models/jamba.qmd
- docs/models/orpheus.qmd
- docs/cli.qmd
- docs/telemetry.qmd
- docs/config-reference.qmd
- text: "API Reference"
href: docs/api
@@ -303,26 +261,20 @@ website:
contents:
- docs/multimodal.qmd
- docs/rlhf.qmd
- docs/grpo.qmd
- docs/ebft.qmd
- docs/vllm_serving.qmd
- docs/reward_modelling.qmd
- docs/lr_groups.qmd
- docs/lora_optims.qmd
- docs/dataset_loading.qmd
- docs/qat.qmd
- docs/quantize.qmd
- docs/optimizations.qmd
- section: "Core Concepts"
contents:
- docs/batch_vs_grad.qmd
- docs/dataset_preprocessing.qmd
- docs/streaming.qmd
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/optimizers.qmd
- docs/attention.qmd
- section: "Advanced Features"
contents:
@@ -333,12 +285,10 @@ website:
- docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- docs/nd_parallelism.qmd
- docs/expert_quantization.qmd
- section: "Troubleshooting"
contents:
- docs/faq.qmd
- docs/training_stability.qmd
- docs/debugging.qmd
- docs/nccl.qmd

View File

@@ -1,208 +0,0 @@
"""Benchmark for entropy_from_logits Triton kernel vs original chunked implementation.
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_entropy.py
"""
import gc
import statistics
import torch
import torch.nn.functional as F
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
V = 151936 # Qwen vocab
WARMUP = 5
BENCH_ITERS = 20
MEM_ITERS = 10
def entropy_from_logits_original(logits: torch.Tensor, chunk_size: int = 128):
"""Original chunked implementation (reference)."""
original_shape = logits.shape[:-1]
num_classes = logits.shape[-1]
flat_logits = logits.reshape(-1, num_classes)
entropies = []
for chunk in flat_logits.split(chunk_size, dim=0):
logps = F.log_softmax(chunk, dim=-1)
chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
entropies.append(chunk_entropy)
return torch.cat(entropies, dim=0).reshape(original_shape)
def _clean_gpu():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.synchronize()
def profile_time(fn, logits, n_iters=BENCH_ITERS):
for _ in range(WARMUP):
out = fn(logits, chunk_size=128)
del out
torch.cuda.synchronize()
times = []
for _ in range(n_iters):
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
out = fn(logits, chunk_size=128)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
del out
return times
def profile_memory(fn, logits, n_iters=MEM_ITERS):
for _ in range(WARMUP):
out = fn(logits, chunk_size=128)
del out
torch.cuda.synchronize()
peaks = []
for _ in range(n_iters):
_clean_gpu()
base = torch.cuda.max_memory_allocated()
out = fn(logits, chunk_size=128)
torch.cuda.synchronize()
peaks.append(torch.cuda.max_memory_allocated() - base)
del out
return [p / 1e6 for p in peaks]
def fmt(values, unit=""):
mean = statistics.mean(values)
std = statistics.stdev(values) if len(values) > 1 else 0.0
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
def benchmark_contiguous():
print("=" * 60)
print(
f"CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
)
print("=" * 60)
configs = [
(1, 2048),
(1, 8192),
(1, 16384),
(4, 4096),
(8, 2048),
(16, 2048),
(16, 4096),
]
for B, L in configs:
mem_gb = B * L * V * 2 / 1e9
if mem_gb > 28:
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
print(f"{'' * 60}")
torch.manual_seed(42)
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
t_orig = profile_time(entropy_from_logits_original, logits)
t_triton = profile_time(entropy_from_logits, logits)
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" TIME (ms):")
print(f" original: {fmt(t_orig, 'ms')}")
print(f" triton: {fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(entropy_from_logits_original, logits)
m_triton = profile_memory(entropy_from_logits, logits)
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" MEMORY (peak overhead):")
print(f" original: {fmt(m_orig, 'MB')}")
print(f" triton: {fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del logits
_clean_gpu()
def benchmark_noncontiguous():
print("\n" + "=" * 60)
print(
f"NON-CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
)
print("=" * 60)
configs = [
(4, 2048, "transpose"),
(4, 8192, "transpose"),
(8, 2048, "transpose"),
(4, 4096, "slice_batch"),
]
for B, L, method in configs:
torch.manual_seed(42)
if method == "transpose":
raw = torch.randn(L, B, V, device="cuda", dtype=torch.bfloat16)
logits_nc = raw.transpose(0, 1)
raw_gb = L * B * V * 2 / 1e9
elif method == "slice_batch":
raw = torch.randn(B * 2, L, V, device="cuda", dtype=torch.bfloat16)
logits_nc = raw[::2]
raw_gb = B * 2 * L * V * 2 / 1e9
else:
continue
if raw_gb > 28:
print(f"\n skip B={B}, L={L}, {method} ({raw_gb:.1f} GB)")
del raw, logits_nc
torch.cuda.empty_cache()
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B}, L={L} {method} ({N} rows, raw {raw_gb:.2f} GB)")
print(f"{'' * 60}")
def original_with_copy(logits, chunk_size=128):
return entropy_from_logits_original(
logits.contiguous(), chunk_size=chunk_size
)
t_orig = profile_time(original_with_copy, logits_nc)
t_triton = profile_time(entropy_from_logits, logits_nc)
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" TIME (ms):")
print(f" orig+copy: {fmt(t_orig, 'ms')}")
print(f" triton-strided:{fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(original_with_copy, logits_nc)
m_triton = profile_memory(entropy_from_logits, logits_nc)
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" MEMORY (peak overhead):")
print(f" orig+copy: {fmt(m_orig, 'MB')}")
print(f" triton-strided:{fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del raw, logits_nc
_clean_gpu()
if __name__ == "__main__":
benchmark_contiguous()
benchmark_noncontiguous()

View File

@@ -1,284 +0,0 @@
"""Benchmark for ScatterMoE LoRA Triton kernels.
Measures forward, backward dX, and backward dA/dB kernels at common MoE
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
and full fwd+bwd autograd throughput.
Usage:
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
"""
import argparse
import gc
import time
from functools import partial
import torch
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
lora_ops,
ops as base_ops,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
flatten_sort_count,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
ScatterMoELoRA,
)
DEVICE = "cuda"
DTYPE = torch.bfloat16
WARMUP = 5
ITERS = 20
# ─── Model configs ──────────────────────────────────────────────────────────
BUILTIN_CONFIGS = {
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
"Qwen3-30B-A3B": (128, 2048, 768, 8),
"OLMoE-1B-7B": (64, 2048, 1024, 8),
"Mixtral-8x7B": (8, 4096, 14336, 2),
}
def _resolve_config(spec):
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
key = spec.lower().replace("/", "-")
for name, cfg in BUILTIN_CONFIGS.items():
if key in name.lower() or name.lower() in key:
return name, cfg
from transformers import AutoConfig
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
if callable(getattr(hf_cfg, "get_text_config", None)):
tc = hf_cfg.get_text_config()
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
hf_cfg = tc
hidden = hf_cfg.hidden_size
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
experts = (
getattr(hf_cfg, "num_experts", None)
or getattr(hf_cfg, "num_local_experts", None)
or getattr(hf_cfg, "n_routed_experts", None)
)
top_k = (
getattr(hf_cfg, "num_experts_per_tok", None)
or getattr(hf_cfg, "num_experts_per_token", None)
or 2
)
name = spec.split("/")[-1]
return name, (experts, hidden, inter, top_k)
# ─── Benchmark helpers ──────────────────────────────────────────────────────
def _clean():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
def _bench(fn, warmup=WARMUP, iters=ITERS):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
times = []
for _ in range(iters):
torch.cuda.synchronize()
t0 = time.perf_counter()
fn()
torch.cuda.synchronize()
times.append((time.perf_counter() - t0) * 1000)
times.sort()
return times[len(times) // 2]
def _setup(num_experts, K, N, T, top_k, R):
torch.manual_seed(42)
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
logits = torch.randn(T, num_experts, device=DEVICE)
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
gx = base_ops.group(x, ssi, fan_out=top_k)
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
return lora_ops.scatter2scatter_lora(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
lora_A=lA,
lora_B=lB,
scaling=2.0,
)
def _call_base(x, W, sei, ssi, top_k):
return base_ops.scatter2scatter(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
)
def _call_dx(dy, W, sei, ssi, lA, lB):
return lora_ops.scatter2scatter_lora_dX(
DY=dy,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=1,
lora_A=lA,
lora_B=lB,
scaling=2.0,
dy_grouped=True,
dx_grouped=False,
)
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
return lora_ops.group_bwd_lora(
DY=dy,
X=gx,
lora_A=lA,
lora_B=lB,
expert_offsets=eo,
E=num_experts,
scaling=2.0,
)
# ─── Main ────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
parser.add_argument(
"--models",
"-m",
nargs="+",
help="Model names or HF IDs (default: all builtins)",
)
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
parser.add_argument("--seq-len", "-T", type=int, default=2048)
args = parser.parse_args()
T = args.seq_len
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"T={T}, ranks={args.ranks}\n")
if args.models:
configs = [_resolve_config(m) for m in args.models]
else:
configs = list(BUILTIN_CONFIGS.items())
for model_name, (num_experts, hidden, inter, top_k) in configs:
print(f"{'=' * 70}")
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
print(f"{'=' * 70}")
for R in args.ranks:
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
_clean()
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
num_experts, K, N, T, top_k, R
)
# Forward with LoRA (auto-dispatched: fused or split)
dispatch = (
"split"
if (
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
)
else "fused"
)
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
total = t_fwd + t_dx + t_bwd
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
print(
f" R={R:>2} {proj:<8} "
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
f"base={t_base:>6.2f}ms "
f"(+{overhead * 100:.0f}%) "
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
f"total={total:>6.2f}ms"
)
# Full autograd fwd+bwd with memory measurement
x_ag = x.clone().requires_grad_(True)
lA_ag = lA.clone().requires_grad_(True)
lB_ag = lB.clone().requires_grad_(True)
def _run_autograd(
_x=x_ag,
_W=W,
_k=top_k,
_sei=sei,
_ssi=ssi,
_eo=eo,
_lA=lA_ag,
_lB=lB_ag,
):
out = ScatterMoELoRA.apply(
_x,
_W,
_k,
_sei,
_ssi,
_eo,
_lA,
_lB,
2.0,
None,
None,
False,
False,
True,
False,
)
out.sum().backward()
_x.grad = None
_lA.grad = None
_lB.grad = None
t_full = _bench(_run_autograd)
_clean()
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
_run_autograd()
torch.cuda.synchronize()
mem_peak = torch.cuda.max_memory_allocated() - mem_before
print(
f" full_fwd_bwd={t_full:>6.2f}ms "
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
)
print()
if __name__ == "__main__":
main()

View File

@@ -1,191 +0,0 @@
"""Benchmark for selective_log_softmax Triton kernel vs original implementation.
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_selective_logsoftmax.py
"""
import gc
import statistics
import torch
from axolotl.monkeypatch.trainer.utils import (
selective_log_softmax,
selective_log_softmax_original,
)
V = 151936 # Qwen vocab
WARMUP = 5
BENCH_ITERS = 20
MEM_ITERS = 10
def _clean_gpu():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.synchronize()
def profile_time(fn, args, n_iters=BENCH_ITERS):
for _ in range(WARMUP):
fn(*args)
torch.cuda.synchronize()
times = []
for _ in range(n_iters):
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
fn(*args)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
return times
def profile_memory(fn, args, n_iters=MEM_ITERS):
for _ in range(WARMUP):
out = fn(*args)
del out
torch.cuda.synchronize()
peaks = []
for _ in range(n_iters):
_clean_gpu()
base = torch.cuda.max_memory_allocated()
out = fn(*args)
torch.cuda.synchronize()
peaks.append(torch.cuda.max_memory_allocated() - base)
del out
return [p / 1e6 for p in peaks]
def fmt(values, unit=""):
mean = statistics.mean(values)
std = statistics.stdev(values) if len(values) > 1 else 0.0
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
def benchmark_forward():
print("=" * 60)
print(f"FORWARD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
print("=" * 60)
configs = [
(1, 2048),
(1, 8192),
(4, 4096),
(8, 2048),
(16, 2048),
(16, 4096),
]
for B, L in configs:
mem_gb = B * L * V * 2 / 1e9
if mem_gb > 28:
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
print(f"{'' * 60}")
torch.manual_seed(42)
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
index = torch.randint(0, V, (B, L), device="cuda")
t_orig = profile_time(selective_log_softmax_original, (logits, index))
t_triton = profile_time(selective_log_softmax, (logits, index))
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" TIME (ms):")
print(f" original: {fmt(t_orig, 'ms')}")
print(f" triton: {fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(selective_log_softmax_original, (logits, index))
m_triton = profile_memory(selective_log_softmax, (logits, index))
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" MEMORY (peak overhead):")
print(f" original: {fmt(m_orig, 'MB')}")
print(f" triton: {fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del logits, index
_clean_gpu()
def benchmark_backward():
print("\n" + "=" * 60)
print(f"FWD+BWD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
print("=" * 60)
configs = [
(1, 2048),
(1, 8192),
(4, 4096),
(8, 2048),
(16, 2048),
(16, 4096),
]
def fwd_bwd_original(logits, index):
logits.grad = None
out = selective_log_softmax_original(logits, index)
out.sum().backward()
def fwd_bwd_triton(logits, index):
logits.grad = None
out = selective_log_softmax(logits, index)
out.sum().backward()
for B, L in configs:
mem_gb = B * L * V * 2 / 1e9
if mem_gb > 20:
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB, need room for grads)")
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
print(f"{'' * 60}")
torch.manual_seed(42)
logits_orig = torch.randn(
B, L, V, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
logits_tri = logits_orig.detach().clone().requires_grad_(True)
index = torch.randint(0, V, (B, L), device="cuda")
t_orig = profile_time(fwd_bwd_original, (logits_orig, index))
t_triton = profile_time(fwd_bwd_triton, (logits_tri, index))
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" FWD+BWD TIME (ms):")
print(f" original: {fmt(t_orig, 'ms')}")
print(f" triton: {fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(fwd_bwd_original, (logits_orig, index))
m_triton = profile_memory(fwd_bwd_triton, (logits_tri, index))
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" FWD+BWD MEMORY (peak overhead):")
print(f" original: {fmt(m_orig, 'MB')}")
print(f" triton: {fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del logits_orig, logits_tri, index
_clean_gpu()
if __name__ == "__main__":
benchmark_forward()
benchmark_backward()

View File

@@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace
@@ -31,9 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN uv pip install packaging==26.0 setuptools==78.1.1
RUN uv pip install torchvision
RUN uv pip uninstall causal_conv1d
RUN uv pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -1,6 +1,6 @@
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
@@ -9,10 +9,10 @@ ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_NUM_PROC="8"
ENV AXOLOTL_DATASET_PROCESSES="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace
@@ -32,8 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN pip install packaging==26.0 setuptools==78.1.1 psutil
RUN pip uninstall -y causal_conv1d
RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -3,24 +3,6 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
set -o pipefail
for i in 1 2 3; do
if curl --silent --show-error --fail -L \
https://axolotl-ci.b-cdn.net/hf-cache.tar.zst \
| tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1; then
echo "HF cache extracted successfully"
break
fi
echo "Attempt $i failed, cleaning up and retrying in 15s..."
rm -rf "${HF_HOME}/hub/"*
sleep 15
done
# hf download "NousResearch/Meta-Llama-3-8B"
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
# hf download "microsoft/Phi-4-reasoning"
# hf download "microsoft/Phi-3.5-mini-instruct"
# hf download "microsoft/Phi-3-medium-128k-instruct"
# Run unit tests with initial coverage report
pytest -v --durations=10 -n8 \
--ignore=tests/e2e/ \

View File

@@ -2,6 +2,8 @@
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
@@ -17,8 +19,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
df_template = template_env.get_template(dockerfile)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
@@ -28,11 +29,8 @@ df_args = {
"CUDA": os.environ.get("CUDA", "126"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
}
dockerfile_contents = df_template.render(**df_args)
@@ -65,7 +63,7 @@ def run_cmd(cmd: str, run_folder: str):
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code)
exit(exit_code) # pylint: disable=consider-using-sys-exit
@app.function(

View File

@@ -2,7 +2,7 @@
set -e
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v --durations=10 -n2 --maxfail=3 \
pytest -v --durations=10 -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \

View File

@@ -1,5 +1,7 @@
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
@@ -57,17 +59,15 @@ VOLUME_CONFIG = {
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_TYPE = os.environ.get("GPU_TYPE", "L40S")
GPU_CONFIG = f"{GPU_TYPE}:{N_GPUS}"
GPU_CONFIG = f"L40S:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
sp_env = os.environ.copy()
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
# Propagate errors from subprocess.
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
if exit_code:
raise RuntimeError(f"Command '{cmd}' failed with exit code {exit_code}")
if exit_code := subprocess.call(cmd.split(), cwd=run_folder, env=sp_env): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit

View File

@@ -12,7 +12,7 @@ coverage:
default:
# basic
target: auto
threshold: 1%
threshold: 0%
base: auto
# advanced
branches: null
@@ -27,7 +27,7 @@ coverage:
default:
# basic
target: auto
threshold: 1%
threshold: 0%
base: auto
# advanced
branches: null
@@ -37,7 +37,6 @@ coverage:
only_pulls: false
flags: null
paths: null
informational: true
parsers:
gcov:

View File

@@ -13,7 +13,7 @@ datasets:
val_set_size: 0
output_dir: temp_debug/axolotl_outputs/model
dataset_prepared_path: temp_debug/axolotl_outputs/data
dataset_num_proc: 1
dataset_processes: 1
sequence_len: 4096
sample_packing: false

View File

@@ -6,7 +6,6 @@ ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -21,18 +20,13 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN pip uninstall -y causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \ python scripts/unsloth_install.py | sh && \
python scripts/unsloth_install.py | sh && \
python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \
pip cache purge

View File

@@ -2,16 +2,14 @@ ARG CUDA_VERSION="11.8.0"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG TARGETARCH
ARG PYTHON_VERSION="3.11"
ARG PYTHON_VERSION="3.10"
ARG PYTORCH_VERSION="2.1.2"
ARG CUDA="128"
ARG CUDA="118"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
@@ -24,17 +22,11 @@ RUN apt-get update \
librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm \
&& rm -rf /var/cache/apt/archives \
&& rm -rf /var/lib/apt/lists/* \
&& if [ "$TARGETARCH" = "amd64" ]; then \
MINICONDA_ARCH="x86_64"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
MINICONDA_ARCH="aarch64"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi \
&& wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh -b \
&& rm -f Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main \
&& conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
@@ -43,34 +35,18 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel psutil && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge
RUN if [ "$CUDA" != "130" ] ; then \
CAUSAL_CONV1D_FORCE_CXX11_ABI=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@v1.5.4"; \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
python3 -m pip cache purge; \
fi
RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
pip3 install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
fi

View File

@@ -30,7 +30,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \

View File

@@ -1,31 +0,0 @@
ARG BASE_TAG=main
FROM axolotlai/axolotl-uv:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
EXPOSE 8888
EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN uv pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/* && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
printf "source /workspace/axolotl-venv/bin/activate\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"]

View File

@@ -1,48 +0,0 @@
ARG BASE_TAG=main-base
FROM axolotlai/axolotl-base-uv:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN uv pip uninstall causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \
python scripts/unsloth_install.py --uv | sh && \
python scripts/cutcrossentropy_install.py --uv | sh && \
uv pip install pytest && \
uv cache clean
# fix so that git fetch/pull from remote works with shallow clone
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch && \
git config --global credential.helper store
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
RUN chmod +x /root/.axolotl-complete.bash && \
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc

View File

@@ -2,11 +2,9 @@ ARG CUDA_VERSION="12.6.3"
ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ARG TARGETARCH
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.6.0"
ARG CUDA="126"
@@ -32,26 +30,7 @@ RUN uv venv --no-project --relocatable axolotl-venv
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
&& uv pip install torch==${PYTORCH_VERSION} \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic
RUN if [ "$TARGETARCH" = "amd64" ]; then \
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
fi
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
LINUX_TAG="manylinux_" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="2_24_x86_64.manylinux_2_28_x86_64" ;; \
arm64) ARCH_TAG="2_34_aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-${LINUX_TAG}${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
uv pip install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"

2
docs/.gitignore vendored
View File

@@ -3,5 +3,3 @@ _site/
/api/*.qmd
/api/*.html
config-reference.qmd
models/**/*.qmd
models/**/*.html

View File

@@ -1,71 +0,0 @@
# GRPO — Agent Reference
Online RL with verifiable reward functions. For full config reference, async features, and scaling, see [grpo.qmd](../grpo.qmd). For vLLM setup, see [vllm_serving.qmd](../vllm_serving.qmd).
## Architecture
```
Terminal 1 (GPU 0) Terminal 2 (GPU 1)
┌──────────────────────┐ ┌──────────────────────────────────┐
│ vLLM Server │ HTTP │ Trainer │
│ Serves base model │◄────────────►│ 1. Send prompts to vLLM │
│ + LoRA adapter │ /generate │ 2. Score completions (rewards) │
│ │ /set_lora │ 3. Compute advantages │
│ Punica kernels for │ │ 4. PPO-clip gradient update │
│ LoRA inference │ │ 5. Sync LoRA weights to vLLM │
└──────────────────────┘ └──────────────────────────────────┘
```
## Components Required
1. A YAML config with `rl: grpo`
2. A reward module (Python file with reward functions)
3. A running vLLM server (`axolotl vllm-serve config.yaml`)
## Reward Function Signature
```python
def my_reward(completions, **kwargs) -> list[float]:
# completions[i][0]["content"] = text of i-th completion
# **kwargs contains dataset columns not removed by transform
return [score_for_each_completion]
```
Multiple rewards: `reward_funcs: [r1, r2]` with `reward_weights: [1.0, 0.5]`.
## Key Async Features
| Feature | Config | Purpose |
|---------|--------|---------|
| Async prefetch | `async_prefetch: true` | Overlap generation with training |
| LoRA sync | `vllm_lora_sync: true` | Fast adapter sync via filesystem |
| Streaming scoring | `streaming_partial_batch: true` | Score one group at a time |
| Zero-adv skip | `skip_zero_advantage_batches: true` | Skip batches with no learning signal |
| Replay buffer | `replay_buffer_size: 100` | Cache high-signal groups |
| IS correction | `vllm_importance_sampling_correction: true` | Fix off-policy distribution shift |
## Health Checks
- `rewards/*/mean` > 0.15 within 20 steps (else: test reward function standalone)
- `reward_std` > 0 on most steps (else: no learning signal)
- `entropy` 0.05-0.5 (< 0.01 = mode collapse)
- `grad_norm` 0.001-1.0 (> 10 = unstable, 0.0 = zero-advantage skip)
See [training_stability.qmd](../training_stability.qmd) for detailed diagnostics.
## File Map
```
src/axolotl/
cli/train.py # Entry point
cli/vllm_serve.py # Entry point for vLLM server
core/trainers/grpo/
trainer.py # AxolotlGRPOTrainer
sampler.py # Sampling utilities
core/builders/rl.py # HFRLTrainerBuilder — routes rl type → trainer
scripts/vllm_serve_lora.py # vLLM serve script with LoRA sync support
utils/schemas/trl.py # TRL config schema (all trl: options)
docs/grpo.qmd # Full user docs: async, rewards, scaling, config reference
docs/vllm_serving.qmd # vLLM server modes, LoRA sync, weight sync
```

View File

@@ -1,121 +0,0 @@
# Preference Learning (RLHF) — Agent Reference
Reference for DPO, IPO, KTO, ORPO, and SimPO. For config templates and dataset format examples, see [rlhf.qmd](../rlhf.qmd). For GRPO, see [grpo.qmd](../grpo.qmd). For EBFT, see [ebft.qmd](../ebft.qmd).
## Method Overview
| Method | Data Requirement | Key Idea | Best For |
|--------|-----------------|----------|----------|
| **DPO** | Paired (chosen + rejected) | Implicit reward via preference pairs | General alignment, most common |
| **IPO** | Paired (chosen + rejected) | DPO with different loss (avoids overfitting) | When DPO overfits |
| **KTO** | Unpaired (completion + binary label) | Kahneman-Tversky loss, no pairs needed | When you only have thumbs-up/down |
| **ORPO** | Paired (chosen + rejected) | Combined SFT + preference, no ref model | Single-stage alignment, saves VRAM |
| **SimPO** | Paired (chosen + rejected) | Length-normalized, no ref model | Simple setup, length-robust |
Default: start with DPO. All methods require `sample_packing: false`.
## Architecture
```
┌──────────────┐ ┌───────────────┐ ┌───────────────┐
│ Policy Model │ │ Reference │ │ Preference │
│ (trainable) │ │ Model (frozen)│ │ Dataset │
└──────┬───────┘ └──────┬────────┘ └──────┬────────┘
└──────────┬───────┘ │
v │
Forward pass on chosen + rejected <─────┘
Preference Loss (DPO/IPO/KTO/...)
Backprop + Update
Exception: ORPO and SimPO do NOT use a reference model (~50% less VRAM).
```
No vLLM server needed (unlike GRPO). Offline RL with pre-collected preference data.
## Method Selection
1. Paired preference data (chosen + rejected)?
- Default → `rl: dpo`
- Overfitting → `rl: ipo`
- VRAM-limited → `rl: orpo` (no ref model)
- Length-sensitive → `rl: simpo` (no ref model)
2. Only binary labels (good/bad)? → `rl: kto`
3. Single-stage training (no separate SFT)? → `rl: orpo`
| | DPO | IPO | KTO | ORPO | SimPO |
|---|---|---|---|---|---|
| **Reference model** | Yes | Yes | Yes | No | No |
| **VRAM overhead** | ~2x model | ~2x model | ~2x model | ~1x model | ~1x model |
| **TRL trainer class** | DPOTrainer | DPOTrainer | KTOTrainer | ORPOTrainer | CPOTrainer |
## Prompt Strategy Resolution
The `type` field resolves to a Python function:
```
type: "chatml.intel"
→ axolotl.prompt_strategies.dpo.chatml.intel(cfg, **kwargs)
→ returns transform_fn(sample) → {"prompt", "chosen", "rejected"}
type: "chat_template.default"
→ axolotl.prompt_strategies.dpo.chat_template.default(cfg, dataset_idx, **kwargs)
type: {"field_prompt": "prompt", ...} (dict)
→ axolotl.prompt_strategies.dpo.user_defined.default(...)
```
Module base: `axolotl.prompt_strategies.{rl_method}` — replace `dpo` with `kto` or `orpo`.
## Healthy Training Indicators
| Metric | Healthy Range | Problem |
|--------|--------------|---------|
| `train/loss` | Decreasing, 0.3-0.7 | Flat or increasing = broken data or too high LR |
| `rewards/chosen` | Increasing | Flat = model not learning preferences |
| `rewards/rejected` | Decreasing | Increasing = model prefers wrong responses |
| `rewards/margins` | Positive and increasing | Negative = prefers rejected over chosen |
| `rewards/accuracies` | > 0.5, toward 0.7+ | < 0.5 = worse than random |
| `logps/rejected` | Decreasing | Increasing = reward hacking |
| `grad_norm` | 0.01 - 10.0 | > 100 = exploding gradients |
Method-specific: DPO/IPO watch `rewards/margins`; KTO loss is noisier; ORPO monitor SFT + odds ratio components; SimPO check length-normalized reward separation.
## Known Issues
| Issue | Fix |
|-------|-----|
| Sample packing crash | Set `sample_packing: false` (required for all preference methods) |
| KTO `KeyError: 'label'` | Ensure dataset has boolean `label` column |
| ORPO/KTO `KeyError` during tokenization | Add `remove_unused_columns: false` |
| ORPO template not applied | ORPO requires explicit `chat_template` setting |
| OOM with ref model (DPO/IPO/KTO) | Use LoRA/QLoRA, or switch to ORPO/SimPO (no ref model) |
| IPO + label_smoothing | Do not set `dpo_label_smoothing` when `rl: ipo` |
Full troubleshooting: [training_stability.qmd](../training_stability.qmd)
## File Map
```
src/axolotl/
core/trainers/dpo/ # DPO trainer, args, strategy
core/builders/rl.py # HFRLTrainerBuilder — routes rl type → trainer class
core/training_args.py # AxolotlKTOConfig, AxolotlORPOConfig, AxolotlCPOConfig
prompt_strategies/
dpo/ # DPO/IPO/SimPO dataset strategies
chat_template.py # chat_template.default, chat_template.argilla_chat
chatml.py # chatml.default/intel/icr/argilla_chat/prompt_pairs/ultra
llama3.py # llama3 variants (same subtypes as chatml)
user_defined.py # Custom field mapping
passthrough.py # No transform
kto/ # KTO dataset strategies (chatml, llama3, user_defined)
orpo/ # ORPO dataset strategies (chat_template.argilla)
utils/schemas/enums.py # RLType enum (dpo, ipo, kto, orpo, simpo, grpo, gdpo, ebft)
utils/schemas/config.py # All rl/dpo/kto/orpo/simpo config fields
docs/rlhf.qmd # Full user docs: all dataset formats, config templates
docs/choosing_method.qmd # SFT vs DPO vs GRPO decision guide
examples/qwen2/dpo.yaml # DPO example
examples/llama-3/qlora-1b-kto.yaml # KTO example
```

View File

@@ -1,75 +0,0 @@
# Pretraining / Continual Pretraining — Agent Reference
Train on raw text with no input masking. Two approaches depending on dataset size.
## When to Use
- Continual pretraining on domain-specific corpora
- Adapting a base model to a new language or domain before fine-tuning
- Pretraining-style data where the entire text is the training signal
## Choosing an Approach
| | Non-streaming (`type: completion`) | Streaming (`pretraining_dataset`) |
|---|---|---|
| **Dataset size** | Fits in memory | Too large to fit in memory |
| **Tokenization** | Pre-tokenized before training | On-demand during training |
| **Config key** | `datasets:` | `pretraining_dataset:` |
| **Long text handling** | Splits texts exceeding `sequence_len` | Concatenates into fixed-length sequences |
| **Benefit** | Can preprocess on CPU, transfer to GPU | Start training immediately, no preprocessing |
## Non-Streaming: `type: completion`
For smaller datasets that fit in memory. Pre-tokenizes the entire dataset.
```yaml
datasets:
- path: my_corpus
type: completion
# field: text # Column name (default: "text")
```
## Streaming: `pretraining_dataset`
For large corpora. Streams data on-demand without loading everything into memory.
```yaml
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
type: pretrain
text_column: text
split: train
max_steps: 1000 # Required — axolotl can't infer dataset size
streaming_multipack_buffer_size: 10000 # Buffer for sample packing
pretrain_multipack_attn: true # Prevent cross-attention between packed samples
```
`max_steps` is required for streaming — one step = `sequence_len * micro_batch_size * gradient_accumulation_steps * num_gpus` tokens.
Full streaming docs: [streaming.qmd](../streaming.qmd)
## Dataset Format
```json
{"text": "The complete document text goes here."}
```
## Key Settings
- `sample_packing: true` + `pad_to_sequence_len: true` — pack documents into fixed-length sequences
- `flash_attention: true` — required for sample packing
- No adapter — typically full fine-tune for pretraining
- `train_on_inputs: true` — default for completion (all tokens trained on)
## File Map
```
src/axolotl/
prompt_strategies/completion.py # Non-streaming: completion prompt strategy (no masking)
utils/data/sft.py # Non-streaming: dataset loading and processing
utils/data/streaming.py # Streaming: encode_streaming(), wrap_streaming_dataset()
utils/schemas/config.py # Config fields: pretraining_dataset, pretrain_multipack_attn, etc.
examples/streaming/pretrain.yaml # Full streaming pretraining example config
```

View File

@@ -1,48 +0,0 @@
# Reward Modelling — Agent Reference
Train models to score responses for use as reward signals in RL. For full docs, see [reward_modelling.qmd](../reward_modelling.qmd).
## Types
### Outcome Reward Models (ORM)
Train a classifier to predict preference over entire interactions. Uses `AutoModelForSequenceClassification`.
```yaml
base_model: google/gemma-2-2b
model_type: AutoModelForSequenceClassification
num_labels: 1
reward_model: true
chat_template: gemma
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
```
Dataset format: `{"system": "...", "input": "...", "chosen": "...", "rejected": "..."}`
### Process Reward Models (PRM)
Train a token classifier to score each reasoning step. Uses `AutoModelForTokenClassification`.
```yaml
base_model: Qwen/Qwen2.5-3B
model_type: AutoModelForTokenClassification
num_labels: 2
process_reward_model: true
datasets:
- path: trl-lib/math_shepherd
type: stepwise_supervised
```
Dataset format: see [stepwise_supervised.qmd](../dataset-formats/stepwise_supervised.qmd).
## File Map
```
src/axolotl/
core/builders/causal.py # Handles reward_model flag in trainer builder
prompt_strategies/bradley_terry/ # Bradley-Terry prompt strategies
prompt_strategies/stepwise_supervised.py # PRM dataset strategy
utils/schemas/config.py # reward_model, process_reward_model config fields
```

View File

@@ -1,115 +0,0 @@
# SFT — Agent Reference
Supervised fine-tuning pipeline reference. For config templates and dataset format examples, see [getting-started.qmd](../getting-started.qmd) and [dataset-formats/](../dataset-formats/).
## Architecture
```
YAML Config → axolotl train config.yaml
1. Load base model (+ quantization if QLoRA/8-bit)
2. Apply adapter layers (LoRA/QLoRA) if configured
3. Load + tokenize dataset(s)
- Apply prompt template (chat_template / alpaca / custom)
- Mask inputs (train_on_inputs: false)
- Pack samples into sequences (sample_packing: true)
4. Training loop (HuggingFace Trainer)
- forward → loss → backward → optimizer step → lr scheduler step
5. Save model / adapter weights + tokenizer
Multi-GPU: FSDP or DeepSpeed shards model across GPUs automatically.
```
## Components Required
1. A YAML config — model, dataset(s), adapter settings, hyperparameters
2. A dataset — HuggingFace Hub, local JSONL/JSON/Parquet, or S3/GCS path
3. (Optional) A custom prompt strategy — for non-standard dataset formats
No external server processes needed (unlike GRPO which requires vLLM).
## Dataset Format Decision Tree
```
Is your data in chat/message format?
├─ YES: OpenAI message format (role/content)?
│ ├─ YES ──────────────────────> type: chat_template (recommended)
│ └─ NO (custom field names) ──> type: chat_template + message_property_mappings
└─ NO: Instruction/response pairs?
├─ YES ──> type: alpaca (instruction, input, output)
└─ NO: Raw text?
├─ YES with segments ─────> type: input_output (template-free masking)
└─ YES continuous ────────> type: completion (pretraining-style)
```
Full format specs: [dataset-formats/](../dataset-formats/)
## Model Size to Adapter Choice
| Model Size | LoRA | QLoRA (4-bit) | Full Fine-Tune | VRAM (approx) |
|-----------|------|---------------|----------------|---------------|
| 1-3B | Preferred | Low-budget option | Single GPU OK | 8-16 GB (LoRA) |
| 7-8B | Preferred | Good balance | Needs multi-GPU | 16-24 GB (LoRA) |
| 13-14B | Preferred | Good balance | Multi-GPU required | 24-40 GB (LoRA) |
| 30-70B | LoRA or QLoRA | Preferred for single GPU | Multi-node | 40-80 GB (QLoRA) |
## Hyperparameter Ranges
| Parameter | LoRA | QLoRA | Full FT |
|-----------|------|-------|---------|
| `learning_rate` | 1e-4 to 3e-4 | 1e-4 to 3e-4 | 1e-5 to 5e-5 |
| `lora_r` | 16-64 | 16-64 | N/A |
| `lora_alpha` | 1-2x `lora_r` | 1-2x `lora_r` | N/A |
| `micro_batch_size` | 2-8 | 2-4 | 1-2 |
| `gradient_accumulation_steps` | 2-8 | 4-16 | 4-16 |
| `num_epochs` | 1-3 | 1-3 | 1-3 |
| `optimizer` | `adamw_8bit` | `adamw_bnb_8bit` | `adamw_torch_fused` |
Effective batch = micro_batch * grad_accum * num_gpus. Lower LR for larger models.
## Healthy Training Indicators
| Metric | Healthy | Problem |
|--------|---------|---------|
| `train_loss` | Decreasing, starting ~2-4 for chat models | Flat or increasing from step 1 — data or LR issue |
| `eval_loss` | Decreasing, tracks train_loss | Increasing while train_loss decreases — overfitting |
| `grad_norm` | 0.1-10, relatively stable | Spikes >100 — instability. 0.0 — frozen weights |
| `learning_rate` | Follows scheduler curve | Flat or NaN — config issue |
Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss goes to 0 quickly (overfitting), eval_loss diverging (reduce epochs, add regularization). See [training_stability.qmd](../training_stability.qmd).
## Known Issues
| Issue | Fix |
|-------|-----|
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `flash_attention: true` or disable `sample_packing` |
| Missing chat template error | Set `chat_template: chatml` explicitly |
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |
| Tokenizer pad token / infinite loss | Set `special_tokens: pad_token: "<\|end_of_text\|>"` |
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
## File Map
```
src/axolotl/
cli/train.py # Entry point for `axolotl train`
cli/preprocess.py # Entry point for `axolotl preprocess`
core/builders/causal.py # HFCausalTrainerBuilder — wires config → SFT trainer
core/trainers/base.py # AxolotlTrainer — base trainer class
core/trainers/mixins/ # Packing, optimizer, scheduler, checkpoints
prompt_strategies/ # Format handlers: chat_template, alpaca, completion, input_output
utils/schemas/config.py # AxolotlInputConfig — main config schema
utils/schemas/datasets.py # SFTDataset, DatasetConfig
utils/schemas/peft.py # LoraConfig — LoRA parameters
integrations/liger/ # Liger kernel plugin
examples/llama-3/ # LoRA, QLoRA, full FT example configs
docs/getting-started.qmd # Quickstart with config templates
docs/optimizations.qmd # Flash attention, gradient checkpointing, sample packing
docs/multi-gpu.qmd # FSDP and DeepSpeed setup
```

View File

@@ -86,7 +86,7 @@ export HF_DATASETS_OFFLINE=1
Download a base model using the Hugging Face CLI:
```bash
hf download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
```
### 10. Create Axolotl Configuration

View File

@@ -1,178 +0,0 @@
---
title: Attention
description: Supported attention modules in Axolotl
---
## SDP Attention
This is the default built-in attention in PyTorch.
```yaml
sdp_attention: true
```
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Flash Attention
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
based on your installed packages and GPU.
```yaml
flash_attention: true
```
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Flash Attention 2
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
```bash
pip install flash-attn --no-build-isolation
```
::: {.callout-tip}
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
Alternatively, try reinstall or downgrade a version.
:::
### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```
### Flash Attention 4
Requirements: Hopper or Blackwell GPUs
```bash
pip install flash-attn-4
```
Or from source:
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/flash_attn/cute
pip install -e .
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
# Remove it so Python can find the real FA4 module:
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
```
::: {.callout-note}
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
for training on Hopper, install from source using the instructions above.
:::
::: {.callout-warning}
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
and falls back to FA2/3.
:::
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
### AMD
Requirements: ROCm 6.0 and above.
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
## Flex Attention
A flexible PyTorch API for attention used in combination with `torch.compile`.
```yaml
flex_attention: true
# recommended
torch_compile: true
```
::: {.callout-note}
We recommend using latest stable version of PyTorch for best performance.
:::
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
## SageAttention
Attention kernels with QK Int8 and PV FP16 accumulator.
```yaml
sage_attention: true
```
Requirements: Ampere, Ada, or Hopper GPUs
```bash
pip install sageattention==2.2.0 --no-build-isolation
```
::: {.callout-warning}
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
:::
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
::: {.callout-note}
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
:::
## xFormers
```yaml
xformers_attention: true
```
::: {.callout-tip}
We recommend using with Turing GPUs or below (such as on Colab).
:::
For more details: [xFormers](https://github.com/facebookresearch/xformers)
## Shifted Sparse Attention
::: {.callout-warning}
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
:::
Requirements: LLaMA model architecture
```yaml
flash_attention: true
s2_attention: true
```
::: {.callout-tip}
No sample packing support!
:::

View File

@@ -1,86 +0,0 @@
---
title: "Checkpoint Saving"
format:
html:
toc: true
toc-depth: 2
number-sections: true
execute:
enabled: false
---
## Overview
Axolotl supports on-demand checkpoint saving during training. You can trigger checkpoints via file-based triggers (for programmatic control) or Control+C (for interactive use).
## File-Based Checkpoint Trigger
### Configuration
Enable in your config:
```yaml
dynamic_checkpoint:
enabled: true
check_interval: 100 # Optional: check every N steps (default: 100)
trigger_file_path: "axolotl_checkpoint.save" # Optional: custom filename
```
**Options:**
- `enabled`: `true` to enable (required)
- `check_interval`: Steps between file checks. Default: 100. Lower = faster response, higher I/O overhead.
- `trigger_file_path`: Custom trigger filename. Default: `axolotl_checkpoint.save`
### How It Works
1. Rank 0 checks for trigger file every `check_interval` steps in `output_dir`
2. When detected, file is deleted and checkpoint is saved
3. In distributed training, rank 0 broadcasts to synchronize all ranks
### Usage
**Command line:**
```bash
touch /path/to/output_dir/axolotl_checkpoint.save
```
**Programmatic:**
```python
from pathlib import Path
Path("/path/to/output_dir/axolotl_checkpoint.save").touch()
```
Checkpoint saves within the next `check_interval` steps. The trigger file is auto-deleted after detection, so you can create it multiple times.
**Custom filename:**
```yaml
dynamic_checkpoint:
enabled: true
trigger_file_path: "my_trigger.save"
```
```bash
touch /path/to/output_dir/my_trigger.save
```
## Control+C (SIGINT) Checkpoint
Pressing `Ctrl+C` during training saves the model state and exits gracefully. **Note:** This saves only the model weights, not optimizer state. For resumable checkpoints, use the file-based trigger.
## Best Practices
- **Check interval**: Lower values (10-50) for fast training, default 100 for slower training
- **Distributed training**: Create trigger file once; rank 0 handles synchronization
- **Resume**: Dynamic checkpoints can be resumed like regular checkpoints via `resume_from_checkpoint`
## Example
```yaml
output_dir: ./outputs/lora-out
save_steps: 500 # Scheduled checkpoints
dynamic_checkpoint:
enabled: true
check_interval: 50
```
This enables scheduled checkpoints every 500 steps plus on-demand saves via file trigger (checked every 50 steps).

View File

@@ -1,206 +0,0 @@
---
title: "Which Fine-Tuning Method Should I Use?"
description: "A decision guide for choosing the right fine-tuning method, adapter, and hardware configuration in Axolotl."
format:
html:
toc: true
toc-depth: 3
number-sections: true
execute:
enabled: false
---
## Overview {#sec-overview}
Axolotl supports four broad categories of fine-tuning, each suited to different data types, objectives, and resource constraints.
| Method | What It Does | Data You Need |
|--------|-------------|---------------|
| **Supervised Fine-Tuning (SFT)** | Teaches the model to produce specific outputs given inputs | Input-output pairs (instructions, conversations, completions) |
| **Preference Learning (DPO/KTO/ORPO)** | Steers the model toward preferred outputs and away from dispreferred ones | Chosen/rejected response pairs (DPO, ORPO) or binary labels (KTO) |
| **Reinforcement Learning (GRPO)** | Optimizes the model against a reward signal through online generation | A reward function (code or model-based) and a prompt dataset |
| **Reward Modeling** | Trains a model to score responses, for use as a reward signal in RL | Preference pairs ranked by quality |
Each method is configured through a YAML file with `rl: <method>` (or omitted for SFT). All methods support LoRA, QLoRA, and full fine-tuning unless otherwise noted.
## Decision Tree {#sec-decision-tree}
Use the following flowchart to choose your method. Start at the top and follow the path that matches your situation.
```
Do you have a reward function (code-based or model-based)?
├── YES
│ └── Use GRPO (rl: grpo)
│ The model generates its own completions and learns from reward scores.
│ Best for: math, code, reasoning, tasks with verifiable answers.
│ See: rlhf.qmd#grpo
└── NO
Do you have preference pairs (chosen vs. rejected responses)?
├── YES
│ │
│ Are they paired (same prompt, one chosen, one rejected)?
│ ├── YES → Use DPO (rl: dpo)
│ │ Direct optimization without a separate reward model.
│ │ See: rlhf.qmd#dpo
│ │
│ └── NO (only binary good/bad labels)
│ └── Use KTO (rl: kto)
│ Works with unpaired preference data.
│ See: rlhf.qmd#kto
└── NO
Do you have input-output examples?
├── YES → Use SFT
│ The simplest and most common method.
│ See: getting-started.qmd
└── NO
└── You need to create training data first.
Consider generating preference pairs with an LLM judge,
or writing a reward function for GRPO.
```
::: {.callout-tip}
**When in doubt, start with SFT.** It is the most straightforward method and works well for most tasks. You can always move to preference learning or RL later to further refine behavior.
:::
### Method Comparison at a Glance
| Criterion | SFT | DPO | KTO | GRPO |
|-----------|-----|-----|-----|------|
| Data complexity | Low (input-output pairs) | Medium (preference pairs) | Medium (binary labels) | Low (prompts + reward code) |
| Compute cost | Low | Medium | Medium | High (requires vLLM server) |
| Learning signal | Supervised | Contrastive | Contrastive | Online reward |
| Online generation | No | No | No | Yes |
| Reward model needed | No | No | No | No (uses reward functions) |
| Best for | Task adaptation, instruction following | Safety, style alignment | Unpaired preference data | Reasoning, math, code |
::: {.callout-note}
**ORPO** is an alternative to DPO that combines SFT and preference optimization in a single training stage, removing the need for a separate SFT step. Configure with `rl: orpo`. See [rlhf.qmd](rlhf.qmd) for details.
:::
## Adapter Selection {#sec-adapter-selection}
Once you have chosen a method, decide how to apply the parameter updates. The three main options trade off VRAM usage against model quality.
### QLoRA
- **How it works**: The base model is loaded in 4-bit (NF4) quantization. Small low-rank adapter matrices are trained in higher precision on top.
- **VRAM savings**: Roughly 4x reduction in model memory compared to full fine-tuning.
- **Quality**: Slight degradation due to quantization noise, but often negligible for task-specific fine-tuning.
- **When to use**: When your GPU cannot fit the model in full precision, or when you want fast experimentation.
```yaml
adapter: qlora
load_in_4bit: true
lora_r: 32
lora_alpha: 64
lora_target_linear: true
```
### LoRA
- **How it works**: The base model is loaded at full precision (or 8-bit). Low-rank adapter matrices are trained alongside.
- **VRAM savings**: Roughly 2-3x reduction compared to full fine-tuning (model weights are frozen, only adapters + optimizer states for adapters are stored).
- **Quality**: Very close to full fine-tuning for most tasks, especially with higher rank values.
- **When to use**: When you have enough VRAM for the base model but not for full optimizer states.
```yaml
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
```
::: {.callout-tip}
For GRPO training, LoRA is strongly recommended. The vLLM server needs to sync weights from the trainer, and LoRA sync (`trl.vllm_lora_sync: true`) is far more efficient than syncing full merged weights. See [vLLM Serving](vllm_serving.qmd) for details.
:::
### Full Fine-Tuning
- **How it works**: All model parameters are updated during training. No adapters.
- **VRAM savings**: None. Requires memory for model weights, gradients, and optimizer states (roughly 4x model size in bf16 with AdamW).
- **Quality**: Highest potential quality, especially for large distribution shifts.
- **When to use**: When you have ample GPU memory or multi-GPU setups, and need maximum performance. Also required for pre-training.
```yaml
# No adapter or load_in_* lines needed
micro_batch_size: 1
gradient_accumulation_steps: 16
```
### Quick Comparison
| | QLoRA | LoRA | Full |
|---|---|---|---|
| Trainable params | ~0.1-1% | ~0.1-1% | 100% |
| Model memory | ~25% of full | ~50-100% of full | 100% |
| Optimizer memory | Tiny (adapters only) | Tiny (adapters only) | 2x model size (AdamW) |
| Training speed | Slower (dequantization overhead) | Baseline | Faster per-step (no adapter overhead) |
| Inference | Merge or serve with adapter | Merge or serve with adapter | Direct |
| Multi-GPU required? | Rarely | For 13B+ models | For 7B+ models |
## Hardware Mapping {#sec-hardware-mapping}
The tables below provide approximate GPU memory requirements. Actual usage depends on context length, batch size, and optimizer choice.
### SFT / Preference Learning
| Model Size | QLoRA (4-bit) | LoRA (bf16) | Full (bf16 + AdamW) |
|------------|--------------|-------------|---------------------|
| 1-3B | 6-8 GB | 8-12 GB | 24-32 GB |
| 7-8B | 10-14 GB | 16-24 GB | 60-80 GB |
| 13-14B | 16-20 GB | 28-40 GB | 120+ GB |
| 30-34B | 24-32 GB | 64-80 GB | 2-4x 80 GB |
| 70-72B | 40-48 GB | 2x 80 GB | 4-8x 80 GB |
::: {.callout-important}
These estimates assume a short context length (512-2048 tokens) and micro_batch_size of 1-2. Longer sequences and larger batches increase memory significantly due to activations. Use [gradient checkpointing](gradient_checkpointing.qmd) to reduce activation memory at the cost of ~30% slower training.
:::
### GRPO (RL Training)
GRPO requires additional GPU(s) for the vLLM generation server. Plan for at least two GPUs: one for training, one for vLLM.
| Model Size | Training GPU (LoRA, bf16) | vLLM GPU | Total GPUs |
|------------|--------------------------|----------|------------|
| 0.5-3B | 1x 24 GB | 1x 24 GB | 2x 24 GB |
| 7-8B | 1x 80 GB | 1x 80 GB | 2x 80 GB |
| 13-14B | 1-2x 80 GB | 1-2x 80 GB | 2-4x 80 GB |
| 30-72B | 2-4x 80 GB (FSDP/DeepSpeed) | 2-4x 80 GB (tensor parallel) | 4-8x 80 GB |
::: {.callout-tip}
For single-GPU GRPO, use `vllm_mode: colocate` with `vllm_enable_sleep_mode: true`. The vLLM engine shares the GPU and offloads VRAM when not generating. This works for smaller models (up to ~3B on a 24 GB GPU) but is slower than the two-GPU server mode.
:::
### Multi-GPU Threshold
You need multi-GPU training when:
- **Full fine-tuning** of models 7B+ (use FSDP or DeepSpeed ZeRO)
- **LoRA** of models 30B+ (or 13B+ with long contexts)
- **GRPO** almost always (separate vLLM server), unless using colocate mode
See [Multi-GPU Training](multi-gpu.qmd) for FSDP and DeepSpeed configuration.
## Quick Links {#sec-quick-links}
| Method | Config Key | Documentation | Example Config |
|--------|-----------|---------------|----------------|
| SFT | *(default, no `rl:` key)* | [Getting Started](getting-started.qmd) | `examples/llama-3/lora-1b.yml` |
| DPO | `rl: dpo` | [RLHF - DPO](rlhf.qmd#dpo) | See rlhf.qmd |
| KTO | `rl: kto` | [RLHF - KTO](rlhf.qmd#kto) | See rlhf.qmd |
| ORPO | `rl: orpo` | [RLHF - ORPO](rlhf.qmd#orpo) | See rlhf.qmd |
| GRPO | `rl: grpo` | [RLHF - GRPO](rlhf.qmd#grpo), [vLLM Serving](vllm_serving.qmd) | See rlhf.qmd |
| Reward Modeling | `rl: reward_trainer` | [Reward Modelling](reward_modelling.qmd) | See reward_modelling.qmd |
### Related Guides
- [Configuration Reference](config-reference.qmd) -- Full list of all config options
- [Dataset Formats](dataset-formats) -- How to structure your training data
- [Optimizations](optimizations.qmd) -- Flash attention, gradient checkpointing, mixed precision
- [Multi-GPU Training](multi-gpu.qmd) -- FSDP and DeepSpeed setup
- [vLLM Serving](vllm_serving.qmd) -- Setting up vLLM for GRPO training

View File

@@ -210,8 +210,6 @@ axolotl lm-eval config.yml
Configuration options:
```yaml
lm_eval_model: # model to evaluate (local or hf path)
# List of tasks to evaluate
lm_eval_tasks:
- arc_challenge
@@ -220,7 +218,7 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
### delinearize-llama4

View File

@@ -212,21 +212,6 @@ Instead of passing `tools` via the system prompt, an alternative method would be
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
:::
::: {.callout-warning}
If you have tool arguments with same name but different dtypes (like `"time": string` and `"time": number`), please save `arguments: ` as JSON string to prevent `datasets` from having casting issues.
```
"arguments": "{\"...\": \"...\"}"
```
The same is applicable for tool parameters.
```
"parameters": "{\"...\": \"...\"}"
```
:::
Example config for Llama4:
```yaml
chat_template: llama4

View File

@@ -22,47 +22,90 @@ For `pretraining_dataset:` specifically, please refer to the [Pre-training secti
## Pre-training
Pre-training trains on raw text corpora with no input masking. The dataset format is simple:
When aiming to train on large corpora of text datasets, pre-training is your go-to choice. Due to the size of these datasets, downloading the entire-datasets before beginning training would be prohibitively time-consuming. Axolotl supports [streaming](https://huggingface.co/docs/datasets/en/stream) to only load batches into memory at a time.
A sample format for a pre-training dataset is as follows:
```json
{"text": "first row"}
{"text": "second row"}
...
```
Axolotl supports two approaches:
It is typically recommended to save your dataset as `.jsonl` due to its flexibility and simplicity.
### Streaming (large datasets)
Axolotl supports loading from a Hugging Face hub repo or from local files.
For large corpora that don't fit in memory, use `pretraining_dataset` with [streaming](../streaming.qmd). Data is tokenized on-demand during training.
### Pre-training from Hugging Face hub datasets
As an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config:
```yaml
pretraining_dataset: hf_org/name
```
### Pre-training from local dataset files
Given a few corpus files: `A.jsonl`, `B.jsonl`, and `C.jsonl`, your config will look like the below:
```yaml
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
type: pretrain
text_column: text
split: train
- path: json
data_files:
- A.jsonl
- B.jsonl
- C.jsonl
```
::: {.callout-important}
Streaming requires `max_steps` in your config — Axolotl cannot infer the dataset size. One step = `sequence_len * micro_batch_size * gradient_accumulation_steps * num_gpus` tokens.
:::
While we recommend `.jsonl`, you can also use the other formats (`csv`, `parquet`, `arrow`, `SQL`, `Webdataset`) that are supported by [`Dataset.load_dataset`](https://huggingface.co/docs/datasets/loading#local-and-remote-files)
See [Streaming Datasets](../streaming.qmd) for full configuration details.
### Pre-training without streaming
### Non-streaming (smaller datasets)
On the rare case that the dataset is small and can be loaded entirely into memory, another approach to running pre-training is to use the `completion` format. This would mean that the entire dataset is pre-tokenized instead of on-demand in streaming.
For datasets that fit in memory, use `type: completion` under `datasets:`. The entire dataset is pre-tokenized before training, which can be done on a CPU-only machine.
One benefit of this is that the tokenization can be performed separately on a CPU-only machine, and then transferred to a GPU machine for training to save costs.
From Hugging Face:
```yaml
datasets:
- path: my_corpus
- path: hf_org/name
type: completion
```
::: {.callout-note}
With `completion`, texts exceeding `sequence_len` are split into multiple samples automatically.
From local files:
```yaml
datasets:
- path: A.jsonl
type: completion
- path: B.jsonl
type: completion
```
::: {.callout-important}
For `completion` only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts. If you are interested in having this for `pretraining_dataset` too, please let us know or help make a PR!
:::
### Pre-training dataset configuration tips
#### Setting max_steps
When using streaming for large datasets, Axolotl does not know in advance how large the dataset is and does not know when to stop.
Therefore, it is necessary to set `max_steps: int` in your config for pre-training to run, so that Axolotl knows when to stop training.
One step is equal to `sequence_len * micro_batch_size * gradient_accumulation_steps * total_num_gpus` tokens.
#### Group_by_length
It is recommended to leave this off if downloading from Hugging Face hub as it would download the entire dataset which can be very large.
### Reference
Please see docs [here](pretraining.qmd).
## Supervised fine-tuning (SFT)
Supervised fine-tuning is the process of training models to respond to an instruction or chat input.

View File

@@ -4,9 +4,29 @@ description: Data format for a pre-training completion task.
order: 1
---
::: {.callout-note}
Pre-training documentation has been consolidated:
For pretraining, there is no prompt template or roles. The only required field is `text`:
```{.json filename="data.jsonl"}
{"text": "first row"}
{"text": "second row"}
...
```
:::{.callout-note}
### Streaming is recommended for large datasets
Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
```{.yaml filename="config.yaml"}
pretraining_dataset:
- name:
path:
split:
text_column: # column in dataset with the data, usually `text`
type: pretrain
trust_remote_code:
skip: # number of rows of data to skip over from the beginning
```
- **Streaming pretraining** (large datasets): See [Streaming Datasets](../streaming.qmd#pretraining-with-streaming)
- **Non-streaming pretraining** (`type: completion`): See [Dataset Formats](index.qmd#pre-training)
:::

View File

@@ -6,10 +6,6 @@ description: How to debug Axolotl
This document provides some tips and tricks for debugging Axolotl. It also provides an example configuration for debugging with VSCode. A good debugging setup is essential to understanding how Axolotl code works behind the scenes.
::: {.callout-tip}
For training-specific debugging (loss spikes, NaN gradients, OOM errors, RL training stability), see [Training Stability & Debugging](training_stability.qmd).
:::
## Table of Contents
- [General Tips](#general-tips)
@@ -33,7 +29,7 @@ While debugging it's helpful to simplify your test scenario as much as possible.
1. **Make sure you are using the latest version of axolotl**: This project changes often and bugs get fixed fast. Check your git branch and make sure you have pulled the latest changes from `main`.
1. **Eliminate concurrency**: Restrict the number of processes to 1 for both training and data preprocessing:
- Set `CUDA_VISIBLE_DEVICES` to a single GPU, ex: `export CUDA_VISIBLE_DEVICES=0`.
- Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`.
- Set `dataset_processes: 1` in your axolotl config or run the training command with `--dataset_processes=1`.
2. **Use a small dataset**: Construct or use a small dataset from HF Hub. When using a small dataset, you will often have to make sure `sample_packing: False` and `eval_sample_packing: False` to avoid errors. If you are in a pinch and don't have time to construct a small dataset but want to use from the HF Hub, you can shard the data (this will still tokenize the entire dataset, but will only use a fraction of the data for training. For example, to shard the dataset into 20 pieces, add the following to your axolotl config):
```yaml
@@ -89,7 +85,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 axolotl train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
```json
// .vscode/launch.json
@@ -105,7 +101,7 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 axolotl
"-m", "axolotl.cli.train", "dev_chat_template.yml",
// The flags below simplify debugging by overriding the axolotl config
// with the debugging tips above. Modify as needed.
"--dataset_num_proc=1", // limits data preprocessing to one process
"--dataset_processes=1", // limits data preprocessing to one process
"--max_steps=1", // limits training to just one step
"--batch_size=1", // minimizes batch size
"--micro_batch_size=1", // minimizes batch size
@@ -246,6 +242,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
</div>
<br>
[^1]: The VSCode config uses `accelerate.commands.launch` as the Python module entry point, which is what `axolotl train` invokes under the hood.
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).

View File

@@ -32,8 +32,11 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
Tags examples:
- `main-base-py3.11-cu128-2.8.0`
- `main-base-py3.11-cu128-2.9.1`
- `main-base-py3.11-cu128-2.7.1`
- `main-base-py3.11-cu126-2.7.1`
- `main-base-py3.11-cu126-2.7.0`
- `main-base-py3.11-cu126-2.6.0`
- `main-base-py3.11-cu124-2.6.0`
## Main
@@ -71,12 +74,15 @@ There may be some extra tags appended to the image, like `-vllm` which installs
Tags examples:
- `main-py3.11-cu128-2.8.0`
- `main-py3.11-cu128-2.9.1`
- `main-py3.11-cu128-2.7.1`
- `main-py3.11-cu126-2.7.1`
- `main-py3.11-cu126-2.7.0`
- `main-py3.11-cu126-2.6.0`
- `main-py3.11-cu124-2.6.0`
- `main-latest`
- `main-20250303-py3.11-cu124-2.6.0`
- `main-20250303-py3.11-cu126-2.6.0`
- `0.12.0`
- `0.10.1`
## Cloud

View File

@@ -1,556 +0,0 @@
---
title: "EBFT Training"
description: "Energy-Based Fine-Tuning uses feature-matching rewards from internal representations to train language models without external reward functions."
order: 9
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---
## Overview
Energy-Based Fine-Tuning (EBFT) is a training method that optimizes language models by matching the **internal feature representations** of generated text to those of ground-truth completions. Instead of relying on external reward models or hand-crafted reward functions, EBFT extracts hidden states from intermediate layers of a frozen copy of the model and uses cosine similarity between generated and reference features as the reward signal.
Paper: ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026)
### How EBFT Differs from Other RL Methods
| Method | Reward Signal | Requires | Best For |
|--------|--------------|----------|----------|
| **GRPO** | External reward function(s) | Custom reward code or reward model | Tasks with verifiable answers (math, code) |
| **DPO** | Preference pairs (chosen vs rejected) | Paired preference data | Alignment with human preferences |
| **EBFT** | Feature similarity to ground truth | Ground-truth completions | Any task with reference outputs |
EBFT's key advantage is that it needs only ground-truth completions -- no reward engineering, no preference annotation, and no reward model training. The model's own internal representations serve as the reward signal. This makes it particularly effective for:
- Code generation (match features of known-good solutions)
- Instruction following with reference outputs
- Continual pretraining on unstructured text (strided mode)
- Multi-turn dialogue with reference conversations
### Reward Formulation
The EBFT reward for each generated completion is:
```
reward = alignment_coef * cosine_similarity(gen_features, gt_features)
- diversity_coef * mean_pairwise_similarity(gen_features)
```
- **Alignment**: How closely the generated output's internal representations match the ground truth. Higher is better.
- **Diversity**: Penalizes generated samples that are too similar to each other (prevents mode collapse). Lower is better.
- **CFM loss** (Cross-Feature Matching): Tracks `||mean(gen_features) - gt_features||^2` as a diagnostic. This is the quantity that EBFT ultimately minimizes.
## Modes
EBFT supports three operational modes, each suited to different use cases.
### Structured Mode (Sync)
Uses vLLM on a separate GPU for generation, with sequential generate-score-train steps. This is the simplest mode and recommended for getting started.
```
GPU 0: vLLM Server (generates completions, receives weight syncs)
GPU 1: Trainer (feature extraction, reward computation, GRPO training)
```
**When to use**: Standard instruction-following or QA datasets where you have prompt/completion pairs. Requires 2 GPUs.
### Structured Mode (Async)
Same architecture as sync, but overlaps generation of the next batch with training on the current batch. Faster throughput at the cost of slightly stale weights during generation.
**When to use**: Same data as sync mode, but when you want faster training and can tolerate weight staleness (controlled by `vllm_sync_interval`).
### Strided Mode
Runs entirely on a single GPU with no vLLM dependency. Places anchor points throughout a document and generates short rollouts at each anchor using block-parallel attention patterns.
```
Single GPU: Base model + LoRA adapter
- Strided block-parallel generation (flex_attention)
- Feature extraction via disable_adapter()
- No vLLM needed
```
**When to use**: Unstructured text data (raw code, prose, documents) where there is no natural prompt/completion split. Also works with structured data that includes prompt boundaries. Requires only 1 GPU.
## Quick Start
### Structured Mode
This minimal example fine-tunes Qwen2-0.5B on code data using EBFT with vLLM generation.
**Step 1**: Create a config file `ebft_quickstart.yaml`:
```yaml
base_model: Qwen/Qwen2-0.5B-Instruct
rl: ebft
ebft:
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
alignment_coef: 1.0
diversity_coef: 1.0
trl:
num_generations: 4
max_completion_length: 256
temperature: 0.7
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
vllm_lora_sync: true
vllm_sync_interval: 3
use_data_producer: true
async_prefetch: false
scale_rewards: true
loss_type: grpo
vllm:
gpu_memory_utilization: 0.5
max_model_len: 1024
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:500]
# Standard training settings (see getting-started.qmd for details)
adapter: lora
lora_r: 16
lora_alpha: 32
lora_target_linear: true
sequence_len: 1024
micro_batch_size: 2
gradient_accumulation_steps: 4
max_steps: 20
learning_rate: 5.0e-6
bf16: auto
flash_attention: true
gradient_checkpointing: true
output_dir: ./outputs/ebft-quickstart
```
**Step 2**: Start vLLM on GPU 0:
```bash
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve ebft_quickstart.yaml
```
**Step 3**: Wait approximately 30 seconds for vLLM to initialize, then start training on GPU 1:
```bash
CUDA_VISIBLE_DEVICES=1 axolotl train ebft_quickstart.yaml
```
::: {.callout-important}
The `micro_batch_size` must be divisible by `num_generations`. For example, with `num_generations: 4`, valid values are 4, 8, 12, etc.
:::
### Dataset Format
Structured mode datasets must produce two fields after the transform:
- `prompt`: Either a string or a list of chat messages (`[{"role": "user", "content": "..."}]`)
- `ground_truth`: A string containing the reference completion
Example raw dataset row:
```json
{
"input": "Write a function to compute fibonacci numbers.",
"output": "def fibonacci(n):\n if n <= 1:\n return n\n return fibonacci(n-1) + fibonacci(n-2)"
}
```
The `ebft_opencode.transform` converts this to the required `{prompt, ground_truth}` format automatically.
## Feature Extraction
EBFT extracts hidden states from intermediate transformer layers and pools them into per-sequence embeddings. These embeddings are compared between generated and ground-truth completions to compute rewards.
### Feature Layers
The `feature_layers` parameter specifies which layers to extract, as fractions of total model depth:
```yaml
ebft:
feature_layers: [0.25, 0.5, 0.75] # Quarter, middle, three-quarter depth
```
For a 32-layer model, this extracts layers 8, 16, and 24. The hidden states from all selected layers are concatenated along the feature dimension, producing embeddings of size `num_layers * hidden_dim`.
::: {.callout-tip}
Using multiple layers captures both low-level syntactic features (early layers) and high-level semantic features (later layers). The default `[0.25, 0.5, 0.75]` works well across model sizes.
:::
### Embed Methods
The `embed_method` controls how per-token hidden states are pooled into a single vector per sequence:
| Method | Description | Output Shape | Notes |
|--------|-------------|-------------|-------|
| `last_token` | Hidden state at the last non-padding token | `(B, D)` | Default. Good for autoregressive models where the last token summarizes the sequence. |
| `mean_pooling` | Mean of all non-padding token states | `(B, D)` | Considers the entire sequence equally. |
| `completion_mean` | Mean over completion tokens only (excludes prompt) | `(B, D)` | Focuses reward signal on generated content. Requires prompt length information. |
| `concat` | Concatenation of states at 25%, 50%, 75% positions | `(B, 3*D)` | Captures positional structure. Higher dimensional. |
```yaml
ebft:
embed_method: completion_mean # Focus on completion features
```
### SVD Whitening
Whitening decorrelates the feature dimensions so that no single direction dominates the feature-matching loss. This is computed via SVD on the generated embeddings, with the same transform applied to the ground-truth embeddings.
```yaml
ebft:
use_whitening: true
```
When whitening is enabled, the reward computation applies a whitening matrix `W = U @ diag(1/S) @ U^T` derived from the SVD of generated embeddings. This ensures all feature dimensions contribute equally to the alignment reward.
::: {.callout-note}
Singular values scale with `sqrt(batch_size)`, so reward magnitudes are batch-size dependent. This is acceptable because the number of samples per prompt (`n_samples_per_prompt` or `num_generations`) is fixed during training.
:::
### Alignment and Diversity Coefficients
The two reward components are weighted by coefficients:
```yaml
ebft:
alignment_coef: 1.0 # Weight for cosine similarity with ground truth
diversity_coef: 1.0 # Weight for pairwise similarity penalty
```
Both values are scaled by 2 internally (per paper equation 7). The final reward per sample is:
```
reward_j = 2 * alignment_coef * cos(gen_j, gt)
- 2 * diversity_coef * (1/(n-1)) * sum_{j' != j} dot(gen_j, gen_j')
```
Setting `diversity_coef: 0.0` disables the diversity penalty entirely, which may be appropriate when `num_generations` is small (e.g., 2).
## Strided Mode
Strided mode is designed for training on unstructured text data where there is no natural prompt/completion boundary. Instead of generating full completions with vLLM, it places **anchor points** at regular intervals throughout each document and generates short rollouts at each anchor using block-parallel attention.
### How Block-Parallel Generation Works
Given a document of length `S` tokens:
1. **Anchor placement**: Starting at position `anchor_offset`, place anchors every `stride` tokens. Each anchor defines a block.
2. **Context window**: Each block sees `context_length` tokens of preceding context from the original document.
3. **Generation**: At each anchor, generate `generate_max_len` tokens autoregressively, conditioned only on the context window.
4. **Parallelism**: All blocks are processed in a single forward pass using a specialized attention mask that prevents information leakage between blocks.
```
Document: [tok0, tok1, ..., tok_S]
| | |
anchor_0 anchor_1 anchor_2
| | |
[ctx][gen] [ctx][gen] [ctx][gen]
```
The attention mask ensures:
- Prompt tokens use standard causal attention
- Each generated block attends to its own context window and its own preceding generated tokens
- Blocks do not attend to each other's generated tokens
When `flex_attention` is available (PyTorch >= 2.5), the mask is compiled into efficient fused kernels. Otherwise, a dense 4D attention mask is used as a fallback.
### Strided Mode Configuration
```yaml
base_model: meta-llama/Llama-3.2-1B
rl: ebft
ebft:
mode: strided
stride: 8 # Tokens between anchor points
context_length: 8 # Context window per block
generate_max_len: 8 # Tokens to generate per block
n_samples_per_prompt: 4 # Independent rollouts per document
temperature: 0.6
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: true
alignment_coef: 1.0
diversity_coef: 1.0
rl_coef: 1.0 # RL policy gradient loss weight
ce_coef: 0.03 # Cross-entropy loss on GT tokens
advantage_estimator: rloo # rloo, group_norm, or reinforce
min_completion_prefix: 8 # Skip anchors in prompt region
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_strided_structured.transform
split: train[:1%]
sequence_len: 2048
micro_batch_size: 1
gradient_accumulation_steps: 2
adapter: lora
lora_r: 16
lora_alpha: 32
lora_target_linear: true
bf16: auto
flex_attention: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true # Required with flex_attention
```
Run with a single command (no vLLM needed):
```bash
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
```
### Advantage Estimators
Strided mode supports three advantage estimation methods:
| Estimator | Formula | Requirements |
|-----------|---------|-------------|
| `rloo` | Leave-one-out baseline: `reward_j - mean(rewards_{-j})` | `n_samples_per_prompt >= 2` |
| `group_norm` | Group normalization: `(reward_j - mean) / std` | `n_samples_per_prompt >= 2` |
| `reinforce` | Raw reward as advantage (no baseline) | Works with `n_samples_per_prompt = 1` |
::: {.callout-warning}
When `n_samples_per_prompt: 1`, the trainer automatically falls back to `reinforce` and disables the diversity penalty (which requires multiple samples).
:::
### Strided Mode Constraints
- **`flex_attention: true`** is strongly recommended. Without it, dense 4D masks consume significantly more memory.
- **`torch_compile: true`** must NOT be set. `flex_attention` compiles its own kernels internally; adding `torch_compile` causes conflicts and OOM.
- **Gradient checkpointing** must use `use_reentrant: true`. Non-reentrant checkpointing causes `CheckpointError` with `flex_attention` block masks.
- **`activation_offloading`** is incompatible with `flex_attention`.
### Cross-Entropy Loss
Strided mode supports an optional cross-entropy loss term on ground-truth tokens. This acts as a regularizer to prevent the model from drifting too far from the original distribution:
```yaml
ebft:
ce_coef: 0.03 # Small CE coefficient
rl_coef: 1.0 # RL loss coefficient
```
The total loss is `rl_coef * rl_loss + ce_coef * ce_loss`. For structured mode, `ce_coef` is typically `0.0` since vLLM generation provides sufficient learning signal.
## Dataset Formats
EBFT provides several built-in dataset transforms in `src/axolotl/prompt_strategies/ebft/`.
### Built-In Transforms
| Transform | Input Format | Output Fields | Use Case |
|-----------|-------------|---------------|----------|
| `ebft_opencode.transform` | `{input, output}` | `{prompt, ground_truth}` | OpenCodeInstruct, structured QA |
| `ebft_strided_structured.transform` | `{input, output}` | `{input_ids, labels, prompt_length}` | Strided mode with structured data |
| `ebft_strided_chat.transform` | `{messages: [...]}` | `{input_ids, labels, prompt_length}` | Strided mode with chat data |
| `ebft_chat_multiturn.transform` | `{messages: [...]}` | `{prompt, ground_truth, remaining_turns}` | Multi-turn: first-turn target |
| `ebft_chat_multiturn.transform_last_turn` | `{messages: [...]}` | `{prompt, ground_truth}` | Multi-turn: last-turn target |
| `ebft_chat_multiturn.transform_all_turns` | `{messages: [...]}` | `{prompt[], ground_truth[]}` | Multi-turn: one example per turn |
| `ebft_reasoning.transform` | `{messages: [...]}` (with `<think>`) | `{prompt, ground_truth}` | Reasoning/thinking datasets |
### Structured Mode Datasets
For structured (sync/async) mode, the transform must produce `prompt` and `ground_truth` fields:
```yaml
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:500]
```
### Multi-Turn Datasets
Multi-turn transforms extract conversation data for sequential rollout. The `transform` variant targets the first assistant turn, while `transform_last_turn` targets the final turn:
```yaml
datasets:
- path: your/multiturn-dataset
type: ebft_chat_multiturn.transform
```
When `remaining_turns` is present in the dataset output, the trainer performs sequential rollouts: it generates the first assistant turn with vLLM, then continues generating subsequent turns by building up the conversation history.
### Strided Mode Datasets
Strided transforms tokenize the full document and produce `input_ids`, `labels`, and `prompt_length`:
```yaml
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_strided_structured.transform
split: train[:1%]
```
### Custom Transforms
To use your own dataset format, write a transform function:
```python
def transform(cfg, **kwargs):
def transform_fn(example, tokenizer=None):
return {
"prompt": [{"role": "user", "content": example["question"]}],
"ground_truth": example["answer"],
}
return transform_fn, {"remove_columns": "__all__"}
```
The `"__all__"` sentinel removes all original dataset columns after the mapping step. Reference this transform in your config:
```yaml
datasets:
- path: your/dataset
type: your_module.transform
```
## Configuration Reference
### Common Parameters (All Modes)
These parameters are set under the `ebft:` key in the YAML config.
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `mode` | `"structured"` or `"strided"` | `"structured"` | EBFT operating mode |
| `feature_layers` | `list[float]` | `[0.25, 0.5, 0.75]` | Fractional layer depths for feature extraction |
| `embed_method` | `string` | `"last_token"` | Pooling method: `last_token`, `mean_pooling`, `completion_mean`, or `concat` |
| `use_whitening` | `bool` | `false` | Apply SVD whitening to feature embeddings before reward computation |
| `alignment_coef` | `float` | `1.0` | Weight for alignment reward (cosine similarity with ground truth) |
| `diversity_coef` | `float` | `1.0` | Weight for diversity penalty (pairwise dot product between samples) |
| `ce_coef` | `float` | `0.0` | Cross-entropy loss coefficient on ground-truth tokens |
| `adaptive_max_tokens` | `bool` | `true` | Dynamically set vLLM `max_tokens` based on ground-truth length (structured mode) |
| `gt_length_multiplier` | `float` | `1.5` | Multiplier for ground-truth token count when computing adaptive max tokens (min 0.1) |
### Strided Mode Parameters
These additional parameters apply only when `mode: strided`.
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `stride` | `int` | `8` | Number of tokens between anchor points (must be >= 1) |
| `context_length` | `int` | `8` | Context window size for each generated block (must be >= 1) |
| `generate_max_len` | `int` | `8` | Number of tokens to generate per block (must be >= 1) |
| `n_samples_per_prompt` | `int` | `4` | Number of independent rollouts per document (must be >= 1) |
| `temperature` | `float` | `0.6` | Sampling temperature for strided generation |
| `top_p` | `float` | `1.0` | Top-p nucleus sampling threshold |
| `rl_coef` | `float` | `1.0` | RL policy gradient loss coefficient |
| `advantage_estimator` | `string` | `"rloo"` | Advantage estimation method: `rloo`, `group_norm`, or `reinforce` |
| `min_completion_prefix` | `int` | `0` | Minimum tokens into the completion span before placing anchors |
### Structured Mode TRL Parameters
These are set under the `trl:` key and control the GRPO training loop.
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `num_generations` | `int` | -- | Number of completions generated per prompt |
| `max_completion_length` | `int` | -- | Maximum tokens per generated completion |
| `temperature` | `float` | `0.7` | Sampling temperature for vLLM generation |
| `use_vllm` | `bool` | -- | Enable vLLM generation backend |
| `vllm_lora_sync` | `bool` | `false` | Sync LoRA adapters via filesystem (recommended) |
| `vllm_sync_interval` | `int` | `1` | Steps between weight syncs to vLLM |
| `use_data_producer` | `bool` | -- | Required for sync mode with LoRA sync |
| `async_prefetch` | `bool` | `false` | Enable async generation (overlaps with training) |
| `streaming_partial_batch` | `bool` | `false` | Score groups incrementally (async mode) |
| `skip_zero_advantage_batches` | `bool` | `false` | Skip micro-batches where all advantages are zero |
| `scale_rewards` | `bool` | -- | Normalize rewards within each prompt group |
| `loss_type` | `string` | `"grpo"` | Loss type for policy optimization |
| `epsilon` | `float` | `0.2` | Clipping parameter for importance sampling |
### Stop Tokens
vLLM needs explicit stop token IDs for generation. Common configurations:
```yaml
trl:
generation_kwargs:
stop_token_ids: [151645, 151643] # Qwen: <|im_end|>, <|endoftext|>
```
### Multi-Turn Chat Settings
For multi-turn conversations with Qwen3.5, disable thinking mode to prevent `<think>` tags in completions:
```yaml
trl:
chat_template_kwargs:
enable_thinking: false
```
## Monitoring
### Key Metrics
EBFT logs several custom metrics to wandb and the training console. Here is what to watch for:
| Metric | Healthy Range | Interpretation |
|--------|--------------|----------------|
| `ebft/alignment` | 0.3 -- 0.9, trending upward | Cosine similarity between generated and ground-truth features. Higher means the model is learning to produce representations that match the reference. |
| `ebft/diversity` | 0.01 -- 0.1 | Mean pairwise similarity between different generations for the same prompt. Values above 1.0 indicate mode collapse. |
| `ebft/cfm_loss` | Below 10, trending downward | Cross-Feature Matching loss. This is the core quantity being minimized. Consistently above 100 indicates instability. |
| `ebft/reward` | Trending upward (may start negative) | Combined reward signal. If stuck at -1.0, the diversity penalty is dominating alignment. |
| `grad_norm` | 0.1 -- 3.0 | Gradient magnitude. Values of 0.0 indicate zero-advantage skip (normal). Values above 10 suggest instability. |
| `entropy` | 0.05 -- 0.5 | Policy entropy. Values below 0.01 suggest mode collapse. |
| `IS ratio min` | Above 0.1 | Importance sampling ratio minimum. Near-zero values mean the policy is too far off-policy; increase `vllm_sync_interval`. |
### Console Log Example
During training, you will see periodic EBFT reward logs:
```
ebft reward | align +0.412 ^ | divers +0.023 v | cfm 4.231 v | reward +0.389 ^
```
The arrows indicate the desired direction: alignment and reward should trend upward, while diversity and CFM loss should trend downward.
### Troubleshooting
| Symptom | Likely Cause | Fix |
|---------|-------------|-----|
| `alignment` stays below 0.1 | Feature layers not capturing useful information | Try different `feature_layers` or `embed_method` |
| `diversity` exceeds 1.0 | Mode collapse -- generations are too similar | Increase `diversity_coef` or `temperature` |
| `reward` stuck at -1.0 | Diversity penalty dominates alignment | Reduce `diversity_coef` or increase `alignment_coef` |
| `grad_norm` consistently 0.0 | All micro-batches have zero advantage | Increase `num_generations` or check data quality |
| `CheckpointError` in strided mode | Incompatible gradient checkpointing settings | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
| OOM during training | Logits tensor too large | Reduce `sequence_len` or `micro_batch_size`; strided mode uses chunked lm_head to mitigate this |
| vLLM 500 errors | `truncate_prompt_tokens` not supported | Ensure you are using `axolotl vllm-serve` (not `trl vllm-serve`) |
### Feature Network Memory
In PEFT (LoRA) mode, the feature network shares base weights with the actor model by using the `disable_adapter()` context manager. This saves an entire model copy in VRAM (approximately 1--16 GB depending on model size). For non-PEFT training, a separate frozen deepcopy is created.
::: {.callout-note}
The `disable_adapter()` approach relies on an invariant: `merge_adapter()` is never called on the base weights. All weight sync paths (LoRA sync, HTTP, NCCL) compute merged weights as new tensors or save the adapter to the filesystem, leaving base weights unmodified.
:::
## Examples
Complete example configurations are available in `examples/ebft/`:
| Config | Model | Mode | Description |
|--------|-------|------|-------------|
| `llama-1b-ebft-strided-structured.yaml` | Llama 3.2 1B | Strided | Single-GPU strided training on code data |
| `qwen3-4b-ebft-structured.yaml` | Qwen3 4B | Structured (sync) | Two-GPU structured training |
| `qwen3-4b-ebft-structured-async.yaml` | Qwen3 4B | Structured (async) | Two-GPU async training with prefetch |
| `qwen3-8b-ebft-structured.yaml` | Qwen3 8B | Structured (sync) | Two-GPU structured training for larger model |
| `qwen35-4b-ebft-structured.yaml` | Qwen3.5 4B | Structured (sync) | Two-GPU with Qwen3.5 |
| `qwen35-4b-ebft-structured-async.yaml` | Qwen3.5 4B | Structured (async) | Two-GPU async with Qwen3.5 |
| `qwen35-9b-ebft-structured.yaml` | Qwen3.5 9B | Structured (sync) | Two-GPU structured for 9B model |

View File

@@ -1,67 +0,0 @@
---
title: "MoE Expert Quantization"
description: "Reduce VRAM usage when training MoE model adapters by quantizing expert weights on load"
---
Transformers v5 changed MoE expert layers from `nn.Linear` to fused `nn.Parameter` (3D+ tensors).
This means `bitsandbytes` can no longer quantize them during model loading, resulting in all expert
weights being loaded in full bf16 precision and causing massive VRAM usage.
`quantize_moe_experts` solves this by quantizing expert weights during model loading.
It intercepts the weight loading process, quantizes each expert tensor on the fly, and
immediately frees the original bf16 tensor from VRAM. This dramatically reduces peak memory.
For example, GLM-4.7-Flash QLoRA drops from ~127GiB to ~23GiB reserved memory.
## Usage
Enable expert quantization in your Axolotl config:
```yaml
quantize_moe_experts: true
```
This works with both 4-bit (QLoRA) and 8-bit (LoRA) quantization.
### Expert LoRA targeting
You can optionally apply LoRA adapters directly to expert weights using `lora_target_parameters`:
```yaml
lora_target_parameters:
- mlp.experts.gate_up_proj
- mlp.experts.down_proj
# - mlp.gate.weight # router
```
::: {.callout-note}
`lora_dropout` must be `0` when using `lora_target_parameters`.
:::
## Requirements
- Requires (`adapter: lora` and `load_in_8bit: true`) or (`adapter: qlora` and `load_in_4bit: true`)
- CUDA GPUs only (not tested with ROCm or other backends)
- FSDP2 compatible for distributed training
## Limitations
- `lora_target_linear` is not compatible with `quantize_moe_experts`. See [Expert LoRA targeting](#expert-lora-targeting) instead.
- `cpu_ram_efficient_loading` hangs / takes long time with FSDP2 + QLoRA.
- Total model parameter count may display incorrectly (trainable param count is correct).
- FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps, which then drops. QLoRA does not exhibit this.
- FSDP2 may use more VRAM per GPU than single GPU training due to not all layers being properly sharded across ranks.
- Model loading takes longer due to on-demand quantization, even on consecutive runs.
- DeepSpeed has not been tested.
## Implementation details
The quantization is applied by patching transformers to intercept weight loading.
When a 3D+ CUDA tensor with "expert" in its name is detected:
- **4-bit mode:** Uses bitsandbytes NF4 parametrization (configurable via `bnb_4bit_quant_type`).
- **8-bit mode:** Uses a custom row-wise int8 parametrization with bitsandbytes dequantization.
The original bf16 tensor is freed immediately after quantization. Multiple sub-patches are applied to
transformers, PEFT and accelerate FSDP2 to support these parametrized expert modules.
For full implementation details, see [PR #3439](https://github.com/axolotl-ai-cloud/axolotl/pull/3439).

View File

@@ -63,14 +63,6 @@ description: Frequently asked questions
> A: There seems to be a wheel issue with FA2 2.8.0 on CUDA 12.4. Try CUDA 12.6 instead or downgrade to FA2 2.7.4. Please refer to the upstream issue: https://github.com/Dao-AILab/flash-attention/issues/1717.
**Q: Can we mix text and text+image datasets for VLM training?**
> A: Yes, you can for newer VLM arch. The ones that would not work are LLaVA / Pixtral arch. If you notice one not working, please let us know!
**Q: Why is `memory/max_*` different from `nvidia-smi`?**
> A: We use `torch` APIs to retrieve this information. You can see https://docs.pytorch.org/docs/stable/notes/cuda.html#cuda-memory-management for more information.
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
@@ -148,7 +140,3 @@ description: Frequently asked questions
**Q: `ValueError("Backward pass should have cleared tracker of all tensors")`
> A: This may happen due to edge cases in using the modern OffloadActivations context manager for CUDA streams. If you encounter this error, you may have success using the naive implementation with `offload_activations: legacy` in your YAML.
**Q: `Error parsing tool_calls arguments as JSON.`
> A: There is an error parsing string arguments to a dict. Please check your dataset and the error message for more details.

View File

@@ -1,5 +1,5 @@
---
title: "FSDP + QLoRA"
title: "FDSP + QLoRA"
description: Use FSDP with QLoRA to fine-tune large LLMs on consumer GPUs.
format:
html:
@@ -23,12 +23,6 @@ To enable `QLoRA` with `FSDP`, you need to perform the following steps:
2. Enable FSDP in your axolotl config, as [described here](multi-gpu.qmd#sec-fsdp).
3. Use one of the supported model types: `llama`, `mistral` or `mixtral`.
## Enabling Swap for FSDP2
If available memory is insufficient even after FSDP's CPU offloading, you can enable swap memory usage by setting `cpu_offload_pin_memory: false` alongside `offload_params: true` in FSDP config.
This disables memory pinning, allowing FSDP to use disk swap space as fallback. Disabling memory pinning itself incurs performance overhead, and actually having to use swap adds more, but it may enable training larger models that would otherwise cause OOM errors on resource constrained systems.
## Example Config
[examples/llama-2/qlora-fsdp.yml](../examples/llama-2/qlora-fsdp.yml) contains an example of how to enable QLoRA + FSDP in axolotl.

View File

@@ -170,26 +170,17 @@ More details can be found in [Merging LoRA weights](inference.qmd#sec-merging).
## Next Steps {#sec-next-steps}
Now that you have the basics, explore these guides based on what you want to do:
Now that you have the basics, you might want to:
**Choose your path:**
- Try different model architectures
- Experiment with hyperparameters
- Use more advanced training methods
- Scale up to larger models
- [Choosing a Fine-Tuning Method](choosing_method.qmd) — SFT vs LoRA vs QLoRA vs GRPO vs DPO, with hardware recommendations
Check our other guides for details on these topics:
**Core guides:**
- [Dataset Loading](dataset_loading.qmd) — Loading datasets from various sources
- [Dataset Formats](dataset-formats) — Working with different data formats
- [Optimizations](optimizations.qmd) — Flash attention, gradient checkpointing, sample packing
- [Training Stability & Debugging](training_stability.qmd) — Monitoring metrics, fixing NaN, OOM debugging
**Advanced training methods:**
- [RLHF / Preference Learning](rlhf.qmd) — DPO, KTO, GRPO, EBFT
- [GRPO Training](grpo.qmd) — RL with custom rewards and vLLM generation
- [vLLM Serving](vllm_serving.qmd) — Setting up vLLM for GRPO
**Scaling up:**
- [Multi-GPU Training](multi-gpu.qmd) — DeepSpeed, FSDP, DDP
- [Multi-Node Training](multi-node.qmd) — Distributed training across machines
- [Configuration Guide](config-reference.qmd) - Full configuration options
- [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources
- [Dataset Formats](dataset-formats) - Working with different data formats
- [Multi-GPU Training](multi-gpu.qmd)
- [Multi-Node Training](multi-node.qmd)

View File

@@ -1,5 +1,5 @@
---
title: Gradient Checkpointing, Activation Offloading, and Layer Offloading
title: Gradient Checkpointing and Activation Offloading
---
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
@@ -27,33 +27,3 @@ The `activation_offloading: legacy` naively offloads activations to CPU and with
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
### Enabling Layer Offloading
```yaml
layer_offloading: true
```
Layer offloading reduces GPU memory usage by moving frozen (non-trainable) decoder layer parameters to CPU
and streaming them back to GPU one layer at a time during the forward and backward passes. This is
particularly useful for LoRA/QLoRA training where most of the model's parameters are frozen — only the
trainable adapter weights stay on GPU permanently.
During training, forward and backward hooks on each decoder layer handle the transfer automatically:
- **Forward pass:** Before a layer executes, its frozen params are loaded to GPU. The next layer is
prefetched asynchronously on a separate CUDA stream for overlap.
- **Backward pass:** Same pattern in reverse — the current layer's frozen params are loaded and the
previous layer is prefetched.
After each layer finishes, its frozen params are offloaded back to CPU pinned memory.
This approach trades some CPU-GPU transfer overhead for significant GPU memory savings — the freed memory
is roughly equal to the size of all frozen parameters across all decoder layers, minus one layer's worth
that is kept on GPU at any given time.
**Requirements:**
- CUDA GPU (CPU-only training is not supported for this feature)
- Works with any HuggingFace model architecture that uses decoder layers (Llama, Mistral, Qwen, etc.)
- Best combined with LoRA/QLoRA where most parameters are frozen

View File

@@ -1,611 +0,0 @@
---
title: "GRPO Training"
description: "Group Relative Policy Optimization — a reinforcement learning method for training language models with verifiable reward functions."
order: 8
---
## Overview
Group Relative Policy Optimization (GRPO) is a reinforcement learning method that improves language models by generating multiple completions per prompt, scoring them with reward functions, and using the relative ranking within each group to compute advantage estimates. Unlike DPO, which requires pre-collected preference pairs, GRPO generates its own training data online and can work with any programmatic reward signal (math correctness, format compliance, code execution results, etc.).
Use GRPO when you have a task with a verifiable reward signal and want the model to discover solution strategies on its own. Use DPO when you already have human preference data. Use SFT when you have gold-standard completions to imitate directly.
Axolotl's GRPO implementation builds on TRL and adds async generation, streaming scoring, importance sampling correction, replay buffers, and multi-GPU scaling via FSDP and DeepSpeed.
## Architecture
GRPO training uses a two-process architecture: a vLLM server for fast generation and a trainer process for scoring and gradient updates.
```
Terminal 1 (GPU 0) Terminal 2 (GPU 1)
┌──────────────────────┐ ┌──────────────────────────────────┐
│ vLLM Server │ │ Trainer │
│ │ HTTP │ │
│ Serves base model │◄────────────►│ Background thread: │
│ + LoRA adapter │ /generate │ Send prompts to vLLM │
│ │ /set_lora │ Pad & collate completions │
│ Punica kernels for │ │ │
│ LoRA inference │ │ Main thread: │
│ │ │ Score completions (rewards) │
└──────────────────────┘ │ Compute policy log-probs │
│ Calculate advantages │
│ PPO-clip gradient update │
│ Sync LoRA weights to vLLM │
└──────────────────────────────────┘
```
**Data flow for each training step:**
1. The background thread sends prompts to vLLM, which generates `num_generations` completions per prompt.
2. The main thread scores completions using your reward functions.
3. Advantages are computed within each prompt group (group-relative normalization).
4. Policy log-probabilities are computed by running a forward pass on the training model.
5. The PPO-clip loss is computed and gradients are applied.
6. Periodically, LoRA adapter weights are synced back to vLLM so future generations reflect the updated policy.
With async prefetch enabled, step 1 for the *next* batch runs concurrently with steps 2-6 for the *current* batch.
## Quick Start
A GRPO training run requires three components: a YAML config, a reward module (Python file), and a running vLLM server.
### 1. Write a reward module
Create a file called `rewards.py` in your working directory:
```python
# rewards.py
import re
def accuracy_reward(completions, answer, **kwargs) -> list[float]:
"""Check if the completion contains the correct numerical answer."""
rewards = []
for completion, correct in zip(completions, answer):
text = completion[0]["content"]
# Extract the last number from the completion
numbers = re.findall(r"-?\d+(?:\.\d+)?", text)
predicted = numbers[-1] if numbers else ""
rewards.append(1.0 if predicted == str(correct) else 0.0)
return rewards
def format_reward(completions, **kwargs) -> list[float]:
"""Reward completions that use a structured thinking format."""
rewards = []
for completion in completions:
text = completion[0]["content"]
has_think = "<think>" in text and "</think>" in text
has_answer = "<answer>" in text and "</answer>" in text
rewards.append(1.0 if has_think and has_answer else 0.0)
return rewards
def prompt_transform(cfg, *args, **kwargs):
"""Convert GSM8K dataset rows into chat prompts."""
def transform_fn(example, tokenizer=None):
label = example["answer"].split("####")[-1].strip().replace(",", "")
return {
"prompt": [
{"role": "system", "content": "Solve the math problem. Show your reasoning in <think> tags and your final numerical answer in <answer> tags."},
{"role": "user", "content": example["question"]},
],
"answer": label,
}
return transform_fn, {"remove_columns": ["question"]}
```
### 2. Write the config
Create `config.yaml`:
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
rl: grpo
chat_template: tokenizer_default
vllm:
host: 0.0.0.0
port: 8000
gpu_memory_utilization: 0.85
dtype: auto
max_model_len: 2048
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
trl:
use_vllm: true
use_data_producer: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
vllm_server_timeout: 300
vllm_lora_sync: true
num_generations: 8
max_completion_length: 512
temperature: 0.7
reward_funcs:
- rewards.accuracy_reward
- rewards.format_reward
reward_weights:
- 1.0
- 0.5
datasets:
- path: openai/gsm8k
name: main
type: rewards.prompt_transform
split: train
skip_prepare_dataset: true
val_set_size: 0.0
sequence_len: 512
micro_batch_size: 2
gradient_accumulation_steps: 4
max_steps: 200
learning_rate: 5.0e-6
optimizer: adamw_torch_fused
lr_scheduler: cosine
warmup_steps: 10
bf16: true
flash_attention: true
gradient_checkpointing: true
special_tokens:
pad_token: "<|endoftext|>"
output_dir: ./grpo-output
logging_steps: 1
```
### 3. Start vLLM and train
```bash
# Terminal 1: Start vLLM server on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Wait 30-90 seconds for model loading and CUDA graph capture
# Terminal 2: Train on GPU 1
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
```
:::{.callout-tip}
Use `tmux` or separate terminal sessions to manage the two processes. The vLLM server must remain running for the entire training duration.
:::
## Custom Reward Functions
### Function signature
TRL calls reward functions with this signature:
```python
def my_reward(completions, **kwargs) -> list[float]:
```
- `completions` is a list of single-element lists, where each element is a dict `{"role": "assistant", "content": "..."}`. So `completions[i][0]["content"]` gives you the text of the i-th completion.
- `**kwargs` contains all dataset columns that were *not* removed by the dataset transform. This is how you pass ground truth answers, metadata, or any other information to your reward function.
- Return a `list[float]` with the same length as `completions`. You may return `None` for individual elements to exclude them from aggregation.
### Example: accuracy reward with answer extraction
```python
def accuracy_reward(completions, answer, **kwargs) -> list[float]:
rewards = []
for completion, correct_answer in zip(completions, answer):
text = completion[0]["content"]
# Extract answer from <answer>...</answer> tags
match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
predicted = match.group(1).strip() if match else ""
rewards.append(1.0 if predicted == str(correct_answer) else 0.0)
return rewards
```
### Example: length penalty
```python
def length_penalty(completions, **kwargs) -> list[float]:
"""Penalize very short or very long completions."""
rewards = []
for completion in completions:
length = len(completion[0]["content"])
if length < 50:
rewards.append(-0.5)
elif length > 2000:
rewards.append(-0.2)
else:
rewards.append(0.0)
return rewards
```
### Multiple rewards and weighting
You can combine multiple reward functions with different weights:
```yaml
trl:
reward_funcs:
- rewards.accuracy_reward
- rewards.format_reward
- rewards.length_penalty
reward_weights:
- 1.0 # accuracy is most important
- 0.5 # format compliance
- 0.1 # mild length preference
```
Rewards are combined by the `multi_objective_aggregation` strategy:
- `sum_then_normalize` (default): weights and sums all rewards first, then normalizes across the group.
- `normalize_then_sum` (GDPO): normalizes each reward independently, then sums. This prevents one reward from dominating and is recommended when using multiple reward functions with different scales.
```yaml
trl:
multi_objective_aggregation: normalize_then_sum
```
### Dataset transforms
The dataset transform converts raw HuggingFace dataset rows into chat-format prompts:
```python
def prompt_transform(cfg, *args, **kwargs):
def map_fn(example, tokenizer=None):
return {
"prompt": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": example["question"]},
],
# Keep 'answer' column for the reward function
"answer": example["answer"],
}
# Remove columns consumed by the transform; keep columns needed by rewards
return map_fn, {"remove_columns": ["question"]}
```
The transform returns a tuple of `(map_function, kwargs_dict)`. The `remove_columns` in the kwargs dict removes columns that are no longer needed. Columns that your reward functions reference via `**kwargs` (like `answer`) must *not* be removed.
:::{.callout-warning}
The reward module must be importable from the directory where you run `axolotl train`. If your reward file is `rewards.py`, the import path is `rewards.accuracy_reward`. If it is inside a package `my_rewards/scoring.py`, use `my_rewards.scoring.accuracy_reward`.
:::
### Reward models (neural network rewards)
Instead of a Python function, you can pass a HuggingFace model path as a reward function. TRL will load it as a reward model and use its scalar output as the reward:
```yaml
trl:
reward_funcs:
- OpenAssistant/reward-model-deberta-v3-large-v2
- rewards.format_reward
reward_weights:
- 1.0
- 0.3
```
### Using math_verify
The `math_verify` library provides robust mathematical answer verification but uses `signal.alarm()` internally, which only works in the main thread. If you use `math_verify` in a reward function, set `reward_num_workers` to use subprocess workers:
```yaml
trl:
reward_num_workers: 4
```
Each worker runs in its own subprocess with its own main thread, so `signal.alarm()` works correctly.
## vLLM Setup
GRPO requires a running vLLM server for generation. For a complete guide on server modes, LoRA sync, weight synchronization, and restart procedures, see [vLLM Serving](vllm_serving.qmd).
The minimal setup:
```yaml
vllm:
host: 0.0.0.0
port: 8000
gpu_memory_utilization: 0.85
trl:
use_vllm: true
vllm_lora_sync: true # Recommended with LoRA — faster sync, no NCCL contention
vllm_sync_interval: 5 # Sync weights every 5 steps
```
```bash
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml # GPU 0: vLLM
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml # GPU 1: training
```
:::{.callout-warning}
vLLM must be restarted between experiments — stale weight syncs corrupt server state. See [Restart Requirements](vllm_serving.qmd#sec-restart).
:::
## Async Training Features
Async GRPO overlaps generation and training to reduce wall-clock time. While the model trains on the current batch, the next batch is already being generated by vLLM.
### Enabling async prefetch
```yaml
trl:
use_data_producer: true
async_prefetch: true
prefetch_depth: 1
vllm_sync_interval: 2
```
- `use_data_producer: true` enables the data producer protocol (required for all async features).
- `async_prefetch: true` runs generation in a background thread.
- `prefetch_depth` controls how many batches to prefetch ahead (1 is usually sufficient).
- `vllm_sync_interval` controls how often LoRA weights are synced to vLLM (every N optimizer steps). Lower values mean fresher generations but more sync overhead.
:::{.callout-tip}
Because the background thread generates with slightly stale model weights, async mode benefits from importance sampling correction (see next section). Enable `vllm_importance_sampling_correction: true` when using `async_prefetch: true`.
:::
### Streaming partial batch
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This reduces peak memory during scoring and enables finer-grained zero-advantage skipping.
```yaml
trl:
streaming_partial_batch: true
streaming_min_groups: 1
```
`streaming_min_groups` controls the minimum number of prompt groups scored per chunk. Setting it to 1 gives maximum granularity.
### Zero-advantage batch skipping
When all advantages in a micro-batch are zero (every completion in the group got the same reward), there is no learning signal. This feature skips the forward/backward pass entirely for such micro-batches.
```yaml
trl:
skip_zero_advantage_batches: true # default
```
This is enabled by default and logged as `skipped_zero_adv_batches` in training metrics. It is a safety net, not a major optimization -- it only saves significant time when the model cannot solve any prompts in the batch.
### Replay buffer
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and replaces zero-signal groups in later batches. This improves data utilization when many prompts yield no reward variance.
```yaml
trl:
replay_buffer_size: 100
replay_recompute_logps: true
```
:::{.callout-warning}
When `replay_recompute_logps: false`, replayed data uses stale log-probabilities which creates an IS mismatch. Keep the default `true` unless you have a specific reason to disable it.
:::
### Deferred re-rolling
Prompts where the model gets zero reward for all generations are buffered and re-injected into later batches, when the model may have improved enough to produce useful completions.
```yaml
trl:
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
reroll_max_groups: 1 # Max groups to replace per batch
```
Set `reroll_start_fraction: 1.0` to disable. This is most useful for tasks where the model starts weak but steadily improves.
### Parallel reward workers
Reward functions that use `signal.alarm()` (like `math_verify`) only work in the main thread. Parallel reward workers run each function in its own subprocess:
```yaml
trl:
reward_num_workers: 4
```
Work is sharded across workers by prompt group. For simple reward functions, a single worker is usually sufficient -- the overhead of IPC can exceed the computation time.
## Importance Sampling and Off-Policy Correction
When using async prefetch, completions are generated from a slightly older policy. IS correction adjusts the gradient to account for this mismatch.
```yaml
trl:
vllm_importance_sampling_correction: true
importance_sampling_level: token # 'token' recommended (especially with Liger kernel)
off_policy_mask_threshold: 0.5 # KL threshold — masks sequences that are too off-policy
```
Use `token` level IS. Sequence-level has numerical issues with Liger's chunked computation. The `off_policy_mask_threshold` (OPSM) is a safety net that drops sequences where KL divergence exceeds the threshold — 0.5 is a reasonable starting point.
For detailed coverage of IS modes (`token_mask`, `token_truncate`, etc.), capping, and bias-corrected KL, see [vLLM Serving — IS Correction](vllm_serving.qmd#sec-weight-sync).
## Scaling
### FP8 training
FP8 quantization halves model VRAM usage with minimal impact on training quality. It does not significantly speed up computation for small models but allows larger models to fit in memory.
```yaml
fp8: true
torch_compile: true
```
:::{.callout-warning}
FP8 requires patching for zero-padding edge cases. The `act_quant_kernel` can produce NaN when input is all zeros (padding positions). If you see NaN in grad norms, check whether your padding token embedding is non-zero.
:::
### FSDP (Fully Sharded Data Parallel)
FSDP distributes model parameters across multiple GPUs for training while vLLM runs on a separate GPU:
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
gradient_checkpointing_kwargs:
use_reentrant: false
```
Launch with:
```bash
# GPU 0: vLLM
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# GPUs 0,1: Training (FSDP will use both visible GPUs)
CUDA_VISIBLE_DEVICES=0,1 axolotl train config.yaml
```
:::{.callout-warning}
`async_prefetch: true` can deadlock with FSDP because background threads perform unsynchronized FSDP collectives across ranks. With multi-GPU FSDP, only rank 0 generates in the background thread and results are broadcast to all ranks. If you still see hangs, set `async_prefetch: false`.
:::
### DeepSpeed ZeRO-3
```yaml
deepspeed: deepspeed_configs/zero3_bf16.json
gradient_checkpointing_kwargs:
use_reentrant: true # Required -- non-reentrant causes CheckpointError with ZeRO-3
```
:::{.callout-note}
DeepSpeed ZeRO-3 requires `use_reentrant: true` for gradient checkpointing. This is the opposite of the FSDP recommendation. Non-reentrant checkpointing causes tensor metadata mismatches during recomputation with ZeRO-3's parameter partitioning.
:::
### Multi-GPU considerations
| Concern | Recommendation |
|---------|---------------|
| vLLM GPU allocation | Dedicate one or more GPUs to vLLM; do not share with trainer GPUs |
| Weight sync contention | Use `vllm_lora_sync: true` to avoid NCCL contention between training and vLLM |
| FSDP + async | Use `async_prefetch: false` or rely on rank-0-only background generation |
| DeepSpeed + gradient checkpoint | Must use `use_reentrant: true` |
| OOM during scoring | Reduce `micro_batch_size` or `num_generations`. The logits tensor scales with `batch_size * vocab_size` |
## Monitoring and Debugging
For detailed metric ranges, failure diagnosis, and OOM debugging, see [Training Stability & Debugging](training_stability.qmd).
Quick health checks during GRPO training:
- `rewards/*/mean` should be > 0.15 within 20 steps — if it stays at 0, test your reward function standalone
- `reward_std` should be > 0 on most steps — all-zero means no learning signal
- `entropy` in 0.05-0.5 — below 0.01 suggests mode collapse
- `grad_norm` in 0.001-1.0 — > 10 is unstable, 0.0 is expected when zero-advantage skip fires
:::{.callout-tip}
Pipe training output to a log file: `axolotl train config.yaml 2>&1 | tee /tmp/training.log`
:::
## Configuration Reference
All GRPO-specific options live under the `trl:` key in your config. Standard training options (`learning_rate`, `micro_batch_size`, etc.) are set at the top level as usual.
### Core GRPO
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `use_vllm` | bool | `false` | Enable vLLM for generation |
| `vllm_mode` | `"server"` or `"colocate"` | `null` | vLLM deployment mode |
| `vllm_server_host` | str | `"0.0.0.0"` | vLLM server hostname |
| `vllm_server_port` | int | `8000` | vLLM server port |
| `vllm_server_timeout` | int | `null` | Timeout (seconds) for vLLM responses |
| `num_generations` | int | `null` | Completions generated per prompt |
| `generation_batch_size` | int | `null` | Number of unique prompts per generation step |
| `max_completion_length` | int | `null` | Maximum tokens per completion |
| `beta` | float | `null` | KL penalty coefficient |
| `num_iterations` | int | `null` | Iterations per batch (mu in the GRPO paper) |
| `epsilon` | float | `null` | PPO clipping lower bound |
| `epsilon_high` | float | `null` | PPO clipping upper bound |
| `loss_type` | str | `null` | Loss formulation: `grpo`, `bnpo`, or `dr_grpo` |
| `scale_rewards` | bool | `true` | Normalize rewards by standard deviation |
| `mask_truncated_completions` | bool | `false` | Exclude truncated completions from loss |
### Reward functions
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `reward_funcs` | list[str] | `null` | Import paths to reward functions or HF model IDs |
| `reward_weights` | list[float] | `null` | Relative weights for each reward function |
| `multi_objective_aggregation` | str | `null` | `"sum_then_normalize"` (GRPO) or `"normalize_then_sum"` (GDPO) |
| `rollout_func` | str | `null` | Import path to custom rollout function for OpenEnv-style tasks |
### Generation parameters
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `temperature` | float | `null` | Sampling temperature |
| `top_p` | float | `null` | Nucleus sampling probability |
| `top_k` | int | `null` | Top-k sampling |
| `min_p` | float | `null` | Minimum probability threshold |
| `repetition_penalty` | float | `null` | Penalty for repeated tokens |
| `generation_kwargs` | dict | `null` | Additional vLLM SamplingParams (e.g., `stop_token_ids`) |
| `chat_template_kwargs` | dict | `null` | Chat template kwargs (e.g., `{enable_thinking: false}`) |
| `vllm_guided_decoding_regex` | str | `null` | Regex constraint for guided decoding |
### Async pipeline
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `use_data_producer` | bool | `false` | Enable data producer protocol (required for async features) |
| `async_prefetch` | bool | `false` | Generate next batch in background thread |
| `prefetch_depth` | int | `null` | Number of batches to prefetch ahead |
| `vllm_sync_interval` | int | `null` | Sync LoRA weights to vLLM every N steps |
| `vllm_lora_sync` | bool | `false` | Use filesystem LoRA sync instead of NCCL merge |
| `streaming_partial_batch` | bool | `null` | Score prompt groups incrementally |
| `streaming_min_groups` | int | `null` | Minimum groups per streaming chunk |
| `skip_zero_advantage_batches` | bool | `true` | Skip micro-batches with zero learning signal |
| `reward_num_workers` | int | `1` | Subprocess workers for reward computation |
| `vllm_enable_sleep_mode` | bool | `null` | Offload vLLM weights when idle (colocate mode) |
### Importance sampling
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `vllm_importance_sampling_correction` | bool | `null` | Enable IS correction for async distribution shift |
| `importance_sampling_level` | `"token"` or `"sequence"` | `null` | Granularity of IS ratios. Use `token` with Liger |
| `vllm_importance_sampling_mode` | str | `null` | `token_mask`, `token_truncate`, `sequence_mask`, or `sequence_truncate` |
| `vllm_importance_sampling_cap` | float | `null` | Cap C for IS ratio clipping/masking |
| `off_policy_mask_threshold` | float | `null` | KL threshold for off-policy sequence masking (OPSM) |
| `use_bias_correction_kl` | bool | `null` | Apply IS correction to KL divergence term |
### Replay and re-roll
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `replay_buffer_size` | int | `0` | Max cached high-signal groups. 0 = disabled |
| `replay_recompute_logps` | bool | `true` | Recompute log-probs for replayed data with current model |
| `reroll_start_fraction` | float | `1.0` | Start re-rolling failed prompts after this fraction of training. 1.0 = disabled |
| `reroll_max_groups` | int | `1` | Max prompt groups to replace with re-rolls per batch |
### Reference model
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `sync_ref_model` | bool | `false` | Periodically sync reference model with training model |
| `ref_model_mixup_alpha` | float | `0.9` | EMA coefficient for reference model sync |
| `ref_model_sync_steps` | int | `64` | Sync reference model every N steps |
### Logging
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `log_completions` | bool | `false` | Log sample completions to W&B |
| `num_completions_to_print` | int | `null` | Number of completions to print per step |
| `use_liger_loss` | bool | `null` | Use Liger fused kernel for GRPO loss (reduces VRAM) |

View File

@@ -26,7 +26,7 @@ Follow the instructions at: [https://pytorch.org/get-started/locally/](https://p
:::
::: {.callout-important}
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
For Blackwell GPUs, please use Pytorch 2.7.0 and CUDA 12.8.
:::
### PyPI Installation (Recommended) {#sec-pypi}
@@ -111,7 +111,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
:::
::: {.callout-important}
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.7.0` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.7.0`.
:::
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
@@ -134,7 +134,7 @@ For providers supporting Docker:
### Google Colab {#sec-colab}
[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
Use our [example notebook](../examples/colab-notebooks/colab-axolotl-example.ipynb).
## Platform-Specific Instructions {#sec-platform-specific}
@@ -165,7 +165,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
```
4. (Optional) Login to Hugging Face:
```{.bash}
hf auth login
huggingface-cli login
```
## Troubleshooting {#sec-troubleshooting}

View File

@@ -5,11 +5,10 @@ description: "Custom autograd functions and Triton kernels in Axolotl for optimi
Inspired by [Unsloth](https://github.com/unslothai/unsloth), we've implemented two
optimizations for LoRA and QLoRA fine-tuning, supporting both single GPU and multi-GPU
(including the DDP, DeepSpeed, and FSDP2 settings) training. These include (1) SwiGLU
and GEGLU activation function Triton kernels, and (2) LoRA MLP and attention custom
autograd functions. Our goal was to leverage operator fusion and tensor re-use in order
to improve speed and reduce memory usage during the forward and backward passes of
these calculations.
(in the DDP and DeepSpeed settings) training. These include (1) SwiGLU and GEGLU activation function
Triton kernels, and (2) LoRA MLP and attention custom autograd functions. Our goal was
to leverage operator fusion and tensor re-use in order to improve speed and reduce
memory usage during the forward and backward passes of these calculations.
We currently support several common model architectures, including (but not limited to):
@@ -89,10 +88,6 @@ lora_o_kernel: true
Currently, LoRA kernels are not supported for RLHF training, only SFT.
:::
::: {.callout-warning}
LoRA kernels do not support remote modeling code.
:::
## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
@@ -136,5 +131,6 @@ computation path.
## Future Work
- Support for additional model architectures
- Support for the FSDP setting
- Support for dropout and bias
- Additional operator fusions

View File

@@ -27,9 +27,3 @@ learning_rate: 2e-5
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
self attention `q_proj` module.
::: {.callout-note}
We currently only support varying `lr` for now. If you're interested in adding support for others (`weight_decay`), we welcome PRs. See https://github.com/axolotl-ai-cloud/axolotl/blob/613bcf90e58f3ab81d3827e7fc572319908db9fb/src/axolotl/core/trainers/mixins/optimizer.py#L17
:::

View File

@@ -4,7 +4,7 @@ format:
html:
toc: true
toc-depth: 3
# number-sections: true
number-sections: true
code-tools: true
execute:
enabled: false
@@ -14,18 +14,12 @@ This guide covers advanced training configurations for multi-GPU setups using Ax
## Overview {#sec-overview}
When training on multiple GPUs, Axolotl supports 3 sharding/parallelism strategies. Additionally, you can layer specific optimization features on top of that strategy.
Axolotl supports several methods for multi-GPU training:
You generally cannot combine these strategies; they are mutually exclusive.
1. **DeepSpeed**: Powerful optimization library, supports ZeRO stages 1-3.
2. **FSDP (Fully Sharded Data Parallel)**: PyTorch's native sharding implementation (Recommended).
3. **DDP (Distributed Data Parallel)**: PyTorch's native parallelism implementation (Default if neither of the above are selected).
These features can often be combined with the strategies above:
* **Sequence Parallelism**: Splits long sequences across GPUs (Compatible with DDP, DeepSpeed, and FSDP).
* **FSDP + QLoRA**: Combines 4-bit quantization with FSDP (Specific to FSDP).
- DeepSpeed (recommended)
- FSDP (Fully Sharded Data Parallel)
- Sequence parallelism
- FSDP + QLoRA
## DeepSpeed {#sec-deepspeed}
@@ -69,9 +63,16 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
:::
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
::: {.callout-tip}
FSDP allows you to shard model parameters, gradients, and optimizer states across data parallel workers.
Using ZeRO Stage 3 with Single-GPU training
ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`
:::
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
::: {.callout-note}
@@ -79,10 +80,6 @@ FSDP2 is recommended for new users. FSDP1 is deprecated and will be removed in a
:::
### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
### Migrating from FSDP1 to FSDP2 {#sec-migrate-fsdp1-fsdp2}
To migrate your config from FSDP1 to FSDP2, you must use the `fsdp_version` top-level config field to specify the FSDP version, and
@@ -100,7 +97,6 @@ fsdp_sync_module_states | **REMOVED**
fsdp_cpu_ram_efficient_loading | cpu_ram_efficient_loading
fsdp_state_dict_type | state_dict_type
fsdp_use_orig_params | **REMOVED**
fsdp_activation_checkpointing | activation_checkpointing
For more details, please see the migration guide in the [torchtitan repo](https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md). In Axolotl,
if you were using the following FSDP1 config:
@@ -157,6 +153,10 @@ single sequence causes OOM errors during model training.
See our [dedicated guide](sequence_parallelism.qmd) for more information.
### FSDP + QLoRA {#sec-fsdp-qlora}
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
## Performance Optimization {#sec-performance}
### Liger Kernel Integration {#sec-liger}

View File

@@ -13,18 +13,13 @@ format:
- [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Mistral-Small-4](#sec-mistral-small-4)
- [Magistral-Small-2509](#sec-magistral-small-2509)
- [Voxtral](#sec-voxtral)
- [Gemma-3](#sec-gemma-3)
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [Qwen3.5](#sec-qwen3-5)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
- [Intern-VL](#sec-intern-vl)
## Usage
@@ -46,6 +41,7 @@ datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
# (optional) if doing lora, only finetune the Language model,
# leave the vision model and vision tower frozen
@@ -60,14 +56,10 @@ image_resize_algorithm: bilinear
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
::: {.callout-tip}
::: {.callout-warning}
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
:::
::: {.callout-note}
As of now, we do not truncate nor drop samples based on `sequence_len` as each arch has different ways to process non-text tokens. We are looking for help on this.
:::
### Mllama {#sec-mllama}
```yaml
@@ -102,28 +94,10 @@ chat_template: llava
### Mistral-Small-3.1 {#sec-mistral-small-31}
::: {.callout-tip}
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
:::
```yaml
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
```
### Mistral-Small-4 {#sec-mistral-small-4}
```yaml
base_model: mistralai/Mistral-Small-4-119B-2603
```
### Magistral-Small-2509 {#sec-magistral-small-2509}
::: {.callout-tip}
Please make sure to install vision lib via `pip install 'mistral-common[opencv]==1.8.5'`
:::
```yaml
base_model: mistralai/Magistral-Small-2509
chat_template: mistral_v7_tekken
```
### Voxtral {#sec-voxtral}
@@ -134,8 +108,6 @@ Please make sure to install audio lib via `pip3 install librosa==0.11.0 'mistral
```yaml
base_model: mistralai/Voxtral-Mini-3B-2507
processor_type: VoxtralProcessor
```
### Gemma-3 {#sec-gemma-3}
@@ -184,34 +156,6 @@ base_model: Qwen/Qwen2.5-VL-7B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### Qwen3-VL {#sec-qwen3-vl}
```yaml
base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### Qwen3.5 {#sec-qwen3-5}
```yaml
base_model: Qwen/Qwen3.5-9B
chat_template: qwen3_5
```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
```yaml
# GLM-4.6V (106B MoE version)
base_model: zai-org/GLM-4.6V
# OR GLM-4.6V-Flash (9B version)
base_model: zai-org/GLM-4.6V-Flash
```
### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip}
@@ -232,16 +176,6 @@ Please uninstall `causal-conv1d` via `pip3 uninstall -y causal-conv1d`
base_model: LiquidAI/LFM2-VL-450M
```
### Intern-VL {#sec-intern-vl}
::: {.callout-tip}
Please make sure to install `timm` via `pip3 install timm==1.0.19`
:::
```yaml
base_model: OpenGVLab/InternVL3_5-8B
```
## Dataset Format
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.

View File

@@ -1,156 +0,0 @@
---
title: Optimizations Guide
description: A guide to the performance and memory optimizations available in Axolotl.
---
Axolotl includes numerous optimizations to speed up training, reduce memory usage, and handle large models.
This guide provides a high-level overview and directs you to the detailed documentation for each feature.
## Speed Optimizations
These optimizations focus on increasing training throughput and reducing total training time.
### Sample Packing
Improves GPU utilization by combining multiple short sequences into a single packed sequence for training. This requires enabling one of the [attention](#attention-implementations) implementations below.
- **Config:** `sample_packing: true`
- **Learn more:** [Sample Packing](multipack.qmd)
### Attention Implementations
Using an optimized attention implementation is critical for training speed.
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
*Note: You should only enable one attention backend.*
### LoRA Optimizations
Leverages optimized kernels to accelerate LoRA training and reduce memory usage.
- **Learn more:** [LoRA Optimizations Documentation](lora_optims.qmd)
## Memory Optimizations
These techniques help you fit larger models or use bigger batch sizes on your existing hardware.
### Parameter Efficient Finetuning (LoRA & QLoRA)
Drastically reduces memory by training a small set of "adapter" parameters instead of the full model. This is the most common and effective memory-saving technique.
- Examples: Find configs with `lora` or `qlora` in the [examples directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-3).
- Config Reference: See `adapter`, `load_in_4bit`, and `load_in_8bit` in the [Configuration Reference](config-reference.qmd).
### Gradient Checkpointing & Activation Offloading
These techniques save VRAM by changing how activations are handled.
- Gradient Checkpointing: re-computes activations during the backward pass, trading compute time for VRAM.
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
### Layer Offloading
Offloads frozen (non-trainable) decoder layer parameters to CPU and streams them back to GPU one layer at a time during forward/backward passes using CUDA stream prefetching. Especially effective for LoRA/QLoRA where most parameters are frozen.
- **Config:** `layer_offloading: true`
- **Learn more:** [Layer Offloading Docs](gradient_checkpointing.qmd#enabling-layer-offloading)
### Cut Cross Entropy (CCE)
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
- **Learn more:** [Custom Integrations - CCE](custom_integrations.qmd#cut-cross-entropy)
### Liger Kernels
Provides efficient Triton kernels to improve training speed and reduce memory usage.
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
### Expert Kernels
Optimized kernel implementations for Mixture of Experts (MoE) model training.
- **ScatterMoE**: Triton-based MoE kernels with fused LoRA support.
- **SonicMoE**: CUTLASS-based MoE kernels for NVIDIA Hopper and Blackwell GPUs.
- **Learn more:** [Custom Integrations - Kernels Integration](custom_integrations.qmd#kernels-integration)
## Long Context Models
Techniques to train models on sequences longer than their original context window.
### RoPE Scaling
Extends a model's context window by interpolating its Rotary Position Embeddings.
- **Config:** Pass the `rope_scaling` config under the `overrides_of_model_config: `. To learn how to set RoPE, check the respective model config.
### Sequence Parallelism
Splits long sequences across multiple GPUs, enabling training with sequence lengths that would not fit on a single device.
- **Learn more:** [Sequence Parallelism Documentation](sequence_parallelism.qmd)
### Artic Long Sequence Training (ALST)
ALST is a recipe that combines several techniques to train long-context models efficiently. It typically involves:
- TiledMLP to reduce memory usage in MLP layers.
- Tiled Loss functions (like [CCE](#cut-cross-entropy-(cce) or [Liger](#liger-kernels)).
- Activation Offloading to CPU.
- Example: [ALST Example Configuration](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst)
## Large Models (Distributed Training)
To train models that don't fit on a single GPU, you'll need to use a distributed training strategy like FSDP or DeepSpeed. These frameworks shard the model weights, gradients, and optimizer states across multiple GPUs and nodes.
- **Learn more:** [Multi-GPU Guide](multi-gpu.qmd)
- **Learn more:** [Multi-Node Guide](multi-node.qmd)
### N-D Parallelism (Beta)
For advanced scaling, Axolotl allows you to compose different parallelism techniques (e.g., Data, Tensor, Sequence Parallelism). This is a powerful approach to train an extremely large model by overcoming multiple bottlenecks at once.
- **Learn more:** [N-D Parallelism Guide](nd_parallelism.qmd)
## Quantization
Techniques to reduce the precision of model weights for memory savings.
### 4-bit Training (QLoRA)
The recommended approach for quantization-based training. It loads the base model in 4-bit using `bitsandbytes` and then trains QLoRA adapters. See [Adapter Finetuning](#adapter-finetuning-lora-qlora) for details.
### FP8 Training
Enables training with 8-bit floating point precision on supported hardware (e.g., NVIDIA Hopper series GPUs) for significant speed and memory gains.
- **Example:** [Llama 3 FP8 FSDP Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-3/3b-fp8-fsdp2.yaml)
### Quantization Aware Training (QAT)
Simulates quantization effects during training, helping the model adapt and potentially improving the final accuracy of the quantized model.
- **Learn more:** [QAT Documentation](qat.qmd)
### GPTQ
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)
### MoE Expert Quantization
Quantizes MoE expert weights on load to reduce VRAM when training MoE models with adapters. Required for Transformers v5+ MoE models where experts use fused `nn.Parameter` tensors.
- **Config:** `quantize_moe_experts: true`
- **Learn more:** [MoE Expert Quantization](expert_quantization.qmd)

View File

@@ -23,18 +23,10 @@ To enable QAT in axolotl, add the following to your configuration file:
```yaml
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4".
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
```
We support the following quantization schemas:
- `Int4WeightOnly` (requires the `fbgemm-gpu` extra when installing Axolotl)
- `Int8DynamicActivationInt4Weight`
- `Float8DynamicActivationFloat8Weight`
- `Float8DynamicActivationInt4Weight`
- `NVFP4`
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize`](./quantize.qmd) command to do this.

View File

@@ -22,8 +22,8 @@ Quantization is configured using the `quantization` key in your configuration fi
```yaml
base_model: # The path to the model to quantize.
quantization:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4", "int8", "float8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4", "fp8", and "nvfp4".
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
@@ -39,8 +39,9 @@ you used to train the model:
# qat.yml
qat:
activation_dtype: int8
weight_dtype: int4
weight_dtype: int8
group_size: 256
quantize_embedding: true
output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
```
@@ -50,11 +51,3 @@ axolotl quantize qat.yml
```
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.
::: {.callout-note}
If you have configured pushing to hub with `hub_model_id`, your model hub name will have the quantization schema appended to it,
e.g. `axolotl-ai-cloud/qat-nvfp4-llama3B` will become `axolotl-ai-cloud/qat-nvfp4-llama3B-nvfp4w`
:::

View File

@@ -11,7 +11,6 @@ We support the reward modelling techniques supported by `trl`.
### (Outcome) Reward Models
Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).
For improved training stability, you can use the `center_rewards_coefficient` parameter to encourage mean-zero reward outputs ([see TRL docs](https://huggingface.co/docs/trl/v0.10.1/en/reward_trainer#centering-rewards)).
```yaml
base_model: google/gemma-2-2b

View File

@@ -16,12 +16,7 @@ feedback. Various methods include, but not limited to:
- [Identity Preference Optimization (IPO)](#ipo)
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
- [Group Relative Policy Optimization (GRPO)](#grpo) — see also the [GRPO deep dive](grpo.qmd) for async features, custom rewards, and scaling
- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)
- [Energy-Based Fine-Tuning (EBFT)](#ebft) — see also the [EBFT guide](ebft.qmd) for detailed mode comparisons and configuration
- [NeMo Gym Integration](#nemo-gym-integration)
For help choosing between these methods, see [Choosing a Fine-Tuning Method](choosing_method.qmd).
- [Group Relative Policy Optimization (GRPO)](#grpo)
## RLHF using Axolotl
@@ -224,21 +219,6 @@ DPO supports the following types with the following dataset format:
}
```
#### chat_template.argilla_chat
```json
{
"chosen": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
],
"rejected": [
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}
]
}
```
#### chat_template.default
```yaml
@@ -517,7 +497,7 @@ The input format is a simple JSON input with customizable fields based on the ab
### GRPO
::: {.callout-tip}
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/grpo_code). For a comprehensive guide covering async training, custom rewards, importance sampling, and scaling, see the [GRPO deep dive](grpo.qmd).
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/grpo_code).
:::
In the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM:
@@ -602,116 +582,6 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
#### OpenEnv Rollout Functions
GRPO supports custom rollout functions for OpenEnv-style environments, enabling interactive tasks like web browsing, code execution, or tool use. This allows you to implement custom generation logic that interacts with external environments.
For example, to implement a simple math-solving environment with step-by-step verification:
```python
# math_env.py
import re
def math_solver_rollout(model, processing_class, prompts, generation_config=None):
"""
Custom rollout function that generates step-by-step math solutions.
Args:
model: The language model
processing_class: The tokenizer/processing_class
prompts: List of prompt dicts (with 'messages' key for chat format)
generation_config: Optional generation configuration
Returns:
List of completion strings
"""
completions = []
for prompt in prompts:
# Apply chat template to prompt
messages = prompt.get("messages", [])
formatted_prompt = processing_class.apply_chat_template(
messages, processing_class=False, add_generation_prompt=True
)
# Generate step-by-step solution
full_response = ""
for step in range(5): # Max 5 reasoning steps
current_input = formatted_prompt + full_response + "\nNext step:"
inputs = processing_class(current_input, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=100,
generation_config=generation_config,
)
step_text = processing_class.decode(
outputs[0][inputs.input_ids.shape[1]:],
skip_special_tokens=True
)
# Check if solution is complete
if "FINAL ANSWER:" in step_text:
full_response += step_text
break
full_response += step_text + "\n"
completions.append(full_response)
return completions
def math_reward(prompts, completions, answers, **kwargs):
"""Reward function that checks mathematical correctness"""
rewards = []
for completion, correct_answer in zip(completions, answers):
# Extract predicted answer
match = re.search(r"FINAL ANSWER:\s*(.+)", completion)
predicted = match.group(1).strip() if match else ""
# Compare with correct answer
reward = 1.0 if predicted == str(correct_answer) else 0.0
rewards.append(reward)
return rewards
def math_transform(cfg, *args, **kwargs):
"""Transform dataset to GRPO format with answer field"""
def transform_fn(example, processing_class=None):
return {
"prompt": [{"role": "user", "content": example["question"]}],
"answer": str(example["answer"]),
}
return transform_fn, {"remove_columns": ["question"]}
```
```yaml
rl: grpo
trl:
beta: 0.001
max_completion_length: 512
num_generations: 4
rollout_func: "math_env.math_solver_rollout" # Custom rollout function
reward_funcs: ["math_env.math_reward"]
reward_weights: [1.0]
datasets:
- path: openai/gsm8k
name: main
type: math_env.math_transform
```
The `rollout_func` parameter accepts a fully qualified name (e.g., `module_name.function_name`) that points to a callable function in your local directory. The function receives:
- `model`: The language model
- `processing_class`: The tokenizer/processing class
- `prompts`: List of prompt dictionaries
- `generation_config` (optional): Generation configuration
And should return a list of completion strings.
For more OpenEnv examples, see [TRL OpenEnv Documentation](https://huggingface.co/docs/trl/main/en/openenv).
#### GRPO with DAPO/Dr. GRPO loss
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.
@@ -725,309 +595,6 @@ trl:
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
#### Async GRPO
Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.
```yaml
trl:
use_data_producer: true # Enable data producer protocol
use_vllm: true
async_prefetch: true # Generate rollouts in background thread
prefetch_depth: 1 # Number of rollouts to prefetch
vllm_sync_interval: 2 # Sync weights to vLLM every N steps
```
::: {.callout-note}
Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by `vllm_importance_sampling_correction: true` (default when async is enabled).
:::
##### vLLM LoRA Sync
By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.
```yaml
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
trl:
vllm_lora_sync: true # Enable native LoRA sync
```
When `vllm_lora_sync: true` is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:
```bash
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
```
Then start training on a separate GPU:
```bash
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
```
::: {.callout-tip}
LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.
:::
##### Streaming Partial Batch
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.
```yaml
trl:
streaming_partial_batch: true
```
##### Importance Sampling Correction
When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.
```yaml
trl:
vllm_importance_sampling_correction: true # Enable IS correction
importance_sampling_level: token # 'token' or 'sequence'
off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below this
```
- `importance_sampling_level: token` applies per-token IS ratios (recommended with Liger kernel)
- `importance_sampling_level: sequence` applies per-sequence IS ratios
- `off_policy_mask_threshold` masks out sequences where the IS ratio indicates they are too far off-policy
##### Replay Buffer
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.
```yaml
trl:
replay_buffer_size: 100 # Max cached groups (0 = disabled)
replay_recompute_logps: true # Recompute log-probs for replayed data (recommended)
```
::: {.callout-note}
When `replay_recompute_logps: true` (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.
:::
##### Deferred Re-rolling
Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.
```yaml
trl:
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
reroll_max_groups: 1 # Max groups to replace per batch
```
##### Zero-Advantage Batch Skipping
When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as `skipped_zero_adv_batches=1`.
```yaml
trl:
skip_zero_advantage_batches: true # default
```
##### Parallel Reward Workers
Reward functions that use `signal.alarm()` (e.g., `math_verify`) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.
```yaml
trl:
reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism)
```
##### Full Async GRPO Example
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
gpu_memory_utilization: 0.35
dtype: auto
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
rl: grpo
trl:
use_data_producer: true
use_vllm: true
async_prefetch: true
prefetch_depth: 1
vllm_sync_interval: 2
vllm_lora_sync: true
streaming_partial_batch: true
vllm_importance_sampling_correction: true
off_policy_mask_threshold: 0.5
importance_sampling_level: token
num_generations: 8
max_completion_length: 512
reward_funcs:
- rewards.accuracy_reward
reroll_start_fraction: 0.5
replay_buffer_size: 100
reward_num_workers: 4
skip_zero_advantage_batches: true
datasets:
- path: AI-MO/NuminaMath-TIR
type: rewards.prompt_transform
split: train
gradient_accumulation_steps: 4
micro_batch_size: 2
max_steps: 500
learning_rate: 1e-5
bf16: true
gradient_checkpointing: true
```
```bash
# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Terminal 2: Train on GPU 1
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
```
##### Multi-GPU Async GRPO
Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.
**FSDP:**
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
gradient_checkpointing_kwargs:
use_reentrant: false
```
**DeepSpeed ZeRO-3:**
```yaml
deepspeed: deepspeed_configs/zero3_bf16.json
gradient_checkpointing_kwargs:
use_reentrant: true # Required for ZeRO-3
```
```bash
# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Terminal 2: Train on GPUs 0,1
CUDA_VISIBLE_DEVICES=0,1 axolotl train config.yaml
```
::: {.callout-important}
With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.
:::
### GDPO
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
::: {.callout-tip}
Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.
:::
Paper: [https://arxiv.org/pdf/2501.05242](https://arxiv.org/pdf/2501.05242)
GDPO uses TRL's native `multi_objective_aggregation` parameter under the hood. When you set `rl: gdpo`, axolotl automatically configures TRL to use `normalize_then_sum` aggregation.
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
rl: gdpo
trl:
beta: 0.001
max_completion_length: 256
use_vllm: true
num_generations: 4
reward_funcs:
- rewards.format_reward
- rewards.correctness_reward
reward_weights: [1.0, 2.0]
datasets:
- path: openai/gsm8k
name: main
type: rewards.oai_gsm8k_transform
```
You can also use GRPO with explicit aggregation control:
```yaml
rl: grpo
trl:
multi_objective_aggregation: normalize_then_sum # GDPO behavior
# or: sum_then_normalize # Default GRPO behavior
```
#### GDPO vs GRPO
| Aspect | GRPO | GDPO |
|--------|------|------|
| **Aggregation** | `sum_then_normalize` | `normalize_then_sum` |
| **Multi-reward** | May collapse advantages | Preserves reward signals |
| **Single reward** | Standard behavior | Equivalent to GRPO |
#### Why GDPO?
When using multiple rewards with GRPO, different reward combinations can produce identical advantages:
```
# Example: format + correctness rewards
[format=0, correct=3] → sum=3
[format=1, correct=2] → sum=3 ← GRPO sees these as equal!
[format=2, correct=1] → sum=3
[format=3, correct=0] → sum=3
```
GDPO normalizes each reward independently, preserving their relative differences.
#### Reward Functions
GDPO uses the same reward function format as GRPO:
```python
# rewards.py
def format_reward(completions, **kwargs) -> list[float]:
return [1.0 if len(c) > 10 else 0.0 for c in completions]
def correctness_reward(completions, answers, **kwargs) -> list[float]:
rewards = []
for completion, answer in zip(completions, answers):
# Your scoring logic here
rewards.append(score)
return rewards
```
#### Sequence Parallelism
GDPO supports sequence parallelism for long-context training:
```yaml
rl: gdpo
context_parallel_size: 2
```
### SimPO
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
@@ -1041,306 +608,6 @@ simpo_gamma: 0.5 # default in CPOTrainer
This method uses the same dataset format as [DPO](#dpo).
### EBFT {#ebft}
::: {.callout-tip}
For a detailed guide on EBFT modes, feature extraction, and configuration, see the [EBFT guide](ebft.qmd).
:::
EBFT (Energy-Based Fine-Tuning) fine-tunes language models by optimizing a **feature-matching loss** rather than relying on external reward functions. A frozen copy of the model extracts embeddings from both generated and ground-truth completions, and the generator is updated via REINFORCE to match the ground-truth feature moments.
Paper: ["Matching Features, Not Tokens: Energy-Based Fine-Tuning of Language Models"](https://arxiv.org/abs/2603.12248) (Jelassi et al., 2026)
**Key advantages:**
- No reward model or verifier required — works on any (prompt, completion) data
- Applicable to non-verifiable tasks (code, translation, creative writing)
- Operates on model rollouts (not teacher forcing), reducing distribution shift
EBFT supports two modes:
- **Structured mode**: For QA/instruction data with prompt + completion pairs. Uses vLLM for generation (like GRPO).
- **Strided mode**: For unstructured text without prompt/completion splits. Uses strided block-parallel generation with flex_attention — no vLLM needed.
#### Structured Mode
```yaml
base_model: Qwen/Qwen3-4B
rl: ebft
ebft:
feature_layers: [0.25, 0.5, 0.75] # Extract features at 25%, 50%, 75% depth
embed_method: last_token
use_whitening: false
alignment_coef: 1.0 # Cosine similarity reward weight
diversity_coef: 1.0 # Pairwise dot product penalty
ce_coef: 0.0 # Cross-entropy on GT tokens (0 = off)
trl:
num_generations: 4
max_completion_length: 256
temperature: 0.7
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
vllm_lora_sync: true # LoRA adapter sync (recommended)
vllm_sync_interval: 3
use_data_producer: true
async_prefetch: true # Set false for sync mode
scale_rewards: true
loss_type: grpo
epsilon: 0.2
vllm:
gpu_memory_utilization: 0.5
max_model_len: 2048
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_opencode.transform
split: train[:500]
adapter: lora
lora_r: 16
lora_alpha: 32
lora_target_linear: true
```
```bash
# Terminal 1: Start vLLM
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Terminal 2: Train
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
```
#### Strided Mode
For unstructured text (raw code, prose). No vLLM needed — runs on a single GPU.
```yaml
base_model: meta-llama/Llama-3.2-1B
rl: ebft
ebft:
mode: strided
stride: 8
context_length: 8
generate_max_len: 8
n_samples_per_prompt: 4
temperature: 0.6
feature_layers: [0.25, 0.5, 0.75]
embed_method: last_token
use_whitening: true
alignment_coef: 1.0
diversity_coef: 1.0
rl_coef: 1.0
ce_coef: 0.03
advantage_estimator: rloo
datasets:
- path: nvidia/OpenCodeInstruct
type: ebft_strided_structured.transform
split: train[:1%]
flash_attention: false
flex_attention: true # Strided mode uses flex_attention
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true # Required for flex_attention
```
```bash
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
```
::: {.callout-tip}
See `examples/ebft/` for complete example configs covering Llama 1B/3B/8B and Qwen3 4B/8B models in both modes.
:::
#### EBFT Configuration Reference
| Parameter | Default | Description |
|-----------|---------|-------------|
| `ebft.feature_layers` | `[0.25, 0.5, 0.75]` | Layer depths for feature extraction (fractional) |
| `ebft.embed_method` | `last_token` | Feature pooling: `last_token`, `mean_pooling`, `concat` |
| `ebft.use_whitening` | `false` | SVD whitening of feature dimensions |
| `ebft.alignment_coef` | `1.0` | Cosine similarity reward weight |
| `ebft.diversity_coef` | `1.0` | Pairwise dot product penalty weight |
| `ebft.ce_coef` | `0.0` | Cross-entropy loss on ground-truth tokens |
| `ebft.mode` | `structured` | `structured` (vLLM) or `strided` (no vLLM) |
| `ebft.stride` | — | Tokens between anchor points (strided mode) |
| `ebft.context_length` | — | Context window per block (strided mode) |
| `ebft.generate_max_len` | — | Tokens to generate per block (strided mode) |
| `ebft.n_samples_per_prompt` | — | Rollouts per document (strided mode) |
| `ebft.advantage_estimator` | `grpo` | `grpo` or `rloo` (strided mode) |
### NeMo Gym Integration
[NeMo Gym](https://github.com/NVIDIA-NeMo/Gym) provides 50+ verified RL environments (math, coding, tool-use, reasoning) with deterministic reward signals. The axolotl integration supports both **single-turn** (call `/verify` after generation) and **multi-turn** (agent-based tool execution via `/run`).
#### Single-Turn (Simplest)
For environments that only need answer verification (math, coding challenges). No agent server needed — the reward function calls `/verify` directly on the resource server.
```yaml
base_model: Qwen/Qwen2.5-0.5B-Instruct
rl: grpo
chat_template: tokenizer_default
trl:
use_vllm: false # Colocate mode (single GPU)
num_generations: 4
max_completion_length: 128
temperature: 0.9
reward_funcs:
- axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify
plugins:
- axolotl.integrations.nemo_gym.NemoGymPlugin
nemo_gym_enabled: true
nemo_gym_dir: ~/Gym
nemo_gym_auto_start: false
nemo_gym_head_port: 11000
nemo_gym_datasets:
- path: resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
server_name: reasoning_gym
datasets:
- path: ~/Gym/resources_servers/reasoning_gym/data/train_basic_arithmetic.jsonl
type: chat_template
field_messages: responses_create_params.input
message_field_content: content
message_field_role: role
```
```bash
# Terminal 1: Start NeMo Gym resource server
cd ~/Gym && .venv/bin/ng_run \
"+config_paths=[resources_servers/reasoning_gym/configs/resources_only.yaml]" \
"+skip_venv_if_present=true"
# Terminal 2: Train
CUDA_VISIBLE_DEVICES=0 axolotl train config.yaml
```
::: {.callout-note}
`nemo_gym_datasets.path` is relative to `nemo_gym_dir`. Don't use absolute paths or they will be double-joined.
:::
#### Multi-Turn with Async GRPO (Recommended)
For environments with tool-use (weather, search, databases). An agent server orchestrates multi-turn interactions: generate → parse tool calls → execute tools → feed results back → repeat until done.
```yaml
base_model: Qwen/Qwen3-0.6B
rl: grpo
chat_template: tokenizer_default
adapter: lora
lora_r: 16
lora_alpha: 32
lora_target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
trl:
use_vllm: true
vllm_mode: server
vllm_server_host: localhost
vllm_server_port: 8000
vllm_lora_sync: true
vllm_sync_interval: 5
use_data_producer: true
async_prefetch: true # 3x speedup
num_generations: 4
max_completion_length: 512
temperature: 0.8
reward_funcs:
- axolotl.integrations.nemo_gym.rewards.reward_env
plugins:
- axolotl.integrations.nemo_gym.NemoGymPlugin
nemo_gym_enabled: true
nemo_gym_auto_start: false
nemo_gym_head_port: 11000
nemo_gym_multi_turn: true
nemo_gym_verify_timeout: 120
nemo_gym_datasets:
- path: resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
server_name: example_single_tool_call
datasets:
- path: ~/Gym/resources_servers/example_single_tool_call/data/weather_tool_calling.jsonl
type: chat_template
field_messages: responses_create_params.input
message_field_content: content
message_field_role: role
vllm:
gpu_memory_utilization: 0.85
max_model_len: 2048
```
Multi-turn requires three services running:
```bash
# Terminal 1: vLLM with LoRA + tool calling
VLLM_ALLOW_RUNTIME_LORA_UPDATING=1 CUDA_VISIBLE_DEVICES=0 \
python -m vllm.entrypoints.openai.api_server \
--model Qwen/Qwen3-0.6B --max-model-len 2048 \
--gpu-memory-utilization 0.85 \
--enable-lora --max-lora-rank 64 \
--enable-auto-tool-choice --tool-call-parser hermes
# Terminal 2: NeMo Gym servers (resource + model proxy + agent)
cd ~/Gym && .venv/bin/ng_run \
"+config_paths=[configs/axolotl_tool_calling.yaml]" \
"+skip_venv_if_present=true"
# Terminal 3: Training
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
```
::: {.callout-important}
Multi-turn requires a NeMo Gym agent config YAML that defines three components: a resource server (tools + `/verify`), a model server proxy (forwards to your vLLM), and an agent server (orchestrates `/run`). See the [NeMo Gym README](https://github.com/NVIDIA-NeMo/Gym) for agent config format.
:::
#### NeMo Gym Prerequisites
```bash
# Clone and set up NeMo Gym
git clone https://github.com/NVIDIA-NeMo/Gym.git ~/Gym
cd ~/Gym
uv venv --python 3.12 && source .venv/bin/activate && uv sync
# Fix pycosat build (GCC 13+)
CFLAGS="" uv pip install pycosat --python .venv/bin/python --no-build-isolation
```
#### NeMo Gym Configuration Reference
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `nemo_gym_enabled` | bool | — | Enable the NeMo Gym integration |
| `nemo_gym_dir` | str | `~/Gym` | Path to NeMo Gym repo |
| `nemo_gym_auto_start` | bool | `true` | Auto-start resource servers |
| `nemo_gym_head_port` | int | `11000` | Head server port |
| `nemo_gym_multi_turn` | bool | `false` | Enable multi-turn via agent `/run` |
| `nemo_gym_verify_timeout` | int | `30` | Per-request timeout (seconds) |
| `nemo_gym_datasets` | list | required | Dataset configs with `path` and `server_name` |
#### Reward Functions
| Function | Mode | Description |
|----------|------|-------------|
| `axolotl.integrations.nemo_gym.rewards.reward_nemo_gym_verify` | Single-turn | Calls `/verify`, returns binary reward |
| `axolotl.integrations.nemo_gym.rewards.reward_env` | Multi-turn | Passthrough reward from agent `/run` |
### Using local dataset files
```yaml

View File

@@ -1,90 +0,0 @@
examples:
# December 2025
- name: kimi-linear
title: Kimi Linear
- name: plano
title: Plano Orchestrator
- name: mimo
title: MiMo
- name: internvl3_5
title: InternVL 3.5
# AllenAI
- name: olmo3
title: OLMo 3
# ArceeAI
- name: trinity
title: Trinity
- name: arcee
title: Arcee AFM
# MistralAI
- name: ministral3/think
title: Ministral 3 Thinking
- name: ministral3/vision
title: Ministral 3 Vision
- name: magistral/think
title: Magistral Thinking
- name: magistral/vision
title: Magistral Vision
- name: ministral
title: Ministral
- name: mistral-small
title: Mistral Small 3.1/3.2
- name: voxtral
title: Voxtral
- name: devstral
title: Devstral
- name: mistral
title: Mistral 7B
# Meta
- name: llama-4
title: Llama 4
- name: llama-2
title: Llama 2
# Alibaba
- name: qwen3-next
title: Qwen 3 Next
- name: qwen3
title: Qwen 3
# Google
- name: gemma3n
title: Gemma 3n
# Swiss AI
- name: apertus
title: Apertus
# GPT-OSS
- name: gpt-oss
title: GPT-OSS
- name: seed-oss
title: Seed-OSS
# Microsoft
- name: phi
title: Phi
# SmolVLM
- name: smolvlm2
title: SmolVLM 2
# IBM
- name: granite4
title: Granite 4
# LiquidAI
- name: LiquidAI
title: Liquid Foundation Models 2
# Other
- name: hunyuan
title: Hunyuan
- name: jamba
title: Jamba
- name: orpheus
title: Orpheus

View File

@@ -47,6 +47,7 @@ class QuartoGenerator:
"""Check if a type is a Pydantic BaseModel."""
return inspect.isclass(type_obj) and issubclass(type_obj, BaseModel)
# pylint: disable=too-many-return-statements
def _extract_nested_type(self, field_type) -> Any:
"""Extract the actual type from complex type annotations."""
# Handle Annotated types (Python 3.9+)
@@ -123,6 +124,7 @@ class QuartoGenerator:
return field_type
# pylint: disable=too-many-return-statements
def _extract_all_pydantic_models_from_type(
self, field_type
) -> list[type[BaseModel]]:
@@ -316,6 +318,7 @@ class QuartoGenerator:
return all_groups
# pylint: disable=too-many-return-statements
def _extract_field_groups_from_source(
self, model_class: type[BaseModel]
) -> list[dict]:
@@ -500,7 +503,7 @@ class QuartoGenerator:
nested_schema = nested_model.model_json_schema()
nested_properties = nested_schema.get("properties", {})
nested_required = nested_schema.get("required", [])
except Exception:
except Exception: # pylint: disable=broad-exception-caught
# Fallback: use model fields directly
nested_properties = {}
nested_required = []
@@ -604,7 +607,7 @@ class QuartoGenerator:
schema = model_class.model_json_schema()
properties = schema.get("properties", {})
required = schema.get("required", [])
except Exception as e:
except Exception as e: # pylint: disable=broad-exception-caught
print(
f"Warning: Could not generate JSON schema ({e}). Using model fields instead."
)

View File

@@ -1,424 +0,0 @@
"""
auto generate example docs from allowlist
"""
import re
import shutil
import sys
from pathlib import Path
import yaml
# Paths
THIS = Path(__file__).resolve()
ROOT = THIS.parents[2] # repo root (docs/scripts -> docs -> ROOT)
EXAMPLES_DIR = ROOT / "examples"
OUTPUT_DIR = ROOT / "docs" / "models"
ALLOWLIST_YML = THIS.parent / "examples-allowlist.yml"
def slugify(name: str) -> str:
"""Convert a name to a slug (lowercase, hyphens for spaces)."""
s = re.sub(r"[^a-zA-Z0-9\s\-]+", "", name.strip())
s = re.sub(r"\s+", "-", s).strip("-").lower()
return s or "example"
def read_allowlist():
with open(ALLOWLIST_YML, "r", encoding="utf-8") as f:
data = yaml.safe_load(f) or {}
items = data.get("examples", [])
if not isinstance(items, list):
raise ValueError("`examples` must be a list in examples-allowlist.yml")
return items
def find_readme(folder: Path) -> Path | None:
for name in ("README.md", "Readme.md", "readme.md"):
p = folder / name
if p.exists():
return p
return None
def remove_first_h1(md: str) -> tuple[str, str | None]:
"""
Remove the first H1 from markdown and return (modified_md, h1_title).
The H1 is removed since we use the frontmatter title instead.
"""
lines = md.splitlines()
result = []
h1_title = None
skipped_first = False
for line in lines:
if not skipped_first and line.startswith("# "):
h1_title = line[2:].strip()
skipped_first = True
continue
result.append(line)
return "\n".join(result), h1_title
IMG_RE = re.compile(r"!\[[^\]]*\]\(([^)]+)\)")
LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
def rewrite_and_copy_assets(md: str, src_dir: Path, dest_assets_root: Path) -> str:
"""
Copy local image assets referenced in markdown to
docs/examples/assets/... and rewrite the links.
"""
dest_assets = dest_assets_root / "assets"
def repl(m):
url = m.group(1).strip()
if re.match(r"^(https?:)?//", url):
return m.group(0) # leave remote URLs
src_path = (src_dir / url).resolve()
if not src_path.exists():
return m.group(0) # leave as-is if not found
rel = src_path.relative_to(src_dir)
# Create a unique asset path based on source directory name
asset_name = src_dir.name.replace("/", "-")
dest_path = dest_assets / asset_name / rel
dest_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src_path, dest_path)
new_rel = f"assets/{asset_name}/{rel.as_posix()}"
return m.group(0).replace(url, new_rel)
return IMG_RE.sub(repl, md)
def rewrite_readme_links(
md: str,
src_dir: Path,
examples_dir: Path,
parent_index_only: set,
current_src_path: str,
allowlist_entries: set,
current_output_path: str,
) -> str:
"""
Rewrite links between README.md files to point to the correct .qmd files.
"""
def repl(m):
text = m.group(1)
url = m.group(2).strip()
# Skip remote URLs and anchor links
if re.match(r"^(https?:)?//", url) or url.startswith("#"):
return m.group(0)
# Skip non-markdown files
if not url.lower().endswith(".md"):
return m.group(0)
# Resolve the target path
try:
target_path = (src_dir / url).resolve()
# Check if target is outside examples_dir
try:
rel_path = target_path.relative_to(examples_dir)
except ValueError:
# Target is outside examples_dir, leave as-is
return m.group(0)
parts = list(rel_path.parts)
# Determine the output path for the target
if len(parts) > 0 and parts[-1].lower() in ("readme.md", "readme"):
# This is a README link
if len(parts) == 1:
# Link to root README -> index.qmd
target_output = "index.qmd"
elif len(parts) == 2:
if parts[0] == ".":
# Current directory README
target_output = "index.qmd"
else:
# subdir/README.md
parent_dir = parts[0]
if parent_dir in parent_index_only:
target_output = f"{parent_dir}/index.qmd"
else:
target_output = f"{parent_dir}.qmd"
else:
# Deeper nesting: parent/subdir/README.md
# Build the full path like "parent/subdir"
full_path = "/".join(parts[:-1]) # Remove README.md
# Check if this exact path is in allowlist
if full_path in allowlist_entries:
# This is a sub-entry with its own entry -> use .qmd
target_output = f"{full_path}.qmd"
elif parts[0] == ".":
# ./subdir/README.md -> check if subdir has own entry
subdir = parts[1]
if subdir in parent_index_only:
target_output = f"{subdir}/index.qmd"
else:
target_output = f"{subdir}.qmd"
else:
# parent/subdir where parent doesn't have own entry
target_output = f"{full_path}/index.qmd"
else:
# Regular .md file -> convert to .qmd, keep path structure
target_output = "/".join(parts)[:-2] + "qmd"
# Compute relative path from current output file to target
current_parts = current_output_path.split("/")
target_parts = target_output.split("/")
# Special case: if current is a subdir file and target is a single-component file at root
# Example: current="magistral/vision", target="magistral.qmd"
if len(current_parts) > 1 and len(target_parts) == 1:
# Current is in subdir, target is at root level
# Go up to root: ../ for each level
up_count = len(current_parts) - 1
rel_parts = [".."] * up_count + [target_parts[0]]
new_url = "/".join(rel_parts)
else:
# Find common prefix
i = 0
while (
i < min(len(current_parts) - 1, len(target_parts))
and current_parts[i] == target_parts[i]
):
i += 1
# Build relative path: go up (../) then down to target
up_count = len(current_parts) - 1 - i
rel_parts = [".."] * up_count + target_parts[i:]
if not rel_parts or rel_parts == [".."]:
# Points to same directory or parent
new_url = "/".join(rel_parts) if rel_parts else "."
else:
new_url = "/".join(rel_parts)
return f"[{text}]({new_url})"
except (ValueError, IndexError):
return m.group(0)
return LINK_RE.sub(repl, md)
def write_qmd(out_path: Path, title: str, body_md: str):
out_path.parent.mkdir(parents=True, exist_ok=True)
fm = f"---\ntitle: {title!r}\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
out_path.write_text(fm + body_md, encoding="utf-8")
def update_quarto_yml(generated: list[tuple[str, str, str]]):
"""
Update _quarto.yml with the generated example files in the correct order.
This keeps the sidebar in sync with the allowlist.
Model Guides is now nested under "Getting Started" section.
Creates nested sections for models with sub-entries (e.g., magistral, ministral3).
Parent pages are now flat files (e.g., ministral3.qmd) with sub-pages in subdirs.
"""
quarto_yml = ROOT / "_quarto.yml"
if not quarto_yml.exists():
print(f"[WARN] {quarto_yml} not found, skipping update", file=sys.stderr)
return
content = quarto_yml.read_text(encoding="utf-8")
# First pass: find all parents that have sub-entries
parents_with_subs = set()
for path, _name, _title in generated:
if "/" in path:
parent = path.split("/")[0]
parents_with_subs.add(parent)
# Build the YAML contents while preserving allowlist order
lines = []
processed_sections = set()
for path, _name, title in generated:
# Check if this is a parent page that has sub-pages
if path in parents_with_subs:
# This is a parent page with sub-pages - create a nested section
if path not in processed_sections:
processed_sections.add(path)
section_title = (
title or path.replace("-", " ").replace("_", " ").title()
)
lines.append(f' - section: "{section_title}"')
lines.append(" contents:")
# Add the parent page first
lines.append(f" - docs/models/{path}.qmd")
# Then add all sub-pages
for sub_path, _sub_name, _sub_title in generated:
if "/" in sub_path and sub_path.split("/")[0] == path:
lines.append(
f" - docs/models/{sub_path}.qmd"
)
elif "/" not in path:
# This is a flat item with no sub-pages
# Skip if it was already included as part of a parent section
if path not in processed_sections:
lines.append(f" - docs/models/{path}.qmd")
yaml_content = "\n".join(lines) + "\n"
# Pattern to match only the Model Guides contents, stopping at the next item
# in Getting Started (lines starting with 12 spaces: same level as the section)
pattern = r'( - section: "Model Guides"\n contents:)([^\n]*|.*?)(?=\n - |\n - section:|\n\nformat:)'
def replacement(match):
prefix = match.group(1)
return prefix + "\n" + yaml_content
new_content = re.sub(pattern, replacement, content, flags=re.DOTALL)
if new_content != content:
quarto_yml.write_text(new_content, encoding="utf-8")
print(f"Updated {quarto_yml}")
else:
print(f"No changes needed for {quarto_yml}")
def main():
allow = read_allowlist()
if not EXAMPLES_DIR.exists():
print(f"[WARN] {EXAMPLES_DIR} not found", file=sys.stderr)
return
(OUTPUT_DIR / "assets").mkdir(parents=True, exist_ok=True)
# First pass: identify which parents have their own entry vs only sub-entries
parent_entries = set() # Parents that have their own entry
parent_with_subs = set() # Parents that have sub-entries
allowlist_entries = set() # All entries in allowlist
for item in allow:
if isinstance(item, str):
name = item
else:
name = item.get("name")
allowlist_entries.add(name)
if "/" in name:
parent = name.split("/")[0]
parent_with_subs.add(parent)
else:
parent_entries.add(name)
# Parents with subs that DON'T have their own entry -> use index.qmd
parent_index_only = parent_with_subs - parent_entries
generated = []
seen_dirs = set() # Track which parent directories we've created index for
for item in allow:
if isinstance(item, str):
name = item
title = None
else:
name = item.get("name")
title = item.get("title")
if not name:
print(f"[WARN] Skipping item without name: {item}", file=sys.stderr)
continue
src_dir = EXAMPLES_DIR / name
if not src_dir.exists() or not src_dir.is_dir():
print(f"[WARN] Skipping {name} (not a directory)", file=sys.stderr)
continue
readme = find_readme(src_dir)
if not readme:
print(f"[WARN] Skipping {name} (no README.md)", file=sys.stderr)
continue
md = readme.read_text(encoding="utf-8")
# Determine output path first (needed for link rewriting)
parts = name.split("/")
if len(parts) == 1:
# Simple case: no subdirectory
out_path = OUTPUT_DIR / f"{parts[0]}.qmd"
sidebar_path = parts[0]
else:
# Has subdirectory: e.g., magistral/think
parent = parts[0]
child = "-".join(parts[1:]) # handle nested subdirs
out_path = OUTPUT_DIR / parent / f"{child}.qmd"
sidebar_path = f"{parent}/{child}"
# Remove the first H1 (we use frontmatter title instead)
md, _ = remove_first_h1(md)
# Rewrite links between README files
md = rewrite_readme_links(
md,
src_dir,
EXAMPLES_DIR,
parent_index_only,
name,
allowlist_entries,
sidebar_path,
)
md = rewrite_and_copy_assets(md, src_dir, OUTPUT_DIR)
# Handle parent page generation for sub-entries
if len(parts) > 1:
# Has subdirectory: e.g., magistral/think
parent = parts[0]
# Create parent.qmd if not already done and parent doesn't have own entry
if parent not in seen_dirs and parent in parent_index_only:
parent_readme = find_readme(EXAMPLES_DIR / parent)
if parent_readme:
parent_md = parent_readme.read_text(encoding="utf-8")
parent_md, _ = remove_first_h1(parent_md)
parent_md = rewrite_readme_links(
parent_md,
EXAMPLES_DIR / parent,
EXAMPLES_DIR,
parent_index_only,
parent,
allowlist_entries,
parent,
)
parent_md = rewrite_and_copy_assets(
parent_md, EXAMPLES_DIR / parent, OUTPUT_DIR
)
parent_title = parent.replace("-", " ").replace("_", " ").title()
write_qmd(OUTPUT_DIR / f"{parent}.qmd", parent_title, parent_md)
generated.append((parent, parent, parent_title))
seen_dirs.add(parent)
if not title:
title = name.replace("/", " ").replace("-", " ").title()
write_qmd(out_path, title, md)
generated.append((sidebar_path, name, title))
# Index page - preserve allowlist order
if generated:
listing = "\n".join(
[f"- [{title}]({path}.qmd)" for path, name, title in generated]
)
index_md = (
"# Model Guides\n\nBelow are the curated examples for training various model architectures:\n\n"
+ listing
+ "\n"
)
index_fm = (
"---\nexecute:\n eval: false\nformat:\n html:\n toc: true\n---\n\n"
)
(OUTPUT_DIR / "index.qmd").write_text(index_fm + index_md, encoding="utf-8")
# Auto-update _quarto.yml to keep sidebar in sync
update_quarto_yml(generated)
if __name__ == "__main__":
main()

View File

@@ -1,120 +0,0 @@
---
title: Streaming Datasets
description: How to use streaming mode for large-scale datasets and memory-efficient training
order: 10
---
Streaming enables memory-efficient training with large datasets by loading data
incrementally rather than loading the entire dataset into memory at once.
Use streaming when:
- Your dataset is too large to fit in memory (e.g. when you're doing pretraining with massive text corpora)
- You want to start training immediately without preprocessing the entire dataset
Streaming works with both remote and locally stored datasets!
::: {.callout-note}
Streaming currently only supports a single dataset. Multi-dataset support will be added soon.
:::
## Configuration
### Basic Streaming
Enable streaming mode by setting the `streaming` flag:
```yaml
streaming: true
```
### Pretraining with Streaming
For pretraining tasks, streaming is automatically enabled when using `pretraining_dataset`:
```yaml
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
type: pretrain
text_column: text
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
### SFT with Streaming
For supervised fine-tuning with streaming:
```yaml
streaming: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
## Configuration Options
### `streaming_multipack_buffer_size`
Controls the buffer size for multipack streaming (default: 10,000). This determines how
many samples are buffered before packing. Larger buffers can improve packing efficiency
but use more memory.
### `shuffle_merged_datasets`
When enabled, shuffles the streaming dataset using the buffer. This requires additional
memory for the shuffle buffer.
## Sample Packing with Streaming
Sample packing is supported for streaming datasets. When enabled, multiple samples are
packed into a single sequence to maximize GPU utilization:
```yaml
sample_packing: true
streaming_multipack_buffer_size: 10000
# For SFT: attention is automatically isolated between packed samples
# For pretraining: control with pretrain_multipack_attn
pretrain_multipack_attn: true # prevent cross-attention between packed samples
```
For more information, see our [documentation](multipack.qmd) on multipacking.
## Important Considerations
### Memory Usage
While streaming reduces memory usage compared to loading entire datasets, you still need
to consider:
- You can control the memory usage by adjusting `streaming_multipack_buffer_size`
- Sample packing requires buffering multiple samples
- Shuffling requires additional memory for the shuffle buffer
### Performance
- Streaming may have slightly higher latency compared to preprocessed datasets, as samples are processed on-the-fly
- Network speed and disk read speed are important when streaming from remote sources or a local dataset, respectively
- Consider using `axolotl preprocess` for smaller or more frequently used datasets
### Evaluation Datasets
Evaluation datasets are not streamed to ensure consistent evaluation metrics. They're
loaded normally even when training uses streaming.
## Examples
See the `examples/streaming/` directory for complete configuration examples:
- `pretrain.yaml`: Pretraining with streaming dataset
- `sft.yaml`: Supervised fine-tuning with streaming

View File

@@ -1,61 +0,0 @@
---
title: Telemetry
description: A description of the telemetry implementation in Axolotl.
---
# Telemetry in Axolotl
Axolotl implements anonymous telemetry to help maintainers understand how the library
is used and where users encounter issues. This data helps prioritize features, optimize
performance, and fix bugs.
## Data Collection
We collect:
- System info: OS, Python version, Axolotl version, PyTorch version, Transformers
version, etc.
- Hardware info: CPU count, memory, GPU count and models
- Runtime metrics: Training progress, memory usage, timing information
- Usage patterns: Models (from a whitelist) and configurations used
- Error tracking: Stack traces and error messages (sanitized to remove personal
information)
Personally identifiable information (PII) is not collected.
## Implementation
Telemetry is implemented using PostHog and consists of:
- `axolotl.telemetry.TelemetryManager`: A singleton class that initializes the
telemetry system and provides methods for tracking events.
- `axolotl.telemetry.errors.send_errors`: A decorator that captures exceptions and
sends sanitized stack traces.
- `axolotl.telemetry.runtime_metrics.RuntimeMetricsTracker`: A class that tracks
runtime metrics during training.
- `axolotl.telemetry.callbacks.TelemetryCallback`: A Trainer callback that sends
runtime metrics telemetry.
The telemetry system will block training startup for 10 seconds to ensure users are
aware of data collection, unless telemetry is explicitly enabled or disabled.
## Opt-Out Mechanism
Telemetry is **enabled by default** on an opt-out basis. To disable it, set
`AXOLOTL_DO_NOT_TRACK=1` or `DO_NOT_TRACK=1`.
A warning message will be logged on start to clearly inform users about telemetry.
We will remove this after some period.
To hide the warning message about telemetry that is displayed on train, etc. startup,
explicitly set: `AXOLOTL_DO_NOT_TRACK=0` (enable telemetry) or `AXOLOTL_DO_NOT_TRACK=1`
(explicitly disable telemetry).
## Privacy
- All path-like config information is automatically redacted from telemetry data
- Model information is only collected for whitelisted organizations
- See `axolotl/telemetry/whitelist.yaml` for the set of whitelisted organizations
- Each run generates a unique anonymous ID
- This allows us to link different telemetry events in a single same training run
- Telemetry is only sent from the main process to avoid duplicate events

View File

@@ -1,399 +0,0 @@
---
title: "Training Stability & Debugging"
order: 15
description: "Guide to monitoring, debugging, and stabilizing training runs in axolotl"
---
This guide covers practical techniques for monitoring training health, diagnosing instability, and resolving common failures in both supervised fine-tuning (SFT) and reinforcement learning (GRPO/EBFT) workflows.
## Monitoring Training
### Key Metrics for SFT
Every SFT run should be monitored through at least these four metrics:
| Metric | What It Tells You | Healthy Range |
|--------|-------------------|---------------|
| `train/loss` | How well the model fits training data | Decreasing; typically 0.5--2.0 for chat fine-tuning |
| `eval/loss` | Generalization performance | Tracks train loss with small gap; divergence signals overfitting |
| `grad_norm` | Gradient magnitude | 0.1--10.0; spikes above 100 indicate instability |
| `learning_rate` | Current LR from scheduler | Should follow expected schedule (warmup then decay) |
::: {.callout-tip}
## Set Up Logging Early
Enable W&B or TensorBoard from the start. Debugging a failed run without metrics is guesswork.
```yaml
wandb_project: my-project
wandb_run_id: # optional, for resuming
logging_steps: 1
```
:::
### Key Metrics for RL (GRPO)
GRPO training logs a richer set of metrics. These are the critical ones:
| Metric | Healthy Range | Red Flag |
|--------|---------------|----------|
| `rewards/<name>/mean` | > 0.15 within 20 steps | Stays at 0 -- reward function is broken or task is too hard |
| `reward_std` | > 0 on most steps | Always 0 -- no learning signal (all completions get the same reward) |
| `frac_reward_zero_std` | < 0.8 | 1.0 on every step -- zero-advantage skip fires constantly, no gradient updates |
| `grad_norm` | 0.001--1.0 | 0.0 is acceptable occasionally (zero-adv skip); > 10.0 is unstable |
| `entropy` | 0.05--0.5 | < 0.01 suggests mode collapse; > 1.0 suggests the model is not converging |
| `kl` | 0.0--0.5 | > 2.0 suggests policy has diverged too far from reference |
| `sampling/sampling_logp_difference/mean` | < 0.1 | > 1.0 means policy has diverged far from vLLM server weights |
| `sampling/importance_sampling_ratio/min` | > 0.1 | Near 0 indicates stale off-policy data; increase `vllm_sync_interval` |
| `clip_ratio/region_mean` | < 0.1 | > 0.3 means PPO clipping is too aggressive |
| `completions/mean_length` | Task-dependent | Monotonically increasing to max length suggests reward hacking |
| `completions/clipped_ratio` | < 0.3 | > 0.8 means most completions hit `max_completion_length` -- increase it |
::: {.callout-note}
## EBFT-Specific Metrics
For EBFT training, also monitor `ebft/alignment` (should trend upward, healthy 0.3--0.9), `ebft/diversity` (healthy 0.01--0.1; > 1.0 indicates mode collapse), and `ebft/cfm_loss` (should trend downward, < 10).
:::
## SFT Stability
### Loss Plateau
**Symptom**: Loss stops decreasing early in training, well above expected values.
**Causes and fixes**:
- **Learning rate too low**: Increase by 2--5x. Typical ranges: full fine-tune 1e-5 to 5e-5, LoRA 1e-4 to 3e-4.
- **Insufficient warmup**: Set `warmup_steps` to 5--10% of total steps. Too-aggressive learning at the start can push the model into a flat region.
- **Data quality**: Check that labels are correctly masked. Use `axolotl preprocess` and inspect tokenized samples to confirm only the target tokens are trainable.
- **Weight decay too high**: Default 0.01 is usually fine. Values above 0.1 can suppress learning in LoRA.
### Loss Spikes
**Symptom**: Loss suddenly jumps by 2--10x then (possibly) recovers.
**Causes and fixes**:
- **Bad data samples**: A single malformed or extremely long example can cause a spike. Enable `sample_packing: false` temporarily and check if spikes correlate with specific batches.
- **Learning rate too high**: Reduce by 2--5x, or increase warmup.
- **Gradient accumulation mismatch**: Effective batch size = `micro_batch_size * gradient_accumulation_steps * num_gpus`. Very large effective batch sizes amplify gradient noise.
- **Mixed precision issues**: With `bf16: true`, some operations can lose precision. If spikes are severe, try `fp32` for diagnosis.
### Overfitting
**Symptom**: Train loss keeps decreasing but eval loss starts increasing.
**Fixes**:
- Increase `val_set_size` (e.g., 0.05) and monitor `eval/loss`.
- Reduce `num_epochs` or `max_steps`.
- Increase `weight_decay` (try 0.01--0.1).
- Use a smaller LoRA rank (`lora_r`). Typical values: 8--32.
- Increase dropout: `lora_dropout: 0.05`.
## RL/GRPO Stability
### Reward Never Increases
If `rewards/*/mean` stays at 0 for more than 20 steps:
1. **Test reward function standalone**: Run it outside training with known inputs to verify it returns nonzero values.
```bash
cd experiments && python -c "import my_rewards; print(my_rewards.accuracy_reward(...))"
```
2. **Check dataset columns**: The reward function receives `**kwargs` containing dataset columns. Verify the columns it needs (e.g., `answer`) are not removed by the dataset transform.
3. **Check completion content**: Enable `log_completions: true` in the `trl:` config and inspect logged completions in W&B. If completions are empty or incoherent, the model may be too weak for the task.
4. **Verify vLLM is serving the right model**: Hit the vLLM health endpoint and confirm the model name matches your config.
### Entropy Collapse (Mode Collapse)
**Symptom**: `entropy` drops below 0.01; all completions become nearly identical.
**Fixes**:
- Increase `temperature` in generation kwargs (try 0.8--1.0).
- Reduce learning rate.
- Add a KL penalty term (`beta` parameter in GRPO config).
- Check that `num_generations` is sufficient (16+ gives better advantage estimates).
### IS Ratio Divergence
**Symptom**: `sampling/importance_sampling_ratio/min` drops near 0, or `sampling/sampling_logp_difference/mean` exceeds 1.0.
This means the policy has diverged significantly from the weights used by vLLM for generation. The importance sampling correction becomes unreliable.
**Fixes**:
- Decrease `vllm_sync_interval` (sync weights more often).
- Enable `off_policy_mask_threshold` (e.g., 0.5) to mask stale off-policy samples.
- Use `importance_sampling_level: token` for finer-grained correction.
### Gradient Norm Instability
**Symptom**: `grad_norm` oscillates wildly or exceeds 10.0 regularly.
**Fixes**:
- Enable gradient clipping: `max_grad_norm: 1.0` (default in most configs).
- Reduce learning rate.
- Increase `gradient_accumulation_steps` to smooth out noisy batches.
- Check for NaN issues (see next section).
## NaN and Inf Handling
### Common Causes
| Cause | Where It Manifests | Detection |
|-------|-------------------|-----------|
| FP8 zero-scale division | Forward pass logits | `grad_norm: nan`, loss becomes NaN immediately |
| Gradient explosion | Backward pass | `grad_norm` spikes to inf, then loss goes NaN |
| Bad data (empty sequences) | Logprob computation | NaN in specific batches only |
| Numerical overflow in log-softmax | Loss computation | Large negative logprobs cause exp() overflow |
### FP8-Specific NaN Issues
FP8 quantization (`fp8: true`) can produce NaN when the activation quantization kernel divides by `max(abs(x)) / 448`. If the input tensor is all zeros (e.g., padding positions), the scale becomes 0, causing division by zero.
**Fixes applied in axolotl**:
- The `act_quant_kernel` has a zero-guard: `s = tl.where(s == 0, 1.0, s)`.
- A safety net `nan_to_num(logits, nan=0.0)` is applied in `_get_per_token_logps_and_entropies`.
- Embedding padding is zero-padded for FP8 compatibility.
::: {.callout-important}
## After Modifying Triton Kernels
If you patch any Triton JIT kernel (e.g., the FP8 quantization kernels in transformers), you must clear the Triton cache for changes to take effect:
```bash
rm -rf ~/.triton/cache
```
:::
### General NaN Debugging Steps
1. **Enable anomaly detection** (slow, but pinpoints the source):
```python
torch.autograd.set_detect_anomaly(True)
```
2. **Check grad_norm**: If it goes to NaN, the backward pass is the problem. If loss is NaN but grad_norm was fine on the previous step, the forward pass is the problem.
3. **Reduce to single GPU, single batch**: Eliminate distributed training variables.
4. **Inspect data**: Print the batch that triggers NaN. Look for empty sequences, extreme token IDs, or unexpected padding patterns.
## OOM Debugging
Out-of-memory errors are the most common training failure. Use this systematic approach, from least to most disruptive:
### Step 1: Reduce Batch Size
The single highest-impact change. VRAM scales roughly linearly with batch size.
```yaml
micro_batch_size: 1 # Start here
gradient_accumulation_steps: 16 # Increase to maintain effective batch size
```
For GRPO specifically, the logits tensor for policy logprob computation can be very large. `batch_size * num_generations * seq_len * vocab_size` in bf16. For example, with `num_generations: 16` and `micro_batch_size: 8`, the logits tensor alone is:
```
8 * 16 * 2048 * 151936 * 2 bytes = ~75 GB (way too large)
```
Reduce `micro_batch_size` to 2--4 for GRPO.
### Step 2: Enable Gradient Checkpointing
Trades compute for memory by recomputing activations during the backward pass instead of storing them.
```yaml
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false # Recommended default
```
::: {.callout-warning}
## Reentrant Checkpointing Exceptions
Some configurations require `use_reentrant: true`:
- DeepSpeed ZeRO-3 (non-reentrant causes `CheckpointError`)
- EBFT strided mode with flex_attention
:::
### Step 3: Use Quantization
Load the base model in reduced precision:
```yaml
# 4-bit QLoRA
adapter: qlora
load_in_4bit: true
# 8-bit
load_in_8bit: true
# FP8 (saves ~50% model VRAM, same compute speed as bf16)
fp8: true
```
### Step 4: Reduce Sequence Length
```yaml
sequence_len: 1024 # Down from 2048 or 4096
```
For GRPO, also reduce `max_completion_length`. Memory scales quadratically with sequence length when using standard attention.
### Step 5: Use Flash Attention
Reduces attention memory from O(n^2) to O(n):
```yaml
flash_attention: true
```
### Step 6: Offload with DeepSpeed
For extreme cases, offload optimizer states or parameters to CPU:
```yaml
deepspeed: deepspeed_configs/zero3_bf16.json
```
### Diagnosing the Specific Culprit
Use the `profiler_steps` config option to capture GPU memory snapshots:
```yaml
profiler_steps: [1, 2]
```
This generates PyTorch profiler traces you can inspect to see exactly which tensor allocation caused the OOM.
## Common Errors
| Error Message | Likely Cause | Fix |
|---------------|-------------|-----|
| `exitcode: -9` | System RAM exhaustion | Reduce dataset size, `dataset_num_proc`, or number of data workers |
| `exitcode: -7` (DeepSpeed) | DeepSpeed version issue | `pip install -U deepspeed` |
| `CUDA out of memory` | GPU VRAM exhaustion | Follow OOM debugging steps above |
| `RuntimeError: NCCL communicator was aborted` | GPU communication failure | See [NCCL docs](nccl.qmd); check `NCCL_DEBUG=INFO` output |
| `ValueError: Asking to pad but the tokenizer does not have a padding token` | Missing pad token | Add `special_tokens: { pad_token: "<\|endoftext\|>" }` to config |
| `'DummyOptim' object has no attribute 'step'` | DeepSpeed on single GPU | Remove `deepspeed:` section from config |
| `unable to load strategy X` then `None is not callable` | Reward module not importable | Run `cd experiments && python -c "import my_rewards"` to check |
| `generation_batch_size not divisible by num_generations` | micro_batch_size too small | Set `micro_batch_size >= num_generations` and make it divisible |
| `'weight' must be 2-D` | FSDP1 flattened parameters | Use `fsdp_version: 2` or skip `unwrap_model` when FSDP is enabled |
| `CheckpointError` (tensor count mismatch) | Non-reentrant checkpointing + ZeRO-3 or flex_attention | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
| `BFloat16` TypeError during weight sync | NumPy does not support bf16 | Fixed in axolotl's `weight_serde.py` (auto bf16 to fp16 conversion) |
| `Content end boundary is before start boundary` | Chat template parsing issue | Check `eos_token` matches template; file a GitHub issue if persistent |
| `CAS service error` during data processing | HuggingFace XET issue | Set `export HF_HUB_DISABLE_XET=1` |
| Training hangs (multi-GPU) | FSDP + async prefetch deadlock | Set `async_prefetch: false` with FSDP |
## Profiling
### PyTorch Profiler
Axolotl supports PyTorch profiler integration via the config:
```yaml
profiler_steps: [1, 2, 3]
```
This captures profiler traces for the specified steps. View them in TensorBoard:
```bash
tensorboard --logdir output_dir/runs
```
Or open the `.json` trace file in `chrome://tracing`.
### CUDA Memory Snapshots
For detailed memory analysis, use PyTorch's memory snapshot API. Add this to your training script or use it interactively:
```python
import torch
# Enable memory history tracking
torch.cuda.memory._record_memory_history()
# ... run your training step ...
# Save snapshot
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
```
Visualize with PyTorch's memory visualizer:
```bash
python -m torch.cuda.memory._viz memory_snapshot.pickle
```
### Quick GPU Memory Check
During training, monitor GPU utilization in a separate terminal:
```bash
watch -n 1 nvidia-smi
```
For programmatic access within axolotl, the logged metrics `memory/max_alloc` and `memory/max_reserved` come from `torch.cuda.max_memory_allocated()` and `torch.cuda.max_memory_reserved()`. Note these report PyTorch's view of memory, which may differ from `nvidia-smi` (see [FAQ](faq.qmd)).
## W&B and Logging
### Enabling Logging
```yaml
wandb_project: my-project
wandb_entity: my-team # optional
wandb_run_id: run-123 # optional, for resuming
wandb_name: experiment-name # optional
logging_steps: 1 # log every step (recommended for RL)
```
### Debug Logging
For detailed axolotl-internal debug output:
```bash
AXOLOTL_LOG_LEVEL=DEBUG axolotl train config.yaml 2>&1 | tee /tmp/training.log
```
::: {.callout-tip}
## Always Log to a File
Pipe training output to a log file so you can inspect it after the run:
```bash
axolotl train config.yaml 2>&1 | tee /tmp/my_run.log
```
:::
### What Axolotl Logs
**SFT metrics** (logged every `logging_steps`):
- `train/loss`, `eval/loss` -- training and validation loss
- `train/grad_norm` -- gradient L2 norm (before clipping)
- `train/learning_rate` -- current learning rate
- `memory/max_alloc`, `memory/max_reserved` -- peak GPU memory
**GRPO/RL metrics** (logged every step):
- `rewards/<name>/mean`, `rewards/<name>/std` -- per-reward-function statistics
- `reward`, `reward_std` -- aggregated reward across all reward functions
- `frac_reward_zero_std` -- fraction of prompt groups where all completions got the same reward
- `completions/mean_length`, `completions/min_length`, `completions/max_length` -- completion token lengths
- `completions/clipped_ratio` -- fraction of completions that hit the max length
- `completions/mean_terminated_length`, `completions/min_terminated_length`, `completions/max_terminated_length` -- lengths of naturally terminated completions
- `kl` -- KL divergence between policy and reference
- `entropy` -- policy entropy (measure of output diversity)
- `clip_ratio/region_mean`, `clip_ratio/low_mean`, `clip_ratio/high_mean` -- PPO clipping statistics
- `sampling/sampling_logp_difference/mean`, `sampling/sampling_logp_difference/max` -- log-probability difference between policy and sampling distribution
- `sampling/importance_sampling_ratio/min`, `sampling/importance_sampling_ratio/mean`, `sampling/importance_sampling_ratio/max` -- IS ratio statistics for off-policy correction
- `num_tokens` -- total tokens processed
### Reading W&B Charts
For a healthy GRPO run, expect to see:
1. **`reward/mean`**: Gradual upward trend. May start near 0 and reach 0.3--0.8 depending on task difficulty. Not monotonic -- fluctuations are normal.
2. **`entropy`**: Gradual decrease from initial values (often 0.3--0.6) as the model becomes more confident. Should not collapse to near-zero.
3. **`grad_norm`**: Mostly in the 0.001--1.0 range. Occasional 0.0 values are fine (zero-advantage skip). Persistent values above 10.0 need investigation.
4. **`kl`**: Starts near 0 and grows slowly. If it shoots up rapidly, the policy is diverging from the reference.
5. **`completions/mean_length`**: Should reflect the task's natural answer length. If it steadily increases to `max_completion_length`, the model may be reward-hacking by generating longer outputs.

View File

@@ -1,318 +0,0 @@
---
title: "vLLM Serving for GRPO Training"
description: "How to configure and run vLLM as a generation backend for GRPO reinforcement learning in Axolotl."
format:
html:
toc: true
toc-depth: 3
number-sections: true
execute:
enabled: false
---
## Overview {#sec-overview}
GRPO (Group Relative Policy Optimization) trains a language model by generating completions, scoring them with reward functions, and updating the policy to favor higher-reward outputs. The generation step is the bottleneck: producing thousands of tokens per training step with the policy model is slow using standard HuggingFace generation.
Axolotl uses [vLLM](https://github.com/vllm-project/vllm) as a high-throughput generation backend. vLLM runs as a separate process (either on a dedicated GPU or colocated on the training GPU) and serves completions via an HTTP API. The trainer sends prompts to vLLM, receives completions, scores them, and performs gradient updates.
```
┌──────────────────────┐ HTTP ┌──────────────────────┐
│ Trainer (GPU 1) │ ───────────────── │ vLLM Server (GPU 0)│
│ │ prompts/compls │ │
│ - Policy model │ ◄──────────────── │ - Same base model │
│ - Reward scoring │ │ - Fast generation │
│ - Gradient updates │ weight sync │ - LoRA adapter │
│ - LoRA adapter │ ─────────────────►│ (periodically │
│ │ (every N steps) │ updated) │
└──────────────────────┘ └──────────────────────┘
```
::: {.callout-important}
vLLM must serve the **same base model** specified in your training config. If the models do not match, weight synchronization will silently produce incorrect results.
:::
## Server Mode {#sec-server-mode}
Server mode runs vLLM as an external process on dedicated GPU(s). This is the recommended configuration for most setups.
### Starting the Server
Use the `axolotl vllm-serve` command with your training config:
```bash
# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve grpo_config.yaml
```
```bash
# Terminal 2: Start training on GPU 1
CUDA_VISIBLE_DEVICES=1 axolotl train grpo_config.yaml
```
The server reads vLLM settings from the `vllm:` section of your config and starts an HTTP server (default: `http://0.0.0.0:8000`).
::: {.callout-tip}
Use `tmux` or `screen` to manage the vLLM server process. Typical startup time is 30-90 seconds depending on model size and whether CUDA graphs are captured.
:::
### Minimal Server Config
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
gpu_memory_utilization: 0.85
dtype: auto
max_model_len: 4096
rl: grpo
trl:
use_vllm: true
vllm_server_host: 0.0.0.0
vllm_server_port: 8000
vllm_server_timeout: 300
```
### Multi-GPU vLLM
For larger models, use tensor parallelism across multiple GPUs:
```yaml
vllm:
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
```
```bash
# vLLM on GPUs 2,3; training on GPUs 0,1
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve grpo_config.yaml
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo_config.yaml --num-processes 2
```
::: {.callout-note}
Due to how TRL maps vLLM device indices, the vLLM instance should use the **last** N GPUs (highest device indices), while training uses the first N.
:::
## Colocate Mode {#sec-colocate-mode}
Colocate mode runs vLLM on the same GPU as the trainer. This is useful when you only have a single GPU.
```yaml
trl:
use_vllm: true
vllm_mode: colocate
vllm_enable_sleep_mode: true
```
With `vllm_enable_sleep_mode: true`, vLLM offloads its VRAM allocation when not actively generating, freeing memory for training. When the trainer needs new completions, vLLM wakes up and reclaims VRAM.
::: {.callout-warning}
Colocate mode is significantly slower than server mode because generation and training cannot overlap. The GPU alternates between the two workloads. This mode is practical only for smaller models (up to ~3B on a 24 GB GPU).
:::
**When to use colocate mode:**
- You have exactly one GPU
- The model fits in memory with both vLLM and training active (with sleep mode), or is small enough to time-share
- You accept the performance tradeoff for simpler setup (no separate vLLM process to manage)
**When to use server mode:**
- You have two or more GPUs
- You want maximum throughput (generation overlaps with training via async prefetch)
- You are running larger models (7B+)
## LoRA Sync {#sec-lora-sync}
LoRA sync is the recommended weight synchronization method when training with LoRA adapters. Instead of merging adapter weights into the base model and broadcasting the full merged weights over NCCL, it saves only the LoRA adapter files to the filesystem and tells vLLM to load them natively.
### How It Works
1. The trainer calls `model.save_pretrained()` to write the LoRA adapter weights to a temporary directory
2. The trainer sends an HTTP POST to `/set_lora_adapter/` on the vLLM server
3. vLLM loads the adapter using its native LoRA support (Punica kernels)
4. Generation uses the updated adapter on the next request
### Benefits
- **Smaller sync payload**: Transfers ~40 MB of LoRA weights instead of ~1.4 GB+ of merged model weights (for a typical 0.5-3B model)
- **No NCCL communicator**: Eliminates the need for a cross-GPU NCCL communication channel, removing GPU contention between vLLM generation and weight sync
- **Faster sync**: ~200 ms per sync vs. 350 ms to 5+ seconds for NCCL merge sync
- **Simpler multi-GPU**: No need to set up NCCL groups between trainer and vLLM processes
### Configuration
```yaml
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
trl:
vllm_lora_sync: true # Enables LoRA sync mode
vllm_sync_interval: 5 # Sync every 5 training steps
```
Setting `vllm_lora_sync: true` automatically selects the LoRA-aware vLLM serve script (`axolotl.scripts.vllm_serve_lora`). You do not need to set `vllm.serve_module` manually.
::: {.callout-important}
LoRA sync requires that you are training with a LoRA adapter (`adapter: lora` or `adapter: qlora`). It is not applicable to full fine-tuning.
:::
## Weight Synchronization {#sec-weight-sync}
During GRPO training, the policy model on the trainer is continuously updated via gradient steps. The vLLM server, however, still holds the old weights. Periodically, the trainer must push updated weights to vLLM so that future generations reflect the improved policy.
### Sync Interval
The `vllm_sync_interval` parameter controls how often weights are synced:
```yaml
trl:
vllm_sync_interval: 5 # Sync every 5 optimizer steps
```
**Tradeoffs:**
- **Lower interval** (e.g., 1-3): Fresher generations, better on-policy data, but more sync overhead per step
- **Higher interval** (e.g., 5-10): Less overhead, but generations become increasingly off-policy between syncs
- **Recommended**: 3-5 for most setups. Axolotl includes importance sampling correction (`vllm_importance_sampling_correction: true`) to handle mild distribution mismatch from stale vLLM weights.
### Sync Methods
| Method | Config | Payload | Mechanism | Typical Time |
|--------|--------|---------|-----------|-------------|
| **LoRA sync** | `vllm_lora_sync: true` | LoRA adapter only (~40 MB) | Filesystem + HTTP | ~200 ms |
| **NCCL merge sync** | Default (no lora_sync) | Full merged weights (~1.4 GB+) | HTTP trigger + NCCL broadcast | 350 ms - 5 s |
::: {.callout-tip}
If you are training with LoRA (which is recommended for GRPO), always enable `vllm_lora_sync: true`. The performance difference is substantial, especially as training progresses and NCCL contention increases.
:::
### Importance Sampling Correction
When vLLM weights are stale (between syncs), the generated data is slightly off-policy. Axolotl can correct for this:
```yaml
trl:
vllm_importance_sampling_correction: true
importance_sampling_level: token # 'token' or 'sequence'
off_policy_mask_threshold: 0.5 # KL threshold for masking stale sequences
```
- **Token-level IS** is recommended when using Liger kernel (sequence-level has numerical issues with chunked computation)
- **Off-policy sequence masking (OPSM)** drops sequences that have diverged too far from the current policy, providing a safety net against stale data
## Restart Requirements {#sec-restart}
::: {.callout-warning}
**vLLM must be restarted between training runs.** Weight syncs from a previous run leave the server in a corrupted state. If you start a new training run against a stale vLLM server, the model may fail to learn.
:::
### When to Restart
- Before every new training experiment
- After a training run crashes or is interrupted
- If you change the base model in your config
### How to Restart
Killing vLLM reliably requires terminating both the main process and its background EngineCore subprocess:
```bash
# Kill all vLLM-related processes
pkill -9 -f "vllm|EngineCore"
# Verify GPU memory is freed
nvidia-smi
# Restart the server
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve grpo_config.yaml
```
::: {.callout-tip}
A single `kill` often does not fully stop vLLM. Always use `kill -9` and verify with `nvidia-smi` that GPU memory has been released before restarting.
:::
### Health Check
The vLLM server exposes a health endpoint. Wait for it to return 200 before starting training:
```bash
# For the LoRA serve script (trailing slash required)
curl http://localhost:8000/health/
# For the default TRL serve script
curl http://localhost:8000/health
```
## Configuration Reference {#sec-config-reference}
### vLLM Server Options (`vllm:` section)
These control the vLLM server process started by `axolotl vllm-serve`.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `host` | str | `0.0.0.0` | Host address for the vLLM server |
| `port` | int | `8000` | Port for the vLLM server |
| `device` | str | `auto` | Device to use for vLLM |
| `tensor_parallel_size` | int | `None` | Number of GPUs for tensor parallelism |
| `data_parallel_size` | int | `None` | Number of data parallel replicas |
| `gpu_memory_utilization` | float | `0.9` | Fraction of GPU memory for vLLM (0.0-1.0) |
| `dtype` | str | `auto` | Data type (`auto`, `float16`, `bfloat16`) |
| `max_model_len` | int | `None` | Maximum model context length. Set explicitly if the default is too large for your GPU |
| `enable_prefix_caching` | bool | `None` | Enable prefix caching for repeated prompt prefixes |
| `enable_reasoning` | bool | `None` | Enable reasoning mode for models with thinking tokens |
| `reasoning_parser` | str | `None` | Parser for reasoning output |
| `enforce_eager` | bool | `None` | Disable CUDA graph capture (required for some architectures like Qwen3.5 hybrid attention) |
| `serve_module` | str | `None` | Python module for vLLM serve script. Auto-set when `vllm_lora_sync: true` |
| `worker_extension_cls` | str | `None` | vLLM worker extension class for weight sync |
### Trainer vLLM Options (`trl:` section)
These control how the trainer interacts with vLLM.
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `use_vllm` | bool | `false` | Enable vLLM for generation |
| `vllm_mode` | str | `None` | `server` (external process) or `colocate` (same GPU) |
| `vllm_server_host` | str | `0.0.0.0` | Host of the vLLM server to connect to |
| `vllm_server_port` | int | `8000` | Port of the vLLM server to connect to |
| `vllm_server_timeout` | int | `None` | Timeout in seconds for vLLM requests |
| `vllm_lora_sync` | bool | `false` | Sync LoRA adapters via filesystem instead of NCCL merge |
| `vllm_sync_interval` | int | `None` | Sync weights every N optimizer steps |
| `vllm_enable_sleep_mode` | bool | `None` | Offload vLLM VRAM when idle (colocate mode) |
| `vllm_guided_decoding_regex` | str | `None` | Regex constraint for guided decoding |
For async pipeline and off-policy correction options, see the [GRPO Configuration Reference](grpo.qmd#configuration-reference).
## Complete Example {#sec-complete-example}
For a full working GRPO config including vLLM, LoRA sync, async generation, rewards, and dataset setup, see the [GRPO Quick Start](grpo.qmd#quick-start). That config includes all the vLLM settings covered in this guide.
```bash
# Terminal 1: Start vLLM
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve grpo_config.yaml
# Wait for health check to pass
curl http://localhost:8000/health/
# Terminal 2: Start training
CUDA_VISIBLE_DEVICES=1 axolotl train grpo_config.yaml
```
## Troubleshooting {#sec-troubleshooting}
| Problem | Likely Cause | Solution |
|---------|-------------|----------|
| Training hangs waiting for vLLM | Server not started or wrong port | Check `curl http://localhost:8000/health/` and verify `vllm_server_host`/`vllm_server_port` match |
| OOM on vLLM GPU | `gpu_memory_utilization` too high or `max_model_len` too large | Reduce `gpu_memory_utilization` to 0.7 or set `max_model_len` explicitly |
| OOM on training GPU | Batch too large for policy logprobs | Reduce `micro_batch_size` or `num_generations` |
| Accuracy stays at zero | Stale vLLM from previous run | Restart vLLM: `pkill -9 -f "vllm\|EngineCore"`, verify with `nvidia-smi`, restart |
| `ResponseValidationError` from vLLM | Missing logprobs in response | Ensure you are using the correct serve module (auto-selected with `vllm_lora_sync: true`) |
| Weight sync takes 5+ seconds | NCCL contention with vLLM generation | Switch to `vllm_lora_sync: true` to eliminate NCCL |
| `async_prefetch` deadlocks with FSDP | Background threads run unsynchronized FSDP collectives | Set `async_prefetch: false` when using FSDP or DeepSpeed multi-GPU |

View File

@@ -6,8 +6,6 @@ LFM2 features a new hybrid Liquid architecture with multiplicative gates, short-
This guide shows how to fine-tune both the LFM2 and LFM2-VL models with Axolotl.
Thanks to the team at LiquidAI for giving us early access to prepare for these releases.
## Getting Started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
@@ -33,14 +31,6 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
axolotl train examples/LiquidAI/lfm2-vl-lora.yaml
```
**LFM2-MoE**
```bash
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
# LoRA SFT (1x48GB @ 16.2GiB)
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
```
### TIPS
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
@@ -55,13 +45,14 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
## Optimization Guides
- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html)
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
## Related Resources
- [LFM2 Blog](https://www.liquid.ai/blog/liquid-foundation-models-v2-our-second-series-of-generative-ai-models)
- [LFM2-VL Blog](https://www.liquid.ai/blog/lfm2-vl-efficient-vision-language-models)
- [LFM2-MoE Blog](https://www.liquid.ai/blog/lfm2-8b-a1b-an-efficient-on-device-mixture-of-experts)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,7 +1,6 @@
base_model: LiquidAI/LFM2-350M
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
chunked_cross_entropy: true
eot_tokens:
- "<|im_end|>"

View File

@@ -1,59 +0,0 @@
base_model: LiquidAI/LFM2-8B-A1B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
eot_tokens:
- "<|im_end|>"
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 4
num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
bf16: true
tf32: true
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 2
saves_per_epoch: 1
weight_decay: 0.0
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -3,9 +3,6 @@ trust_remote_code: true
model_type: AutoModelForImageTextToText
processor_type: AutoProcessor
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false

View File

@@ -7,24 +7,3 @@ techniques. It is a combination of:
- Activation Offloading: Offload activations to CPU RAM to reduce memory usage
For more information, you can check out the ALST paper [here](https://www.arxiv.org/abs/2506.13996).
## Usage
```yaml
tiled_mlp: true
# See Sequence Parallelism docs
# https://docs.axolotl.ai/docs/sequence_parallelism.html
context_parallel_size: int
plugins:
# See Cut Cross Entropy docs
# https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# or Liger Kernel docs
# https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels
- axolotl.integrations.liger.LigerPlugin
# ...
```

View File

@@ -1,110 +0,0 @@
# Finetune Swiss-AI's Apertus with Axolotl
[Apertus](https://huggingface.co/collections/swiss-ai/apertus-llm-68b699e65415c231ace3b059) is a family of opensource models trained by Swiss-ai.
This guide shows how to fine-tune it with Axolotl with multi-turn conversations and proper masking.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Apertus is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. (Optional, highly recommended) Install XIELU CUDA
```bash
## Recommended for reduced VRAM and faster speeds
# Point to CUDA toolkit directory
# For those using our Docker image, use the below path.
export CUDA_HOME=/usr/local/cuda
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
```
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
3. Run the finetuning example:
```bash
axolotl train examples/apertus/apertus-8b-qlora.yaml
```
This config uses about 8.7 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
### Tips
- For inference, the official Apertus team recommends `top_p=0.9` and `temperature=0.8`.
- You can instead use full paremter fine-tuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
### XIELU Installation Issues
#### `ModuleNotFoundError: No module named 'torch'`
Please check these one by one:
- Running in correct environment
- Env has PyTorch installed
- CUDA toolkit is at `CUDA_HOME`
If those didn't help, please try the below solutions:
1. Pass env for CMAKE and try install again:
```bash
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
```
2. Git clone the repo and manually hardcode python path:
```bash
git clone https://github.com/nickjbrowning/XIELU
cd xielu
git checkout 59d6031
cd xielu
nano CMakeLists.txt # or vi depending on your preference
```
```diff
execute_process(
- COMMAND ${Python_EXECUTABLE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
+ COMMAND /root/miniconda3/envs/py3.11/bin/python -c "import torch.utils; print(torch.utils.cmake_prefix_path)"
RESULT_VARIABLE TORCH_CMAKE_PATH_RESULT
OUTPUT_VARIABLE TORCH_CMAKE_PATH_OUTPUT
ERROR_VARIABLE TORCH_CMAKE_PATH_ERROR
)
```
```bash
pip3 install . --no-build-isolation --no-deps
```
## Optimization Guides
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
## Related Resources
- [Apertus Tech Report](https://github.com/swiss-ai/apertus-tech-report/blob/main/Apertus_Tech_Report.pdf)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,64 +0,0 @@
base_model: swiss-ai/Apertus-8B-Instruct-2509
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -17,11 +17,8 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Run the finetuning example:

View File

@@ -9,6 +9,10 @@ strict: false
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_property_mappings:
role: role
content: content
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -1,10 +0,0 @@
provider: baseten
project_name:
secrets:
- HF_TOKEN
- WANDB_API_KEY
gpu: h100
gpu_count: 8
node_count: 1

File diff suppressed because it is too large Load Diff

View File

@@ -9,6 +9,10 @@ strict: false
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_property_mappings:
role: role
content: content
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -9,6 +9,10 @@ strict: false
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_property_mappings:
role: role
content: content
dataset_prepared_path:
val_set_size: 0.05

View File

@@ -16,17 +16,11 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
```bash
python scripts/cutcrossentropy_install.py | sh
```
3. Run the finetuning example:
2. Run the finetuning example:
```bash
axolotl train examples/devstral/devstral-small-qlora.yml

Some files were not shown because too many files have changed in this diff Show More