Compare commits
4 Commits
scattermoe
...
fix/gemma3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53a12282bc | ||
|
|
7271754902 | ||
|
|
6d5257d92e | ||
|
|
0e357b5df6 |
9
.github/CONTRIBUTING.md
vendored
9
.github/CONTRIBUTING.md
vendored
@@ -68,12 +68,7 @@ You can skip certain CI checks by including specific keywords in your commit mes
|
||||
|
||||
### Code Style
|
||||
|
||||
axolotl uses [Ruff](https://docs.astral.sh/ruff/) as its code style guide. Please ensure that your code follows these guidelines.
|
||||
|
||||
Use the pre-commit linter to ensure that your code is formatted consistently.
|
||||
```bash
|
||||
pre-commit run --all-files
|
||||
```
|
||||
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
|
||||
|
||||
### Commit Messages
|
||||
|
||||
@@ -83,6 +78,6 @@ Write clear and concise commit messages that briefly describe the changes made i
|
||||
|
||||
- [GitHub Help](https://help.github.com/)
|
||||
- [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests)
|
||||
- [Ruff](https://docs.astral.sh/ruff/)
|
||||
- [{codestyle}]({URLofCodestyle})
|
||||
|
||||
Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together!
|
||||
|
||||
79
.github/workflows/base.yml
vendored
79
.github/workflows/base.yml
vendored
@@ -15,9 +15,6 @@ on:
|
||||
- '.github/workflows/base.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build-base:
|
||||
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
|
||||
@@ -54,30 +51,14 @@ jobs:
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
- cuda: "129"
|
||||
cuda_version: 12.9.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: "129"
|
||||
# cuda_version: 12.9.1
|
||||
# cudnn_version: ""
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
# dockerfile: "Dockerfile-base"
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
@@ -94,14 +75,6 @@ jobs:
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
# cudnn_version: ""
|
||||
@@ -127,7 +100,7 @@ jobs:
|
||||
images: |
|
||||
axolotlai/axolotl-base
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v2
|
||||
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
@@ -135,7 +108,7 @@ jobs:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/${{ matrix.dockerfile }}
|
||||
@@ -176,14 +149,6 @@ jobs:
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
@@ -192,30 +157,14 @@ jobs:
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
- cuda: "129"
|
||||
cuda_version: 12.9.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
pytorch: 2.9.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: "129"
|
||||
# cuda_version: 12.9.1
|
||||
# cudnn_version: ""
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
# dockerfile: "Dockerfile-uv-base"
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
@@ -232,14 +181,6 @@ jobs:
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: "130"
|
||||
cuda_version: 13.0.0
|
||||
cudnn_version: ""
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
torch_cuda_arch_list: "9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -250,7 +191,7 @@ jobs:
|
||||
images: |
|
||||
axolotlai/axolotl-base-uv
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v2
|
||||
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
@@ -258,7 +199,7 @@ jobs:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/${{ matrix.dockerfile }}
|
||||
|
||||
3
.github/workflows/lint.yml
vendored
3
.github/workflows/lint.yml
vendored
@@ -13,9 +13,6 @@ on:
|
||||
- ".pre-commit-config.yaml"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
|
||||
196
.github/workflows/main.yml
vendored
196
.github/workflows/main.yml
vendored
@@ -8,9 +8,6 @@ on:
|
||||
- "v*"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
@@ -37,28 +34,16 @@ jobs:
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 129
|
||||
cuda_version: 12.9.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras:
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
@@ -101,89 +86,11 @@ jobs:
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-uv:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
is_latest: true
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
axolotlai/axolotl-uv
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=pep440,pattern={{version}}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
|
||||
- name: Build and export to Docker
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
||||
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
|
||||
file: ./docker/Dockerfile-uv
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-cloud:
|
||||
needs: build-axolotl
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
@@ -205,28 +112,16 @@ jobs:
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
- cuda: 129
|
||||
cuda_version: 12.9.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras:
|
||||
# platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
@@ -264,86 +159,11 @@ jobs:
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-cloud-uv:
|
||||
needs: build-axolotl-uv
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras:
|
||||
platforms: "linux/amd64,linux/arm64"
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
axolotlai/axolotl-cloud-uv
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=pep440,pattern={{version}}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
platforms: ${{ matrix.platforms }}
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
file: ./docker/Dockerfile-cloud-uv
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
build-axolotl-cloud-no-tmux:
|
||||
needs: build-axolotl
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
|
||||
34
.github/workflows/multi-gpu-e2e.yml
vendored
34
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -8,7 +8,6 @@ on:
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
- 'scripts/cutcrossentropy_install.py'
|
||||
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
|
||||
- 'src/axolotl/utils/distributed.py'
|
||||
workflow_dispatch:
|
||||
@@ -20,9 +19,6 @@ concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
MODAL_IMAGE_BUILDER_VERSION: "2025.06"
|
||||
|
||||
@@ -39,13 +35,19 @@ jobs:
|
||||
pytorch: 2.8.0
|
||||
axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
# - cuda: 129
|
||||
# cuda_version: 12.9.1
|
||||
# python_version: "3.12"
|
||||
# pytorch: 2.9.1
|
||||
# axolotl_extras: "fbgemm-gpu"
|
||||
# num_gpus: 2
|
||||
# dockerfile: "Dockerfile-uv.jinja"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras: "fbgemm-gpu"
|
||||
num_gpus: 2
|
||||
- cuda: 129
|
||||
cuda_version: 12.9.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
axolotl_extras: "fbgemm-gpu"
|
||||
num_gpus: 2
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
@@ -53,13 +55,6 @@ jobs:
|
||||
axolotl_extras:
|
||||
# axolotl_extras: fbgemm-gpu
|
||||
num_gpus: 2
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
axolotl_extras: "fbgemm-gpu"
|
||||
num_gpus: 2
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
@@ -81,9 +76,8 @@ jobs:
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run -m cicd.multigpu
|
||||
|
||||
3
.github/workflows/nightlies.yml
vendored
3
.github/workflows/nightlies.yml
vendored
@@ -5,9 +5,6 @@ on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
|
||||
2
.github/workflows/precommit-autoupdate.yml
vendored
2
.github/workflows/precommit-autoupdate.yml
vendored
@@ -5,8 +5,6 @@ on:
|
||||
- cron: '0 0 1 * *' # Run monthly
|
||||
workflow_dispatch: # Manual kickoff
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
auto-update:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
8
.github/workflows/preview-docs.yml
vendored
8
.github/workflows/preview-docs.yml
vendored
@@ -14,8 +14,14 @@ on:
|
||||
- .github/workflows/preview-docs.yml
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
checks: write
|
||||
contents: write
|
||||
deployments: write
|
||||
issues: write
|
||||
discussions: write
|
||||
pages: write
|
||||
pull-requests: write
|
||||
statuses: write
|
||||
|
||||
jobs:
|
||||
preview:
|
||||
|
||||
9
.github/workflows/pypi.yml
vendored
9
.github/workflows/pypi.yml
vendored
@@ -3,11 +3,9 @@ name: publish pypi
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
setup_release:
|
||||
name: Create Release
|
||||
@@ -30,8 +28,7 @@ jobs:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/axolotl
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
@@ -49,7 +46,7 @@ jobs:
|
||||
|
||||
- name: Extract tag name
|
||||
id: tag
|
||||
run: echo "TAG_NAME=$(echo $GITHUB_REF | cut -d / -f 3)" >> "$GITHUB_OUTPUT"
|
||||
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
|
||||
|
||||
- name: Update version in VERSION file
|
||||
run: |
|
||||
|
||||
44
.github/workflows/tests-nightly.yml
vendored
44
.github/workflows/tests-nightly.yml
vendored
@@ -3,13 +3,6 @@ on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
paths:
|
||||
- '.github/workflows/tests-nightly.yml'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
@@ -25,26 +18,15 @@ jobs:
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
prime-cdn-s3-cache:
|
||||
name: Prefetch S3 once to prime the CDN cache
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
curl -v -H "Range: bytes=0-1023" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
runs-on: ubuntu-latest
|
||||
needs: [prime-cdn-s3-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -55,7 +37,7 @@ jobs:
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p /home/runner/.cache/huggingface/hub
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -66,7 +48,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel
|
||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -120,23 +102,16 @@ jobs:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.9.1
|
||||
pytorch: 2.8.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
nightly_build: "true"
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -157,11 +132,9 @@ jobs:
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
docker-e2e-multigpu-tests:
|
||||
@@ -202,8 +175,7 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.multigpu
|
||||
|
||||
74
.github/workflows/tests.yml
vendored
74
.github/workflows/tests.yml
vendored
@@ -28,9 +28,6 @@ concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
TRANSFORMERS_IS_CI: "yes"
|
||||
|
||||
@@ -49,32 +46,21 @@ jobs:
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
prime-cdn-s3-cache:
|
||||
name: Prefetch S3 once to prime the CDN cache
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
curl -v -H "Range: bytes=0-1023" -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst > /dev/null
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
needs: [prime-cdn-s3-cache]
|
||||
# needs: [preload-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||
# exclude:
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.8.0"
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.9.1"
|
||||
python_version: ["3.11", "3.12"]
|
||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||
exclude:
|
||||
- python_version: "3.12"
|
||||
pytorch_version: "2.8.0"
|
||||
- python_version: "3.12"
|
||||
pytorch_version: "2.9.0"
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -89,7 +75,7 @@ jobs:
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p ~/.cache/huggingface/hub
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
ls -ltr ~/.cache/huggingface/hub/
|
||||
|
||||
- name: Setup Python
|
||||
@@ -160,18 +146,17 @@ jobs:
|
||||
name: PyTest from Source Dist
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ !github.event.pull_request.draft }}
|
||||
needs: [prime-cdn-s3-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||
pytorch_version: ["2.8.0", "2.9.1", "2.10.0"]
|
||||
# exclude:
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.8.0"
|
||||
# - python_version: "3.14"
|
||||
# pytorch_version: "2.9.1"
|
||||
timeout-minutes: 30
|
||||
python_version: ["3.11", "3.12"]
|
||||
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
|
||||
exclude:
|
||||
- python_version: "3.12"
|
||||
pytorch_version: "2.8.0"
|
||||
- python_version: "3.12"
|
||||
pytorch_version: "2.9.0"
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: cleanup node
|
||||
@@ -185,7 +170,7 @@ jobs:
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p ~/.cache/huggingface/hub
|
||||
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
|
||||
ls -ltr ~/.cache/huggingface/hub/
|
||||
|
||||
- name: Setup Python
|
||||
@@ -279,8 +264,8 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
- cuda: 129
|
||||
cuda_version: 12.9.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
@@ -306,10 +291,9 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
@@ -342,12 +326,6 @@ jobs:
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.10.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 130
|
||||
cuda_version: 13.0.0
|
||||
python_version: "3.11"
|
||||
@@ -375,10 +353,9 @@ jobs:
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
env:
|
||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
@@ -392,9 +369,9 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
- cuda: 129
|
||||
cuda_version: 12.9.1
|
||||
python_version: "3.12"
|
||||
pytorch: 2.9.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
@@ -418,6 +395,7 @@ jobs:
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.cleanup
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -193,3 +193,6 @@ out/
|
||||
|
||||
# scm auto-versioning
|
||||
src/axolotl/_version.py
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
|
||||
@@ -11,7 +11,7 @@ repos:
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.15.4
|
||||
rev: v0.14.10
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
@@ -26,7 +26,7 @@ repos:
|
||||
'pydantic>=2.5.3',
|
||||
]
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.9.4
|
||||
rev: 1.9.2
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: [
|
||||
|
||||
32
README.md
32
README.md
@@ -29,23 +29,8 @@
|
||||
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2026/03:
|
||||
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
|
||||
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
|
||||
- 2026/02:
|
||||
- [ScatterMoE LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3410) support. LoRA fine-tuning directly on MoE expert weights using custom Triton kernels.
|
||||
- Axolotl now has support for [SageAttention](https://github.com/axolotl-ai-cloud/axolotl/pull/2823) and [GDPO](https://github.com/axolotl-ai-cloud/axolotl/pull/3353) (Generalized DPO).
|
||||
- 2026/01:
|
||||
- New integration for [EAFT](https://github.com/axolotl-ai-cloud/axolotl/pull/3366) (Entropy-Aware Focal Training), weights loss by entropy of the top-k logit distribution, and [Scalable Softmax](https://github.com/axolotl-ai-cloud/axolotl/pull/3338), improves long context in attention.
|
||||
- 2025/12:
|
||||
- Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
|
||||
- [Distributed Muon Optimizer](https://github.com/axolotl-ai-cloud/axolotl/pull/3264) support has been added for FSDP2 pretraining.
|
||||
- 2025/12: Axolotl now includes support for [Kimi-Linear](https://docs.axolotl.ai/docs/models/kimi-linear.html), [Plano-Orchestrator](https://docs.axolotl.ai/docs/models/plano.html), [MiMo](https://docs.axolotl.ai/docs/models/mimo.html), [InternVL 3.5](https://docs.axolotl.ai/docs/models/internvl3_5.html), [Olmo3](https://docs.axolotl.ai/docs/models/olmo3.html), [Trinity](https://docs.axolotl.ai/docs/models/trinity.html), and [Ministral3](https://docs.axolotl.ai/docs/models/ministral3.html).
|
||||
- 2025/10: New model support has been added in Axolotl for: [Qwen3 Next](https://docs.axolotl.ai/docs/models/qwen3-next.html), [Qwen2.5-vl, Qwen3-vl](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen2_5-vl), [Qwen3, Qwen3MoE](https://docs.axolotl.ai/docs/models/qwen3.html), [Granite 4](https://docs.axolotl.ai/docs/models/granite4.html), [HunYuan](https://docs.axolotl.ai/docs/models/hunyuan.html), [Magistral 2509](https://docs.axolotl.ai/docs/models/magistral/vision.html), [Apertus](https://docs.axolotl.ai/docs/models/apertus.html), and [Seed-OSS](https://docs.axolotl.ai/docs/models/seed-oss.html).
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Expand older updates</summary>
|
||||
|
||||
- 2025/09: Axolotl now has text diffusion training. Read more [here](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/diffusion).
|
||||
- 2025/08: QAT has been updated to include NVFP4 support. See [PR](https://github.com/axolotl-ai-cloud/axolotl/pull/3107).
|
||||
- 2025/07:
|
||||
@@ -54,10 +39,15 @@
|
||||
- FP8 finetuning with fp8 gather op is now possible in Axolotl via `torchao`. Get started [here](https://docs.axolotl.ai/docs/mixed_precision.html#sec-fp8)!
|
||||
- [Voxtral](https://docs.axolotl.ai/docs/models/voxtral.html), [Magistral 1.1](https://docs.axolotl.ai/docs/models/magistral.html), and [Devstral](https://docs.axolotl.ai/docs/models/devstral.html) with mistral-common tokenizer support has been integrated in Axolotl!
|
||||
- TiledMLP support for single-GPU to multi-GPU training with DDP, DeepSpeed and FSDP support has been added to support Arctic Long Sequence Training. (ALST). See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/alst) for using ALST with Axolotl!
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
|
||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Expand older updates</summary>
|
||||
|
||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [docs](https://docs.axolotl.ai/docs/models/magistral.html) to start training your own Magistral models with Axolotl!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [docs](https://docs.axolotl.ai/docs/models/llama-4.html) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
||||
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
||||
@@ -72,10 +62,10 @@ Axolotl is a free and open-source tool designed to streamline post-training and
|
||||
Features:
|
||||
|
||||
- **Multiple Model Support**: Train various models like GPT-OSS, LLaMA, Mistral, Mixtral, Pythia, and many more models available on the Hugging Face Hub.
|
||||
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, GLM-4.6V, InternVL 3.5, Gemma 3n, and audio models like Voxtral with image, video, and audio support.
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO, GDPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||
- **Multimodal Training**: Fine-tune vision-language models (VLMs) including LLaMA-Vision, Qwen2-VL, Pixtral, LLaVA, SmolVLM2, and audio models like Voxtral with image, video, and audio support.
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||
- **Easy Configuration**: Re-use a single YAML configuration file across the full fine-tuning pipeline: dataset preprocessing, training, evaluation, quantization, and inference.
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention 2/3/4](https://docs.axolotl.ai/docs/attention.html#flash-attention), [Xformers](https://docs.axolotl.ai/docs/attention.html#xformers), [Flex Attention](https://docs.axolotl.ai/docs/attention.html#flex-attention), [SageAttention](https://docs.axolotl.ai/docs/attention.html#sageattention), [Liger Kernel](https://docs.axolotl.ai/docs/custom_integrations.html#liger-kernels), [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy), [ScatterMoE](https://docs.axolotl.ai/docs/custom_integrations.html#kernels-integration), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||
|
||||
|
||||
@@ -128,9 +128,11 @@ quartodoc:
|
||||
- monkeypatch.mistral_attn_hijack_flash
|
||||
- monkeypatch.multipack
|
||||
- monkeypatch.relora
|
||||
- monkeypatch.llama_expand_mask
|
||||
- monkeypatch.lora_kernels
|
||||
- monkeypatch.utils
|
||||
- monkeypatch.btlm_attn_hijack_flash
|
||||
- monkeypatch.llama_patch_multipack
|
||||
- monkeypatch.stablelm_attn_hijack_flash
|
||||
- monkeypatch.trainer_fsdp_optim
|
||||
- monkeypatch.transformers_fa_utils
|
||||
@@ -329,7 +331,6 @@ website:
|
||||
- docs/sequence_parallelism.qmd
|
||||
- docs/gradient_checkpointing.qmd
|
||||
- docs/nd_parallelism.qmd
|
||||
- docs/expert_quantization.qmd
|
||||
|
||||
- section: "Troubleshooting"
|
||||
contents:
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
"""Benchmark for entropy_from_logits Triton kernel vs original chunked implementation.
|
||||
|
||||
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_entropy.py
|
||||
"""
|
||||
|
||||
import gc
|
||||
import statistics
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from axolotl.monkeypatch.trainer.utils import entropy_from_logits
|
||||
|
||||
V = 151936 # Qwen vocab
|
||||
WARMUP = 5
|
||||
BENCH_ITERS = 20
|
||||
MEM_ITERS = 10
|
||||
|
||||
|
||||
def entropy_from_logits_original(logits: torch.Tensor, chunk_size: int = 128):
|
||||
"""Original chunked implementation (reference)."""
|
||||
original_shape = logits.shape[:-1]
|
||||
num_classes = logits.shape[-1]
|
||||
flat_logits = logits.reshape(-1, num_classes)
|
||||
entropies = []
|
||||
for chunk in flat_logits.split(chunk_size, dim=0):
|
||||
logps = F.log_softmax(chunk, dim=-1)
|
||||
chunk_entropy = -(torch.exp(logps) * logps).sum(-1)
|
||||
entropies.append(chunk_entropy)
|
||||
return torch.cat(entropies, dim=0).reshape(original_shape)
|
||||
|
||||
|
||||
def _clean_gpu():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_accumulated_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def profile_time(fn, logits, n_iters=BENCH_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
out = fn(logits, chunk_size=128)
|
||||
del out
|
||||
torch.cuda.synchronize()
|
||||
|
||||
times = []
|
||||
for _ in range(n_iters):
|
||||
s = torch.cuda.Event(enable_timing=True)
|
||||
e = torch.cuda.Event(enable_timing=True)
|
||||
s.record()
|
||||
out = fn(logits, chunk_size=128)
|
||||
e.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(s.elapsed_time(e))
|
||||
del out
|
||||
return times
|
||||
|
||||
|
||||
def profile_memory(fn, logits, n_iters=MEM_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
out = fn(logits, chunk_size=128)
|
||||
del out
|
||||
torch.cuda.synchronize()
|
||||
|
||||
peaks = []
|
||||
for _ in range(n_iters):
|
||||
_clean_gpu()
|
||||
base = torch.cuda.max_memory_allocated()
|
||||
out = fn(logits, chunk_size=128)
|
||||
torch.cuda.synchronize()
|
||||
peaks.append(torch.cuda.max_memory_allocated() - base)
|
||||
del out
|
||||
return [p / 1e6 for p in peaks]
|
||||
|
||||
|
||||
def fmt(values, unit=""):
|
||||
mean = statistics.mean(values)
|
||||
std = statistics.stdev(values) if len(values) > 1 else 0.0
|
||||
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
|
||||
|
||||
|
||||
def benchmark_contiguous():
|
||||
print("=" * 60)
|
||||
print(
|
||||
f"CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
|
||||
)
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(1, 2048),
|
||||
(1, 8192),
|
||||
(1, 16384),
|
||||
(4, 4096),
|
||||
(8, 2048),
|
||||
(16, 2048),
|
||||
(16, 4096),
|
||||
]
|
||||
|
||||
for B, L in configs:
|
||||
mem_gb = B * L * V * 2 / 1e9
|
||||
if mem_gb > 28:
|
||||
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
t_orig = profile_time(entropy_from_logits_original, logits)
|
||||
t_triton = profile_time(entropy_from_logits, logits)
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" TIME (ms):")
|
||||
print(f" original: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton: {fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(entropy_from_logits_original, logits)
|
||||
m_triton = profile_memory(entropy_from_logits, logits)
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" MEMORY (peak overhead):")
|
||||
print(f" original: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton: {fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del logits
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
def benchmark_noncontiguous():
|
||||
print("\n" + "=" * 60)
|
||||
print(
|
||||
f"NON-CONTIGUOUS BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})"
|
||||
)
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(4, 2048, "transpose"),
|
||||
(4, 8192, "transpose"),
|
||||
(8, 2048, "transpose"),
|
||||
(4, 4096, "slice_batch"),
|
||||
]
|
||||
|
||||
for B, L, method in configs:
|
||||
torch.manual_seed(42)
|
||||
|
||||
if method == "transpose":
|
||||
raw = torch.randn(L, B, V, device="cuda", dtype=torch.bfloat16)
|
||||
logits_nc = raw.transpose(0, 1)
|
||||
raw_gb = L * B * V * 2 / 1e9
|
||||
elif method == "slice_batch":
|
||||
raw = torch.randn(B * 2, L, V, device="cuda", dtype=torch.bfloat16)
|
||||
logits_nc = raw[::2]
|
||||
raw_gb = B * 2 * L * V * 2 / 1e9
|
||||
else:
|
||||
continue
|
||||
|
||||
if raw_gb > 28:
|
||||
print(f"\n skip B={B}, L={L}, {method} ({raw_gb:.1f} GB)")
|
||||
del raw, logits_nc
|
||||
torch.cuda.empty_cache()
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B}, L={L} {method} ({N} rows, raw {raw_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
def original_with_copy(logits, chunk_size=128):
|
||||
return entropy_from_logits_original(
|
||||
logits.contiguous(), chunk_size=chunk_size
|
||||
)
|
||||
|
||||
t_orig = profile_time(original_with_copy, logits_nc)
|
||||
t_triton = profile_time(entropy_from_logits, logits_nc)
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" TIME (ms):")
|
||||
print(f" orig+copy: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton-strided:{fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(original_with_copy, logits_nc)
|
||||
m_triton = profile_memory(entropy_from_logits, logits_nc)
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" MEMORY (peak overhead):")
|
||||
print(f" orig+copy: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton-strided:{fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del raw, logits_nc
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_contiguous()
|
||||
benchmark_noncontiguous()
|
||||
@@ -1,284 +0,0 @@
|
||||
"""Benchmark for ScatterMoE LoRA Triton kernels.
|
||||
|
||||
Measures forward, backward dX, and backward dA/dB kernels at common MoE
|
||||
model shapes. Reports per-kernel timings, LoRA overhead vs base scatter2scatter,
|
||||
and full fwd+bwd autograd throughput.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --ranks 16 64
|
||||
CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_scattermoe_lora.py --models Qwen/Qwen3.5-35B-A3B
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.kernels import (
|
||||
lora_ops,
|
||||
ops as base_ops,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_experts import (
|
||||
flatten_sort_count,
|
||||
)
|
||||
from axolotl.integrations.kernels.libs.scattermoe_lora.parallel_linear_lora import (
|
||||
ScatterMoELoRA,
|
||||
)
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
WARMUP = 5
|
||||
ITERS = 20
|
||||
|
||||
# ─── Model configs ──────────────────────────────────────────────────────────
|
||||
|
||||
BUILTIN_CONFIGS = {
|
||||
"Qwen3.5-35B-A3B": (256, 2048, 512, 8), # E, H, I, k
|
||||
"Qwen3-30B-A3B": (128, 2048, 768, 8),
|
||||
"OLMoE-1B-7B": (64, 2048, 1024, 8),
|
||||
"Mixtral-8x7B": (8, 4096, 14336, 2),
|
||||
}
|
||||
|
||||
|
||||
def _resolve_config(spec):
|
||||
"""Resolve a model spec to (E, H, I, k). Accepts builtin names or HF IDs."""
|
||||
key = spec.lower().replace("/", "-")
|
||||
for name, cfg in BUILTIN_CONFIGS.items():
|
||||
if key in name.lower() or name.lower() in key:
|
||||
return name, cfg
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
hf_cfg = AutoConfig.from_pretrained(spec, trust_remote_code=True)
|
||||
if callable(getattr(hf_cfg, "get_text_config", None)):
|
||||
tc = hf_cfg.get_text_config()
|
||||
if hasattr(tc, "model_type") and tc.model_type != hf_cfg.model_type:
|
||||
hf_cfg = tc
|
||||
hidden = hf_cfg.hidden_size
|
||||
inter = getattr(hf_cfg, "moe_intermediate_size", None) or hf_cfg.intermediate_size
|
||||
experts = (
|
||||
getattr(hf_cfg, "num_experts", None)
|
||||
or getattr(hf_cfg, "num_local_experts", None)
|
||||
or getattr(hf_cfg, "n_routed_experts", None)
|
||||
)
|
||||
top_k = (
|
||||
getattr(hf_cfg, "num_experts_per_tok", None)
|
||||
or getattr(hf_cfg, "num_experts_per_token", None)
|
||||
or 2
|
||||
)
|
||||
name = spec.split("/")[-1]
|
||||
return name, (experts, hidden, inter, top_k)
|
||||
|
||||
|
||||
# ─── Benchmark helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _clean():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def _bench(fn, warmup=WARMUP, iters=ITERS):
|
||||
for _ in range(warmup):
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
times = []
|
||||
for _ in range(iters):
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
fn()
|
||||
torch.cuda.synchronize()
|
||||
times.append((time.perf_counter() - t0) * 1000)
|
||||
times.sort()
|
||||
return times[len(times) // 2]
|
||||
|
||||
|
||||
def _setup(num_experts, K, N, T, top_k, R):
|
||||
torch.manual_seed(42)
|
||||
x = torch.randn(T, K, device=DEVICE, dtype=DTYPE)
|
||||
W = torch.randn(num_experts, K, N, device=DEVICE, dtype=DTYPE) * 0.02
|
||||
lora_A = torch.randn(R * num_experts, K, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
lora_B = torch.randn(N, R * num_experts, device=DEVICE, dtype=DTYPE) * 0.01
|
||||
logits = torch.randn(T, num_experts, device=DEVICE)
|
||||
_, top_idx = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1)
|
||||
sei, ssi, eo = flatten_sort_count(top_idx, num_experts)
|
||||
gx = base_ops.group(x, ssi, fan_out=top_k)
|
||||
dy = torch.randn(gx.size(0), N, device=DEVICE, dtype=DTYPE)
|
||||
return x, W, lora_A, lora_B, sei, ssi, eo, gx, dy
|
||||
|
||||
|
||||
# ─── Kernel wrappers (avoid B023 loop-variable capture) ──────────────────────
|
||||
|
||||
|
||||
def _call_fwd(x, W, sei, ssi, top_k, lA, lB):
|
||||
return lora_ops.scatter2scatter_lora(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=top_k,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
|
||||
def _call_base(x, W, sei, ssi, top_k):
|
||||
return base_ops.scatter2scatter(
|
||||
X=x,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=top_k,
|
||||
)
|
||||
|
||||
|
||||
def _call_dx(dy, W, sei, ssi, lA, lB):
|
||||
return lora_ops.scatter2scatter_lora_dX(
|
||||
DY=dy,
|
||||
W=W,
|
||||
sorted_expert_idxs=sei,
|
||||
sorted_scattered_idxs=ssi,
|
||||
k=1,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
scaling=2.0,
|
||||
dy_grouped=True,
|
||||
dx_grouped=False,
|
||||
)
|
||||
|
||||
|
||||
def _call_bwd(dy, gx, lA, lB, eo, num_experts):
|
||||
return lora_ops.group_bwd_lora(
|
||||
DY=dy,
|
||||
X=gx,
|
||||
lora_A=lA,
|
||||
lora_B=lB,
|
||||
expert_offsets=eo,
|
||||
E=num_experts,
|
||||
scaling=2.0,
|
||||
)
|
||||
|
||||
|
||||
# ─── Main ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ScatterMoE LoRA kernel benchmark")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
"-m",
|
||||
nargs="+",
|
||||
help="Model names or HF IDs (default: all builtins)",
|
||||
)
|
||||
parser.add_argument("--ranks", "-r", nargs="+", type=int, default=[16, 32, 64])
|
||||
parser.add_argument("--seq-len", "-T", type=int, default=2048)
|
||||
args = parser.parse_args()
|
||||
|
||||
T = args.seq_len
|
||||
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||
print(f"T={T}, ranks={args.ranks}\n")
|
||||
|
||||
if args.models:
|
||||
configs = [_resolve_config(m) for m in args.models]
|
||||
else:
|
||||
configs = list(BUILTIN_CONFIGS.items())
|
||||
|
||||
for model_name, (num_experts, hidden, inter, top_k) in configs:
|
||||
print(f"{'=' * 70}")
|
||||
print(f" {model_name}: E={num_experts}, H={hidden}, I={inter}, k={top_k}")
|
||||
print(f"{'=' * 70}")
|
||||
|
||||
for R in args.ranks:
|
||||
for proj, K, N in [("gate_up", hidden, 2 * inter), ("down", inter, hidden)]:
|
||||
_clean()
|
||||
x, W, lA, lB, sei, ssi, eo, gx, dy = _setup(
|
||||
num_experts, K, N, T, top_k, R
|
||||
)
|
||||
|
||||
# Forward with LoRA (auto-dispatched: fused or split)
|
||||
dispatch = (
|
||||
"split"
|
||||
if (
|
||||
num_experts <= lora_ops._SPLIT_LORA_FWD_MAX_EXPERTS
|
||||
and K * N >= lora_ops._SPLIT_LORA_FWD_THRESHOLD
|
||||
)
|
||||
else "fused"
|
||||
)
|
||||
t_fwd = _bench(partial(_call_fwd, x, W, sei, ssi, top_k, lA, lB))
|
||||
t_base = _bench(partial(_call_base, x, W, sei, ssi, top_k))
|
||||
t_dx = _bench(partial(_call_dx, dy, W, sei, ssi, lA, lB))
|
||||
t_bwd = _bench(partial(_call_bwd, dy, gx, lA, lB, eo, num_experts))
|
||||
|
||||
total = t_fwd + t_dx + t_bwd
|
||||
overhead = t_fwd / t_base - 1 if t_base > 0 else 0
|
||||
|
||||
print(
|
||||
f" R={R:>2} {proj:<8} "
|
||||
f"fwd={t_fwd:>6.2f}ms [{dispatch}] "
|
||||
f"base={t_base:>6.2f}ms "
|
||||
f"(+{overhead * 100:.0f}%) "
|
||||
f"dx={t_dx:>6.2f}ms bwd={t_bwd:>6.2f}ms "
|
||||
f"total={total:>6.2f}ms"
|
||||
)
|
||||
|
||||
# Full autograd fwd+bwd with memory measurement
|
||||
x_ag = x.clone().requires_grad_(True)
|
||||
lA_ag = lA.clone().requires_grad_(True)
|
||||
lB_ag = lB.clone().requires_grad_(True)
|
||||
|
||||
def _run_autograd(
|
||||
_x=x_ag,
|
||||
_W=W,
|
||||
_k=top_k,
|
||||
_sei=sei,
|
||||
_ssi=ssi,
|
||||
_eo=eo,
|
||||
_lA=lA_ag,
|
||||
_lB=lB_ag,
|
||||
):
|
||||
out = ScatterMoELoRA.apply(
|
||||
_x,
|
||||
_W,
|
||||
_k,
|
||||
_sei,
|
||||
_ssi,
|
||||
_eo,
|
||||
_lA,
|
||||
_lB,
|
||||
2.0,
|
||||
None,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
)
|
||||
out.sum().backward()
|
||||
_x.grad = None
|
||||
_lA.grad = None
|
||||
_lB.grad = None
|
||||
|
||||
t_full = _bench(_run_autograd)
|
||||
|
||||
_clean()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
mem_before = torch.cuda.memory_allocated()
|
||||
_run_autograd()
|
||||
torch.cuda.synchronize()
|
||||
mem_peak = torch.cuda.max_memory_allocated() - mem_before
|
||||
|
||||
print(
|
||||
f" full_fwd_bwd={t_full:>6.2f}ms "
|
||||
f"peak_delta={mem_peak / 1e6:>6.1f}MB"
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,191 +0,0 @@
|
||||
"""Benchmark for selective_log_softmax Triton kernel vs original implementation.
|
||||
|
||||
Usage: CUDA_VISIBLE_DEVICES=0 python benchmarks/bench_selective_logsoftmax.py
|
||||
"""
|
||||
|
||||
import gc
|
||||
import statistics
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.trainer.utils import (
|
||||
selective_log_softmax,
|
||||
selective_log_softmax_original,
|
||||
)
|
||||
|
||||
V = 151936 # Qwen vocab
|
||||
WARMUP = 5
|
||||
BENCH_ITERS = 20
|
||||
MEM_ITERS = 10
|
||||
|
||||
|
||||
def _clean_gpu():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_accumulated_memory_stats()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def profile_time(fn, args, n_iters=BENCH_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
times = []
|
||||
for _ in range(n_iters):
|
||||
s = torch.cuda.Event(enable_timing=True)
|
||||
e = torch.cuda.Event(enable_timing=True)
|
||||
s.record()
|
||||
fn(*args)
|
||||
e.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(s.elapsed_time(e))
|
||||
return times
|
||||
|
||||
|
||||
def profile_memory(fn, args, n_iters=MEM_ITERS):
|
||||
for _ in range(WARMUP):
|
||||
out = fn(*args)
|
||||
del out
|
||||
torch.cuda.synchronize()
|
||||
|
||||
peaks = []
|
||||
for _ in range(n_iters):
|
||||
_clean_gpu()
|
||||
base = torch.cuda.max_memory_allocated()
|
||||
out = fn(*args)
|
||||
torch.cuda.synchronize()
|
||||
peaks.append(torch.cuda.max_memory_allocated() - base)
|
||||
del out
|
||||
return [p / 1e6 for p in peaks]
|
||||
|
||||
|
||||
def fmt(values, unit=""):
|
||||
mean = statistics.mean(values)
|
||||
std = statistics.stdev(values) if len(values) > 1 else 0.0
|
||||
return f"{mean:8.2f} ± {std:5.2f} {unit} [min={min(values):.2f}, max={max(values):.2f}]"
|
||||
|
||||
|
||||
def benchmark_forward():
|
||||
print("=" * 60)
|
||||
print(f"FORWARD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(1, 2048),
|
||||
(1, 8192),
|
||||
(4, 4096),
|
||||
(8, 2048),
|
||||
(16, 2048),
|
||||
(16, 4096),
|
||||
]
|
||||
|
||||
for B, L in configs:
|
||||
mem_gb = B * L * V * 2 / 1e9
|
||||
if mem_gb > 28:
|
||||
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB)")
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits = torch.randn(B, L, V, device="cuda", dtype=torch.bfloat16)
|
||||
index = torch.randint(0, V, (B, L), device="cuda")
|
||||
|
||||
t_orig = profile_time(selective_log_softmax_original, (logits, index))
|
||||
t_triton = profile_time(selective_log_softmax, (logits, index))
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" TIME (ms):")
|
||||
print(f" original: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton: {fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(selective_log_softmax_original, (logits, index))
|
||||
m_triton = profile_memory(selective_log_softmax, (logits, index))
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" MEMORY (peak overhead):")
|
||||
print(f" original: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton: {fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del logits, index
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
def benchmark_backward():
|
||||
print("\n" + "=" * 60)
|
||||
print(f"FWD+BWD BENCHMARK (warmup={WARMUP}, time={BENCH_ITERS}, mem={MEM_ITERS})")
|
||||
print("=" * 60)
|
||||
|
||||
configs = [
|
||||
(1, 2048),
|
||||
(1, 8192),
|
||||
(4, 4096),
|
||||
(8, 2048),
|
||||
(16, 2048),
|
||||
(16, 4096),
|
||||
]
|
||||
|
||||
def fwd_bwd_original(logits, index):
|
||||
logits.grad = None
|
||||
out = selective_log_softmax_original(logits, index)
|
||||
out.sum().backward()
|
||||
|
||||
def fwd_bwd_triton(logits, index):
|
||||
logits.grad = None
|
||||
out = selective_log_softmax(logits, index)
|
||||
out.sum().backward()
|
||||
|
||||
for B, L in configs:
|
||||
mem_gb = B * L * V * 2 / 1e9
|
||||
if mem_gb > 20:
|
||||
print(f"\n skip B={B}, L={L} ({mem_gb:.1f} GB, need room for grads)")
|
||||
continue
|
||||
|
||||
N = B * L
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"B={B:2d}, L={L:5d} ({N:6d} rows, logits {mem_gb:.2f} GB)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
logits_orig = torch.randn(
|
||||
B, L, V, device="cuda", dtype=torch.bfloat16, requires_grad=True
|
||||
)
|
||||
logits_tri = logits_orig.detach().clone().requires_grad_(True)
|
||||
index = torch.randint(0, V, (B, L), device="cuda")
|
||||
|
||||
t_orig = profile_time(fwd_bwd_original, (logits_orig, index))
|
||||
t_triton = profile_time(fwd_bwd_triton, (logits_tri, index))
|
||||
orig_mean = statistics.mean(t_orig)
|
||||
triton_mean = statistics.mean(t_triton)
|
||||
|
||||
print(" FWD+BWD TIME (ms):")
|
||||
print(f" original: {fmt(t_orig, 'ms')}")
|
||||
print(f" triton: {fmt(t_triton, 'ms')}")
|
||||
print(f" speedup: {orig_mean / triton_mean:.2f}x")
|
||||
|
||||
m_orig = profile_memory(fwd_bwd_original, (logits_orig, index))
|
||||
m_triton = profile_memory(fwd_bwd_triton, (logits_tri, index))
|
||||
orig_peak = statistics.mean(m_orig)
|
||||
triton_peak = statistics.mean(m_triton)
|
||||
|
||||
print(" FWD+BWD MEMORY (peak overhead):")
|
||||
print(f" original: {fmt(m_orig, 'MB')}")
|
||||
print(f" triton: {fmt(m_triton, 'MB')}")
|
||||
print(f" saved: {orig_peak - triton_peak:.1f} MB")
|
||||
|
||||
del logits_orig, logits_tri, index
|
||||
_clean_gpu()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_forward()
|
||||
benchmark_backward()
|
||||
@@ -11,7 +11,7 @@ ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
@@ -31,9 +31,8 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN uv pip install packaging==26.0 setuptools==78.1.1
|
||||
RUN uv pip install packaging==26.0 setuptools==75.8.0
|
||||
RUN uv pip install torchvision
|
||||
RUN uv pip uninstall causal_conv1d
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -12,7 +12,7 @@ ENV HF_HOME="{{ HF_HOME }}"
|
||||
ENV AXOLOTL_DATASET_NUM_PROC="8"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
@@ -32,8 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN pip install packaging==26.0 setuptools==78.1.1 psutil
|
||||
RUN pip uninstall -y causal_conv1d
|
||||
RUN pip install packaging==26.0 setuptools==75.8.0 psutil
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -3,14 +3,6 @@ set -e
|
||||
|
||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||
|
||||
set -o pipefail
|
||||
curl --silent --show-error --fail --retry 3 --retry-delay 5 -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C "${HF_HOME}/hub/" --use-compress-program unzstd --strip-components=1
|
||||
# hf download "NousResearch/Meta-Llama-3-8B"
|
||||
# hf download "NousResearch/Meta-Llama-3-8B-Instruct"
|
||||
# hf download "microsoft/Phi-4-reasoning"
|
||||
# hf download "microsoft/Phi-3.5-mini-instruct"
|
||||
# hf download "microsoft/Phi-3-medium-128k-instruct"
|
||||
|
||||
# Run unit tests with initial coverage report
|
||||
pytest -v --durations=10 -n8 \
|
||||
--ignore=tests/e2e/ \
|
||||
|
||||
@@ -68,6 +68,10 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
|
||||
if exit_code:
|
||||
raise RuntimeError(f"Command '{cmd}' failed with exit code {exit_code}")
|
||||
try:
|
||||
exit_code = subprocess.call(cmd.split(), cwd=run_folder, env=sp_env) # nosec
|
||||
if exit_code:
|
||||
print(f"Command '{cmd}' failed with exit code {exit_code}")
|
||||
return exit_code
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
print(f"Command '{cmd}' failed with exception {e}")
|
||||
|
||||
@@ -37,7 +37,6 @@ coverage:
|
||||
only_pulls: false
|
||||
flags: null
|
||||
paths: null
|
||||
informational: true
|
||||
|
||||
parsers:
|
||||
gcov:
|
||||
|
||||
@@ -22,7 +22,6 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||
RUN pip uninstall -y causal_conv1d
|
||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
else \
|
||||
|
||||
@@ -59,18 +59,34 @@ RUN git lfs install --skip-repo && \
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||
pip3 cache purge
|
||||
|
||||
# Map Python version (e.g., 3.12 -> cp312)
|
||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
||||
# Map architecture
|
||||
case "$TARGETARCH" in \
|
||||
amd64) ARCH_TAG="x86_64" ;; \
|
||||
arm64) ARCH_TAG="aarch64" ;; \
|
||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
||||
esac && \
|
||||
WHL_VERSION="v0.7.16" && \
|
||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
||||
pip3 install --no-cache-dir "${WHL_FILE}" && \
|
||||
rm "${WHL_FILE}"
|
||||
RUN case "$PYTORCH_VERSION" in \
|
||||
2.9.[0-9]*) \
|
||||
if [ "$CUDA" = "128" ]; then \
|
||||
if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl"; \
|
||||
WHL_VERSION="v0.5.4"; \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl"; \
|
||||
WHL_VERSION="v0.6.4"; \
|
||||
else \
|
||||
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
|
||||
fi; \
|
||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
|
||||
pip3 install --no-cache-dir ${WHL_FILE}; \
|
||||
rm ${WHL_FILE}; \
|
||||
elif [ "$CUDA" = "130" ]; then \
|
||||
if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl"; \
|
||||
WHL_VERSION="v0.5.4"; \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl"; \
|
||||
WHL_VERSION="v0.6.4"; \
|
||||
else \
|
||||
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
|
||||
fi; \
|
||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
|
||||
pip3 install --no-cache-dir ${WHL_FILE}; \
|
||||
rm ${WHL_FILE}; \
|
||||
fi \
|
||||
;; \
|
||||
esac
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
ARG BASE_TAG=main
|
||||
FROM axolotlai/axolotl-uv:$BASE_TAG
|
||||
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||
|
||||
EXPOSE 8888
|
||||
EXPOSE 22
|
||||
|
||||
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
||||
COPY scripts/motd /etc/motd
|
||||
|
||||
RUN uv pip install jupyterlab notebook ipywidgets && \
|
||||
jupyter lab clean
|
||||
RUN apt update && \
|
||||
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
||||
rm -rf /var/cache/apt/archives && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
mkdir -p ~/.ssh && \
|
||||
chmod 700 ~/.ssh && \
|
||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
||||
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
|
||||
chmod +x /root/cloud-entrypoint.sh && \
|
||||
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
|
||||
|
||||
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
|
||||
CMD ["sleep", "infinity"]
|
||||
@@ -1,48 +0,0 @@
|
||||
ARG BASE_TAG=main-base
|
||||
FROM axolotlai/axolotl-base-uv:$BASE_TAG
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
ARG AXOLOTL_ARGS=""
|
||||
ARG CUDA="118"
|
||||
ARG PYTORCH_VERSION="2.1.2"
|
||||
ARG TARGETARCH
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
|
||||
rm -rf /var/cache/apt/archives && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||
RUN uv pip uninstall causal_conv1d
|
||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
else \
|
||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
||||
fi && \
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||
fi && \
|
||||
python scripts/unsloth_install.py --uv | sh && \
|
||||
python scripts/cutcrossentropy_install.py --uv | sh && \
|
||||
uv pip install pytest && \
|
||||
uv cache clean
|
||||
|
||||
# fix so that git fetch/pull from remote works with shallow clone
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch && \
|
||||
git config --global credential.helper store
|
||||
|
||||
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
|
||||
RUN chmod +x /root/.axolotl-complete.bash && \
|
||||
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc
|
||||
@@ -6,7 +6,6 @@ ARG TARGETARCH
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ARG TARGETARCH
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="2.6.0"
|
||||
ARG CUDA="126"
|
||||
@@ -40,18 +39,28 @@ RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
|
||||
fi
|
||||
|
||||
# Map Python version (e.g., 3.12 -> cp312)
|
||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
||||
# Map architecture
|
||||
case "$TARGETARCH" in \
|
||||
amd64) ARCH_TAG="x86_64" ;; \
|
||||
arm64) ARCH_TAG="aarch64" ;; \
|
||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
||||
esac && \
|
||||
WHL_VERSION="v0.7.16" && \
|
||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
||||
uv pip install --no-cache-dir "${WHL_FILE}" && \
|
||||
rm "${WHL_FILE}"
|
||||
RUN case "$PYTORCH_VERSION" in \
|
||||
2.9.[0-9]*) \
|
||||
if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
if [ "$CUDA" = "128" ]; then \
|
||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
elif [ "$CUDA" = "130" ]; then \
|
||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||
fi \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
if [ "$CUDA" = "128" ]; then \
|
||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||
elif [ "$CUDA" = "130" ]; then \
|
||||
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.6.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl; \
|
||||
fi \
|
||||
fi \
|
||||
;; \
|
||||
esac
|
||||
|
||||
@@ -13,10 +13,9 @@ sdp_attention: true
|
||||
|
||||
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
|
||||
## Flash Attention
|
||||
## Flash Attention 2
|
||||
|
||||
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
|
||||
based on your installed packages and GPU.
|
||||
Uses efficient kernels to compute attention.
|
||||
|
||||
```yaml
|
||||
flash_attention: true
|
||||
@@ -24,9 +23,11 @@ flash_attention: true
|
||||
|
||||
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
||||
|
||||
### Flash Attention 2
|
||||
### Nvidia
|
||||
|
||||
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
||||
Requirements: Ampere, Ada, or Hopper GPUs
|
||||
|
||||
Note: For Turing GPUs or lower, please use other attention methods.
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
@@ -34,12 +35,11 @@ pip install flash-attn --no-build-isolation
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl.
|
||||
Alternatively, try reinstall or downgrade a version.
|
||||
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
|
||||
|
||||
:::
|
||||
|
||||
### Flash Attention 3
|
||||
#### Flash Attention 3
|
||||
|
||||
Requirements: Hopper only and CUDA 12.8 (recommended)
|
||||
|
||||
@@ -50,44 +50,6 @@ cd flash-attention/hopper
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
### Flash Attention 4
|
||||
|
||||
Requirements: Hopper or Blackwell GPUs
|
||||
|
||||
```bash
|
||||
pip install flash-attn-4
|
||||
```
|
||||
|
||||
Or from source:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||
cd flash-attention/flash_attn/cute
|
||||
|
||||
pip install -e .
|
||||
|
||||
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
||||
# Remove it so Python can find the real FA4 module:
|
||||
rm -r $(python -c "import flash_attn; print(flash_attn.__path__[0])")/cute
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
**Hopper (SM90) users**: The backward kernel is not yet included in the pip package. To use FA4
|
||||
for training on Hopper, install from source using the instructions above.
|
||||
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
FA4 only supports head dimensions up to 128 (`d ≤ 128`). The DeepSeek shape `(192, 128)` is
|
||||
also supported but only on Blackwell. Axolotl automatically detects incompatible head dimensions
|
||||
and falls back to FA2/3.
|
||||
|
||||
:::
|
||||
|
||||
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
|
||||
|
||||
### AMD
|
||||
|
||||
Requirements: ROCm 6.0 and above.
|
||||
|
||||
@@ -210,8 +210,6 @@ axolotl lm-eval config.yml
|
||||
Configuration options:
|
||||
|
||||
```yaml
|
||||
lm_eval_model: # model to evaluate (local or hf path)
|
||||
|
||||
# List of tasks to evaluate
|
||||
lm_eval_tasks:
|
||||
- arc_challenge
|
||||
@@ -220,7 +218,7 @@ lm_eval_batch_size: # Batch size for evaluation
|
||||
output_dir: # Directory to save evaluation results
|
||||
```
|
||||
|
||||
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
|
||||
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
|
||||
|
||||
### delinearize-llama4
|
||||
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
---
|
||||
title: "MoE Expert Quantization"
|
||||
description: "Reduce VRAM usage when training MoE model adapters by quantizing expert weights on load"
|
||||
---
|
||||
|
||||
Transformers v5 changed MoE expert layers from `nn.Linear` to fused `nn.Parameter` (3D+ tensors).
|
||||
This means `bitsandbytes` can no longer quantize them during model loading, resulting in all expert
|
||||
weights being loaded in full bf16 precision and causing massive VRAM usage.
|
||||
|
||||
`quantize_moe_experts` solves this by quantizing expert weights during model loading.
|
||||
It intercepts the weight loading process, quantizes each expert tensor on the fly, and
|
||||
immediately frees the original bf16 tensor from VRAM. This dramatically reduces peak memory.
|
||||
For example, GLM-4.7-Flash QLoRA drops from ~127GiB to ~23GiB reserved memory.
|
||||
|
||||
## Usage
|
||||
|
||||
Enable expert quantization in your Axolotl config:
|
||||
|
||||
```yaml
|
||||
quantize_moe_experts: true
|
||||
```
|
||||
|
||||
This works with both 4-bit (QLoRA) and 8-bit (LoRA) quantization.
|
||||
|
||||
### Expert LoRA targeting
|
||||
|
||||
You can optionally apply LoRA adapters directly to expert weights using `lora_target_parameters`:
|
||||
|
||||
```yaml
|
||||
lora_target_parameters:
|
||||
- mlp.experts.gate_up_proj
|
||||
- mlp.experts.down_proj
|
||||
# - mlp.gate.weight # router
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
`lora_dropout` must be `0` when using `lora_target_parameters`.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- Requires (`adapter: lora` and `load_in_8bit: true`) or (`adapter: qlora` and `load_in_4bit: true`)
|
||||
- CUDA GPUs only (not tested with ROCm or other backends)
|
||||
- FSDP2 compatible for distributed training
|
||||
|
||||
## Limitations
|
||||
|
||||
- `lora_target_linear` is not compatible with `quantize_moe_experts`. See [Expert LoRA targeting](#expert-lora-targeting) instead.
|
||||
- `cpu_ram_efficient_loading` hangs / takes long time with FSDP2 + QLoRA.
|
||||
- Total model parameter count may display incorrectly (trainable param count is correct).
|
||||
- FSDP LoRA (8-bit) may have a large initial VRAM spike at the first 1-2 steps, which then drops. QLoRA does not exhibit this.
|
||||
- FSDP2 may use more VRAM per GPU than single GPU training due to not all layers being properly sharded across ranks.
|
||||
- Model loading takes longer due to on-demand quantization, even on consecutive runs.
|
||||
- DeepSpeed has not been tested.
|
||||
|
||||
## Implementation details
|
||||
|
||||
The quantization is applied by patching transformers to intercept weight loading.
|
||||
When a 3D+ CUDA tensor with "expert" in its name is detected:
|
||||
|
||||
- **4-bit mode:** Uses bitsandbytes NF4 parametrization (configurable via `bnb_4bit_quant_type`).
|
||||
- **8-bit mode:** Uses a custom row-wise int8 parametrization with bitsandbytes dequantization.
|
||||
|
||||
The original bf16 tensor is freed immediately after quantization. Multiple sub-patches are applied to
|
||||
transformers, PEFT and accelerate FSDP2 to support these parametrized expert modules.
|
||||
|
||||
For full implementation details, see [PR #3439](https://github.com/axolotl-ai-cloud/axolotl/pull/3439).
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Gradient Checkpointing, Activation Offloading, and Layer Offloading
|
||||
title: Gradient Checkpointing and Activation Offloading
|
||||
---
|
||||
|
||||
Gradient checkpointing and activation offloading are techniques used to optimize the performance of deep learning
|
||||
@@ -27,33 +27,3 @@ The `activation_offloading: legacy` naively offloads activations to CPU and with
|
||||
|
||||
For resource constrained environments with limited CPU memory, `activation_offloading: disk` offloads
|
||||
activations to disk instead of CPU RAM so that much larger context lengths can be trained with minimal memory.
|
||||
|
||||
### Enabling Layer Offloading
|
||||
|
||||
```yaml
|
||||
layer_offloading: true
|
||||
```
|
||||
|
||||
Layer offloading reduces GPU memory usage by moving frozen (non-trainable) decoder layer parameters to CPU
|
||||
and streaming them back to GPU one layer at a time during the forward and backward passes. This is
|
||||
particularly useful for LoRA/QLoRA training where most of the model's parameters are frozen — only the
|
||||
trainable adapter weights stay on GPU permanently.
|
||||
|
||||
During training, forward and backward hooks on each decoder layer handle the transfer automatically:
|
||||
|
||||
- **Forward pass:** Before a layer executes, its frozen params are loaded to GPU. The next layer is
|
||||
prefetched asynchronously on a separate CUDA stream for overlap.
|
||||
- **Backward pass:** Same pattern in reverse — the current layer's frozen params are loaded and the
|
||||
previous layer is prefetched.
|
||||
|
||||
After each layer finishes, its frozen params are offloaded back to CPU pinned memory.
|
||||
|
||||
This approach trades some CPU-GPU transfer overhead for significant GPU memory savings — the freed memory
|
||||
is roughly equal to the size of all frozen parameters across all decoder layers, minus one layer's worth
|
||||
that is kept on GPU at any given time.
|
||||
|
||||
**Requirements:**
|
||||
|
||||
- CUDA GPU (CPU-only training is not supported for this feature)
|
||||
- Works with any HuggingFace model architecture that uses decoder layers (Llama, Mistral, Qwen, etc.)
|
||||
- Best combined with LoRA/QLoRA where most parameters are frozen
|
||||
|
||||
@@ -13,14 +13,12 @@ format:
|
||||
- [Pixtral](#sec-pixtral)
|
||||
- [Llava-1.5](#sec-llava-15)
|
||||
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||
- [Mistral-Small-4](#sec-mistral-small-4)
|
||||
- [Magistral-Small-2509](#sec-magistral-small-2509)
|
||||
- [Voxtral](#sec-voxtral)
|
||||
- [Gemma-3](#sec-gemma-3)
|
||||
- [Gemma-3n](#sec-gemma-3n)
|
||||
- [Qwen2-VL](#sec-qwen2-vl)
|
||||
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||
- [Qwen3.5](#sec-qwen3-5)
|
||||
- [GLM-4.6V](#sec-glm-4-6v)
|
||||
- [SmolVLM2](#sec-smolvlm2)
|
||||
- [LFM2-VL](#sec-lfm2-vl)
|
||||
@@ -110,12 +108,6 @@ Please make sure to install vision lib via `pip install 'mistral-common[opencv]=
|
||||
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||
```
|
||||
|
||||
### Mistral-Small-4 {#sec-mistral-small-4}
|
||||
|
||||
```yaml
|
||||
base_model: mistralai/Mistral-Small-4-119B-2603
|
||||
```
|
||||
|
||||
### Magistral-Small-2509 {#sec-magistral-small-2509}
|
||||
|
||||
::: {.callout-tip}
|
||||
@@ -192,14 +184,6 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
|
||||
chat_template: qwen2_vl # same as qwen2-vl
|
||||
```
|
||||
|
||||
### Qwen3.5 {#sec-qwen3-5}
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
|
||||
chat_template: qwen3_5
|
||||
```
|
||||
|
||||
### GLM-4.6V {#sec-glm-4-6v}
|
||||
|
||||
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
|
||||
|
||||
@@ -54,13 +54,6 @@ These techniques save VRAM by changing how activations are handled.
|
||||
- Activation Offloading: moves activations to CPU RAM or disk, trading I/O overhead for VRAM.
|
||||
- Learn more: [Gradient Checkpointing and Offloading Docs](gradient_checkpointing.qmd)
|
||||
|
||||
### Layer Offloading
|
||||
|
||||
Offloads frozen (non-trainable) decoder layer parameters to CPU and streams them back to GPU one layer at a time during forward/backward passes using CUDA stream prefetching. Especially effective for LoRA/QLoRA where most parameters are frozen.
|
||||
|
||||
- **Config:** `layer_offloading: true`
|
||||
- **Learn more:** [Layer Offloading Docs](gradient_checkpointing.qmd#enabling-layer-offloading)
|
||||
|
||||
### Cut Cross Entropy (CCE)
|
||||
|
||||
Reduces VRAM usage by using an optimized cross-entropy loss calculation.
|
||||
@@ -73,15 +66,6 @@ Provides efficient Triton kernels to improve training speed and reduce memory us
|
||||
|
||||
- **Learn more:** [Custom Integrations - Liger Kernels](custom_integrations.qmd#liger-kernels)
|
||||
|
||||
### Expert Kernels
|
||||
|
||||
Optimized kernel implementations for Mixture of Experts (MoE) model training.
|
||||
|
||||
- **ScatterMoE**: Triton-based MoE kernels with fused LoRA support.
|
||||
- **SonicMoE**: CUTLASS-based MoE kernels for NVIDIA Hopper and Blackwell GPUs.
|
||||
|
||||
- **Learn more:** [Custom Integrations - Kernels Integration](custom_integrations.qmd#kernels-integration)
|
||||
|
||||
## Long Context Models
|
||||
|
||||
Techniques to train models on sequences longer than their original context window.
|
||||
@@ -147,10 +131,3 @@ Simulates quantization effects during training, helping the model adapt and pote
|
||||
Allows you to finetune LoRA adapters on top of a model that has already been quantized using the GPTQ method.
|
||||
|
||||
- **Example:** [GPTQ LoRA Example](https://github.com/axolotl-ai-cloud/axolotl/blob/main/examples/llama-2/gptq-lora.yml)
|
||||
|
||||
### MoE Expert Quantization
|
||||
|
||||
Quantizes MoE expert weights on load to reduce VRAM when training MoE models with adapters. Required for Transformers v5+ MoE models where experts use fused `nn.Parameter` tensors.
|
||||
|
||||
- **Config:** `quantize_moe_experts: true`
|
||||
- **Learn more:** [MoE Expert Quantization](expert_quantization.qmd)
|
||||
|
||||
207
docs/rlhf.qmd
207
docs/rlhf.qmd
@@ -721,213 +721,6 @@ trl:
|
||||
|
||||
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
|
||||
|
||||
#### Async GRPO
|
||||
|
||||
Async GRPO overlaps vLLM generation with training by producing rollouts in a background thread. While the model trains on the current batch, the next batch is already being generated. This can significantly reduce wall-clock time per step.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
use_data_producer: true # Enable data producer protocol
|
||||
use_vllm: true
|
||||
async_prefetch: true # Generate rollouts in background thread
|
||||
prefetch_depth: 1 # Number of rollouts to prefetch
|
||||
vllm_sync_interval: 2 # Sync weights to vLLM every N steps
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
Because the background thread generates completions with slightly stale model weights, async GRPO uses importance sampling correction to account for the distribution shift. This is controlled by `vllm_importance_sampling_correction: true` (default when async is enabled).
|
||||
:::
|
||||
|
||||
##### vLLM LoRA Sync
|
||||
|
||||
By default, weight sync to vLLM merges the LoRA adapter into the base model and broadcasts all parameters via NCCL. LoRA sync is a faster alternative that saves only the adapter weights to the filesystem and has vLLM load them natively using Punica kernels.
|
||||
|
||||
```yaml
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_linear: true
|
||||
|
||||
trl:
|
||||
vllm_lora_sync: true # Enable native LoRA sync
|
||||
```
|
||||
|
||||
When `vllm_lora_sync: true` is set, axolotl automatically selects the LoRA-aware vLLM serve module. Start vLLM as usual:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
```
|
||||
|
||||
Then start training on a separate GPU:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
LoRA sync is especially beneficial with multi-GPU training (FSDP/DeepSpeed), where NCCL merge-sync can cause GPU contention with vLLM generation.
|
||||
:::
|
||||
|
||||
##### Streaming Partial Batch
|
||||
|
||||
Instead of scoring the entire batch at once, streaming mode scores one prompt group at a time. This enables finer-grained zero-advantage skipping and reduces peak memory usage during scoring.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
streaming_partial_batch: true
|
||||
```
|
||||
|
||||
##### Importance Sampling Correction
|
||||
|
||||
When using async prefetch, completions are generated from a slightly older version of the model. Importance sampling (IS) correction adjusts the policy gradient to account for this distribution shift.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
vllm_importance_sampling_correction: true # Enable IS correction
|
||||
importance_sampling_level: token # 'token' or 'sequence'
|
||||
off_policy_mask_threshold: 0.5 # Mask sequences with IS ratio below this
|
||||
```
|
||||
|
||||
- `importance_sampling_level: token` applies per-token IS ratios (recommended with Liger kernel)
|
||||
- `importance_sampling_level: sequence` applies per-sequence IS ratios
|
||||
- `off_policy_mask_threshold` masks out sequences where the IS ratio indicates they are too far off-policy
|
||||
|
||||
##### Replay Buffer
|
||||
|
||||
The replay buffer caches rollout groups that had learning signal (non-zero reward variance) and uses them to replace zero-signal groups in later batches.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
replay_buffer_size: 100 # Max cached groups (0 = disabled)
|
||||
replay_recompute_logps: true # Recompute log-probs for replayed data (recommended)
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
When `replay_recompute_logps: true` (default), old log-probabilities are recomputed using the current model weights. This fixes the IS mismatch that would otherwise occur when replaying stale data.
|
||||
:::
|
||||
|
||||
##### Deferred Re-rolling
|
||||
|
||||
Failed prompts (where the model produces zero reward for all generations) are buffered and re-injected into later batches when the model may be better equipped to solve them.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reroll_start_fraction: 0.5 # Start re-rolling after 50% of training
|
||||
reroll_max_groups: 1 # Max groups to replace per batch
|
||||
```
|
||||
|
||||
##### Zero-Advantage Batch Skipping
|
||||
|
||||
When all advantages in a micro-batch are zero (no learning signal), the forward/backward pass is skipped entirely. This is enabled by default and logged as `skipped_zero_adv_batches=1`.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
skip_zero_advantage_batches: true # default
|
||||
```
|
||||
|
||||
##### Parallel Reward Workers
|
||||
|
||||
Reward functions that use `signal.alarm()` (e.g., `math_verify`) must run in the main thread. Parallel reward workers use subprocesses to work around this limitation while enabling concurrent reward computation.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
reward_num_workers: 4 # Number of subprocess workers (1 = no parallelism)
|
||||
```
|
||||
|
||||
##### Full Async GRPO Example
|
||||
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||
|
||||
vllm:
|
||||
host: 0.0.0.0
|
||||
port: 8000
|
||||
gpu_memory_utilization: 0.35
|
||||
dtype: auto
|
||||
|
||||
adapter: lora
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_target_linear: true
|
||||
|
||||
rl: grpo
|
||||
trl:
|
||||
use_data_producer: true
|
||||
use_vllm: true
|
||||
async_prefetch: true
|
||||
prefetch_depth: 1
|
||||
vllm_sync_interval: 2
|
||||
vllm_lora_sync: true
|
||||
streaming_partial_batch: true
|
||||
vllm_importance_sampling_correction: true
|
||||
off_policy_mask_threshold: 0.5
|
||||
importance_sampling_level: token
|
||||
num_generations: 8
|
||||
max_completion_length: 512
|
||||
reward_funcs:
|
||||
- rewards.accuracy_reward
|
||||
reroll_start_fraction: 0.5
|
||||
replay_buffer_size: 100
|
||||
reward_num_workers: 4
|
||||
skip_zero_advantage_batches: true
|
||||
|
||||
datasets:
|
||||
- path: AI-MO/NuminaMath-TIR
|
||||
type: rewards.prompt_transform
|
||||
split: train
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
max_steps: 500
|
||||
learning_rate: 1e-5
|
||||
bf16: true
|
||||
gradient_checkpointing: true
|
||||
```
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM on GPU 0
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
|
||||
# Terminal 2: Train on GPU 1
|
||||
CUDA_VISIBLE_DEVICES=1 axolotl train config.yaml
|
||||
```
|
||||
|
||||
##### Multi-GPU Async GRPO
|
||||
|
||||
Async GRPO supports FSDP and DeepSpeed ZeRO-3 for multi-GPU training. vLLM runs on one GPU while training is distributed across the remaining GPUs.
|
||||
|
||||
**FSDP:**
|
||||
|
||||
```yaml
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
```
|
||||
|
||||
**DeepSpeed ZeRO-3:**
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true # Required for ZeRO-3
|
||||
```
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start vLLM on GPU 0
|
||||
CUDA_VISIBLE_DEVICES=0 axolotl vllm-serve config.yaml
|
||||
|
||||
# Terminal 2: Train on GPUs 0,1
|
||||
CUDA_VISIBLE_DEVICES=0,1 accelerate launch --num_processes 2 -m axolotl.cli.train config.yaml
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
With multi-GPU async prefetch, only rank 0 generates completions in the background thread. Results are broadcast to all ranks on the main thread. This avoids FSDP/DeepSpeed collective deadlocks from unsynchronized background threads.
|
||||
:::
|
||||
|
||||
### GDPO
|
||||
|
||||
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
|
||||
|
||||
@@ -40,7 +40,7 @@
|
||||
"%%capture\n",
|
||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
|
||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
base_model: google/gemma-3-1b-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
@@ -24,11 +27,6 @@ datasets:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
base_model: google/gemma-3-270m-it
|
||||
|
||||
model_type: Gemma3ForCausalLM
|
||||
cls_model_config: Gemma3TextConfig
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
@@ -24,11 +27,6 @@ datasets:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
|
||||
adapter: qlora
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
base_model: google/gemma-3-4b-it
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
|
||||
# gemma3 doesn't seem to play nice with ddp
|
||||
@@ -20,18 +23,12 @@ dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
# Freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- ^model\.language_model\..*
|
||||
- ^lm_head\..*
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
# Finetune Z.ai's GLM-4.5-Air with Axolotl
|
||||
|
||||
[GLM-4.5-Air](https://huggingface.co/zai-org/GLM-4.5-Air) is a MoE model by Z.ai.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# QLoRA (1x80GB @ ~63.4GiB/GPU)
|
||||
axolotl train examples/glm45/glm-45-air-qlora.yaml
|
||||
```
|
||||
|
||||
### Dataset
|
||||
|
||||
In addition to the standard OpenAI Messages format, GLM-4.5 supports an extra parameter for thinking in the assistant section.
|
||||
|
||||
```json
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "...", // or have </think>...</think> in `content`
|
||||
"content": "..."
|
||||
}
|
||||
```
|
||||
|
||||
Make sure you set the below extra attributes if needed:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
|
||||
# tool_calls: tool_calls # uncomment if using tools
|
||||
# reasoning_content: reasoning_content # uncomment if have reasoning
|
||||
|
||||
# Uncomment if training on tool role (you would rarely if ever need this)
|
||||
# eot_tokens:
|
||||
# - <|observation|>
|
||||
```
|
||||
|
||||
### Tips
|
||||
|
||||
- The role name for tools in this template is `tool`.
|
||||
- You will see this Axolotl WARNING — this is expected as the template does not use EOS:
|
||||
```
|
||||
EOS token '<|endoftext|>' not found in chat_template. Please check if your template/EOS token is correct.
|
||||
```
|
||||
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config.
|
||||
- **LoRA kernels**: Incompatible with this model. Must be explicitly disabled (`lora_*_kernel: false`).
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [GLM-4.5-Air on HuggingFace](https://huggingface.co/zai-org/GLM-4.5-Air)
|
||||
- [GLM-4.5 Blog](https://z.ai/blog/glm-4.5)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -1,64 +0,0 @@
|
||||
base_model: zai-org/GLM-4.5-Air
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
quantize_moe_experts: true # important
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 8
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,65 +0,0 @@
|
||||
# Finetune Z.ai's GLM-4.7-Flash with Axolotl
|
||||
|
||||
[GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model by Z.ai.
|
||||
|
||||
This guide shows how to fine-tune it with Axolotl.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
# QLoRA
|
||||
# - no target experts (1x48GB @ ~24GiB/GPU)
|
||||
# - target experts (1x48GB @ ~34GiB/GPU)
|
||||
axolotl train examples/glm47-flash/qlora.yaml
|
||||
|
||||
# QLoRA FSDP2 no target experts (2x48GB @ ~29GiB/GPU)
|
||||
axolotl train examples/glm47-flash/qlora_fsdp.yaml
|
||||
```
|
||||
|
||||
```bash
|
||||
# LoRA
|
||||
# - no target experts (1x48GB @ ~35GiB/GPU)
|
||||
# - target experts (1x48GB @ OOM. Projected ~45-50GiB/GPU)
|
||||
axolotl train examples/glm47-flash/lora.yaml
|
||||
|
||||
# LoRA FSDP2 no target experts (2x48GB @ ~43GiB/GPU)
|
||||
axolotl train examples/glm47-flash/lora_fsdp.yaml
|
||||
```
|
||||
|
||||
### MoE Expert Quantization & Expert LoRA
|
||||
|
||||
This model quantize expert weights on load. To learn about expert quantization, expert LoRA targeting, and related limitations, see the [MoE Expert Quantization](https://docs.axolotl.ai/docs/expert_quantization.html) docs.
|
||||
|
||||
## Limitations
|
||||
|
||||
- **lora_target_linear**: Incompatible for this model.
|
||||
- **LoRA kernels**: Incompatible with this model due to non-standard attention projections (DSA). Must be explicitly disabled (`lora_*_kernel: false`).
|
||||
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, the official Z.ai team recommends these default settings (most tasks):
|
||||
- `temperature: 1.0`
|
||||
- `top_p: 0.95`
|
||||
- `max_new_tokens: 131072`
|
||||
- You can run a full finetuning by removing `adapter: qlora`, `load_in_4bit: true`, and `quantize_moe_experts: true` from the config. This is heavy, so we have not tested this.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [GLM-4.7-Flash on HuggingFace](https://huggingface.co/zai-org/GLM-4.7-Flash)
|
||||
- [GLM-4.7 Blog](https://z.ai/blog/glm-4.7)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -1,65 +0,0 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-lora-8bit-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
@@ -1,75 +0,0 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_8bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-lora-8bit-fsdp-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -1,65 +0,0 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
@@ -1,75 +0,0 @@
|
||||
base_model: zai-org/GLM-4.7-Flash
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/glm4.7-flash-qlora-fsdp-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# Uncomment to also target MoE expert weights:
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
# LoRA kernels incompatible with DSA attention
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Glm4MoeLiteDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -1,65 +0,0 @@
|
||||
base_model: meta-llama/Llama-3.2-3B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
datasets:
|
||||
- path: yahma/alpaca-cleaned
|
||||
type: alpaca
|
||||
split: train[:95%]
|
||||
|
||||
output_dir: ./outputs/qat_out/
|
||||
dataset_prepared_path: ./outputs/dataset_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
flash_attention: true
|
||||
|
||||
qat:
|
||||
activation_dtype: mxfp4
|
||||
weight_dtype: mxfp4
|
||||
group_size: 32
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_checkpointing: true
|
||||
activation_offloading: true
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_8bit
|
||||
|
||||
cosine_constant_lr_ratio: 0
|
||||
cosine_min_lr_ratio: 1.0
|
||||
learning_rate: 2e-5
|
||||
save_only_model: true
|
||||
bf16: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_ratio: 0.1
|
||||
weight_decay: 0.0
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
|
||||
# save_first_step: true # uncomment this to validate checkpoint saving works with your config
|
||||
@@ -1,82 +0,0 @@
|
||||
# Finetune Mistral Small 4 with Axolotl
|
||||
|
||||
Mistral Small 4 is a 119B parameter (6.5B active) multimodal MoE model from MistralAI that unifies instruct, reasoning, and coding capabilities into a single model. It is available on HuggingFace at [Mistral-Small-4-119B-2603](https://huggingface.co/mistralai/Mistral-Small-4-119B-2603).
|
||||
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for this release.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||
|
||||
3. Install transformers from main
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/transformers.git
|
||||
```
|
||||
|
||||
4. Run one of the example configs:
|
||||
|
||||
```bash
|
||||
# text-only
|
||||
axolotl train examples/mistral4/qlora-text.yml # no experts ~69 GiB, experts ~93 GiB
|
||||
axolotl train examples/mistral4/fft-text.yml
|
||||
|
||||
# text + vision
|
||||
# run: wget https://huggingface.co/datasets/Nanobit/text-vision-2k-test/resolve/main/African_elephant.jpg
|
||||
axolotl train examples/mistral4/qlora-vision.yml # no experts ~68 GiB
|
||||
axolotl train examples/mistral4/fft-vision.yml
|
||||
```
|
||||
|
||||
Note: FFT configs provided as reference. Please adjust hyperparameters as needed.
|
||||
|
||||
## Reasoning Effort
|
||||
|
||||
The chat template supports a `reasoning_effort` variable to control the model's reasoning depth:
|
||||
|
||||
- `"none"` — instruct mode (default)
|
||||
- `"high"` — reasoning mode with explicit thinking steps
|
||||
|
||||
Pass it via `chat_template_kwargs` under your dataset config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: your/dataset
|
||||
type: chat_template
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: high
|
||||
```
|
||||
|
||||
## Thinking Support
|
||||
|
||||
The chat template supports a `thinking` content type in assistant messages for training on reasoning traces (rendered as `[THINK]...[/THINK]` blocks).
|
||||
|
||||
To use thinking datasets, add the `thinking` mapping via `message_property_mappings`:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: your/thinking-dataset
|
||||
type: chat_template
|
||||
message_property_mappings:
|
||||
role: role
|
||||
content: content
|
||||
thinking: thinking
|
||||
chat_template_kwargs:
|
||||
reasoning_effort: high
|
||||
```
|
||||
|
||||
See the [Magistral thinking guide](../magistral/think/README.md) for dataset format details.
|
||||
|
||||
## Tips
|
||||
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The text dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- The vision model requires multi-modal dataset format as documented [here](https://docs.axolotl.ai/docs/multimodal.html#dataset-format).
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [MistralAI Mistral Small 4 Blog](https://mistral.ai/news/mistral-small-4)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -1,58 +0,0 @@
|
||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
use_kernels: true
|
||||
use_sonicmoe: true
|
||||
|
||||
# only train language model layers, freeze vision tower
|
||||
unfrozen_parameters:
|
||||
- model.language_model.*
|
||||
- lm_head
|
||||
- embed_tokens
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Mistral4DecoderLayer
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -1,57 +0,0 @@
|
||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
- axolotl.integrations.kernels.KernelsPlugin
|
||||
use_kernels: true
|
||||
use_sonicmoe: true
|
||||
|
||||
# vision requirements
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
datasets:
|
||||
- path: Nanobit/text-vision-2k-test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Mistral4DecoderLayer
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -1,58 +0,0 @@
|
||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
# uncomment to train on expert layers
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
# lora_mlp_kernel: false
|
||||
# lora_qkv_kernel: false
|
||||
# lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
@@ -1,63 +0,0 @@
|
||||
base_model: axolotl-ai-co/Mistral-Small-4-119B-2603-BF16
|
||||
processor_type: AutoProcessor
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
|
||||
# vision chat template requirements
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
datasets:
|
||||
- path: Nanobit/text-vision-2k-test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: qlora
|
||||
|
||||
sequence_len: 2048
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
# uncomment to train on expert layers
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
# lora_mlp_kernel: false
|
||||
# lora_qkv_kernel: false
|
||||
# lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
@@ -1,57 +0,0 @@
|
||||
base_model: nvidia/Nemotron-Mini-4B-Instruct
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/nemotron-mini-4b-qlora
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- up_proj
|
||||
- down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
special_tokens:
|
||||
@@ -1,12 +1,11 @@
|
||||
base_model: google/gemma-3-12b-it
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
|
||||
@@ -7,6 +7,7 @@ load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
base_model: google/gemma-3-12b-it
|
||||
# Math finetuning configuration for Gemma3-12B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
|
||||
@@ -7,6 +7,7 @@ load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
base_model: google/gemma-3-27b-it
|
||||
# Math finetuning configuration for Gemma3-27B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
|
||||
@@ -7,6 +7,7 @@ load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.gemma3.Gemma3TextFromMultimodalPlugin
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
|
||||
@@ -6,13 +6,30 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Qwen3-Next is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
||||
|
||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
```
|
||||
|
||||
2. Install Qwen3-Next transformers commit
|
||||
```bash
|
||||
pip3 uninstall -y transformers && pip3 install "git+https://github.com/huggingface/transformers.git@b9282355bea846b54ed850a066901496b19da654"
|
||||
```
|
||||
|
||||
3. Install FLA for improved performance
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.3.2
|
||||
```
|
||||
|
||||
4. Run the finetuning example:
|
||||
@@ -21,7 +38,7 @@ pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||
axolotl train examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about ~47 GiB (no target experts) and ~71GiB (target experts) VRAM.
|
||||
This config uses about 45.62 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
|
||||
@@ -9,8 +9,6 @@ plugins:
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
quantize_moe_experts: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
@@ -27,7 +25,7 @@ sample_packing: true
|
||||
|
||||
lora_r: 16
|
||||
lora_alpha: 8
|
||||
lora_dropout: 0
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- linear_attn.in_proj_ba
|
||||
- linear_attn.in_proj_qkvz
|
||||
@@ -36,19 +34,12 @@ lora_target_modules:
|
||||
- shared_expert.down_proj
|
||||
- shared_expert.gate_proj
|
||||
- shared_expert_gate
|
||||
- mlp.gate
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
base_model: Qwen/Qwen3.5-122B-A10B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: true
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -1,74 +0,0 @@
|
||||
base_model: Qwen/Qwen3.5-122B-A10B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -1,59 +0,0 @@
|
||||
base_model: Qwen/Qwen3.5-27B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
# Full fine-tune (FFT) of the text-only path of Qwen3.5-27B.
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
# Freeze vision encoder
|
||||
unfrozen_parameters:
|
||||
- model\.language_model\..*
|
||||
- lm_head\..*
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -1,81 +0,0 @@
|
||||
base_model: Qwen/Qwen3.5-27B
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
# Uncomment below to also target the linear attention projections.
|
||||
# These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).
|
||||
# - linear_attn.in_proj_qkv
|
||||
# - linear_attn.in_proj_z
|
||||
# - linear_attn.out_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: false
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen3_5DecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -1,70 +0,0 @@
|
||||
base_model: Qwen/Qwen3.5-27B
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
# Uncomment below to also target the linear attention projections.
|
||||
# These use separate in_proj_qkv / in_proj_z / out_proj (Qwen3.5-specific).
|
||||
# - linear_attn.in_proj_qkv
|
||||
# - linear_attn.in_proj_z
|
||||
# - linear_attn.out_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -1,85 +0,0 @@
|
||||
base_model: Qwen/Qwen3.5-35B-A3B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
offload_params: true
|
||||
cpu_ram_efficient_loading: false
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: Qwen3_5MoeDecoderLayer
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
sharding_strategy: FULL_SHARD
|
||||
reshard_after_forward: true
|
||||
activation_checkpointing: true
|
||||
@@ -1,74 +0,0 @@
|
||||
base_model: Qwen/Qwen3.5-35B-A3B
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
|
||||
load_in_4bit: true
|
||||
quantize_moe_experts: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
|
||||
# Regex matching to target shared experts too
|
||||
# lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
|
||||
# Target experts
|
||||
# lora_target_parameters:
|
||||
# - mlp.experts.gate_up_proj
|
||||
# - mlp.experts.down_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
lora_mlp_kernel: false
|
||||
lora_qkv_kernel: false
|
||||
lora_o_kernel: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -1,49 +0,0 @@
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# Required for multimodal training
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096
|
||||
pad_to_sequence_len: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -1,66 +0,0 @@
|
||||
base_model: Qwen/Qwen3.5-9B
|
||||
processor_type: AutoProcessor
|
||||
|
||||
# These 3 lines are required for vision/multimodal training
|
||||
skip_prepare_dataset: true
|
||||
remove_unused_columns: false
|
||||
sample_packing: false
|
||||
|
||||
chat_template: qwen3_5
|
||||
datasets:
|
||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||
type: chat_template
|
||||
split: train[:1%]
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
pad_to_sequence_len: false
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
# Targets the language model attention and MLP layers.
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
# Uncomment to also target the linear attention (GatedDeltaNet) projections:
|
||||
# - linear_attn.in_proj_qkv
|
||||
# - linear_attn.in_proj_z
|
||||
# - linear_attn.out_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
@@ -1,86 +0,0 @@
|
||||
# Finetune Qwen3.5 with Axolotl
|
||||
|
||||
[Qwen3.5](https://huggingface.co/collections/Qwen/qwen35) is a hybrid architecture model series combining Gated DeltaNet linear attention with standard Transformer attention. All Qwen3.5 models are early-fusion vision-language models: dense variants use `Qwen3_5ForConditionalGeneration` and MoE variants use `Qwen3_5MoeForConditionalGeneration`.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
|
||||
```bash
|
||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
||||
```
|
||||
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
|
||||
|
||||
4. Pick any config from the table below and run:
|
||||
|
||||
```bash
|
||||
axolotl train examples/qwen3.5/<config>.yaml
|
||||
```
|
||||
|
||||
Available configs:
|
||||
|
||||
| Config | Model | Type | Peak VRAM |
|
||||
|---|---|---|---|
|
||||
| `9b-lora-vision.yaml` | Qwen3.5-9B | Vision+text LoRA, single GPU | — |
|
||||
| `9b-fft-vision.yaml` | Qwen3.5-9B | Vision+text FFT, single GPU | ~61 GiB |
|
||||
| `27b-qlora.yaml` | Qwen3.5-27B | Dense, text-only QLoRA | ~47 GiB |
|
||||
| `27b-fft.yaml` | Qwen3.5-27B | Dense, text-only FFT (vision frozen) | ~53 GiB |
|
||||
| `27b-qlora-fsdp.yaml` | Qwen3.5-27B | Dense, text-only QLoRA + FSDP2 | — |
|
||||
| `35b-a3b-moe-qlora.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA | — |
|
||||
| `35b-a3b-moe-qlora-fsdp.yaml` | Qwen3.5-35B-A3B | MoE, text-only QLoRA + FSDP2 | — |
|
||||
| `122b-a10b-moe-qlora.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA | — |
|
||||
| `122b-a10b-moe-qlora-fsdp.yaml` | Qwen3.5-122B-A10B | MoE, text-only QLoRA + FSDP2 | — |
|
||||
|
||||
### Gated DeltaNet Linear Attention
|
||||
|
||||
Qwen3.5 interleaves standard attention with Gated DeltaNet linear attention layers. To apply LoRA to them, add to `lora_target_modules`:
|
||||
|
||||
```yaml
|
||||
lora_target_modules:
|
||||
# ... standard projections ...
|
||||
- linear_attn.in_proj_qkv
|
||||
- linear_attn.in_proj_z
|
||||
- linear_attn.out_proj
|
||||
```
|
||||
|
||||
### Routed Experts (MoE)
|
||||
|
||||
To apply LoRA to routed expert parameters, add `lora_target_parameters`:
|
||||
|
||||
```yaml
|
||||
lora_target_parameters:
|
||||
- mlp.experts.gate_up_proj
|
||||
- mlp.experts.down_proj
|
||||
# - mlp.gate.weight # router
|
||||
```
|
||||
|
||||
### Shared Experts (MoE)
|
||||
|
||||
Routed experts and shared experts both have `gate_up_proj`/`down_proj`, so a plain module name in `lora_target_modules` would match both. Use a regex to target only attention and shared expert projections, while `lora_target_parameters` above handles routed experts separately:
|
||||
|
||||
```yaml
|
||||
lora_target_modules: 'model\.(language_model\.)?layers\.[\d]+\.(mlp|self_attn)\.(shared_expert\.)?(up|down|gate|gate_up|q|k|v|o)_proj'
|
||||
```
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference hyp, please see the respective model card details.
|
||||
- You can run a full finetuning of smaller configs by removing `adapter: qlora` and `load_in_4bit: true`. See [Multi-GPU](#optimization-guides) below.
|
||||
- Read more on loading your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format follows the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
- For **multimodal** finetuning, set `processor_type: AutoProcessor`, `skip_prepare_dataset: true`, and `remove_unused_columns: false` as shown in `9b-lora-vision.yaml`.
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Optimizations Guide](https://docs.axolotl.ai/docs/optimizations.html)
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Qwen3.5 Blog](https://qwenlm.github.io/blog/qwen3.5/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
@@ -8,15 +8,13 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
||||
|
||||
1. Install Axolotl following the main from the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
|
||||
|
||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||
|
||||
3. Run the finetuning example:
|
||||
2. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/trinity/trinity-nano-preview-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 24.9 GiB VRAM (w/o CCE).
|
||||
This config uses about 24.9 GiB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
@@ -31,6 +29,10 @@ Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
|
||||
|
||||
## Limitations
|
||||
|
||||
**Cut Cross Entropy (CCE)**: Currently not supported. We plan to include CCE support for Trinity in the near future.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [Trinity Blog](https://www.arcee.ai/blog/the-trinity-manifesto)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: arcee-ai/Trinity-Nano-Preview
|
||||
trust_remote_code: true
|
||||
revision_of_model: 2ee94b0
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
|
||||
@@ -61,11 +61,5 @@ skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
docstring-code-format = false
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "-m 'not slow'"
|
||||
markers = [
|
||||
"slow: marks tests as slow",
|
||||
]
|
||||
|
||||
[tool.uv.extra-build-dependencies]
|
||||
axolotl = ["huggingface_hub"]
|
||||
|
||||
@@ -2,28 +2,25 @@
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.49.1
|
||||
triton>=3.4.0
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
liger-kernel==0.7.0
|
||||
liger-kernel==0.6.4
|
||||
# END section
|
||||
|
||||
packaging==26.0
|
||||
huggingface_hub>=1.1.7
|
||||
peft>=0.18.1
|
||||
tokenizers>=0.22.1
|
||||
transformers==5.3.0
|
||||
accelerate==1.13.0
|
||||
transformers==5.0.0
|
||||
accelerate==1.12.0
|
||||
datasets==4.5.0
|
||||
deepspeed>=0.18.6,<0.19.0
|
||||
trl==0.29.0
|
||||
hf_xet==1.3.2
|
||||
kernels==0.12.2
|
||||
deepspeed>=0.18.3
|
||||
trl==0.27.1
|
||||
hf_xet==1.2.0
|
||||
kernels==0.11.5
|
||||
|
||||
fla-core==0.4.1
|
||||
flash-linear-attention==0.4.1
|
||||
|
||||
trackio>=0.16.1
|
||||
trackio>=0.13.0
|
||||
typing-extensions>=4.15.0
|
||||
|
||||
optimum==1.16.2
|
||||
@@ -66,7 +63,7 @@ langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.16.0
|
||||
torchao==0.13.0
|
||||
openenv-core==0.1.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
@@ -75,4 +72,4 @@ axolotl-contribs-mit==0.0.6
|
||||
# telemetry
|
||||
posthog==6.7.11
|
||||
|
||||
mistral-common==1.10.0
|
||||
mistral-common==1.8.8
|
||||
|
||||
@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6"'
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0d4ce4b"'
|
||||
)
|
||||
|
||||
225
scripts/merge_gemma3_multimodal_weights.py
Normal file
225
scripts/merge_gemma3_multimodal_weights.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Merge trained text-only Gemma3 weights back into a full multimodal checkpoint.
|
||||
|
||||
After training with the Gemma3TextFromMultimodalPlugin, the saved checkpoint
|
||||
contains only the language model weights (with ``model.language_model.*``
|
||||
prefix, reversed by transformers v5's key_mapping on save).
|
||||
|
||||
This script reconstructs a full ``Gemma3ForConditionalGeneration`` checkpoint by
|
||||
combining the trained language model weights with the original vision tower and
|
||||
projector weights from the base multimodal model.
|
||||
|
||||
Usage::
|
||||
|
||||
python scripts/merge_gemma3_multimodal_weights.py \\
|
||||
--original-model google/gemma-3-4b-it \\
|
||||
--trained-model /path/to/trained/output \\
|
||||
--output-dir /path/to/merged
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import load_file, save_file
|
||||
from transformers import AutoConfig
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def collect_safetensors(model_dir: Path) -> dict[str, torch.Tensor]:
|
||||
"""Load and merge all safetensors shard files in a directory."""
|
||||
shard_files = sorted(model_dir.glob("*.safetensors"))
|
||||
if not shard_files:
|
||||
raise FileNotFoundError(f"No safetensors files found in {model_dir}")
|
||||
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for shard in shard_files:
|
||||
LOG.info("Loading %s", shard.name)
|
||||
state_dict.update(load_file(str(shard)))
|
||||
return state_dict
|
||||
|
||||
|
||||
def merge(
|
||||
original_model: str,
|
||||
trained_model: str,
|
||||
output_dir: str,
|
||||
*,
|
||||
trust_remote_code: bool = False,
|
||||
) -> None:
|
||||
original_path = Path(original_model)
|
||||
trained_path = Path(trained_model)
|
||||
out_path = Path(output_dir)
|
||||
out_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 1. Load the original multimodal checkpoint
|
||||
LOG.info("Loading original multimodal weights from %s", original_model)
|
||||
if original_path.is_dir():
|
||||
original_sd = collect_safetensors(original_path)
|
||||
else:
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
cached = Path(
|
||||
snapshot_download(original_model, allow_patterns=["*.safetensors"])
|
||||
)
|
||||
original_sd = collect_safetensors(cached)
|
||||
|
||||
# 2. Load trained text-only weights (already reversed to model.language_model.* by
|
||||
# transformers v5 key_mapping on save)
|
||||
LOG.info("Loading trained text-only weights from %s", trained_model)
|
||||
trained_sd = collect_safetensors(trained_path)
|
||||
|
||||
# 3. Classify original keys
|
||||
lang_keys = {k for k in original_sd if k.startswith("model.language_model.")}
|
||||
vision_keys = {k for k in original_sd if k.startswith("model.vision_tower.")}
|
||||
projector_keys = {
|
||||
k for k in original_sd if k.startswith("model.multi_modal_projector.")
|
||||
}
|
||||
other_keys = set(original_sd.keys()) - lang_keys - vision_keys - projector_keys
|
||||
|
||||
LOG.info(
|
||||
"Original checkpoint: %d language, %d vision, %d projector, %d other keys",
|
||||
len(lang_keys),
|
||||
len(vision_keys),
|
||||
len(projector_keys),
|
||||
len(other_keys),
|
||||
)
|
||||
|
||||
# 4. Classify trained keys (reverse mapping on save gives model.language_model.* prefix)
|
||||
trained_lang_keys = {k for k in trained_sd if k.startswith("model.language_model.")}
|
||||
trained_other = set(trained_sd.keys()) - trained_lang_keys
|
||||
|
||||
LOG.info(
|
||||
"Trained checkpoint: %d language keys, %d other keys",
|
||||
len(trained_lang_keys),
|
||||
len(trained_other),
|
||||
)
|
||||
|
||||
# 5. Build merged state dict
|
||||
merged: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Keep vision tower and projector from original
|
||||
for key in vision_keys | projector_keys:
|
||||
merged[key] = original_sd[key]
|
||||
|
||||
# Use trained language model weights (overwrite original)
|
||||
for key in trained_lang_keys:
|
||||
merged[key] = trained_sd[key]
|
||||
|
||||
# For other trained keys (like lm_head.weight), use trained version
|
||||
for key in trained_other:
|
||||
merged[key] = trained_sd[key]
|
||||
|
||||
# For any original other keys not covered by trained (shouldn't usually happen),
|
||||
# keep original
|
||||
for key in other_keys:
|
||||
if key not in merged:
|
||||
merged[key] = original_sd[key]
|
||||
|
||||
# Check for missing language keys that were in original but not in trained
|
||||
missing_lang = lang_keys - trained_lang_keys
|
||||
if missing_lang:
|
||||
LOG.warning(
|
||||
"%d language keys in original but not in trained; keeping original: %s",
|
||||
len(missing_lang),
|
||||
list(missing_lang)[:5],
|
||||
)
|
||||
for key in missing_lang:
|
||||
merged[key] = original_sd[key]
|
||||
|
||||
LOG.info("Merged checkpoint: %d total keys", len(merged))
|
||||
|
||||
# 6. Save merged weights (sharded at 50GB, matching transformers default)
|
||||
LOG.info("Saving merged weights to %s", out_path)
|
||||
state_dict_split = split_torch_state_dict_into_shards(merged, max_shard_size="50GB")
|
||||
|
||||
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
||||
shard = {name: merged[name] for name in tensors}
|
||||
save_file(shard, str(out_path / filename))
|
||||
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": {
|
||||
"total_size": sum(t.numel() * t.element_size() for t in merged.values())
|
||||
},
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
with open(out_path / "model.safetensors.index.json", "w") as f:
|
||||
json.dump(index, f, indent=2)
|
||||
LOG.info("Saved %d shards", len(state_dict_split.filename_to_tensors))
|
||||
|
||||
# 7. Copy/update config
|
||||
LOG.info("Writing config.json")
|
||||
original_config = AutoConfig.from_pretrained(
|
||||
original_model, trust_remote_code=trust_remote_code
|
||||
)
|
||||
|
||||
# Update text_config fields from trained model's config if available
|
||||
trained_config_path = trained_path / "config.json"
|
||||
if trained_config_path.exists():
|
||||
with open(trained_config_path) as f:
|
||||
trained_config_dict = json.load(f)
|
||||
|
||||
# The trained config is the text sub-config; merge its fields into
|
||||
# the original composite config's text_config
|
||||
if hasattr(original_config, "text_config"):
|
||||
for key, val in trained_config_dict.items():
|
||||
if key not in ("model_type", "_name_or_path", "architectures"):
|
||||
if hasattr(original_config.text_config, key):
|
||||
setattr(original_config.text_config, key, val)
|
||||
|
||||
original_config.save_pretrained(out_path)
|
||||
|
||||
# 8. Copy tokenizer files from trained model if present
|
||||
tokenizer_files = list(trained_path.glob("tokenizer*")) + list(
|
||||
trained_path.glob("special_tokens_map*")
|
||||
)
|
||||
if tokenizer_files:
|
||||
import shutil
|
||||
|
||||
for tok_file in tokenizer_files:
|
||||
shutil.copy2(tok_file, out_path / tok_file.name)
|
||||
LOG.info("Copied %d tokenizer files", len(tokenizer_files))
|
||||
|
||||
LOG.info("Merge complete. Output saved to %s", out_path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Merge trained text-only Gemma3 weights back into a multimodal checkpoint."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--original-model",
|
||||
required=True,
|
||||
help="HuggingFace model ID or local path to the original multimodal model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trained-model",
|
||||
required=True,
|
||||
help="Local path to the trained text-only model output directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
required=True,
|
||||
help="Directory to save the merged multimodal checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Trust remote code when loading model config",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
merge(
|
||||
original_model=args.original_model,
|
||||
trained_model=args.trained_model,
|
||||
output_dir=args.output_dir,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
25
setup.py
25
setup.py
@@ -26,18 +26,6 @@ def parse_requirements(extras_require_map):
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
install_xformers = platform.machine() != "aarch64"
|
||||
if platform.machine() == "aarch64":
|
||||
# skip on ARM64
|
||||
skip_packages = [
|
||||
"torchao",
|
||||
"fla-core",
|
||||
"flash-linear-attention",
|
||||
]
|
||||
_install_requires = [
|
||||
req
|
||||
for req in _install_requires
|
||||
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
|
||||
]
|
||||
if "Darwin" in platform.system():
|
||||
# skip packages not compatible with OSX
|
||||
skip_packages = [
|
||||
@@ -81,23 +69,16 @@ def parse_requirements(extras_require_map):
|
||||
f"https://download.pytorch.org/whl/{torch_cuda_version}"
|
||||
)
|
||||
|
||||
if (major, minor) >= (2, 10):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = [
|
||||
"fbgemm-gpu==1.5.0",
|
||||
"fbgemm-gpu-genai==1.5.0",
|
||||
]
|
||||
if not install_xformers:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
extras_require_map["vllm"] = ["vllm==0.17.1"]
|
||||
elif (major, minor) >= (2, 9):
|
||||
if (major, minor) >= (2, 9):
|
||||
extras_require_map.pop("fbgemm-gpu")
|
||||
extras_require_map["fbgemm-gpu"] = [
|
||||
"fbgemm-gpu==1.4.0",
|
||||
"fbgemm-gpu-genai==1.4.2",
|
||||
]
|
||||
extras_require_map["vllm"] = ["vllm==0.11.1"]
|
||||
if not install_xformers:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
||||
if patch == 0:
|
||||
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
||||
else:
|
||||
|
||||
@@ -6,6 +6,5 @@ from axolotl.logging_config import configure_logging
|
||||
|
||||
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
||||
os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
|
||||
os.environ.setdefault("TRL_EXPERIMENTAL_SILENCE", "1")
|
||||
|
||||
configure_logging()
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpcore
|
||||
from accelerate.commands.config import config_args
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
@@ -48,7 +47,7 @@ def check_user_token() -> bool:
|
||||
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||
)
|
||||
return False
|
||||
except (HTTPError, httpcore.ConnectError):
|
||||
except HTTPError:
|
||||
LOG.warning(
|
||||
"Error accessing HuggingFace. This may be due to a network issue or rate limiting."
|
||||
)
|
||||
|
||||
@@ -90,8 +90,9 @@ class ModalCloud(Cloud):
|
||||
# grab the sha256 hash from docker hub for this image+tag
|
||||
# this ensures that we always get the latest image for this tag, even if it's already cached
|
||||
try:
|
||||
manifest = subprocess.check_output( # nosec
|
||||
["docker", "manifest", "inspect", docker_image],
|
||||
manifest = subprocess.check_output( # nosec B602
|
||||
f"docker manifest inspect {docker_image}",
|
||||
shell=True,
|
||||
).decode("utf-8")
|
||||
sha256_hash = json.loads(manifest)["manifests"][0]["digest"]
|
||||
except subprocess.CalledProcessError:
|
||||
|
||||
@@ -5,13 +5,13 @@ import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import yaml
|
||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_tf32_available
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.telemetry.errors import send_errors
|
||||
@@ -32,63 +32,6 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _coerce_value(value: Any, existing: Optional[Any] = None) -> Any:
|
||||
"""Coerce a string CLI value to its most likely Python type.
|
||||
|
||||
If an existing value is present in the config, its type is used to guide
|
||||
casting. Otherwise, YAML-style inference is applied: booleans, ints,
|
||||
floats, and None literals are recognised automatically.
|
||||
|
||||
Args:
|
||||
value: The raw value (typically a string from the CLI).
|
||||
existing: An optional existing config value whose type guides coercion.
|
||||
|
||||
Returns:
|
||||
The value cast to the inferred or expected type.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
# If the config already has a typed value, cast to match
|
||||
if existing is not None:
|
||||
if isinstance(existing, bool):
|
||||
return value.lower() in ("true", "1", "yes")
|
||||
if isinstance(existing, int):
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
if isinstance(existing, float):
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
# For other types (str, list, dict, etc.), return as-is
|
||||
return value
|
||||
|
||||
# No existing value -- use YAML-style inference
|
||||
lower = value.lower()
|
||||
if lower in ("true", "yes"):
|
||||
return True
|
||||
if lower in ("false", "no"):
|
||||
return False
|
||||
if lower in ("null", "none", "~"):
|
||||
return None
|
||||
|
||||
# Try int then float
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
|
||||
API_KEY_FIELDS = {"comet_api_key"}
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
@@ -265,42 +208,18 @@ def load_cfg(
|
||||
# If there are any options passed in the cli, if it is something that seems valid
|
||||
# from the yaml, then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
|
||||
# Separate nested (dot-notation) kwargs from flat kwargs
|
||||
nested_kwargs: dict[str, dict[str, Any]] = {}
|
||||
flat_kwargs: dict[str, Any] = {}
|
||||
for key, value in kwargs.items():
|
||||
if "__" in key:
|
||||
parent, child = key.split("__", 1)
|
||||
nested_kwargs.setdefault(parent, {})[child] = value
|
||||
else:
|
||||
flat_kwargs[key] = value
|
||||
|
||||
# Apply flat kwargs
|
||||
for key, value in flat_kwargs.items():
|
||||
# If not strict, allow writing to cfg even if it's not in the yml already
|
||||
if key in cfg_keys or not cfg.strict:
|
||||
cfg[key] = _coerce_value(value, cfg.get(key))
|
||||
|
||||
# Apply nested kwargs (e.g., trl__beta -> cfg.trl.beta)
|
||||
for parent, children in nested_kwargs.items():
|
||||
if parent not in cfg_keys and cfg.strict:
|
||||
continue
|
||||
if cfg[parent] is None:
|
||||
cfg[parent] = {}
|
||||
if not isinstance(cfg[parent], dict):
|
||||
LOG.warning(
|
||||
"Overwriting non-dict value for '%s' with nested CLI overrides", parent
|
||||
)
|
||||
cfg[parent] = {}
|
||||
for child_key, child_value in children.items():
|
||||
existing_child = cfg[parent].get(child_key)
|
||||
cfg[parent][child_key] = _coerce_value(child_value, existing_child)
|
||||
if isinstance(cfg[key], bool):
|
||||
cfg[key] = bool(value)
|
||||
else:
|
||||
cfg[key] = value
|
||||
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
||||
except (RuntimeError, AssertionError):
|
||||
except:
|
||||
gpu_version = None
|
||||
|
||||
prepare_plugins(cfg)
|
||||
@@ -310,7 +229,6 @@ def load_cfg(
|
||||
capabilities={
|
||||
"bf16": is_torch_bf16_gpu_available(),
|
||||
"fp8": compute_supports_fp8(),
|
||||
"tf32": is_torch_tf32_available(),
|
||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||
"compute_capability": gpu_version,
|
||||
},
|
||||
|
||||
@@ -71,7 +71,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
merge_lora=True,
|
||||
load_in_8bit=False,
|
||||
load_in_4bit=False,
|
||||
quantize_moe_experts=False,
|
||||
flash_attention=False,
|
||||
context_parallel_size=None,
|
||||
deepspeed=None,
|
||||
|
||||
@@ -196,10 +196,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
state.wait_for_everyone()
|
||||
LOG.info(
|
||||
f"FSDP SHARDED_STATE_DICT weights successfully merged to: {output_path}",
|
||||
main_process_only=True,
|
||||
)
|
||||
LOG.info(
|
||||
"Merged weights are only the safetensors and doesn't include the model configuration "
|
||||
f"or tokenizer which may be found in {parsed_cfg.output_dir}.",
|
||||
main_process_only=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import dataclasses
|
||||
from functools import wraps
|
||||
from types import NoneType, UnionType
|
||||
from types import NoneType
|
||||
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
|
||||
import click
|
||||
@@ -20,8 +20,7 @@ def _strip_optional_type(field_type: type | str | None):
|
||||
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
|
||||
returns the input type unchanged.
|
||||
"""
|
||||
is_union = get_origin(field_type) is Union or isinstance(field_type, UnionType)
|
||||
if is_union and type(None) in get_args(field_type):
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
@@ -88,70 +87,10 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
return decorator
|
||||
|
||||
|
||||
def _is_pydantic_model(field_type: type) -> bool:
|
||||
"""Check if a type is a Pydantic BaseModel subclass."""
|
||||
try:
|
||||
return isinstance(field_type, type) and issubclass(field_type, BaseModel)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def _get_field_description(field) -> str | None:
|
||||
"""Get description from a Pydantic field, checking both .description and json_schema_extra."""
|
||||
if field.description:
|
||||
return field.description
|
||||
if field.json_schema_extra and isinstance(field.json_schema_extra, dict):
|
||||
return field.json_schema_extra.get("description")
|
||||
return None
|
||||
|
||||
|
||||
def _add_nested_model_options(
|
||||
function: Callable, parent_name: str, model_class: Type[BaseModel]
|
||||
) -> Callable:
|
||||
"""
|
||||
Add Click options for all fields of a nested Pydantic model using dot-notation.
|
||||
|
||||
Note: Only single-level nesting is supported (e.g., ``--trl.beta``).
|
||||
Deeper nesting (e.g., ``--trl.scheduler.warmup``) is not handled.
|
||||
|
||||
Args:
|
||||
function: Click command function to add options to.
|
||||
parent_name: Parent field name (e.g., "trl").
|
||||
model_class: Nested Pydantic model class.
|
||||
|
||||
Returns:
|
||||
Function with added Click options.
|
||||
"""
|
||||
for sub_name, sub_field in reversed(model_class.model_fields.items()):
|
||||
sub_type = _strip_optional_type(sub_field.annotation)
|
||||
# Use dot notation: --parent.sub_field
|
||||
cli_name = f"{parent_name}.{sub_name}".replace("_", "-")
|
||||
# The kwarg name uses double-underscore as separator
|
||||
param_name = f"{parent_name}__{sub_name}"
|
||||
description = _get_field_description(sub_field)
|
||||
|
||||
if sub_type is bool:
|
||||
option_name = f"--{cli_name}/--no-{cli_name}"
|
||||
function = click.option(
|
||||
option_name, param_name, default=None, help=description
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{cli_name}"
|
||||
click_type = {str: str, int: int, float: float}.get(sub_type)
|
||||
function = click.option(
|
||||
option_name, param_name, default=None, type=click_type, help=description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
|
||||
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a Pydantic model.
|
||||
|
||||
For fields whose type is itself a Pydantic BaseModel, dot-notation CLI options are
|
||||
generated for each sub-field (e.g., ``--trl.beta=0.1``).
|
||||
|
||||
Args:
|
||||
config_class: PyDantic model with fields to parse from the CLI
|
||||
|
||||
@@ -164,11 +103,6 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
field_type = _strip_optional_type(field.annotation)
|
||||
|
||||
# Handle nested Pydantic models with dot-notation options
|
||||
if _is_pydantic_model(field_type):
|
||||
function = _add_nested_model_options(function, name, field_type)
|
||||
continue
|
||||
|
||||
if field_type is bool:
|
||||
field_name = name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
|
||||
@@ -38,18 +38,7 @@ def do_vllm_serve(
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
# Determine serve module: explicit CLI/config > auto-select from vllm_lora_sync > default
|
||||
serve_module = cli_args.get("serve_module") or getattr(
|
||||
cfg.vllm, "serve_module", None
|
||||
)
|
||||
if (
|
||||
serve_module is None
|
||||
and getattr(cfg, "trl", None)
|
||||
and getattr(cfg.trl, "vllm_lora_sync", False)
|
||||
):
|
||||
serve_module = "axolotl.scripts.vllm_serve_lora"
|
||||
if serve_module is None:
|
||||
serve_module = "trl.scripts.vllm_serve"
|
||||
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
||||
vllm_serve_main = __import__(serve_module, fromlist=["main"]).main
|
||||
tensor_parallel_size = 1
|
||||
data_parallel_size = 1
|
||||
@@ -79,7 +68,7 @@ def do_vllm_serve(
|
||||
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
|
||||
)
|
||||
|
||||
base_kwargs = dict(
|
||||
vllm_script_args = AxolotlScriptArguments(
|
||||
model=model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
data_parallel_size=data_parallel_size,
|
||||
@@ -89,21 +78,7 @@ def do_vllm_serve(
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
reasoning_parser=reasoning_parser,
|
||||
enable_reasoning=enable_reasoning,
|
||||
)
|
||||
|
||||
# Use LoRAScriptArguments when serving with native LoRA support
|
||||
if serve_module == "axolotl.scripts.vllm_serve_lora":
|
||||
from axolotl.scripts.vllm_serve_lora import LoRAScriptArguments
|
||||
|
||||
lora_kwargs = {}
|
||||
if hasattr(cfg, "lora_r") and cfg.lora_r:
|
||||
lora_kwargs["max_lora_rank"] = cfg.lora_r
|
||||
vllm_script_args = LoRAScriptArguments(**base_kwargs, **lora_kwargs)
|
||||
else:
|
||||
vllm_script_args = AxolotlScriptArguments(
|
||||
**base_kwargs,
|
||||
reasoning_parser=reasoning_parser,
|
||||
enable_reasoning=enable_reasoning,
|
||||
)
|
||||
|
||||
vllm_serve_main(vllm_script_args)
|
||||
|
||||
@@ -12,15 +12,10 @@ MOE_ARCH_BLOCK = {
|
||||
"mixtral": "MixtralSparseMoeBlock",
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"qwen3_5_moe": "Qwen3_5MoeSparseMoeBlock",
|
||||
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
"deepseek_v3": "DeepseekV3MoE",
|
||||
"mistral4": "Mistral4MoE",
|
||||
"gpt_oss": "GptOssDecoderLayer",
|
||||
"lfm2_moe": "Lfm2MoeSparseMoeBlock",
|
||||
"afmoe": "AfmoeMoE",
|
||||
"glm4_moe": "Glm4MoeDecoderLayer",
|
||||
"glm4_moe_lite": "Glm4MoeLiteDecoderLayer",
|
||||
"glm_moe_dsa": "GlmMoeDsaDecoderLayer",
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ class JsonToJsonlConverter:
|
||||
self.json_parser = json_parser
|
||||
self.jsonl_serializer = jsonl_serializer
|
||||
|
||||
def convert(self, input_file_path):
|
||||
def convert(self, input_file_path, output_file_path):
|
||||
content = self.file_reader.read(input_file_path)
|
||||
data = self.json_parser.parse(content)
|
||||
# data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations
|
||||
|
||||
@@ -250,7 +250,7 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
def _configure_precision_settings(self, training_args_kwargs: dict):
|
||||
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
|
||||
training_args_kwargs["tf32"] = True if self.cfg.tf32 is True else False
|
||||
training_args_kwargs["tf32"] = self.cfg.tf32
|
||||
if self.cfg.bf16 == "full":
|
||||
training_args_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
@@ -353,30 +353,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
adam_kwargs["eps"] = (eps1, eps2)
|
||||
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "flash_adamw":
|
||||
from flashoptim import FlashAdamW
|
||||
|
||||
optimizer_cls = FlashAdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "flash_adam":
|
||||
from flashoptim import FlashAdam
|
||||
|
||||
optimizer_cls = FlashAdam
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "flash_sgd":
|
||||
from flashoptim import FlashSGD
|
||||
|
||||
optimizer_cls = FlashSGD
|
||||
elif self.cfg.optimizer == "flash_sgdw":
|
||||
from flashoptim import FlashSGDW
|
||||
|
||||
optimizer_cls = FlashSGDW
|
||||
elif self.cfg.optimizer == "flash_lion":
|
||||
from flashoptim import FlashLion
|
||||
|
||||
optimizer_cls = FlashLion
|
||||
if "betas" in adam_kwargs:
|
||||
optimizer_kwargs["betas"] = adam_kwargs["betas"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
|
||||
@@ -508,8 +484,6 @@ class TrainerBuilderBase(abc.ABC):
|
||||
training_args_kwargs["accelerator_config"] = AcceleratorConfig()
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.layer_offloading:
|
||||
training_args_kwargs["layer_offloading"] = True
|
||||
if self.cfg.activation_offloading is True:
|
||||
# don't use the HF gradient checkpointing, manually wrap
|
||||
training_args_kwargs["gradient_checkpointing"] = False
|
||||
|
||||
@@ -122,12 +122,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
ColabCallback = colab_inference_post_train_callback(trainer)
|
||||
callbacks.append(ColabCallback(self.cfg))
|
||||
|
||||
if getattr(self.cfg, "generate_samples", False):
|
||||
from axolotl.utils.callbacks.generation import SFTGenerationCallback
|
||||
|
||||
callbacks.append(SFTGenerationCallback(trainer))
|
||||
LOG.info("SFT sample generation enabled")
|
||||
|
||||
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
|
||||
return callbacks
|
||||
|
||||
@@ -252,8 +246,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
ddp_find_unused_parameters
|
||||
)
|
||||
|
||||
if self.cfg.group_by_length:
|
||||
training_arguments_kwargs["train_sampling_strategy"] = "group_by_length"
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
|
||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||
@@ -421,13 +414,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
# TRL's RewardTrainer validates num_labels=1 on pre-loaded models; ensure the
|
||||
# config reflects this regardless of how the model was instantiated.
|
||||
if (
|
||||
self.cfg.reward_model
|
||||
and getattr(self.model.config, "num_labels", None) != 1
|
||||
):
|
||||
self.model.config.num_labels = 1
|
||||
trainer = trainer_cls(
|
||||
model=self.model,
|
||||
train_dataset=self.train_dataset,
|
||||
|
||||
@@ -11,6 +11,7 @@ from axolotl.core.trainers import (
|
||||
)
|
||||
from axolotl.core.trainers.dpo import DPOStrategy
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import ensure_dtype
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
@@ -52,18 +53,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
async_grpo = bool(
|
||||
self.cfg.trl
|
||||
and (
|
||||
getattr(self.cfg.trl, "async_prefetch", False)
|
||||
or getattr(self.cfg.trl, "use_data_producer", False)
|
||||
)
|
||||
)
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1,
|
||||
async_grpo=async_grpo,
|
||||
sequence_parallel=self.cfg.context_parallel_size > 1
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
@@ -128,6 +119,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.use_wandb:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
else:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl is RLType.SIMPO:
|
||||
@@ -137,17 +133,21 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
blocklist_args_kwargs.append("max_prompt_length")
|
||||
# Handle when max_prompt_length == max_length from defaults
|
||||
# CPOTrainer requires strictly less than
|
||||
if (
|
||||
training_args_kwargs["max_prompt_length"]
|
||||
== training_args_kwargs["max_length"]
|
||||
):
|
||||
training_args_kwargs["max_prompt_length"] -= 1
|
||||
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
|
||||
blocklist_args_kwargs.append("max_prompt_length")
|
||||
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
|
||||
blocklist_args_kwargs.append("max_prompt_length")
|
||||
blocklist_args_kwargs = ["max_prompt_length"]
|
||||
|
||||
training_args_kwargs["desirable_weight"] = (
|
||||
self.cfg.kto_desirable_weight or 1.0
|
||||
@@ -157,18 +157,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
|
||||
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
|
||||
async_grpo = bool(
|
||||
self.cfg.trl
|
||||
and (
|
||||
getattr(self.cfg.trl, "async_prefetch", False)
|
||||
or getattr(self.cfg.trl, "use_data_producer", False)
|
||||
)
|
||||
)
|
||||
training_args_cls = GRPOStrategy.get_training_args_class(
|
||||
async_grpo=async_grpo
|
||||
)
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
if self.cfg.rl is RLType.GDPO:
|
||||
@@ -208,11 +197,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.eval_dataset:
|
||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||
if (
|
||||
self.cfg.adapter
|
||||
and self.peft_config
|
||||
and self.cfg.rl not in (RLType.GRPO, RLType.ORPO)
|
||||
):
|
||||
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
|
||||
trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||
@@ -238,36 +223,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_kwargs, trainer_cls
|
||||
)
|
||||
|
||||
# Allow FP8-quantized models to be fine-tuned with LoRA adapters.
|
||||
# transformers' validate_quantization_for_training blocks FP8 because
|
||||
# hf_quantizer.is_trainable is False, but LoRA only trains the adapters
|
||||
# (base weights stay frozen in FP8).
|
||||
_orig_validate_quant = None
|
||||
if (
|
||||
self.cfg.adapter
|
||||
and hasattr(self.model, "is_quantized")
|
||||
and self.model.is_quantized
|
||||
):
|
||||
import transformers.trainer as _trainer_module
|
||||
|
||||
_orig_validate_quant = _trainer_module.validate_quantization_for_training
|
||||
_trainer_module.validate_quantization_for_training = lambda model: None
|
||||
|
||||
try:
|
||||
trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
train_dataset=self.train_dataset,
|
||||
callbacks=self.get_callbacks(),
|
||||
**trainer_kwargs,
|
||||
)
|
||||
finally:
|
||||
if _orig_validate_quant is not None:
|
||||
import transformers.trainer as _trainer_module
|
||||
|
||||
_trainer_module.validate_quantization_for_training = (
|
||||
_orig_validate_quant
|
||||
)
|
||||
trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
train_dataset=self.train_dataset,
|
||||
callbacks=self.get_callbacks(),
|
||||
**trainer_kwargs,
|
||||
)
|
||||
if self.cfg.fsdp_config or self.cfg.fsdp:
|
||||
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
|
||||
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
|
||||
|
||||
@@ -26,15 +26,13 @@ from transformers import PreTrainedModel, Trainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
|
||||
from trl.experimental.utils import pad_to_length
|
||||
from trl.trainer.utils import pad_to_length
|
||||
from typing_extensions import override
|
||||
|
||||
from axolotl.core.trainers.constants import TOKENS_STATE_FILE
|
||||
from axolotl.core.trainers.mixins import (
|
||||
ActivationOffloadingMixin,
|
||||
CheckpointSaveMixin,
|
||||
DistributedParallelMixin,
|
||||
LayerOffloadingMixin,
|
||||
OptimizerMixin,
|
||||
PackingMixin,
|
||||
RngLoaderMixin,
|
||||
@@ -53,6 +51,8 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
TOKENS_STATE_FILE = "tokens_state."
|
||||
|
||||
REDUCTION_FNS = {
|
||||
"mean": torch.mean,
|
||||
"min": torch.min,
|
||||
@@ -67,7 +67,6 @@ class AxolotlTrainer(
|
||||
OptimizerMixin,
|
||||
RngLoaderMixin,
|
||||
CheckpointSaveMixin,
|
||||
LayerOffloadingMixin,
|
||||
ActivationOffloadingMixin,
|
||||
DistributedParallelMixin,
|
||||
Trainer,
|
||||
@@ -720,20 +719,13 @@ class AxolotlTrainer(
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
LOG.info(f"Saving model checkpoint to {output_dir}")
|
||||
|
||||
# fix for Context Parallel save: CP eval invalidates tensor storage
|
||||
# pointers, so clone to CPU to get fresh valid storage for safetensors
|
||||
if (
|
||||
state_dict is not None
|
||||
and self.axolotl_cfg
|
||||
and self.axolotl_cfg.context_parallel_size
|
||||
and self.axolotl_cfg.context_parallel_size > 1
|
||||
):
|
||||
if state_dict is None:
|
||||
state_dict = self.accelerator.get_state_dict(self.model)
|
||||
if state_dict is not None:
|
||||
state_dict = {
|
||||
k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
|
||||
k: v.clone() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
supported_classes = (
|
||||
(PreTrainedModel,)
|
||||
if not is_peft_available()
|
||||
@@ -744,7 +736,6 @@ class AxolotlTrainer(
|
||||
if not isinstance(self.model, supported_classes):
|
||||
if state_dict is None:
|
||||
state_dict = self.model.state_dict()
|
||||
|
||||
if isinstance(
|
||||
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
|
||||
supported_classes,
|
||||
@@ -754,7 +745,6 @@ class AxolotlTrainer(
|
||||
).save_pretrained(
|
||||
output_dir,
|
||||
state_dict=state_dict,
|
||||
is_main_process=self.accelerator.is_main_process,
|
||||
)
|
||||
else:
|
||||
LOG.info(
|
||||
@@ -782,7 +772,11 @@ class AxolotlTrainer(
|
||||
LOG.info(
|
||||
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
|
||||
)
|
||||
self.data_collator.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
save_jinja_files = True
|
||||
if self.axolotl_cfg:
|
||||
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
|
||||
self.data_collator.tokenizer.save_pretrained(
|
||||
output_dir, save_jinja_files=save_jinja_files
|
||||
)
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
TOKENS_STATE_FILE = "tokens_state.json"
|
||||
@@ -25,13 +25,17 @@ class DPOStrategy:
|
||||
# Label smoothing is not compatible with IPO
|
||||
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
||||
if cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||
if cfg.dpo_padding_free is not None:
|
||||
training_args_kwargs["padding_free"] = cfg.dpo_padding_free
|
||||
if cfg.dpo_norm_loss is not None:
|
||||
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
|
||||
if cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
|
||||
if cfg.dpo_use_liger_kernel is not None:
|
||||
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
|
||||
return training_args_kwargs
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
Axolotl specific DPO args
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trl import DPOConfig
|
||||
|
||||
@@ -17,4 +16,3 @@ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||
"""
|
||||
|
||||
dpo_norm_loss: bool | None = False
|
||||
rpo_alpha: Optional[float] = field(default=None)
|
||||
|
||||
@@ -57,18 +57,16 @@ class AxolotlDPOTrainer(
|
||||
def tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length: int | None = None,
|
||||
max_completion_length: int | None = None,
|
||||
add_special_tokens: bool = True,
|
||||
is_chat: bool = False,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> Dict:
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length=max_prompt_length,
|
||||
max_completion_length=max_completion_length,
|
||||
add_special_tokens=add_special_tokens,
|
||||
is_chat=is_chat,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
)
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
@@ -103,10 +101,10 @@ class AxolotlDPOTrainer(
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if self.args.dpo_norm_loss:
|
||||
# fmt: off
|
||||
loss_type: list[str] = self.loss_type # type: ignore[has-type]
|
||||
loss_type: str = self.loss_type # type: ignore[has-type]
|
||||
# fmt: on
|
||||
# concatenated_forward handles avg token logprob for ipo case already
|
||||
self.loss_type = ["ipo"]
|
||||
self.loss_type = "ipo"
|
||||
res = super().concatenated_forward(model, batch, is_ref_model=is_ref_model)
|
||||
self.loss_type = loss_type
|
||||
return res
|
||||
|
||||
@@ -9,9 +9,8 @@ from huggingface_hub import snapshot_download
|
||||
from requests import HTTPError
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.args import AxolotlAsyncGRPOConfig, AxolotlGRPOConfig
|
||||
from axolotl.core.trainers.grpo.args import AxolotlGRPOConfig
|
||||
from axolotl.core.trainers.grpo.trainer import (
|
||||
AxolotlAsyncGRPOTrainer,
|
||||
AxolotlGRPOSequenceParallelTrainer,
|
||||
AxolotlGRPOTrainer,
|
||||
)
|
||||
@@ -28,31 +27,14 @@ class GRPOStrategy:
|
||||
|
||||
@classmethod
|
||||
def get_trainer_class(
|
||||
cls,
|
||||
sequence_parallel: bool,
|
||||
async_grpo: bool = False,
|
||||
) -> (
|
||||
type[AxolotlGRPOTrainer]
|
||||
| type[AxolotlGRPOSequenceParallelTrainer]
|
||||
| type[AxolotlAsyncGRPOTrainer]
|
||||
):
|
||||
if sequence_parallel and async_grpo:
|
||||
raise ValueError(
|
||||
"sequence_parallel and async_grpo cannot both be enabled. "
|
||||
"Disable one of context_parallel_size > 1 or async_prefetch/use_data_producer."
|
||||
)
|
||||
cls, sequence_parallel: bool
|
||||
) -> type[AxolotlGRPOTrainer] | type[AxolotlGRPOSequenceParallelTrainer]:
|
||||
if sequence_parallel:
|
||||
return AxolotlGRPOSequenceParallelTrainer
|
||||
if async_grpo:
|
||||
return AxolotlAsyncGRPOTrainer
|
||||
return AxolotlGRPOTrainer
|
||||
|
||||
@classmethod
|
||||
def get_training_args_class(
|
||||
cls, async_grpo: bool = False
|
||||
) -> type[AxolotlGRPOConfig] | type[AxolotlAsyncGRPOConfig]:
|
||||
if async_grpo:
|
||||
return AxolotlAsyncGRPOConfig
|
||||
def get_training_args_class(cls) -> type[AxolotlGRPOConfig]:
|
||||
return AxolotlGRPOConfig
|
||||
|
||||
@classmethod
|
||||
@@ -142,63 +124,13 @@ class GRPOStrategy:
|
||||
grpo_args_kwargs["epsilon_high"] = trl.epsilon_high
|
||||
|
||||
if trl.use_liger_loss is not None:
|
||||
grpo_args_kwargs["use_liger_kernel"] = trl.use_liger_loss
|
||||
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
|
||||
|
||||
if trl.multi_objective_aggregation is not None:
|
||||
grpo_args_kwargs["multi_objective_aggregation"] = (
|
||||
trl.multi_objective_aggregation
|
||||
)
|
||||
|
||||
# Async GRPO fields
|
||||
if getattr(trl, "use_data_producer", None) is not None:
|
||||
grpo_args_kwargs["use_data_producer"] = trl.use_data_producer
|
||||
if getattr(trl, "async_prefetch", None) is not None:
|
||||
grpo_args_kwargs["async_prefetch"] = trl.async_prefetch
|
||||
if getattr(trl, "prefetch_depth", None) is not None:
|
||||
grpo_args_kwargs["prefetch_depth"] = trl.prefetch_depth
|
||||
if getattr(trl, "vllm_sync_interval", None) is not None:
|
||||
grpo_args_kwargs["vllm_sync_interval"] = trl.vllm_sync_interval
|
||||
if getattr(trl, "streaming_partial_batch", None) is not None:
|
||||
grpo_args_kwargs["streaming_partial_batch"] = trl.streaming_partial_batch
|
||||
if getattr(trl, "streaming_min_groups", None) is not None:
|
||||
grpo_args_kwargs["streaming_min_groups"] = trl.streaming_min_groups
|
||||
if getattr(trl, "vllm_importance_sampling_correction", None) is not None:
|
||||
grpo_args_kwargs["vllm_importance_sampling_correction"] = (
|
||||
trl.vllm_importance_sampling_correction
|
||||
)
|
||||
if getattr(trl, "vllm_importance_sampling_mode", None) is not None:
|
||||
grpo_args_kwargs["vllm_importance_sampling_mode"] = (
|
||||
trl.vllm_importance_sampling_mode
|
||||
)
|
||||
if getattr(trl, "vllm_importance_sampling_cap", None) is not None:
|
||||
grpo_args_kwargs["vllm_importance_sampling_cap"] = (
|
||||
trl.vllm_importance_sampling_cap
|
||||
)
|
||||
if getattr(trl, "off_policy_mask_threshold", None) is not None:
|
||||
grpo_args_kwargs["off_policy_mask_threshold"] = (
|
||||
trl.off_policy_mask_threshold
|
||||
)
|
||||
if getattr(trl, "use_bias_correction_kl", None) is not None:
|
||||
grpo_args_kwargs["use_bias_correction_kl"] = trl.use_bias_correction_kl
|
||||
|
||||
# Fast Async GRPO fields
|
||||
if getattr(trl, "reward_num_workers", None) is not None:
|
||||
grpo_args_kwargs["reward_num_workers"] = trl.reward_num_workers
|
||||
if getattr(trl, "replay_buffer_size", None) is not None:
|
||||
grpo_args_kwargs["replay_buffer_size"] = trl.replay_buffer_size
|
||||
if getattr(trl, "replay_recompute_logps", None) is not None:
|
||||
grpo_args_kwargs["replay_recompute_logps"] = trl.replay_recompute_logps
|
||||
if getattr(trl, "reroll_start_fraction", None) is not None:
|
||||
grpo_args_kwargs["reroll_start_fraction"] = trl.reroll_start_fraction
|
||||
if getattr(trl, "reroll_max_groups", None) is not None:
|
||||
grpo_args_kwargs["reroll_max_groups"] = trl.reroll_max_groups
|
||||
if getattr(trl, "skip_zero_advantage_batches", None) is not None:
|
||||
grpo_args_kwargs["skip_zero_advantage_batches"] = (
|
||||
trl.skip_zero_advantage_batches
|
||||
)
|
||||
if getattr(trl, "vllm_lora_sync", None) is not None:
|
||||
grpo_args_kwargs["vllm_lora_sync"] = trl.vllm_lora_sync
|
||||
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -6,7 +6,6 @@ from dataclasses import dataclass
|
||||
|
||||
from trl import GRPOConfig
|
||||
|
||||
from axolotl.core.trainers.grpo.fast_async_trainer import FastAsyncGRPOConfig
|
||||
from axolotl.core.training_args import AxolotlTrainingMixins
|
||||
|
||||
|
||||
@@ -15,10 +14,3 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""Axolotl GRPO Config for GRPO training"""
|
||||
|
||||
context_parallel_size: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlAsyncGRPOConfig(AxolotlTrainingMixins, FastAsyncGRPOConfig):
|
||||
"""Axolotl Async GRPO Config — adds async prefetch, streaming scoring, and IS correction."""
|
||||
|
||||
context_parallel_size: int | None = None
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user