Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
208f8b253f add validation for DFT 2026-01-13 09:33:04 -05:00
Wing Lian
75ad1a9932 use dynamic finetuning with chunked cross entropy 2026-01-13 09:33:04 -05:00
189 changed files with 829 additions and 9041 deletions

View File

@@ -15,11 +15,6 @@
<!--- Include details of your testing environment, tests ran to see how -->
<!--- your change affects other areas of the code, etc. -->
## AI Usage Disclaimer
<!--- Was AI (e.g., ChatGPT, Claude, Copilot) used to generate or assist with this PR? -->
<!--- Please indicate: No / Yes (specify which tool and to what extent) -->
## Screenshots (if appropriate)
## Types of changes

View File

@@ -21,8 +21,6 @@ jobs:
timeout-minutes: 480
# this job needs to be run on self-hosted GPU runners...
runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy:
fail-fast: false
matrix:
@@ -34,7 +32,6 @@ jobs:
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -42,7 +39,6 @@ jobs:
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -50,23 +46,6 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -74,23 +53,6 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "128"
# cuda_version: 12.8.1
# cudnn_version: ""
@@ -117,7 +79,7 @@ jobs:
axolotlai/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
if: ${{ github.event_name != 'pull_request' && secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -128,7 +90,7 @@ jobs:
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
platforms: linux/amd64,linux/arm64
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
@@ -143,8 +105,6 @@ jobs:
if: ${{ github.repository_owner == 'axolotl-ai-cloud' && (github.event_name != 'pull_request' || !github.event.pull_request.draft) }}
timeout-minutes: 480
runs-on: ubuntu-latest-m
env:
HAS_DOCKERHUB_CREDS: ${{ secrets.DOCKERHUB_USERNAME != '' && secrets.DOCKERHUB_TOKEN != '' }}
strategy:
fail-fast: false
matrix:
@@ -156,7 +116,6 @@ jobs:
pytorch: 2.8.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -164,7 +123,6 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
@@ -172,23 +130,6 @@ jobs:
pytorch: 2.9.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
# - cuda: "129"
# cuda_version: 12.9.1
# cudnn_version: ""
# python_version: "3.12"
# pytorch: 2.9.1
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
# dockerfile: "Dockerfile-uv-base"
# platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
@@ -196,23 +137,6 @@ jobs:
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.9.1
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
- cuda: "130"
cuda_version: 13.0.0
cudnn_version: ""
python_version: "3.12"
pytorch: 2.10.0
torch_cuda_arch_list: "9.0+PTX"
dockerfile: "Dockerfile-uv-base"
platforms: "linux/amd64,linux/arm64"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -224,7 +148,6 @@ jobs:
axolotlai/axolotl-base-uv
- name: Login to Docker Hub
uses: docker/login-action@v2
if: ${{ github.event_name != 'pull_request' && env.HAS_DOCKERHUB_CREDS == 'true' }}
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
@@ -235,7 +158,6 @@ jobs:
with:
context: .
file: ./docker/${{ matrix.dockerfile }}
platforms: ${{ matrix.platforms }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -20,44 +20,22 @@ jobs:
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
platforms: "linux/amd64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -83,7 +61,7 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
platforms: linux/amd64,linux/arm64
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
@@ -98,77 +76,6 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-uv:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
- name: Build and export to Docker
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
file: ./docker/Dockerfile-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
@@ -181,44 +88,22 @@ jobs:
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
platforms: "linux/amd64"
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
# - cuda: 129
# cuda_version: 12.9.1
# python_version: "3.12"
# pytorch: 2.9.1
# axolotl_extras:
# platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -243,7 +128,7 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
platforms: linux/amd64,linux/arm64
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
@@ -254,73 +139,6 @@ jobs:
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-uv:
needs: build-axolotl-uv
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
include:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
is_latest: true
platforms: "linux/amd64,linux/arm64"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.10.0
axolotl_extras:
platforms: "linux/amd64,linux/arm64"
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
with:
images: |
axolotlai/axolotl-cloud-uv
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build
uses: docker/build-push-action@v5
with:
context: .
platforms: ${{ matrix.platforms }}
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud-uv
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-cloud-no-tmux:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
@@ -331,11 +149,11 @@ jobs:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
pytorch: 2.8.0
axolotl_extras:
is_latest: true
- cuda: 130
cuda_version: 13.0.0
is_latest:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:

View File

@@ -35,26 +35,21 @@ jobs:
pytorch: 2.8.0
axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras: "fbgemm-gpu"
axolotl_extras: fbgemm-gpu
num_gpus: 2
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
pytorch: 2.9.1
axolotl_extras: "fbgemm-gpu"
num_gpus: 2
dockerfile: "Dockerfile-uv.jinja"
nightly_build: "true"
- cuda: 130
cuda_version: 13.0.0
python_version: "3.11"
pytorch: 2.9.1
axolotl_extras:
# axolotl_extras: fbgemm-gpu
axolotl_extras: fbgemm-gpu
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -76,8 +71,8 @@ jobs:
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run -m cicd.multigpu

View File

@@ -40,7 +40,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install wheel packaging==26.0
pip3 install wheel packaging==23.2
pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt
@@ -48,9 +48,9 @@ jobs:
id: tag
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
- name: Update version in VERSION file
- name: Update version in setup.py
run: |
echo "${{ steps.tag.outputs.TAG_NAME }}" | sed 's/^v//' > VERSION
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
- name: Build a source dist
run: |

View File

@@ -37,7 +37,7 @@ jobs:
id: hf-cache-restore-s3
run: |
mkdir -p /home/runner/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
- name: Setup Python
uses: actions/setup-python@v5
@@ -48,7 +48,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |

View File

@@ -54,13 +54,8 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.11", "3.12"]
python_version: ["3.11"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
timeout-minutes: 20
steps:
@@ -75,7 +70,7 @@ jobs:
id: hf-cache-restore-s3
run: |
mkdir -p ~/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python
@@ -87,7 +82,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
@@ -115,10 +110,10 @@ jobs:
- name: Pre-Download dataset fixture
run: |
hf download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Show HF cache
run: hf cache ls
run: hf cache scan
- name: Run tests
run: |
@@ -132,7 +127,7 @@ jobs:
pytest -v --durations=10 tests/cli/ --cov=axolotl --cov-append --cov-report=xml
- name: Show HF cache
run: hf cache ls
run: hf cache scan
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
@@ -149,13 +144,8 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.11", "3.12"]
python_version: ["3.11"]
pytorch_version: ["2.8.0", "2.9.0", "2.9.1"]
exclude:
- python_version: "3.12"
pytorch_version: "2.8.0"
- python_version: "3.12"
pytorch_version: "2.9.0"
timeout-minutes: 20
steps:
@@ -170,7 +160,7 @@ jobs:
id: hf-cache-restore-s3
run: |
mkdir -p ~/.cache/huggingface/hub
curl -L https://axolotl-ci.b-cdn.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xpf - -C ~/.cache/huggingface/hub/ --use-compress-program unzstd --strip-components=1
ls -ltr ~/.cache/huggingface/hub/
- name: Setup Python
@@ -182,7 +172,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel psutil
- name: Install PyTorch
run: |
@@ -210,7 +200,7 @@ jobs:
axolotl --help
- name: Show HF cache
run: hf cache ls
run: hf cache scan
- name: Run tests
run: |
@@ -219,10 +209,10 @@ jobs:
pytest -v --durations=10 tests/cli/
- name: Show HF cache
run: hf cache ls
run: hf cache scan
gate-skip-e2e:
needs: [pre-commit]
needs: [pre-commit, pytest, pytest-sdist]
runs-on: ubuntu-latest
outputs:
skip: ${{ steps.compute.outputs.skip }}
@@ -258,16 +248,16 @@ jobs:
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 120
needs: [pre-commit, pytest]
needs: [pre-commit, pytest, pytest-sdist, gate-skip-e2e]
strategy:
fail-fast: false
matrix:
include:
- cuda: 130
cuda_version: 13.0.0
python_version: "3.12"
pytorch: 2.9.1
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
@@ -369,9 +359,9 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 129
cuda_version: 12.9.1
python_version: "3.12"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.9.1
num_gpus: 1
axolotl_extras:

View File

@@ -123,7 +123,7 @@ datasets:
| --------------------------------- | -------------------------- | ----------------------------------- |
| `dataset_prepared_path` | `"data/last_run_prepared"` | Path for prepared dataset |
| `push_dataset_to_hub` | `""` | Push dataset to HF hub |
| `dataset_num_proc` | `4` | Number of preprocessing processes |
| `dataset_processes` | `4` | Number of preprocessing processes |
| `dataset_keep_in_memory` | `false` | Keep dataset in memory |
| `shuffle_merged_datasets` | `true` | Shuffle merged datasets |
| `shuffle_before_merging_datasets` | `false` | Shuffle each dataset before merging |

View File

@@ -39,6 +39,7 @@
# type: # linear | dynamic
# factor: # float
# # Whether you are training a 4-bit GPTQ quantized model
# gptq: true
# gptq_groupsize: 128 # group size
@@ -106,7 +107,7 @@
# push_dataset_to_hub: # repo path
# # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# # if not set.
# dataset_num_proc: # defaults to os.cpu_count() if not set
# dataset_processes: # defaults to os.cpu_count() if not set
# # push checkpoints to hub
# hub_model_id: # repo path to push finetuned model
# # how to push checkpoints to hub
@@ -223,6 +224,9 @@
# eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
# eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
# # Save model as safetensors (require safetensors package)
# save_safetensors:
# # Whether to mask out or include the human's prompt from the training labels
# train_on_inputs: false
# # Group similarly sized data to minimize padding.
@@ -348,6 +352,8 @@
# # Allow overwrite yml config using from cli
# strict:
base_model: ${BASE_MODEL}
base_model_ignore_patterns: ${BASE_MODEL_IGNORE_PATTERNS}
base_model_config: ${BASE_MODEL_CONFIG}
@@ -406,7 +412,7 @@ chat_template_jinja: ${CHAT_TEMPLATE_JINJA}
default_system_message: ${DEFAULT_SYSTEM_MESSAGE}
dataset_prepared_path: ${DATASET_PREPARED_PATH}
push_dataset_to_hub: ${PUSH_DATASET_TO_HUB}
dataset_num_proc: ${DATASET_NUM_PROC}
dataset_processes: ${DATASET_PROCESSES}
dataset_keep_in_memory: ${DATASET_KEEP_IN_MEMORY}
hub_model_id: ${HUB_MODEL_ID}
hub_strategy: ${HUB_STRATEGY}
@@ -506,6 +512,7 @@ profiler_steps: ${PROFILER_STEPS}
loss_watchdog_threshold: ${LOSS_WATCHDOG_THRESHOLD}
loss_watchdog_patience: ${LOSS_WATCHDOG_PATIENCE}
save_safetensors: ${SAVE_SAFETENSORS}
train_on_inputs: ${TRAIN_ON_INPUTS}
group_by_length: ${GROUP_BY_LENGTH}
gradient_checkpointing: ${GRADIENT_CHECKPOINTING}

View File

@@ -88,7 +88,7 @@ Features:
#### Using pip
```bash
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs

View File

@@ -1 +0,0 @@
0.15.0.dev0

View File

@@ -251,6 +251,7 @@ website:
- docs/models/olmo3.qmd
- docs/models/trinity.qmd
- docs/models/arcee.qmd
- docs/models/mistral.qmd
- section: "Ministral3"
contents:
- docs/models/ministral3.qmd
@@ -265,7 +266,6 @@ website:
- docs/models/mistral-small.qmd
- docs/models/voxtral.qmd
- docs/models/devstral.qmd
- docs/models/mistral.qmd
- docs/models/llama-4.qmd
- docs/models/llama-2.qmd
- docs/models/qwen3-next.qmd
@@ -320,7 +320,6 @@ website:
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/optimizers.qmd
- docs/attention.qmd
- section: "Advanced Features"
contents:

View File

@@ -31,7 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN uv pip install packaging==26.0 setuptools==75.8.0
RUN uv pip install packaging==23.2 setuptools==75.8.0
RUN uv pip install torchvision
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \

View File

@@ -32,7 +32,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN pip install packaging==26.0 setuptools==75.8.0 psutil
RUN pip install packaging==23.2 setuptools==75.8.0 psutil
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -17,8 +17,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
df_template = template_env.get_template(dockerfile)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
@@ -28,11 +27,8 @@ df_args = {
"CUDA": os.environ.get("CUDA", "126"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
}
dockerfile_contents = df_template.render(**df_args)

View File

@@ -2,7 +2,7 @@
set -e
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v --durations=10 -n2 --maxfail=3 \
pytest -v --durations=10 -n2 --maxfail=4 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \

View File

@@ -6,7 +6,6 @@ ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -21,17 +20,13 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \ python scripts/unsloth_install.py | sh && \
python scripts/unsloth_install.py | sh && \
python scripts/cutcrossentropy_install.py | sh && \
pip install pytest && \
pip cache purge

View File

@@ -43,7 +43,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel psutil && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel psutil && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip cache purge
@@ -59,18 +59,34 @@ RUN git lfs install --skip-repo && \
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
pip3 cache purge
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
pip3 install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$CUDA" = "128" ]; then \
if [ "$TARGETARCH" = "amd64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl"; \
WHL_VERSION="v0.5.4"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_aarch64.whl"; \
WHL_VERSION="v0.6.4"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
pip3 install --no-cache-dir ${WHL_FILE}; \
rm ${WHL_FILE}; \
elif [ "$CUDA" = "130" ]; then \
if [ "$TARGETARCH" = "amd64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl"; \
WHL_VERSION="v0.5.4"; \
elif [ "$TARGETARCH" = "arm64" ]; then \
WHL_FILE="flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_aarch64.whl"; \
WHL_VERSION="v0.6.4"; \
else \
echo "Unsupported architecture: $TARGETARCH"; exit 1; \
fi; \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}; \
pip3 install --no-cache-dir ${WHL_FILE}; \
rm ${WHL_FILE}; \
fi \
;; \
esac

View File

@@ -30,7 +30,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==26.0 setuptools==75.8.0 wheel && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \

View File

@@ -1,30 +0,0 @@
ARG BASE_TAG=main
FROM axolotlai/axolotl-uv:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
EXPOSE 8888
EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN uv pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt update && \
apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/* && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh && \
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"]

View File

@@ -1,47 +0,0 @@
ARG BASE_TAG=main-base
FROM axolotlai/axolotl-base-uv:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ARG PYTORCH_VERSION="2.1.2"
ARG TARGETARCH
ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev rsync s3fs && \
rm -rf /var/cache/apt/archives && \
rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
RUN if [ "$TARGETARCH" = "arm64" ]; then \
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
else \
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
fi && \
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
fi && \
python scripts/unsloth_install.py --uv | sh && \
python scripts/cutcrossentropy_install.py --uv | sh && \
uv pip install pytest && \
uv cache clean
# fix so that git fetch/pull from remote works with shallow clone
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch && \
git config --global credential.helper store
COPY .axolotl-complete.bash /root/.axolotl-complete.bash
RUN chmod +x /root/.axolotl-complete.bash && \
echo 'source /root/.axolotl-complete.bash' >> ~/.bashrc

View File

@@ -2,11 +2,9 @@ ARG CUDA_VERSION="12.6.3"
ARG CUDNN_VERSION=""
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
ARG TARGETARCH
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ARG TARGETARCH
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="2.6.0"
ARG CUDA="126"
@@ -33,25 +31,20 @@ ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
RUN uv pip install packaging setuptools wheel psutil \
&& uv pip install torch==${PYTORCH_VERSION} torchvision \
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
&& uv pip install awscli pydantic
RUN if [ "$TARGETARCH" = "amd64" ]; then \
uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main"; \
uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"; \
fi
# Map Python version (e.g., 3.12 -> cp312)
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
# Map architecture
case "$TARGETARCH" in \
amd64) ARCH_TAG="x86_64" ;; \
arm64) ARCH_TAG="aarch64" ;; \
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
esac && \
WHL_VERSION="v0.7.16" && \
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
uv pip install --no-cache-dir "${WHL_FILE}" && \
rm "${WHL_FILE}"
RUN case "$PYTORCH_VERSION" in \
2.9.[0-9]*) \
if [ "$CUDA" = "128" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
elif [ "$CUDA" = "130" ]; then \
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
uv pip install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
fi \
;; \
esac

View File

@@ -86,7 +86,7 @@ export HF_DATASETS_OFFLINE=1
Download a base model using the Hugging Face CLI:
```bash
hf download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
```
### 10. Create Axolotl Configuration

View File

@@ -1,140 +0,0 @@
---
title: Attention
description: Supported attention modules in Axolotl
---
## SDP Attention
This is the default built-in attention in PyTorch.
```yaml
sdp_attention: true
```
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
## Flash Attention 2
Uses efficient kernels to compute attention.
```yaml
flash_attention: true
```
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
### Nvidia
Requirements: Ampere, Ada, or Hopper GPUs
Note: For Turing GPUs or lower, please use other attention methods.
```bash
pip install flash-attn --no-build-isolation
```
::: {.callout-tip}
If you get `undefined symbol` while training, ensure you installed PyTorch prior to Axolotl. Alternatively, try reinstall or downgrade a version.
:::
#### Flash Attention 3
Requirements: Hopper only and CUDA 12.8 (recommended)
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py install
```
### AMD
Requirements: ROCm 6.0 and above.
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
## Flex Attention
A flexible PyTorch API for attention used in combination with `torch.compile`.
```yaml
flex_attention: true
# recommended
torch_compile: true
```
::: {.callout-note}
We recommend using latest stable version of PyTorch for best performance.
:::
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
## SageAttention
Attention kernels with QK Int8 and PV FP16 accumulator.
```yaml
sage_attention: true
```
Requirements: Ampere, Ada, or Hopper GPUs
```bash
pip install sageattention==2.2.0 --no-build-isolation
```
::: {.callout-warning}
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
:::
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
::: {.callout-note}
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
:::
## xFormers
```yaml
xformers_attention: true
```
::: {.callout-tip}
We recommend using with Turing GPUs or below (such as on Colab).
:::
For more details: [xFormers](https://github.com/facebookresearch/xformers)
## Shifted Sparse Attention
::: {.callout-warning}
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
:::
Requirements: LLaMA model architecture
```yaml
flash_attention: true
s2_attention: true
```
::: {.callout-tip}
No sample packing support!
:::

View File

@@ -210,8 +210,6 @@ axolotl lm-eval config.yml
Configuration options:
```yaml
lm_eval_model: # model to evaluate (local or hf path)
# List of tasks to evaluate
lm_eval_tasks:
- arc_challenge
@@ -220,7 +218,7 @@ lm_eval_batch_size: # Batch size for evaluation
output_dir: # Directory to save evaluation results
```
See [LM Eval Harness integration docs](https://docs.axolotl.ai/docs/custom_integrations.html#language-model-evaluation-harness-lm-eval) for full configuration details.
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
### delinearize-llama4

View File

@@ -165,7 +165,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
```
4. (Optional) Login to Hugging Face:
```{.bash}
hf auth login
huggingface-cli login
```
## Troubleshooting {#sec-troubleshooting}

View File

@@ -89,10 +89,6 @@ lora_o_kernel: true
Currently, LoRA kernels are not supported for RLHF training, only SFT.
:::
::: {.callout-warning}
LoRA kernels do not support remote modeling code.
:::
## Requirements
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)

View File

@@ -19,7 +19,6 @@ format:
- [Gemma-3n](#sec-gemma-3n)
- [Qwen2-VL](#sec-qwen2-vl)
- [Qwen2.5-VL](#sec-qwen25-vl)
- [GLM-4.6V](#sec-glm-4-6v)
- [SmolVLM2](#sec-smolvlm2)
- [LFM2-VL](#sec-lfm2-vl)
- [Intern-VL](#sec-intern-vl)
@@ -184,18 +183,6 @@ base_model: Qwen/Qwen3-VL-4B-Instruct
chat_template: qwen2_vl # same as qwen2-vl
```
### GLM-4.6V {#sec-glm-4-6v}
Both GLM-4.6V (106B MoE) and GLM-4.6V-Flash (9B) are supported.
```yaml
# GLM-4.6V (106B MoE version)
base_model: zai-org/GLM-4.6V
# OR GLM-4.6V-Flash (9B version)
base_model: zai-org/GLM-4.6V-Flash
```
### SmolVLM2 {#sec-smolvlm2}
::: {.callout-tip}

View File

@@ -17,7 +17,6 @@ feedback. Various methods include, but not limited to:
- [Kahneman-Tversky Optimization (KTO)](#kto)
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
- [Group Relative Policy Optimization (GRPO)](#grpo)
- [Group Reward-Decoupled Policy Optimization (GDPO)](#gdpo)
## RLHF using Axolotl
@@ -721,102 +720,6 @@ trl:
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
### GDPO
GDPO (Group Reward-Decoupled Policy Optimization) extends GRPO for multi-reward training. It addresses the **reward advantage collapse** problem by normalizing each reward function independently before combining them.
::: {.callout-tip}
Use GDPO when training with multiple reward functions. For single reward, GRPO and GDPO produce equivalent results.
:::
Paper: [https://arxiv.org/pdf/2501.05242](https://arxiv.org/pdf/2501.05242)
GDPO uses TRL's native `multi_objective_aggregation` parameter under the hood. When you set `rl: gdpo`, axolotl automatically configures TRL to use `normalize_then_sum` aggregation.
```yaml
base_model: Qwen/Qwen2.5-1.5B-Instruct
vllm:
host: 0.0.0.0
port: 8000
tensor_parallel_size: 2
gpu_memory_utilization: 0.85
rl: gdpo
trl:
beta: 0.001
max_completion_length: 256
use_vllm: true
num_generations: 4
reward_funcs:
- rewards.format_reward
- rewards.correctness_reward
reward_weights: [1.0, 2.0]
datasets:
- path: openai/gsm8k
name: main
type: rewards.oai_gsm8k_transform
```
You can also use GRPO with explicit aggregation control:
```yaml
rl: grpo
trl:
multi_objective_aggregation: normalize_then_sum # GDPO behavior
# or: sum_then_normalize # Default GRPO behavior
```
#### GDPO vs GRPO
| Aspect | GRPO | GDPO |
|--------|------|------|
| **Aggregation** | `sum_then_normalize` | `normalize_then_sum` |
| **Multi-reward** | May collapse advantages | Preserves reward signals |
| **Single reward** | Standard behavior | Equivalent to GRPO |
#### Why GDPO?
When using multiple rewards with GRPO, different reward combinations can produce identical advantages:
```
# Example: format + correctness rewards
[format=0, correct=3] → sum=3
[format=1, correct=2] → sum=3 ← GRPO sees these as equal!
[format=2, correct=1] → sum=3
[format=3, correct=0] → sum=3
```
GDPO normalizes each reward independently, preserving their relative differences.
#### Reward Functions
GDPO uses the same reward function format as GRPO:
```python
# rewards.py
def format_reward(completions, **kwargs) -> list[float]:
return [1.0 if len(c) > 10 else 0.0 for c in completions]
def correctness_reward(completions, answers, **kwargs) -> list[float]:
rewards = []
for completion, answer in zip(completions, answers):
# Your scoring logic here
rewards.append(score)
return rewards
```
#### Sequence Parallelism
GDPO supports sequence parallelism for long-context training:
```yaml
rl: gdpo
context_parallel_size: 2
```
### SimPO
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.

View File

@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -17,7 +17,7 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2\""
]
},
{

View File

@@ -16,7 +16,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -1,77 +0,0 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
chat_template: gemma3
eot_tokens:
- <end_of_turn>
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/eaft-gemma-3-1b
use_eaft: true
eaft_alpha: 1.0
eaft_k: 20
sequence_len: 1024
sample_packing: false
adapter:
lora_model_dir:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
eval_batch_size: 1
max_steps: 1000
evaluation_strategy: "no"
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_ratio: 0.1
weight_decay: 0.0
debug:
deepspeed:
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,7 +1,6 @@
base_model: google/gemma-3-1b-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -30,7 +29,7 @@ output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048

View File

@@ -1,7 +1,6 @@
base_model: google/gemma-3-270m-it
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
@@ -30,7 +29,7 @@ output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048

View File

@@ -2,7 +2,6 @@ base_model: google/gemma-3-4b-it
# Need to set else transformers tries to load vision too
model_type: Gemma3ForCausalLM
cls_model_config: Gemma3TextConfig
load_in_4bit: true
@@ -33,8 +32,8 @@ sample_packing: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_target_linear: true
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:

View File

@@ -31,7 +31,7 @@ pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0
lora_dropout: 0.05
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:

View File

@@ -10,7 +10,7 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -1,44 +0,0 @@
# Finetune GLM-4.6V with Axolotl
GLM-4.6V is a family of vision-language models from ZhipuAI found on [HuggingFace](https://huggingface.co/zai-org/GLM-4.6V). This guide shows how to fine-tune it with Axolotl for vision-language tasks.
## Getting started
1. Install Axolotl from source following the [installation guide](https://docs.axolotl.ai/docs/installation.html#sec-edge-build).
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
3. Run the fine-tuning:
glm-4-6v-flash(9B)
```bash
axolotl train examples/glm46v/glm-4-6v-flash-qlora.yaml
```
Let us know how it goes. Happy finetuning! 🚀
## Tips
- Vision datasets should follow the format described in the [multimodal docs](https://docs.axolotl.ai/docs/multimodal.html#dataset-format)
- You can run a **full finetuning** by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
- Read more on how to load your own dataset in the [dataset loading docs](https://docs.axolotl.ai/docs/dataset_loading.html).
## Supported Models
- **GLM-4.6V**: Full vision-language model (`zai-org/GLM-4.6V`)
- **GLM-4.6V-Flash**: Faster variant (`zai-org/GLM-4.6V-Flash`)
## Optimization Guides
Please check the [Optimizations doc](https://docs.axolotl.ai/docs/optimizations.html).
## Related Resources
- [ZhipuAI GLM-4.6V](https://huggingface.co/zai-org/GLM-4.6V)
- [Axolotl Docs](https://docs.axolotl.ai)
- [Axolotl Website](https://axolotl.ai)
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)

View File

@@ -1,53 +0,0 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
ddp_find_unused_parameters: true
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -1,50 +0,0 @@
base_model: zai-org/GLM-4.6V-Flash
trust_remote_code: true
processor_type: AutoProcessor
load_in_4bit: true
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
output_dir: ./outputs/glm-4-6v-flash-qlora
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
adapter: qlora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
logging_steps: 1
sdp_attention: true
warmup_ratio: 0.1
evals_per_epoch: 0
saves_per_epoch: 1
weight_decay: 0.0

View File

@@ -14,7 +14,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -13,7 +13,7 @@ Tencent released a family of opensource models called HunYuan with varying param
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -19,6 +19,7 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: jamba-large-fsdp-qlora-ft
save_safetensors: true
adapter: qlora
sequence_len: 2048
sample_packing: true

View File

@@ -1,68 +0,0 @@
base_model: meta-llama/Llama-3.2-1B-Instruct
chat_template: llama3
rl: gdpo
trl:
beta: 0.001
max_completion_length: 128
num_generations: 2
temperature: 0.7
top_p: 0.95
use_vllm: false
multi_objective_aggregation: normalize_then_sum
reward_funcs:
- rwd.format_reward
- rwd.correctness_reward
reward_weights: [1.0, 2.0]
log_completions: true
num_completions_to_print: 3
scale_rewards: true
datasets:
- path: openai/gsm8k
name: main
split: train[:1000]
type: rwd.gsm8k_transform
val_set_size: 0.0
output_dir: ./outputs/llama3-gdpo-out
sequence_len: 512
sample_packing: false
pad_to_sequence_len: false
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
max_steps: 100
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-5
weight_decay: 0.01
warmup_steps: 10
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
flash_attention: true
logging_steps: 1
save_steps: 50
save_safetensors: true
special_tokens:
pad_token: "<|end_of_text|>"
seed: 42

View File

@@ -12,6 +12,7 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b
save_safetensors: true
adapter: qlora

View File

@@ -14,7 +14,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
```bash
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -47,5 +47,6 @@ saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
tokens:
save_safetensors: False
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -15,7 +15,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn]'
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy

View File

@@ -12,7 +12,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
```bash
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
```

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==26.0"]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
build-backend = "setuptools.build_meta"
[project]
@@ -24,9 +24,6 @@ Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
include-package-data = true
[tool.setuptools.dynamic]
version = { file = "VERSION" }
[tool.setuptools.cmdclass]
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
@@ -60,6 +57,3 @@ indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
docstring-code-format = false
[tool.uv.extra-build-dependencies]
axolotl = ["huggingface_hub"]

View File

@@ -2,25 +2,25 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.49.1
triton>=3.4.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
liger-kernel==0.7.0
liger-kernel==0.6.4
# END section
packaging==26.0
huggingface_hub>=1.1.7
peft>=0.18.1
tokenizers>=0.22.1
transformers==5.2.0
accelerate==1.12.0
datasets==4.5.0
deepspeed>=0.18.3
trl==0.28.0
hf_xet==1.2.0
kernels==0.12.1
packaging==23.2
trackio>=0.16.1
huggingface_hub>=0.36.0
peft>=0.18.0
tokenizers>=0.22.1
transformers==4.57.1
accelerate==1.12.0
datasets==4.4.2
deepspeed>=0.18.3
trl==0.25.1
hf_xet==1.2.0
kernels==0.11.5
trackio>=0.13.0
typing-extensions>=4.15.0
optimum==1.16.2
@@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.16.0
torchao==0.13.0
openenv-core==0.1.0
schedulefree==1.4.1
@@ -72,4 +72,4 @@ axolotl-contribs-mit==0.0.6
# telemetry
posthog==6.7.11
mistral-common==1.8.8
mistral-common==1.8.6

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"'
)

View File

@@ -1,5 +1,6 @@
"""setup.py for axolotl"""
import ast
import os
import platform
import re
@@ -25,12 +26,6 @@ def parse_requirements(extras_require_map):
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
install_xformers = platform.machine() != "aarch64"
if platform.machine() == "aarch64":
# skip torchao on ARM64
_install_requires = [
req for req in _install_requires if "torchao" not in req
]
if "Darwin" in platform.system():
# skip packages not compatible with OSX
skip_packages = [
@@ -67,68 +62,44 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
torch_parts = torch_version.split("+")
if len(torch_parts) == 2:
torch_cuda_version = torch_parts[1]
_dependency_links.append(
f"https://download.pytorch.org/whl/{torch_cuda_version}"
)
if (major, minor) >= (2, 9):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = [
"fbgemm-gpu==1.4.0",
"fbgemm-gpu-genai==1.4.2",
]
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.4.1"]
extras_require_map["vllm"] = ["vllm==0.11.1"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
extras_require_map["vllm"] = ["vllm==0.13.0"]
if patch == 0:
extras_require_map["vllm"] = ["vllm==0.13.0"]
else:
extras_require_map["vllm"] = ["vllm==0.14.0"]
elif (major, minor) >= (2, 8):
extras_require_map.pop("fbgemm-gpu")
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
extras_require_map["vllm"] = ["vllm==0.11.0"]
if not install_xformers:
_install_requires.pop(_install_requires.index(xformers_version))
elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
if install_xformers:
_install_requires.append("xformers==0.0.30")
_install_requires.append("xformers==0.0.30")
# vllm 0.9.x is incompatible with latest transformers
extras_require_map.pop("vllm")
else:
if install_xformers:
_install_requires.append("xformers==0.0.31")
_install_requires.append("xformers==0.0.31")
extras_require_map["vllm"] = ["vllm==0.10.1"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
if install_xformers:
_install_requires.append("xformers==0.0.29.post3")
_install_requires.append("xformers==0.0.29.post3")
# since we only support 2.6.0+cu126
_dependency_links.append("https://download.pytorch.org/whl/cu126")
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if install_xformers:
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers>=0.0.28.post3")
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers>=0.0.28.post3")
extras_require_map.pop("vllm")
elif (major, minor) >= (2, 4):
extras_require_map.pop("vllm")
if install_xformers:
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.27")
else:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.28.post1")
else:
raise ValueError("axolotl requires torch>=2.4")
@@ -139,11 +110,15 @@ def parse_requirements(extras_require_map):
def get_package_version():
with open(
Path(os.path.dirname(os.path.abspath(__file__))) / "VERSION",
Path(os.path.dirname(os.path.abspath(__file__)))
/ "src"
/ "axolotl"
/ "__init__.py",
"r",
encoding="utf-8",
) as fin:
version_ = fin.read().strip()
version_match = re.search(r"^__version__\s*=\s*(.*)$", fin.read(), re.MULTILINE)
version_ = ast.literal_eval(version_match.group(1))
return version_

View File

@@ -1,11 +1,7 @@
"""Axolotl - Train and fine-tune large language models"""
import pkgutil
from importlib.metadata import PackageNotFoundError, version
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
try:
__version__ = version("axolotl")
except PackageNotFoundError:
__version__ = "unknown"
__version__ = "0.13.0.dev"

View File

@@ -5,6 +5,6 @@ import os
from axolotl.logging_config import configure_logging
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("HF_XET_HIGH_PERFORMANCE", "1")
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
configure_logging()

View File

@@ -44,7 +44,7 @@ def check_user_token() -> bool:
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `hf auth login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False
except HTTPError:

View File

@@ -5,7 +5,7 @@ import os
import tempfile
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Optional, Union
from typing import Union
from urllib.parse import urlparse
import requests
@@ -32,63 +32,6 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__)
def _coerce_value(value: Any, existing: Optional[Any] = None) -> Any:
"""Coerce a string CLI value to its most likely Python type.
If an existing value is present in the config, its type is used to guide
casting. Otherwise, YAML-style inference is applied: booleans, ints,
floats, and None literals are recognised automatically.
Args:
value: The raw value (typically a string from the CLI).
existing: An optional existing config value whose type guides coercion.
Returns:
The value cast to the inferred or expected type.
"""
if not isinstance(value, str):
return value
# If the config already has a typed value, cast to match
if existing is not None:
if isinstance(existing, bool):
return value.lower() in ("true", "1", "yes")
if isinstance(existing, int):
try:
return int(value)
except (ValueError, TypeError):
return value
if isinstance(existing, float):
try:
return float(value)
except (ValueError, TypeError):
return value
# For other types (str, list, dict, etc.), return as-is
return value
# No existing value -- use YAML-style inference
lower = value.lower()
if lower in ("true", "yes"):
return True
if lower in ("false", "no"):
return False
if lower in ("null", "none", "~"):
return None
# Try int then float
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
API_KEY_FIELDS = {"comet_api_key"}
TELEMETRY_MANAGER = TelemetryManager.get_instance()
@@ -265,37 +208,13 @@ def load_cfg(
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
# Separate nested (dot-notation) kwargs from flat kwargs
nested_kwargs: dict[str, dict[str, Any]] = {}
flat_kwargs: dict[str, Any] = {}
for key, value in kwargs.items():
if "__" in key:
parent, child = key.split("__", 1)
nested_kwargs.setdefault(parent, {})[child] = value
else:
flat_kwargs[key] = value
# Apply flat kwargs
for key, value in flat_kwargs.items():
# If not strict, allow writing to cfg even if it's not in the yml already
if key in cfg_keys or not cfg.strict:
cfg[key] = _coerce_value(value, cfg.get(key))
# Apply nested kwargs (e.g., trl__beta -> cfg.trl.beta)
for parent, children in nested_kwargs.items():
if parent not in cfg_keys and cfg.strict:
continue
if cfg[parent] is None:
cfg[parent] = {}
if not isinstance(cfg[parent], dict):
LOG.warning(
"Overwriting non-dict value for '%s' with nested CLI overrides", parent
)
cfg[parent] = {}
for child_key, child_value in children.items():
existing_child = cfg[parent].get(child_key)
cfg[parent][child_key] = _coerce_value(child_value, existing_child)
if isinstance(cfg[key], bool):
cfg[key] = bool(value)
else:
cfg[key] = value
try:
device_props = torch.cuda.get_device_properties("cuda")

View File

@@ -24,6 +24,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True
LOG.info("Running merge of LoRA with base model...")
model = model.merge_and_unload(progressbar=True)
@@ -41,6 +42,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(

View File

@@ -14,6 +14,8 @@ from accelerate import PartialState
from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_torch_version,
)
from huggingface_hub import split_torch_state_dict_into_shards
@@ -38,15 +40,17 @@ class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
def _distributed_checkpoint_to_merged_weights(
checkpoint_dir: Union[str, Path],
save_path: str,
safe_serialization: bool = False,
max_shard_size: str = "5GB",
) -> Path:
"""
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will
save under `save_path` as `model.safetensors`.
save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
Args:
checkpoint_dir: Directory where distributed checkpoint is saved.
save_path: Path to save model to.
safe_serialization: Whether to save in safetensors format.
max_shard_size: Max size of model shards to save.
Returns:
@@ -72,7 +76,11 @@ def _distributed_checkpoint_to_merged_weights(
if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16:
state_dict[key] = value.to(torch.bfloat16)
filename_pattern = SAFE_WEIGHTS_NAME.replace(".safetensors", "{suffix}.safetensors")
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
@@ -90,12 +98,19 @@ def _distributed_checkpoint_to_merged_weights(
for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor] for tensor in tensors}
safe_save_file(
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
)
if safe_serialization:
safe_save_file(
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
)
else:
torch.save(shard, os.path.join(save_path_, shard_file))
if index is not None:
save_index_file = os.path.join(save_path_, SAFE_WEIGHTS_INDEX_NAME)
save_index_file = (
SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
)
save_index_file = os.path.join(save_path_, save_index_file)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as fout:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
@@ -108,11 +123,13 @@ def _distributed_checkpoint_to_merged_weights(
def merge_fsdp_weights(
checkpoint_dir: str,
output_path: str,
safe_serialization: bool = False,
remove_checkpoint_dir: bool = False,
):
"""
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors`.
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
`safe_serialization` else `pytorch_model.bin`.
Note: this is a CPU-bound process.
@@ -121,6 +138,8 @@ def merge_fsdp_weights(
The directory containing the FSDP checkpoints (can be either the model or optimizer).
output_path (`str`):
The path to save the merged checkpoint.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the merged weights with safetensors (recommended).
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging.
@@ -158,7 +177,7 @@ def merge_fsdp_weights(
if state.is_main_process:
LOG.info(f"Merging FSDP weights from {checkpoint_dir_}")
save_path = _distributed_checkpoint_to_merged_weights(
checkpoint_dir_, output_path
checkpoint_dir_, output_path, safe_serialization
)
LOG.info(f"Successfully merged FSDP weights and saved to {save_path}")
if remove_checkpoint_dir:
@@ -191,6 +210,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=output_path,
safe_serialization=True,
)
state = PartialState()
state.wait_for_everyone()

View File

@@ -102,10 +102,12 @@ def do_quantize(
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
)
tokenizer.save_pretrained(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
@@ -119,7 +121,7 @@ def do_quantize(
hub_model_id.rstrip("-")
+ f"-{quantization_config_to_str[type(quantization_config)]}"
)
model.push_to_hub(hub_model_id)
model.push_to_hub(hub_model_id, safe_serialization=False)
tokenizer.push_to_hub(hub_model_id)
if processor:
processor.push_to_hub(hub_model_id)

View File

@@ -2,7 +2,7 @@
import dataclasses
from functools import wraps
from types import NoneType, UnionType
from types import NoneType
from typing import Any, Callable, Type, Union, get_args, get_origin
import click
@@ -20,8 +20,7 @@ def _strip_optional_type(field_type: type | str | None):
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
is_union = get_origin(field_type) is Union or isinstance(field_type, UnionType)
if is_union and type(None) in get_args(field_type):
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
@@ -88,70 +87,10 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
return decorator
def _is_pydantic_model(field_type: type) -> bool:
"""Check if a type is a Pydantic BaseModel subclass."""
try:
return isinstance(field_type, type) and issubclass(field_type, BaseModel)
except TypeError:
return False
def _get_field_description(field) -> str | None:
"""Get description from a Pydantic field, checking both .description and json_schema_extra."""
if field.description:
return field.description
if field.json_schema_extra and isinstance(field.json_schema_extra, dict):
return field.json_schema_extra.get("description")
return None
def _add_nested_model_options(
function: Callable, parent_name: str, model_class: Type[BaseModel]
) -> Callable:
"""
Add Click options for all fields of a nested Pydantic model using dot-notation.
Note: Only single-level nesting is supported (e.g., ``--trl.beta``).
Deeper nesting (e.g., ``--trl.scheduler.warmup``) is not handled.
Args:
function: Click command function to add options to.
parent_name: Parent field name (e.g., "trl").
model_class: Nested Pydantic model class.
Returns:
Function with added Click options.
"""
for sub_name, sub_field in reversed(model_class.model_fields.items()):
sub_type = _strip_optional_type(sub_field.annotation)
# Use dot notation: --parent.sub_field
cli_name = f"{parent_name}.{sub_name}".replace("_", "-")
# The kwarg name uses double-underscore as separator
param_name = f"{parent_name}__{sub_name}"
description = _get_field_description(sub_field)
if sub_type is bool:
option_name = f"--{cli_name}/--no-{cli_name}"
function = click.option(
option_name, param_name, default=None, help=description
)(function)
else:
option_name = f"--{cli_name}"
click_type = {str: str, int: int, float: float}.get(sub_type)
function = click.option(
option_name, param_name, default=None, type=click_type, help=description
)(function)
return function
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
"""
Create Click options from the fields of a Pydantic model.
For fields whose type is itself a Pydantic BaseModel, dot-notation CLI options are
generated for each sub-field (e.g., ``--trl.beta=0.1``).
Args:
config_class: PyDantic model with fields to parse from the CLI
@@ -164,11 +103,6 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
for name, field in reversed(config_class.model_fields.items()):
field_type = _strip_optional_type(field.annotation)
# Handle nested Pydantic models with dot-notation options
if _is_pydantic_model(field_type):
function = _add_nested_model_options(function, name, field_type)
continue
if field_type is bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"

View File

@@ -216,7 +216,7 @@ class TrainerBuilderBase(abc.ABC):
def _configure_warmup_and_logging(
self, total_num_steps: int, training_args_kwargs: dict
):
warmup_steps: int | float = 0
warmup_steps = 0
warmup_ratio = 0.0
if self.cfg.warmup_steps is not None:
warmup_steps = self.cfg.warmup_steps
@@ -230,10 +230,6 @@ class TrainerBuilderBase(abc.ABC):
else:
warmup_ratio = 0.03
# transformers v5
if warmup_ratio > 0.0 and warmup_steps == 0:
warmup_steps = warmup_ratio
if warmup_steps == 1:
warmup_steps = 2
@@ -246,6 +242,7 @@ class TrainerBuilderBase(abc.ABC):
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_args_kwargs["warmup_ratio"] = warmup_ratio
training_args_kwargs["warmup_steps"] = warmup_steps
def _configure_precision_settings(self, training_args_kwargs: dict):
@@ -409,9 +406,6 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.hub_strategy:
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.hub_revision:
training_args_kwargs["hub_revision"] = self.cfg.hub_revision
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
# save_strategy and save_steps
if self.cfg.save_steps:
@@ -536,7 +530,9 @@ class TrainerBuilderBase(abc.ABC):
"loraplus_lr_ratio",
"loraplus_lr_embedding",
"output_dir",
"save_safetensors",
"save_only_model",
"include_tokens_per_second",
"weight_decay",
"seed",
"dion_momentum",
@@ -549,7 +545,6 @@ class TrainerBuilderBase(abc.ABC):
arg_map = {
"dion_learning_rate": "dion_lr",
"include_num_input_tokens_seen": "include_tokens_per_second",
}
for kwarg, cfg_arg in arg_map.items():
if hasattr(self.cfg, cfg_arg) and getattr(self.cfg, cfg_arg) is not None:

View File

@@ -122,12 +122,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
if getattr(self.cfg, "generate_samples", False):
from axolotl.utils.callbacks.generation import SFTGenerationCallback
callbacks.append(SFTGenerationCallback(trainer))
LOG.info("SFT sample generation enabled")
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
@@ -252,8 +246,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
ddp_find_unused_parameters
)
if self.cfg.group_by_length:
training_arguments_kwargs["train_sampling_strategy"] = "group_by_length"
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
@@ -380,18 +373,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = multiple
if self.cfg.use_eaft:
from functools import partial
from axolotl.monkeypatch.loss.eaft import eaft_loss
configured_eaft_loss = partial(
eaft_loss,
alpha=self.cfg.eaft_alpha if self.cfg.eaft_alpha is not None else 1.0,
k=self.cfg.eaft_k if self.cfg.eaft_k is not None else 20,
)
trainer_kwargs["compute_loss_func"] = configured_eaft_loss
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
@@ -456,9 +437,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn) or (
self.cfg.micro_batch_size == 1 and is_eval is False
):
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
return None
if self.cfg.model_config_type == "mamba":

View File

@@ -11,6 +11,7 @@ from axolotl.core.trainers import (
)
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback
@@ -51,13 +52,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
trainer_cls = None
trainer_cls_args = [self.model]
if self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.context_parallel_size > 1
)
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
@@ -134,17 +134,19 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
blocklist_args_kwargs.append("max_prompt_length")
# Handle when max_prompt_length == max_length from defaults
# CPOTrainer requires strictly less than
if (
training_args_kwargs["max_prompt_length"]
== training_args_kwargs["max_length"]
):
training_args_kwargs["max_prompt_length"] -= 1
elif self.cfg.rl is RLType.ORPO:
training_args_cls = AxolotlORPOConfig
blocklist_args_kwargs.append("max_prompt_length")
elif self.cfg.rl is RLType.KTO:
training_args_cls = AxolotlKTOConfig
# KTOConfig in TRL >= 0.27.0 no longer accepts max_prompt_length
blocklist_args_kwargs.append("max_prompt_length")
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
@@ -153,16 +155,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.kto_undesirable_weight or 1.0
)
elif self.cfg.rl in {RLType.GRPO, RLType.GDPO}:
from axolotl.core.trainers.grpo import GRPOStrategy
elif self.cfg.rl is RLType.GRPO:
training_args_cls = GRPOStrategy.get_training_args_class()
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
if self.cfg.rl is RLType.GDPO:
training_args_kwargs.setdefault(
"multi_objective_aggregation", "normalize_then_sum"
)
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
training_args_cls = AxolotlDPOConfig

View File

@@ -25,7 +25,7 @@ from torch.utils.data import (
from transformers import PreTrainedModel, Trainer
from transformers.trainer import TRAINING_ARGS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length, seed_worker
from transformers.utils import SAFE_WEIGHTS_NAME, is_peft_available
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, is_peft_available
from trl.trainer.utils import pad_to_length
from typing_extensions import override
@@ -719,16 +719,6 @@ class AxolotlTrainer(
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
LOG.info(f"Saving model checkpoint to {output_dir}")
# fix for Context Parallel save
if state_dict is None:
state_dict = self.accelerator.get_state_dict(self.model)
if state_dict is not None:
state_dict = {
k: v.clone() if isinstance(v, torch.Tensor) else v
for k, v in state_dict.items()
}
supported_classes = (
(PreTrainedModel,)
if not is_peft_available()
@@ -739,7 +729,6 @@ class AxolotlTrainer(
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
if isinstance(
self.accelerator.unwrap_model(self.model, keep_torch_compile=False),
supported_classes,
@@ -749,31 +738,43 @@ class AxolotlTrainer(
).save_pretrained(
output_dir,
state_dict=state_dict,
is_main_process=self.accelerator.is_main_process,
safe_serialization=self.args.save_safetensors,
)
else:
LOG.info(
"Trainer.model is not a `PreTrainedModel`, only saving its state dict."
)
safetensors.torch.save_file(
state_dict,
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
metadata={"format": "pt"},
)
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict,
os.path.join(output_dir, SAFE_WEIGHTS_NAME),
metadata={"format": "pt"},
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
self.model.save_pretrained(output_dir, state_dict=state_dict)
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
self.model.save_pretrained(
output_dir,
state_dict=state_dict,
safe_serialization=self.args.save_safetensors,
is_main_process=self.accelerator.is_main_process,
)
self.data_collator.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
if self.processing_class is not None:
self.processing_class.save_pretrained(output_dir)
elif (
self.data_collator is not None
and hasattr(self.data_collator, "tokenizer")
and self.data_collator.tokenizer is not None
):
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -57,18 +57,16 @@ class AxolotlDPOTrainer(
def tokenize_row(
features,
processing_class,
max_prompt_length: int | None = None,
max_completion_length: int | None = None,
add_special_tokens: bool = True,
is_chat: bool = False,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length=max_prompt_length,
max_completion_length=max_completion_length,
add_special_tokens=add_special_tokens,
is_chat=is_chat,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:

View File

@@ -126,10 +126,8 @@ class GRPOStrategy:
if trl.use_liger_loss is not None:
grpo_args_kwargs["use_liger_loss"] = trl.use_liger_loss
if trl.multi_objective_aggregation is not None:
grpo_args_kwargs["multi_objective_aggregation"] = (
trl.multi_objective_aggregation
)
if trl.rollout_func:
grpo_args_kwargs["rollout_func"] = cls.get_rollout_func(trl.rollout_func)
return grpo_args_kwargs
@@ -151,8 +149,6 @@ class GRPOStrategy:
trainer_kwargs["reward_processing_classes"] = (
cfg.trl.reward_processing_classes
)
if cfg.trl and cfg.trl.rollout_func:
trainer_kwargs["rollout_func"] = cls.get_rollout_func(cfg.trl.rollout_func)
return trainer_kwargs
@@ -163,12 +159,7 @@ class GRPOStrategy:
@classmethod
def get_blocklist_args_kwargs(cls) -> list[str]:
return [
"dataset_num_proc",
"max_length",
"include_tokens_per_second",
"max_prompt_length",
]
return ["dataset_num_proc", "max_length", "include_tokens_per_second"]
@classmethod
def get_reward_func(cls, reward_func_fqn: str) -> RewardFunc:

View File

@@ -25,7 +25,7 @@ class SchedulerMixin(Trainer):
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: None | torch.optim.Optimizer = None
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
) -> LRScheduler:
"""
Set up the scheduler. The optimizer of the trainer must have been set up either before this method is called or
@@ -45,13 +45,6 @@ class SchedulerMixin(Trainer):
and self.args.cosine_min_lr_ratio is not None
)
if optimizer is None:
if self.optimizer is None:
raise ValueError(
"Optimizer must be set before calling create_scheduler or passed as an argument."
)
optimizer = self.optimizer
# fmt: off
if self.lr_scheduler is None: # type: ignore
# fmt: on

View File

@@ -1,10 +1,12 @@
"""Module for TRL RL trainers"""
from trl import RewardTrainer
from trl.experimental.cpo import CPOTrainer
from trl.experimental.kto import KTOTrainer
from trl.experimental.orpo import ORPOTrainer
from trl.experimental.prm import PRMTrainer
from trl import (
CPOTrainer,
KTOTrainer,
ORPOTrainer,
PRMTrainer,
RewardTrainer,
)
from axolotl.core.trainers.mixins import DistributedParallelMixin, RngLoaderMixin
from axolotl.core.trainers.mixins.optimizer import OptimizerInitMixin, OptimizerMixin

View File

@@ -8,11 +8,7 @@ from dataclasses import dataclass, field
from typing import Optional, Type
from transformers import TrainingArguments
from trl import RewardConfig
from trl.experimental.cpo import CPOConfig
from trl.experimental.kto import KTOConfig
from trl.experimental.orpo import ORPOConfig
from trl.experimental.prm import PRMConfig
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.integrations.config import merge_training_args

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"
```
## Usage
@@ -31,13 +31,11 @@ plugins:
## Supported Models
- afmoe
- apertus
- arcee
- cohere
- cohere2
- deepseek_v3
- exaone4
- gemma
- gemma2
- gemma3
@@ -47,17 +45,13 @@ plugins:
- glm
- glm4
- glm4_moe
- glm4_moe_lite
- glm46v
- glm4v
- glm4v_moe
- glm_image
- glm_moe_dsa
- gpt_oss
- granite
- granitemoe
- granitemoehybrid
- granitemoeshared
- granitemoehybrid
- hunyuan_v1_dense
- hunyuan_v1_moe
- internvl
@@ -78,26 +72,20 @@ plugins:
- olmo
- olmo2
- olmo3
- olmoe
- phi
- phi3
- phi4_multimodal
- qwen2
- qwen2_5_vl
- qwen2_moe
- qwen2_vl
- qwen2_moe
- qwen2_5_vl
- qwen3
- qwen3_5
- qwen3_5_moe
- qwen3_5_moe_vl
- qwen3_5_vl
- qwen3_moe
- qwen3_next
- qwen3_vl
- qwen3_vl_moe
- seed_oss
- qwen3_next
- smollm3
- step3p5
- seed_oss
- voxtral
## Citation

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@58d6572"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@318b7e2"`'
)
@@ -104,7 +104,7 @@ class CutCrossEntropyPlugin(BasePlugin):
def patch_llama_like(
self,
model_type_to_patch: str,
model_type: str,
) -> None:
"""
Generic patch for model architectures with causal lm similar to llama
@@ -112,10 +112,7 @@ class CutCrossEntropyPlugin(BasePlugin):
from cut_cross_entropy.transformers.patch import PATCH_FNS
def patch_generic(
maybe_model,
patch_options,
remote_model_id: str | None,
model_type: str,
maybe_model, patch_options, model_type: str, remote_model_id: str | None
):
import cut_cross_entropy.transformers.llama
from cut_cross_entropy.transformers.llama import cce_forward
@@ -139,13 +136,11 @@ class CutCrossEntropyPlugin(BasePlugin):
f"Error: {str(e)}"
) from e
if model_type_to_patch not in PATCH_FNS:
if model_type not in PATCH_FNS:
LOG.warning_once(
"Setting up generic cce patch for model type: %s", model_type_to_patch
"Setting up generic cce patch for model type: %s", model_type
)
LOG.warning_once(
f"Generic Cut Cross Entropy + {model_type_to_patch} support is experimental and may not work as expected."
)
PATCH_FNS[model_type_to_patch] = partial(
patch_generic, model_type=model_type_to_patch
f"Generic Cut Cross Entropy + {model_type} support is experimental and may not work as expected."
)
PATCH_FNS[model_type] = partial(patch_generic, model_type=model_type)

View File

@@ -1,44 +0,0 @@
# Kernels Integration
MoE (Mixture of Experts) kernels speed up training for MoE layers and reduce VRAM costs. In transformers v5, `batched_mm` and `grouped_mm` were integrated as built-in options via the `experts_implementation` config kwarg:
```python
class ExpertsInterface(GeneralInterface):
_global_mapping = {
"batched_mm": batched_mm_experts_forward,
"grouped_mm": grouped_mm_experts_forward,
}
```
In our custom integration, we add support for **ScatterMoE**, which is even more efficient and faster than `grouped_mm`.
## Usage
Add the following to your axolotl YAML config:
```yaml
plugins:
- axolotl.integrations.kernels.KernelsPlugin
use_kernels: true
use_scattermoe: true
```
**Important:** Setting `experts_implementation` is incompatible with `use_scattermoe`.
## How It Works
The `KernelsPlugin` runs before model loading and:
1. Registers the ScatterMoE kernel from the [`axolotl-ai-co/scattermoe`](https://huggingface.co/axolotl-ai-co/scattermoe) Hub repo.
2. Patches the model's `SparseMoeBlock` forward method with the optimized ScatterMoE implementation.
This works for any MoE model in transformers that uses a `SparseMoeBlock` class (Mixtral, Qwen2-MoE, OLMoE, etc.).
## Limitations
ScatterMoE uses a softmax -> topk routing, so results may be different for some model arch as baseline (GPT-OSS, GLM_MOE_DSA).
## Note on MegaBlocks
We tested [MegaBlocks](https://huggingface.co/kernels-community/megablocks) but were unable to ensure numerical accuracy, so we did not integrate it. It was also incompatible with many newer model architectures in transformers.

View File

@@ -1,7 +0,0 @@
from .args import KernelsArgs
from .plugin import KernelsPlugin
__all__ = [
"KernelsArgs",
"KernelsPlugin",
]

View File

@@ -1,48 +0,0 @@
from pydantic import BaseModel, model_validator
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
class KernelsArgs(BaseModel):
use_scattermoe: bool | None = True
@model_validator(mode="before")
@classmethod
def check_use_kernels(cls, data):
if data.get("use_kernels") is not True:
LOG.warning(
"`use_kernels` must be set to True to use this. Automatically setting it to True."
)
data["use_kernels"] = True
return data
@model_validator(mode="before")
@classmethod
def check_experts_implementation(cls, data):
experts_implementation = data.get("experts_implementation")
if experts_implementation is None:
# transformers may default to batched_mm when unset
data["experts_implementation"] = "eager"
elif experts_implementation != "eager":
LOG.warning(
"`experts_implementation` must be set to 'eager' to use this. Automatically setting it to 'eager'."
)
data["experts_implementation"] = "eager"
return data
@model_validator(mode="before")
@classmethod
def disable_mlp_kernel_scattermoe(cls, data):
if data.get("use_scattermoe") is True:
if data.get("lora_mlp_kernel") is True:
LOG.warning(
"Disabling lora_mlp_kernel when using scattermoe due to compatibility issues."
)
data["lora_mlp_kernel"] = False
data["mlp_kernel"] = False
return data

View File

@@ -1,18 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
from . import layers
from .lora_ops import ParallelExperts
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
__all__ = [
"layers",
"ParallelExperts",
"flatten_sort_count",
"parallel_linear",
"ScatterMoELoRA",
"parallel_linear_lora",
"lora_ops",
]

View File

@@ -1,12 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
#
# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors
# Adapted from https://github.com/shawntan/scattermoe
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
#
# Modifications and LoRA adaptation Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
from . import lora_ops, ops
__all__ = ["ops", "lora_ops"]

View File

@@ -1,645 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/shawntan/scattermoe
# Copyright (c) Shawn Tan and ScatterMoE Contributors
# Licensed under the Apache License, Version 2.0
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
from typing import Optional
import torch
import triton
import triton.language as tl
BLOCK_M = 128
ALLOW_TF32 = True
@triton.jit
def _compute_expert_block(
E_idx,
E_mask,
M_in_idx,
N_block,
N_mask,
X_ptr,
stride_xm,
stride_xk,
W_ptr,
stride_we,
stride_wk,
stride_wn,
K,
acc,
no_k_mask,
BLOCK_K,
allow_tf32=True,
):
K_block = tl.arange(0, BLOCK_K)
X_blk_ptrs = X_ptr + M_in_idx[:, None] * stride_xm + K_block[None, :] * stride_xk
W_blk_ptrs = (
W_ptr
+ K_block[:, None] * stride_wk
+ N_block[None, :] * stride_wn
+ E_idx * stride_we
)
iters = tl.cdiv(K, BLOCK_K)
for K_block_id in range(iters):
if no_k_mask:
x = tl.load(X_blk_ptrs, mask=E_mask[:, None])
w = tl.load(W_blk_ptrs, mask=N_mask[None, :])
else:
K_mask = (K_block_id * BLOCK_K + K_block) < K
x = tl.load(X_blk_ptrs, mask=E_mask[:, None] & K_mask[None, :])
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :])
X_blk_ptrs += BLOCK_K * stride_xk
W_blk_ptrs += BLOCK_K * stride_wk
acc = tl.dot(x, w, acc, allow_tf32=allow_tf32)
return acc
def _scatter2scatter_configs():
return [
triton.Config({"BLOCK_N": 128, "BLOCK_K": 32}, num_stages=4, num_warps=4),
]
@triton.autotune(
configs=_scatter2scatter_configs(),
key=["M", "N", "K"],
)
@triton.heuristics(
{
"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0,
"NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0,
}
)
@triton.jit
def _scatter2scatter(
X_ptr,
stride_xm: tl.constexpr,
stride_xk: tl.constexpr,
W_ptr,
stride_we,
stride_wk: tl.constexpr,
stride_wn: tl.constexpr,
Y_ptr,
stride_ym: tl.constexpr,
stride_yn: tl.constexpr,
B_ptr,
stride_be: tl.constexpr,
stride_bn: tl.constexpr,
grouped_idx_ptr,
expert_idxs_ptr,
# block_start_idx_ptr,
FAN_OUT: tl.constexpr,
M,
K: tl.constexpr,
N: tl.constexpr,
E: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
# OUT_M,
allow_tf32: tl.constexpr,
x_grouped: tl.constexpr,
y_grouped: tl.constexpr,
NO_K_MASK: tl.constexpr,
NO_N_MASK: tl.constexpr,
):
pid = tl.program_id(axis=0)
N_BLOCK_COUNT = tl.cdiv(N, BLOCK_N)
M_block_id = pid // N_BLOCK_COUNT
N_block_id = pid % N_BLOCK_COUNT
M_block = M_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N
M_boundary_mask = M_block < (FAN_OUT * M)
E_idxs = tl.load(expert_idxs_ptr + M_block, mask=M_boundary_mask, other=E)
no_k_mask = K % BLOCK_K == 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
E_first_idx = tl.min(E_idxs)
E_last_idx = tl.minimum(tl.max(E_idxs), E - 1)
M_idx = tl.load(grouped_idx_ptr + M_block, mask=M_boundary_mask).to(tl.int32)
for E_idx in range(E_first_idx, E_last_idx + 1):
E_mask = E_idxs == E_idx
E_M_idx = M_idx
if x_grouped:
M_in_idx = M_block
else:
M_in_idx = E_M_idx // FAN_OUT
acc = _compute_expert_block(
E_idx,
E_mask,
M_in_idx,
N_block,
N_mask,
X_ptr,
stride_xm,
stride_xk,
W_ptr,
stride_we,
stride_wk,
stride_wn,
K,
acc,
no_k_mask,
BLOCK_K,
allow_tf32=allow_tf32,
)
if B_ptr is not None:
B_blk_ptrs = B_ptr + E_idxs[:, None] * stride_be + N_block[None, :] * stride_bn
acc += tl.load(B_blk_ptrs, mask=M_boundary_mask[:, None] & N_mask[None, :])
if y_grouped:
M_out_idx = M_block
else:
M_out_idx = M_idx
Y_blk_ptrs = Y_ptr + (M_out_idx[:, None] * stride_ym + N_block[None, :] * stride_yn)
tl.store(Y_blk_ptrs, acc, mask=M_boundary_mask[:, None] & N_mask[None, :])
def scatter2scatter(
X,
W,
sorted_expert_idxs,
sorted_scattered_idxs,
k,
b=None,
x_grouped=False,
y_grouped=False,
out=None,
):
assert sorted_scattered_idxs.size(0) == sorted_expert_idxs.size(0)
assert sorted_scattered_idxs.size(0) == X.size(0) * k
# Pre-kernel setup
y_dim = W.size(-1)
L_scattered = sorted_expert_idxs.size(0)
if out is None:
output = torch.empty((L_scattered, y_dim), device=X.device, dtype=X.dtype)
else:
assert out.size(0) == L_scattered and out.size(1) == y_dim
output = out
scatter2scatter_compileable(
output,
W,
X,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
b,
x_grouped,
y_grouped,
)
return output
@torch.library.custom_op("scattermoe::scatter2scatter", mutates_args={"output"})
def scatter2scatter_compileable(
output: torch.Tensor,
W: torch.Tensor,
X: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
b: Optional[torch.Tensor],
x_grouped: bool,
y_grouped: bool,
) -> None:
def grid(META):
grid_num = (
triton.cdiv(sorted_expert_idxs.size(0), META["BLOCK_M"])
* triton.cdiv(META["N"], META["BLOCK_N"]),
)
return grid_num
if b is None:
b = None
stride_be = stride_bn = 0
else:
stride_be, stride_bn = b.stride()
_scatter2scatter[grid](
# X_ptr, stride_xm, stride_xk,
X,
X.stride(0),
X.stride(1),
# W_ptr, stride_we, stride_wk, stride_wn,
W,
W.stride(0),
W.stride(1),
W.stride(2),
# Y_ptr, stride_ym, stride_yn,
output,
output.stride(0),
output.stride(1),
# B_ptr, stride_be, stride_bn
b,
stride_be,
stride_bn,
grouped_idx_ptr=sorted_scattered_idxs,
expert_idxs_ptr=sorted_expert_idxs,
# block_start_idx_ptr=padded_block_idxs,
FAN_OUT=k,
M=X.size(0),
K=X.size(1),
N=output.size(1),
E=W.size(0),
BLOCK_M=BLOCK_M,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
x_grouped=x_grouped,
y_grouped=y_grouped,
)
def _config_XtY():
return [
triton.Config(
{"BLOCK_N": 128, "BLOCK_K": 128, "BLOCK_M": 32}, num_stages=4, num_warps=4
),
]
def group_bwd_W(DY, X, expert_offsets, E, has_bias=False):
DWt = torch.zeros((E, DY.size(-1), X.size(-1)), device=DY.device, dtype=DY.dtype)
DW = DWt.permute(0, 2, 1)
if has_bias:
Db = torch.zeros((E, DY.size(-1)), device=DY.device, dtype=DY.dtype)
else:
Db = None
groupXtY_compileable(E, DW, Db, DY, X, expert_offsets)
return DW, Db
@torch.library.custom_op("scattermoe::groupXtY", mutates_args={"DW", "Db"})
def groupXtY_compileable(
E: int,
DW: torch.Tensor,
Db: Optional[torch.Tensor],
DY: torch.Tensor,
X: torch.Tensor,
expert_offsets: torch.Tensor,
) -> None:
def grid(META):
grid = (
E * triton.cdiv(META["K"], META["BLOCK_K"]),
triton.cdiv(META["N"], META["BLOCK_N"]),
)
return grid
if Db is None:
stride_dbe = 0
stride_dbn = 0
else:
stride_dbe, stride_dbn = Db.stride()
_groupXtY[grid](
# DY_ptr, stride_dym, stride_dyk,
DY,
DY.stride(0),
DY.stride(1),
# X_ptr, stride_xm, stride_xn,
X,
X.stride(0),
X.stride(1),
# DW_ptr, stride_dwe, stride_dwk, stride_dwn,
DW,
DW.stride(0),
DW.stride(1),
DW.stride(2),
# Db_ptr, stride_dwe, stride_dbn,
Db,
stride_dbe,
stride_dbn,
# expert_offsets_ptr,
expert_offsets,
# K: tl.constexpr, N: tl.constexpr,
M=DY.size(0),
N=DY.size(-1),
K=X.size(-1),
# ACC_TYPE: tl.constexpr,
ACC_TYPE=tl.float32,
allow_tf32=ALLOW_TF32,
)
@triton.autotune(
configs=_config_XtY(),
key=["M", "N", "K"],
)
@triton.heuristics(
{
"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0,
"NO_N_MASK": lambda args: (args["N"] % args["BLOCK_N"]) == 0,
}
)
@triton.jit
def _groupXtY(
DY_ptr,
stride_dym,
stride_dyk,
X_ptr,
stride_xm,
stride_xn,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
expert_offsets_ptr,
M,
K: tl.constexpr,
N: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
allow_tf32: tl.constexpr,
NO_K_MASK: tl.constexpr,
NO_N_MASK: tl.constexpr,
):
pid0 = tl.program_id(axis=0)
pid1 = tl.program_id(axis=1)
num0 = tl.num_programs(0)
num1 = tl.num_programs(1)
# pid1, pid0 = tl.swizzle2d(pid1, pid0, num1, num0, 128)
pid0, pid1 = tl.swizzle2d(pid0, pid1, num0, num1, 4)
K_BLOCK_COUNT = tl.cdiv(K, BLOCK_K)
E_idx = pid0 // K_BLOCK_COUNT
K_block_id = pid0 % K_BLOCK_COUNT
N_block_id = pid1
if E_idx == 0:
start_idx = 0
else:
start_idx = tl.load(expert_offsets_ptr + E_idx - 1).to(tl.int32)
end_idx = tl.load(expert_offsets_ptr + E_idx).to(tl.int32)
if end_idx > start_idx:
M_block = tl.max_contiguous(start_idx + tl.arange(0, BLOCK_M), BLOCK_M)
K_block = K_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
K_mask = K_block < K
K_block = tl.max_contiguous(tl.multiple_of(K_block % K, BLOCK_K), BLOCK_K)
N_block = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_block < N
N_block = tl.max_contiguous(tl.multiple_of(N_block % N, BLOCK_N), BLOCK_N)
M_idxs = M_block
xt_blk_ptrs = X_ptr + K_block[:, None] * stride_xn + M_idxs[None, :] * stride_xm
dy_blk_ptrs = (
DY_ptr + M_idxs[:, None] * stride_dym + N_block[None, :] * stride_dyk
)
if (Db_ptr is not None) and (K_block_id == 0):
_xty_and_bias(
E_idx,
start_idx,
end_idx,
M_block,
K_block,
K_mask,
N_block,
N_mask,
dy_blk_ptrs,
stride_dym,
xt_blk_ptrs,
stride_xm,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
BLOCK_M,
BLOCK_N,
BLOCK_K,
ACC_TYPE,
allow_tf32,
NO_K_MASK,
NO_N_MASK,
compute_bias=True,
)
else:
_xty_and_bias(
E_idx,
start_idx,
end_idx,
M_block,
K_block,
K_mask,
N_block,
N_mask,
dy_blk_ptrs,
stride_dym,
xt_blk_ptrs,
stride_xm,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
BLOCK_M,
BLOCK_N,
BLOCK_K,
ACC_TYPE,
allow_tf32,
NO_K_MASK,
NO_N_MASK,
compute_bias=False,
)
@triton.jit
def _xty_and_bias(
E_idx,
start_idx,
end_idx,
M_block,
K_block,
K_mask,
N_block,
N_mask,
dy_blk_ptrs,
stride_dym,
xt_blk_ptrs,
stride_xm,
DW_ptr,
stride_dwe,
stride_dwk,
stride_dwn,
Db_ptr,
stride_dbe,
stride_dbn,
BLOCK_M,
BLOCK_N,
BLOCK_K,
ACC_TYPE,
allow_tf32,
NO_K_MASK,
NO_N_MASK,
compute_bias: tl.constexpr,
):
if compute_bias:
db_acc = tl.zeros((BLOCK_N,), dtype=ACC_TYPE)
else:
db_acc = None
acc = tl.zeros((BLOCK_K, BLOCK_N), dtype=ACC_TYPE)
iters = tl.cdiv(end_idx - start_idx, BLOCK_M)
for i in range(0, iters):
M_mask = (i * BLOCK_M + M_block) < end_idx
if NO_K_MASK:
xt = tl.load(xt_blk_ptrs, mask=M_mask[None, :])
else:
xt = tl.load(xt_blk_ptrs, mask=K_mask[:, None] & M_mask[None, :])
if NO_N_MASK:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None])
else:
dy = tl.load(dy_blk_ptrs, mask=M_mask[:, None] & N_mask[None, :])
acc += tl.dot(xt, dy, out_dtype=ACC_TYPE, allow_tf32=allow_tf32)
xt_blk_ptrs += BLOCK_M * stride_xm
dy_blk_ptrs += BLOCK_M * stride_dym
if compute_bias:
db_acc += tl.sum(dy, axis=0)
DW_blk_ptrs = (
DW_ptr
+ E_idx * stride_dwe
+ K_block[:, None] * stride_dwk
+ N_block[None, :] * stride_dwn
)
acc = acc.to(DW_blk_ptrs.dtype.element_ty)
tl.store(DW_blk_ptrs, acc, mask=K_mask[:, None] & N_mask[None, :])
if compute_bias:
Db_blk_ptrs = Db_ptr + E_idx * stride_dbe + N_block * stride_dbn
tl.store(Db_blk_ptrs, db_acc, mask=N_mask)
def _config_grouping():
return [
triton.Config({"BLOCK_N": 256, "BLOCK_K": 128}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_N': 128, 'BLOCK_K': 64}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=4, num_warps=4),
]
def group(A, sorted_expert_idxs, coeff=None, fan_out=1, out=None):
N = sorted_expert_idxs.size(0)
K = A.size(1)
assert A.size(0) * fan_out == N
if out is not None:
Y = out
else:
Y = torch.empty((N, K), dtype=A.dtype, device=A.device)
group_compileable(A, K, N, Y, coeff, coeff is not None, fan_out, sorted_expert_idxs)
return Y
@torch.library.custom_op("scattermoe::group", mutates_args={"Y"})
def group_compileable(
A: torch.Tensor,
K: int,
N: int,
Y: torch.Tensor,
coeff: Optional[torch.Tensor],
has_coeff: bool,
fan_out: int,
sorted_expert_idxs: torch.Tensor,
) -> None:
def grid(META):
grid_num = (triton.cdiv(META["N"], META["BLOCK_N"]),)
return grid_num
_group[grid](
# A_ptr, stride_an, stride_ai,
A,
A.stride(0),
A.stride(1),
has_coeff,
coeff,
fan_out,
# Y_ptr, stride_yn, stride_yk,
Y,
Y.stride(0),
Y.stride(1),
# grouped_idx_ptr,
sorted_expert_idxs,
# N: tl.constexpr, K: tl.constexpr,
N,
K,
)
@triton.autotune(configs=_config_grouping(), key=["K"])
@triton.heuristics({"NO_K_MASK": lambda args: (args["K"] % args["BLOCK_K"]) == 0})
@triton.jit
def _group(
src_ptr,
stride_sn,
stride_sk,
has_coeff: tl.constexpr,
coeff_ptr,
FAN_OUT: tl.constexpr,
tgt_ptr,
stride_tn,
stride_ti,
grouped_idx_ptr,
N,
K: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
NO_K_MASK: tl.constexpr,
):
pid = tl.program_id(axis=0)
N_block_id = pid
N_blk = N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
N_mask = N_blk < N
N_blk = tl.max_contiguous(tl.multiple_of(N_blk % N, BLOCK_N), BLOCK_N)
N_idx = tl.load(grouped_idx_ptr + N_blk, mask=N_mask, other=0)
K_blk = tl.arange(0, BLOCK_K)
src_blk_ptrs = (
src_ptr + (N_idx // FAN_OUT)[:, None] * stride_sn + K_blk[None, :] * stride_sk
)
tgt_blk_ptrs = tgt_ptr + N_blk[:, None] * stride_tn + K_blk[None, :] * stride_ti
if has_coeff:
c = tl.load(coeff_ptr + N_idx, mask=N_mask)[:, None]
iters = tl.cdiv(K, BLOCK_K)
for i in range(0, iters):
if NO_K_MASK or i < iters - 1:
block = tl.load(src_blk_ptrs, mask=N_mask[:, None])
if has_coeff:
block *= c
tl.store(tgt_blk_ptrs, block, mask=N_mask[:, None])
else:
K_mask = (i * BLOCK_K + K_blk) < K
mask = N_mask[:, None] & K_mask[None, :]
block = tl.load(src_blk_ptrs, mask=mask)
if has_coeff:
block *= c
tl.store(tgt_blk_ptrs, block, mask=mask)
src_blk_ptrs += BLOCK_K * stride_sk
tgt_blk_ptrs += BLOCK_K * stride_ti

View File

@@ -1,98 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/shawntan/scattermoe
# Copyright (c) Shawn Tan and ScatterMoE Contributors
# Licensed under the Apache License, Version 2.0
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
import torch
import triton
import triton.language as tl
@triton.jit
def _single2scatter(
X_ptr,
stride_xm,
stride_xk,
W_ptr,
stride_we,
stride_wk,
stride_wn,
Y_ptr,
stride_ym,
stride_yn,
expert_idxs_ptr,
FAN_OUT: tl.constexpr,
K: tl.constexpr,
N: tl.constexpr,
E: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
pid0 = tl.program_id(axis=0)
pid1 = tl.program_id(axis=1)
N_block_id = pid0
if FAN_OUT == 1:
in_idx = pid1
else:
in_idx = 0
out_idx = pid1
K_block = tl.arange(0, BLOCK_K)
N_block = tl.max_contiguous(
tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N),
BLOCK_N,
)
E_idx = tl.load(expert_idxs_ptr + pid1)
X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
W_blk_ptrs = (
W_ptr
+ E_idx * stride_we
+ K_block[:, None] * stride_wk
+ N_block[None, :] * stride_wn
)
N_mask = N_block < N
acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
for _K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
K_mask = K_block < K
x = tl.load(X_blk_ptrs, mask=K_mask[:, None], other=0.0)
w = tl.load(W_blk_ptrs, mask=K_mask[:, None] & N_mask[None, :], other=0.0)
acc += tl.sum(x * w, axis=0)[None, :]
X_blk_ptrs += BLOCK_K * stride_xk
W_blk_ptrs += BLOCK_K * stride_wk
K_block += BLOCK_K
Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
tl.store(Y_blk_ptrs, acc, mask=N_mask[None, :])
def single2scatter(X, W, expert_idxs):
E, xdim, ydim = W.size()
k = expert_idxs.size(1)
assert X.size(0) == k or X.size(0) == 1
Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
BLOCK_N = 128
BLOCK_K = 128
grid = triton.cdiv(ydim, BLOCK_N), k
_single2scatter[grid](
X,
X.stride(0),
X.stride(1),
W,
W.stride(0),
W.stride(1),
W.stride(2),
Y,
Y.stride(0),
Y.stride(1),
expert_idxs,
FAN_OUT=Y.size(0) // X.size(0),
K=xdim,
N=ydim,
E=E,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
ACC_TYPE=tl.float32,
)
return Y

View File

@@ -1,439 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
#
# Original work Copyright (c) Shawn Tan and ScatterMoE Contributors
# Adapted from https://github.com/shawntan/scattermoe
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
#
# Modifications and LoRA adaptation Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
ScatterMoE layer replacements for HuggingFace MoE architectures.
Provides drop-in forward replacements that use ScatterMoE kernels for
acceleration. When used via the HF ``kernels`` library
(``replace_kernel_forward_from_hub``), these classes replace the forward
method of the original MoE block.
LoRA support
------------
When peft wraps parameters via ``target_parameters``, the ``self.experts``
submodule becomes a chain of ``ParamWrapper`` objects and the ``self.gate``
router may also become a ``ParamWrapper``. The ``HFScatterMoEGatedMLP``
forward detects this and automatically:
1. Unwraps ``self.gate`` to the base router, applying gate LoRA delta
2. Unwraps ``self.experts`` to the base ``OlmoeExperts`` module
3. Extracts LoRA A/B weights and scaling from each wrapper
4. Converts B layout from peft rank-major to scattermoe expert-major
5. Routes to ``parallel_linear_lora`` for fused LoRA computation
6. Passes through ``self.shared_expert`` / ``self.shared_expert_gate``
(peft wraps their linear layers with standard LoRA, no special handling)
"""
import torch
from torch import nn
from torch.nn import functional as F
from .parallel_experts import flatten_sort_count, parallel_linear
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
# =============================================================================
# LoRA layout conversion utilities (peft <-> scattermoe)
# =============================================================================
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
expert-major ``[N, r*E]``.
peft reshapes B to ``[out, r, E]`` (rank-major).
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
"""
N = peft_B.shape[0]
return (
peft_B.reshape(N, rank, num_experts)
.permute(0, 2, 1)
.contiguous()
.reshape(N, num_experts * rank)
)
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Convert peft LoRA weights to scattermoe layout (with A<->B swap).
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
where ``in_features=dim1, out_features=dim2``. ScatterMoE transposes the
parameter (``W = param.transpose(2, 1)``) giving ``[E, dim2, dim1]`` with
``K=dim2, N=dim1``. Because of this transposition, peft's A and B roles
are swapped relative to scattermoe's convention.
peft gives:
lora_A ``[r*E, dim1]``, lora_B ``[dim2, r*E]``
scattermoe needs:
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
This function swaps A<->B and converts B from rank-major to expert-major.
Uses vectorized tensor operations (no Python loop over experts).
Works for **both** gate_up_proj and down_proj since the transposition
issue is the same for any parameter.
"""
peft_B_em = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
dim1 = peft_A.shape[1] # peft in_features -> scattermoe N
dim2 = peft_B_em.shape[0] # peft out_features -> scattermoe K
# smoe_A: per expert, transpose B_e [dim2, r] -> [r, dim2]
# [dim2, E*r] -> [dim2, E, r] -> [E, r, dim2] -> [E*r, dim2]
smoe_A = (
peft_B_em.reshape(dim2, num_experts, rank)
.permute(1, 2, 0)
.contiguous()
.reshape(rank * num_experts, dim2)
)
# smoe_B: per expert, transpose A_e [r, dim1] -> [dim1, r]
# [E*r, dim1] -> [E, r, dim1] -> [dim1, E, r] -> [dim1, E*r]
smoe_B = (
peft_A.reshape(num_experts, rank, dim1)
.permute(2, 0, 1)
.contiguous()
.reshape(dim1, num_experts * rank)
)
return smoe_A, smoe_B
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
# =============================================================================
# ParamWrapper unwrapping
# =============================================================================
def _unwrap_gate_lora(gate_module):
"""Unwrap peft ``ParamWrapper`` on the router gate.
When peft targets ``gate.weight``, ``self.gate`` becomes::
ParamWrapper(weight)
-> base_layer: OlmoeTopKRouter (the real module)
This function detects the wrapping and returns the base router, its
weight tensor, and an optional LoRA delta tensor.
Returns:
(base_gate, gate_weight, gate_lora_delta_or_None)
``base_gate`` is the original router module (with ``.top_k``,
``.num_experts``, ``.norm_topk_prob``).
``gate_weight`` is the base router weight (may be a DTensor under FSDP).
``gate_lora_delta_or_None`` is the LoRA delta tensor if LoRA is active,
else ``None``. Kept separate to avoid mixing DTensor + Tensor in an add.
"""
if hasattr(gate_module, "base_layer") and hasattr(gate_module, "lora_A"):
base_gate = gate_module.base_layer
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gate_module)
if lora_A is not None:
# gate weight: [num_experts, hidden_size]
# lora_A: [r, hidden_size], lora_B: [num_experts, r]
# delta = scaling * B @ A = [num_experts, hidden_size]
delta = scaling * (lora_B @ lora_A)
return base_gate, base_gate.weight, delta
else:
return base_gate, base_gate.weight, None
else:
# No wrapping — gate is the original module
return gate_module, gate_module.weight, None
def _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling):
"""Convert peft LoRA weights to scattermoe layout."""
smoe_A, smoe_B = peft_lora_to_scattermoe(lora_A, lora_B, num_experts, rank)
return (smoe_A, smoe_B, scaling)
def _unwrap_experts_lora(experts_module):
"""Walk a peft ``ParamWrapper`` chain on ``self.experts``.
When peft targets ``experts.gate_up_proj`` and ``experts.down_proj`` via
``target_parameters``, ``self.experts`` becomes a nested chain::
ParamWrapper(down_proj)
-> base_layer: ParamWrapper(gate_up_proj)
-> base_layer: OlmoeExperts (the real module)
This function walks the chain, collects LoRA params keyed by
``parameter_name``, and returns the base experts module.
Returns:
(base_experts, gup_lora, down_lora)
Each ``*_lora`` is either ``(smoe_A, smoe_B, scaling)`` or ``None``.
A/B are already in scattermoe layout.
"""
# Collect ParamWrapper layers by their parameter_name
wrappers = {}
module = experts_module
while hasattr(module, "base_layer") and hasattr(module, "lora_A"):
param_name = getattr(module, "parameter_name", None)
if param_name is not None:
wrappers[param_name] = module
module = module.base_layer
base_experts = module
if not wrappers:
return base_experts, None, None
# Determine num_experts from base module
num_experts = getattr(base_experts, "num_experts", None)
if num_experts is None:
# Fallback: infer from parameter shape
gup = getattr(base_experts, "gate_up_proj", None)
if gup is not None:
num_experts = gup.shape[0]
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
gup_lora = None
gup_wrapper = wrappers.get("gate_up_proj")
if gup_wrapper is not None:
lora_A, lora_B, scaling = get_lora_params_from_wrapper(gup_wrapper)
if lora_A is not None:
rank = lora_A.shape[0] // num_experts
gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
# Extract down_proj LoRA (needs A<->B swap due to transposition)
down_lora = None
down_wrapper = wrappers.get("down_proj")
if down_wrapper is not None:
lora_A, lora_B, scaling = get_lora_params_from_wrapper(down_wrapper)
if lora_A is not None:
rank = lora_A.shape[0] // num_experts
down_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
return base_experts, gup_lora, down_lora
# =============================================================================
# Layer classes
# =============================================================================
class ScatterMoEGatedMLP(nn.Module):
def forward(self, layer_input):
"""
Forward pass of the mixture of experts layer.
Args:
layer_input (Tensor):
Input tensor.
Returns:
Tensor:
Output tensor.
"""
bsz, length, emb_size = layer_input.size()
layer_input = layer_input.reshape(-1, emb_size)
# compute the top_k routing decision
router_logits = self.router.layer(layer_input)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.router.top_k, dim=-1
)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(layer_input.dtype)
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
selected_experts, num_experts=self.router.num_experts
)
# compute experts
gates, h = parallel_linear(
layer_input,
self.input_linear.weight.transpose(2, 1),
self.router.top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=False,
grouped_out=True,
).chunk(2, dim=-1)
h = self.activation(gates) * h
layer_output = parallel_linear(
h,
self.output_linear.weight.transpose(2, 1),
1,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=True,
grouped_out=False,
gates=routing_weights,
)
layer_output = layer_output.view(bsz, length, emb_size)
return layer_output
class HFScatterMoEGatedMLP(nn.Module):
"""
ScatterMoE-accelerated forward pass for HF MoEs (OLMoE / Qwen2MoE).
Used as a kernel layer via the HF ``kernels`` library. The ``forward``
method replaces the original ``OlmoeSparseMoeBlock.forward``.
Supports both full-parameter training and LoRA fine-tuning:
* **Full-param**: uses ``parallel_linear`` (base ScatterMoE kernel)
* **LoRA**: detects peft ``ParamWrapper`` on ``self.experts``, extracts
adapter weights, and uses ``parallel_linear_lora`` (fused kernel)
"""
@staticmethod
def forward(self: nn.Module, layer_input: torch.Tensor):
"""
Forward pass using ScatterMoE kernels.
Args:
self: The MoeSparseMoeBlock module containing:
- self.gate: Router (or peft ParamWrapper wrapping it)
- self.experts: Experts module (or peft ParamWrapper chain)
- self.shared_expert: Optional shared expert (e.g. Qwen2MoE)
- self.shared_expert_gate: Optional shared expert gate
layer_input: Input tensor [batch_size, seq_len, hidden_size]
Returns:
Tensor: [batch_size, seq_len, hidden_size]
"""
batch_size, sequence_length, hidden_dim = layer_input.shape
hidden_states_flat = layer_input.view(-1, hidden_dim)
# ====================================================================
# Shared Expert (if present, e.g. Qwen2MoE)
# ====================================================================
# peft wraps individual linear layers inside shared_expert with
# standard LoRA — calling forward() handles this transparently.
if hasattr(self, "shared_expert") and self.shared_expert is not None:
shared_expert_output = self.shared_expert(hidden_states_flat)
# shared_expert_gate may also be peft-wrapped (standard LoRA
# on nn.Linear), its forward() applies LoRA automatically.
shared_expert_gate_output = F.sigmoid(
self.shared_expert_gate(hidden_states_flat)
)
shared_expert_output = shared_expert_output * shared_expert_gate_output
else:
shared_expert_output = None
# ====================================================================
# Router Computation (with optional gate LoRA)
# ====================================================================
base_gate, gate_weight, gate_lora_delta = _unwrap_gate_lora(self.gate)
router_logits = F.linear(hidden_states_flat, gate_weight)
if gate_lora_delta is not None:
router_logits = router_logits + F.linear(
hidden_states_flat, gate_lora_delta
)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
top_k = base_gate.top_k
num_experts = base_gate.num_experts
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
if base_gate.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states_flat.dtype)
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets = flatten_sort_count(
selected_experts, num_experts=num_experts
)
# ====================================================================
# Detect LoRA (peft ParamWrapper) and extract adapter weights
# ====================================================================
experts, gup_lora, down_lora = _unwrap_experts_lora(self.experts)
# ====================================================================
# Gate + Up projection
# ====================================================================
gate_up_W = experts.gate_up_proj.transpose(2, 1) # [E, hidden, 2*inter]
if gup_lora is not None:
gup_A, gup_B, gup_scaling = gup_lora
gup = parallel_linear_lora(
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A=gup_A,
lora_B=gup_B,
scaling=gup_scaling,
grouped_in=False,
grouped_out=True,
use_fused_dX=True,
use_fused_gather=True,
)
else:
gup = parallel_linear(
hidden_states_flat,
gate_up_W,
top_k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=False,
grouped_out=True,
)
gates, h = gup.chunk(2, dim=-1)
h = experts.act_fn(gates) * h
# ====================================================================
# Down projection
# ====================================================================
down_W = experts.down_proj.transpose(2, 1) # [E, inter, hidden]
if down_lora is not None:
down_A, down_B, down_scaling = down_lora
expert_output = parallel_linear_lora(
h,
down_W,
1,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A=down_A,
lora_B=down_B,
scaling=down_scaling,
gates=routing_weights,
grouped_in=True,
grouped_out=False,
use_fused_dX=True,
use_fused_gather=True,
)
else:
expert_output = parallel_linear(
h,
down_W,
1,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
grouped_in=True,
grouped_out=False,
gates=routing_weights,
)
# ====================================================================
# Combine with shared expert and reshape
# ====================================================================
if shared_expert_output is not None:
expert_output = expert_output + shared_expert_output
expert_output = expert_output.view(batch_size, sequence_length, hidden_dim)
return expert_output

View File

@@ -1,99 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
ParallelExperts module with LoRA support.
Provides a drop-in replacement for ScatterMoE's ParallelExperts that
uses the fused LoRA kernel when adapter weights are attached.
"""
from typing import Optional
import torch
import torch.nn as nn
from .parallel_linear_lora import parallel_linear_lora
class ParallelExperts(nn.Module):
"""
Parallel Experts with fused LoRA support.
Drop-in replacement for the original ParallelExperts. When LoRA parameters
are attached via set_lora(), the forward pass uses a fused kernel:
Y = X @ W + scaling * (X @ A^T) @ B^T
"""
def __init__(
self,
num_experts: int,
input_size: int,
output_size: int,
bias: bool = False,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
if bias:
self.bias = nn.Parameter(torch.empty(num_experts, output_size))
else:
self.bias = None
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size
self._lora_A: torch.Tensor | None = None
self._lora_B: torch.Tensor | None = None
self._lora_scaling: float | None = None
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
def extra_repr(self) -> str:
return (
f"num_experts={self.num_experts}, "
f"input_size={self.input_size}, "
f"output_size={self.output_size}"
)
def set_lora(self, lora_A: torch.Tensor, lora_B: torch.Tensor, scaling: float):
"""Attach LoRA parameters for fused computation."""
self._lora_A = lora_A
self._lora_B = lora_B
self._lora_scaling = scaling
def clear_lora(self):
"""Remove LoRA parameters."""
self._lora_A = None
self._lora_B = None
self._lora_scaling = None
def forward(
self,
inputs: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
) -> torch.Tensor:
return parallel_linear_lora(
inputs,
self.weight.permute(0, 2, 1), # [E, input, output]
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A=self._lora_A,
lora_B=self._lora_B,
scaling=self._lora_scaling if self._lora_scaling is not None else 1.0,
expert_biases=self.bias,
gates=gates,
grouped_in=grouped_in,
grouped_out=grouped_out,
)

View File

@@ -1,253 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/shawntan/scattermoe
# Copyright (c) Shawn Tan and ScatterMoE Contributors
# Licensed under the Apache License, Version 2.0
# See https://github.com/shawntan/scattermoe/blob/main/LICENSE
from typing import Optional
import torch
import torch.nn as nn
from . import kernels
@torch.library.custom_op("scattermoe::bincount", mutates_args={})
def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor:
return x.bincount(minlength=minlength)
@compileable_bincount.register_fake
def _(x: torch.Tensor, minlength: int) -> torch.Tensor:
return torch.empty(minlength, dtype=torch.long, device=x.device)
@torch.compile
def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int):
with torch.no_grad():
flattened_expert_idxs = expert_idxs.flatten()
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
expert_counts = compileable_bincount(
flattened_expert_idxs, minlength=num_experts
)
expert_offsets = expert_counts.cumsum(-1)
return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets
class ParallelLinear(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: torch.Tensor,
expert_weights: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
expert_biases: Optional[torch.Tensor] = None,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
):
with torch.device(x.device):
output = kernels.ops.scatter2scatter(
X=x,
W=expert_weights,
b=expert_biases,
k=k,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
x_grouped=grouped_in,
y_grouped=grouped_out,
)
if gates is not None:
output_expanded = output.view(
gates.size(0), gates.size(1), output.size(-1)
)
output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
else:
output_expanded = None
ctx.save_for_backward(
x,
expert_weights,
expert_biases,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
)
ctx.grouped_in = grouped_in
ctx.grouped_out = grouped_out
ctx.k = k
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
with torch.device(grad_out.device):
(
x,
expert_weights,
expert_biases,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
) = ctx.saved_tensors
k = ctx.k
grouped_in = ctx.grouped_in
grouped_out = ctx.grouped_out
if gates is not None:
# calculate gates gradient
# d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
gates_flat = gates.flatten()
gate_fan = gates.size(1)
grouped_grad_out = output_expanded.flatten(
0, 1
) # reuse expanded buffer later
else:
d_gates = None
gates_flat = None
gate_fan = 1
grouped_grad_out = None
if grouped_out:
grouped_grad_out = grad_out
else:
grouped_grad_out = kernels.ops.group(
grad_out,
sorted_scattered_idxs,
fan_out=gate_fan,
coeff=gates_flat,
out=grouped_grad_out,
)
if grouped_in:
grouped_x = x
d_expanded_input = None
else:
grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k)
d_expanded_input = grouped_x
d_weights, d_biases = kernels.ops.group_bwd_W(
DY=grouped_grad_out,
X=grouped_x,
expert_offsets=expert_offsets,
E=expert_weights.size(0),
has_bias=expert_biases is not None,
)
d_expanded_input = kernels.ops.scatter2scatter(
X=grouped_grad_out,
x_grouped=True,
W=expert_weights.permute(0, 2, 1),
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
y_grouped=grouped_in,
out=d_expanded_input, # Reuse grouped_x buffer
)
if k == 1:
d_input = d_expanded_input
else:
d_input = d_expanded_input.view(
x.size(0), k, d_expanded_input.size(-1)
).sum(-2)
return (
# x, expert_weights,
d_input,
d_weights,
# k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
None,
None,
None,
None,
# bias, gates
d_biases,
d_gates,
# grouped_in, grouped_out,
None,
None,
)
def parallel_linear(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases=None,
gates=None,
grouped_in=False,
grouped_out=False,
):
results = ParallelLinear.apply(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases,
gates,
grouped_in,
grouped_out,
)
return results
class ParallelExperts(nn.Module):
def __init__(self, num_experts, input_size, output_size, bias=False) -> None:
super().__init__()
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
if bias:
self.bias = nn.Parameter(torch.empty(num_experts, output_size))
else:
self.bias = None
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size
self.reset_parameters()
def extra_repr(self):
return "num_experts={}, input_size={}, output_size={}".format(
self.num_experts, self.input_size, self.output_size
)
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(
self,
inputs,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates=None,
grouped_in=False,
grouped_out=False,
):
results = parallel_linear(
inputs,
self.weight.permute(0, 2, 1),
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases=self.bias,
gates=gates,
grouped_in=grouped_in,
grouped_out=grouped_out,
)
return results

View File

@@ -1,480 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Axolotl AI
# Licensed under the Apache License, Version 2.0
"""
ScatterMoE + LoRA Autograd Function
====================================
Provides the autograd function and Python interface for fused ScatterMoE + LoRA.
Key design for LoRA training:
- Expert weights W are FROZEN (no gradient computed for W).
- Only LoRA adapter weights (A, B) receive gradients.
- The input gradient dX is still computed (needed for upstream layers).
- This avoids the expensive group_bwd_W computation entirely.
Forward:
Y = X @ W + scaling * (X @ A^T) @ B^T
Backward (W frozen):
dX = dY @ W^T + scaling * (dY @ B) @ A (via scatter2scatter for base, separate for LoRA)
dA = scaling * (dY @ B)^T @ X (per-expert, on grouped data)
dB = scaling * dY^T @ (X @ A^T) (per-expert, on grouped data)
"""
from typing import Optional
import torch
from .kernels import ops as base_ops
from .kernels.lora_ops import (
group_bwd_lora,
group_bwd_lora_fused,
scatter2scatter_lora,
scatter2scatter_lora_dX,
)
class ScatterMoELoRA(torch.autograd.Function):
"""
Autograd function for fused ScatterMoE + LoRA with frozen expert weights.
This function is optimized for the LoRA fine-tuning scenario where:
- Expert weights W are frozen (requires_grad=False)
- Only LoRA A and B matrices receive gradients
- Input gradients are computed for upstream layer backprop
"""
@staticmethod
def forward(
ctx,
x: torch.Tensor,
expert_weights: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
scaling: float,
expert_biases: Optional[torch.Tensor] = None,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
use_fused_dX: bool = False,
use_fused_gather: bool = False,
):
with torch.device(x.device):
# Fused forward: Y = X @ W + scaling * (X @ A^T) @ B^T
output = scatter2scatter_lora(
X=x,
W=expert_weights,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=k,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
b=expert_biases,
x_grouped=grouped_in,
y_grouped=grouped_out,
)
# Handle gating (weighted combination of top-k expert outputs)
if gates is not None:
output_expanded = output.view(
gates.size(0), gates.size(1), output.size(-1)
)
output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
else:
output_expanded = None
ctx.save_for_backward(
x,
lora_A,
lora_B,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
)
# Store frozen weights as plain Python attributes instead of
# save_for_backward. This avoids:
# 1. Version-check conflicts with FSDP unshard/reshard
# 2. Pinning all-gathered parameters via saved_tensors hooks
# 3. Interfering with activation offloading pack/unpack hooks
# Safe because expert_weights are frozen (requires_grad=False).
ctx.expert_weights = expert_weights
ctx.expert_biases = expert_biases
ctx.grouped_in = grouped_in
ctx.grouped_out = grouped_out
ctx.k = k
ctx.scaling = scaling
ctx.use_fused_dX = use_fused_dX
ctx.use_fused_gather = use_fused_gather
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
with torch.device(grad_out.device):
(
x,
lora_A,
lora_B,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
) = ctx.saved_tensors
expert_weights = ctx.expert_weights
k = ctx.k
scaling = ctx.scaling
grouped_in = ctx.grouped_in
grouped_out = ctx.grouped_out
E = expert_weights.size(0)
# ------------------------------------------------------------------
# Gate gradients (if using top-k gating with routing weights)
# ------------------------------------------------------------------
if gates is not None:
# d_gates[t, j] = output_expanded[t, j, :] . grad_out[t, :]
d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
gates_flat = gates.flatten()
gate_fan = gates.size(1)
# Reuse output_expanded buffer for grouped_grad_out
grouped_grad_out = output_expanded.flatten(0, 1)
else:
d_gates = None
gates_flat = None
gate_fan = 1
grouped_grad_out = None
# ------------------------------------------------------------------
# LoRA gradients (dA, dB) and setup for dX
# ------------------------------------------------------------------
# Fused gather uses sorted_scattered_idxs for indirect X access
# in the Triton kernel, avoiding the group(x) allocation.
#
# can_fuse_gather: X is ungrouped and not too large for scatter loads
# - When gates is None and grouped_out=False: both DY and X ungrouped
# - When grouped_out=True (gate_up_proj): DY already grouped, X ungrouped
# -> use dy_grouped=True in the fused kernel
M_total = sorted_scattered_idxs.size(0)
K_dim = x.size(-1)
N_dim = expert_weights.size(-1)
fuse_gather_workload = M_total * max(K_dim, N_dim)
_FUSE_GATHER_THRESHOLD = 2**24 # ~16M elements
can_fuse_gather = (
ctx.use_fused_gather
and not grouped_in # X must be ungrouped for scatter access
and gates is None # gate coeff requires multiplicative gather
and fuse_gather_workload < _FUSE_GATHER_THRESHOLD
)
if can_fuse_gather:
# ------------------------------------------------------------------
# Fused path: skip group(x) entirely
# ------------------------------------------------------------------
d_expanded_input = None
d_lora_A, d_lora_B = group_bwd_lora_fused(
DY=grad_out,
X=x,
lora_A=lora_A,
lora_B=lora_B,
expert_offsets=expert_offsets,
sorted_scattered_idxs=sorted_scattered_idxs,
E=E,
k=k,
scaling=scaling,
dy_grouped=grouped_out,
)
# Prepare grouped_grad_out for the dX path (needed by both
# the fused dX kernel when grouped_out=True, and the non-fused path)
if grouped_out:
grouped_grad_out = grad_out
elif not ctx.use_fused_dX:
grouped_grad_out = base_ops.group(
grad_out,
sorted_scattered_idxs,
fan_out=gate_fan,
coeff=gates_flat,
out=grouped_grad_out,
)
else:
# ------------------------------------------------------------------
# Original path: explicit group() calls
# ------------------------------------------------------------------
if grouped_out:
grouped_grad_out = grad_out
else:
grouped_grad_out = base_ops.group(
grad_out,
sorted_scattered_idxs,
fan_out=gate_fan,
coeff=gates_flat,
out=grouped_grad_out,
)
if grouped_in:
grouped_x = x
d_expanded_input = None
else:
grouped_x = base_ops.group(x, sorted_scattered_idxs, fan_out=k)
d_expanded_input = grouped_x # Will be overwritten; reuse buffer
d_lora_A, d_lora_B = group_bwd_lora(
DY=grouped_grad_out,
X=grouped_x,
lora_A=lora_A,
lora_B=lora_B,
expert_offsets=expert_offsets,
E=E,
scaling=scaling,
)
# ------------------------------------------------------------------
# Input gradient: dX = dY @ W^T + scaling * (dY @ B) @ A
# ------------------------------------------------------------------
if ctx.use_fused_dX:
if can_fuse_gather and not grouped_out:
# Fully fused: read ungrouped DY via scatter pattern
d_expanded_input = scatter2scatter_lora_dX(
DY=grad_out,
W=expert_weights,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
dy_grouped=False,
dx_grouped=grouped_in,
out=d_expanded_input,
)
else:
# Fused dX only: read from pre-grouped DY
d_expanded_input = scatter2scatter_lora_dX(
DY=grouped_grad_out,
W=expert_weights,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
dy_grouped=True,
dx_grouped=grouped_in,
out=d_expanded_input,
)
else:
# Original path: separate base scatter2scatter + LoRA Python loop
d_expanded_input = base_ops.scatter2scatter(
X=grouped_grad_out,
x_grouped=True,
W=expert_weights.permute(0, 2, 1), # [E, N, K]
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
y_grouped=grouped_in,
out=d_expanded_input,
)
# LoRA part: dX_lora = scaling * (dY @ B) @ A
if scaling != 0.0:
d_input_lora_grouped = _compute_lora_input_grad(
grouped_grad_out,
lora_A,
lora_B,
expert_offsets,
E,
scaling,
)
if grouped_in:
d_expanded_input.add_(d_input_lora_grouped)
else:
# Scatter-add LoRA gradient directly into d_expanded_input.
# Avoids allocating a zeros_like + add result
d_expanded_input[sorted_scattered_idxs] += d_input_lora_grouped
# Reduce over top-k if k > 1
if k == 1:
d_input = d_expanded_input
else:
d_input = d_expanded_input.view(
x.size(0), k, d_expanded_input.size(-1)
).sum(-2)
# W is frozen during LoRA training -- skip weight gradient
d_weights = (
torch.zeros_like(expert_weights)
if expert_weights.requires_grad
else None
)
d_biases = None
return (
d_input,
d_weights,
None,
None,
None,
None, # k, sorted indices, offsets
d_lora_A,
d_lora_B,
None, # lora_A, lora_B, scaling
d_biases,
d_gates,
None,
None, # grouped_in, grouped_out
None, # use_fused_dX
None, # use_fused_gather
)
def _compute_lora_input_grad(
grouped_grad_out: torch.Tensor,
lora_A: torch.Tensor,
lora_B: torch.Tensor,
expert_offsets: torch.Tensor,
E: int,
scaling: float,
) -> torch.Tensor:
"""
Compute the LoRA contribution to the input gradient:
dX_lora = scaling * (dY @ B) @ A
Uses PyTorch ops on expert-grouped data.
Each expert e: dX_e = scaling * (dY_e @ B_e) @ A_e
"""
R = lora_A.size(0) // E
K = lora_A.size(1)
M_total = grouped_grad_out.size(0)
d_input_lora = torch.zeros(
(M_total, K), device=grouped_grad_out.device, dtype=grouped_grad_out.dtype
)
compute_dtype = grouped_grad_out.dtype
prev_offset = 0
for e in range(E):
curr_offset = expert_offsets[e].item()
if curr_offset > prev_offset:
dy_e = grouped_grad_out[prev_offset:curr_offset] # [M_e, N]
a_e = lora_A[e * R : (e + 1) * R, :].to(compute_dtype) # [r, K]
b_e = lora_B[:, e * R : (e + 1) * R].to(compute_dtype) # [N, r]
# dX_e = scaling * (dY_e @ B_e) @ A_e
dy_b = dy_e @ b_e # [M_e, r]
dx_e = scaling * (dy_b @ a_e) # [M_e, K]
d_input_lora[prev_offset:curr_offset] = dx_e
prev_offset = curr_offset
return d_input_lora
# =============================================================================
# Helper: Extract LoRA params from PEFT ParamWrapper
# =============================================================================
def get_lora_params_from_wrapper(module) -> tuple:
"""
Extract LoRA parameters from a PEFT ParamWrapper.
Returns:
(lora_A, lora_B, scaling) if LoRA is active, else (None, None, None)
"""
if not hasattr(module, "lora_A") or not hasattr(module, "lora_B"):
return None, None, None
active_adapters = getattr(module, "active_adapters", ["default"])
if not active_adapters:
return None, None, None
adapter_name = active_adapters[0]
lora_A_dict = getattr(module, "lora_A", {})
lora_B_dict = getattr(module, "lora_B", {})
scaling_dict = getattr(module, "scaling", {})
if adapter_name not in lora_A_dict:
return None, None, None
lora_A = lora_A_dict[adapter_name].weight
lora_B = lora_B_dict[adapter_name].weight
scaling = scaling_dict[adapter_name]
return lora_A, lora_B, scaling
# =============================================================================
# Drop-in replacement for parallel_linear
# =============================================================================
def parallel_linear_lora(
inputs: torch.Tensor,
expert_weights: torch.Tensor,
k: int,
sorted_expert_idxs: torch.Tensor,
sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
lora_A: Optional[torch.Tensor] = None,
lora_B: Optional[torch.Tensor] = None,
scaling: float = 1.0,
expert_biases: Optional[torch.Tensor] = None,
gates: Optional[torch.Tensor] = None,
grouped_in: bool = False,
grouped_out: bool = False,
use_fused_dX: bool = False,
use_fused_gather: bool = False,
):
"""
Drop-in replacement for parallel_linear that supports LoRA.
If lora_A and lora_B are provided, uses fused LoRA kernel.
Otherwise falls back to standard scatter2scatter.
"""
if lora_A is not None and lora_B is not None:
return ScatterMoELoRA.apply(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
lora_A,
lora_B,
scaling,
expert_biases,
gates,
grouped_in,
grouped_out,
use_fused_dX,
use_fused_gather,
)
else:
from .parallel_experts import ParallelLinear
return ParallelLinear.apply(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
expert_biases,
gates,
grouped_in,
grouped_out,
)

View File

@@ -1,66 +0,0 @@
from pathlib import Path
from kernels import (
LocalLayerRepository,
Mode,
register_kernel_mapping,
replace_kernel_forward_from_hub,
)
from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks.models import get_causal_lm_model_cls_prefix
class KernelsPlugin(BasePlugin):
def get_input_args(self):
return "axolotl.integrations.kernels.KernelsArgs"
def pre_model_load(self, cfg):
if cfg.use_scattermoe:
self._register_kernels()
self._kernelize_model(cfg.model_config_type)
def _register_kernels(self):
plugin_root = Path(__file__).parent
register_kernel_mapping(
{
"HFScatterMoEParallelExperts": {
"cuda": {
Mode.TRAINING: LocalLayerRepository(
repo_path=plugin_root / "libs" / "scattermoe_lora",
package_name="scattermoe_lora",
layer_name="HFScatterMoEGatedMLP",
),
Mode.INFERENCE: LocalLayerRepository(
repo_path=plugin_root / "libs" / "scattermoe_lora",
package_name="scattermoe_lora",
layer_name="HFScatterMoEGatedMLP",
),
},
}
}
)
def _kernelize_model(self, model_type: str):
if model_type == "olmoe":
from transformers.models.olmoe.modeling_olmoe import OlmoeSparseMoeBlock
replace_kernel_forward_from_hub(
OlmoeSparseMoeBlock, "HFScatterMoEParallelExperts"
)
else:
try:
model_moe_cls = get_model_moe_block(model_type)
replace_kernel_forward_from_hub(
model_moe_cls, "HFScatterMoEParallelExperts"
)
except Exception as err:
raise ValueError(f"Unsupported model type: {model_type}") from err
def get_model_moe_block(model_type: str):
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix, _ = get_causal_lm_model_cls_prefix(model_type)
module = __import__(module_path, fromlist=[f"{model_cls_prefix}SparseMoeBlock"])
model_cls = getattr(module, f"{model_cls_prefix}SparseMoeBlock")
return model_cls

View File

@@ -12,6 +12,7 @@ def save_compressed_model(
model: PreTrainedModel,
output_dir: Union[str, bytes],
trainer: Trainer,
safe_serialization: bool = False,
save_compressed: bool = False,
) -> None:
"""
@@ -21,6 +22,7 @@ def save_compressed_model(
model (PreTrainedModel): The model to be saved.
output_dir (str or bytes): Path where the model files will be written.
trainer (Trainer): Hugging Face Trainer for process synchronization.
safe_serialization (bool): Use safe serialization if True.
save_compressed (bool): Write compressed tensors if True.
"""
trainer.accelerator.wait_for_everyone()
@@ -32,6 +34,7 @@ def save_compressed_model(
modify_save_pretrained(model)
model.save_pretrained(
output_dir,
safe_serialization=safe_serialization,
save_compressed=save_compressed,
skip_sparsity_compression_stats=not save_compressed,
)

View File

@@ -6,12 +6,6 @@ See https://github.com/EleutherAI/lm-evaluation-harness
## Usage
There are two ways to use the LM Eval integration:
### 1. Post-Training Evaluation
When training with the plugin enabled, evaluation runs automatically after training completes:
```yaml
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
@@ -22,50 +16,9 @@ lm_eval_tasks:
- arc_easy
lm_eval_batch_size: # Batch size for evaluation
# Directory to save evaluation results.
# The final model is loaded from this directory
# unless specified otherwise (see below)
output_dir:
output_dir: # Directory to save evaluation results
```
Run training as usual:
```bash
axolotl train config.yml
```
### 2. Standalone CLI Evaluation
Evaluate any model directly without training:
```yaml
lm_eval_model: meta-llama/Llama-2-7b-hf
plugins:
- axolotl.integrations.lm_eval.LMEvalPlugin
lm_eval_tasks:
- gsm8k
- hellaswag
- arc_easy
lm_eval_batch_size: 8
output_dir: ./outputs
```
Run evaluation:
```bash
axolotl lm-eval config.yml
```
## Model Selection Priority
The model to evaluate is selected in the following priority order:
1. **`lm_eval_model`** - Explicit model path or HuggingFace repo (highest priority)
2. **`hub_model_id`** - Trained model pushed to HuggingFace Hub
3. **`output_dir`** - Local checkpoint directory containing trained model weights
## Citation
```bib

View File

@@ -5,7 +5,7 @@ Module for the Plugin for LM Eval Harness
import subprocess # nosec
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command, get_model_path
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
from .args import LMEvalArgs as LMEvalArgs
@@ -29,7 +29,7 @@ class LMEvalPlugin(BasePlugin):
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=get_model_path(cfg),
model=cfg.lm_eval_model or cfg.hub_model_id,
):
subprocess.run( # nosec
lm_eval_args,

View File

@@ -13,21 +13,6 @@ import yaml
from axolotl.utils.dict import DictDefault
def get_model_path(cfg: DictDefault) -> str | None:
"""
Determine which model path to use for evaluation.
Priority order (highest to lowest):
1. lm_eval_model - Explicit model path override
2. hub_model_id - Model pushed to HuggingFace Hub
3. None - Falls back to output_dir in build_lm_eval_command
Returns:
Model path string or None to use output_dir fallback
"""
return cfg.lm_eval_model or cfg.hub_model_id or None
def build_lm_eval_command(
tasks: list[str],
bfloat16=True,
@@ -123,7 +108,7 @@ def lm_eval(config: str, cloud: Optional[str] = None):
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
wandb_name=cfg.wandb_name,
model=get_model_path(cfg),
model=cfg.lm_eval_model or cfg.hub_model_id,
revision=cfg.revision,
apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn,

View File

@@ -26,6 +26,7 @@ from torch.distributed import DeviceMesh
from transformers import (
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForVision2Seq,
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
@@ -225,7 +226,6 @@ class ModelLoader:
):
self.model = self.model.merge_and_unload()
self._configure_experts_implementation()
self._apply_activation_checkpointing()
self._resize_token_embeddings()
self._adjust_model_config()
@@ -233,10 +233,6 @@ class ModelLoader:
self._configure_qat()
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
def _configure_experts_implementation(self):
if self.cfg.experts_implementation is not None:
self.model.set_experts_implementation(self.cfg.experts_implementation)
def _apply_activation_checkpointing(self):
if self.cfg.activation_offloading is True:
from axolotl.core.trainers.mixins.activation_checkpointing import (
@@ -338,12 +334,7 @@ class ModelLoader:
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so
# we need to convert them back to fp16/bf16 for flash-attn compatibility.
(
(
needs_fa2_dtype
or self.cfg.flash_attention
or self.cfg.flex_attention
or self.cfg.sage_attention
)
(needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention)
and not self.is_qlora_and_fsdp_enabled
)
or (
@@ -443,7 +434,7 @@ class ModelLoader:
"""
if self.cfg.is_multimodal:
self.auto_model_loader = MULTIMODAL_AUTO_MODEL_MAPPING.get(
self.model_config.model_type, AutoModelForImageTextToText
self.model_config.model_type, AutoModelForVision2Seq
)
if isinstance(self.auto_model_loader, str):
self.auto_model_loader = AutoModelForImageTextToText
@@ -485,7 +476,6 @@ class ModelLoader:
max_memory = None
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
self.model_kwargs["dtype"] = self.cfg.torch_dtype
is_ds_zero3 = is_deepspeed_zero3_enabled()
@@ -617,10 +607,6 @@ class ModelLoader:
elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = "sdpa"
elif self.cfg.sage_attention:
# sets FA2 attention to re-use same internal handling like masking
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = "flash_attention_2"
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = "eager"
@@ -684,7 +670,7 @@ class ModelLoader:
Uses the selected loader when provided; otherwise falls back to the auto loader.
"""
loader = model_loader_class or self.auto_model_loader
if loader in [AutoModelForCausalLM, AutoModelForImageTextToText]:
if loader in [AutoModelForCausalLM, AutoModelForVision2Seq]:
model = loader.from_config(
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
@@ -802,7 +788,6 @@ class ModelLoader:
# Use auto model loader (handles gptq and default cases)
model_loader_class = self.auto_model_loader
self.model_kwargs["dtype"] = self.model_kwargs["torch_dtype"]
if self.cfg.reinit_weights:
self.model = self._load_model_from_config(model_loader_class)
else:

View File

@@ -10,7 +10,6 @@ from functools import cached_property
import addict
import transformers
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_flash_attention_utils import is_flash_attn_available
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import (
@@ -97,7 +96,6 @@ class PatchManager:
# self._apply_flex_attention_patches()
self._apply_flash_attention_patches()
self._apply_chunked_cross_entropy_patch()
self._apply_sageattn_patches()
self._apply_fsdp_patches()
self._apply_adapter_patches()
self._apply_model_specific_patches()
@@ -155,9 +153,12 @@ class PatchManager:
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
if self.cfg.chunked_cross_entropy_num_chunks:
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
patch_chunked_ce_loss_fn(
self.cfg.chunked_cross_entropy_num_chunks,
use_dft=self.cfg.use_dynamic_finetuning,
)
else:
patch_chunked_ce_loss_fn()
patch_chunked_ce_loss_fn(use_dft=self.cfg.use_dynamic_finetuning)
def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
@@ -203,13 +204,6 @@ class PatchManager:
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs)
def _apply_sageattn_patches(self):
"""Apply patches for SageAttention."""
if self.cfg.sage_attention:
from axolotl.monkeypatch.attention.sage_attn import patch_sageattn
patch_sageattn()
def _apply_model_specific_patches(self):
"""Apply patches specific to model architectures."""
if (
@@ -229,6 +223,13 @@ class PatchManager:
patch_qwen3_next_modeling_packing()
if self.cfg.model_config_type == "mistral3" and self.cfg.processor_type:
from axolotl.monkeypatch.models.mistral3.mistral_common_tokenizer import (
apply_mistral_tokenizer_image_patch,
)
apply_mistral_tokenizer_image_patch()
if self.cfg.model_config_type == "kimi_linear":
from axolotl.monkeypatch.models.kimi_linear.patch_kimi_linear import (
patch_kimi_model,
@@ -329,7 +330,7 @@ class PatchManager:
else:
has_remote_code = False
if has_remote_code and self.cfg.trust_remote_code is not None:
if has_remote_code and self.cfg.trust_remote_code is False:
# If explicitly set in YAML, prefer that
has_remote_code = self.cfg.trust_remote_code
@@ -501,7 +502,6 @@ class PatchManager:
and not self.cfg.trust_remote_code
and not self.cfg.gptq
and self.cfg.flash_attention
and is_flash_attn_available()
and not self.inference
):
# TODO(MengqingCao): split these patches separately

View File

@@ -19,11 +19,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)
# Build common kwargs for processor loading
processor_kwargs = {}
if cfg.revision_of_model:
processor_kwargs["revision"] = cfg.revision_of_model
if cfg.tokenizer_use_mistral_common:
def _patch_mistralcommontokenizer():
@@ -36,7 +31,7 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
from axolotl.utils.mistral import HFMistralTokenizer
tokenization_mistral_common.MistralCommonBackend = HFMistralTokenizer
tokenization_mistral_common.MistralCommonTokenizer = HFMistralTokenizer
_patch_mistralcommontokenizer()
@@ -45,7 +40,6 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
if processor_cls == VoxtralProcessor:
return VoxtralProcessor.from_pretrained(
cfg.processor_config,
**processor_kwargs,
)
from axolotl.utils.mistral import Mistral3Processor
@@ -54,12 +48,10 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
tokenizer=tokenizer,
)
processor_kwargs["trust_remote_code"] = cfg.trust_remote_code or False
processor_kwargs["tokenizer"] = tokenizer
processor = processor_cls.from_pretrained(
cfg.processor_config,
**processor_kwargs,
trust_remote_code=cfg.trust_remote_code or False,
tokenizer=tokenizer,
)
# Attempt to load image size from processor if available

View File

@@ -28,10 +28,7 @@ PLUGIN_MANAGER = PluginManager.get_instance()
def modify_tokenizer_files(
tokenizer_path: str,
token_mappings: dict[int, str],
output_dir: str,
revision: str = "main",
tokenizer_path: str, token_mappings: dict[int, str], output_dir: str
) -> str:
"""
Modify tokenizer files to replace added_tokens strings, save to output directory,
@@ -44,7 +41,6 @@ def modify_tokenizer_files(
tokenizer_path: Path or name of the original tokenizer
token_mappings: Dict mapping {token_id (int): new_token_string}
output_dir: Directory to save the modified tokenizer
revision: Model revision/branch/tag/commit to load from (HF Hub)
Returns:
Path to the modified tokenizer directory
@@ -57,9 +53,7 @@ def modify_tokenizer_files(
if is_local_main_process():
# Load the tokenizer
temp_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, use_fast=True, revision=revision
)
temp_tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
# Save the tokenizer to the output directory
temp_tokenizer.save_pretrained(tokenizer_dir)
@@ -140,10 +134,7 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
from axolotl.utils.mistral import HFMistralTokenizer
# Load the HF-compatible wrapper around MistralTokenizer
kwargs = {}
if cfg.revision_of_model:
kwargs["revision"] = cfg.revision_of_model
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config, **kwargs)
tokenizer = HFMistralTokenizer.from_pretrained(cfg.tokenizer_config)
return tokenizer
@@ -159,8 +150,6 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
if cfg.revision_of_model:
tokenizer_kwargs["revision"] = cfg.revision_of_model
tokenizer_cls = AutoTokenizer
if cfg.tokenizer_type:
@@ -172,11 +161,8 @@ def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
# Apply token string overrides if specified
if cfg.added_tokens_overrides:
# Modify tokenizer files and get path to modified tokenizer
modify_kwargs = {"output_dir": cfg.output_dir}
if cfg.revision_of_model:
modify_kwargs["revision"] = cfg.revision_of_model
tokenizer_path = modify_tokenizer_files(
tokenizer_path, cfg.added_tokens_overrides, **modify_kwargs
tokenizer_path, cfg.added_tokens_overrides, output_dir=cfg.output_dir
)
tokenizer = tokenizer_cls.from_pretrained(

View File

@@ -5,7 +5,6 @@ from typing import Type
import addict
import torch
import transformers
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from axolotl.utils.dict import DictDefault
@@ -154,9 +153,6 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
This function determines the appropriate model config source, loads it, applies any
necessary overrides, and validates it for compatibility with the `axolotl` config.
If `cfg.cls_model_config` is set, a custom config class from transformers will be
used instead of `AutoConfig` (e.g., 'LlamaConfig', 'MistralConfig').
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
@@ -178,13 +174,8 @@ def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
if cfg.num_labels:
# num_labels is used to initialize classifier models
config_kwargs["num_labels"] = cfg.num_labels
config_cls = AutoConfig
if cfg.cls_model_config:
config_cls = getattr(transformers, cfg.cls_model_config)
try:
model_config = config_cls.from_pretrained(
model_config = AutoConfig.from_pretrained(
model_config_name,
trust_remote_code=trust_remote_code,
**config_kwargs,

View File

@@ -111,6 +111,7 @@ class MambaLMHeadModel(nn.Module, GenerationMixin):
self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
safe_serialization: Optional[bool] = None,
):
if state_dict is None:
state_dict = self.state_dict()

View File

@@ -1,211 +0,0 @@
"""
Monkeypatch for SageAttention for use with transformers.
https://github.com/thu-ml/SageAttention/
"""
import torch
from transformers.integrations.sdpa_attention import repeat_kv
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
sageattn = None # pylint: disable=invalid-name
sageattn_varlen = None # pylint: disable=invalid-name
def _is_sageattn_available():
"""Determine if SageAttention is available"""
try:
import sageattention # noqa: F401 # pylint: disable=unused-import
return True
except ImportError:
return False
if _is_sageattn_available():
# import sageattn here if available
from sageattention import sageattn, sageattn_varlen
def _check_sageattn_imported():
"""Check if SageAttention is imported. Raises an ImportError if not."""
if sageattn is None:
raise ImportError(
"SageAttention is not installed. Please install it from source: "
"`pip install git+https://github.com/thu-ml/SageAttention.git@1718ddc06dbc694bcf3c6b49ac28c1921aa2d8bd`"
)
def sage_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None = None,
dropout: float = 0.0,
scaling: float | None = None,
is_causal: bool | None = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
"""
Forward pass for SageAttention compatible with transformers attention interfaces.
https://github.com/thu-ml/SageAttention/
"""
_check_sageattn_imported()
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
raise NotImplementedError(
"SageAttention does not support `output_attentions=True` or `head_mask`."
)
# The base sageattn API does not support dropout.
if dropout > 0.0:
raise NotImplementedError("SageAttention does not support dropout.")
# Handle Grouped-Query Attention (GQA) and Multi-Query Attention (MQA)
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
# Calculate is_causal following transformers
assert is_causal is not False, "is_causal must be True or None"
is_causal = True
position_ids = kwargs.get("position_ids", None)
query_length = query.shape[2]
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_k", None)
max_length_q = kwargs.get("max_length_q", None)
max_length_k = kwargs.get("max_length_k", None)
# Sample packing uses position_ids, so we check for it first
if position_ids is not None and (
max_length_q is not None
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
# transpose inputs to NHD layout for use with FA2 utils
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
batch_size = query.size(0)
from transformers.modeling_flash_attention_utils import (
prepare_fa2_from_position_ids,
)
if cu_seqlens_q is None or cu_seqlens_k is None:
query, key, value, indices_q, cu_seq_lens, max_seq_lens = (
prepare_fa2_from_position_ids(query, key, value, position_ids)
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_length_q, max_length_k = max_seq_lens
else:
query = query.reshape(-1, query.size(-2), query.size(-1))
key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1))
attn_output_unpad = sageattn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
is_causal=is_causal,
sm_scale=scaling,
smooth_k=False, # reduces loss 0 / nan grad norms
tensor_layout="NHD",
)
attn_output = attn_output_unpad.view(
batch_size, -1, attn_output_unpad.size(-2), attn_output_unpad.size(-1)
)
elif attention_mask is not None:
# NOTE: When used without `pad_to_sequence_len`, the loss becomes unstable after a few steps.
assert attention_mask.ndim == 2, "Attention mask must be 2D"
from transformers.modeling_flash_attention_utils import (
_upad_input,
)
# transpose inputs to NHD layout for use with FA2 utils
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
batch_size = query.shape[0]
query, key, value, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
query, key, value, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_q, max_seqlen_k = max_seq_lens
attn_output_unpad = sageattn_varlen(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
is_causal=is_causal,
sm_scale=scaling,
tensor_layout="NHD",
)
from flash_attn.bert_padding import pad_input
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
# Use standard sageattn
# The input layout for transformers models is (batch_size, num_heads, seq_len, head_dim),
# which corresponds to SageAttention's "HND" layout.
attn_output = sageattn(
q=query,
k=key,
v=value,
tensor_layout="HND",
is_causal=is_causal,
sm_scale=scaling,
)
# SageAttention with "HND" returns (batch, heads, seq_len, head_dim)
# Transformers expects (batch, seq_len, heads, head_dim) for the output
# So we need to transpose dimensions 1 and 2
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
def patch_sageattn():
"""Patch SageAttention for use with transformers."""
_check_sageattn_imported()
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
# Replace flash attention with sage attention
ALL_ATTENTION_FUNCTIONS.register("flash_attention_2", sage_attention_forward)
# Note: New method after transformers refactor to use ALL_MASK_ATTENTION_FUNCTIONS
# Register sage_attention with the global attention interface
# ALL_ATTENTION_FUNCTIONS.register("sage_attention", sage_attention_forward)
# from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS, flash_attention_mask
# ALL_MASK_ATTENTION_FUNCTIONS.register("sage_attention", flash_attention_mask)
LOG.info("SageAttention patched successfully")

Some files were not shown because too many files have changed in this diff Show More