Compare commits

..

3 Commits

Author SHA1 Message Date
Wing Lian
2d13a06722 slow fsdp1 test 2026-02-10 13:23:52 -05:00
Wing Lian
ba27e830e8 triton versions for older pytorch 2026-02-10 11:09:03 -05:00
Wing Lian
8f7219e139 upgrade liger to 0.6.5 and triton to 3.5.1 2026-02-10 11:05:00 -05:00
247 changed files with 1382 additions and 30875 deletions

View File

@@ -68,12 +68,7 @@ You can skip certain CI checks by including specific keywords in your commit mes
### Code Style
axolotl uses [Ruff](https://docs.astral.sh/ruff/) as its code style guide. Please ensure that your code follows these guidelines.
Use the pre-commit linter to ensure that your code is formatted consistently.
```bash
pre-commit run --all-files
```
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
### Commit Messages
@@ -83,6 +78,6 @@ Write clear and concise commit messages that briefly describe the changes made i
- [GitHub Help](https://help.github.com/)
- [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests)
- [Ruff](https://docs.astral.sh/ruff/)
- [{codestyle}]({URLofCodestyle})
Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together!

View File

@@ -15,9 +15,6 @@ on:
- '.github/workflows/base.yml'
workflow_dispatch:
permissions:
contents: read
jobs:
build-base:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
@@ -54,30 +51,14 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
- cuda: "129"
cuda_version: 12.9.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -94,14 +75,6 @@ jobs:
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
@@ -127,7 +100,7 @@ jobs:
images: |
axolotlai/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v3
uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
@@ -135,7 +108,7 @@ jobs:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
uses: docker/build-push-action@v4
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
@@ -176,14 +149,6 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -192,30 +157,14 @@ jobs:
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
- cuda: "129"
cuda_version: 12.9.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-uv-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -232,14 +181,6 @@ jobs:
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -250,7 +191,7 @@ jobs:
images: |
axolotlai/axolotl-base-uv
- name: Login to Docker Hub
uses: docker/login-action@v3
uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
@@ -258,7 +199,7 @@ jobs:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
uses: docker/build-push-action@v4
with:
context: .
file: ./docker/${{ matrix.dockerfile }}

View File

@@ -13,9 +13,6 @@ on:
- ".pre-commit-config.yaml"
workflow_dispatch:
permissions:
contents: read
jobs:
pre-commit:
name: pre-commit

View File

@@ -8,9 +8,6 @@ on:
- "v*"
workflow_dispatch:
permissions:
contents: read
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
@@ -37,28 +34,16 @@ jobs:
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
@@ -101,89 +86,11 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-uv:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
- name: Build and export to Docker
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
file: ./docker/Dockerfile-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
@@ -205,28 +112,16 @@ jobs:
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
@@ -264,86 +159,11 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-uv:
needs: build-axolotl-uv
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-cloud-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-no-tmux:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: 128

View File

@@ -8,7 +8,6 @@ on:
- 'setup.py'
- 'pyproject.toml'
- '.github/workflows/multi-gpu-e2e.yml'
- 'scripts/cutcrossentropy_install.py'
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
- 'src/axolotl/utils/distributed.py'
workflow_dispatch:
@@ -20,9 +19,6 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions:
contents: read
env:
MODAL_IMAGE_BUILDER_VERSION: "2025.06"
@@ -39,13 +35,19 @@ jobs:
pytorch: 2.8.0
axolotl_extras: fbgemm-gpu
num_gpus: 2
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras: "fbgemm-gpu"
# num_gpus: 2
# dockerfile: "Dockerfile-uv.jinja"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras: "fbgemm-gpu"
num_gpus: 2
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras: "fbgemm-gpu"
num_gpus: 2
dockerfile: "Dockerfile-uv.jinja"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
@@ -53,13 +55,6 @@ jobs:
axolotl_extras:
# axolotl_extras: fbgemm-gpu
num_gpus: 2
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.10.0
axolotl_extras: "fbgemm-gpu"
num_gpus: 2
dockerfile: "Dockerfile-uv.jinja"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -81,9 +76,8 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run -m cicd.multigpu

View File

@@ -5,9 +5,6 @@ on:
schedule:
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
permissions:
contents: read
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}

View File

@@ -5,8 +5,6 @@ on:
- cron: '0 0 1 * *' # Run monthly
workflow_dispatch: # Manual kickoff
permissions: {}
jobs:
auto-update:
runs-on: ubuntu-latest

View File

@@ -14,8 +14,14 @@ on:
- .github/workflows/preview-docs.yml
permissions:
contents: read
checks: write
contents: write
deployments: write
issues: write
discussions: write
pages: write
pull-requests: write
statuses: write
jobs:
preview:

View File

@@ -3,11 +3,9 @@ name: publish pypi
on:
push:
tags:
- "v*"
- 'v*'
workflow_dispatch:
permissions: {}
jobs:
setup_release:
name: Create Release
@@ -30,8 +28,7 @@ jobs:
name: pypi
url: https://pypi.org/p/axolotl
permissions:
contents: read
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
steps:
- name: Check out repository code
uses: actions/checkout@v4
@@ -49,7 +46,7 @@ jobs:
- name: Extract tag name
id: tag
run: echo "TAG_NAME=$(echo $GITHUB_REF | cut -d / -f 3)" >> "$GITHUB_OUTPUT"
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
- name: Update version in VERSION file
run: |

View File

@@ -3,13 +3,6 @@ on:
workflow_dispatch:
schedule:
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
pull_request:
types: [opened, synchronize, reopened, ready_for_review]
paths:
- '.github/workflows/tests-nightly.yml'
permissions:
contents: read
jobs:
pre-commit:
@@ -25,26 +18,15 @@ jobs:
env:
SKIP: no-commit-to-branch
prime-cdn-s3-cache:
name: Prefetch S3 once to prime the CDN cache
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
timeout-minutes: 10
steps:
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
curl -v -H "Range: bytes=0-1023" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
pytest:
name: PyTest
runs-on: ubuntu-latest
needs: [prime-cdn-s3-cache]
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
python_version: ["3.11"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
timeout-minutes: 20
steps:
@@ -55,7 +37,7 @@ jobs:
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python
uses: actions/setup-python@v5
@@ -66,7 +48,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
@@ -120,23 +102,16 @@ jobs:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.8.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.10.0
num_gpus: 1
axolotl_extras:
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
nightly_build: "true"
steps:
- name: Checkout
@@ -157,11 +132,9 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.e2e_tests
docker-e2e-multigpu-tests:
@@ -202,8 +175,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.multigpu

View File

@@ -28,9 +28,6 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
permissions:
contents: read
env:
TRANSFORMERS_IS_CI: "yes"
@@ -49,32 +46,21 @@ jobs:
env:
SKIP: no-commit-to-branch
prime-cdn-s3-cache:
name: Prefetch S3 once to prime the CDN cache
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
timeout-minutes: 10
steps:
- name: Restore Cache from S3
id: hf-cache-restore-s3
run: |
curl -v -H "Range: bytes=0-1023" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
pytest:
name: PyTest
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
needs: [prime-cdn-s3-cache]
# needs: [preload-cache]
strategy:
fail-fast: false
matrix:
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude:
# - python_version: "3.14"
# pytorch_version: "2.8.0"
# - python_version: "3.14"
# pytorch_version: "2.9.1"
python_version: ["3.11", "3.12"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
timeout-minutes: 20
steps:
@@ -89,7 +75,7 @@ jobs:
id: hf-cache-restore-s3
run: |
mkdir -p ~/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python
@@ -160,18 +146,17 @@ jobs:
name: PyTest from Source Dist
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
needs: [prime-cdn-s3-cache]
strategy:
fail-fast: false
matrix:
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
# exclude:
# - python_version: "3.14"
# pytorch_version: "2.8.0"
# - python_version: "3.14"
# pytorch_version: "2.9.1"
timeout-minutes: 30
python_version: ["3.11", "3.12"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
timeout-minutes: 20
steps:
- name: cleanup node
@@ -185,7 +170,7 @@ jobs:
id: hf-cache-restore-s3
run: |
mkdir -p ~/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python
@@ -279,8 +264,8 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 130
cuda_version: 13.0.0
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
num_gpus: 1
@@ -306,10 +291,9 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.e2e_tests
@@ -342,12 +326,6 @@ jobs:
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.10.0
num_gpus: 1
axolotl_extras:
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
@@ -375,10 +353,9 @@ jobs:
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
run: |
modal run cicd.e2e_tests
@@ -392,9 +369,9 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:
@@ -418,6 +395,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.cleanup

View File

@@ -11,7 +11,7 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.15.4
rev: v0.14.10
hooks:
- id: ruff
args: [--fix]
@@ -26,7 +26,7 @@ repos:
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.9.4
rev: 1.9.2
hooks:
- id: bandit
args: [

View File

@@ -29,23 +29,8 @@
## 🎉 Latest Updates
- 2026/03:
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
- 2026/02:
- [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.
- Axolotl now has support for [SageAttention](https://github.com/axolotl-ai-cloud/axolotl/pull/2823) and [GDPO](https://github.com/axolotl-ai-cloud/axolotl/pull/3353) (Generalized DPO).
- 2026/01:
- New integration for [EAFT](https://github.com/axolotl-ai-cloud/axolotl/pull/3366) (Entropy-Aware Focal Training), weights loss by entropy of the top-k logit distribution, and [Scalable Softmax](https://github.com/axolotl-ai-cloud/axolotl/pull/3338), improves long context in attention.
- 2025/12:
- Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
- [Distributed Muon Optimizer](https://github.com/axolotl-ai-cloud/axolotl/pull/3264) support has been added for FSDP2 pretraining.
- 2025/12: Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).
<details>
<summary>Expand older updates</summary>
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
- 2025/07:
@@ -54,10 +39,15 @@
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
- [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
<details>
<summary>Expand older updates</summary>
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
@@ -72,10 +62,10 @@ Axolotl is a free and open-source tool designed to streamline post-training and
Features:
- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, and audio models like Voxtral with image, video, and audio support.
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention 2/3/4](https://docs.axolotl.ai/docs/attention.html#flash-attention), [Xformers](https://docs.axolotl.ai/docs/attention.html#xformers), [Flex Attention](https://docs.axolotl.ai/docs/attention.html#flex-attention), [SageAttention](https://docs.axolotl.ai/docs/attention.html#sageattention), [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels), [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.

View File

@@ -1 +1 @@
0.16.0.dev0
0.15.0.dev0

View File

@@ -128,9 +128,11 @@ quartodoc:
- monkeypatch.mistral_attn_hijack_flash
- monkeypatch.multipack
- monkeypatch.relora
- monkeypatch.llama_expand_mask
- monkeypatch.lora_kernels
- monkeypatch.utils
- monkeypatch.btlm_attn_hijack_flash
- monkeypatch.llama_patch_multipack
- monkeypatch.stablelm_attn_hijack_flash
- monkeypatch.trainer_fsdp_optim
- monkeypatch.transformers_fa_utils
@@ -329,7 +331,6 @@ website:
- docs/sequence_parallelism.qmd
- docs/gradient_checkpointing.qmd
- docs/nd_parallelism.qmd
- docs/expert_quantization.qmd
- section: "Troubleshooting"
contents:

View File

@@ -1,208 +0,0 @@
"""Benchmark for entropy_from_logits Triton kernel vs original chunked implementation.
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_entropy.py
"""
import gc
import statistics
import torch
import torch.nn.functional as F
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
V = 151936 # Qwen vocab
WARMUP = 5
BENCH_ITERS = 20
MEM_ITERS = 10
def entropy_from_logits_original(logits: torch.Tensor, chunk_size: int = 128):
"""Original chunked implementation (reference)."""
original_shape = logits.shape[:-1]
num_classes = logits.shape[-1]
flat_logits = logits.reshape(-1, num_classes)
entropies = []
for chunk in flat_logits.split(chunk_size, dim=0):
logps = F.log_softmax(chunk, dim=-1)
chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
entropies.append(chunk_entropy)
return torch.cat(entropies, dim=0).reshape(original_shape)
def _clean_gpu():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.synchronize()
def profile_time(fn, logits, n_iters=BENCH_ITERS):
for _ in range(WARMUP):
out = fn(logits, chunk_size=128)
del out
torch.cuda.synchronize()
times = []
for _ in range(n_iters):
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
out = fn(logits, chunk_size=128)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
del out
return times
def profile_memory(fn, logits, n_iters=MEM_ITERS):
for _ in range(WARMUP):
out = fn(logits, chunk_size=128)
del out
torch.cuda.synchronize()
peaks = []
for _ in range(n_iters):
_clean_gpu()
base = torch.cuda.max_memory_allocated()
out = fn(logits, chunk_size=128)
torch.cuda.synchronize()
peaks.append(torch.cuda.max_memory_allocated() - base)
del out
return [p / 1e6 for p in peaks]
def fmt(values, unit=""):
mean = statistics.mean(values)
std = statistics.stdev(values) if len(values) > 1 else 0.0
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
def benchmark_contiguous():
print("=" * 60)
print(
f"CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
)
print("=" * 60)
configs = [
(1, 2048),
(1, 8192),
(1, 16384),
(4, 4096),
(8, 2048),
(16, 2048),
(16, 4096),
]
for B, L in configs:
mem_gb = B * L * V * 2 / 1e9
if mem_gb > 28:
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
print(f"{'' * 60}")
torch.manual_seed(42)
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
t_orig = profile_time(entropy_from_logits_original, logits)
t_triton = profile_time(entropy_from_logits, logits)
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" TIME (ms):")
print(f" original: {fmt(t_orig, 'ms')}")
print(f" triton: {fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(entropy_from_logits_original, logits)
m_triton = profile_memory(entropy_from_logits, logits)
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" MEMORY (peak overhead):")
print(f" original: {fmt(m_orig, 'MB')}")
print(f" triton: {fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del logits
_clean_gpu()
def benchmark_noncontiguous():
print("\n" + "=" * 60)
print(
f"NON-CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
)
print("=" * 60)
configs = [
(4, 2048, "transpose"),
(4, 8192, "transpose"),
(8, 2048, "transpose"),
(4, 4096, "slice_batch"),
]
for B, L, method in configs:
torch.manual_seed(42)
if method == "transpose":
raw = torch.randn(L, B, V, device="cuda", dtype=torch.bfloat16)
logits_nc = raw.transpose(0, 1)
raw_gb = L * B * V * 2 / 1e9
elif method == "slice_batch":
raw = torch.randn(B * 2, L, V, device="cuda", dtype=torch.bfloat16)
logits_nc = raw[::2]
raw_gb = B * 2 * L * V * 2 / 1e9
else:
continue
if raw_gb > 28:
print(f"\n skip B={B}, L={L}, {method} ({raw_gb:.1f} GB)")
del raw, logits_nc
torch.cuda.empty_cache()
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B}, L={L} {method} ({N} rows, raw {raw_gb:.2f} GB)")
print(f"{'' * 60}")
def original_with_copy(logits, chunk_size=128):
return entropy_from_logits_original(
logits.contiguous(), chunk_size=chunk_size
)
t_orig = profile_time(original_with_copy, logits_nc)
t_triton = profile_time(entropy_from_logits, logits_nc)
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" TIME (ms):")
print(f" orig+copy: {fmt(t_orig, 'ms')}")
print(f" triton-strided:{fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(original_with_copy, logits_nc)
m_triton = profile_memory(entropy_from_logits, logits_nc)
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" MEMORY (peak overhead):")
print(f" orig+copy: {fmt(m_orig, 'MB')}")
print(f" triton-strided:{fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del raw, logits_nc
_clean_gpu()
if __name__ == "__main__":
benchmark_contiguous()
benchmark_noncontiguous()

View File

@@ -1,284 +0,0 @@
"""Benchmark for ScatterMoE LoRA Triton kernels.
Measures forward, backward dX, and backward dA/dB kernels at common MoE
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
and full fwd+bwd autograd throughput.
Usage:
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
"""
import argparse
import gc
import time
from functools import partial
import torch
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
lora_ops,
ops as base_ops,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
flatten_sort_count,
)
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
ScatterMoELoRA,
)
DEVICE = "cuda"
DTYPE = torch.bfloat16
WARMUP = 5
ITERS = 20
# ─── Model configs ──────────────────────────────────────────────────────────
BUILTIN_CONFIGS = {
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
"Qwen3-30B-A3B": (128, 2048, 768, 8),
"OLMoE-1B-7B": (64, 2048, 1024, 8),
"Mixtral-8x7B": (8, 4096, 14336, 2),
}
def _resolve_config(spec):
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
key = spec.lower().replace("/", "-")
for name, cfg in BUILTIN_CONFIGS.items():
if key in name.lower() or name.lower() in key:
return name, cfg
from transformers import AutoConfig
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
if callable(getattr(hf_cfg, "get_text_config", None)):
tc = hf_cfg.get_text_config()
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
hf_cfg = tc
hidden = hf_cfg.hidden_size
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
experts = (
getattr(hf_cfg, "num_experts", None)
or getattr(hf_cfg, "num_local_experts", None)
or getattr(hf_cfg, "n_routed_experts", None)
)
top_k = (
getattr(hf_cfg, "num_experts_per_tok", None)
or getattr(hf_cfg, "num_experts_per_token", None)
or 2
)
name = spec.split("/")[-1]
return name, (experts, hidden, inter, top_k)
# ─── Benchmark helpers ──────────────────────────────────────────────────────
def _clean():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
def _bench(fn, warmup=WARMUP, iters=ITERS):
for _ in range(warmup):
fn()
torch.cuda.synchronize()
times = []
for _ in range(iters):
torch.cuda.synchronize()
t0 = time.perf_counter()
fn()
torch.cuda.synchronize()
times.append((time.perf_counter() - t0) * 1000)
times.sort()
return times[len(times) // 2]
def _setup(num_experts, K, N, T, top_k, R):
torch.manual_seed(42)
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
logits = torch.randn(T, num_experts, device=DEVICE)
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
gx = base_ops.group(x, ssi, fan_out=top_k)
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
return lora_ops.scatter2scatter_lora(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
lora_A=lA,
lora_B=lB,
scaling=2.0,
)
def _call_base(x, W, sei, ssi, top_k):
return base_ops.scatter2scatter(
X=x,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=top_k,
)
def _call_dx(dy, W, sei, ssi, lA, lB):
return lora_ops.scatter2scatter_lora_dX(
DY=dy,
W=W,
sorted_expert_idxs=sei,
sorted_scattered_idxs=ssi,
k=1,
lora_A=lA,
lora_B=lB,
scaling=2.0,
dy_grouped=True,
dx_grouped=False,
)
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
return lora_ops.group_bwd_lora(
DY=dy,
X=gx,
lora_A=lA,
lora_B=lB,
expert_offsets=eo,
E=num_experts,
scaling=2.0,
)
# ─── Main ────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
parser.add_argument(
"--models",
"-m",
nargs="+",
help="Model names or HF IDs (default: all builtins)",
)
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
parser.add_argument("--seq-len", "-T", type=int, default=2048)
args = parser.parse_args()
T = args.seq_len
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"T={T}, ranks={args.ranks}\n")
if args.models:
configs = [_resolve_config(m) for m in args.models]
else:
configs = list(BUILTIN_CONFIGS.items())
for model_name, (num_experts, hidden, inter, top_k) in configs:
print(f"{'=' * 70}")
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
print(f"{'=' * 70}")
for R in args.ranks:
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
_clean()
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
num_experts, K, N, T, top_k, R
)
# Forward with LoRA (auto-dispatched: fused or split)
dispatch = (
"split"
if (
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
)
else "fused"
)
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
total = t_fwd + t_dx + t_bwd
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
print(
f" R={R:>2} {proj:<8} "
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
f"base={t_base:>6.2f}ms "
f"(+{overhead * 100:.0f}%) "
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
f"total={total:>6.2f}ms"
)
# Full autograd fwd+bwd with memory measurement
x_ag = x.clone().requires_grad_(True)
lA_ag = lA.clone().requires_grad_(True)
lB_ag = lB.clone().requires_grad_(True)
def _run_autograd(
_x=x_ag,
_W=W,
_k=top_k,
_sei=sei,
_ssi=ssi,
_eo=eo,
_lA=lA_ag,
_lB=lB_ag,
):
out = ScatterMoELoRA.apply(
_x,
_W,
_k,
_sei,
_ssi,
_eo,
_lA,
_lB,
2.0,
None,
None,
False,
False,
True,
False,
)
out.sum().backward()
_x.grad = None
_lA.grad = None
_lB.grad = None
t_full = _bench(_run_autograd)
_clean()
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
_run_autograd()
torch.cuda.synchronize()
mem_peak = torch.cuda.max_memory_allocated() - mem_before
print(
f" full_fwd_bwd={t_full:>6.2f}ms "
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
)
print()
if __name__ == "__main__":
main()

View File

@@ -1,191 +0,0 @@
"""Benchmark for selective_log_softmax Triton kernel vs original implementation.
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_selective_logsoftmax.py
"""
import gc
import statistics
import torch
from axolotl.monkeypatch.trainer.utils import (
selective_log_softmax,
selective_log_softmax_original,
)
V = 151936 # Qwen vocab
WARMUP = 5
BENCH_ITERS = 20
MEM_ITERS = 10
def _clean_gpu():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.synchronize()
def profile_time(fn, args, n_iters=BENCH_ITERS):
for _ in range(WARMUP):
fn(*args)
torch.cuda.synchronize()
times = []
for _ in range(n_iters):
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
fn(*args)
e.record()
torch.cuda.synchronize()
times.append(s.elapsed_time(e))
return times
def profile_memory(fn, args, n_iters=MEM_ITERS):
for _ in range(WARMUP):
out = fn(*args)
del out
torch.cuda.synchronize()
peaks = []
for _ in range(n_iters):
_clean_gpu()
base = torch.cuda.max_memory_allocated()
out = fn(*args)
torch.cuda.synchronize()
peaks.append(torch.cuda.max_memory_allocated() - base)
del out
return [p / 1e6 for p in peaks]
def fmt(values, unit=""):
mean = statistics.mean(values)
std = statistics.stdev(values) if len(values) > 1 else 0.0
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
def benchmark_forward():
print("=" * 60)
print(f"FORWARD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
print("=" * 60)
configs = [
(1, 2048),
(1, 8192),
(4, 4096),
(8, 2048),
(16, 2048),
(16, 4096),
]
for B, L in configs:
mem_gb = B * L * V * 2 / 1e9
if mem_gb > 28:
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
print(f"{'' * 60}")
torch.manual_seed(42)
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
index = torch.randint(0, V, (B, L), device="cuda")
t_orig = profile_time(selective_log_softmax_original, (logits, index))
t_triton = profile_time(selective_log_softmax, (logits, index))
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" TIME (ms):")
print(f" original: {fmt(t_orig, 'ms')}")
print(f" triton: {fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(selective_log_softmax_original, (logits, index))
m_triton = profile_memory(selective_log_softmax, (logits, index))
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" MEMORY (peak overhead):")
print(f" original: {fmt(m_orig, 'MB')}")
print(f" triton: {fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del logits, index
_clean_gpu()
def benchmark_backward():
print("\n" + "=" * 60)
print(f"FWD+BWD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
print("=" * 60)
configs = [
(1, 2048),
(1, 8192),
(4, 4096),
(8, 2048),
(16, 2048),
(16, 4096),
]
def fwd_bwd_original(logits, index):
logits.grad = None
out = selective_log_softmax_original(logits, index)
out.sum().backward()
def fwd_bwd_triton(logits, index):
logits.grad = None
out = selective_log_softmax(logits, index)
out.sum().backward()
for B, L in configs:
mem_gb = B * L * V * 2 / 1e9
if mem_gb > 20:
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB, need room for grads)")
continue
N = B * L
print(f"\n{'' * 60}")
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
print(f"{'' * 60}")
torch.manual_seed(42)
logits_orig = torch.randn(
B, L, V, device="cuda", dtype=torch.bfloat16, requires_grad=True
)
logits_tri = logits_orig.detach().clone().requires_grad_(True)
index = torch.randint(0, V, (B, L), device="cuda")
t_orig = profile_time(fwd_bwd_original, (logits_orig, index))
t_triton = profile_time(fwd_bwd_triton, (logits_tri, index))
orig_mean = statistics.mean(t_orig)
triton_mean = statistics.mean(t_triton)
print(" FWD+BWD TIME (ms):")
print(f" original: {fmt(t_orig, 'ms')}")
print(f" triton: {fmt(t_triton, 'ms')}")
print(f" speedup: {orig_mean / triton_mean:.2f}x")
m_orig = profile_memory(fwd_bwd_original, (logits_orig, index))
m_triton = profile_memory(fwd_bwd_triton, (logits_tri, index))
orig_peak = statistics.mean(m_orig)
triton_peak = statistics.mean(m_triton)
print(" FWD+BWD MEMORY (peak overhead):")
print(f" original: {fmt(m_orig, 'MB')}")
print(f" triton: {fmt(m_triton, 'MB')}")
print(f" saved: {orig_peak - triton_peak:.1f} MB")
del logits_orig, logits_tri, index
_clean_gpu()
if __name__ == "__main__":
benchmark_forward()
benchmark_backward()

View File

@@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace
@@ -31,9 +31,8 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN uv pip install packaging==26.0 setuptools==78.1.1
RUN uv pip install packaging==26.0 setuptools==75.8.0
RUN uv pip install torchvision
RUN uv pip uninstall causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -12,7 +12,7 @@ ENV HF_HOME="{{ HF_HOME }}"
ENV AXOLOTL_DATASET_NUM_PROC="8"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
WORKDIR /workspace
@@ -32,8 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN pip install packaging==26.0 setuptools==78.1.1 psutil
RUN pip uninstall -y causal_conv1d
RUN pip install packaging==26.0 setuptools==75.8.0 psutil
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -3,14 +3,6 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
set -o pipefail
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
# hf download "NousResearch/Meta-Llama-3-8B"
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
# hf download "microsoft/Phi-4-reasoning"
# hf download "microsoft/Phi-3.5-mini-instruct"
# hf download "microsoft/Phi-3-medium-128k-instruct"
# Run unit tests with initial coverage report
pytest -v --durations=10 -n8 \
--ignore=tests/e2e/ \

View File

@@ -68,6 +68,10 @@ def run_cmd(cmd: str, run_folder: str):
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
# Propagate errors from subprocess.
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
if exit_code:
raise RuntimeError(f"Command '{cmd}' failed with exit code {exit_code}")
try:
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
if exit_code:
print(f"Command '{cmd}' failed with exit code {exit_code}")
return exit_code
except Exception as e: # pylint: disable=broad-except
print(f"Command '{cmd}' failed with exception {e}")

View File

@@ -37,7 +37,6 @@ coverage:
only_pulls: false
flags: null
paths: null
informational: true
parsers:
gcov:

View File

@@ -22,7 +22,6 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN pip uninstall -y causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \

View File

@@ -59,18 +59,34 @@ RUN git lfs install --skip-repo && \
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
pip3 install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$CUDA" = "128" ]; then \
if [ "$TARGETARCH" = "amd64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl"; \
WHL_VERSION="v0.5.4"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl"; \
WHL_VERSION="v0.6.4"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
pip3 install --no-cache-dir ${WHL_FILE}; \
rm ${WHL_FILE}; \
elif [ "$CUDA" = "130" ]; then \
if [ "$TARGETARCH" = "amd64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl"; \
WHL_VERSION="v0.5.4"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl"; \
WHL_VERSION="v0.6.4"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
pip3 install --no-cache-dir ${WHL_FILE}; \
rm ${WHL_FILE}; \
fi \
;; \
esac

View File

@@ -1,30 +0,0 @@
ARG BASE_TAG=main
FROM axolotlai/axolotl-uv:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
EXPOSE 8888
EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN uv pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/* && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"]

View File

@@ -1,48 +0,0 @@
ARG BASE_TAG=main-base
FROM axolotlai/axolotl-base-uv:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN uv pip uninstall causal_conv1d
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \
python scripts/unsloth_install.py --uv | sh && \
python scripts/cutcrossentropy_install.py --uv | sh && \
uv pip install pytest && \
uv cache clean
# fix so that git fetch/pull from remote works with shallow clone
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch && \
git config --global credential.helper store
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
RUN chmod +x /root/.axolotl-complete.bash && \
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc

View File

@@ -6,7 +6,6 @@ ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ARG TARGETARCH
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.6.0"
ARG CUDA="126"
@@ -40,18 +39,28 @@ RUN if [ "$TARGETARCH" = "amd64" ]; then \
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
fi
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
uv pip install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$TARGETARCH" = "amd64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
fi \
elif [ "$TARGETARCH" = "arm64" ]; then \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
fi \
fi \
;; \
esac

View File

@@ -13,10 +13,9 @@ sdp_attention: true
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Flash Attention
## Flash Attention 2
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
based on your installed packages and GPU.
Uses efficient kernels to compute attention.
```yaml
flash_attention: true
@@ -24,9 +23,11 @@ flash_attention: true
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Flash Attention 2
### Nvidia
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
Requirements: Ampere, Ada, or Hopper GPUs
Note: For Turing GPUs or lower, please use other attention methods.
```bash
pip install flash-attn --no-build-isolation
@@ -34,12 +35,11 @@ pip install flash-attn --no-build-isolation
::: {.callout-tip}
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
Alternatively, try reinstall or downgrade a version.
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
:::
### Flash Attention 3
#### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
@@ -50,44 +50,6 @@ cd flash-attention/hopper
python setup.py install
```
### Flash Attention 4
Requirements: Hopper or Blackwell GPUs
```bash
pip install flash-attn-4
```
Or from source:
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/flash_attn/cute
pip install -e .
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
# Remove it so Python can find the real FA4 module:
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
```
::: {.callout-note}
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
for training on Hopper, install from source using the instructions above.
:::
::: {.callout-warning}
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
and falls back to FA2/3.
:::
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
### AMD
Requirements: ROCm 6.0 and above.

View File

@@ -210,8 +210,6 @@ axolotl lm-eval config.yml
Configuration options:
```yaml
lm_eval_model: # model to evaluate (local or hf path)
# List of tasks to evaluate
lm_eval_tasks:
- arc_challenge
@@ -220,7 +218,7 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
### delinearize-llama4

View File

@@ -1,67 +0,0 @@
---
title: "MoE Expert Quantization"
description: "Reduce VRAM usage when training MoE model adapters by quantizing expert weights on load"
---
Transformers v5 changed MoE expert layers from `nn.Linear` to fused `nn.Parameter` (3D+ tensors).
This means `bitsandbytes` can no longer quantize them during model loading, resulting in all expert
weights being loaded in full bf16 precision and causing massive VRAM usage.
`quantize_moe_experts` solves this by quantizing expert weights during model loading.
It intercepts the weight loading process, quantizes each expert tensor on the fly, and
immediately frees the original bf16 tensor from VRAM. This dramatically reduces peak memory.
For example, GLM-4.7-Flash QLoRA drops from ~127GiB to ~23GiB reserved memory.
## Usage
Enable expert quantization in your Axolotl config:
```yaml
quantize_moe_experts: true
```
This works with both 4-bit (QLoRA) and 8-bit (LoRA) quantization.
### Expert LoRA targeting
You can optionally apply LoRA adapters directly to expert weights using `lora_target_parameters`:
```yaml
lora_target_parameters:
- mlp.experts.gate_up_proj
- mlp.experts.down_proj
# - mlp.gate.weight # router
```
::: {.callout-note}
`lora_dropout` must be `0` when using `lora_target_parameters`.
:::
## Requirements
- Requires (`adapter: lora` and `load_in_8bit: true`) or (`adapter: qlora` and `load_in_4bit: true`)
- CUDA GPUs only (not tested with ROCm or other backends)
- FSDP2 compatible for distributed training
## Limitations
- `lora_target_linear` is not compatible with `quantize_moe_experts`. See [Expert LoRA targeting](#expert-lora-targeting) instead.
- `cpu_ram_efficient_loading` hangs / takes long time with FSDP2 + QLoRA.
- Total model parameter count may display incorrectly (trainable param count is correct).
- FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps, which then drops. QLoRA does not exhibit this.
- FSDP2 may use more VRAM per GPU than single GPU training due to not all layers being properly sharded across ranks.
- Model loading takes longer due to on-demand quantization, even on consecutive runs.
- DeepSpeed has not been tested.
## Implementation details
The quantization is applied by patching transformers to intercept weight loading.
When a 3D+ CUDA tensor with "expert" in its name is detected:
- **4-bit mode:** Uses bitsandbytes NF4 parametrization (configurable via `bnb_4bit_quant_type`).
- **8-bit mode:** Uses a custom row-wise int8 parametrization with bitsandbytes dequantization.
The original bf16 tensor is freed immediately after quantization. Multiple sub-patches are applied to
transformers, PEFT and accelerate FSDP2 to support these parametrized expert modules.
For full implementation details, see [PR #3439](https://github.com/axolotl-ai-cloud/axolotl/pull/3439).

View File

@@ -1,5 +1,5 @@
---
title: Gradient Checkpointing, Activation Offloading, and Layer Offloading
title: Gradient Checkpointing and Activation Offloading
---
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
@@ -27,33 +27,3 @@ The `activation_offloading: legacy` naively offloads activations to CPU and with
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
### Enabling Layer Offloading
```yaml
layer_offloading: true
```
Layer offloading reduces GPU memory usage by moving frozen (non-trainable) decoder layer parameters to CPU
and streaming them back to GPU one layer at a time during the forward and backward passes. This is
particularly useful for LoRA/QLoRA training where most of the model's parameters are frozen — only the
trainable adapter weights stay on GPU permanently.
During training, forward and backward hooks on each decoder layer handle the transfer automatically:
- **Forward pass:** Before a layer executes, its frozen params are loaded to GPU. The next layer is
prefetched asynchronously on a separate CUDA stream for overlap.
- **Backward pass:** Same pattern in reverse — the current layer's frozen params are loaded and the
previous layer is prefetched.
After each layer finishes, its frozen params are offloaded back to CPU pinned memory.
This approach trades some CPU-GPU transfer overhead for significant GPU memory savings — the freed memory
is roughly equal to the size of all frozen parameters across all decoder layers, minus one layer's worth
that is kept on GPU at any given time.
**Requirements:**
- CUDA GPU (CPU-only training is not supported for this feature)
- Works with any HuggingFace model architecture that uses decoder layers (Llama, Mistral, Qwen, etc.)
- Best combined with LoRA/QLoRA where most parameters are frozen

View File

@@ -13,14 +13,12 @@ format:
- [Pixtral](#sec-pixtral)
- [Llava-1.5](#sec-llava-15)
- [Mistral-Small-3.1](#sec-mistral-small-31)
- [Mistral-Small-4](#sec-mistral-small-4)
- [Magistral-Small-2509](#sec-magistral-small-2509)
- [Voxtral](#sec-voxtral)
- [Gemma-3](#sec-gemma-3)
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [Qwen3.5](#sec-qwen3-5)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
@@ -110,12 +108,6 @@ Please make sure to install vision lib via `pip install 'mistral-common[opencv]=
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
```
### Mistral-Small-4 {#sec-mistral-small-4}
```yaml
base_model: mistralai/Mistral-Small-4-119B-2603
```
### Magistral-Small-2509 {#sec-magistral-small-2509}
::: {.callout-tip}
@@ -192,14 +184,6 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### Qwen3.5 {#sec-qwen3-5}
```yaml
base_model: Qwen/Qwen3.5-9B
chat_template: qwen3_5
```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.

View File

@@ -54,13 +54,6 @@ These techniques save VRAM by changing how activations are handled.
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
### Layer Offloading
Offloads frozen (non-trainable) decoder layer parameters to CPU and streams them back to GPU one layer at a time during forward/backward passes using CUDA stream prefetching. Especially effective for LoRA/QLoRA where most parameters are frozen.
- **Config:** `layer_offloading: true`
- **Learn more:** [Layer Offloading Docs](gradient_checkpointing.qmd#enabling-layer-offloading)
### Cut Cross Entropy (CCE)
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
@@ -73,15 +66,6 @@ Provides efficient Triton kernels to improve training speed and reduce memory us
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
### Expert Kernels
Optimized kernel implementations for Mixture of Experts (MoE) model training.
- **ScatterMoE**: Triton-based MoE kernels with fused LoRA support.
- **SonicMoE**: CUTLASS-based MoE kernels for NVIDIA Hopper and Blackwell GPUs.
- **Learn more:** [Custom Integrations - Kernels Integration](custom_integrations.qmd#kernels-integration)
## Long Context Models
Techniques to train models on sequences longer than their original context window.
@@ -147,10 +131,3 @@ Simulates quantization effects during training, helping the model adapt and pote
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)
### MoE Expert Quantization
Quantizes MoE expert weights on load to reduce VRAM when training MoE models with adapters. Required for Transformers v5+ MoE models where experts use fused `nn.Parameter` tensors.
- **Config:** `quantize_moe_experts: true`
- **Learn more:** [MoE Expert Quantization](expert_quantization.qmd)

View File

@@ -721,213 +721,6 @@ trl:
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
#### Async GRPO
Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.
```yaml
trl:
use_data_producer: true # Enable data producer protocol
use_vllm: true
async_prefetch: true # Generate rollouts in background thread
prefetch_depth: 1 # Number of rollouts to prefetch
vllm_sync_interval: 2 # Sync weights to vLLM every N steps
```
::: {.callout-note}
Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by `vllm_importance_sampling_correction: true` (default when async is enabled).
:::
##### vLLM LoRA Sync
By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.
```yaml
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
trl:
vllm_lora_sync: true # Enable native LoRA sync
```
When `vllm_lora_sync: true` is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:
```bash
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
```
Then start training on a separate GPU:
```bash
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
```
::: {.callout-tip}
LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.
:::
##### Streaming Partial Batch
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.
```yaml
trl:
streaming_partial_batch: true
```
##### Importance Sampling Correction
When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.
```yaml
trl:
vllm_importance_sampling_correction: true # Enable IS correction
importance_sampling_level: token # 'token' or 'sequence'
off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below this
```
- `importance_sampling_level: token` applies per-token IS ratios (recommended with Liger kernel)
- `importance_sampling_level: sequence` applies per-sequence IS ratios
- `off_policy_mask_threshold` masks out sequences where the IS ratio indicates they are too far off-policy
##### Replay Buffer
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.
```yaml
trl:
replay_buffer_size: 100 # Max cached groups (0 = disabled)
replay_recompute_logps: true # Recompute log-probs for replayed data (recommended)
```
::: {.callout-note}
When `replay_recompute_logps: true` (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.
:::
##### Deferred Re-rolling
Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.
```yaml
trl:
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
reroll_max_groups: 1 # Max groups to replace per batch
```
##### Zero-Advantage Batch Skipping
When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as `skipped_zero_adv_batches=1`.
```yaml
trl:
skip_zero_advantage_batches: true # default
```
##### Parallel Reward Workers
Reward functions that use `signal.alarm()` (e.g., `math_verify`) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.
```yaml
trl:
reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism)
```
##### Full Async GRPO Example
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
gpu_memory_utilization: 0.35
dtype: auto
adapter: lora
lora_r: 32
lora_alpha: 64
lora_target_linear: true
rl: grpo
trl:
use_data_producer: true
use_vllm: true
async_prefetch: true
prefetch_depth: 1
vllm_sync_interval: 2
vllm_lora_sync: true
streaming_partial_batch: true
vllm_importance_sampling_correction: true
off_policy_mask_threshold: 0.5
importance_sampling_level: token
num_generations: 8
max_completion_length: 512
reward_funcs:
- rewards.accuracy_reward
reroll_start_fraction: 0.5
replay_buffer_size: 100
reward_num_workers: 4
skip_zero_advantage_batches: true
datasets:
- path: AI-MO/NuminaMath-TIR
type: rewards.prompt_transform
split: train
gradient_accumulation_steps: 4
micro_batch_size: 2
max_steps: 500
learning_rate: 1e-5
bf16: true
gradient_checkpointing: true
```
```bash
# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Terminal 2: Train on GPU 1
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
```
##### Multi-GPU Async GRPO
Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.
**FSDP:**
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
gradient_checkpointing_kwargs:
use_reentrant: false
```
**DeepSpeed ZeRO-3:**
```yaml
deepspeed: deepspeed_configs/zero3_bf16.json
gradient_checkpointing_kwargs:
use_reentrant: true # Required for ZeRO-3
```
```bash
# Terminal 1: Start vLLM on GPU 0
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
# Terminal 2: Train on GPUs 0,1
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yaml
```
::: {.callout-important}
With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.
:::
### GDPO
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b\""
]
},
{

View File

@@ -1,5 +1,8 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -24,11 +27,6 @@ datasets:
val_set_size: 0.0
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_r: 32
lora_alpha: 16

View File

@@ -1,5 +1,8 @@
base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -24,11 +27,6 @@ datasets:
val_set_size: 0.0
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_r: 32
lora_alpha: 16

View File

@@ -1,5 +1,9 @@
base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true
# gemma3 doesn't seem to play nice with ddp
@@ -20,11 +24,6 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
# Freeze vision tower
unfrozen_parameters:
- ^model\.language_model\..*
- ^lm_head\..*
adapter: qlora
lora_model_dir:

View File

@@ -1,72 +0,0 @@
# Finetune Z.ai's GLM-4.5-Air with Axolotl
[GLM-4.5-Air](https://huggingface.co/zai-org/GLM-4.5-Air) is a MoE model by Z.ai.
This guide shows how to fine-tune it with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
# QLoRA (1x80GB @ ~63.4GiB/GPU)
axolotl train examples/glm45/glm-45-air-qlora.yaml
```
### Dataset
In addition to the standard OpenAI Messages format, GLM-4.5 supports an extra parameter for thinking in the assistant section.
```json
{
"role": "assistant",
"reasoning_content": "...", // or have </think>...</think> in `content`
"content": "..."
}
```
Make sure you set the below extra attributes if needed:
```yaml
datasets:
- path: ...
type: chat_template
message_property_mappings:
role: role
content: content
# tool_calls: tool_calls # uncomment if using tools
# reasoning_content: reasoning_content # uncomment if have reasoning
# Uncomment if training on tool role (you would rarely if ever need this)
# eot_tokens:
# - <|observation|>
```
### Tips
- The role name for tools in this template is `tool`.
- You will see this Axolotl WARNING — this is expected as the template does not use EOS:
```
EOS token '<|endoftext|>' not found in chat_template. Please check if your template/EOS token is correct.
```
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config.
- **LoRA kernels**: Incompatible with this model. Must be explicitly disabled (`lora_*_kernel: false`).
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [GLM-4.5-Air on HuggingFace](https://huggingface.co/zai-org/GLM-4.5-Air)
- [GLM-4.5 Blog](https://z.ai/blog/glm-4.5)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,64 +0,0 @@
base_model: zai-org/GLM-4.5-Air
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: false
load_in_4bit: true
quantize_moe_experts: true # important
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 16
lora_alpha: 8
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,65 +0,0 @@
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.
This guide shows how to fine-tune it with Axolotl.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
```bash
# QLoRA
# - no target experts (1x48GB @ ~24GiB/GPU)
# - target experts (1x48GB @ ~34GiB/GPU)
axolotl train examples/glm47-flash/qlora.yaml
# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)
axolotl train examples/glm47-flash/qlora_fsdp.yaml
```
```bash
# LoRA
# - no target experts (1x48GB @ ~35GiB/GPU)
# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)
axolotl train examples/glm47-flash/lora.yaml
# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)
axolotl train examples/glm47-flash/lora_fsdp.yaml
```
### MoE Expert Quantization & Expert LoRA
This model quantize expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs.
## Limitations
- **lora_target_linear**: Incompatible for this model.
- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).
### TIPS
- For inference, the official Z.ai team recommends these default settings (most tasks):
- `temperature: 1.0`
- `top_p: 0.95`
- `max_new_tokens: 131072`
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,65 +0,0 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-lora-8bit-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -1,75 +0,0 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_8bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out
adapter: lora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,65 +0,0 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1

View File

@@ -1,75 +0,0 @@
base_model: zai-org/GLM-4.7-Flash
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
# Uncomment to also target MoE expert weights:
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# LoRA kernels incompatible with DSA attention
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,65 +0,0 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
datasets:
- path: yahma/alpaca-cleaned
type: alpaca
split: train[:95%]
output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/dataset_prepared
sequence_len: 2048
flash_attention: true
qat:
activation_dtype: mxfp4
weight_dtype: mxfp4
group_size: 32
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_checkpointing: true
activation_offloading: true
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_8bit
cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true
resume_from_checkpoint:
logging_steps: 1
evals_per_epoch: 1
saves_per_epoch: 1
warmup_ratio: 0.1
weight_decay: 0.0
special_tokens:
pad_token: <|finetune_right_pad_id|>
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,82 +0,0 @@
# Finetune Mistral Small 4 with Axolotl
Mistral Small 4 is a 119B parameter (6.5B active) multimodal MoE model from MistralAI that unifies instruct, reasoning, and coding capabilities into a single model. It is available on HuggingFace at [Mistral-Small-4-119B-2603](https://huggingface.co/mistralai/Mistral-Small-4-119B-2603).
Thanks to the team at MistralAI for giving us early access to prepare for this release.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
3. Install transformers from main
```bash
pip install git+https://github.com/huggingface/transformers.git
```
4. Run one of the example configs:
```bash
# text-only
axolotl train examples/mistral4/qlora-text.yml # no experts ~69 GiB, experts ~93 GiB
axolotl train examples/mistral4/fft-text.yml
# text + vision
# run: wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
axolotl train examples/mistral4/qlora-vision.yml # no experts ~68 GiB
axolotl train examples/mistral4/fft-vision.yml
```
Note: FFT configs provided as reference. Please adjust hyperparameters as needed.
## Reasoning Effort
The chat template supports a `reasoning_effort` variable to control the model's reasoning depth:
- `"none"` — instruct mode (default)
- `"high"` — reasoning mode with explicit thinking steps
Pass it via `chat_template_kwargs` under your dataset config:
```yaml
datasets:
- path: your/dataset
type: chat_template
chat_template_kwargs:
reasoning_effort: high
```
## Thinking Support
The chat template supports a `thinking` content type in assistant messages for training on reasoning traces (rendered as `[THINK]...[/THINK]` blocks).
To use thinking datasets, add the `thinking` mapping via `message_property_mappings`:
```yaml
datasets:
- path: your/thinking-dataset
type: chat_template
message_property_mappings:
role: role
content: content
thinking: thinking
chat_template_kwargs:
reasoning_effort: high
```
See the [Magistral thinking guide](../magistral/think/README.md) for dataset format details.
## Tips
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
## Related Resources
- [MistralAI Mistral Small 4 Blog](https://mistral.ai/news/mistral-small-4)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,58 +0,0 @@
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_sonicmoe: true
# only train language model layers, freeze vision tower
unfrozen_parameters:
- model.language_model.*
- lm_head
- embed_tokens
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Mistral4DecoderLayer
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,57 +0,0 @@
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
processor_type: AutoProcessor
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_sonicmoe: true
# vision requirements
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: Nanobit/text-vision-2k-test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
sequence_len: 2048
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 2e-5
bf16: true
tf32: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: false
state_dict_type: FULL_STATE_DICT
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Mistral4DecoderLayer
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,58 +0,0 @@
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
sequence_len: 2048
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
# uncomment to train on expert layers
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# lora_mlp_kernel: false
# lora_qkv_kernel: false
# lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -1,63 +0,0 @@
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
processor_type: AutoProcessor
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
load_in_4bit: true
quantize_moe_experts: true
# vision chat template requirements
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
datasets:
- path: Nanobit/text-vision-2k-test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
adapter: qlora
sequence_len: 2048
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
# uncomment to train on expert layers
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
# lora_mlp_kernel: false
# lora_qkv_kernel: false
# lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
tf32: true
gradient_checkpointing: true
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -1,57 +0,0 @@
base_model: nvidia/Nemotron-Mini-4B-Instruct
load_in_8bit: false
load_in_4bit: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/nemotron-mini-4b-qlora
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- up_proj
- down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
special_tokens:

View File

@@ -6,13 +6,30 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
Here is an example of how to install from main for pip:
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
python scripts/cutcrossentropy_install.py | sh
```
2. Install Qwen3-Next transformers commit
```bash
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
```
3. Install FLA for improved performance
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
```
4. Run the finetuning example:
@@ -21,7 +38,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
```
This config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.
This config uses about 45.62 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀

View File

@@ -9,8 +9,6 @@ plugins:
load_in_8bit: false
load_in_4bit: true
quantize_moe_experts: true
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
@@ -27,7 +25,7 @@ sample_packing: true
lora_r: 16
lora_alpha: 8
lora_dropout: 0
lora_dropout: 0.05
lora_target_modules:
- linear_attn.in_proj_ba
- linear_attn.in_proj_qkvz
@@ -36,19 +34,12 @@ lora_target_modules:
- shared_expert.down_proj
- shared_expert.gate_proj
- shared_expert_gate
- mlp.gate
- q_proj
- v_proj
- k_proj
- o_proj
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
wandb_project:
wandb_entity:
wandb_watch:

View File

@@ -1,84 +0,0 @@
base_model: Qwen/Qwen3.5-122B-A10B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_config:
fsdp_version: 2
offload_params: true
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,74 +0,0 @@
base_model: Qwen/Qwen3.5-122B-A10B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,59 +0,0 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# Full fine-tune (FFT) of the text-only path of Qwen3.5-27B.
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
# Freeze vision encoder
unfrozen_parameters:
- model\.language_model\..*
- lm_head\..*
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,81 +0,0 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
# Uncomment below to also target the linear attention projections.
# These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).
# - linear_attn.in_proj_qkv
# - linear_attn.in_proj_z
# - linear_attn.out_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_config:
fsdp_version: 2
offload_params: false
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,70 +0,0 @@
base_model: Qwen/Qwen3.5-27B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
# Uncomment below to also target the linear attention projections.
# These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).
# - linear_attn.in_proj_qkv
# - linear_attn.in_proj_z
# - linear_attn.out_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,85 +0,0 @@
base_model: Qwen/Qwen3.5-35B-A3B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_config:
fsdp_version: 2
offload_params: true
cpu_ram_efficient_loading: false
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,74 +0,0 @@
base_model: Qwen/Qwen3.5-35B-A3B
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
strict: false
chat_template: qwen3_5
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
dataset_prepared_path: last_run_prepared
sequence_len: 2048
sample_packing: true
load_in_4bit: true
quantize_moe_experts: true
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
# Regex matching to target shared experts too
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
# Target experts
# lora_target_parameters:
# - mlp.experts.gate_up_proj
# - mlp.experts.down_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
lora_mlp_kernel: false
lora_qkv_kernel: false
lora_o_kernel: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,49 +0,0 @@
base_model: Qwen/Qwen3.5-9B
processor_type: AutoProcessor
# Required for multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen3_5
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
pad_to_sequence_len: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,66 +0,0 @@
base_model: Qwen/Qwen3.5-9B
processor_type: AutoProcessor
# These 3 lines are required for vision/multimodal training
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen3_5
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
# Targets the language model attention and MLP layers.
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- down_proj
- up_proj
# Uncomment to also target the linear attention (GatedDeltaNet) projections:
# - linear_attn.in_proj_qkv
# - linear_attn.in_proj_z
# - linear_attn.out_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -1,86 +0,0 @@
# Finetune Qwen3.5 with Axolotl
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.
## Getting started
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
```bash
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
```
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
4. Pick any config from the table below and run:
```bash
axolotl train examples/qwen3.5/<config>.yaml
```
Available configs:
| Config | Model | Type | Peak VRAM |
|---|---|---|---|
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
| `27b-qlora.yaml` | Qwen3.5-27B | Dense, text-only QLoRA | ~47 GiB |
| `27b-fft.yaml` | Qwen3.5-27B | Dense, text-only FFT (vision frozen) | ~53 GiB |
| `27b-qlora-fsdp.yaml` | Qwen3.5-27B | Dense, text-only QLoRA + FSDP2 | — |
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
| `35b-a3b-moe-qlora-fsdp.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA + FSDP2 | — |
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
| `122b-a10b-moe-qlora-fsdp.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA + FSDP2 | — |
### Gated DeltaNet Linear Attention
Qwen3.5 interleaves standard attention with Gated DeltaNet linear attention layers. To apply LoRA to them, add to `lora_target_modules`:
```yaml
lora_target_modules:
# ... standard projections ...
- linear_attn.in_proj_qkv
- linear_attn.in_proj_z
- linear_attn.out_proj
```
### Routed Experts (MoE)
To apply LoRA to routed expert parameters, add `lora_target_parameters`:
```yaml
lora_target_parameters:
- mlp.experts.gate_up_proj
- mlp.experts.down_proj
# - mlp.gate.weight # router
```
### Shared Experts (MoE)
Routed experts and shared experts both have `gate_up_proj`/`down_proj`, so a plain module name in `lora_target_modules` would match both. Use a regex to target only attention and shared expert projections, while `lora_target_parameters` above handles routed experts separately:
```yaml
lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
```
### TIPS
- For inference hyp, please see the respective model card details.
- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.
## Optimization Guides
- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html)
## Related Resources
- [Qwen3.5 Blog](https://qwenlm.github.io/blog/qwen3.5/)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -8,15 +8,13 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the finetuning example:
2. Run the finetuning example:
```bash
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
```
This config uses about 24.9 GiB VRAM (w/o CCE).
This config uses about 24.9 GiB VRAM.
Let us know how it goes. Happy finetuning! 🚀
@@ -31,6 +29,10 @@ Let us know how it goes. Happy finetuning! 🚀
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Limitations
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
## Related Resources
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)

View File

@@ -1,4 +1,5 @@
base_model: arcee-ai/Trinity-Nano-Preview
trust_remote_code: true
revision_of_model: 2ee94b0
# Automatically upload checkpoint and final model to HF

View File

@@ -61,11 +61,5 @@ skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
[tool.pytest.ini_options]
addopts = "-m 'not slow'"
markers = [
"slow: marks tests as slow",
]
[tool.uv.extra-build-dependencies]
axolotl = ["huggingface_hub"]

View File

@@ -5,25 +5,22 @@ bitsandbytes==0.49.1
triton>=3.4.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
liger-kernel==0.7.0
liger-kernel==0.6.5
# END section
packaging==26.0
huggingface_hub>=1.1.7
peft>=0.18.1
tokenizers>=0.22.1
transformers==5.3.0
accelerate==1.13.0
transformers==5.0.0
accelerate==1.12.0
datasets==4.5.0
deepspeed>=0.18.6,<0.19.0
trl==0.29.0
hf_xet==1.3.2
kernels==0.12.2
deepspeed>=0.18.3
trl==0.27.1
hf_xet==1.2.0
kernels==0.11.5
fla-core==0.4.1
flash-linear-attention==0.4.1
trackio>=0.16.1
trackio>=0.13.0
typing-extensions>=4.15.0
optimum==1.16.2
@@ -66,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.16.0
torchao==0.13.0
openenv-core==0.1.0
schedulefree==1.4.1
@@ -75,4 +72,4 @@ axolotl-contribs-mit==0.0.6
# telemetry
posthog==6.7.11
mistral-common==1.10.0
mistral-common==1.8.8

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"'
)

View File

@@ -26,18 +26,6 @@ def parse_requirements(extras_require_map):
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
install_xformers = platform.machine() != "aarch64"
if platform.machine() == "aarch64":
# skip on ARM64
skip_packages = [
"torchao",
"fla-core",
"flash-linear-attention",
]
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]
if "Darwin" in platform.system():
# skip packages not compatible with OSX
skip_packages = [
@@ -81,23 +69,16 @@ def parse_requirements(extras_require_map):
f"https://download.pytorch.org/whl/{torch_cuda_version}"
)
if (major, minor) >= (2, 10):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.5.0",
"fbgemm-gpu-genai==1.5.0",
]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.17.1"]
elif (major, minor) >= (2, 9):
if (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.4.0",
"fbgemm-gpu-genai==1.4.2",
]
extras_require_map["vllm"] = ["vllm==0.11.1"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.13.0"]
if patch == 0:
extras_require_map["vllm"] = ["vllm==0.13.0"]
else:

View File

@@ -6,6 +6,5 @@ from axolotl.logging_config import configure_logging
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
os.environ.setdefault("TRL_EXPERIMENTAL_SILENCE", "1")
configure_logging()

View File

@@ -3,7 +3,6 @@
import os
from pathlib import Path
import httpcore
from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
@@ -48,7 +47,7 @@ def check_user_token() -> bool:
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False
except (HTTPError, httpcore.ConnectError):
except HTTPError:
LOG.warning(
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
)

View File

@@ -90,8 +90,9 @@ class ModalCloud(Cloud):
# grab the sha256 hash from docker hub for this image+tag
# this ensures that we always get the latest image for this tag, even if it's already cached
try:
manifest = subprocess.check_output( # nosec
["docker", "manifest", "inspect", docker_image],
manifest = subprocess.check_output( # nosec B602
f"docker manifest inspect {docker_image}",
shell=True,
).decode("utf-8")
sha256_hash = json.loads(manifest)["manifests"][0]["digest"]
except subprocess.CalledProcessError:

View File

@@ -5,13 +5,13 @@ import os
import tempfile
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Optional, Union
from typing import Union
from urllib.parse import urlparse
import requests
import torch
import yaml
from transformers.utils import is_torch_bf16_gpu_available, is_torch_tf32_available
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager
from axolotl.telemetry.errors import send_errors
@@ -32,63 +32,6 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__)
def _coerce_value(value: Any, existing: Optional[Any] = None) -> Any:
"""Coerce a string CLI value to its most likely Python type.
If an existing value is present in the config, its type is used to guide
casting. Otherwise, YAML-style inference is applied: booleans, ints,
floats, and None literals are recognised automatically.
Args:
value: The raw value (typically a string from the CLI).
existing: An optional existing config value whose type guides coercion.
Returns:
The value cast to the inferred or expected type.
"""
if not isinstance(value, str):
return value
# If the config already has a typed value, cast to match
if existing is not None:
if isinstance(existing, bool):
return value.lower() in ("true", "1", "yes")
if isinstance(existing, int):
try:
return int(value)
except (ValueError, TypeError):
return value
if isinstance(existing, float):
try:
return float(value)
except (ValueError, TypeError):
return value
# For other types (str, list, dict, etc.), return as-is
return value
# No existing value -- use YAML-style inference
lower = value.lower()
if lower in ("true", "yes"):
return True
if lower in ("false", "no"):
return False
if lower in ("null", "none", "~"):
return None
# Try int then float
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
API_KEY_FIELDS = {"comet_api_key"}
TELEMETRY_MANAGER = TelemetryManager.get_instance()
@@ -265,42 +208,18 @@ def load_cfg(
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
# Separate nested (dot-notation) kwargs from flat kwargs
nested_kwargs: dict[str, dict[str, Any]] = {}
flat_kwargs: dict[str, Any] = {}
for key, value in kwargs.items():
if "__" in key:
parent, child = key.split("__", 1)
nested_kwargs.setdefault(parent, {})[child] = value
else:
flat_kwargs[key] = value
# Apply flat kwargs
for key, value in flat_kwargs.items():
# If not strict, allow writing to cfg even if it's not in the yml already
if key in cfg_keys or not cfg.strict:
cfg[key] = _coerce_value(value, cfg.get(key))
# Apply nested kwargs (e.g., trl__beta -> cfg.trl.beta)
for parent, children in nested_kwargs.items():
if parent not in cfg_keys and cfg.strict:
continue
if cfg[parent] is None:
cfg[parent] = {}
if not isinstance(cfg[parent], dict):
LOG.warning(
"Overwriting non-dict value for '%s' with nested CLI overrides", parent
)
cfg[parent] = {}
for child_key, child_value in children.items():
existing_child = cfg[parent].get(child_key)
cfg[parent][child_key] = _coerce_value(child_value, existing_child)
if isinstance(cfg[key], bool):
cfg[key] = bool(value)
else:
cfg[key] = value
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except (RuntimeError, AssertionError):
except:
gpu_version = None
prepare_plugins(cfg)
@@ -310,7 +229,6 @@ def load_cfg(
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"fp8": compute_supports_fp8(),
"tf32": is_torch_tf32_available(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},

View File

@@ -71,7 +71,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
merge_lora=True,
load_in_8bit=False,
load_in_4bit=False,
quantize_moe_experts=False,
flash_attention=False,
context_parallel_size=None,
deepspeed=None,

View File

@@ -196,10 +196,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
state.wait_for_everyone()
LOG.info(
f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
main_process_only=True,
)
LOG.info(
"Merged weights are only the safetensors and doesn't include the model configuration "
f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
main_process_only=True,
)

View File

@@ -2,7 +2,7 @@
import dataclasses
from functools import wraps
from types import NoneType, UnionType
from types import NoneType
from typing import Any, Callable, Type, Union, get_args, get_origin
import click
@@ -20,8 +20,7 @@ def _strip_optional_type(field_type: type | str | None):
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
is_union = get_origin(field_type) is Union or isinstance(field_type, UnionType)
if is_union and type(None) in get_args(field_type):
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
@@ -88,70 +87,10 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
return decorator
def _is_pydantic_model(field_type: type) -> bool:
"""Check if a type is a Pydantic BaseModel subclass."""
try:
return isinstance(field_type, type) and issubclass(field_type, BaseModel)
except TypeError:
return False
def _get_field_description(field) -> str | None:
"""Get description from a Pydantic field, checking both .description and json_schema_extra."""
if field.description:
return field.description
if field.json_schema_extra and isinstance(field.json_schema_extra, dict):
return field.json_schema_extra.get("description")
return None
def _add_nested_model_options(
function: Callable, parent_name: str, model_class: Type[BaseModel]
) -> Callable:
"""
Add Click options for all fields of a nested Pydantic model using dot-notation.
Note: Only single-level nesting is supported (e.g., ``--trl.beta``).
Deeper nesting (e.g., ``--trl.scheduler.warmup``) is not handled.
Args:
function: Click command function to add options to.
parent_name: Parent field name (e.g., "trl").
model_class: Nested Pydantic model class.
Returns:
Function with added Click options.
"""
for sub_name, sub_field in reversed(model_class.model_fields.items()):
sub_type = _strip_optional_type(sub_field.annotation)
# Use dot notation: --parent.sub_field
cli_name = f"{parent_name}.{sub_name}".replace("_", "-")
# The kwarg name uses double-underscore as separator
param_name = f"{parent_name}__{sub_name}"
description = _get_field_description(sub_field)
if sub_type is bool:
option_name = f"--{cli_name}/--no-{cli_name}"
function = click.option(
option_name, param_name, default=None, help=description
)(function)
else:
option_name = f"--{cli_name}"
click_type = {str: str, int: int, float: float}.get(sub_type)
function = click.option(
option_name, param_name, default=None, type=click_type, help=description
)(function)
return function
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
"""
Create Click options from the fields of a Pydantic model.
For fields whose type is itself a Pydantic BaseModel, dot-notation CLI options are
generated for each sub-field (e.g., ``--trl.beta=0.1``).
Args:
config_class: PyDantic model with fields to parse from the CLI
@@ -164,11 +103,6 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
for name, field in reversed(config_class.model_fields.items()):
field_type = _strip_optional_type(field.annotation)
# Handle nested Pydantic models with dot-notation options
if _is_pydantic_model(field_type):
function = _add_nested_model_options(function, name, field_type)
continue
if field_type is bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"

View File

@@ -38,18 +38,7 @@ def do_vllm_serve(
cfg = load_cfg(config)
model = cfg.base_model
# Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default
serve_module = cli_args.get("serve_module") or getattr(
cfg.vllm, "serve_module", None
)
if (
serve_module is None
and getattr(cfg, "trl", None)
and getattr(cfg.trl, "vllm_lora_sync", False)
):
serve_module = "axolotl.scripts.vllm_serve_lora"
if serve_module is None:
serve_module = "trl.scripts.vllm_serve"
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
tensor_parallel_size = 1
data_parallel_size = 1
@@ -79,7 +68,7 @@ def do_vllm_serve(
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
)
base_kwargs = dict(
vllm_script_args = AxolotlScriptArguments(
model=model,
tensor_parallel_size=tensor_parallel_size,
data_parallel_size=data_parallel_size,
@@ -89,21 +78,7 @@ def do_vllm_serve(
dtype=dtype,
max_model_len=max_model_len,
enable_prefix_caching=enable_prefix_caching,
reasoning_parser=reasoning_parser,
enable_reasoning=enable_reasoning,
)
# Use LoRAScriptArguments when serving with native LoRA support
if serve_module == "axolotl.scripts.vllm_serve_lora":
from axolotl.scripts.vllm_serve_lora import LoRAScriptArguments
lora_kwargs = {}
if hasattr(cfg, "lora_r") and cfg.lora_r:
lora_kwargs["max_lora_rank"] = cfg.lora_r
vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs)
else:
vllm_script_args = AxolotlScriptArguments(
**base_kwargs,
reasoning_parser=reasoning_parser,
enable_reasoning=enable_reasoning,
)
vllm_serve_main(vllm_script_args)

View File

@@ -12,15 +12,10 @@ MOE_ARCH_BLOCK = {
"mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"deepseek_v3": "DeepseekV3MoE",
"mistral4": "Mistral4MoE",
"gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
"afmoe": "AfmoeMoE",
"glm4_moe": "Glm4MoeDecoderLayer",
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
}

View File

@@ -67,7 +67,7 @@ class JsonToJsonlConverter:
self.json_parser = json_parser
self.jsonl_serializer = jsonl_serializer
def convert(self, input_file_path):
def convert(self, input_file_path, output_file_path):
content = self.file_reader.read(input_file_path)
data = self.json_parser.parse(content)
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations

View File

@@ -250,7 +250,7 @@ class TrainerBuilderBase(abc.ABC):
def _configure_precision_settings(self, training_args_kwargs: dict):
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
training_args_kwargs["tf32"] = True if self.cfg.tf32 is True else False
training_args_kwargs["tf32"] = self.cfg.tf32
if self.cfg.bf16 == "full":
training_args_kwargs["bf16_full_eval"] = True
else:
@@ -353,30 +353,6 @@ class TrainerBuilderBase(abc.ABC):
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_adamw":
from flashoptim import FlashAdamW
optimizer_cls = FlashAdamW
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_adam":
from flashoptim import FlashAdam
optimizer_cls = FlashAdam
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "flash_sgd":
from flashoptim import FlashSGD
optimizer_cls = FlashSGD
elif self.cfg.optimizer == "flash_sgdw":
from flashoptim import FlashSGDW
optimizer_cls = FlashSGDW
elif self.cfg.optimizer == "flash_lion":
from flashoptim import FlashLion
optimizer_cls = FlashLion
if "betas" in adam_kwargs:
optimizer_kwargs["betas"] = adam_kwargs["betas"]
else:
raise ValueError(
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
@@ -508,8 +484,6 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.layer_offloading:
training_args_kwargs["layer_offloading"] = True
if self.cfg.activation_offloading is True:
# don't use the HF gradient checkpointing, manually wrap
training_args_kwargs["gradient_checkpointing"] = False

View File

@@ -122,12 +122,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
if getattr(self.cfg, "generate_samples", False):
from axolotl.utils.callbacks.generation import SFTGenerationCallback
callbacks.append(SFTGenerationCallback(trainer))
LOG.info("SFT sample generation enabled")
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
@@ -252,8 +246,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ddp_find_unused_parameters
)
if self.cfg.group_by_length:
training_arguments_kwargs["train_sampling_strategy"] = "group_by_length"
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
@@ -421,13 +414,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
# TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the
# config reflects this regardless of how the model was instantiated.
if (
self.cfg.reward_model
and getattr(self.model.config, "num_labels", None) != 1
):
self.model.config.num_labels = 1
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,

View File

@@ -11,6 +11,7 @@ from axolotl.core.trainers import (
)
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback
@@ -52,18 +53,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls_args = [self.model]
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
async_grpo = bool(
self.cfg.trl
and (
getattr(self.cfg.trl, "async_prefetch", False)
or getattr(self.cfg.trl, "use_data_producer", False)
)
)
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1,
async_grpo=async_grpo,
sequence_parallel=self.cfg.context_parallel_size > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
@@ -128,6 +119,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.use_wandb:
training_args_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
else:
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl is RLType.SIMPO:
@@ -137,17 +133,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
blocklist_args_kwargs.append("max_prompt_length")
# Handle when max_prompt_length == max_length from defaults
# CPOTrainer requires strictly less than
if (
training_args_kwargs["max_prompt_length"]
== training_args_kwargs["max_length"]
):
training_args_kwargs["max_prompt_length"] -= 1
elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig
blocklist_args_kwargs.append("max_prompt_length")
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
blocklist_args_kwargs.append("max_prompt_length")
blocklist_args_kwargs = ["max_prompt_length"]
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
@@ -157,18 +157,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
)
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
async_grpo = bool(
self.cfg.trl
and (
getattr(self.cfg.trl, "async_prefetch", False)
or getattr(self.cfg.trl, "use_data_producer", False)
)
)
training_args_cls = GRPOStrategy.get_training_args_class(
async_grpo=async_grpo
)
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
if self.cfg.rl is RLType.GDPO:
@@ -208,11 +197,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.eval_dataset:
trainer_kwargs["eval_dataset"] = self.eval_dataset
if (
self.cfg.adapter
and self.peft_config
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
):
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
trainer_kwargs["precompute_ref_log_probs"] = (
@@ -238,36 +223,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_kwargs, trainer_cls
)
# Allow FP8-quantized models to be fine-tuned with LoRA adapters.
# transformers' validate_quantization_for_training blocks FP8 because
# hf_quantizer.is_trainable is False, but LoRA only trains the adapters
# (base weights stay frozen in FP8).
_orig_validate_quant = None
if (
self.cfg.adapter
and hasattr(self.model, "is_quantized")
and self.model.is_quantized
):
import transformers.trainer as _trainer_module
_orig_validate_quant = _trainer_module.validate_quantization_for_training
_trainer_module.validate_quantization_for_training = lambda model: None
try:
trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
finally:
if _orig_validate_quant is not None:
import transformers.trainer as _trainer_module
_trainer_module.validate_quantization_for_training = (
_orig_validate_quant
)
trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
train_dataset=self.train_dataset,
callbacks=self.get_callbacks(),
**trainer_kwargs,
)
if self.cfg.fsdp_config or self.cfg.fsdp:
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:

View File

@@ -26,15 +26,13 @@ from transformers import PreTrainedModel, Trainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
from trl.experimental.utils import pad_to_length
from trl.trainer.utils import pad_to_length
from typing_extensions import override
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
from axolotl.core.trainers.mixins import (
ActivationOffloadingMixin,
CheckpointSaveMixin,
DistributedParallelMixin,
LayerOffloadingMixin,
OptimizerMixin,
PackingMixin,
RngLoaderMixin,
@@ -53,6 +51,8 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
TOKENS_STATE_FILE = "tokens_state."
REDUCTION_FNS = {
"mean": torch.mean,
"min": torch.min,
@@ -67,7 +67,6 @@ class AxolotlTrainer(
OptimizerMixin,
RngLoaderMixin,
CheckpointSaveMixin,
LayerOffloadingMixin,
ActivationOffloadingMixin,
DistributedParallelMixin,
Trainer,
@@ -720,20 +719,13 @@ class AxolotlTrainer(
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")
# fix for Context Parallel save: CP eval invalidates tensor storage
# pointers, so clone to CPU to get fresh valid storage for safetensors
if (
state_dict is not None
and self.axolotl_cfg
and self.axolotl_cfg.context_parallel_size
and self.axolotl_cfg.context_parallel_size > 1
):
if state_dict is None:
state_dict = self.accelerator.get_state_dict(self.model)
if state_dict is not None:
state_dict = {
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items()
}
supported_classes = (
(PreTrainedModel,)
if not is_peft_available()
@@ -744,7 +736,6 @@ class AxolotlTrainer(
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
if isinstance(
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
supported_classes,
@@ -754,7 +745,6 @@ class AxolotlTrainer(
).save_pretrained(
output_dir,
state_dict=state_dict,
is_main_process=self.accelerator.is_main_process,
)
else:
LOG.info(
@@ -782,7 +772,11 @@ class AxolotlTrainer(
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
self.data_collator.tokenizer.save_pretrained(output_dir)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -1 +0,0 @@
TOKENS_STATE_FILE = "tokens_state.json"

View File

@@ -25,13 +25,17 @@ class DPOStrategy:
# Label smoothing is not compatible with IPO
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_length"] = cfg.sequence_len
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
if cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
if cfg.dpo_padding_free is not None:
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
if cfg.dpo_norm_loss is not None:
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
if cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
if cfg.dpo_use_liger_kernel is not None:
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
return training_args_kwargs

View File

@@ -2,8 +2,7 @@
Axolotl specific DPO args
"""
from dataclasses import dataclass, field
from typing import Optional
from dataclasses import dataclass
from trl import DPOConfig
@@ -17,4 +16,3 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
dpo_norm_loss: bool | None = False
rpo_alpha: Optional[float] = field(default=None)

View File

@@ -57,18 +57,16 @@ class AxolotlDPOTrainer(
def tokenize_row(
features,
processing_class,
max_prompt_length: int | None = None,
max_completion_length: int | None = None,
add_special_tokens: bool = True,
is_chat: bool = False,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length=max_prompt_length,
max_completion_length=max_completion_length,
add_special_tokens=add_special_tokens,
is_chat=is_chat,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
@@ -103,10 +101,10 @@ class AxolotlDPOTrainer(
) -> dict[str, torch.Tensor]:
if self.args.dpo_norm_loss:
# fmt: off
loss_type: list[str] = self.loss_type # type: ignore[has-type]
loss_type: str = self.loss_type # type: ignore[has-type]
# fmt: on
# concatenated_forward handles avg token logprob for ipo case already
self.loss_type = ["ipo"]
self.loss_type = "ipo"
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
self.loss_type = loss_type
return res

View File

@@ -9,9 +9,8 @@ from huggingface_hub import snapshot_download
from requests import HTTPError
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.args import AxolotlAsyncGRPOConfig, AxolotlGRPOConfig
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
from axolotl.core.trainers.grpo.trainer import (
AxolotlAsyncGRPOTrainer,
AxolotlGRPOSequenceParallelTrainer,
AxolotlGRPOTrainer,
)
@@ -28,31 +27,14 @@ class GRPOStrategy:
@classmethod
def get_trainer_class(
cls,
sequence_parallel: bool,
async_grpo: bool = False,
) -> (
type[AxolotlGRPOTrainer]
| type[AxolotlGRPOSequenceParallelTrainer]
| type[AxolotlAsyncGRPOTrainer]
):
if sequence_parallel and async_grpo:
raise ValueError(
"sequence_parallel and async_grpo cannot both be enabled. "
"Disable one of context_parallel_size > 1 or async_prefetch/use_data_producer."
)
cls, sequence_parallel: bool
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
if sequence_parallel:
return AxolotlGRPOSequenceParallelTrainer
if async_grpo:
return AxolotlAsyncGRPOTrainer
return AxolotlGRPOTrainer
@classmethod
def get_training_args_class(
cls, async_grpo: bool = False
) -> type[AxolotlGRPOConfig] | type[AxolotlAsyncGRPOConfig]:
if async_grpo:
return AxolotlAsyncGRPOConfig
def get_training_args_class(cls) -> type[AxolotlGRPOConfig]:
return AxolotlGRPOConfig
@classmethod
@@ -142,63 +124,13 @@ class GRPOStrategy:
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_kernel"] = trl.use_liger_loss
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["multi_objective_aggregation"] = (
trl.multi_objective_aggregation
)
# Async GRPO fields
if getattr(trl, "use_data_producer", None) is not None:
grpo_args_kwargs["use_data_producer"] = trl.use_data_producer
if getattr(trl, "async_prefetch", None) is not None:
grpo_args_kwargs["async_prefetch"] = trl.async_prefetch
if getattr(trl, "prefetch_depth", None) is not None:
grpo_args_kwargs["prefetch_depth"] = trl.prefetch_depth
if getattr(trl, "vllm_sync_interval", None) is not None:
grpo_args_kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
if getattr(trl, "streaming_partial_batch", None) is not None:
grpo_args_kwargs["streaming_partial_batch"] = trl.streaming_partial_batch
if getattr(trl, "streaming_min_groups", None) is not None:
grpo_args_kwargs["streaming_min_groups"] = trl.streaming_min_groups
if getattr(trl, "vllm_importance_sampling_correction", None) is not None:
grpo_args_kwargs["vllm_importance_sampling_correction"] = (
trl.vllm_importance_sampling_correction
)
if getattr(trl, "vllm_importance_sampling_mode", None) is not None:
grpo_args_kwargs["vllm_importance_sampling_mode"] = (
trl.vllm_importance_sampling_mode
)
if getattr(trl, "vllm_importance_sampling_cap", None) is not None:
grpo_args_kwargs["vllm_importance_sampling_cap"] = (
trl.vllm_importance_sampling_cap
)
if getattr(trl, "off_policy_mask_threshold", None) is not None:
grpo_args_kwargs["off_policy_mask_threshold"] = (
trl.off_policy_mask_threshold
)
if getattr(trl, "use_bias_correction_kl", None) is not None:
grpo_args_kwargs["use_bias_correction_kl"] = trl.use_bias_correction_kl
# Fast Async GRPO fields
if getattr(trl, "reward_num_workers", None) is not None:
grpo_args_kwargs["reward_num_workers"] = trl.reward_num_workers
if getattr(trl, "replay_buffer_size", None) is not None:
grpo_args_kwargs["replay_buffer_size"] = trl.replay_buffer_size
if getattr(trl, "replay_recompute_logps", None) is not None:
grpo_args_kwargs["replay_recompute_logps"] = trl.replay_recompute_logps
if getattr(trl, "reroll_start_fraction", None) is not None:
grpo_args_kwargs["reroll_start_fraction"] = trl.reroll_start_fraction
if getattr(trl, "reroll_max_groups", None) is not None:
grpo_args_kwargs["reroll_max_groups"] = trl.reroll_max_groups
if getattr(trl, "skip_zero_advantage_batches", None) is not None:
grpo_args_kwargs["skip_zero_advantage_batches"] = (
trl.skip_zero_advantage_batches
)
if getattr(trl, "vllm_lora_sync", None) is not None:
grpo_args_kwargs["vllm_lora_sync"] = trl.vllm_lora_sync
return grpo_args_kwargs
@classmethod

View File

@@ -6,7 +6,6 @@ from dataclasses import dataclass
from trl import GRPOConfig
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig
from axolotl.core.training_args import AxolotlTrainingMixins
@@ -15,10 +14,3 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""Axolotl GRPO Config for GRPO training"""
context_parallel_size: int | None = None
@dataclass
class AxolotlAsyncGRPOConfig(AxolotlTrainingMixins, FastAsyncGRPOConfig):
"""Axolotl Async GRPO Config — adds async prefetch, streaming scoring, and IS correction."""
context_parallel_size: int | None = None

File diff suppressed because it is too large Load Diff

View File

@@ -1,768 +0,0 @@
# Copyright 2020-2026 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Experimental GRPO extensions: parallel reward workers, replay buffer,
deferred re-roll, and zero-advantage skipping.
These features are built as subclasses of GRPOTrainer and GRPODataProducer,
using the hook system (_compute_rewards_for_batch, _post_advantage_hook,
_pre_produce_hook) defined in the base classes.
"""
from __future__ import annotations
import asyncio
import logging
import threading
from dataclasses import dataclass, field
import torch
from torch import nn
from trl import GRPOTrainer
from axolotl.core.trainers.grpo.async_trainer import (
AsyncGRPOConfig,
AsyncGRPOTrainer,
GRPODataProducer,
)
from axolotl.core.trainers.grpo.replay_buffer import ReplayBuffer
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Extended config
# ---------------------------------------------------------------------------
@dataclass
class FastAsyncGRPOConfig(AsyncGRPOConfig):
"""GRPOConfig with additional experimental parameters."""
reward_num_workers: int = field(
default=1,
metadata={
"help": "Number of persistent subprocess workers for parallel reward computation. Each worker has its "
"own main thread so signal.alarm() (used by math_verify) works correctly. Work is sharded across "
"workers by prompt groups. Only used with use_data_producer=True and non-nn.Module reward functions."
},
)
replay_buffer_size: int = field(
default=0,
metadata={
"help": "[Experimental, disabled by default] Size of the replay buffer for storing high-signal rollout "
"groups. When > 0, groups with reward variance are cached and used to replace zero-signal groups "
"(where all rewards are identical). Set to 0 to disable. Only used with use_data_producer=True."
},
)
replay_recompute_logps: bool = field(
default=True,
metadata={
"help": "When True (default), recompute old_per_token_logps for replayed groups using the current "
"training model. This fixes the importance sampling mismatch that occurs when replaying stale data. "
"Only relevant when replay_buffer_size > 0."
},
)
reroll_start_fraction: float = field(
default=0.5,
metadata={
"help": "Fraction of total training steps after which deferred re-rolling begins. Zero-signal prompts "
"(where all rewards in a group are identical) are buffered and re-injected into later batches when the "
"model is more likely to solve them. Set to 1.0 to disable. Only used with use_data_producer=True."
},
)
reroll_max_groups: int = field(
default=1,
metadata={
"help": "Maximum number of prompt groups to replace with re-roll candidates per batch. Higher values "
"increase data utilization but reduce prompt diversity. Only used with use_data_producer=True."
},
)
skip_zero_advantage_batches: bool = field(
default=True,
metadata={
"help": "When True, skip gradient computation for micro-batches where all advantages are zero (no learning "
"signal). This avoids the forward/backward pass entirely when no learning signal is present. The step is "
"logged with skipped_zero_adv_batches=1 for monitoring."
},
)
vllm_lora_sync: bool = field(
default=False,
metadata={
"help": "When True, sync LoRA adapter weights to vLLM via filesystem instead of merging into base model "
"and NCCL-broadcasting all parameters. vLLM loads the adapter natively using Punica kernels. "
"Requires vllm_serve_lora serve module (auto-selected when this is True). "
"Syncs only LoRA adapter weights (much smaller) vs full merged model. Legacy merge behavior is used when False."
},
)
# ---------------------------------------------------------------------------
# Extended data producer with re-roll injection
# ---------------------------------------------------------------------------
class RerollDataProducer(GRPODataProducer):
"""GRPODataProducer that injects re-roll candidates into prompt batches.
Reads from the trainer's ``_reroll_buffer`` (populated by
``GRPOExperimentalTrainer._post_advantage_hook``) and replaces the
last N prompt groups with previously-failed prompts.
"""
def _pre_produce_hook(self, inputs: list, global_step: int) -> list:
trainer = self._trainer
reroll_buf = getattr(trainer, "_reroll_buffer", None)
reroll_lock = getattr(trainer, "_reroll_lock", None)
if reroll_buf is None or reroll_lock is None:
return inputs
max_steps = getattr(trainer.args, "max_steps", -1)
start_frac = getattr(trainer.args, "reroll_start_fraction", 1.0)
max_groups = getattr(trainer.args, "reroll_max_groups", 1)
reroll_start_step = (
max(1, int(max_steps * start_frac)) if max_steps > 0 else float("inf")
)
if global_step < reroll_start_step:
return inputs
with reroll_lock:
n_to_take = min(max_groups, len(reroll_buf))
reroll_prompts = [reroll_buf.pop(0) for _ in range(n_to_take)]
if reroll_prompts:
num_gen = self._num_generations
n_groups = len(inputs) // num_gen
for i, reroll_prompt in enumerate(reroll_prompts):
group_idx = n_groups - 1 - i
if group_idx < 0:
break
start = group_idx * num_gen
for j in range(num_gen):
inputs[start + j] = reroll_prompt
logger.info(
f"[REROLL] Step {global_step}: replaced {len(reroll_prompts)}/{n_groups} prompt groups "
f"with deferred re-roll candidates ({len(reroll_buf)} remaining)"
)
return inputs
# ---------------------------------------------------------------------------
# Persistent reward subprocess pool
# ---------------------------------------------------------------------------
def _persistent_reward_worker(conn):
"""Long-lived reward worker. Receives work items, returns results."""
while True:
try:
msg = conn.recv()
except EOFError:
break
if msg is None: # Shutdown signal
break
(
reward_funcs,
prompts,
completions,
completion_ids_list,
inputs,
reward_func_names,
) = msg
try:
keys = [
key
for key in inputs[0]
if key not in ["prompt", "completion", "completion_ids"]
]
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
results = []
for reward_func, _reward_func_name in zip(
reward_funcs, reward_func_names, strict=True
):
output = reward_func(
prompts=prompts,
completions=completions,
completion_ids=completion_ids_list,
**reward_kwargs,
)
results.append(
[float(r) if r is not None else float("nan") for r in output]
)
conn.send(results)
except Exception:
conn.send(None)
# ---------------------------------------------------------------------------
# Extended trainer
# ---------------------------------------------------------------------------
class FastAsyncGRPOTrainer(AsyncGRPOTrainer):
"""GRPOTrainer with experimental extensions.
Adds:
- Parallel reward subprocess workers (``reward_num_workers``)
- Replay buffer for high-signal group reuse (``replay_buffer_size``)
- Deferred re-roll of failed prompts (``reroll_start_fraction``)
- Zero-advantage micro-batch skipping
"""
def __init__(self, *args, **kwargs):
# These must be initialized before super().__init__() because
# _create_data_producer (called during super().__init__) needs them.
self._reroll_buffer: list = []
self._reroll_lock = threading.Lock()
# Temporarily suppress the base class's Liger + OPSM validation check,
# since this subclass supports it via a custom compute_liger_loss override.
grpo_args = kwargs.get("args")
if grpo_args is None:
for a in args:
if hasattr(a, "off_policy_mask_threshold"):
grpo_args = a
break
saved_threshold = None
if grpo_args is not None and getattr(grpo_args, "use_liger_kernel", False):
saved_threshold = grpo_args.off_policy_mask_threshold
grpo_args.off_policy_mask_threshold = None
super().__init__(*args, **kwargs)
if saved_threshold is not None:
grpo_args.off_policy_mask_threshold = saved_threshold
self.off_policy_mask_threshold = saved_threshold
# Replay buffer
if getattr(self.args, "replay_buffer_size", 0) > 0:
self._replay_buffer = ReplayBuffer(max_size=self.args.replay_buffer_size)
else:
self._replay_buffer = None
self._replay_recompute_logps = getattr(
self.args, "replay_recompute_logps", True
)
# Reward worker pool (lazy-initialized)
self._reward_workers = None
# -- Factory override: use RerollDataProducer ----------------------------
def _create_data_producer(self, args, train_dataset):
"""Override to use RerollDataProducer for re-roll prompt injection."""
from axolotl.core.trainers.grpo.async_trainer import (
AsyncDataProducer,
ProducerConfig,
)
producer_config = ProducerConfig(
mini_epochs=args.num_iterations,
max_rollouts=None,
eval_during_produce=False,
empty_cache_before_produce=True,
empty_cache_after_produce=True,
async_prefetch=args.async_prefetch,
prefetch_depth=args.prefetch_depth,
)
data_producer = RerollDataProducer(
config=producer_config,
prompt_dataset=train_dataset,
num_generations=self.num_generations,
generation_batch_size=args.generation_batch_size,
train_batch_size=args.per_device_train_batch_size,
steps_per_generation=args.steps_per_generation,
shuffle_dataset=self.shuffle_dataset,
seed=args.seed,
)
data_producer.set_trainer(self)
if args.async_prefetch:
data_producer = AsyncDataProducer(
data_producer,
background_produce_kwargs={"skip_policy_logps": True},
)
return data_producer
# -- Reward worker pool --------------------------------------------------
def _get_reward_workers(self):
"""Return a list of persistent reward worker subprocesses (lazy-initialized)."""
import multiprocessing as _mp
num_workers = getattr(self.args, "reward_num_workers", 1)
if num_workers < 1:
num_workers = 1
if self._reward_workers is not None:
alive = all(proc.is_alive() for conn, proc in self._reward_workers)
if alive and len(self._reward_workers) == num_workers:
return self._reward_workers
self._shutdown_reward_workers()
workers = []
for _ in range(num_workers):
parent_conn, child_conn = _mp.Pipe()
proc = _mp.Process(
target=_persistent_reward_worker, args=(child_conn,), daemon=True
)
proc.start()
child_conn.close()
workers.append((parent_conn, proc))
self._reward_workers = workers
return workers
def _shutdown_reward_workers(self):
"""Shut down all persistent reward workers."""
if self._reward_workers is None:
return
for conn, proc in self._reward_workers:
try:
conn.send(None)
proc.join(timeout=5)
except Exception:
pass
try:
conn.close()
except Exception:
pass
self._reward_workers = None
# -- Hook overrides ------------------------------------------------------
def _compute_rewards_for_batch(
self, inputs, prompts, completions, completion_ids_list
):
"""Dispatch rewards to parallel subprocess workers (synchronous wrapper)."""
self._launch_reward_workers(inputs, prompts, completions, completion_ids_list)
return self._collect_reward_workers(
inputs, prompts, completions, completion_ids_list
)
def _launch_reward_workers(self, inputs, prompts, completions, completion_ids_list):
"""Send reward work to subprocess workers (non-blocking).
Results are collected later by _collect_reward_workers, allowing GPU
logprob computation to overlap with CPU reward computation.
"""
reward_can_bg = all(
callable(rf)
and not isinstance(rf, nn.Module)
and not asyncio.iscoroutinefunction(rf)
for rf in self.reward_funcs
)
num_workers = getattr(self.args, "reward_num_workers", 1)
if not reward_can_bg or num_workers <= 1:
# Can't parallelize — store args for sync fallback in collect
self._reward_workers_used = None
self._pending_reward_args = (
inputs,
prompts,
completions,
completion_ids_list,
)
return
workers = self._get_reward_workers()
num_generations = self.num_generations
num_prompts = len(prompts)
num_groups = num_prompts // num_generations
# Shard by prompt groups across workers
groups_per_worker = max(1, (num_groups + len(workers) - 1) // len(workers))
workers_used = []
for w_idx, (conn, _proc) in enumerate(workers):
g_start = w_idx * groups_per_worker
g_end = min((w_idx + 1) * groups_per_worker, num_groups)
if g_start >= num_groups:
break
s_start = g_start * num_generations
s_end = g_end * num_generations
conn.send(
(
self.reward_funcs,
prompts[s_start:s_end],
completions[s_start:s_end],
completion_ids_list[s_start:s_end],
inputs[s_start:s_end],
self.reward_func_names,
)
)
workers_used.append(conn)
self._reward_workers_used = workers_used
self._pending_reward_args = (inputs, prompts, completions, completion_ids_list)
def _collect_reward_workers(
self, inputs, prompts, completions, completion_ids_list
):
"""Collect reward results from subprocess workers (blocks until done)."""
from accelerate.utils import gather
workers_used = getattr(self, "_reward_workers_used", None)
args = getattr(self, "_pending_reward_args", None)
self._reward_workers_used = None
self._pending_reward_args = None
if workers_used is None:
# Sync fallback — compute on main thread
if args is not None:
return self._calculate_rewards(*args)
return self._calculate_rewards(
inputs, prompts, completions, completion_ids_list
)
device = self.accelerator.device
num_prompts = len(args[1]) if args else len(prompts)
# Collect results from workers
all_worker_results = []
any_failed = False
for conn in workers_used:
result = conn.recv()
if result is None:
any_failed = True
# Drain remaining workers to prevent stale results in pipes
for remaining_conn in workers_used:
if remaining_conn is not conn:
try:
remaining_conn.recv()
except Exception:
pass
break
all_worker_results.append(result)
if not any_failed:
rewards_per_func = torch.zeros(
num_prompts, len(self.reward_funcs), device=device
)
offset = 0
for worker_result in all_worker_results:
chunk_size = len(worker_result[0])
for i, result in enumerate(worker_result):
rewards_per_func[offset : offset + chunk_size, i] = torch.tensor(
result, dtype=torch.float32, device=device
)
offset += chunk_size
return gather(rewards_per_func)
# Fallback to main thread on failure
if args is not None:
return self._calculate_rewards(*args)
return self._calculate_rewards(
inputs, prompts, completions, completion_ids_list
)
def _post_advantage_hook(
self,
data: dict,
rewards_per_func,
advantages,
inputs: list,
num_generations: int,
mode: str,
s_start: int | None = None,
s_end: int | None = None,
is_last_chunk: bool = True,
) -> None:
"""Replay buffer store/replace + re-roll buffering."""
from trl.models.utils import disable_gradient_checkpointing
# -- Replay buffer: store high-signal groups --
if self._replay_buffer is not None:
local_grouped = rewards_per_func.view(
-1, num_generations, len(self.reward_funcs)
)
per_group_std = local_grouped.std(dim=1)
has_signal = (per_group_std > 0).any(dim=1)
offset = s_start or 0
if has_signal.any():
grouped_adv = advantages.view(-1, num_generations)
replay_scores = grouped_adv.abs().sum(dim=1) * per_group_std.sum(dim=1)
for group_idx in has_signal.nonzero(as_tuple=True)[0]:
gi = group_idx.item()
start = offset + gi * num_generations
end = start + num_generations
group_data = {}
for key in data:
val = data[key]
if (
isinstance(val, torch.Tensor)
and val.dim() > 0
and val.size(0) >= end
):
group_data[key] = val[start:end].clone()
self._replay_buffer.add(replay_scores[gi].item(), group_data)
# Replace zero-signal groups with high-signal replay buffer entries
# Only in non-streaming path (s_start is None) — streaming scores
# groups incrementally, so replacement + logprob recompute would be
# too expensive per chunk.
n_replaced = 0
if s_start is None:
no_signal = ~has_signal
replaced_ranges = []
if no_signal.any() and len(self._replay_buffer) > 0:
for group_idx in no_signal.nonzero(as_tuple=True)[0]:
sampled = self._replay_buffer.sample(1)
if sampled is None:
break
sampled_group = sampled[0]
gi = group_idx.item()
start = offset + gi * num_generations
end = start + num_generations
for key, val in sampled_group.items():
if key in data and isinstance(data[key], torch.Tensor):
src = val.to(data[key].device)
tgt_seq_len = (
data[key].size(1) if data[key].dim() > 1 else None
)
if start >= data[key].size(0) or end > data[key].size(
0
):
continue
if tgt_seq_len is not None:
if src.size(1) <= tgt_seq_len:
data[key][start:end] = 0
data[key][start:end, : src.size(1)] = src
else:
data[key][start:end] = src[:, :tgt_seq_len]
else:
data[key][start:end] = src
replaced_ranges.append((start, end))
n_replaced += 1
# Recompute old_per_token_logps for replayed groups
if (
n_replaced > 0
and self._replay_recompute_logps
and "old_per_token_logps" in data
):
with (
torch.no_grad(),
disable_gradient_checkpointing(
self.model, self.args.gradient_checkpointing_kwargs
),
):
for r_start, r_end in replaced_ranges:
r_ids = torch.cat(
[
data["prompt_ids"][r_start:r_end],
data["completion_ids"][r_start:r_end],
],
dim=1,
)
r_mask = torch.cat(
[
data["prompt_mask"][r_start:r_end],
data["completion_mask"][r_start:r_end],
],
dim=1,
)
r_logits_to_keep = data["completion_ids"].size(1)
r_fwd_kwargs = {}
for fk in (
"pixel_values",
"image_grid_thw",
"pixel_attention_mask",
"image_sizes",
"token_type_ids",
"mm_token_type_ids",
):
if fk in data:
r_fwd_kwargs[fk] = data[fk]
r_logps, _ = self._get_per_token_logps_and_entropies(
self.model,
r_ids,
r_mask,
r_logits_to_keep,
r_end - r_start,
**r_fwd_kwargs,
)
data["old_per_token_logps"][r_start:r_end] = r_logps
if n_replaced > 0:
self._metrics[mode]["replay_buffer_replacements"].append(
float(n_replaced)
)
if is_last_chunk:
self._metrics[mode]["replay_buffer_size"].append(
float(len(self._replay_buffer))
)
# -- Re-roll buffer: store failed prompts --
if getattr(self.args, "reroll_start_fraction", 1.0) < 1.0:
grouped_rewards = rewards_per_func.view(
-1, num_generations, len(self.reward_funcs)
)
per_group_std = grouped_rewards.std(dim=1)
per_group_mean = grouped_rewards.mean(dim=1)
zero_signal = (per_group_std == 0).all(dim=1)
all_failed = (per_group_mean.abs() < 1e-6).all(dim=1)
should_reroll = zero_signal & all_failed
_n_buffered = 0
with self._reroll_lock:
for group_idx in should_reroll.nonzero(as_tuple=True)[0]:
idx = group_idx.item() * num_generations
if idx >= len(inputs):
continue
prompt_input = inputs[idx]
self._reroll_buffer.append(prompt_input)
_n_buffered += 1
if _n_buffered > 0:
self._metrics[mode]["reroll_buffered"].append(float(_n_buffered))
if is_last_chunk:
self._metrics[mode]["reroll_buffer_size"].append(
float(len(self._reroll_buffer))
)
# -- Zero-advantage skipping + Liger OPSM ---------------------------------
def compute_liger_loss(self, unwrapped_model, inputs):
"""Liger loss with zero-adv skipping and off-policy sequence masking (OPSM).
The base class Liger path doesn't support OPSM because the fused kernel
doesn't expose per-token logprobs needed for the KL computation. This
override computes them via chunked lm_head matmul (no grad, low memory)
and applies the OPSM to the loss mask before calling the kernel.
"""
if self.args.skip_zero_advantage_batches and torch.all(
inputs["advantages"] == 0
):
mode = "train" if self.model.training else "eval"
self._metrics[mode]["skipped_zero_adv_batches"].append(1.0)
return torch.tensor(
0.0, device=inputs["advantages"].device, requires_grad=True
)
if self.off_policy_mask_threshold is None:
return super().compute_liger_loss(unwrapped_model, inputs)
# OPSM path: need per_token_logps for KL, which Liger kernel doesn't provide
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = (
inputs["completion_ids"],
inputs["completion_mask"],
)
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1)
last_hidden_state = self._get_last_hidden_state(
unwrapped_model,
input_ids,
attention_mask,
logits_to_keep,
inputs.get("pixel_values"),
inputs.get("image_grid_thw"),
inputs.get("pixel_attention_mask"),
inputs.get("image_sizes"),
)
loss_mask = (
completion_mask
if "tool_mask" not in inputs
else completion_mask * inputs["tool_mask"]
)
# Compute per_token_logps via chunked lm_head matmul (no grad, low memory)
lm_weight = unwrapped_model.lm_head.weight
lm_bias = unwrapped_model.lm_head.bias
with torch.no_grad():
per_token_logps_chunks = []
for i in range(last_hidden_state.size(0)):
chunk_logits = torch.matmul(last_hidden_state[i : i + 1], lm_weight.t())
if lm_bias is not None:
chunk_logits = chunk_logits + lm_bias
chunk_lps = (
chunk_logits.float()
.log_softmax(-1)
.gather(-1, completion_ids[i : i + 1].unsqueeze(-1))
.squeeze(-1)
)
per_token_logps_chunks.append(chunk_lps)
del chunk_logits
per_token_logps = torch.cat(per_token_logps_chunks, dim=0)
advantages = inputs["advantages"]
if advantages.dim() == 1:
advantages_2d = advantages.unsqueeze(1)
else:
advantages_2d = advantages
sampling_per_token_logps = inputs.get("sampling_per_token_logps")
if sampling_per_token_logps is None:
sampling_per_token_logps = inputs.get("old_per_token_logps")
if sampling_per_token_logps is None:
sampling_per_token_logps = per_token_logps
off_policy_mask = GRPOTrainer.get_off_policy_mask(
advantages=advantages_2d,
per_token_logps=per_token_logps,
sampling_per_token_logps=sampling_per_token_logps,
mask=loss_mask,
off_policy_threshold=self.off_policy_mask_threshold,
)
loss_mask = loss_mask * off_policy_mask
# Call the Liger fused kernel with OPSM-modified mask
loss, metrics = self.liger_grpo_loss(
_input=last_hidden_state,
lin_weight=unwrapped_model.lm_head.weight,
selected_token_ids=completion_ids,
attention_mask=loss_mask,
advantages=inputs["advantages"],
bias=unwrapped_model.lm_head.bias,
old_per_token_logps=inputs.get("old_per_token_logps"),
ref_per_token_logps=inputs.get("ref_per_token_logps"),
vllm_is_ratio=inputs.get("importance_sampling_ratio"),
)
mean_kl = metrics[0] if self.beta != 0.0 else None
clip_ratio = metrics[-1]
mode = "train" if self.model.training else "eval"
if self.beta != 0.0:
self._metrics[mode]["kl"].append(
self.accelerator.gather(mean_kl).mean().item()
)
self._metrics[mode]["clip_ratio"].append(
self.accelerator.gather(clip_ratio).mean().item()
)
normalizer = (
self.current_gradient_accumulation_steps if mode == "train" else 1.0
)
return loss / normalizer
def _compute_loss(self, model, inputs):
if self.args.skip_zero_advantage_batches and torch.all(
inputs["advantages"] == 0
):
mode = "train" if self.model.training else "eval"
self._metrics[mode]["skipped_zero_adv_batches"].append(1.0)
# Create zero loss with grad_fn. DeepSpeed requires grad_fn != None.
# With ZeRO-3, parameters are partitioned (shape=[0], requires_grad=False)
# so we can't just do `(p * 0).sum()`. Instead, do a tiny forward pass
# with a single token to create a proper computation graph.
prompt_ids = inputs["prompt_ids"][:1, :1] # (1, 1)
attn = torch.ones_like(prompt_ids)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
out = model(input_ids=prompt_ids, attention_mask=attn)
return out.logits.sum() * 0
return super()._compute_loss(model, inputs)

View File

@@ -1,44 +0,0 @@
"""Simple replay buffer for storing and sampling high-signal rollout groups."""
import heapq
import torch
class ReplayBuffer:
"""Min-heap replay buffer that keeps the highest-scoring rollout groups.
Groups are scored by signal quality (advantage magnitude * reward variance).
When sampling, groups are drawn proportional to their scores.
"""
def __init__(self, max_size: int):
self.max_size = max_size
self._heap: list[tuple[float, int, dict]] = [] # min-heap of (score, id, data)
self._counter = 0 # unique tiebreaker for heap
def __len__(self):
return len(self._heap)
def add(self, score: float, data: dict):
"""Add a group to the buffer. If full, replaces lowest-scoring entry."""
if self.max_size <= 0:
return
self._counter += 1
if len(self._heap) < self.max_size:
heapq.heappush(self._heap, (score, self._counter, data))
elif score > self._heap[0][0]:
heapq.heapreplace(self._heap, (score, self._counter, data))
def sample(self, num_samples: int) -> list[dict] | None:
"""Sample groups weighted by their scores. Returns None if buffer is empty."""
if self.max_size <= 0 or not self._heap:
return None
scores = torch.tensor([item[0] for item in self._heap], dtype=torch.float32)
scores = scores.clamp(min=1e-8) # avoid zero probabilities
probs = scores / scores.sum()
replacement = num_samples > len(self._heap)
indices = torch.multinomial(
probs, num_samples, replacement=replacement
).tolist()
return [self._heap[i][2] for i in indices]

View File

@@ -40,7 +40,6 @@ from trl.trainer.grpo_config import GRPOConfig
from trl.trainer.grpo_trainer import RewardFunc, nanstd
from trl.trainer.utils import pad
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOTrainer
from axolotl.core.trainers.grpo.sampler import SequenceParallelRepeatRandomSampler
from axolotl.core.trainers.mixins import (
DistributedParallelMixin,
@@ -67,19 +66,6 @@ class AxolotlGRPOTrainer(
_tag_names = ["trl", "grpo", "axolotl"]
class AxolotlAsyncGRPOTrainer(
RngLoaderMixin,
SchedulerMixin,
OptimizerMixin,
OptimizerInitMixin,
DistributedParallelMixin,
FastAsyncGRPOTrainer,
):
"""Extend AsyncGRPOTrainer with axolotl helpers"""
_tag_names = ["trl", "grpo", "async", "axolotl"]
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""

View File

@@ -4,7 +4,6 @@
from .activation_checkpointing import ActivationOffloadingMixin
from .checkpoints import CheckpointSaveMixin
from .layer_offloading import LayerOffloadingMixin
from .distributed_parallel import DistributedParallelMixin
from .optimizer import OptimizerMixin
from .packing import PackingMixin

View File

@@ -19,4 +19,5 @@ class CheckpointSaveMixin(Trainer):
f"Trainer does not support saving optimizer and scheduler: {exc}\n"
"Optimizer and scheduler states were not saved - resuming from checkpoints "
"for this training run will not be possible.",
main_process_only=True,
)

View File

@@ -1,304 +0,0 @@
"""
Trainer mixin for layer-wise parameter offloading to CPU.
Offloads frozen (non-trainable) parameters in decoder layers to CPU, then uses
forward/backward hooks to stream them on/off GPU one layer at a time with CUDA
stream prefetching. Trainable parameters (e.g. LoRA weights) stay on GPU always.
Forward: pre-hook loads layer N's frozen params to GPU (prefetches N+1 on
transfer stream), post-hook offloads layer N-1's frozen params.
Backward: same in reverse order.
"""
import contextlib
import torch
import torch.nn as nn
from transformers import Trainer
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def _find_decoder_layers(model: nn.Module) -> tuple[nn.ModuleList | None, list[str]]:
"""Recursively search the model for the decoder layer ModuleList.
Finds any ModuleList whose children have 'DecoderLayer' in their class name.
Handles all common HF architectures including VLM wrappers (e.g. Qwen3.5-MoE
where layers are at model.language_model.layers).
"""
# BFS to find the first ModuleList containing decoder layers
queue = [model]
while queue:
m = queue.pop(0)
for _name, child in m.named_children():
if isinstance(child, nn.ModuleList) and len(child) > 0:
first_type = type(child[0]).__name__
if "DecoderLayer" in first_type or "TransformerBlock" in first_type:
layer_types = list({type(layer).__name__ for layer in child})
return child, layer_types
else:
queue.append(child)
return None, []
def _get_frozen_params(layer: nn.Module) -> list[tuple[str, nn.Parameter]]:
"""Get all non-trainable parameters in a layer."""
return [(n, p) for n, p in layer.named_parameters() if not p.requires_grad]
class LayerOffloadManager:
"""Manages offloading frozen decoder layer params to CPU and streaming
them back during forward/backward with CUDA stream overlap.
Only frozen (requires_grad=False) parameters are offloaded.
Trainable parameters (LoRA weights, etc.) remain on GPU at all times.
"""
def __init__(
self,
model: nn.Module,
num_prefetch: int = 1,
):
self.model = model
self.num_prefetch = num_prefetch
self._hooks: list = []
self._device = None
# Find decoder layers
self.layers, layer_types = _find_decoder_layers(model)
if self.layers is None:
LOG.warning(
"LayerOffloadManager: no decoder layers found, offloading disabled"
)
self.enabled = False
return
self.enabled = True
self.n_layers = len(self.layers)
LOG.info(
f"Layer offloading: found {self.n_layers} layers ({', '.join(layer_types)})"
)
# Determine GPU device
for p in model.parameters():
if p.device.type == "cuda":
self._device = p.device
break
if self._device is None:
LOG.warning("LayerOffloadManager: no CUDA parameters found")
self.enabled = False
return
# Transfer stream for async prefetch
self._transfer_stream = torch.cuda.Stream(device=self._device)
# Track which layers have their frozen params on GPU
self._on_gpu: set[int] = set(range(self.n_layers))
# Cache: frozen param references per layer (list of (name, param) tuples)
self._frozen_params: list[list[tuple[str, nn.Parameter]]] = [
_get_frozen_params(self.layers[i]) for i in range(self.n_layers)
]
# CPU storage: pinned tensors for each layer's frozen params
# Populated on first offload
self._cpu_data: list[dict[str, torch.Tensor]] = [
{} for _ in range(self.n_layers)
]
# Offload all layers upfront
self._offload_all()
# Release cached memory blocks back to the driver
torch.cuda.empty_cache()
def _offload_all(self):
"""Move all frozen params in all decoder layers to CPU."""
mem_before = torch.cuda.memory_allocated(self._device)
for i in range(self.n_layers):
self._offload_layer(i)
mem_after = torch.cuda.memory_allocated(self._device)
freed = (mem_before - mem_after) / 1e6
LOG.info(
f"Layer offloading: offloaded frozen params from {self.n_layers} layers, "
f"freed {freed:.0f} MB GPU memory"
)
def _offload_layer(self, idx: int):
"""Move frozen params of layer idx to CPU pinned memory."""
if idx not in self._on_gpu:
return
for name, param in self._frozen_params[idx]:
if param.device.type != "cuda":
continue
# Allocate pinned CPU tensor on first offload
if name not in self._cpu_data[idx]:
self._cpu_data[idx][name] = torch.empty_like(
param.data, device="cpu", pin_memory=True
)
cpu_buf = self._cpu_data[idx][name]
# Async copy GPU -> CPU (on transfer stream for overlap)
cpu_buf.copy_(param.data, non_blocking=True)
# Point parameter at a dummy CPU tensor to free GPU memory
param.data = cpu_buf
self._on_gpu.discard(idx)
def _load_layer(self, idx: int, stream=None):
"""Move frozen params of layer idx back to GPU."""
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
return
ctx = (
torch.cuda.stream(stream)
if stream is not None
else contextlib.nullcontext()
)
with ctx:
for _name, param in self._frozen_params[idx]:
if param.device.type == "cuda":
continue
gpu_data = param.data.to(self._device, non_blocking=True)
param.data = gpu_data
self._on_gpu.add(idx)
def _prefetch_layer(self, idx: int):
"""Async prefetch layer idx on the transfer stream."""
if idx in self._on_gpu or idx < 0 or idx >= self.n_layers:
return
self._transfer_stream.wait_stream(torch.cuda.default_stream(self._device))
self._load_layer(idx, stream=self._transfer_stream)
def _wait_transfer(self):
"""Make default stream wait for any in-flight transfers."""
torch.cuda.default_stream(self._device).wait_stream(self._transfer_stream)
def setup_hooks(self):
"""Register forward and backward hooks on each decoder layer."""
if not self.enabled:
return
for idx in range(self.n_layers):
layer = self.layers[idx]
def make_pre_fwd(i):
def hook(module, args):
# Ensure this layer is on GPU
if i not in self._on_gpu:
self._load_layer(i)
self._wait_transfer()
# Prefetch next layer(s)
for offset in range(1, self.num_prefetch + 1):
self._prefetch_layer(i + offset)
return hook
def make_post_fwd(i):
def hook(module, args, output):
# Offload previous layer (no longer needed in forward)
if i > 0:
self._offload_layer(i - 1)
# Offload last layer after forward
if i == self.n_layers - 1:
self._offload_layer(i)
return hook
def make_pre_bwd(i):
def hook(module, grad_output):
# Load this layer for backward
if i not in self._on_gpu:
self._load_layer(i)
self._wait_transfer()
# Prefetch previous layer(s)
for offset in range(1, self.num_prefetch + 1):
self._prefetch_layer(i - offset)
return hook
def make_post_bwd(i):
def hook(module, grad_input, grad_output):
# Offload the layer above
if i < self.n_layers - 1:
self._offload_layer(i + 1)
# Offload first layer after backward
if i == 0:
self._offload_layer(i)
return hook
h1 = layer.register_forward_pre_hook(make_pre_fwd(idx))
h2 = layer.register_forward_hook(make_post_fwd(idx))
h3 = layer.register_full_backward_pre_hook(make_pre_bwd(idx))
h4 = layer.register_full_backward_hook(make_post_bwd(idx))
self._hooks.extend([h1, h2, h3, h4])
def remove_hooks(self):
"""Remove all hooks and restore layers to GPU."""
for h in self._hooks:
h.remove()
self._hooks.clear()
if self.enabled:
for i in range(self.n_layers):
if i not in self._on_gpu:
self._load_layer(i)
def pre_step(self):
"""Called before each training step — ensure layers start offloaded."""
if not self.enabled:
return
for i in list(self._on_gpu):
self._offload_layer(i)
# Prefetch layer 0 for forward
self._prefetch_layer(0)
def post_step(self):
"""Called after each training step — ensure layers are offloaded."""
if not self.enabled:
return
for i in list(self._on_gpu):
self._offload_layer(i)
# Prefetch layer 0 for next step
self._prefetch_layer(0)
class _LayerOffloadContext:
"""Context manager wrapping pre_step / post_step around a training step."""
def __init__(self, manager: LayerOffloadManager):
self.manager = manager
def __enter__(self):
self.manager.pre_step()
return self
def __exit__(self, *args):
self.manager.post_step()
class LayerOffloadingMixin(Trainer):
"""
Trainer mixin class for layer-wise parameter offloading to CPU.
Offloads frozen decoder layer params to CPU at init, then streams them
on/off GPU one layer at a time during each training step.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if getattr(self.args, "layer_offloading", False):
LOG.info("Layer parameter offloading enabled")
self._layer_offload_manager = LayerOffloadManager(
model=self.model,
num_prefetch=1,
)
self._layer_offload_manager.setup_hooks()
self._layer_offload_ctx = _LayerOffloadContext(self._layer_offload_manager)
else:
self._layer_offload_manager = None
self._layer_offload_ctx = contextlib.nullcontext()
def training_step(self, *args, **kwargs):
with self._layer_offload_ctx:
return super().training_step(*args, **kwargs)

View File

@@ -104,7 +104,7 @@ class OptimizerMixin(Trainer):
return optimizer_grouped_parameters
def create_optimizer(self, model=None):
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
@@ -112,9 +112,9 @@ class OptimizerMixin(Trainer):
and self.args.lr_groups is None
and self.optimizer_cls_and_kwargs is None
):
return super().create_optimizer(model=model)
return super().create_optimizer()
opt_model = self.model if model is None else model
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if (
not self.optimizer

View File

@@ -25,7 +25,7 @@ class SchedulerMixin(Trainer):
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: None | torch.optim.Optimizer = None
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
) -> LRScheduler:
"""
Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or
@@ -45,13 +45,6 @@ class SchedulerMixin(Trainer):
and self.args.cosine_min_lr_ratio is not None
)
if optimizer is None:
if self.optimizer is None:
raise ValueError(
"Optimizer must be set before calling create_scheduler or passed as an argument."
)
optimizer = self.optimizer
# fmt: off
if self.lr_scheduler is None: # type: ignore
# fmt: on

Some files were not shown because too many files have changed in this diff Show More