Compare commits
42 Commits
08fc7de87e
...
b7ec06b8a1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b7ec06b8a1 | ||
|
|
e2f01de0e8 | ||
|
|
5352d41d32 | ||
|
|
c15f6cffe2 | ||
|
|
e4032fc90f | ||
|
|
6136ae627b | ||
|
|
e662972a29 | ||
|
|
ebbd7fa847 | ||
|
|
ac77da96da | ||
|
|
798c8fba89 | ||
|
|
17fc747f99 | ||
|
|
901f2356bc | ||
|
|
1bf65c500e | ||
|
|
bcbe049c21 | ||
|
|
90090fa9e8 | ||
|
|
7420fd4de6 | ||
|
|
05113bc91a | ||
|
|
e562e149ce | ||
|
|
9de5b76336 | ||
|
|
323da791eb | ||
|
|
6990478163 | ||
|
|
63a58cfec1 | ||
|
|
3985ec2f67 | ||
|
|
a44edda6d7 | ||
|
|
66c3e5a3fd | ||
|
|
b8358aa5ab | ||
|
|
e079cf16a2 | ||
|
|
e2f69828d2 | ||
|
|
122b50bad6 | ||
|
|
e77a185e86 | ||
|
|
29fa4dedbb | ||
|
|
315cdeede9 | ||
|
|
e7a6a5b529 | ||
|
|
bfb4da1d25 | ||
|
|
4dfa0a59b2 | ||
|
|
4ef608dda3 | ||
|
|
7daf7d96f1 | ||
|
|
7c56809c7f | ||
|
|
149178ddb7 | ||
|
|
dc638e723f | ||
|
|
6f15da4cac | ||
|
|
900eec7988 |
6
.github/CONTRIBUTING.md
vendored
6
.github/CONTRIBUTING.md
vendored
@@ -31,7 +31,11 @@ PRs are **greatly welcome**!
|
|||||||
|
|
||||||
Please run below to setup env
|
Please run below to setup env
|
||||||
```bash
|
```bash
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
# Install axolotl + dev and test dependencies
|
||||||
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
|
uv venv --no-project --relocatable
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
|
||||||
pre-commit install
|
pre-commit install
|
||||||
|
|
||||||
# test
|
# test
|
||||||
|
|||||||
16
.github/workflows/base.yml
vendored
16
.github/workflows/base.yml
vendored
@@ -30,14 +30,6 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -168,14 +160,6 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-uv-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
|
|||||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -6,7 +6,7 @@ on:
|
|||||||
types: [opened, synchronize, reopened, ready_for_review]
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- 'requirements.txt'
|
- 'pyproject.toml'
|
||||||
- '.github/workflows/*.yml'
|
- '.github/workflows/*.yml'
|
||||||
- "*.[q]md"
|
- "*.[q]md"
|
||||||
- "examples/**/*.y[a]?ml"
|
- "examples/**/*.y[a]?ml"
|
||||||
|
|||||||
12
.github/workflows/main.yml
vendored
12
.github/workflows/main.yml
vendored
@@ -18,12 +18,6 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.0
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -180,12 +174,6 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.0
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
|
|||||||
35
.github/workflows/multi-gpu-e2e.yml
vendored
35
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -3,17 +3,15 @@ name: docker-multigpu-tests-biweekly
|
|||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- 'tests/e2e/multigpu/**.py'
|
- "tests/e2e/multigpu/**.py"
|
||||||
- 'requirements.txt'
|
- "pyproject.toml"
|
||||||
- 'setup.py'
|
- ".github/workflows/multi-gpu-e2e.yml"
|
||||||
- 'pyproject.toml'
|
- "scripts/cutcrossentropy_install.py"
|
||||||
- '.github/workflows/multi-gpu-e2e.yml'
|
- "src/axolotl/core/trainers/mixins/sequence_parallel.py"
|
||||||
- 'scripts/cutcrossentropy_install.py'
|
- "src/axolotl/utils/distributed.py"
|
||||||
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
|
|
||||||
- 'src/axolotl/utils/distributed.py'
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
schedule:
|
schedule:
|
||||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
- cron: "0 0 * * 1,4" # Runs at 00:00 UTC every monday & thursday
|
||||||
|
|
||||||
# Cancel jobs on the same ref if a new one is triggered
|
# Cancel jobs on the same ref if a new one is triggered
|
||||||
concurrency:
|
concurrency:
|
||||||
@@ -33,19 +31,19 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
# - cuda: 129
|
# - cuda: 129
|
||||||
# cuda_version: 12.9.1
|
# cuda_version: 12.9.1
|
||||||
# python_version: "3.12"
|
# python_version: "3.12"
|
||||||
# pytorch: 2.9.1
|
# pytorch: 2.9.1
|
||||||
# axolotl_extras: "fbgemm-gpu"
|
# axolotl_extras: "fbgemm-gpu"
|
||||||
# num_gpus: 2
|
# num_gpus: 2
|
||||||
# dockerfile: "Dockerfile-uv.jinja"
|
# dockerfile: "Dockerfile-uv.jinja"
|
||||||
- cuda: 130
|
- cuda: 130
|
||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
# axolotl_extras: fbgemm-gpu
|
# axolotl_extras: fbgemm-gpu
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
@@ -53,7 +51,6 @@ jobs:
|
|||||||
pytorch: 2.10.0
|
pytorch: 2.10.0
|
||||||
axolotl_extras: "fbgemm-gpu"
|
axolotl_extras: "fbgemm-gpu"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
dockerfile: "Dockerfile-uv.jinja"
|
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
steps:
|
steps:
|
||||||
@@ -75,7 +72,7 @@ jobs:
|
|||||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
env:
|
env:
|
||||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
|||||||
13
.github/workflows/pypi.yml
vendored
13
.github/workflows/pypi.yml
vendored
@@ -8,6 +8,9 @@ on:
|
|||||||
|
|
||||||
permissions: {}
|
permissions: {}
|
||||||
|
|
||||||
|
env:
|
||||||
|
UV_SYSTEM_PYTHON: "1"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
setup_release:
|
setup_release:
|
||||||
name: Create Release
|
name: Create Release
|
||||||
@@ -41,11 +44,15 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install wheel packaging==26.0
|
uv pip install wheel packaging
|
||||||
pip3 install --no-build-isolation -e .
|
uv pip install --no-build-isolation -e .
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||||
|
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||||
|
|
||||||
- name: Extract tag name
|
- name: Extract tag name
|
||||||
id: tag
|
id: tag
|
||||||
|
|||||||
55
.github/workflows/tests-nightly.yml
vendored
55
.github/workflows/tests-nightly.yml
vendored
@@ -2,15 +2,18 @@ name: Tests Nightly against upstream main
|
|||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
schedule:
|
schedule:
|
||||||
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
|
- cron: "0 0 * * *" # Runs at 00:00 UTC every day
|
||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
paths:
|
paths:
|
||||||
- '.github/workflows/tests-nightly.yml'
|
- ".github/workflows/tests-nightly.yml"
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
|
|
||||||
|
env:
|
||||||
|
UV_SYSTEM_PYTHON: "1"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pre-commit:
|
pre-commit:
|
||||||
name: pre-commit
|
name: pre-commit
|
||||||
@@ -20,7 +23,7 @@ jobs:
|
|||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: "pip" # caching pip dependencies
|
||||||
- uses: pre-commit/action@v3.0.1
|
- uses: pre-commit/action@v3.0.1
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
@@ -43,7 +46,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||||
pytorch_version: ["2.9.1", "2.10.0"]
|
pytorch_version: ["2.9.1", "2.10.0"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
@@ -61,36 +64,34 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
|
||||||
|
|
||||||
- name: upgrade pip
|
- name: Install uv
|
||||||
run: |
|
uses: astral-sh/setup-uv@v7
|
||||||
pip3 install --upgrade pip
|
|
||||||
pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel
|
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
uv pip install torch==${{ matrix.pytorch_version }} torchvision
|
||||||
|
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
|
||||||
- name: Update requirements.txt
|
|
||||||
run: |
|
|
||||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
|
|
||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
|
||||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
|
||||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 show torch
|
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
|
||||||
pip3 install --no-build-isolation -U -e .
|
python scripts/cutcrossentropy_install.py --uv | sh
|
||||||
python scripts/unsloth_install.py | sh
|
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
|
||||||
|
- name: Override with nightly HF packages
|
||||||
|
run: |
|
||||||
|
uv pip install --no-deps \
|
||||||
|
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
|
||||||
|
"peft @ git+https://github.com/huggingface/peft.git@main" \
|
||||||
|
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
||||||
|
"trl @ git+https://github.com/huggingface/trl.git@main" \
|
||||||
|
"datasets @ git+https://github.com/huggingface/datasets.git@main"
|
||||||
|
|
||||||
- name: Make sure PyTorch version wasn't clobbered
|
- name: Make sure PyTorch version wasn't clobbered
|
||||||
run: |
|
run: |
|
||||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
|
||||||
|
|
||||||
- name: Ensure axolotl CLI was installed
|
- name: Ensure axolotl CLI was installed
|
||||||
run: |
|
run: |
|
||||||
@@ -102,9 +103,6 @@ jobs:
|
|||||||
pytest -v --durations=10 tests/patched/
|
pytest -v --durations=10 tests/patched/
|
||||||
pytest -v --durations=10 tests/cli/
|
pytest -v --durations=10 tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
|
||||||
run: |
|
|
||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
@@ -136,7 +134,6 @@ jobs:
|
|||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
dockerfile: "Dockerfile-uv.jinja"
|
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -157,7 +154,7 @@ jobs:
|
|||||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
env:
|
env:
|
||||||
|
|||||||
97
.github/workflows/tests.yml
vendored
97
.github/workflows/tests.yml
vendored
@@ -6,21 +6,19 @@ on:
|
|||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- "**.py"
|
||||||
- 'requirements.txt'
|
- "pyproject.toml"
|
||||||
- '.github/workflows/*.yml'
|
- ".github/workflows/*.yml"
|
||||||
- 'requirements-tests.txt'
|
- "cicd/cicd.sh"
|
||||||
- 'cicd/cicd.sh'
|
- "cicd/Dockerfile-uv.jinja"
|
||||||
- 'cicd/Dockerfile.jinja'
|
|
||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- "**.py"
|
||||||
- 'requirements.txt'
|
- "pyproject.toml"
|
||||||
- '.github/workflows/*.yml'
|
- ".github/workflows/*.yml"
|
||||||
- 'requirements-tests.txt'
|
- "cicd/cicd.sh"
|
||||||
- 'cicd/cicd.sh'
|
- "cicd/Dockerfile-uv.jinja"
|
||||||
- 'cicd/Dockerfile.jinja'
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
# Cancel jobs on the same ref if a new one is triggered
|
# Cancel jobs on the same ref if a new one is triggered
|
||||||
@@ -33,6 +31,7 @@ permissions:
|
|||||||
|
|
||||||
env:
|
env:
|
||||||
TRANSFORMERS_IS_CI: "yes"
|
TRANSFORMERS_IS_CI: "yes"
|
||||||
|
UV_SYSTEM_PYTHON: "1"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pre-commit:
|
pre-commit:
|
||||||
@@ -44,7 +43,7 @@ jobs:
|
|||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: "pip" # caching pip dependencies
|
||||||
- uses: pre-commit/action@v3.0.1
|
- uses: pre-commit/action@v3.0.1
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
@@ -73,7 +72,7 @@ jobs:
|
|||||||
exclude:
|
exclude:
|
||||||
- python_version: "3.14"
|
- python_version: "3.14"
|
||||||
pytorch_version: "2.9.1"
|
pytorch_version: "2.9.1"
|
||||||
timeout-minutes: 20
|
timeout-minutes: 25
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: cleanup node
|
- name: cleanup node
|
||||||
@@ -94,32 +93,25 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
|
||||||
|
|
||||||
- name: upgrade pip
|
- name: Install uv
|
||||||
run: |
|
uses: astral-sh/setup-uv@v7
|
||||||
pip3 install --upgrade pip
|
|
||||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
|
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
uv pip install torch==${{ matrix.pytorch_version }} torchvision
|
||||||
|
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 show torch
|
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
|
||||||
pip3 install --no-cache-dir --no-build-isolation -U -e .
|
python scripts/cutcrossentropy_install.py --uv | sh
|
||||||
python scripts/unsloth_install.py | sh
|
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
|
||||||
|
|
||||||
- name: cleanup pip cache
|
|
||||||
run: |
|
|
||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
|
||||||
|
|
||||||
- name: Make sure PyTorch version wasn't clobbered
|
- name: Make sure PyTorch version wasn't clobbered
|
||||||
run: |
|
run: |
|
||||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
|
||||||
|
|
||||||
- name: Ensure axolotl CLI was installed
|
- name: Ensure axolotl CLI was installed
|
||||||
run: |
|
run: |
|
||||||
@@ -188,38 +180,42 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
|
||||||
|
|
||||||
- name: upgrade pip
|
- name: Install uv
|
||||||
run: |
|
uses: astral-sh/setup-uv@v7
|
||||||
pip3 install --upgrade pip
|
|
||||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
|
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
uv pip install torch==${{ matrix.pytorch_version }} torchvision
|
||||||
|
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 show torch
|
uv pip install packaging setuptools_scm build wheel psutil
|
||||||
python -m build --no-isolation --sdist
|
python -m build --no-isolation --sdist
|
||||||
pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz
|
uv pip install --no-build-isolation dist/axolotl*.tar.gz --override /tmp/torch-pin.txt
|
||||||
python scripts/unsloth_install.py | sh
|
python scripts/cutcrossentropy_install.py --uv | sh
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||||
|
|
||||||
- name: cleanup pip cache
|
|
||||||
run: |
|
|
||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
|
||||||
|
|
||||||
- name: Make sure PyTorch version wasn't clobbered
|
- name: Make sure PyTorch version wasn't clobbered
|
||||||
run: |
|
run: |
|
||||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
|
||||||
|
|
||||||
- name: Ensure axolotl CLI was installed
|
- name: Ensure axolotl CLI was installed
|
||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
|
- name: Verify agent docs are discoverable
|
||||||
|
run: |
|
||||||
|
# Agent docs live in docs/agents/ (source of truth) and are resolved
|
||||||
|
# at runtime from the repo checkout or via `axolotl fetch docs`
|
||||||
|
axolotl agent-docs --list
|
||||||
|
axolotl agent-docs | grep -q "Fine-tuning framework"
|
||||||
|
axolotl agent-docs grpo | grep -q "GRPO"
|
||||||
|
axolotl agent-docs sft | grep -q "SFT"
|
||||||
|
python -c "from axolotl.cli.agent_docs import get_doc, list_topics; assert len(list_topics()) >= 5; assert 'GRPO' in get_doc('grpo')"
|
||||||
|
|
||||||
- name: Show HF cache
|
- name: Show HF cache
|
||||||
run: hf cache ls
|
run: hf cache ls
|
||||||
|
|
||||||
@@ -281,7 +277,6 @@ jobs:
|
|||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
dockerfile: "Dockerfile-uv.jinja"
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -302,7 +297,7 @@ jobs:
|
|||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
env:
|
env:
|
||||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||||
@@ -364,7 +359,7 @@ jobs:
|
|||||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
|
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
|
||||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
env:
|
env:
|
||||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
|||||||
@@ -16,6 +16,9 @@ axolotl inference config.yaml # Interactive inference
|
|||||||
axolotl merge-lora config.yaml # Merge LoRA adapter into base model
|
axolotl merge-lora config.yaml # Merge LoRA adapter into base model
|
||||||
axolotl vllm-serve config.yaml # Start vLLM server for GRPO/EBFT training
|
axolotl vllm-serve config.yaml # Start vLLM server for GRPO/EBFT training
|
||||||
axolotl fetch examples # Download example configs
|
axolotl fetch examples # Download example configs
|
||||||
|
axolotl agent-docs # Show agent-optimized docs (bundled with pip package)
|
||||||
|
axolotl agent-docs grpo # Topic-specific agent reference
|
||||||
|
axolotl config-schema # Dump config JSON schema
|
||||||
```
|
```
|
||||||
|
|
||||||
## Training Methods
|
## Training Methods
|
||||||
@@ -23,7 +26,7 @@ axolotl fetch examples # Download example configs
|
|||||||
| Method | Config Key | When to Use |
|
| Method | Config Key | When to Use |
|
||||||
|--------|-----------|-------------|
|
|--------|-----------|-------------|
|
||||||
| SFT | *(default)* | Input-output pairs, instruction tuning |
|
| SFT | *(default)* | Input-output pairs, instruction tuning |
|
||||||
| DPO/IPO | `rl: dpo` / `rl: ipo` | Paired preference data (chosen vs rejected) |
|
| DPO/IPO | `rl: dpo` / `rl: dpo, dpo_loss_type: ["ipo"]` | Paired preference data (chosen vs rejected) |
|
||||||
| KTO | `rl: kto` | Unpaired binary preference labels |
|
| KTO | `rl: kto` | Unpaired binary preference labels |
|
||||||
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
|
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
|
||||||
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |
|
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |
|
||||||
@@ -35,6 +38,8 @@ Agent-specific references:
|
|||||||
- [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions
|
- [docs/agents/grpo.md](docs/agents/grpo.md) — GRPO online RL with reward functions
|
||||||
- [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models
|
- [docs/agents/reward_modelling.md](docs/agents/reward_modelling.md) — outcome and process reward models
|
||||||
- [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining
|
- [docs/agents/pretraining.md](docs/agents/pretraining.md) — continual pretraining
|
||||||
|
- [docs/agents/model_architectures.md](docs/agents/model_architectures.md) — model-specific quirks (Gemma4, Qwen3.5 MoE, etc.)
|
||||||
|
- [docs/agents/new_model_support.md](docs/agents/new_model_support.md) — debugging and adding support for new model architectures
|
||||||
|
|
||||||
## Config Pattern
|
## Config Pattern
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
include requirements.txt
|
|
||||||
include README.md
|
include README.md
|
||||||
include LICENSE
|
include LICENSE
|
||||||
include src/setuptools_axolotl_dynamic_dependencies.py
|
include VERSION
|
||||||
include src/axolotl/utils/chat_templates/templates/*.jinja
|
include src/axolotl/utils/chat_templates/templates/*.jinja
|
||||||
|
include AGENTS.md
|
||||||
|
recursive-include docs/agents *.md
|
||||||
recursive-include axolotl *.py
|
recursive-include axolotl *.py
|
||||||
|
|||||||
46
README.md
46
README.md
@@ -29,6 +29,9 @@
|
|||||||
|
|
||||||
## 🎉 Latest Updates
|
## 🎉 Latest Updates
|
||||||
|
|
||||||
|
- 2026/04:
|
||||||
|
- New model support has been added in Axolotl for [Mistral Medium 3.5](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral-medium-3_5) and [Gemma 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/gemma4).
|
||||||
|
- Axolotl is now [uv-first](https://github.com/axolotl-ai-cloud/axolotl/pull/3545) and has [SonicMoE fused LoRA](https://github.com/axolotl-ai-cloud/axolotl/pull/3519) support.
|
||||||
- 2026/03:
|
- 2026/03:
|
||||||
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
|
- New model support has been added in Axolotl for [Mistral Small 4](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/mistral4), [Qwen3.5, Qwen3.5 MoE](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/qwen3.5), [GLM-4.7-Flash](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm47-flash), [GLM-4.6V](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm46v), and [GLM-4.5-Air](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/glm45).
|
||||||
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
|
- [MoE expert quantization](https://docs.axolotl.ai/docs/expert_quantization.html) support (via `quantize_moe_experts: true`) greatly reduces VRAM when training MoE models (FSDP2 compat).
|
||||||
@@ -86,7 +89,7 @@ Features:
|
|||||||
**Requirements**:
|
**Requirements**:
|
||||||
|
|
||||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python 3.11
|
- Python >=3.11 (3.12 recommended)
|
||||||
- PyTorch ≥2.9.1
|
- PyTorch ≥2.9.1
|
||||||
|
|
||||||
### Google Colab
|
### Google Colab
|
||||||
@@ -95,11 +98,19 @@ Features:
|
|||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
#### Using pip
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
|
# install uv if you don't already have it installed (restart shell after)
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
|
||||||
|
# change depending on system
|
||||||
|
export UV_TORCH_BACKEND=cu128
|
||||||
|
|
||||||
|
# create a new virtual environment
|
||||||
|
uv venv --python 3.12
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
uv pip install torch==2.10.0 torchvision
|
||||||
|
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||||
|
|
||||||
# Download example axolotl configs, deepspeed configs
|
# Download example axolotl configs, deepspeed configs
|
||||||
axolotl fetch examples
|
axolotl fetch examples
|
||||||
@@ -110,7 +121,7 @@ axolotl fetch deepspeed_configs # OPTIONAL
|
|||||||
|
|
||||||
Installing with Docker can be less error prone than installing in your own environment.
|
Installing with Docker can be less error prone than installing in your own environment.
|
||||||
```bash
|
```bash
|
||||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
docker run --gpus '"all"' --ipc=host --rm -it axolotlai/axolotl:main-latest
|
||||||
```
|
```
|
||||||
|
|
||||||
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
||||||
@@ -157,6 +168,29 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
|
|||||||
- [API Reference](https://docs.axolotl.ai/docs/api/) - Auto-generated code documentation
|
- [API Reference](https://docs.axolotl.ai/docs/api/) - Auto-generated code documentation
|
||||||
- [FAQ](https://docs.axolotl.ai/docs/faq.html) - Frequently asked questions
|
- [FAQ](https://docs.axolotl.ai/docs/faq.html) - Frequently asked questions
|
||||||
|
|
||||||
|
## AI Agent Support
|
||||||
|
|
||||||
|
Axolotl ships with built-in documentation optimized for AI coding agents (Claude Code, Cursor, Copilot, etc.). These docs are bundled with the pip package — no repo clone needed.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Show overview and available training methods
|
||||||
|
axolotl agent-docs
|
||||||
|
|
||||||
|
# Topic-specific references
|
||||||
|
axolotl agent-docs sft # supervised fine-tuning
|
||||||
|
axolotl agent-docs grpo # GRPO online RL
|
||||||
|
axolotl agent-docs preference_tuning # DPO, KTO, ORPO, SimPO
|
||||||
|
axolotl agent-docs reward_modelling # outcome and process reward models
|
||||||
|
axolotl agent-docs pretraining # continual pretraining
|
||||||
|
axolotl agent-docs --list # list all topics
|
||||||
|
|
||||||
|
# Dump config schema for programmatic use
|
||||||
|
axolotl config-schema
|
||||||
|
axolotl config-schema --field adapter
|
||||||
|
```
|
||||||
|
|
||||||
|
If you're working with the source repo, agent docs are also available at `docs/agents/` and the project overview is in `AGENTS.md`.
|
||||||
|
|
||||||
## 🤝 Getting Help
|
## 🤝 Getting Help
|
||||||
|
|
||||||
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support
|
- Join our [Discord community](https://discord.gg/HhrNrHJPRb) for support
|
||||||
|
|||||||
@@ -134,7 +134,6 @@ quartodoc:
|
|||||||
- monkeypatch.stablelm_attn_hijack_flash
|
- monkeypatch.stablelm_attn_hijack_flash
|
||||||
- monkeypatch.trainer_fsdp_optim
|
- monkeypatch.trainer_fsdp_optim
|
||||||
- monkeypatch.transformers_fa_utils
|
- monkeypatch.transformers_fa_utils
|
||||||
- monkeypatch.unsloth_
|
|
||||||
- monkeypatch.data.batch_dataset_fetcher
|
- monkeypatch.data.batch_dataset_fetcher
|
||||||
- monkeypatch.mixtral
|
- monkeypatch.mixtral
|
||||||
- monkeypatch.gradient_checkpointing.offload_cpu
|
- monkeypatch.gradient_checkpointing.offload_cpu
|
||||||
@@ -312,6 +311,7 @@ website:
|
|||||||
- docs/dataset_loading.qmd
|
- docs/dataset_loading.qmd
|
||||||
- docs/qat.qmd
|
- docs/qat.qmd
|
||||||
- docs/quantize.qmd
|
- docs/quantize.qmd
|
||||||
|
- docs/1_58bit_finetuning.qmd
|
||||||
- docs/optimizations.qmd
|
- docs/optimizations.qmd
|
||||||
|
|
||||||
- section: "Core Concepts"
|
- section: "Core Concepts"
|
||||||
@@ -327,7 +327,6 @@ website:
|
|||||||
- section: "Advanced Features"
|
- section: "Advanced Features"
|
||||||
contents:
|
contents:
|
||||||
- docs/fsdp_qlora.qmd
|
- docs/fsdp_qlora.qmd
|
||||||
- docs/unsloth.qmd
|
|
||||||
- docs/torchao.qmd
|
- docs/torchao.qmd
|
||||||
- docs/custom_integrations.qmd
|
- docs/custom_integrations.qmd
|
||||||
- docs/sequence_parallelism.qmd
|
- docs/sequence_parallelism.qmd
|
||||||
|
|||||||
@@ -22,15 +22,6 @@ WORKDIR /workspace/axolotl
|
|||||||
RUN git fetch origin +$GITHUB_REF && \
|
RUN git fetch origin +$GITHUB_REF && \
|
||||||
git checkout FETCH_HEAD
|
git checkout FETCH_HEAD
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
|
||||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|
||||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
RUN uv pip install packaging==26.0 setuptools==78.1.1
|
RUN uv pip install packaging==26.0 setuptools==78.1.1
|
||||||
RUN uv pip install torchvision
|
RUN uv pip install torchvision
|
||||||
RUN uv pip uninstall causal_conv1d
|
RUN uv pip uninstall causal_conv1d
|
||||||
@@ -40,11 +31,21 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
|||||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py --uv | sh
|
# Override with nightly HF packages for nightly builds
|
||||||
|
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||||
|
uv pip install --no-deps \
|
||||||
|
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
|
||||||
|
"peft @ git+https://github.com/huggingface/peft.git@main" \
|
||||||
|
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
||||||
|
"trl @ git+https://github.com/huggingface/trl.git@main" \
|
||||||
|
"datasets @ git+https://github.com/huggingface/datasets.git@main"; \
|
||||||
|
fi
|
||||||
|
|
||||||
RUN python scripts/cutcrossentropy_install.py --uv | sh
|
RUN python scripts/cutcrossentropy_install.py --uv | sh
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
RUN uv pip install -r requirements-dev.txt -r requirements-tests.txt
|
RUN uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||||
|
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
# fix so that git fetch/pull from remote works
|
||||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||||
|
|||||||
@@ -1,54 +0,0 @@
|
|||||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
|
||||||
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
|
||||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
|
||||||
ENV CUDA="{{ CUDA }}"
|
|
||||||
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
|
|
||||||
ENV GITHUB_REF="{{ GITHUB_REF }}"
|
|
||||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
|
||||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
|
||||||
ENV HF_HOME="{{ HF_HOME }}"
|
|
||||||
ENV AXOLOTL_DATASET_NUM_PROC="8"
|
|
||||||
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
|
||||||
|
|
||||||
WORKDIR /workspace
|
|
||||||
|
|
||||||
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
|
||||||
|
|
||||||
WORKDIR /workspace/axolotl
|
|
||||||
|
|
||||||
RUN git fetch origin +$GITHUB_REF && \
|
|
||||||
git checkout FETCH_HEAD
|
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
|
||||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|
||||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
RUN pip install packaging==26.0 setuptools==78.1.1 psutil
|
|
||||||
RUN pip uninstall -y causal_conv1d
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
|
||||||
else \
|
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
|
||||||
RUN python scripts/cutcrossentropy_install.py | sh
|
|
||||||
|
|
||||||
# So we can test the Docker image
|
|
||||||
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
|
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
|
||||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
|
||||||
git config --get remote.origin.fetch
|
|
||||||
|
|
||||||
# helper for huggingface-login cli
|
|
||||||
RUN git config --global credential.helper store
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__, f'Expected torch $PYTORCH_VERSION but got {torch.__version__}'"
|
||||||
|
|
||||||
set -o pipefail
|
set -o pipefail
|
||||||
for i in 1 2 3; do
|
for i in 1 2 3; do
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
|||||||
template_env = jinja2.Environment(
|
template_env = jinja2.Environment(
|
||||||
loader=template_loader, autoescape=select_autoescape()
|
loader=template_loader, autoescape=select_autoescape()
|
||||||
)
|
)
|
||||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
|
||||||
df_template = template_env.get_template(dockerfile)
|
df_template = template_env.get_template(dockerfile)
|
||||||
|
|
||||||
df_args = {
|
df_args = {
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
|||||||
template_env = jinja2.Environment(
|
template_env = jinja2.Environment(
|
||||||
loader=template_loader, autoescape=select_autoescape()
|
loader=template_loader, autoescape=select_autoescape()
|
||||||
)
|
)
|
||||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
|
||||||
df_template = template_env.get_template(dockerfile)
|
df_template = template_env.get_template(dockerfile)
|
||||||
|
|
||||||
df_args = {
|
df_args = {
|
||||||
|
|||||||
@@ -24,15 +24,15 @@ WORKDIR /workspace/axolotl
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||||
RUN pip uninstall -y causal_conv1d
|
RUN pip uninstall -y causal_conv1d
|
||||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
BASE_EXTRAS="optimizers,ray"; \
|
||||||
else \
|
else \
|
||||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
BASE_EXTRAS="deepspeed,optimizers,ray"; \
|
||||||
fi && \
|
fi && \
|
||||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||||
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
fi && \ python scripts/unsloth_install.py | sh && \
|
fi && \
|
||||||
python scripts/cutcrossentropy_install.py | sh && \
|
python scripts/cutcrossentropy_install.py | sh && \
|
||||||
pip install pytest && \
|
pip install pytest && \
|
||||||
pip cache purge
|
pip cache purge
|
||||||
|
|||||||
@@ -58,19 +58,3 @@ RUN git lfs install --skip-repo && \
|
|||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||||
pip3 cache purge
|
pip3 cache purge
|
||||||
|
|
||||||
# Map Python version (e.g., 3.12 -> cp312)
|
|
||||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
|
||||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
|
||||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
|
||||||
# Map architecture
|
|
||||||
case "$TARGETARCH" in \
|
|
||||||
amd64) ARCH_TAG="x86_64" ;; \
|
|
||||||
arm64) ARCH_TAG="aarch64" ;; \
|
|
||||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
|
||||||
esac && \
|
|
||||||
WHL_VERSION="v0.7.16" && \
|
|
||||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
|
||||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
|
||||||
pip3 install --no-cache-dir "${WHL_FILE}" && \
|
|
||||||
rm "${WHL_FILE}"
|
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,mamba-ssm] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -24,16 +24,15 @@ WORKDIR /workspace/axolotl
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||||
RUN uv pip uninstall causal_conv1d
|
RUN uv pip uninstall causal_conv1d
|
||||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
BASE_EXTRAS="optimizers,ray"; \
|
||||||
else \
|
else \
|
||||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
BASE_EXTRAS="deepspeed,optimizers,ray"; \
|
||||||
fi && \
|
fi && \
|
||||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
fi && \
|
fi && \
|
||||||
python scripts/unsloth_install.py --uv | sh && \
|
|
||||||
python scripts/cutcrossentropy_install.py --uv | sh && \
|
python scripts/cutcrossentropy_install.py --uv | sh && \
|
||||||
uv pip install pytest && \
|
uv pip install pytest && \
|
||||||
uv cache clean
|
uv cache clean
|
||||||
|
|||||||
@@ -38,20 +38,3 @@ RUN uv pip install packaging setuptools wheel psutil \
|
|||||||
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||||
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
|
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Map Python version (e.g., 3.12 -> cp312)
|
|
||||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
|
||||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
|
||||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
|
||||||
LINUX_TAG="manylinux_" && \
|
|
||||||
# Map architecture
|
|
||||||
case "$TARGETARCH" in \
|
|
||||||
amd64) ARCH_TAG="2_24_x86_64.manylinux_2_28_x86_64" ;; \
|
|
||||||
arm64) ARCH_TAG="2_34_aarch64" ;; \
|
|
||||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
|
||||||
esac && \
|
|
||||||
WHL_VERSION="v0.7.16" && \
|
|
||||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-${LINUX_TAG}${ARCH_TAG}.whl" && \
|
|
||||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
|
||||||
uv pip install --no-cache-dir "${WHL_FILE}" && \
|
|
||||||
rm "${WHL_FILE}"
|
|
||||||
|
|||||||
70
docs/1_58bit_finetuning.qmd
Normal file
70
docs/1_58bit_finetuning.qmd
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
---
|
||||||
|
title: "1.58-bit Finetuning"
|
||||||
|
back-to-top-navigation: true
|
||||||
|
toc: true
|
||||||
|
toc-expand: 2
|
||||||
|
toc-depth: 4
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
1.58-bit finetuning allows you to finetune BitNet models when their prequantized weights are provided. In theory, it will be possible to fine-tune any LLM in 1.58bit format but the performance degradation will be dramatic.
|
||||||
|
|
||||||
|
Axolotl supports 1.58-bit finetuning via the [`onebitllms`](https://github.com/tiiuae/onebitllms) library, which replaces standard linear layers with BitNet-compatible counterparts ready to use for training.
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
LoRA is not supported for BitNet models
|
||||||
|
:::
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Install the `onebitllms` package before using this feature:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install onebitllms
|
||||||
|
```
|
||||||
|
|
||||||
|
Or from source:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install git+https://github.com/tiiuae/onebitllms
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported models
|
||||||
|
|
||||||
|
For now, only `Falcon-E` series of models are supported. Make sure to use their `-prequantized` version:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tiiuae/Falcon-E-3B-Base-prequantized
|
||||||
|
tiiuae/Falcon-E-1B-Base-prequantized
|
||||||
|
```
|
||||||
|
|
||||||
|
In theory, any other model would 'work' but the performance degradation will be huge. This remains an area of exploration.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
To enable 1.58-bit finetuning, set the following in your configuration file:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: tiiuae/Falcon-E-3B-Base-prequantized # A BitNet-compatible model
|
||||||
|
|
||||||
|
use_onebitllms: true
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
For BitNet models, it is recommended to use a higher learning rate than classic models (usually in the order of magnitude of 10x).
|
||||||
|
:::
|
||||||
|
|
||||||
|
## Considerations after training
|
||||||
|
|
||||||
|
Once your model has been trained with 1.58bit fine-tuning, you can convert the trained model in ternary format using the `onebitllms` CLI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
onebitllms quantize_to_1bit INPUT_PATH OUTPUT_PATH
|
||||||
|
```
|
||||||
|
|
||||||
|
After that, you can use supported packages such as `llama.cpp` or Apple MLX package to run the trained model.
|
||||||
|
|
||||||
|
## Example Configuration
|
||||||
|
|
||||||
|
You can find example configurations in `examples/falcon-e` which contain one configuration for SFT and one configuration for DPO.
|
||||||
198
docs/agents/model_architectures.md
Normal file
198
docs/agents/model_architectures.md
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
# Model Architectures — Agent Reference
|
||||||
|
|
||||||
|
Model-specific quirks, required settings, and known issues. Check this before debugging training failures on specific model families.
|
||||||
|
|
||||||
|
## VLM (Vision Language Model) Quick Start
|
||||||
|
|
||||||
|
All VLM configs require these four lines:
|
||||||
|
```yaml
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
```
|
||||||
|
|
||||||
|
Decision tree for VLM config:
|
||||||
|
```text
|
||||||
|
Is the model multimodal (has vision/audio encoder)?
|
||||||
|
├─ YES: Add `freeze_mm_modules: true` if training text only
|
||||||
|
│ Add `chat_template: <model_template>` (e.g. gemma4, qwen3_5, gemma3)
|
||||||
|
│ LoRA: use regex `lora_target_modules` to restrict to language model
|
||||||
|
└─ NO: Train as a regular text model
|
||||||
|
|
||||||
|
Is the model MoE (e.g. Gemma4 26B-A4B, Qwen3.5 35B-A3B)?
|
||||||
|
├─ YES: Add `lora_target_parameters` for expert LoRA
|
||||||
|
│ Consider ScatterMoE kernels (see Plugins section)
|
||||||
|
└─ NO: Standard LoRA config
|
||||||
|
```
|
||||||
|
|
||||||
|
## Plugins & Optimizations
|
||||||
|
|
||||||
|
### Cut Cross Entropy (CCE)
|
||||||
|
|
||||||
|
Computes loss from hidden states + lm_head weight without materializing the full logits tensor, saving significant VRAM. Install if not already present:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
|
||||||
|
```
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
```
|
||||||
|
|
||||||
|
### ScatterMoE Kernels
|
||||||
|
|
||||||
|
Fuses expert + LoRA computation into a single kernel for MoE models. Significant speedup for models with many experts.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
experts_implementation: scattermoe
|
||||||
|
|
||||||
|
# Expert LoRA targets (3D parameter tensors, not nn.Linear):
|
||||||
|
lora_target_parameters:
|
||||||
|
- experts.gate_up_proj
|
||||||
|
- experts.down_proj
|
||||||
|
```
|
||||||
|
|
||||||
|
Supported: Gemma4 (`gemma4_text`), Mixtral, Qwen MoE variants. The plugin auto-detects model type and routing function. Without ScatterMoE, expert LoRA still works but runs base expert matmul and LoRA as separate operations.
|
||||||
|
|
||||||
|
## Gemma 4
|
||||||
|
|
||||||
|
**Models**: `google/gemma-4-26B-A4B` (MoE), `google/gemma-4-31B` (dense), `google/gemma-4-E2B`, `google/gemma-4-E4B`
|
||||||
|
|
||||||
|
**Architecture**: Multimodal wrapper (`Gemma4ForConditionalGeneration`) over a text backbone (`Gemma4TextModel`), with optional vision/audio encoders. All Gemma4 HF repos have `model_type: "gemma4"` — even text-only variants load as multimodal with a vision tower.
|
||||||
|
|
||||||
|
### Required settings
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Always needed for Gemma4:
|
||||||
|
freeze_mm_modules: true # Freeze vision/audio encoders for text-only training
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false # Shared per-layer norms cause "marked ready twice" with reentrant
|
||||||
|
|
||||||
|
# LoRA target — restrict to language model only (DO NOT use lora_target_linear: true):
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Auto-detection
|
||||||
|
|
||||||
|
Axolotl auto-detects Gemma4 and applies:
|
||||||
|
- `use_reentrant: false` for gradient checkpointing
|
||||||
|
- `ddp_find_unused_parameters: true` for DDP (skipped when `activation_offloading: true`)
|
||||||
|
|
||||||
|
### Multi-GPU
|
||||||
|
|
||||||
|
| Strategy | Works? | Notes |
|
||||||
|
|----------|--------|-------|
|
||||||
|
| DDP | Yes | Auto-sets `ddp_find_unused_parameters=True` |
|
||||||
|
| DDP + activation_offloading | Yes | `find_unused_parameters` is skipped (conflicts with checkpoint wrappers) |
|
||||||
|
| FSDP1 | No | OOM during dequantization/sharding with QLoRA |
|
||||||
|
| FSDP2 | Yes | Use `Gemma4TextDecoderLayer` (not `Gemma4DecoderLayer`) as wrap class |
|
||||||
|
| FSDP2 + activation_offloading | Yes | Lowest VRAM (~26 GiB/GPU for 26B-A4B) |
|
||||||
|
|
||||||
|
FSDP2 config:
|
||||||
|
```yaml
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_version: 2
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer
|
||||||
|
```
|
||||||
|
|
||||||
|
### MoE (26B-A4B)
|
||||||
|
|
||||||
|
- `enable_moe_block: true`, 256 experts, top-k routing
|
||||||
|
- No separate `SparseMoeBlock` — MoE is embedded in each decoder layer
|
||||||
|
- Expert LoRA targets 3D parameter tensors:
|
||||||
|
```yaml
|
||||||
|
lora_target_parameters:
|
||||||
|
- experts.gate_up_proj
|
||||||
|
- experts.down_proj
|
||||||
|
```
|
||||||
|
- ScatterMoE kernel acceleration:
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
experts_implementation: scattermoe
|
||||||
|
```
|
||||||
|
|
||||||
|
### VLM (Vision) Training
|
||||||
|
|
||||||
|
All Gemma4 models load as `Gemma4ForConditionalGeneration` with a vision tower. No custom `ProcessingStrategy` needed — the base class auto-detects the image token.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
freeze_mm_modules: true
|
||||||
|
chat_template: gemma4
|
||||||
|
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
```
|
||||||
|
|
||||||
|
A starting VLM loss of ~8-15 is typical. In most runs, loss converges below 1.0 within ~30-50 steps, though results may vary across configurations.
|
||||||
|
|
||||||
|
For the 26B-A4B MoE variant with ScatterMoE + expert LoRA + CCE, add:
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
experts_implementation: scattermoe
|
||||||
|
lora_target_parameters:
|
||||||
|
- experts.gate_up_proj
|
||||||
|
- experts.down_proj
|
||||||
|
```
|
||||||
|
|
||||||
|
### Common issues
|
||||||
|
|
||||||
|
| Symptom | Cause | Fix |
|
||||||
|
|---------|-------|-----|
|
||||||
|
| `mm_token_type_ids is required` in DDP | `model.config` not accessible through DDP wrapper | Already fixed — `unwrap_model()` in `compute_loss` and `prediction_step` |
|
||||||
|
| `marked a variable ready twice` in DDP | `ddp_find_unused_parameters=True` + activation_offloading checkpoint wrappers | Auto-handled — `find_unused_parameters` is skipped when `activation_offloading: true` |
|
||||||
|
| Loss ~12 instead of ~0.5 | Using `lora_target_linear: true` (applies LoRA to vision/audio modules) | Use the regex `lora_target_modules` pattern instead |
|
||||||
|
| FSDP2 `Could not find Gemma4AudioLayer` | Auto-wrap detects `_no_split_modules` including audio layers that don't exist | Explicitly set `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer` |
|
||||||
|
| `Gemma4ClippableLinear not supported` by PEFT | Vision tower uses a non-standard linear wrapper | Axolotl patches this automatically via `_patch_peft_clippable_linear()` |
|
||||||
|
|
||||||
|
### E2B/E4B dense models
|
||||||
|
|
||||||
|
These have `hidden_size_per_layer_input: 256` (per-layer input embeddings) and `attention_k_eq_v: False`. Known issue: loss starts higher than expected (~12 vs ~0.5 for 26B). Root cause under investigation — may be related to the per-layer input mechanism or the `Gemma4ForConditionalGeneration` loss computation.
|
||||||
|
|
||||||
|
## Gemma 3
|
||||||
|
|
||||||
|
**Models**: `google/gemma-3-*`
|
||||||
|
|
||||||
|
- `ddp_find_unused_parameters: true` needed (multimodal unused params)
|
||||||
|
- `use_reentrant: false` recommended
|
||||||
|
- Attention mask must be dropped for sample packing (handled automatically)
|
||||||
|
- Multi-GPU test currently skipped (`tests/e2e/multigpu/test_gemma3.py`)
|
||||||
|
|
||||||
|
## Qwen 3.5 MoE
|
||||||
|
|
||||||
|
**Models**: `Qwen/Qwen3.5-35B-A3B`
|
||||||
|
|
||||||
|
- Hybrid architecture: DeltaNet linear attention (30 layers) + full attention (10 layers)
|
||||||
|
- 256 experts, 8 active per token
|
||||||
|
- Known weight scale drift in late DeltaNet layers (36-38) due to AdamW + rare expert interaction
|
||||||
|
- Fix: `normalize_weight_scales` config to detect and rescale outliers:
|
||||||
|
```yaml
|
||||||
|
normalize_weight_scales:
|
||||||
|
- name_pattern: 'linear_attn\.conv1d\.weight'
|
||||||
|
threshold: 1.3
|
||||||
|
```
|
||||||
|
|
||||||
|
## General MoE Notes
|
||||||
|
|
||||||
|
- `lora_target_linear: true` with multimodal MoE models will apply LoRA to ALL linear modules including vision/audio encoders — use regex `lora_target_modules` to restrict to language model only
|
||||||
|
- Rare experts get larger effective learning rate from AdamW (small second-moment estimates) — can cause weight drift in recurrent/SSM components. Use `normalize_weight_scales` with `dry_run: true` to detect.
|
||||||
|
- For ScatterMoE kernel support, set `experts_implementation: scattermoe` and add the KernelsPlugin
|
||||||
181
docs/agents/new_model_support.md
Normal file
181
docs/agents/new_model_support.md
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
# New Model Support — Agent Reference
|
||||||
|
|
||||||
|
Guide for debugging and adding support for new model architectures in axolotl. Based on lessons learned from Gemma4, Gemma3, Qwen2-VL, and other multimodal/MoE models.
|
||||||
|
|
||||||
|
## Quick Validation Checklist
|
||||||
|
|
||||||
|
When testing a new model, run through these checks in order:
|
||||||
|
|
||||||
|
1. **Does the model load?** `axolotl preprocess config.yaml` — catches config schema errors
|
||||||
|
2. **Does LoRA apply?** Check for "Unsupported layer type" warnings from PEFT
|
||||||
|
3. **Is the initial loss sane?** First-step loss for a pretrained model should be 0.5–2.0 for SFT
|
||||||
|
4. **Does sample packing work?** Compare loss with `sample_packing: true` vs `false` — should be similar
|
||||||
|
5. **Is CCE active?** Check for "Applying Cut Cross Entropy" log and verify peak VRAM is lower
|
||||||
|
|
||||||
|
## Loss Debugging
|
||||||
|
|
||||||
|
### Expected initial loss
|
||||||
|
A pretrained model doing SFT should start with loss roughly in the 0.5–2.0 range. If loss starts above 3.0, something is wrong. If it's near `log(vocab_size)` (≈ 12 for 262K vocab), the model is predicting at random — attention masking or model weights are broken.
|
||||||
|
|
||||||
|
### Direct comparison technique
|
||||||
|
The fastest way to isolate a loss issue — bypass the trainer entirely:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Load model via axolotl's pipeline (applies all patches)
|
||||||
|
from axolotl.cli.config import load_cfg
|
||||||
|
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||||
|
from axolotl.loaders.tokenizer import load_tokenizer
|
||||||
|
from axolotl.loaders.model import ModelLoader
|
||||||
|
|
||||||
|
cfg = load_cfg("your_config.yaml")
|
||||||
|
normalize_config(cfg)
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
tokenizer = load_tokenizer(cfg)
|
||||||
|
model, _ = ModelLoader(cfg, tokenizer).load()
|
||||||
|
|
||||||
|
# Forward pass on preprocessed data
|
||||||
|
model.train()
|
||||||
|
out = model(input_ids, labels=labels)
|
||||||
|
print(f"Direct loss: {out.loss.item()}") # Compare to trainer's reported loss
|
||||||
|
```
|
||||||
|
|
||||||
|
If direct loss is correct (~1.0) but trainer reports 3–4x higher, check `model_accepts_loss_kwargs` (see below).
|
||||||
|
|
||||||
|
### `model_accepts_loss_kwargs` inflation
|
||||||
|
HF Trainer checks if the model's `forward()` has `**kwargs` and sets `model_accepts_loss_kwargs=True`. This changes loss normalization: the trainer does NOT divide loss by `gradient_accumulation_steps` before logging. The gradient is correct — only the logged loss is inflated.
|
||||||
|
|
||||||
|
**Symptom**: Logged loss ≈ actual_loss × gradient_accumulation_steps.
|
||||||
|
|
||||||
|
**Which models are affected**: Any model with `**kwargs` in forward (common in multimodal models for extra inputs like `mm_token_type_ids`, `pixel_values`, etc.).
|
||||||
|
|
||||||
|
**Fix location**: `src/axolotl/core/trainers/base.py` `__init__()` — after `super().__init__()`, check if the unwrapped model actually has `num_items_in_batch` in its forward signature. If not, set `self.model_accepts_loss_kwargs = False`.
|
||||||
|
|
||||||
|
## Multimodal Models (ForConditionalGeneration)
|
||||||
|
|
||||||
|
Many recent models use `ForConditionalGeneration` as the top-level class, not `ForCausalLM`:
|
||||||
|
- Gemma3 → `Gemma3ForConditionalGeneration`
|
||||||
|
- Gemma4 → `Gemma4ForConditionalGeneration`
|
||||||
|
- Qwen2-VL → `Qwen2VLForConditionalGeneration`
|
||||||
|
- LLaVA → `LlavaForConditionalGeneration`
|
||||||
|
|
||||||
|
### Why this matters
|
||||||
|
|
||||||
|
| Component | Targets `ForCausalLM` | Needs `ForConditionalGeneration` |
|
||||||
|
|-----------|----------------------|--------------------------------|
|
||||||
|
| CCE patches | ✅ (default) | ❌ silently inactive if not patched |
|
||||||
|
| PEFT LoRA | ✅ | May fail on custom layer types |
|
||||||
|
| HF Trainer label handling | ✅ | May need extra inputs |
|
||||||
|
|
||||||
|
### Required extra inputs
|
||||||
|
Multimodal models require special inputs during training even for text-only data:
|
||||||
|
|
||||||
|
| Model | Required Input | Value for Text-Only |
|
||||||
|
|-------|---------------|-------------------|
|
||||||
|
| Gemma4 | `mm_token_type_ids` | `torch.zeros_like(input_ids)` |
|
||||||
|
| Gemma3 | `token_type_ids` | `torch.zeros_like(input_ids)` |
|
||||||
|
|
||||||
|
Auto-inject in `compute_loss()` when not provided by the data collator. See `core/trainers/base.py`.
|
||||||
|
|
||||||
|
### Custom layer types and PEFT
|
||||||
|
Vision towers often use custom module wrappers that PEFT doesn't support:
|
||||||
|
|
||||||
|
| Model | Custom Layer | Wraps | Fix |
|
||||||
|
|-------|-------------|-------|-----|
|
||||||
|
| Gemma4 | `Gemma4ClippableLinear` | `nn.Linear` | Redirect to `.linear` child |
|
||||||
|
|
||||||
|
Fix location: `src/axolotl/loaders/adapter.py` `_patch_peft_clippable_linear()`.
|
||||||
|
|
||||||
|
## Sample Packing
|
||||||
|
|
||||||
|
### How packed sequence detection works (transformers ≥ 5.x)
|
||||||
|
`transformers.masking_utils._preprocess_mask_arguments()` detects packed sequences from `position_ids` resets. But **only when `attention_mask is None`**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# From masking_utils.py:
|
||||||
|
if position_ids is not None and attention_mask is None and past_key_values is None:
|
||||||
|
packed_sequence_mask = find_packed_sequence_indices(position_ids)
|
||||||
|
```
|
||||||
|
|
||||||
|
If the collator provides an all-ones `attention_mask`, packing detection is **skipped** and the model builds a single causal mask spanning all packed sequences → cross-sequence attention leakage → very high loss.
|
||||||
|
|
||||||
|
### Fix for models using `create_causal_mask_mapping`
|
||||||
|
For Gemma3, Gemma4, and similar models that use the new transformers masking system, remove `attention_mask` from inputs when sample packing is active:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In compute_loss():
|
||||||
|
if (
|
||||||
|
self.args.sample_packing
|
||||||
|
and model_type in ("gemma4", "gemma3")
|
||||||
|
and "attention_mask" in inputs
|
||||||
|
and "position_ids" in inputs
|
||||||
|
):
|
||||||
|
del inputs["attention_mask"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Fix location: `src/axolotl/core/trainers/base.py` `compute_loss()`.
|
||||||
|
|
||||||
|
### Models that DON'T need this fix
|
||||||
|
Older models that use `_prepare_4d_causal_attention_mask` (Llama, Mistral, Qwen2, etc.) handle sample packing via axolotl's multipack attention monkeypatch instead. Only models using the new `create_causal_mask_mapping` / `create_causal_mask` masking system need the `attention_mask` removal.
|
||||||
|
|
||||||
|
## Attention Backend Selection
|
||||||
|
|
||||||
|
| Backend | Config | head_dim limit | torch_compile | Notes |
|
||||||
|
|---------|--------|---------------|---------------|-------|
|
||||||
|
| FA2 | `attn_implementation: flash_attention_2` | 256 | ✅ | Fastest when supported |
|
||||||
|
| FA4 | auto with `attn_implementation: flash_attention_2` | 256 (SM90+) | ✅ | Auto-detected on H100+ |
|
||||||
|
| SDPA | `attn_implementation: sdpa` | None | ✅ | Universal fallback |
|
||||||
|
| flex | `attn_implementation: flex_attention` | None | ⚠️ Triton OOM for large head_dim | Good for variable head dims |
|
||||||
|
| eager | `attn_implementation: eager` | None | ✅ | Slowest, always works |
|
||||||
|
|
||||||
|
**Check model support**: Look at `_supports_flash_attn_2`, `_supports_flex_attn`, `_supports_sdpa` attributes on the model class.
|
||||||
|
|
||||||
|
**head_dim gotcha**: The 256 limit is specific to flash-attn CUDA kernels, NOT PyTorch-level. SDPA and flex_attention both handle arbitrary head_dim. Models with `global_head_dim > 256` (Gemma4: 512) must use SDPA or flex.
|
||||||
|
|
||||||
|
**flex + compile gotcha**: `torch_compile` with flex_attention can hit Triton shared memory OOM for large head_dim. Falls back to eager per-function (not a crash, but slower). Unsloth disables flex for Gemma4 for this reason.
|
||||||
|
|
||||||
|
## Cut Cross Entropy (CCE)
|
||||||
|
|
||||||
|
### How CCE patches work
|
||||||
|
CCE replaces the model's `forward()` with a fused version that computes loss from hidden states + lm_head weight without materializing the full logits tensor. This saves ~`batch × seq_len × vocab_size × dtype_bytes` of VRAM.
|
||||||
|
|
||||||
|
### Adding CCE for a new model
|
||||||
|
1. Check if the model type is in `cut_cross_entropy.transformers.patch.PATCH_FNS`
|
||||||
|
2. If not, axolotl's generic fallback (`integrations/cut_cross_entropy/__init__.py` `patch_llama_like()`) patches `{Prefix}ForCausalLM.forward` with `cce_forward`
|
||||||
|
3. For multimodal models (`ForConditionalGeneration`), a model-specific patch is needed in `ml-cross-entropy` repo
|
||||||
|
4. The multimodal `cce_forward` must accept all extra kwargs (pixel_values, mm_token_type_ids, etc.) and pop any that would conflict before calling `self.model()`
|
||||||
|
|
||||||
|
### Common CCE pitfall
|
||||||
|
If CCE appears active (log says "Applying Cut Cross Entropy") but peak VRAM doesn't decrease, check which class was patched. If the model loads as `ForConditionalGeneration` but CCE patched `ForCausalLM`, the patch is silently inactive.
|
||||||
|
|
||||||
|
## MoE Models
|
||||||
|
|
||||||
|
### Dense MLP vs MoE experts
|
||||||
|
Some MoE models (e.g., Gemma4) have BOTH dense MLP layers and MoE expert layers at every decoder layer:
|
||||||
|
- `gate_proj/up_proj/down_proj` → targets the **dense MLP** (`Gemma4TextMLP`)
|
||||||
|
- `experts.gate_up_proj/experts.down_proj` → targets the **MoE experts** (`Gemma4TextExperts`)
|
||||||
|
|
||||||
|
LoRA on the dense MLP works normally. Expert LoRA via `lora_target_parameters` requires PEFT support for the specific expert module type (may warn "Unsupported layer type").
|
||||||
|
|
||||||
|
### ScatterMoE kernels
|
||||||
|
`use_scattermoe: true` with `experts_implementation: scattermoe` registers fused expert kernels via transformers' `ExpertsInterface`. Significant speedup for MoE models. Requires the kernels plugin:
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
experts_implementation: scattermoe
|
||||||
|
```
|
||||||
|
|
||||||
|
## Where to Add Model-Specific Fixes
|
||||||
|
|
||||||
|
| What | Where | Example |
|
||||||
|
|------|-------|---------|
|
||||||
|
| Missing forward inputs | `core/trainers/base.py` `compute_loss()` | mm_token_type_ids injection |
|
||||||
|
| Attention mask fixes | `core/trainers/base.py` `compute_loss()` | Sample packing mask removal |
|
||||||
|
| Loss logging fixes | `core/trainers/base.py` `__init__()` | model_accepts_loss_kwargs override |
|
||||||
|
| PEFT/LoRA patches | `loaders/adapter.py` | ClippableLinear redirect |
|
||||||
|
| Attention patches | `monkeypatch/attention/` | FA4 tuple fix |
|
||||||
|
| Model-specific patches | `loaders/patch_manager.py` `_apply_model_specific_patches()` | Llama4, Kimi, NemotronH |
|
||||||
|
| CCE patches | `ml-cross-entropy` repo `transformers/` | Per-model cce_forward |
|
||||||
|
| Example configs | `examples/<model>/` | Validated YAML |
|
||||||
|
| Config validation | `utils/schemas/validation.py` | Compatibility checks |
|
||||||
@@ -38,7 +38,7 @@ No vLLM server needed (unlike GRPO). Offline RL with pre-collected preference da
|
|||||||
|
|
||||||
1. Paired preference data (chosen + rejected)?
|
1. Paired preference data (chosen + rejected)?
|
||||||
- Default → `rl: dpo`
|
- Default → `rl: dpo`
|
||||||
- Overfitting → `rl: ipo`
|
- Overfitting → `rl: dpo, dpo_loss_type: ["ipo"]`
|
||||||
- VRAM-limited → `rl: orpo` (no ref model)
|
- VRAM-limited → `rl: orpo` (no ref model)
|
||||||
- Length-sensitive → `rl: simpo` (no ref model)
|
- Length-sensitive → `rl: simpo` (no ref model)
|
||||||
2. Only binary labels (good/bad)? → `rl: kto`
|
2. Only binary labels (good/bad)? → `rl: kto`
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
|
|||||||
| Issue | Fix |
|
| Issue | Fix |
|
||||||
|-------|-----|
|
|-------|-----|
|
||||||
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
|
| OOM during training | Reduce `micro_batch_size`, enable `gradient_checkpointing`, reduce `sequence_len` |
|
||||||
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `flash_attention: true` or disable `sample_packing` |
|
| `sample_packing` + SDPA + bf16 = 0.0 loss | Use `attn_implementation: flash_attention_2` or disable `sample_packing` |
|
||||||
| Missing chat template error | Set `chat_template: chatml` explicitly |
|
| Missing chat template error | Set `chat_template: chatml` explicitly |
|
||||||
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
|
| Label masking wrong | Run `axolotl preprocess config.yaml --debug` and inspect labels |
|
||||||
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |
|
| Loss NaN | Use `bf16: auto`, lower LR, check data for empty samples |
|
||||||
@@ -91,6 +91,30 @@ Watch for: loss never decreasing (check `train_on_inputs`, dataset, LR), loss go
|
|||||||
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
|
| FSDP save hangs | Use `fsdp_state_dict_type: FULL_STATE_DICT` |
|
||||||
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
| DeepSpeed CheckpointError | Set `use_reentrant: true` in `gradient_checkpointing_kwargs` |
|
||||||
|
|
||||||
|
## Profiling
|
||||||
|
|
||||||
|
To profile training and identify optimization opportunities:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Profile steps 3-7 (after warmup/autotuning settles)
|
||||||
|
profiler_steps_start: 3
|
||||||
|
profiler_steps: 5
|
||||||
|
```
|
||||||
|
|
||||||
|
This produces `profiler_trace.json` (Chrome trace) and `snapshot.pickle` (memory snapshot) in `output_dir`.
|
||||||
|
View the Chrome trace at `chrome://tracing`.
|
||||||
|
|
||||||
|
To programmatically inspect the trace:
|
||||||
|
```bash
|
||||||
|
python scripts/analyze_profile.py output_dir/
|
||||||
|
```
|
||||||
|
|
||||||
|
The trace shows per-kernel CUDA times, memory allocations, and operator-level breakdown. Look for:
|
||||||
|
- **Large matmul kernels**: candidates for fusion or quantization
|
||||||
|
- **Memory copies (H2D/D2H)**: unnecessary data movement
|
||||||
|
- **Small frequent kernels**: candidates for kernel fusion
|
||||||
|
- **Gaps between kernels**: pipeline bubbles from CPU overhead
|
||||||
|
|
||||||
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
|
Full troubleshooting: [training_stability.qmd](../training_stability.qmd), [debugging.qmd](../debugging.qmd)
|
||||||
|
|
||||||
## File Map
|
## File Map
|
||||||
|
|||||||
@@ -3,28 +3,71 @@ title: Attention
|
|||||||
description: Supported attention modules in Axolotl
|
description: Supported attention modules in Axolotl
|
||||||
---
|
---
|
||||||
|
|
||||||
## SDP Attention
|
Axolotl routes attention via a single config field:
|
||||||
|
|
||||||
This is the default built-in attention in PyTorch.
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
sdp_attention: true
|
attn_implementation: <backend>
|
||||||
```
|
```
|
||||||
|
|
||||||
For more details: [PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
`attn_implementation` is passed through to `transformers` verbatim (via
|
||||||
|
`model.config._attn_implementation`). Accepted values are the HF-native
|
||||||
|
backends, axolotl-registered backends, or a hub-kernel path.
|
||||||
|
|
||||||
## Flash Attention
|
## Backends
|
||||||
|
|
||||||
Axolotl supports Flash Attention 2, 3, and 4. The best available version is used automatically
|
| `attn_implementation` | Description |
|
||||||
based on your installed packages and GPU.
|
|---|---|
|
||||||
|
| `eager` | Plain PyTorch attention. No packing support. |
|
||||||
|
| `sdpa` | PyTorch `scaled_dot_product_attention`. No packing support. |
|
||||||
|
| `flash_attention_2` | Dao-AILab Flash Attention 2. |
|
||||||
|
| `flash_attention_3` | Dao-AILab Flash Attention 3 (Hopper+). |
|
||||||
|
| `flex_attention` | Torch Flex Attention (requires torch ≥ 2.6). |
|
||||||
|
| `xformers` | xFormers memory-efficient attention. |
|
||||||
|
| `sage` | SageAttention (QK int8 / PV fp16). |
|
||||||
|
| `s2` | Shifted-Sparse Attention (LLaMA only, FA2 under the hood). |
|
||||||
|
| `fp8` | torchao FP8 low-precision attention (requires SM90+, torch ≥ 2.11). Loaded as SDPA and patched post-load. |
|
||||||
|
| `kernels-community/flash-attn3` | HF hub FA3 kernel. |
|
||||||
|
| `kernels-community/sage-attention` | HF hub SageAttention kernel. |
|
||||||
|
| Other `<org>/<name>` path | Any hub-kernel path supported by `transformers`. |
|
||||||
|
|
||||||
|
Short-form aliases (`flash`, `fa2`, `flex`, `sdp`, etc.) are **not accepted** —
|
||||||
|
set the canonical name above.
|
||||||
|
|
||||||
|
### Capability flags
|
||||||
|
|
||||||
|
Axolotl derives three boolean capability flags from `attn_implementation` and
|
||||||
|
exposes them on the validated config:
|
||||||
|
|
||||||
|
- `cfg.attn_supports_packing` — backend supports varlen sample packing via
|
||||||
|
`position_ids`. Gates multipack patches and `sample_packing_drop_attention_mask`.
|
||||||
|
- `cfg.attn_uses_flash_lib` — backend needs the `flash_attn` (Dao-AILab)
|
||||||
|
monkeypatches (FA4 auto, LLaMA flash hijack, ring-FA).
|
||||||
|
- `cfg.attn_needs_dtype_cast` — backend requires fp16/bf16 embeddings
|
||||||
|
(everything except `eager` and `sdpa`).
|
||||||
|
|
||||||
|
These are **computed** — they cannot be overridden from YAML.
|
||||||
|
|
||||||
|
## Per-backend notes
|
||||||
|
|
||||||
|
### SDPA
|
||||||
|
|
||||||
|
Default PyTorch attention. See
|
||||||
|
[PyTorch docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
flash_attention: true
|
attn_implementation: sdpa
|
||||||
```
|
```
|
||||||
|
|
||||||
For more details: [Flash Attention](https://github.com/Dao-AILab/flash-attention/)
|
### Flash Attention
|
||||||
|
|
||||||
### Flash Attention 2
|
Axolotl supports FA2, FA3, and FA4. The best available version is used
|
||||||
|
automatically based on your installed packages and GPU.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
attn_implementation: flash_attention_2 # or flash_attention_3
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Flash Attention 2
|
||||||
|
|
||||||
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
Requirements: Ampere, Ada, or Hopper GPUs (Turing or lower not supported)
|
||||||
|
|
||||||
@@ -39,23 +82,25 @@ Alternatively, try reinstall or downgrade a version.
|
|||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
### Flash Attention 3
|
#### Flash Attention 3
|
||||||
|
|
||||||
Requirements: Hopper only and CUDA 12.8 (recommended)
|
Requirements: Hopper only and CUDA 12.8 (recommended)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||||
cd flash-attention/hopper
|
cd flash-attention/hopper
|
||||||
|
|
||||||
python setup.py install
|
python setup.py install
|
||||||
```
|
```
|
||||||
|
|
||||||
### Flash Attention 4
|
#### Flash Attention 4
|
||||||
|
|
||||||
Requirements: Hopper or Blackwell GPUs
|
Requirements: Hopper or Blackwell GPUs. Auto-applied when `attn_uses_flash_lib`
|
||||||
|
is true and FA4 is importable.
|
||||||
|
|
||||||
|
FA4 is still a pre-release on PyPI, so `--pre` is required:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install flash-attn-4
|
pip install --pre flash-attn-4
|
||||||
```
|
```
|
||||||
|
|
||||||
Or from source:
|
Or from source:
|
||||||
@@ -63,7 +108,6 @@ Or from source:
|
|||||||
```bash
|
```bash
|
||||||
git clone https://github.com/Dao-AILab/flash-attention.git
|
git clone https://github.com/Dao-AILab/flash-attention.git
|
||||||
cd flash-attention/flash_attn/cute
|
cd flash-attention/flash_attn/cute
|
||||||
|
|
||||||
pip install -e .
|
pip install -e .
|
||||||
|
|
||||||
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
# FA2's flash_attn package includes a cute/ stub that shadows FA4.
|
||||||
@@ -86,93 +130,113 @@ and falls back to FA2/3.
|
|||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
For more details: [flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)
|
|
||||||
|
|
||||||
### AMD
|
### AMD
|
||||||
|
|
||||||
Requirements: ROCm 6.0 and above.
|
Requirements: ROCm 6.0 and above. See
|
||||||
|
[Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
||||||
|
|
||||||
See [Flash Attention AMD docs](https://github.com/Dao-AILab/flash-attention/tree/main?tab=readme-ov-file#amd-rocm-support).
|
### Flex Attention
|
||||||
|
|
||||||
## Flex Attention
|
|
||||||
|
|
||||||
A flexible PyTorch API for attention used in combination with `torch.compile`.
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
flex_attention: true
|
attn_implementation: flex_attention
|
||||||
|
torch_compile: true # recommended
|
||||||
# recommended
|
|
||||||
torch_compile: true
|
|
||||||
```
|
```
|
||||||
|
|
||||||
::: {.callout-note}
|
Requires torch ≥ 2.6. See [PyTorch docs](https://pytorch.org/blog/flexattention/).
|
||||||
|
|
||||||
We recommend using latest stable version of PyTorch for best performance.
|
### SageAttention
|
||||||
|
|
||||||
:::
|
Requirements: Ampere, Ada, or Hopper GPUs.
|
||||||
|
|
||||||
For more details: [PyTorch docs](https://pytorch.org/blog/flexattention/)
|
|
||||||
|
|
||||||
## SageAttention
|
|
||||||
|
|
||||||
Attention kernels with QK Int8 and PV FP16 accumulator.
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
sage_attention: true
|
attn_implementation: sage
|
||||||
```
|
```
|
||||||
|
|
||||||
Requirements: Ampere, Ada, or Hopper GPUs
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install sageattention==2.2.0 --no-build-isolation
|
pip install sageattention==2.2.0 --no-build-isolation
|
||||||
```
|
```
|
||||||
|
|
||||||
::: {.callout-warning}
|
::: {.callout-warning}
|
||||||
|
|
||||||
Only LoRA/QLoRA recommended at the moment. We found loss drop to 0 for full finetuning. See [GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
Only LoRA/QLoRA recommended. Full finetuning has been observed to drop loss to 0. See
|
||||||
|
[GitHub Issue](https://github.com/thu-ml/SageAttention/issues/198).
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention)
|
For more details: [Sage Attention](https://github.com/thu-ml/SageAttention).
|
||||||
|
|
||||||
::: {.callout-note}
|
### xFormers
|
||||||
|
|
||||||
We do not support SageAttention 3 at the moment. If you are interested on adding this or improving SageAttention implementation, please make an Issue.
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
|
|
||||||
## xFormers
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
```
|
```
|
||||||
|
|
||||||
::: {.callout-tip}
|
::: {.callout-tip}
|
||||||
|
|
||||||
We recommend using with Turing GPUs or below (such as on Colab).
|
Recommended for Turing GPUs or below (e.g. Colab T4).
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
For more details: [xFormers](https://github.com/facebookresearch/xformers)
|
### Shifted Sparse Attention
|
||||||
|
|
||||||
## Shifted Sparse Attention
|
|
||||||
|
|
||||||
::: {.callout-warning}
|
::: {.callout-warning}
|
||||||
|
|
||||||
We plan to deprecate this! If you use this feature, we recommend switching to methods above.
|
Planned for deprecation. Prefer one of the backends above.
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
Requirements: LLaMA model architecture
|
Requirements: LLaMA model architecture. Loaded as FA2 under the hood and
|
||||||
|
patched to implement shifted-sparse attention. Does not support sample packing.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
flash_attention: true
|
attn_implementation: s2
|
||||||
s2_attention: true
|
|
||||||
```
|
```
|
||||||
|
|
||||||
::: {.callout-tip}
|
### FP8
|
||||||
|
|
||||||
No sample packing support!
|
torchao low-precision attention. Loaded as SDPA and patched post-load.
|
||||||
|
|
||||||
|
Requirements: SM90+ (Hopper/Blackwell), PyTorch ≥ 2.11, torchao ≥ 0.17,
|
||||||
|
flash-attn with FA3. KV caching must be disabled.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
attn_implementation: fp8
|
||||||
|
```
|
||||||
|
|
||||||
|
### Hub kernels
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
attn_implementation: kernels-community/flash-attn3
|
||||||
|
```
|
||||||
|
|
||||||
|
Passed through to `transformers`; axolotl does not install the kernel itself.
|
||||||
|
For recognized hub paths the capability flags are set automatically; for
|
||||||
|
arbitrary paths axolotl uses conservative defaults (`attn_supports_packing=False`,
|
||||||
|
`attn_uses_flash_lib=False`).
|
||||||
|
|
||||||
|
## Migrating from legacy boolean flags
|
||||||
|
|
||||||
|
The following legacy config fields are **deprecated** and will be removed in a
|
||||||
|
future release. Each emits a `DeprecationWarning` when set and is stripped from
|
||||||
|
the validated config.
|
||||||
|
|
||||||
|
| Legacy | Canonical |
|
||||||
|
|---|---|
|
||||||
|
| `flash_attention: true` | `attn_implementation: flash_attention_2` |
|
||||||
|
| `sdp_attention: true` | `attn_implementation: sdpa` |
|
||||||
|
| `xformers_attention: true` | `attn_implementation: xformers` |
|
||||||
|
| `flex_attention: true` | `attn_implementation: flex_attention` |
|
||||||
|
| `sage_attention: true` | `attn_implementation: sage` |
|
||||||
|
| `s2_attention: true` | `attn_implementation: s2` |
|
||||||
|
| `eager_attention: true` | `attn_implementation: eager` |
|
||||||
|
|
||||||
|
Combining `attn_implementation` with a legacy flag (e.g. `attn_implementation:
|
||||||
|
flash_attention_2` **and** `flash_attention: true`) raises — pick one.
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
|
||||||
|
Existing example configs under `examples/` still use the legacy flags. They
|
||||||
|
continue to work with a deprecation warning; they will be migrated in a
|
||||||
|
follow-up pass.
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|||||||
@@ -108,6 +108,14 @@ datasets:
|
|||||||
type: chat_template
|
type: chat_template
|
||||||
```
|
```
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
`chat_template_jinja` also accepts a file path to a `.jinja2` file instead of an inline string:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
chat_template_jinja: ./path/to/my_template.jinja2
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
|
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
|
||||||
:::
|
:::
|
||||||
@@ -294,6 +302,113 @@ datasets:
|
|||||||
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
#### Content parts with per-part training control
|
||||||
|
|
||||||
|
Instead of using character offsets with `train_detail`, you can split a message's content into a list of parts, each with its own training flag. This is useful when you want to mask specific sections of a response (e.g., mask reasoning but train on the answer).
|
||||||
|
|
||||||
|
```{.json filename="data.jsonl"}
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Let me think step by step...", "train": false},
|
||||||
|
{"type": "text", "text": " The answer is 4.", "train": true}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The configuration is the same as standard `chat_template` — no extra fields needed:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
datasets:
|
||||||
|
- path: ...
|
||||||
|
type: chat_template
|
||||||
|
roles_to_train: ["assistant"]
|
||||||
|
```
|
||||||
|
|
||||||
|
Each content part supports:
|
||||||
|
|
||||||
|
- `type`: `"text"` (required)
|
||||||
|
- `text`: the text value (also accepts `content` or `value` as the key)
|
||||||
|
- `train`: `true`/`false` (optional) — whether to train on this part
|
||||||
|
- `weight`: `0`/`1` (optional) — alternative to `train`
|
||||||
|
|
||||||
|
If a part has no `train` or `weight` flag, it inherits the turn-level training decision (from `roles_to_train`, `message_field_training`, or `train_on_inputs`).
|
||||||
|
|
||||||
|
::: {.callout-warning title="Whitespace at part boundaries"}
|
||||||
|
BPE tokenizers (used by Llama, Qwen, Mistral, GPT, etc.) prepend spaces to word tokens. For example, `" answer"` is a single token — the space is part of it. This means **where you place whitespace between content parts matters**:
|
||||||
|
|
||||||
|
**Split BEFORE spaces** (space goes with the next part):
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{"type": "text", "text": "Let me think...", "train": false},
|
||||||
|
{"type": "text", "text": " The answer is 4.", "train": true}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
**DON'T put trailing spaces** on a part (the space merges with the next word into one token that straddles the boundary, and straddling tokens are masked):
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{"type": "text", "text": "Let me think... ", "train": false},
|
||||||
|
{"type": "text", "text": "The answer is 4.", "train": true}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
In the bad example, `" The"` becomes a single token that spans both parts. Because it straddles the boundary, it is conservatively **masked** (not trained) — even though the second part has `train: true`.
|
||||||
|
|
||||||
|
**Newlines** typically merge with preceding punctuation (e.g., `":\n"` is one token). Keep newlines with the preceding part:
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{"type": "text", "text": "Thinking:\n", "train": false},
|
||||||
|
{"type": "text", "text": "The answer is 4.", "train": true}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Axolotl will log a warning if it detects trailing whitespace at a boundary between parts with different training flags.
|
||||||
|
:::
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
When all content parts in a message are strings, they are concatenated before being passed to the chat template. This means content parts work with **any** Jinja template — the template sees a plain string, and the per-part training flags are applied during tokenization.
|
||||||
|
:::
|
||||||
|
|
||||||
|
##### Per-part training on reasoning_content
|
||||||
|
|
||||||
|
For templates that support a separate `reasoning_content` field (e.g., `qwen3`), the same content-parts format works on `reasoning_content`. This is useful for masking incorrect reasoning steps while training on self-corrections:
|
||||||
|
|
||||||
|
```{.json filename="data.jsonl"}
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "What is 2+2?"}]},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"reasoning_content": [
|
||||||
|
{"type": "text", "text": "Hmm maybe 2+2=5.", "train": false},
|
||||||
|
{"type": "text", "text": " Wait no, 2+2=4.", "train": true}
|
||||||
|
],
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "The answer is 4.", "train": true}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The `reasoning_content` and `content` fields are handled independently — each has its own token boundaries and per-part masking. No additional configuration is needed beyond what the template already requires.
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
When `reasoning_content` is provided as a separate field, `split_thinking` is not needed — the reasoning is already separated from the content in the data.
|
||||||
|
:::
|
||||||
|
|
||||||
|
The same whitespace rules apply to `reasoning_content` parts as to `content` parts — split before spaces, keep newlines with the preceding part.
|
||||||
|
|
||||||
|
|
||||||
#### Reasoning split
|
#### Reasoning split
|
||||||
|
|
||||||
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
||||||
|
|||||||
@@ -76,8 +76,10 @@ datasets:
|
|||||||
Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) of Axolotl, which ensures that changes you make to the code are reflected at runtime. Run the following commands from the root of this project:
|
Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) of Axolotl, which ensures that changes you make to the code are reflected at runtime. Run the following commands from the root of this project:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install packaging
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
uv venv --no-project --relocatable
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Remote Hosts
|
#### Remote Hosts
|
||||||
@@ -208,17 +210,18 @@ cd axolotl
|
|||||||
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
|
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl-uv:main-latest
|
||||||
```
|
```
|
||||||
|
|
||||||
>[!Tip]
|
>[!Tip]
|
||||||
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
||||||
|
|
||||||
You will now be in the container. Next, perform an editable install of Axolotl:
|
You will now be in the container. Next, install Axolotl with dev dependencies:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install packaging
|
uv venv --no-project --relocatable
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
source .venv/bin/activate
|
||||||
|
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
|
||||||
```
|
```
|
||||||
|
|
||||||
### Attach To Container
|
### Attach To Container
|
||||||
|
|||||||
@@ -6,23 +6,33 @@ format:
|
|||||||
toc-depth: 4
|
toc-depth: 4
|
||||||
---
|
---
|
||||||
|
|
||||||
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
This section describes the different Docker images that are released by AxolotlAI at
|
||||||
|
[Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
|
### Switch to the `-uv` images
|
||||||
|
|
||||||
|
Each image below ships a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with a relocatable venv
|
||||||
|
(`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
|
||||||
|
(e.g. `axolotlai/axolotl-uv`, `axolotlai/axolotl-base-uv`, `axolotlai/axolotl-cloud-uv`). Tags follow the
|
||||||
|
same format as their non-uv counterparts.
|
||||||
|
|
||||||
|
**We recommend switching to the `-uv` images early.** In the near future we will publish the uv-based
|
||||||
|
build to the non-uv tags as well. The non-uv names will continue to work, but they will start serving
|
||||||
|
the uv image.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## Base
|
## Base
|
||||||
|
|
||||||
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
|
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image.
|
||||||
|
It includes python, torch, git, git-lfs, awscli, pydantic, and more.
|
||||||
|
|
||||||
#### Image
|
#### Image
|
||||||
|
|
||||||
```
|
| Variant | Image | Docker Hub |
|
||||||
axolotlai/axolotl-base
|
|---------|-------|------------|
|
||||||
```
|
| pip | `axolotlai/axolotl-base` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base) |
|
||||||
|
| uv | `axolotlai/axolotl-base-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base-uv) |
|
||||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
|
|
||||||
|
|
||||||
#### Tags format
|
#### Tags format
|
||||||
|
|
||||||
@@ -32,8 +42,10 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
|||||||
|
|
||||||
Tags examples:
|
Tags examples:
|
||||||
|
|
||||||
- `main-base-py3.11-cu128-2.8.0`
|
|
||||||
- `main-base-py3.11-cu128-2.9.1`
|
- `main-base-py3.11-cu128-2.9.1`
|
||||||
|
- `main-base-py3.12-cu128-2.10.0`
|
||||||
|
- `main-base-py3.12-cu130-2.9.1`
|
||||||
|
- `main-base-py3.12-cu130-2.10.0`
|
||||||
|
|
||||||
## Main
|
## Main
|
||||||
|
|
||||||
@@ -41,11 +53,10 @@ The main image is the image that is used to run Axolotl. It is based on the `axo
|
|||||||
|
|
||||||
#### Image
|
#### Image
|
||||||
|
|
||||||
```
|
| Variant | Image | Docker Hub |
|
||||||
axolotlai/axolotl
|
|---------|-------|------------|
|
||||||
```
|
| pip | `axolotlai/axolotl` | [Link](https://hub.docker.com/r/axolotlai/axolotl) |
|
||||||
|
| uv | `axolotlai/axolotl-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-uv) |
|
||||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
|
|
||||||
|
|
||||||
#### Tags format {#sec-main-tags}
|
#### Tags format {#sec-main-tags}
|
||||||
|
|
||||||
@@ -53,7 +64,7 @@ Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
|
|||||||
# on push to main
|
# on push to main
|
||||||
main-py{python_version}-cu{cuda_version}-{pytorch_version}
|
main-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||||
|
|
||||||
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4)
|
# latest main (currently torch 2.9.1, python 3.11, cuda 12.8)
|
||||||
main-latest
|
main-latest
|
||||||
|
|
||||||
# nightly build
|
# nightly build
|
||||||
@@ -71,12 +82,13 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
|||||||
|
|
||||||
Tags examples:
|
Tags examples:
|
||||||
|
|
||||||
- `main-py3.11-cu128-2.8.0`
|
|
||||||
- `main-py3.11-cu128-2.9.1`
|
- `main-py3.11-cu128-2.9.1`
|
||||||
|
- `main-py3.12-cu128-2.10.0`
|
||||||
|
- `main-py3.12-cu130-2.9.1`
|
||||||
|
- `main-py3.12-cu130-2.10.0`
|
||||||
- `main-latest`
|
- `main-latest`
|
||||||
- `main-20250303-py3.11-cu124-2.6.0`
|
- `main-20260315-py3.11-cu128-2.9.1`
|
||||||
- `main-20250303-py3.11-cu126-2.6.0`
|
- `0.16.1`
|
||||||
- `0.12.0`
|
|
||||||
|
|
||||||
## Cloud
|
## Cloud
|
||||||
|
|
||||||
@@ -90,11 +102,10 @@ Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variab
|
|||||||
|
|
||||||
#### Image
|
#### Image
|
||||||
|
|
||||||
```
|
| Variant | Image | Docker Hub |
|
||||||
axolotlai/axolotl-cloud
|
|---------|-------|------------|
|
||||||
```
|
| pip | `axolotlai/axolotl-cloud` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud) |
|
||||||
|
| uv | `axolotlai/axolotl-cloud-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud-uv) |
|
||||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
|
|
||||||
|
|
||||||
#### Tags format
|
#### Tags format
|
||||||
|
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ gradient_accumulation_steps: 4
|
|||||||
max_steps: 20
|
max_steps: 20
|
||||||
learning_rate: 5.0e-6
|
learning_rate: 5.0e-6
|
||||||
bf16: auto
|
bf16: auto
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
output_dir: ./outputs/ebft-quickstart
|
output_dir: ./outputs/ebft-quickstart
|
||||||
```
|
```
|
||||||
@@ -304,7 +304,7 @@ lora_alpha: 32
|
|||||||
lora_target_linear: true
|
lora_target_linear: true
|
||||||
|
|
||||||
bf16: auto
|
bf16: auto
|
||||||
flex_attention: true
|
attn_implementation: flex_attention
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: true # Required with flex_attention
|
use_reentrant: true # Required with flex_attention
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
**Q: vLLM is not working with Axolotl**
|
**Q: vLLM is not working with Axolotl**
|
||||||
|
|
||||||
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
|
> A: We currently recommend torch 2.10 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.12-cu128-2.10.0` tag (note: torch 2.10 images are built with Python 3.12).
|
||||||
|
|
||||||
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
|
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
|
||||||
|
|
||||||
|
|||||||
@@ -154,7 +154,7 @@ lr_scheduler: cosine
|
|||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
|
|
||||||
bf16: true
|
bf16: true
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
@@ -15,64 +15,30 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
|||||||
|
|
||||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python ≥3.11
|
- Python ≥3.11
|
||||||
- PyTorch ≥2.6.0
|
- PyTorch ≥2.9.1
|
||||||
|
|
||||||
## Installation Methods {#sec-installation-methods}
|
## Installation {#sec-installation}
|
||||||
|
|
||||||
::: {.callout-important}
|
|
||||||
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
|
|
||||||
|
|
||||||
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
|
||||||
:::
|
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
|
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
### PyPI Installation (Recommended) {#sec-pypi}
|
### Quick Install {#sec-uv}
|
||||||
|
|
||||||
```{.bash}
|
Axolotl uses [uv](https://docs.astral.sh/uv/) as its package manager. uv is a fast, reliable Python package installer and resolver built in Rust.
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
|
||||||
```
|
|
||||||
|
|
||||||
We use `--no-build-isolation` in order to detect the installed PyTorch version (if
|
Install uv if not already installed:
|
||||||
installed) in order not to clobber it, and so that we set the correct version of
|
|
||||||
dependencies that are specific to the PyTorch version or other installed
|
|
||||||
co-dependencies.
|
|
||||||
|
|
||||||
### uv Installation {#sec-uv}
|
|
||||||
|
|
||||||
uv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.
|
|
||||||
|
|
||||||
Install uv if not already installed
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
source $HOME/.local/bin/env
|
source $HOME/.local/bin/env
|
||||||
```
|
```
|
||||||
|
|
||||||
Choose your CUDA version to use with PyTorch; e.g. `cu124`, `cu126`, `cu128`,
|
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
|
||||||
then create the venv and activate
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
export UV_TORCH_BACKEND=cu126
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
uv venv --no-project --relocatable
|
uv venv
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
```
|
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||||
|
|
||||||
Install PyTorch
|
|
||||||
- PyTorch 2.6.0 recommended
|
|
||||||
```{.bash}
|
|
||||||
uv pip install packaging setuptools wheel
|
|
||||||
uv pip install torch==2.6.0
|
|
||||||
uv pip install awscli pydantic
|
|
||||||
```
|
|
||||||
|
|
||||||
Install axolotl from PyPi
|
|
||||||
```{.bash}
|
|
||||||
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn]
|
|
||||||
|
|
||||||
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
|
|
||||||
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Edge/Development Build {#sec-edge-build}
|
### Edge/Development Build {#sec-edge-build}
|
||||||
@@ -82,14 +48,16 @@ For the latest features between releases:
|
|||||||
```{.bash}
|
```{.bash}
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install --no-build-isolation -e '.[deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Docker {#sec-docker}
|
### Docker {#sec-docker}
|
||||||
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
docker run --gpus '"all"' --rm -it --ipc=host axolotlai/axolotl-uv:main-latest
|
||||||
```
|
```
|
||||||
|
|
||||||
For development with Docker:
|
For development with Docker:
|
||||||
@@ -106,12 +74,12 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
|||||||
--ulimit memlock=-1 --ulimit stack=67108864 \
|
--ulimit memlock=-1 --ulimit stack=67108864 \
|
||||||
--mount type=bind,src="${PWD}",target=/workspace/axolotl \
|
--mount type=bind,src="${PWD}",target=/workspace/axolotl \
|
||||||
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
|
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
|
||||||
axolotlai/axolotl:main-latest
|
axolotlai/axolotl-uv:main-latest
|
||||||
```
|
```
|
||||||
:::
|
:::
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
|
For Blackwell GPUs, please use `axolotlai/axolotl-uv:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud-uv:main-py3.11-cu128-2.9.1`.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
||||||
@@ -122,7 +90,7 @@ Please refer to the [Docker documentation](docker.qmd) for more information on t
|
|||||||
|
|
||||||
For providers supporting Docker:
|
For providers supporting Docker:
|
||||||
|
|
||||||
- Use `axolotlai/axolotl-cloud:main-latest`
|
- Use `axolotlai/axolotl-cloud-uv:main-latest`
|
||||||
- Available on:
|
- Available on:
|
||||||
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||||
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link)
|
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link)
|
||||||
@@ -141,7 +109,7 @@ For providers supporting Docker:
|
|||||||
### macOS {#sec-macos}
|
### macOS {#sec-macos}
|
||||||
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
pip3 install --no-build-isolation -e '.'
|
uv pip install --no-build-isolation -e '.'
|
||||||
```
|
```
|
||||||
|
|
||||||
See @sec-troubleshooting for Mac-specific issues.
|
See @sec-troubleshooting for Mac-specific issues.
|
||||||
@@ -152,21 +120,44 @@ See @sec-troubleshooting for Mac-specific issues.
|
|||||||
We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## Environment Managers {#sec-env-managers}
|
## Migrating from pip to uv {#sec-migrating}
|
||||||
|
|
||||||
### Conda/Pip venv {#sec-conda}
|
If you have an existing pip-based Axolotl installation, you can migrate to uv:
|
||||||
|
|
||||||
1. Install Python ≥3.11
|
```{.bash}
|
||||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
# Install uv
|
||||||
3. Install Axolotl:
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
```{.bash}
|
source $HOME/.local/bin/env
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
# Create a fresh venv (recommended for a clean start)
|
||||||
```
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
4. (Optional) Login to Hugging Face:
|
uv venv
|
||||||
```{.bash}
|
source .venv/bin/activate
|
||||||
hf auth login
|
|
||||||
```
|
# Reinstall axolotl
|
||||||
|
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using pip (Alternative) {#sec-pip}
|
||||||
|
|
||||||
|
If you are unable to install uv, you can still use pip directly.
|
||||||
|
|
||||||
|
::: {.callout-important}
|
||||||
|
Please make sure to have PyTorch installed before installing Axolotl with pip.
|
||||||
|
|
||||||
|
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
||||||
|
:::
|
||||||
|
|
||||||
|
```{.bash}
|
||||||
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
|
pip3 install --no-build-isolation axolotl[deepspeed]
|
||||||
|
```
|
||||||
|
|
||||||
|
For editable/development installs:
|
||||||
|
```{.bash}
|
||||||
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
|
pip3 install --no-build-isolation -e '.[deepspeed]'
|
||||||
|
```
|
||||||
|
|
||||||
## Troubleshooting {#sec-troubleshooting}
|
## Troubleshooting {#sec-troubleshooting}
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ format:
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
|
- [Gemma-4](#sec-gemma-4) *(NEW)*
|
||||||
- [Mllama](#sec-mllama)
|
- [Mllama](#sec-mllama)
|
||||||
- [Llama4](#sec-llama4)
|
- [Llama4](#sec-llama4)
|
||||||
- [Pixtral](#sec-pixtral)
|
- [Pixtral](#sec-pixtral)
|
||||||
@@ -138,6 +139,40 @@ base_model: mistralai/Voxtral-Mini-3B-2507
|
|||||||
processor_type: VoxtralProcessor
|
processor_type: VoxtralProcessor
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Gemma-4 {#sec-gemma-4}
|
||||||
|
|
||||||
|
All Gemma 4 variants (E2B, E4B, 26B-A4B, 31B) load as multimodal models even for text-only training.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: google/gemma-4-E2B-it # or E4B-it, 26B-A4B, 31B
|
||||||
|
|
||||||
|
chat_template: gemma4
|
||||||
|
freeze_mm_modules: true # freeze vision/audio encoders for text-only or vision LoRA
|
||||||
|
|
||||||
|
# For the 26B-A4B MoE model, enable ScatterMoE and expert LoRA:
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
- axolotl.integrations.kernels.KernelsPlugin
|
||||||
|
use_kernels: true
|
||||||
|
use_scattermoe: true
|
||||||
|
experts_implementation: scattermoe
|
||||||
|
|
||||||
|
lora_target_modules: 'model.language_model.layers.[\d]+.(_checkpoint_wrapped_module.)?(mlp|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
# MoE expert LoRA (3D tensors, not nn.Linear) — only for 26B-A4B:
|
||||||
|
lora_target_parameters:
|
||||||
|
- experts.gate_up_proj
|
||||||
|
- experts.down_proj
|
||||||
|
```
|
||||||
|
|
||||||
|
::: {.callout-warning}
|
||||||
|
Gemma 4 VLM training starts with high loss (~8-15). This is expected — see the [training stability guide](training_stability.qmd) for details.
|
||||||
|
:::
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
For DDP training, axolotl auto-detects Gemma4 and sets `use_reentrant=False` and `ddp_find_unused_parameters=True`. However, when `activation_offloading: true`, `ddp_find_unused_parameters` is skipped (checkpoint wrappers conflict with it); use `freeze_mm_modules: true` instead to handle unused vision/audio params. For FSDP2, use `fsdp_transformer_layer_cls_to_wrap: Gemma4TextDecoderLayer`.
|
||||||
|
:::
|
||||||
|
|
||||||
### Gemma-3 {#sec-gemma-3}
|
### Gemma-3 {#sec-gemma-3}
|
||||||
|
|
||||||
::: {.callout-tip}
|
::: {.callout-tip}
|
||||||
|
|||||||
84
docs/multimodal_assistant_mask.md
Normal file
84
docs/multimodal_assistant_mask.md
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
# Multimodal assistant-only loss masking
|
||||||
|
|
||||||
|
## Correct placement
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Top-level: only train_on_inputs lives here.
|
||||||
|
train_on_inputs: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: data/train.jsonl
|
||||||
|
type: chat_template
|
||||||
|
roles_to_train: # per-dataset — this is what the MM scanner reads
|
||||||
|
- assistant
|
||||||
|
train_on_eos: turn # per-dataset — same
|
||||||
|
|
||||||
|
test_datasets:
|
||||||
|
- path: data/val.jsonl
|
||||||
|
type: chat_template
|
||||||
|
split: train
|
||||||
|
roles_to_train:
|
||||||
|
- assistant
|
||||||
|
train_on_eos: turn
|
||||||
|
```
|
||||||
|
|
||||||
|
## How to verify at runtime
|
||||||
|
|
||||||
|
`build_collator` logs the resolved knobs at INFO:
|
||||||
|
|
||||||
|
```text
|
||||||
|
MM collator: train_on_inputs=False roles_to_train=['assistant'] train_on_eos=turn role_boundaries_override=none
|
||||||
|
```
|
||||||
|
|
||||||
|
If `roles_to_train` logs as `None`, the YAML knobs are not reaching the
|
||||||
|
scanner — check that they are under `datasets[0]`, not at the root.
|
||||||
|
|
||||||
|
Each verified strategy additionally logs its resolved boundary token ids at
|
||||||
|
strategy init (e.g. `<|turn>model` → `[105, 4368]`, `<turn|>` → `[106]` for
|
||||||
|
Gemma 4). If a strategy emits the "has no built-in role boundaries ... only
|
||||||
|
pad and media tokens are masked" one-shot warning instead, it is on the
|
||||||
|
fallback path — declare per-role markers in YAML via `cfg.role_boundaries`
|
||||||
|
(below) to activate masking. The strategies currently on this path are
|
||||||
|
listed in the audit table above under `fallback + warn`.
|
||||||
|
|
||||||
|
## Config-based override: `cfg.role_boundaries`
|
||||||
|
|
||||||
|
For the "unverified" strategies above, or for custom chat templates that
|
||||||
|
don't match a built-in strategy's markers, users can declare role boundaries
|
||||||
|
directly in YAML without subclassing:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
role_boundaries:
|
||||||
|
- role: assistant
|
||||||
|
start: "<|turn>model"
|
||||||
|
end: "<turn|>"
|
||||||
|
- role: user
|
||||||
|
start: "<|turn>user"
|
||||||
|
end: "<turn|>"
|
||||||
|
# Optional keys:
|
||||||
|
# include_start: false # default False
|
||||||
|
# include_end: true # default True, respects cfg.train_on_eos
|
||||||
|
# end: eos_token # sentinel: resolves to tokenizer.eos_token_id
|
||||||
|
# end: null # span runs to end of sequence
|
||||||
|
```
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
|
||||||
|
- `start` and `end` are literal strings; axolotl encodes them at strategy
|
||||||
|
init via `tokenizer.encode(..., add_special_tokens=False)` and logs the
|
||||||
|
resolved token-id sequences at INFO level.
|
||||||
|
- The special value `end: eos_token` is the portable way to express
|
||||||
|
"Pixtral-style assistant turns end at EOS" without hard-coding an id.
|
||||||
|
- `role_boundaries` is an **opt-in override**. A non-empty list **replaces**
|
||||||
|
the strategy's built-in declarations wholesale (partial overlays are
|
||||||
|
intentionally unsupported — they're hard to reason about at review time).
|
||||||
|
Leaving the field unset *or* setting it to an empty list (`[]`) both mean
|
||||||
|
"use the strategy's built-ins." Writing `role_boundaries: []` is almost
|
||||||
|
always a typo or leftover — honoring it literally would produce all-masked
|
||||||
|
labels and zero gradient, so it is treated the same as unset.
|
||||||
|
- `cfg.roles_to_train` still governs which declared roles contribute to
|
||||||
|
loss. You can declare `user` and `assistant` boundaries and set
|
||||||
|
`roles_to_train: ["assistant"]` to have the scanner correctly identify
|
||||||
|
user spans as masking boundaries without training on their content.
|
||||||
|
- Invalid specs fail loudly at strategy init (missing `role`/`start`,
|
||||||
|
unencodable markers), not silently at loss-compute time.
|
||||||
@@ -22,12 +22,12 @@ Improves GPU utilization by combining multiple short sequences into a single pac
|
|||||||
|
|
||||||
Using an optimized attention implementation is critical for training speed.
|
Using an optimized attention implementation is critical for training speed.
|
||||||
|
|
||||||
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `flash_attention: true`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
- **[Flash Attention 2](https://github.com/Dao-AILab/flash-attention)**: `attn_implementation: flash_attention_2`. **(Recommended)** The industry standard for fast attention on modern GPUs. Requires Ampere or higher. For AMD, check [AMD Support](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#amd-rocm-support).
|
||||||
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `flex_attention: true`.
|
- **[Flex Attention](https://pytorch.org/blog/flexattention/)**: `attn_implementation: flex_attention`.
|
||||||
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `sdp_attention: true`. PyTorch's native implementation.
|
- **[SDP Attention](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)**: `attn_implementation: sdpa`. PyTorch's native implementation.
|
||||||
- **[Xformers](https://github.com/facebookresearch/xformers)**: `xformers_attention: true`. Works with FP16.
|
- **[Xformers](https://github.com/facebookresearch/xformers)**: `attn_implementation: xformers`. Works with FP16.
|
||||||
|
|
||||||
*Note: You should only enable one attention backend.*
|
See [Attention](attention.qmd) for the full list of backends and the canonical values.
|
||||||
|
|
||||||
### LoRA Optimizations
|
### LoRA Optimizations
|
||||||
|
|
||||||
|
|||||||
@@ -320,8 +320,10 @@ The input format is a simple JSON input with customizable fields based on the ab
|
|||||||
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
rl: ipo
|
rl: dpo
|
||||||
|
dpo_loss_type: ["ipo"]
|
||||||
```
|
```
|
||||||
|
*Note:* Passing `rl: ipo` directly is still supported, but will soon be deprecated.
|
||||||
|
|
||||||
### ORPO
|
### ORPO
|
||||||
|
|
||||||
@@ -1145,8 +1147,7 @@ datasets:
|
|||||||
type: ebft_strided_structured.transform
|
type: ebft_strided_structured.transform
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
|
|
||||||
flash_attention: false
|
attn_implementation: flex_attention # Strided mode uses flex_attention
|
||||||
flex_attention: true # Strided mode uses flex_attention
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: true # Required for flex_attention
|
use_reentrant: true # Required for flex_attention
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ examples:
|
|||||||
title: Arcee AFM
|
title: Arcee AFM
|
||||||
|
|
||||||
# MistralAI
|
# MistralAI
|
||||||
|
- name: mistral-medium-3_5
|
||||||
|
title: Mistral Medium 3.5
|
||||||
- name: ministral3/think
|
- name: ministral3/think
|
||||||
title: Ministral 3 Thinking
|
title: Ministral 3 Thinking
|
||||||
- name: ministral3/vision
|
- name: ministral3/vision
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ To use sequence parallelism, you need:
|
|||||||
|
|
||||||
## Limitations
|
## Limitations
|
||||||
|
|
||||||
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
- Flash attention must be enabled for this to work (`attn_implementation: flash_attention_2` in config YAML)
|
||||||
- May have a small performance overhead due to communication between GPUs
|
- May have a small performance overhead due to communication between GPUs
|
||||||
|
|
||||||
## Example
|
## Example
|
||||||
|
|||||||
@@ -245,7 +245,7 @@ For GRPO, also reduce `max_completion_length`. Memory scales quadratically with
|
|||||||
Reduces attention memory from O(n^2) to O(n):
|
Reduces attention memory from O(n^2) to O(n):
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
```
|
```
|
||||||
|
|
||||||
### Step 6: Offload with DeepSpeed
|
### Step 6: Offload with DeepSpeed
|
||||||
|
|||||||
@@ -1,53 +0,0 @@
|
|||||||
---
|
|
||||||
title: "Unsloth"
|
|
||||||
description: "Hyper-optimized QLoRA finetuning for single GPUs"
|
|
||||||
---
|
|
||||||
|
|
||||||
### Overview
|
|
||||||
|
|
||||||
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
|
|
||||||
standard industry baselines.
|
|
||||||
|
|
||||||
::: {.callout-important}
|
|
||||||
Due to breaking changes in transformers `v4.48.0`, users will need to downgrade to `<=v4.47.1` to use this patch.
|
|
||||||
|
|
||||||
This will later be deprecated in favor of [LoRA Optimizations](lora_optims.qmd).
|
|
||||||
:::
|
|
||||||
|
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
The following will install the correct unsloth and extras from source.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/unsloth_install.py | sh
|
|
||||||
```
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
|
|
||||||
|
|
||||||
Our unsloth integration is currently limited to the following model architectures:
|
|
||||||
- llama
|
|
||||||
|
|
||||||
These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
|
|
||||||
```yaml
|
|
||||||
unsloth_lora_mlp: true
|
|
||||||
unsloth_lora_qkv: true
|
|
||||||
unsloth_lora_o: true
|
|
||||||
```
|
|
||||||
|
|
||||||
These options are composable and can be used with multi-gpu finetuning
|
|
||||||
```yaml
|
|
||||||
unsloth_cross_entropy_loss: true
|
|
||||||
unsloth_rms_norm: true
|
|
||||||
unsloth_rope: true
|
|
||||||
```
|
|
||||||
|
|
||||||
### Limitations
|
|
||||||
|
|
||||||
- Single GPU only; e.g. no multi-gpu support
|
|
||||||
- No deepspeed or FSDP support (requires multi-gpu)
|
|
||||||
- LoRA + QLoRA support only. No full fine tunes or fp8 support.
|
|
||||||
- Limited model architecture support. Llama, Phi, Gemma, Mistral only
|
|
||||||
- No MoE support.
|
|
||||||
@@ -15,8 +15,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have a compatible version of Pytorch installed
|
# Ensure you have a compatible version of Pytorch installed
|
||||||
pip3 install packaging setuptools wheel ninja
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Run one of the finetuning examples below.
|
2. Run one of the finetuning examples below.
|
||||||
@@ -35,7 +34,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
|||||||
|
|
||||||
**LFM2-MoE**
|
**LFM2-MoE**
|
||||||
```bash
|
```bash
|
||||||
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
|
uv pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
|
||||||
|
|
||||||
# LoRA SFT (1x48GB @ 16.2GiB)
|
# LoRA SFT (1x48GB @ 16.2GiB)
|
||||||
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
|
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
|
||||||
@@ -45,7 +44,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
|||||||
|
|
||||||
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
|
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
|
||||||
```bash
|
```bash
|
||||||
pip uninstall -y causal-conv1d
|
uv pip uninstall causal-conv1d
|
||||||
```
|
```
|
||||||
|
|
||||||
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
|
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ tf32: true
|
|||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 2
|
evals_per_epoch: 2
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 2
|
evals_per_epoch: 2
|
||||||
|
|||||||
@@ -50,8 +50,7 @@ tf32: true
|
|||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
eager_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ activation_offloading: legacy
|
|||||||
|
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ activation_offloading: legacy
|
|||||||
|
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
saves_per_epoch: 1
|
saves_per_epoch: 1
|
||||||
|
|||||||
@@ -11,12 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
uv pip install --no-build-isolation -e '.'
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
@@ -31,7 +30,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
# For those using our Docker image, use the below path.
|
# For those using our Docker image, use the below path.
|
||||||
export CUDA_HOME=/usr/local/cuda
|
export CUDA_HOME=/usr/local/cuda
|
||||||
|
|
||||||
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||||
```
|
```
|
||||||
|
|
||||||
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
|
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
|
||||||
@@ -67,7 +66,7 @@ If those didn't help, please try the below solutions:
|
|||||||
1. Pass env for CMAKE and try install again:
|
1. Pass env for CMAKE and try install again:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
Python_EXECUTABLE=$(which python) uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Git clone the repo and manually hardcode python path:
|
2. Git clone the repo and manually hardcode python path:
|
||||||
@@ -92,7 +91,7 @@ If those didn't help, please try the below solutions:
|
|||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install . --no-build-isolation --no-deps
|
uv pip install . --no-build-isolation --no-deps
|
||||||
```
|
```
|
||||||
|
|
||||||
## Optimization Guides
|
## Optimization Guides
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -13,12 +13,11 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
|
|||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
uv pip install --no-build-isolation -e '.'
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -59,8 +59,7 @@ gradient_checkpointing: false
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
|
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
sdp_attention:
|
|
||||||
flash_optimum:
|
flash_optimum:
|
||||||
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
|
|||||||
@@ -39,8 +39,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -43,8 +43,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -73,8 +73,7 @@ early_stopping_patience: 3
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
auto_resume_from_checkpoints: true
|
auto_resume_from_checkpoints: true
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -40,8 +40,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -36,8 +36,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -37,8 +37,7 @@ bf16: auto
|
|||||||
tf32: true
|
tf32: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 5
|
logging_steps: 5
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ bf16: auto
|
|||||||
tf32: true
|
tf32: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 5
|
logging_steps: 5
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ tf32: false
|
|||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ tf32: false
|
|||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention:
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ bf16: auto
|
|||||||
tf32: true
|
tf32: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 5
|
logging_steps: 5
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ tf32: true
|
|||||||
gradient_checkpointing:
|
gradient_checkpointing:
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
flash_attn_cross_entropy: false
|
flash_attn_cross_entropy: false
|
||||||
flash_attn_rms_norm: true
|
flash_attn_rms_norm: true
|
||||||
flash_attn_fuse_mlp: true
|
flash_attn_fuse_mlp: true
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
flash_attn_cross_entropy: false
|
flash_attn_cross_entropy: false
|
||||||
flash_attn_rms_norm: true
|
flash_attn_rms_norm: true
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -47,7 +47,6 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: false
|
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 0
|
evals_per_epoch: 0
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 4
|
evals_per_epoch: 4
|
||||||
|
|||||||
@@ -71,8 +71,7 @@ early_stopping_patience: 3
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
auto_resume_from_checkpoints: true
|
auto_resume_from_checkpoints: true
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention: true
|
attn_implementation: xformers
|
||||||
flash_attention:
|
|
||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ load_in_4bit: true
|
|||||||
sequence_len: 1024
|
sequence_len: 1024
|
||||||
bf16: auto
|
bf16: auto
|
||||||
tf32: false
|
tf32: false
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
special_tokens:
|
special_tokens:
|
||||||
bos_token: "<|startoftext|>"
|
bos_token: "<|startoftext|>"
|
||||||
eos_token: "<|endoftext|>"
|
eos_token: "<|endoftext|>"
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch:
|
evals_per_epoch:
|
||||||
|
|||||||
@@ -36,12 +36,7 @@
|
|||||||
"id": "msOCO4NRmRLa"
|
"id": "msOCO4NRmRLa"
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": "%%capture\n# This step can take ~5-10 minutes to install dependencies\n!pip install --no-build-isolation \"axolotl>=0.16.1\"\n!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
|
||||||
"%%capture\n",
|
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@63b15e6\""
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ tf32: true
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 1
|
evals_per_epoch: 1
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 2
|
evals_per_epoch: 2
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ gradient_checkpointing_kwargs:
|
|||||||
use_reentrant: false
|
use_reentrant: false
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
warmup_ratio: 0.1
|
warmup_ratio: 0.1
|
||||||
evals_per_epoch: 2
|
evals_per_epoch: 2
|
||||||
|
|||||||
@@ -15,9 +15,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ lora_model_dir:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0
|
lora_dropout: 0
|
||||||
@@ -51,8 +50,8 @@ tf32: false
|
|||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
scaling_softmax: true
|
# scaling_softmax: true # needs flex_attention
|
||||||
|
|
||||||
loss_watchdog_threshold: 5.0
|
loss_watchdog_threshold: 5.0
|
||||||
loss_watchdog_patience: 3
|
loss_watchdog_patience: 3
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ output_dir: ./outputs/ndp-out/
|
|||||||
|
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
flash_attention: true
|
attn_implementation: flash_attention_2
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user