Compare commits
10 Commits
swe-rebenc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
798c8fba89 | ||
|
|
17fc747f99 | ||
|
|
901f2356bc | ||
|
|
1bf65c500e | ||
|
|
bcbe049c21 | ||
|
|
90090fa9e8 | ||
|
|
7420fd4de6 | ||
|
|
05113bc91a | ||
|
|
e562e149ce | ||
|
|
9de5b76336 |
6
.github/CONTRIBUTING.md
vendored
6
.github/CONTRIBUTING.md
vendored
@@ -31,7 +31,11 @@ PRs are **greatly welcome**!
|
|||||||
|
|
||||||
Please run below to setup env
|
Please run below to setup env
|
||||||
```bash
|
```bash
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
# Install axolotl + dev and test dependencies
|
||||||
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
|
uv venv --no-project --relocatable
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
|
||||||
pre-commit install
|
pre-commit install
|
||||||
|
|
||||||
# test
|
# test
|
||||||
|
|||||||
16
.github/workflows/base.yml
vendored
16
.github/workflows/base.yml
vendored
@@ -30,14 +30,6 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
@@ -168,14 +160,6 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
dockerfile: "Dockerfile-uv-base"
|
dockerfile: "Dockerfile-uv-base"
|
||||||
platforms: "linux/amd64,linux/arm64"
|
platforms: "linux/amd64,linux/arm64"
|
||||||
- cuda: "128"
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
cudnn_version: ""
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
dockerfile: "Dockerfile-uv-base"
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
|
|||||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -6,7 +6,7 @@ on:
|
|||||||
types: [opened, synchronize, reopened, ready_for_review]
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- 'requirements.txt'
|
- 'pyproject.toml'
|
||||||
- '.github/workflows/*.yml'
|
- '.github/workflows/*.yml'
|
||||||
- "*.[q]md"
|
- "*.[q]md"
|
||||||
- "examples/**/*.y[a]?ml"
|
- "examples/**/*.y[a]?ml"
|
||||||
|
|||||||
12
.github/workflows/main.yml
vendored
12
.github/workflows/main.yml
vendored
@@ -18,12 +18,6 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.0
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -180,12 +174,6 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 128
|
|
||||||
cuda_version: 12.8.1
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.9.0
|
|
||||||
axolotl_extras:
|
|
||||||
platforms: "linux/amd64,linux/arm64"
|
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
|
|||||||
35
.github/workflows/multi-gpu-e2e.yml
vendored
35
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -3,17 +3,15 @@ name: docker-multigpu-tests-biweekly
|
|||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- 'tests/e2e/multigpu/**.py'
|
- "tests/e2e/multigpu/**.py"
|
||||||
- 'requirements.txt'
|
- "pyproject.toml"
|
||||||
- 'setup.py'
|
- ".github/workflows/multi-gpu-e2e.yml"
|
||||||
- 'pyproject.toml'
|
- "scripts/cutcrossentropy_install.py"
|
||||||
- '.github/workflows/multi-gpu-e2e.yml'
|
- "src/axolotl/core/trainers/mixins/sequence_parallel.py"
|
||||||
- 'scripts/cutcrossentropy_install.py'
|
- "src/axolotl/utils/distributed.py"
|
||||||
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
|
|
||||||
- 'src/axolotl/utils/distributed.py'
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
schedule:
|
schedule:
|
||||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
- cron: "0 0 * * 1,4" # Runs at 00:00 UTC every monday & thursday
|
||||||
|
|
||||||
# Cancel jobs on the same ref if a new one is triggered
|
# Cancel jobs on the same ref if a new one is triggered
|
||||||
concurrency:
|
concurrency:
|
||||||
@@ -33,19 +31,19 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
# - cuda: 129
|
# - cuda: 129
|
||||||
# cuda_version: 12.9.1
|
# cuda_version: 12.9.1
|
||||||
# python_version: "3.12"
|
# python_version: "3.12"
|
||||||
# pytorch: 2.9.1
|
# pytorch: 2.9.1
|
||||||
# axolotl_extras: "fbgemm-gpu"
|
# axolotl_extras: "fbgemm-gpu"
|
||||||
# num_gpus: 2
|
# num_gpus: 2
|
||||||
# dockerfile: "Dockerfile-uv.jinja"
|
# dockerfile: "Dockerfile-uv.jinja"
|
||||||
- cuda: 130
|
- cuda: 130
|
||||||
cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
# axolotl_extras: fbgemm-gpu
|
# axolotl_extras: fbgemm-gpu
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
- cuda: 128
|
- cuda: 128
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
@@ -53,7 +51,6 @@ jobs:
|
|||||||
pytorch: 2.10.0
|
pytorch: 2.10.0
|
||||||
axolotl_extras: "fbgemm-gpu"
|
axolotl_extras: "fbgemm-gpu"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
dockerfile: "Dockerfile-uv.jinja"
|
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 120
|
timeout-minutes: 120
|
||||||
steps:
|
steps:
|
||||||
@@ -75,7 +72,7 @@ jobs:
|
|||||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
env:
|
env:
|
||||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
|||||||
13
.github/workflows/pypi.yml
vendored
13
.github/workflows/pypi.yml
vendored
@@ -8,6 +8,9 @@ on:
|
|||||||
|
|
||||||
permissions: {}
|
permissions: {}
|
||||||
|
|
||||||
|
env:
|
||||||
|
UV_SYSTEM_PYTHON: "1"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
setup_release:
|
setup_release:
|
||||||
name: Create Release
|
name: Create Release
|
||||||
@@ -41,11 +44,15 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v7
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install wheel packaging==26.0
|
uv pip install wheel packaging
|
||||||
pip3 install --no-build-isolation -e .
|
uv pip install --no-build-isolation -e .
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||||
|
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||||
|
|
||||||
- name: Extract tag name
|
- name: Extract tag name
|
||||||
id: tag
|
id: tag
|
||||||
|
|||||||
55
.github/workflows/tests-nightly.yml
vendored
55
.github/workflows/tests-nightly.yml
vendored
@@ -2,15 +2,18 @@ name: Tests Nightly against upstream main
|
|||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
schedule:
|
schedule:
|
||||||
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
|
- cron: "0 0 * * *" # Runs at 00:00 UTC every day
|
||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
paths:
|
paths:
|
||||||
- '.github/workflows/tests-nightly.yml'
|
- ".github/workflows/tests-nightly.yml"
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
|
|
||||||
|
env:
|
||||||
|
UV_SYSTEM_PYTHON: "1"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pre-commit:
|
pre-commit:
|
||||||
name: pre-commit
|
name: pre-commit
|
||||||
@@ -20,7 +23,7 @@ jobs:
|
|||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: "pip" # caching pip dependencies
|
||||||
- uses: pre-commit/action@v3.0.1
|
- uses: pre-commit/action@v3.0.1
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
@@ -43,7 +46,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
python_version: ["3.12"] # TODO include py3.14 once https://github.com/mistralai/mistral-common/pull/194 is merged
|
||||||
pytorch_version: ["2.9.1", "2.10.0"]
|
pytorch_version: ["2.9.1", "2.10.0"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
@@ -61,36 +64,34 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
|
||||||
|
|
||||||
- name: upgrade pip
|
- name: Install uv
|
||||||
run: |
|
uses: astral-sh/setup-uv@v7
|
||||||
pip3 install --upgrade pip
|
|
||||||
pip3 install --upgrade packaging==26.0 setuptools==78.1.1 wheel
|
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
pip3 install torch==${{ matrix.pytorch_version }} torchvision
|
uv pip install torch==${{ matrix.pytorch_version }} torchvision
|
||||||
|
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
|
||||||
- name: Update requirements.txt
|
|
||||||
run: |
|
|
||||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
|
|
||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
|
||||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
|
||||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 show torch
|
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
|
||||||
pip3 install --no-build-isolation -U -e .
|
python scripts/cutcrossentropy_install.py --uv | sh
|
||||||
python scripts/unsloth_install.py | sh
|
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
|
||||||
|
- name: Override with nightly HF packages
|
||||||
|
run: |
|
||||||
|
uv pip install --no-deps \
|
||||||
|
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
|
||||||
|
"peft @ git+https://github.com/huggingface/peft.git@main" \
|
||||||
|
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
||||||
|
"trl @ git+https://github.com/huggingface/trl.git@main" \
|
||||||
|
"datasets @ git+https://github.com/huggingface/datasets.git@main"
|
||||||
|
|
||||||
- name: Make sure PyTorch version wasn't clobbered
|
- name: Make sure PyTorch version wasn't clobbered
|
||||||
run: |
|
run: |
|
||||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
|
||||||
|
|
||||||
- name: Ensure axolotl CLI was installed
|
- name: Ensure axolotl CLI was installed
|
||||||
run: |
|
run: |
|
||||||
@@ -102,9 +103,6 @@ jobs:
|
|||||||
pytest -v --durations=10 tests/patched/
|
pytest -v --durations=10 tests/patched/
|
||||||
pytest -v --durations=10 tests/cli/
|
pytest -v --durations=10 tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
|
||||||
run: |
|
|
||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
@@ -136,7 +134,6 @@ jobs:
|
|||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
dockerfile: "Dockerfile-uv.jinja"
|
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -157,7 +154,7 @@ jobs:
|
|||||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
env:
|
env:
|
||||||
|
|||||||
85
.github/workflows/tests.yml
vendored
85
.github/workflows/tests.yml
vendored
@@ -6,21 +6,19 @@ on:
|
|||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- "**.py"
|
||||||
- 'requirements.txt'
|
- "pyproject.toml"
|
||||||
- '.github/workflows/*.yml'
|
- ".github/workflows/*.yml"
|
||||||
- 'requirements-tests.txt'
|
- "cicd/cicd.sh"
|
||||||
- 'cicd/cicd.sh'
|
- "cicd/Dockerfile-uv.jinja"
|
||||||
- 'cicd/Dockerfile.jinja'
|
|
||||||
pull_request:
|
pull_request:
|
||||||
types: [opened, synchronize, reopened, ready_for_review]
|
types: [opened, synchronize, reopened, ready_for_review]
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- "**.py"
|
||||||
- 'requirements.txt'
|
- "pyproject.toml"
|
||||||
- '.github/workflows/*.yml'
|
- ".github/workflows/*.yml"
|
||||||
- 'requirements-tests.txt'
|
- "cicd/cicd.sh"
|
||||||
- 'cicd/cicd.sh'
|
- "cicd/Dockerfile-uv.jinja"
|
||||||
- 'cicd/Dockerfile.jinja'
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
# Cancel jobs on the same ref if a new one is triggered
|
# Cancel jobs on the same ref if a new one is triggered
|
||||||
@@ -33,6 +31,7 @@ permissions:
|
|||||||
|
|
||||||
env:
|
env:
|
||||||
TRANSFORMERS_IS_CI: "yes"
|
TRANSFORMERS_IS_CI: "yes"
|
||||||
|
UV_SYSTEM_PYTHON: "1"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pre-commit:
|
pre-commit:
|
||||||
@@ -44,7 +43,7 @@ jobs:
|
|||||||
- uses: actions/setup-python@v5
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.11"
|
python-version: "3.11"
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: "pip" # caching pip dependencies
|
||||||
- uses: pre-commit/action@v3.0.1
|
- uses: pre-commit/action@v3.0.1
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
@@ -94,32 +93,25 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
|
||||||
|
|
||||||
- name: upgrade pip
|
- name: Install uv
|
||||||
run: |
|
uses: astral-sh/setup-uv@v7
|
||||||
pip3 install --upgrade pip
|
|
||||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 wheel
|
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
uv pip install torch==${{ matrix.pytorch_version }} torchvision
|
||||||
|
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 show torch
|
uv pip install --no-build-isolation -e . --override /tmp/torch-pin.txt
|
||||||
pip3 install --no-cache-dir --no-build-isolation -U -e .
|
python scripts/cutcrossentropy_install.py --uv | sh
|
||||||
python scripts/unsloth_install.py | sh
|
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
|
||||||
|
|
||||||
- name: cleanup pip cache
|
|
||||||
run: |
|
|
||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
|
||||||
|
|
||||||
- name: Make sure PyTorch version wasn't clobbered
|
- name: Make sure PyTorch version wasn't clobbered
|
||||||
run: |
|
run: |
|
||||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
|
||||||
|
|
||||||
- name: Ensure axolotl CLI was installed
|
- name: Ensure axolotl CLI was installed
|
||||||
run: |
|
run: |
|
||||||
@@ -188,33 +180,27 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
|
||||||
|
|
||||||
- name: upgrade pip
|
- name: Install uv
|
||||||
run: |
|
uses: astral-sh/setup-uv@v7
|
||||||
pip3 install --upgrade pip
|
|
||||||
pip3 install --upgrade packaging==26.0 setuptools==75.8.0 setuptools_scm build wheel psutil
|
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
pip3 install --no-cache-dir torch==${{ matrix.pytorch_version }} torchvision
|
uv pip install torch==${{ matrix.pytorch_version }} torchvision
|
||||||
|
uv pip freeze | grep -E "^(torch|torchvision)==" > /tmp/torch-pin.txt
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 show torch
|
uv pip install packaging setuptools_scm build wheel psutil
|
||||||
python -m build --no-isolation --sdist
|
python -m build --no-isolation --sdist
|
||||||
pip3 install --no-cache-dir --no-build-isolation dist/axolotl*.tar.gz
|
uv pip install --no-build-isolation dist/axolotl*.tar.gz --override /tmp/torch-pin.txt
|
||||||
python scripts/unsloth_install.py | sh
|
python scripts/cutcrossentropy_install.py --uv | sh
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||||
|
|
||||||
- name: cleanup pip cache
|
|
||||||
run: |
|
|
||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
|
||||||
|
|
||||||
- name: Make sure PyTorch version wasn't clobbered
|
- name: Make sure PyTorch version wasn't clobbered
|
||||||
run: |
|
run: |
|
||||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__, f'Expected torch ${{ matrix.pytorch_version }} but got {torch.__version__}'"
|
||||||
|
|
||||||
- name: Ensure axolotl CLI was installed
|
- name: Ensure axolotl CLI was installed
|
||||||
run: |
|
run: |
|
||||||
@@ -291,7 +277,6 @@ jobs:
|
|||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
dockerfile: "Dockerfile-uv.jinja"
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -312,7 +297,7 @@ jobs:
|
|||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
env:
|
env:
|
||||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||||
@@ -374,7 +359,7 @@ jobs:
|
|||||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
|
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
|
||||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile-uv.jinja'}}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
env:
|
env:
|
||||||
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ axolotl config-schema # Dump config JSON schema
|
|||||||
| Method | Config Key | When to Use |
|
| Method | Config Key | When to Use |
|
||||||
|--------|-----------|-------------|
|
|--------|-----------|-------------|
|
||||||
| SFT | *(default)* | Input-output pairs, instruction tuning |
|
| SFT | *(default)* | Input-output pairs, instruction tuning |
|
||||||
| DPO/IPO | `rl: dpo` / `rl: ipo` | Paired preference data (chosen vs rejected) |
|
| DPO/IPO | `rl: dpo` / `rl: dpo, dpo_loss_type: ["ipo"]` | Paired preference data (chosen vs rejected) |
|
||||||
| KTO | `rl: kto` | Unpaired binary preference labels |
|
| KTO | `rl: kto` | Unpaired binary preference labels |
|
||||||
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
|
| ORPO | `rl: orpo` | Single-stage alignment, no ref model |
|
||||||
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |
|
| GRPO | `rl: grpo` | RL with verifiable reward functions (math, code) |
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
include requirements.txt
|
|
||||||
include README.md
|
include README.md
|
||||||
include LICENSE
|
include LICENSE
|
||||||
include src/setuptools_axolotl_dynamic_dependencies.py
|
include VERSION
|
||||||
include src/axolotl/utils/chat_templates/templates/*.jinja
|
include src/axolotl/utils/chat_templates/templates/*.jinja
|
||||||
include AGENTS.md
|
include AGENTS.md
|
||||||
recursive-include docs/agents *.md
|
recursive-include docs/agents *.md
|
||||||
|
|||||||
26
README.md
26
README.md
@@ -95,14 +95,11 @@ Features:
|
|||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
#### Using uv (recommended)
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# install uv if you don't already have it installed
|
# install uv if you don't already have it installed (restart shell after)
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
source $HOME/.local/bin/env
|
|
||||||
|
|
||||||
# CUDA 12.8.1 tends to have better package compatibility
|
# change depending on system
|
||||||
export UV_TORCH_BACKEND=cu128
|
export UV_TORCH_BACKEND=cu128
|
||||||
|
|
||||||
# create a new virtual environment
|
# create a new virtual environment
|
||||||
@@ -112,23 +109,6 @@ source .venv/bin/activate
|
|||||||
uv pip install torch==2.10.0 torchvision
|
uv pip install torch==2.10.0 torchvision
|
||||||
uv pip install --no-build-isolation axolotl[deepspeed]
|
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||||
|
|
||||||
# recommended - install cut-cross-entropy
|
|
||||||
uv pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@main"
|
|
||||||
|
|
||||||
# (optional) - prefetch flash-attn2 and causal-conv1d kernels
|
|
||||||
uv run --python 3.12 python -c "from kernels import get_kernel; get_kernel('kernels-community/flash-attn2'); get_kernel('kernels-community/causal-conv1d')"
|
|
||||||
|
|
||||||
# Download example axolotl configs, deepspeed configs
|
|
||||||
axolotl fetch examples
|
|
||||||
axolotl fetch deepspeed_configs # OPTIONAL
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Using pip
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip3 install -U packaging==26.0 setuptools==75.8.0 wheel ninja
|
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
|
||||||
|
|
||||||
# Download example axolotl configs, deepspeed configs
|
# Download example axolotl configs, deepspeed configs
|
||||||
axolotl fetch examples
|
axolotl fetch examples
|
||||||
axolotl fetch deepspeed_configs # OPTIONAL
|
axolotl fetch deepspeed_configs # OPTIONAL
|
||||||
@@ -138,7 +118,7 @@ axolotl fetch deepspeed_configs # OPTIONAL
|
|||||||
|
|
||||||
Installing with Docker can be less error prone than installing in your own environment.
|
Installing with Docker can be less error prone than installing in your own environment.
|
||||||
```bash
|
```bash
|
||||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
docker run --gpus '"all"' --ipc=host --rm -it axolotlai/axolotl:main-latest
|
||||||
```
|
```
|
||||||
|
|
||||||
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|||||||
@@ -134,7 +134,6 @@ quartodoc:
|
|||||||
- monkeypatch.stablelm_attn_hijack_flash
|
- monkeypatch.stablelm_attn_hijack_flash
|
||||||
- monkeypatch.trainer_fsdp_optim
|
- monkeypatch.trainer_fsdp_optim
|
||||||
- monkeypatch.transformers_fa_utils
|
- monkeypatch.transformers_fa_utils
|
||||||
- monkeypatch.unsloth_
|
|
||||||
- monkeypatch.data.batch_dataset_fetcher
|
- monkeypatch.data.batch_dataset_fetcher
|
||||||
- monkeypatch.mixtral
|
- monkeypatch.mixtral
|
||||||
- monkeypatch.gradient_checkpointing.offload_cpu
|
- monkeypatch.gradient_checkpointing.offload_cpu
|
||||||
@@ -327,7 +326,6 @@ website:
|
|||||||
- section: "Advanced Features"
|
- section: "Advanced Features"
|
||||||
contents:
|
contents:
|
||||||
- docs/fsdp_qlora.qmd
|
- docs/fsdp_qlora.qmd
|
||||||
- docs/unsloth.qmd
|
|
||||||
- docs/torchao.qmd
|
- docs/torchao.qmd
|
||||||
- docs/custom_integrations.qmd
|
- docs/custom_integrations.qmd
|
||||||
- docs/sequence_parallelism.qmd
|
- docs/sequence_parallelism.qmd
|
||||||
|
|||||||
@@ -22,15 +22,6 @@ WORKDIR /workspace/axolotl
|
|||||||
RUN git fetch origin +$GITHUB_REF && \
|
RUN git fetch origin +$GITHUB_REF && \
|
||||||
git checkout FETCH_HEAD
|
git checkout FETCH_HEAD
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
|
||||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|
||||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
RUN uv pip install packaging==26.0 setuptools==78.1.1
|
RUN uv pip install packaging==26.0 setuptools==78.1.1
|
||||||
RUN uv pip install torchvision
|
RUN uv pip install torchvision
|
||||||
RUN uv pip uninstall causal_conv1d
|
RUN uv pip uninstall causal_conv1d
|
||||||
@@ -40,11 +31,21 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
|||||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py --uv | sh
|
# Override with nightly HF packages for nightly builds
|
||||||
|
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||||
|
uv pip install --no-deps \
|
||||||
|
"transformers @ git+https://github.com/huggingface/transformers.git@main" \
|
||||||
|
"peft @ git+https://github.com/huggingface/peft.git@main" \
|
||||||
|
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
|
||||||
|
"trl @ git+https://github.com/huggingface/trl.git@main" \
|
||||||
|
"datasets @ git+https://github.com/huggingface/datasets.git@main"; \
|
||||||
|
fi
|
||||||
|
|
||||||
RUN python scripts/cutcrossentropy_install.py --uv | sh
|
RUN python scripts/cutcrossentropy_install.py --uv | sh
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
RUN uv pip install -r requirements-dev.txt -r requirements-tests.txt
|
RUN uv pip install black mypy pre-commit types-requests quartodoc jupyter blobfile tiktoken \
|
||||||
|
codecov codecov-cli pytest pytest-cov pytest-retry pytest-sugar pytest-xdist tbparse
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
# fix so that git fetch/pull from remote works
|
||||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||||
|
|||||||
@@ -1,54 +0,0 @@
|
|||||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
|
||||||
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
|
||||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
|
||||||
ENV CUDA="{{ CUDA }}"
|
|
||||||
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
|
|
||||||
ENV GITHUB_REF="{{ GITHUB_REF }}"
|
|
||||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
|
||||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
|
||||||
ENV HF_HOME="{{ HF_HOME }}"
|
|
||||||
ENV AXOLOTL_DATASET_NUM_PROC="8"
|
|
||||||
|
|
||||||
RUN apt-get update && \
|
|
||||||
apt-get install -y --allow-change-held-packages vim curl nano zstd libnccl2 libnccl-dev ibverbs-providers ibverbs-utils infiniband-diags librdmacm-dev librdmacm1 rdmacm-utils slurm-wlm
|
|
||||||
|
|
||||||
WORKDIR /workspace
|
|
||||||
|
|
||||||
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
|
||||||
|
|
||||||
WORKDIR /workspace/axolotl
|
|
||||||
|
|
||||||
RUN git fetch origin +$GITHUB_REF && \
|
|
||||||
git checkout FETCH_HEAD
|
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
|
||||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|
||||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
|
||||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
RUN pip install packaging==26.0 setuptools==78.1.1 psutil
|
|
||||||
RUN pip uninstall -y causal_conv1d
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
|
||||||
else \
|
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
|
||||||
RUN python scripts/cutcrossentropy_install.py | sh
|
|
||||||
|
|
||||||
# So we can test the Docker image
|
|
||||||
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
|
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
|
||||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
|
||||||
git config --get remote.origin.fetch
|
|
||||||
|
|
||||||
# helper for huggingface-login cli
|
|
||||||
RUN git config --global credential.helper store
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__, f'Expected torch $PYTORCH_VERSION but got {torch.__version__}'"
|
||||||
|
|
||||||
set -o pipefail
|
set -o pipefail
|
||||||
for i in 1 2 3; do
|
for i in 1 2 3; do
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
|||||||
template_env = jinja2.Environment(
|
template_env = jinja2.Environment(
|
||||||
loader=template_loader, autoescape=select_autoescape()
|
loader=template_loader, autoescape=select_autoescape()
|
||||||
)
|
)
|
||||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
|
||||||
df_template = template_env.get_template(dockerfile)
|
df_template = template_env.get_template(dockerfile)
|
||||||
|
|
||||||
df_args = {
|
df_args = {
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
|||||||
template_env = jinja2.Environment(
|
template_env = jinja2.Environment(
|
||||||
loader=template_loader, autoescape=select_autoescape()
|
loader=template_loader, autoescape=select_autoescape()
|
||||||
)
|
)
|
||||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile-uv.jinja")
|
||||||
df_template = template_env.get_template(dockerfile)
|
df_template = template_env.get_template(dockerfile)
|
||||||
|
|
||||||
df_args = {
|
df_args = {
|
||||||
|
|||||||
@@ -24,15 +24,15 @@ WORKDIR /workspace/axolotl
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||||
RUN pip uninstall -y causal_conv1d
|
RUN pip uninstall -y causal_conv1d
|
||||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
BASE_EXTRAS="optimizers,ray"; \
|
||||||
else \
|
else \
|
||||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
BASE_EXTRAS="deepspeed,optimizers,ray"; \
|
||||||
fi && \
|
fi && \
|
||||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||||
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
fi && \ python scripts/unsloth_install.py | sh && \
|
fi && \
|
||||||
python scripts/cutcrossentropy_install.py | sh && \
|
python scripts/cutcrossentropy_install.py | sh && \
|
||||||
pip install pytest && \
|
pip install pytest && \
|
||||||
pip cache purge
|
pip cache purge
|
||||||
|
|||||||
@@ -58,19 +58,3 @@ RUN git lfs install --skip-repo && \
|
|||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||||
pip3 cache purge
|
pip3 cache purge
|
||||||
|
|
||||||
# Map Python version (e.g., 3.12 -> cp312)
|
|
||||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
|
||||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
|
||||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
|
||||||
# Map architecture
|
|
||||||
case "$TARGETARCH" in \
|
|
||||||
amd64) ARCH_TAG="x86_64" ;; \
|
|
||||||
arm64) ARCH_TAG="aarch64" ;; \
|
|
||||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
|
||||||
esac && \
|
|
||||||
WHL_VERSION="v0.7.16" && \
|
|
||||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-linux_${ARCH_TAG}.whl" && \
|
|
||||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
|
||||||
pip3 install --no-cache-dir "${WHL_FILE}" && \
|
|
||||||
rm "${WHL_FILE}"
|
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,mamba-ssm] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -24,16 +24,15 @@ WORKDIR /workspace/axolotl
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
# If AXOLOTL_EXTRAS is set, append it in brackets; don't install deepspeed with arm64
|
||||||
RUN uv pip uninstall causal_conv1d
|
RUN uv pip uninstall causal_conv1d
|
||||||
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
RUN if [ "$TARGETARCH" = "arm64" ]; then \
|
||||||
BASE_EXTRAS="flash-attn,ring-flash-attn,optimizers,ray"; \
|
BASE_EXTRAS="optimizers,ray"; \
|
||||||
else \
|
else \
|
||||||
BASE_EXTRAS="deepspeed,flash-attn,ring-flash-attn,optimizers,ray"; \
|
BASE_EXTRAS="deepspeed,optimizers,ray"; \
|
||||||
fi && \
|
fi && \
|
||||||
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
if [ "$AXOLOTL_EXTRAS" != "" ]; then \
|
||||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
uv pip install --no-build-isolation -e .[$BASE_EXTRAS,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
uv pip install --no-build-isolation -e .[$BASE_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
fi && \
|
fi && \
|
||||||
python scripts/unsloth_install.py --uv | sh && \
|
|
||||||
python scripts/cutcrossentropy_install.py --uv | sh && \
|
python scripts/cutcrossentropy_install.py --uv | sh && \
|
||||||
uv pip install pytest && \
|
uv pip install pytest && \
|
||||||
uv cache clean
|
uv cache clean
|
||||||
|
|||||||
@@ -38,20 +38,3 @@ RUN uv pip install packaging setuptools wheel psutil \
|
|||||||
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
RUN if [ "$TARGETARCH" = "amd64" ]; then \
|
||||||
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
|
MAMBA_SKIP_CUDA_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE uv pip install --no-build-isolation mamba_ssm causal_conv1d; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Map Python version (e.g., 3.12 -> cp312)
|
|
||||||
RUN PYTHON_CP="cp$(echo $PYTHON_VERSION | tr -d '.')" && \
|
|
||||||
# Map PyTorch version (e.g., 2.9.1 -> torch2.9, 2.10.0 -> torch2.10)
|
|
||||||
TORCH_TAG="torch$(echo $PYTORCH_VERSION | grep -oP '^\d+\.\d+')" && \
|
|
||||||
LINUX_TAG="manylinux_" && \
|
|
||||||
# Map architecture
|
|
||||||
case "$TARGETARCH" in \
|
|
||||||
amd64) ARCH_TAG="2_24_x86_64.manylinux_2_28_x86_64" ;; \
|
|
||||||
arm64) ARCH_TAG="2_34_aarch64" ;; \
|
|
||||||
*) echo "Unsupported architecture: $TARGETARCH"; exit 1 ;; \
|
|
||||||
esac && \
|
|
||||||
WHL_VERSION="v0.7.16" && \
|
|
||||||
WHL_FILE="flash_attn-2.8.3+cu${CUDA}${TORCH_TAG}-${PYTHON_CP}-${PYTHON_CP}-${LINUX_TAG}${ARCH_TAG}.whl" && \
|
|
||||||
wget -nv "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/${WHL_VERSION}/${WHL_FILE}" && \
|
|
||||||
uv pip install --no-cache-dir "${WHL_FILE}" && \
|
|
||||||
rm "${WHL_FILE}"
|
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ No vLLM server needed (unlike GRPO). Offline RL with pre-collected preference da
|
|||||||
|
|
||||||
1. Paired preference data (chosen + rejected)?
|
1. Paired preference data (chosen + rejected)?
|
||||||
- Default → `rl: dpo`
|
- Default → `rl: dpo`
|
||||||
- Overfitting → `rl: ipo`
|
- Overfitting → `rl: dpo, dpo_loss_type: ["ipo"]`
|
||||||
- VRAM-limited → `rl: orpo` (no ref model)
|
- VRAM-limited → `rl: orpo` (no ref model)
|
||||||
- Length-sensitive → `rl: simpo` (no ref model)
|
- Length-sensitive → `rl: simpo` (no ref model)
|
||||||
2. Only binary labels (good/bad)? → `rl: kto`
|
2. Only binary labels (good/bad)? → `rl: kto`
|
||||||
|
|||||||
@@ -76,8 +76,10 @@ datasets:
|
|||||||
Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) of Axolotl, which ensures that changes you make to the code are reflected at runtime. Run the following commands from the root of this project:
|
Make sure you have an [editable install](https://setuptools.pypa.io/en/latest/userguide/development_mode.html) of Axolotl, which ensures that changes you make to the code are reflected at runtime. Run the following commands from the root of this project:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install packaging
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
uv venv --no-project --relocatable
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Remote Hosts
|
#### Remote Hosts
|
||||||
@@ -208,17 +210,18 @@ cd axolotl
|
|||||||
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
|
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl-uv:main-latest
|
||||||
```
|
```
|
||||||
|
|
||||||
>[!Tip]
|
>[!Tip]
|
||||||
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
||||||
|
|
||||||
You will now be in the container. Next, perform an editable install of Axolotl:
|
You will now be in the container. Next, install Axolotl with dev dependencies:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install packaging
|
uv venv --no-project --relocatable
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
source .venv/bin/activate
|
||||||
|
uv pip install --no-build-isolation -e '.[deepspeed]' --group dev --group test
|
||||||
```
|
```
|
||||||
|
|
||||||
### Attach To Container
|
### Attach To Container
|
||||||
|
|||||||
@@ -6,23 +6,33 @@ format:
|
|||||||
toc-depth: 4
|
toc-depth: 4
|
||||||
---
|
---
|
||||||
|
|
||||||
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
This section describes the different Docker images that are released by AxolotlAI at
|
||||||
|
[Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
|
### Switch to the `-uv` images
|
||||||
|
|
||||||
|
Each image below ships a **uv variant** that uses [uv](https://docs.astral.sh/uv/) with a relocatable venv
|
||||||
|
(`/workspace/axolotl-venv`) instead of Miniconda + pip. Append `-uv` to the image name
|
||||||
|
(e.g. `axolotlai/axolotl-uv`, `axolotlai/axolotl-base-uv`, `axolotlai/axolotl-cloud-uv`). Tags follow the
|
||||||
|
same format as their non-uv counterparts.
|
||||||
|
|
||||||
|
**We recommend switching to the `-uv` images early.** In the near future we will publish the uv-based
|
||||||
|
build to the non-uv tags as well. The non-uv names will continue to work, but they will start serving
|
||||||
|
the uv image.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## Base
|
## Base
|
||||||
|
|
||||||
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
|
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image.
|
||||||
|
It includes python, torch, git, git-lfs, awscli, pydantic, and more.
|
||||||
|
|
||||||
#### Image
|
#### Image
|
||||||
|
|
||||||
```
|
| Variant | Image | Docker Hub |
|
||||||
axolotlai/axolotl-base
|
|---------|-------|------------|
|
||||||
```
|
| pip | `axolotlai/axolotl-base` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base) |
|
||||||
|
| uv | `axolotlai/axolotl-base-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-base-uv) |
|
||||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
|
|
||||||
|
|
||||||
#### Tags format
|
#### Tags format
|
||||||
|
|
||||||
@@ -32,8 +42,10 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
|||||||
|
|
||||||
Tags examples:
|
Tags examples:
|
||||||
|
|
||||||
- `main-base-py3.11-cu128-2.8.0`
|
|
||||||
- `main-base-py3.11-cu128-2.9.1`
|
- `main-base-py3.11-cu128-2.9.1`
|
||||||
|
- `main-base-py3.12-cu128-2.10.0`
|
||||||
|
- `main-base-py3.12-cu130-2.9.1`
|
||||||
|
- `main-base-py3.12-cu130-2.10.0`
|
||||||
|
|
||||||
## Main
|
## Main
|
||||||
|
|
||||||
@@ -41,11 +53,10 @@ The main image is the image that is used to run Axolotl. It is based on the `axo
|
|||||||
|
|
||||||
#### Image
|
#### Image
|
||||||
|
|
||||||
```
|
| Variant | Image | Docker Hub |
|
||||||
axolotlai/axolotl
|
|---------|-------|------------|
|
||||||
```
|
| pip | `axolotlai/axolotl` | [Link](https://hub.docker.com/r/axolotlai/axolotl) |
|
||||||
|
| uv | `axolotlai/axolotl-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-uv) |
|
||||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
|
|
||||||
|
|
||||||
#### Tags format {#sec-main-tags}
|
#### Tags format {#sec-main-tags}
|
||||||
|
|
||||||
@@ -53,7 +64,7 @@ Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
|
|||||||
# on push to main
|
# on push to main
|
||||||
main-py{python_version}-cu{cuda_version}-{pytorch_version}
|
main-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||||
|
|
||||||
# latest main (currently torch 2.6.0, python 3.11, cuda 12.4)
|
# latest main (currently torch 2.9.1, python 3.11, cuda 12.8)
|
||||||
main-latest
|
main-latest
|
||||||
|
|
||||||
# nightly build
|
# nightly build
|
||||||
@@ -71,12 +82,13 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
|||||||
|
|
||||||
Tags examples:
|
Tags examples:
|
||||||
|
|
||||||
- `main-py3.11-cu128-2.8.0`
|
|
||||||
- `main-py3.11-cu128-2.9.1`
|
- `main-py3.11-cu128-2.9.1`
|
||||||
|
- `main-py3.12-cu128-2.10.0`
|
||||||
|
- `main-py3.12-cu130-2.9.1`
|
||||||
|
- `main-py3.12-cu130-2.10.0`
|
||||||
- `main-latest`
|
- `main-latest`
|
||||||
- `main-20250303-py3.11-cu124-2.6.0`
|
- `main-20260315-py3.11-cu128-2.9.1`
|
||||||
- `main-20250303-py3.11-cu126-2.6.0`
|
- `0.16.1`
|
||||||
- `0.12.0`
|
|
||||||
|
|
||||||
## Cloud
|
## Cloud
|
||||||
|
|
||||||
@@ -90,11 +102,10 @@ Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variab
|
|||||||
|
|
||||||
#### Image
|
#### Image
|
||||||
|
|
||||||
```
|
| Variant | Image | Docker Hub |
|
||||||
axolotlai/axolotl-cloud
|
|---------|-------|------------|
|
||||||
```
|
| pip | `axolotlai/axolotl-cloud` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud) |
|
||||||
|
| uv | `axolotlai/axolotl-cloud-uv` | [Link](https://hub.docker.com/r/axolotlai/axolotl-cloud-uv) |
|
||||||
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
|
|
||||||
|
|
||||||
#### Tags format
|
#### Tags format
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
**Q: vLLM is not working with Axolotl**
|
**Q: vLLM is not working with Axolotl**
|
||||||
|
|
||||||
> A: We currently recommend torch 2.6.0 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.11-cu124-2.6.0` tag.
|
> A: We currently recommend torch 2.10 for use with `vllm`. Please ensure you use the right version. For Docker, please use the `main-py3.12-cu128-2.10.0` tag (note: torch 2.10 images are built with Python 3.12).
|
||||||
|
|
||||||
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
|
**Q: FA2 2.8.0 `undefined symbol` runtime error on CUDA 12.4**
|
||||||
|
|
||||||
|
|||||||
@@ -15,64 +15,30 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
|||||||
|
|
||||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||||
- Python ≥3.11
|
- Python ≥3.11
|
||||||
- PyTorch ≥2.6.0
|
- PyTorch ≥2.9.1
|
||||||
|
|
||||||
## Installation Methods {#sec-installation-methods}
|
## Installation {#sec-installation}
|
||||||
|
|
||||||
::: {.callout-important}
|
|
||||||
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
|
|
||||||
|
|
||||||
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
|
||||||
:::
|
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
|
For Blackwell GPUs, please use Pytorch 2.9.1 and CUDA 12.8.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
### PyPI Installation (Recommended) {#sec-pypi}
|
### Quick Install {#sec-uv}
|
||||||
|
|
||||||
```{.bash}
|
Axolotl uses [uv](https://docs.astral.sh/uv/) as its package manager. uv is a fast, reliable Python package installer and resolver built in Rust.
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
|
||||||
```
|
|
||||||
|
|
||||||
We use `--no-build-isolation` in order to detect the installed PyTorch version (if
|
Install uv if not already installed:
|
||||||
installed) in order not to clobber it, and so that we set the correct version of
|
|
||||||
dependencies that are specific to the PyTorch version or other installed
|
|
||||||
co-dependencies.
|
|
||||||
|
|
||||||
### uv Installation {#sec-uv}
|
|
||||||
|
|
||||||
uv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.
|
|
||||||
|
|
||||||
Install uv if not already installed
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
source $HOME/.local/bin/env
|
source $HOME/.local/bin/env
|
||||||
```
|
```
|
||||||
|
|
||||||
Choose your CUDA version to use with PyTorch; e.g. `cu124`, `cu126`, `cu128`,
|
Choose your CUDA version (e.g. `cu128`, `cu130`), create a venv, and install:
|
||||||
then create the venv and activate
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
export UV_TORCH_BACKEND=cu126
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
uv venv --no-project --relocatable
|
uv venv
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
```
|
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||||
|
|
||||||
Install PyTorch
|
|
||||||
- PyTorch 2.6.0 recommended
|
|
||||||
```{.bash}
|
|
||||||
uv pip install packaging setuptools wheel
|
|
||||||
uv pip install torch==2.6.0
|
|
||||||
uv pip install awscli pydantic
|
|
||||||
```
|
|
||||||
|
|
||||||
Install axolotl from PyPi
|
|
||||||
```{.bash}
|
|
||||||
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn]
|
|
||||||
|
|
||||||
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
|
|
||||||
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Edge/Development Build {#sec-edge-build}
|
### Edge/Development Build {#sec-edge-build}
|
||||||
@@ -82,14 +48,16 @@ For the latest features between releases:
|
|||||||
```{.bash}
|
```{.bash}
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install --no-build-isolation -e '.[deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Docker {#sec-docker}
|
### Docker {#sec-docker}
|
||||||
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
docker run --gpus '"all"' --rm -it --ipc=host axolotlai/axolotl-uv:main-latest
|
||||||
```
|
```
|
||||||
|
|
||||||
For development with Docker:
|
For development with Docker:
|
||||||
@@ -106,12 +74,12 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
|||||||
--ulimit memlock=-1 --ulimit stack=67108864 \
|
--ulimit memlock=-1 --ulimit stack=67108864 \
|
||||||
--mount type=bind,src="${PWD}",target=/workspace/axolotl \
|
--mount type=bind,src="${PWD}",target=/workspace/axolotl \
|
||||||
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
|
-v ${HOME}/.cache/huggingface:/root/.cache/huggingface \
|
||||||
axolotlai/axolotl:main-latest
|
axolotlai/axolotl-uv:main-latest
|
||||||
```
|
```
|
||||||
:::
|
:::
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.9.1`.
|
For Blackwell GPUs, please use `axolotlai/axolotl-uv:main-py3.11-cu128-2.9.1` or the cloud variant `axolotlai/axolotl-cloud-uv:main-py3.11-cu128-2.9.1`.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
||||||
@@ -122,7 +90,7 @@ Please refer to the [Docker documentation](docker.qmd) for more information on t
|
|||||||
|
|
||||||
For providers supporting Docker:
|
For providers supporting Docker:
|
||||||
|
|
||||||
- Use `axolotlai/axolotl-cloud:main-latest`
|
- Use `axolotlai/axolotl-cloud-uv:main-latest`
|
||||||
- Available on:
|
- Available on:
|
||||||
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||||
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link)
|
- [Vast.ai](https://cloud.vast.ai?ref_id=62897&template_id=bdd4a49fa8bce926defc99471864cace&utm_source=axolotl&utm_medium=partner&utm_campaign=template_launch_july2025&utm_content=docs_link)
|
||||||
@@ -141,7 +109,7 @@ For providers supporting Docker:
|
|||||||
### macOS {#sec-macos}
|
### macOS {#sec-macos}
|
||||||
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
pip3 install --no-build-isolation -e '.'
|
uv pip install --no-build-isolation -e '.'
|
||||||
```
|
```
|
||||||
|
|
||||||
See @sec-troubleshooting for Mac-specific issues.
|
See @sec-troubleshooting for Mac-specific issues.
|
||||||
@@ -152,21 +120,44 @@ See @sec-troubleshooting for Mac-specific issues.
|
|||||||
We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## Environment Managers {#sec-env-managers}
|
## Migrating from pip to uv {#sec-migrating}
|
||||||
|
|
||||||
### Conda/Pip venv {#sec-conda}
|
If you have an existing pip-based Axolotl installation, you can migrate to uv:
|
||||||
|
|
||||||
1. Install Python ≥3.11
|
```{.bash}
|
||||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
# Install uv
|
||||||
3. Install Axolotl:
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
```{.bash}
|
source $HOME/.local/bin/env
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
# Create a fresh venv (recommended for a clean start)
|
||||||
```
|
export UV_TORCH_BACKEND=cu128 # or cu130
|
||||||
4. (Optional) Login to Hugging Face:
|
uv venv
|
||||||
```{.bash}
|
source .venv/bin/activate
|
||||||
hf auth login
|
|
||||||
```
|
# Reinstall axolotl
|
||||||
|
uv pip install --no-build-isolation axolotl[deepspeed]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using pip (Alternative) {#sec-pip}
|
||||||
|
|
||||||
|
If you are unable to install uv, you can still use pip directly.
|
||||||
|
|
||||||
|
::: {.callout-important}
|
||||||
|
Please make sure to have PyTorch installed before installing Axolotl with pip.
|
||||||
|
|
||||||
|
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
||||||
|
:::
|
||||||
|
|
||||||
|
```{.bash}
|
||||||
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
|
pip3 install --no-build-isolation axolotl[deepspeed]
|
||||||
|
```
|
||||||
|
|
||||||
|
For editable/development installs:
|
||||||
|
```{.bash}
|
||||||
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
|
pip3 install --no-build-isolation -e '.[deepspeed]'
|
||||||
|
```
|
||||||
|
|
||||||
## Troubleshooting {#sec-troubleshooting}
|
## Troubleshooting {#sec-troubleshooting}
|
||||||
|
|
||||||
|
|||||||
@@ -320,8 +320,10 @@ The input format is a simple JSON input with customizable fields based on the ab
|
|||||||
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
rl: ipo
|
rl: dpo
|
||||||
|
dpo_loss_type: ["ipo"]
|
||||||
```
|
```
|
||||||
|
*Note:* Passing `rl: ipo` directly is still supported, but will soon be deprecated.
|
||||||
|
|
||||||
### ORPO
|
### ORPO
|
||||||
|
|
||||||
|
|||||||
@@ -1,53 +0,0 @@
|
|||||||
---
|
|
||||||
title: "Unsloth"
|
|
||||||
description: "Hyper-optimized QLoRA finetuning for single GPUs"
|
|
||||||
---
|
|
||||||
|
|
||||||
### Overview
|
|
||||||
|
|
||||||
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
|
|
||||||
standard industry baselines.
|
|
||||||
|
|
||||||
::: {.callout-important}
|
|
||||||
Due to breaking changes in transformers `v4.48.0`, users will need to downgrade to `<=v4.47.1` to use this patch.
|
|
||||||
|
|
||||||
This will later be deprecated in favor of [LoRA Optimizations](lora_optims.qmd).
|
|
||||||
:::
|
|
||||||
|
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
The following will install the correct unsloth and extras from source.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/unsloth_install.py | sh
|
|
||||||
```
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
|
|
||||||
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
|
|
||||||
|
|
||||||
Our unsloth integration is currently limited to the following model architectures:
|
|
||||||
- llama
|
|
||||||
|
|
||||||
These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
|
|
||||||
```yaml
|
|
||||||
unsloth_lora_mlp: true
|
|
||||||
unsloth_lora_qkv: true
|
|
||||||
unsloth_lora_o: true
|
|
||||||
```
|
|
||||||
|
|
||||||
These options are composable and can be used with multi-gpu finetuning
|
|
||||||
```yaml
|
|
||||||
unsloth_cross_entropy_loss: true
|
|
||||||
unsloth_rms_norm: true
|
|
||||||
unsloth_rope: true
|
|
||||||
```
|
|
||||||
|
|
||||||
### Limitations
|
|
||||||
|
|
||||||
- Single GPU only; e.g. no multi-gpu support
|
|
||||||
- No deepspeed or FSDP support (requires multi-gpu)
|
|
||||||
- LoRA + QLoRA support only. No full fine tunes or fp8 support.
|
|
||||||
- Limited model architecture support. Llama, Phi, Gemma, Mistral only
|
|
||||||
- No MoE support.
|
|
||||||
@@ -15,8 +15,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have a compatible version of Pytorch installed
|
# Ensure you have a compatible version of Pytorch installed
|
||||||
pip3 install packaging setuptools wheel ninja
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Run one of the finetuning examples below.
|
2. Run one of the finetuning examples below.
|
||||||
@@ -35,7 +34,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
|||||||
|
|
||||||
**LFM2-MoE**
|
**LFM2-MoE**
|
||||||
```bash
|
```bash
|
||||||
pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
|
uv pip install git+https://github.com/huggingface/transformers.git@0c9a72e4576fe4c84077f066e585129c97bfd4e6
|
||||||
|
|
||||||
# LoRA SFT (1x48GB @ 16.2GiB)
|
# LoRA SFT (1x48GB @ 16.2GiB)
|
||||||
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
|
axolotl train examples/LiquidAI/lfm2-8b-a1b-lora.yaml
|
||||||
@@ -45,7 +44,7 @@ Thanks to the team at LiquidAI for giving us early access to prepare for these r
|
|||||||
|
|
||||||
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
|
- **Installation Error**: If you encounter `ImportError: ... undefined symbol ...` or `ModuleNotFoundError: No module named 'causal_conv1d_cuda'`, the `causal-conv1d` package may have been installed incorrectly. Try uninstalling it:
|
||||||
```bash
|
```bash
|
||||||
pip uninstall -y causal-conv1d
|
uv pip uninstall causal-conv1d
|
||||||
```
|
```
|
||||||
|
|
||||||
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
|
- **Dataset Loading**: Read more on how to load your own dataset in our [documentation](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||||
|
|||||||
@@ -11,12 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
uv pip install --no-build-isolation -e '.'
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
@@ -31,7 +30,7 @@ python scripts/cutcrossentropy_install.py | sh
|
|||||||
# For those using our Docker image, use the below path.
|
# For those using our Docker image, use the below path.
|
||||||
export CUDA_HOME=/usr/local/cuda
|
export CUDA_HOME=/usr/local/cuda
|
||||||
|
|
||||||
pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||||
```
|
```
|
||||||
|
|
||||||
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
|
For any installation errors, see [XIELU Installation Issues](#xielu-installation-issues)
|
||||||
@@ -67,7 +66,7 @@ If those didn't help, please try the below solutions:
|
|||||||
1. Pass env for CMAKE and try install again:
|
1. Pass env for CMAKE and try install again:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
Python_EXECUTABLE=$(which python) pip3 install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
Python_EXECUTABLE=$(which python) uv pip install git+https://github.com/nickjbrowning/XIELU@59d6031 --no-build-isolation --no-deps
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Git clone the repo and manually hardcode python path:
|
2. Git clone the repo and manually hardcode python path:
|
||||||
@@ -92,7 +91,7 @@ If those didn't help, please try the below solutions:
|
|||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install . --no-build-isolation --no-deps
|
uv pip install . --no-build-isolation --no-deps
|
||||||
```
|
```
|
||||||
|
|
||||||
## Optimization Guides
|
## Optimization Guides
|
||||||
|
|||||||
@@ -13,12 +13,11 @@ Thanks to the team at Arcee.ai for using Axolotl in supervised fine-tuning the A
|
|||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
uv pip install --no-build-isolation -e '.'
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
@@ -36,12 +36,7 @@
|
|||||||
"id": "msOCO4NRmRLa"
|
"id": "msOCO4NRmRLa"
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": "%%capture\n# This step can take ~5-10 minutes to install dependencies\n!pip install --no-build-isolation \"axolotl>=0.16.1\"\n!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
|
||||||
"%%capture\n",
|
|
||||||
"# This step can take ~5-10 minutes to install dependencies\n",
|
|
||||||
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
|
|
||||||
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@fec1a88\""
|
|
||||||
]
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
|||||||
@@ -15,9 +15,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||||
|
|||||||
@@ -9,18 +9,17 @@ Gemma-3n is a family of multimodal models from Google found on [HuggingFace](htt
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
2. In addition to Axolotl's requirements, Gemma-3n requires:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install timm==1.0.17
|
uv pip install timm==1.0.17
|
||||||
|
|
||||||
# for loading audio data
|
# for loading audio data
|
||||||
pip3 install librosa==0.11.0
|
uv pip install librosa==0.11.0
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Download sample dataset files
|
3. Download sample dataset files
|
||||||
|
|||||||
@@ -13,9 +13,8 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
|
2. Choose one of the following configs below for training the 20B model. (for 120B, see [below](#training-120b))
|
||||||
@@ -87,7 +86,7 @@ for more information about using a special vllm-openai docker image for inferenc
|
|||||||
Optionally, vLLM can be installed from nightly:
|
Optionally, vLLM can be installed from nightly:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
|
uv pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
|
||||||
```
|
```
|
||||||
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
|
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -11,12 +11,11 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.7.1 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
uv pip install --no-build-isolation -e '.'
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
@@ -9,12 +9,11 @@ Tencent released a family of opensource models called HunYuan with varying param
|
|||||||
Here is an example of how to install from main for pip:
|
Here is an example of how to install from main for pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
uv pip install --no-build-isolation -e '.'
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn]'
|
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ This guide shows how to fine-tune it with Axolotl.
|
|||||||
2. Install `timm` for vision model support:
|
2. Install `timm` for vision model support:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install timm==1.0.19
|
uv pip install timm==1.0.19
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
3. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage.
|
||||||
|
|||||||
@@ -13,9 +13,8 @@ Thanks to the team at MistralAI for giving us early access to prepare for these
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.7.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
2. Install [Cut Cross Entropy](https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy) to reduce training VRAM usage
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ Before starting, ensure you have:
|
|||||||
|
|
||||||
1. Install the required vision lib:
|
1. Install the required vision lib:
|
||||||
```bash
|
```bash
|
||||||
pip install 'mistral-common[opencv]==1.8.5'
|
uv pip install 'mistral-common[opencv]==1.8.5'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Download the example dataset image:
|
2. Download the example dataset image:
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ Note: This is still experimental given it is based on transformers v5 RC.
|
|||||||
git checkout transformers-v5
|
git checkout transformers-v5
|
||||||
|
|
||||||
# Install packages for transformers v5
|
# Install packages for transformers v5
|
||||||
pip install -e .
|
uv pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Run the fine-tuning:
|
4. Run the fine-tuning:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ Before starting, ensure you have:
|
|||||||
|
|
||||||
1. Install the required vision lib:
|
1. Install the required vision lib:
|
||||||
```bash
|
```bash
|
||||||
pip install 'mistral-common[opencv]==1.8.6'
|
uv pip install 'mistral-common[opencv]==1.8.6'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Download the example dataset image:
|
2. Download the example dataset image:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ Before starting, ensure you have:
|
|||||||
|
|
||||||
1. Install the required vision lib:
|
1. Install the required vision lib:
|
||||||
```bash
|
```bash
|
||||||
pip install 'mistral-common[opencv]==1.8.5'
|
uv pip install 'mistral-common[opencv]==1.8.5'
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Download the example dataset image:
|
2. Download the example dataset image:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
3. Install transformers from main
|
3. Install transformers from main
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install git+https://github.com/huggingface/transformers.git
|
uv pip install git+https://github.com/huggingface/transformers.git
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Run one of the example configs:
|
4. Run one of the example configs:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
|
|
||||||
3. Install FLA for improved performance
|
3. Install FLA for improved performance
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
uv pip uninstall causal-conv1d && uv pip install flash-linear-attention==0.4.1
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Run the finetuning example:
|
4. Run the finetuning example:
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
|
|
||||||
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
|
3. Install FLA for sample packing support with the Gated DeltaNet linear attention layers:
|
||||||
```bash
|
```bash
|
||||||
pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1
|
uv pip uninstall causal-conv1d && uv pip install flash-linear-attention==0.4.1
|
||||||
```
|
```
|
||||||
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
|
> FLA is required when `sample_packing: true`. Without it, training raises a `RuntimeError` on packed sequences. Vision configs use `sample_packing: false` so FLA is optional there.
|
||||||
|
|
||||||
|
|||||||
@@ -11,8 +11,7 @@ This guide shows how to fine-tune it with Axolotl with multi-turn conversations
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have a compatible version of Pytorch installed
|
# Ensure you have a compatible version of Pytorch installed
|
||||||
pip3 install packaging setuptools wheel ninja
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
|
||||||
|
|
||||||
# Install Cut Cross Entropy
|
# Install Cut Cross Entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
@@ -13,14 +13,13 @@ This guide shows how to fine-tune SmolVLM2 models with Axolotl.
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have a compatible version of Pytorch installed
|
# Ensure you have a compatible version of Pytorch installed
|
||||||
pip3 install packaging setuptools wheel ninja
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Install an extra dependency:
|
2. Install an extra dependency:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install num2words==0.5.14
|
uv pip install num2words==0.5.14
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Run the finetuning example:
|
3. Run the finetuning example:
|
||||||
|
|||||||
@@ -11,17 +11,16 @@ Thanks to the team at MistralAI for giving us early access to prepare for this r
|
|||||||
Here is an example of how to install from pip:
|
Here is an example of how to install from pip:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 min)
|
# Ensure you have Pytorch installed (Pytorch 2.9.1 min)
|
||||||
pip3 install packaging==26.0 setuptools==75.8.0 wheel ninja
|
uv pip install --no-build-isolation 'axolotl>=0.16.1'
|
||||||
pip3 install --no-build-isolation 'axolotl[flash-attn]>=0.12.0'
|
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Please install the below.
|
2. Please install the below.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# audio
|
# audio
|
||||||
pip3 install librosa==0.11.0
|
uv pip install librosa==0.11.0
|
||||||
pip3 install 'mistral_common[audio]==1.8.3'
|
uv pip install 'mistral_common[audio]==1.8.3'
|
||||||
|
|
||||||
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
# Install CCE https://docs.axolotl.ai/docs/custom_integrations.html#cut-cross-entropy
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
203
pyproject.toml
203
pyproject.toml
@@ -1,15 +1,165 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==26.0"]
|
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "axolotl"
|
name = "axolotl"
|
||||||
dynamic = ["version", "dependencies", "optional-dependencies"]
|
dynamic = ["version"]
|
||||||
description = "LLM Trainer"
|
description = "LLM Trainer"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
# license = "Apache-2.0"
|
# license = "Apache-2.0"
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
# Core ML stack
|
||||||
|
"torch>=2.9.1",
|
||||||
|
"packaging==26.0",
|
||||||
|
"huggingface_hub>=1.1.7",
|
||||||
|
"peft>=0.19.1,<0.20.0",
|
||||||
|
"tokenizers>=0.22.1",
|
||||||
|
"transformers==5.5.4",
|
||||||
|
"accelerate==1.13.0",
|
||||||
|
"datasets>=4.8.4,<4.9.0",
|
||||||
|
"trl==1.1.0",
|
||||||
|
"hf_xet==1.4.3",
|
||||||
|
"kernels==0.13.0",
|
||||||
|
"trackio>=0.16.1",
|
||||||
|
"typing-extensions>=4.15.0",
|
||||||
|
"optimum==1.16.2",
|
||||||
|
"hf_transfer",
|
||||||
|
"sentencepiece",
|
||||||
|
"gradio>=6.2.0,<7.0",
|
||||||
|
"modal==1.3.0.post1",
|
||||||
|
"pydantic>=2.10.6",
|
||||||
|
"addict",
|
||||||
|
"fire",
|
||||||
|
"PyYAML>=6.0",
|
||||||
|
"requests",
|
||||||
|
"wandb",
|
||||||
|
"einops",
|
||||||
|
"colorama",
|
||||||
|
"numba>=0.61.2",
|
||||||
|
"numpy>=2.2.6",
|
||||||
|
|
||||||
|
# Evaluation & metrics
|
||||||
|
"evaluate==0.4.1",
|
||||||
|
"scipy",
|
||||||
|
"nvidia-ml-py==12.560.30",
|
||||||
|
"art",
|
||||||
|
"tensorboard",
|
||||||
|
"python-dotenv==1.0.1",
|
||||||
|
|
||||||
|
# Remote filesystems
|
||||||
|
"s3fs>=2024.5.0",
|
||||||
|
"gcsfs>=2025.3.0",
|
||||||
|
"adlfs>=2024.5.0",
|
||||||
|
"ocifs==1.3.2",
|
||||||
|
|
||||||
|
"zstandard==0.22.0",
|
||||||
|
"fastcore",
|
||||||
|
|
||||||
|
# lm eval harness
|
||||||
|
"lm_eval==0.4.11",
|
||||||
|
"langdetect==1.0.9",
|
||||||
|
"immutabledict==4.2.0",
|
||||||
|
"antlr4-python3-runtime==4.13.2",
|
||||||
|
|
||||||
|
"schedulefree==1.4.1",
|
||||||
|
"openenv-core==0.1.0",
|
||||||
|
|
||||||
|
# Axolotl contribs
|
||||||
|
"axolotl-contribs-lgpl==0.0.7",
|
||||||
|
"axolotl-contribs-mit==0.0.6",
|
||||||
|
|
||||||
|
# Telemetry
|
||||||
|
"posthog==6.7.11",
|
||||||
|
|
||||||
|
"mistral-common==1.11.0",
|
||||||
|
|
||||||
|
# Platform-specific (Linux only)
|
||||||
|
"bitsandbytes==0.49.1 ; sys_platform != 'darwin'",
|
||||||
|
"triton>=3.4.0 ; sys_platform != 'darwin'",
|
||||||
|
"xformers>=0.0.33.post2 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
|
||||||
|
"liger-kernel==0.7.0 ; sys_platform != 'darwin'",
|
||||||
|
"torchao==0.17.0 ; sys_platform != 'darwin' and platform_machine != 'aarch64'",
|
||||||
|
|
||||||
|
# Architecture-specific
|
||||||
|
"fla-core==0.4.1 ; platform_machine != 'aarch64'",
|
||||||
|
"flash-linear-attention==0.4.1 ; platform_machine != 'aarch64'",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
flash-attn = ["flash-attn==2.8.3"]
|
||||||
|
ring-flash-attn = [
|
||||||
|
"flash-attn==2.8.3",
|
||||||
|
"ring-flash-attn>=0.1.7",
|
||||||
|
]
|
||||||
|
deepspeed = [
|
||||||
|
"deepspeed>=0.18.6,<0.19.0",
|
||||||
|
"deepspeed-kernels",
|
||||||
|
]
|
||||||
|
mamba-ssm = [
|
||||||
|
"mamba-ssm==1.2.0.post1",
|
||||||
|
"causal_conv1d",
|
||||||
|
]
|
||||||
|
auto-gptq = [
|
||||||
|
"auto-gptq==0.5.1",
|
||||||
|
]
|
||||||
|
mlflow = [
|
||||||
|
"mlflow",
|
||||||
|
]
|
||||||
|
galore = [
|
||||||
|
"galore_torch",
|
||||||
|
]
|
||||||
|
apollo = [
|
||||||
|
"apollo-torch",
|
||||||
|
]
|
||||||
|
optimizers = [
|
||||||
|
"galore_torch",
|
||||||
|
"apollo-torch",
|
||||||
|
"lomo-optim==0.1.1",
|
||||||
|
"torch-optimi==0.2.1",
|
||||||
|
"came_pytorch==0.1.3",
|
||||||
|
]
|
||||||
|
ray = [
|
||||||
|
"ray[train]>=2.52.1",
|
||||||
|
]
|
||||||
|
vllm = [
|
||||||
|
"vllm>=0.15.0",
|
||||||
|
]
|
||||||
|
llmcompressor = [
|
||||||
|
"llmcompressor>=0.10.0",
|
||||||
|
]
|
||||||
|
fbgemm-gpu = ["fbgemm-gpu-genai>=1.3.0"]
|
||||||
|
opentelemetry = [
|
||||||
|
"opentelemetry-api",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"opentelemetry-exporter-prometheus",
|
||||||
|
"prometheus-client",
|
||||||
|
]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"black",
|
||||||
|
"mypy",
|
||||||
|
"pre-commit",
|
||||||
|
"types-requests",
|
||||||
|
"quartodoc",
|
||||||
|
"jupyter",
|
||||||
|
"blobfile",
|
||||||
|
"tiktoken",
|
||||||
|
]
|
||||||
|
test = [
|
||||||
|
"codecov",
|
||||||
|
"codecov-cli",
|
||||||
|
"pytest",
|
||||||
|
"pytest-cov",
|
||||||
|
"pytest-retry",
|
||||||
|
"pytest-sugar",
|
||||||
|
"pytest-xdist",
|
||||||
|
"tbparse",
|
||||||
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
axolotl = "axolotl.cli.main:main"
|
axolotl = "axolotl.cli.main:main"
|
||||||
|
|
||||||
@@ -18,18 +168,15 @@ Homepage = "https://axolotl.ai/"
|
|||||||
Documentation = "https://docs.axolotl.ai/"
|
Documentation = "https://docs.axolotl.ai/"
|
||||||
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
|
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
|
||||||
|
|
||||||
[tool.setuptools_scm]
|
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
py-modules = ["setuptools_axolotl_dynamic_dependencies"]
|
|
||||||
include-package-data = true
|
include-package-data = true
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["src"]
|
||||||
|
|
||||||
[tool.setuptools.dynamic]
|
[tool.setuptools.dynamic]
|
||||||
version = { file = "VERSION" }
|
version = { file = "VERSION" }
|
||||||
|
|
||||||
[tool.setuptools.cmdclass]
|
|
||||||
build_py = "setuptools_axolotl_dynamic_dependencies.BuildPyCommand"
|
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 88
|
line-length = 88
|
||||||
target-version = "py310"
|
target-version = "py310"
|
||||||
@@ -67,5 +214,43 @@ markers = [
|
|||||||
"slow: marks tests as slow",
|
"slow: marks tests as slow",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# UV specific configuration
|
||||||
|
[tool.uv]
|
||||||
|
prerelease = "allow"
|
||||||
|
conflicts = [
|
||||||
|
[
|
||||||
|
{ package = "axolotl" },
|
||||||
|
{ extra = "vllm" },
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{ package = "axolotl" },
|
||||||
|
{ extra = "flash-attn" },
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{ package = "axolotl" },
|
||||||
|
{ extra = "ring-flash-attn" },
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{ package = "axolotl" },
|
||||||
|
{ extra = "mamba-ssm" },
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{ package = "axolotl" },
|
||||||
|
{ extra = "auto-gptq" },
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{ package = "axolotl" },
|
||||||
|
{ extra = "fbgemm-gpu" },
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{ package = "axolotl" },
|
||||||
|
{ extra = "llmcompressor" },
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
[tool.uv.extra-build-dependencies]
|
[tool.uv.extra-build-dependencies]
|
||||||
axolotl = ["huggingface_hub"]
|
mamba-ssm = [{ requirement = "torch", match-runtime = true }]
|
||||||
|
causal-conv1d = [{ requirement = "torch", match-runtime = true }]
|
||||||
|
flash-attn = [{ requirement = "torch", match-runtime = true }]
|
||||||
|
deepspeed = [{ requirement = "torch", match-runtime = true }]
|
||||||
|
auto-gptq = [{ requirement = "torch", match-runtime = true }]
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
black
|
|
||||||
mypy
|
|
||||||
pre-commit
|
|
||||||
types-requests
|
|
||||||
quartodoc
|
|
||||||
jupyter
|
|
||||||
blobfile
|
|
||||||
tiktoken
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
codecov
|
|
||||||
codecov-cli
|
|
||||||
pytest
|
|
||||||
pytest-cov
|
|
||||||
pytest-retry
|
|
||||||
pytest-sugar
|
|
||||||
pytest-xdist
|
|
||||||
tbparse
|
|
||||||
@@ -1,78 +0,0 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
|
||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
|
||||||
bitsandbytes==0.49.1
|
|
||||||
triton>=3.4.0
|
|
||||||
mamba-ssm==1.2.0.post1
|
|
||||||
xformers>=0.0.23.post1
|
|
||||||
liger-kernel==0.7.0
|
|
||||||
# END section
|
|
||||||
|
|
||||||
packaging==26.0
|
|
||||||
huggingface_hub>=1.1.7
|
|
||||||
peft>=0.19.0,<0.20.0
|
|
||||||
tokenizers>=0.22.1
|
|
||||||
transformers==5.5.4
|
|
||||||
accelerate==1.13.0
|
|
||||||
datasets>=4.8.4,<4.9.0
|
|
||||||
deepspeed>=0.18.6,<0.19.0
|
|
||||||
trl==1.1.0
|
|
||||||
hf_xet==1.4.3
|
|
||||||
kernels==0.13.0
|
|
||||||
|
|
||||||
fla-core==0.4.1
|
|
||||||
flash-linear-attention==0.4.1
|
|
||||||
|
|
||||||
trackio>=0.16.1
|
|
||||||
typing-extensions>=4.15.0
|
|
||||||
|
|
||||||
optimum==1.16.2
|
|
||||||
hf_transfer
|
|
||||||
sentencepiece
|
|
||||||
gradio>=6.2.0,<7.0
|
|
||||||
|
|
||||||
modal==1.3.0.post1
|
|
||||||
pydantic>=2.10.6
|
|
||||||
addict
|
|
||||||
fire
|
|
||||||
PyYAML>=6.0
|
|
||||||
requests
|
|
||||||
wandb
|
|
||||||
einops
|
|
||||||
colorama
|
|
||||||
numba>=0.61.2
|
|
||||||
numpy>=2.2.6
|
|
||||||
|
|
||||||
# qlora things
|
|
||||||
evaluate==0.4.1
|
|
||||||
scipy
|
|
||||||
nvidia-ml-py==12.560.30
|
|
||||||
art
|
|
||||||
tensorboard
|
|
||||||
python-dotenv==1.0.1
|
|
||||||
|
|
||||||
# remote filesystems
|
|
||||||
s3fs>=2024.5.0
|
|
||||||
gcsfs>=2025.3.0
|
|
||||||
adlfs>=2024.5.0
|
|
||||||
ocifs==1.3.2
|
|
||||||
|
|
||||||
zstandard==0.22.0
|
|
||||||
fastcore
|
|
||||||
|
|
||||||
# lm eval harness
|
|
||||||
lm_eval==0.4.11
|
|
||||||
langdetect==1.0.9
|
|
||||||
immutabledict==4.2.0
|
|
||||||
antlr4-python3-runtime==4.13.2
|
|
||||||
|
|
||||||
torchao==0.17.0
|
|
||||||
openenv-core==0.1.0
|
|
||||||
schedulefree==1.4.1
|
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.7
|
|
||||||
axolotl-contribs-mit==0.0.6
|
|
||||||
# telemetry
|
|
||||||
posthog==6.7.11
|
|
||||||
|
|
||||||
mistral-common==1.11.0
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
# noqa
|
|
||||||
import sys
|
|
||||||
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
except ImportError as error:
|
|
||||||
raise ImportError("Install torch via `pip install torch`") from error
|
|
||||||
from packaging.version import Version as V
|
|
||||||
|
|
||||||
use_uv = "--uv" in sys.argv[1:]
|
|
||||||
|
|
||||||
v = V(torch.__version__)
|
|
||||||
cuda = str(torch.version.cuda)
|
|
||||||
try:
|
|
||||||
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
|
||||||
except RuntimeError:
|
|
||||||
is_ampere = False
|
|
||||||
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
|
|
||||||
raise RuntimeError(f"CUDA = {cuda} not supported!")
|
|
||||||
if v <= V("2.1.0"):
|
|
||||||
raise RuntimeError(f"Torch = {v} too old!")
|
|
||||||
elif v <= V("2.1.1"):
|
|
||||||
x = "cu{}{}-torch211"
|
|
||||||
elif v <= V("2.1.2"):
|
|
||||||
x = "cu{}{}-torch212"
|
|
||||||
elif v < V("2.3.0"):
|
|
||||||
x = "cu{}{}-torch220"
|
|
||||||
elif v < V("2.4.0"):
|
|
||||||
x = "cu{}{}-torch230"
|
|
||||||
elif v < V("2.5.0"):
|
|
||||||
x = "cu{}{}-torch240"
|
|
||||||
elif v < V("2.6.0"):
|
|
||||||
x = "cu{}{}-torch250"
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Torch = {v} too new!")
|
|
||||||
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
|
||||||
uv_prefix = "uv " if use_uv else ""
|
|
||||||
print(
|
|
||||||
f'{uv_prefix}pip install unsloth-zoo==2024.12.1 && {uv_prefix}pip install --no-deps "unsloth[{x}]==2024.12.4"'
|
|
||||||
)
|
|
||||||
230
setup.py
230
setup.py
@@ -1,230 +0,0 @@
|
|||||||
"""setup.py for axolotl"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import re
|
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
|
||||||
|
|
||||||
|
|
||||||
def parse_requirements(extras_require_map):
|
|
||||||
_install_requires = []
|
|
||||||
_dependency_links = []
|
|
||||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
|
||||||
lines = [r.strip() for r in requirements_file.readlines()]
|
|
||||||
for line in lines:
|
|
||||||
is_extras = "deepspeed" in line or "mamba-ssm" in line
|
|
||||||
if line.startswith("--extra-index-url"):
|
|
||||||
# Handle custom index URLs
|
|
||||||
_, url = line.split()
|
|
||||||
_dependency_links.append(url)
|
|
||||||
elif not is_extras and line and line[0] != "#":
|
|
||||||
# Handle standard packages
|
|
||||||
_install_requires.append(line)
|
|
||||||
try:
|
|
||||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
|
||||||
install_xformers = platform.machine() != "aarch64"
|
|
||||||
if platform.machine() == "aarch64":
|
|
||||||
# skip on ARM64
|
|
||||||
skip_packages = [
|
|
||||||
"torchao",
|
|
||||||
"fla-core",
|
|
||||||
"flash-linear-attention",
|
|
||||||
]
|
|
||||||
_install_requires = [
|
|
||||||
req
|
|
||||||
for req in _install_requires
|
|
||||||
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
|
|
||||||
]
|
|
||||||
if "Darwin" in platform.system():
|
|
||||||
# skip packages not compatible with OSX
|
|
||||||
skip_packages = [
|
|
||||||
"bitsandbytes",
|
|
||||||
"triton",
|
|
||||||
"mamba-ssm",
|
|
||||||
"xformers",
|
|
||||||
"liger-kernel",
|
|
||||||
]
|
|
||||||
_install_requires = [
|
|
||||||
req
|
|
||||||
for req in _install_requires
|
|
||||||
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
|
|
||||||
]
|
|
||||||
print(
|
|
||||||
_install_requires, [req in skip_packages for req in _install_requires]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# detect the version of torch already installed
|
|
||||||
# and set it so dependencies don't clobber the torch version
|
|
||||||
try:
|
|
||||||
torch_version = version("torch")
|
|
||||||
except PackageNotFoundError:
|
|
||||||
torch_version = "2.8.0" # default to torch 2.8.0
|
|
||||||
_install_requires.append(f"torch=={torch_version}")
|
|
||||||
|
|
||||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
|
||||||
if version_match:
|
|
||||||
major, minor, patch = version_match.groups()
|
|
||||||
major, minor = int(major), int(minor)
|
|
||||||
patch = (
|
|
||||||
int(patch) if patch is not None else 0
|
|
||||||
) # Default patch to 0 if not present
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid version format")
|
|
||||||
|
|
||||||
torch_parts = torch_version.split("+")
|
|
||||||
if len(torch_parts) == 2:
|
|
||||||
torch_cuda_version = torch_parts[1]
|
|
||||||
_dependency_links.append(
|
|
||||||
f"https://download.pytorch.org/whl/{torch_cuda_version}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (major, minor) >= (2, 10):
|
|
||||||
extras_require_map.pop("fbgemm-gpu")
|
|
||||||
extras_require_map["fbgemm-gpu"] = [
|
|
||||||
"fbgemm-gpu==1.5.0",
|
|
||||||
"fbgemm-gpu-genai==1.5.0",
|
|
||||||
]
|
|
||||||
if not install_xformers:
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
extras_require_map["vllm"] = ["vllm>=0.19.0"]
|
|
||||||
elif (major, minor) >= (2, 9):
|
|
||||||
extras_require_map.pop("fbgemm-gpu")
|
|
||||||
extras_require_map["fbgemm-gpu"] = [
|
|
||||||
"fbgemm-gpu==1.4.0",
|
|
||||||
"fbgemm-gpu-genai==1.4.2",
|
|
||||||
]
|
|
||||||
if not install_xformers:
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
if patch == 0:
|
|
||||||
extras_require_map["vllm"] = ["vllm==0.13.0"]
|
|
||||||
else:
|
|
||||||
extras_require_map["vllm"] = ["vllm==0.14.0"]
|
|
||||||
elif (major, minor) >= (2, 8):
|
|
||||||
extras_require_map.pop("fbgemm-gpu")
|
|
||||||
extras_require_map["fbgemm-gpu"] = ["fbgemm-gpu-genai==1.3.0"]
|
|
||||||
extras_require_map["vllm"] = ["vllm==0.11.0"]
|
|
||||||
if not install_xformers:
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
elif (major, minor) >= (2, 7):
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
if patch == 0:
|
|
||||||
if install_xformers:
|
|
||||||
_install_requires.append("xformers==0.0.30")
|
|
||||||
# vllm 0.9.x is incompatible with latest transformers
|
|
||||||
extras_require_map.pop("vllm")
|
|
||||||
else:
|
|
||||||
if install_xformers:
|
|
||||||
_install_requires.append("xformers==0.0.31")
|
|
||||||
extras_require_map["vllm"] = ["vllm==0.10.1"]
|
|
||||||
elif (major, minor) >= (2, 6):
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
if install_xformers:
|
|
||||||
_install_requires.append("xformers==0.0.29.post3")
|
|
||||||
# since we only support 2.6.0+cu126
|
|
||||||
_dependency_links.append("https://download.pytorch.org/whl/cu126")
|
|
||||||
extras_require_map.pop("vllm")
|
|
||||||
elif (major, minor) >= (2, 5):
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
if install_xformers:
|
|
||||||
if patch == 0:
|
|
||||||
_install_requires.append("xformers==0.0.28.post2")
|
|
||||||
else:
|
|
||||||
_install_requires.append("xformers>=0.0.28.post3")
|
|
||||||
extras_require_map.pop("vllm")
|
|
||||||
elif (major, minor) >= (2, 4):
|
|
||||||
extras_require_map.pop("vllm")
|
|
||||||
if install_xformers:
|
|
||||||
if patch == 0:
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
_install_requires.append("xformers>=0.0.27")
|
|
||||||
else:
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
_install_requires.append("xformers==0.0.28.post1")
|
|
||||||
else:
|
|
||||||
raise ValueError("axolotl requires torch>=2.4")
|
|
||||||
|
|
||||||
except PackageNotFoundError:
|
|
||||||
pass
|
|
||||||
return _install_requires, _dependency_links, extras_require_map
|
|
||||||
|
|
||||||
|
|
||||||
def get_package_version():
|
|
||||||
with open(
|
|
||||||
Path(os.path.dirname(os.path.abspath(__file__))) / "VERSION",
|
|
||||||
"r",
|
|
||||||
encoding="utf-8",
|
|
||||||
) as fin:
|
|
||||||
version_ = fin.read().strip()
|
|
||||||
return version_
|
|
||||||
|
|
||||||
|
|
||||||
extras_require = {
|
|
||||||
"flash-attn": ["flash-attn==2.8.3"],
|
|
||||||
"ring-flash-attn": [
|
|
||||||
"flash-attn==2.8.3",
|
|
||||||
"ring-flash-attn>=0.1.7",
|
|
||||||
],
|
|
||||||
"deepspeed": [
|
|
||||||
"deepspeed==0.18.2",
|
|
||||||
"deepspeed-kernels",
|
|
||||||
],
|
|
||||||
"mamba-ssm": [
|
|
||||||
"mamba-ssm==1.2.0.post1",
|
|
||||||
"causal_conv1d",
|
|
||||||
],
|
|
||||||
"auto-gptq": [
|
|
||||||
"auto-gptq==0.5.1",
|
|
||||||
],
|
|
||||||
"mlflow": [
|
|
||||||
"mlflow",
|
|
||||||
],
|
|
||||||
"galore": [
|
|
||||||
"galore_torch",
|
|
||||||
],
|
|
||||||
"apollo": [
|
|
||||||
"apollo-torch",
|
|
||||||
],
|
|
||||||
"optimizers": [
|
|
||||||
"galore_torch",
|
|
||||||
"apollo-torch",
|
|
||||||
"lomo-optim==0.1.1",
|
|
||||||
"torch-optimi==0.2.1",
|
|
||||||
"came_pytorch==0.1.3",
|
|
||||||
],
|
|
||||||
"ray": [
|
|
||||||
"ray[train]>=2.52.1",
|
|
||||||
],
|
|
||||||
"vllm": [
|
|
||||||
"vllm==0.10.0",
|
|
||||||
],
|
|
||||||
"llmcompressor": [
|
|
||||||
"llmcompressor==0.5.1",
|
|
||||||
],
|
|
||||||
"fbgemm-gpu": ["fbgemm-gpu-genai==1.3.0"],
|
|
||||||
"opentelemetry": [
|
|
||||||
"opentelemetry-api",
|
|
||||||
"opentelemetry-sdk",
|
|
||||||
"opentelemetry-exporter-prometheus",
|
|
||||||
"prometheus-client",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
|
||||||
extras_require
|
|
||||||
)
|
|
||||||
|
|
||||||
setup(
|
|
||||||
version=get_package_version(),
|
|
||||||
package_dir={"": "src"},
|
|
||||||
packages=find_packages("src"),
|
|
||||||
install_requires=install_requires,
|
|
||||||
dependency_links=dependency_links,
|
|
||||||
entry_points={
|
|
||||||
"console_scripts": [
|
|
||||||
"axolotl=axolotl.cli.main:main",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
extras_require=extras_require_build,
|
|
||||||
)
|
|
||||||
@@ -339,7 +339,11 @@ def _build_peft_layer_and_get_delta(
|
|||||||
)
|
)
|
||||||
layer.lora_A[adapter_name].weight.data = lora_a
|
layer.lora_A[adapter_name].weight.data = lora_a
|
||||||
layer.lora_B[adapter_name].weight.data = lora_b
|
layer.lora_B[adapter_name].weight.data = lora_b
|
||||||
return layer.get_delta_weight(adapter_name)
|
delta = layer.get_delta_weight(adapter_name)
|
||||||
|
# peft >=0.19.1 may return delta with transposed dims for 3D params
|
||||||
|
if delta.shape != base_tensor.shape and delta.ndim == 3:
|
||||||
|
delta = delta.transpose(1, 2).contiguous()
|
||||||
|
return delta
|
||||||
elif (
|
elif (
|
||||||
layer_type and "Conv" in layer_type or (layer_type is None and lora_a.ndim > 2)
|
layer_type and "Conv" in layer_type or (layer_type is None and lora_a.ndim > 2)
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -370,7 +370,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
data_collator_kwargs = {
|
data_collator_kwargs = {
|
||||||
"padding": True, # True/"longest" is the default
|
"padding": True, # True/"longest" is the default
|
||||||
}
|
}
|
||||||
multiple = 64
|
multiple = getattr(self.cfg, "pad_to_multiple_of", None) or 64
|
||||||
if self.cfg.pad_to_sequence_len:
|
if self.cfg.pad_to_sequence_len:
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
||||||
self.cfg.sequence_len / multiple
|
self.cfg.sequence_len / multiple
|
||||||
|
|||||||
@@ -228,9 +228,47 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
return training_args, trainer_kwargs
|
return training_args, trainer_kwargs
|
||||||
|
|
||||||
|
def build_collator(self, **kwargs):
|
||||||
|
"""Build a data collator for preference-tuning trainers.
|
||||||
|
|
||||||
|
Returns None for RL types that provide their own collator (e.g. GRPO,
|
||||||
|
KTO), letting the trainer construct its default. For DPO/IPO/ORPO/SIMPO
|
||||||
|
returns an ``AxolotlDPODataCollatorWithPadding`` when
|
||||||
|
``pad_to_multiple_of`` is set, otherwise None (so the trainer
|
||||||
|
falls back to the TRL default).
|
||||||
|
"""
|
||||||
|
if self.cfg.rl not in (
|
||||||
|
RLType.DPO,
|
||||||
|
RLType.IPO,
|
||||||
|
RLType.ORPO,
|
||||||
|
RLType.SIMPO,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
pad_to_multiple_of = getattr(self.cfg, "pad_to_multiple_of", None)
|
||||||
|
if not pad_to_multiple_of:
|
||||||
|
return None
|
||||||
|
|
||||||
|
from axolotl.utils.collators.dpo import AxolotlDPODataCollatorWithPadding
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"Using AxolotlDPODataCollatorWithPadding with pad_to_multiple_of="
|
||||||
|
f"{pad_to_multiple_of}"
|
||||||
|
)
|
||||||
|
is_enc_dec = getattr(self.model.config, "is_encoder_decoder", False)
|
||||||
|
return AxolotlDPODataCollatorWithPadding(
|
||||||
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
is_encoder_decoder=is_enc_dec,
|
||||||
|
pad_to_multiple_of=pad_to_multiple_of,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
training_args, trainer_kwargs = self._build_training_arguments(total_num_steps)
|
training_args, trainer_kwargs = self._build_training_arguments(total_num_steps)
|
||||||
|
|
||||||
|
if (data_collator := self.build_collator()) is not None:
|
||||||
|
trainer_kwargs["data_collator"] = data_collator
|
||||||
|
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -20,8 +20,16 @@ class DPOStrategy:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def set_training_args_kwargs(cls, cfg):
|
def set_training_args_kwargs(cls, cfg):
|
||||||
training_args_kwargs = {}
|
training_args_kwargs = {}
|
||||||
|
if cfg.rl is RLType.DPO:
|
||||||
|
if cfg.dpo_loss_type is not None:
|
||||||
|
training_args_kwargs["loss_type"] = cfg.dpo_loss_type
|
||||||
|
|
||||||
|
if cfg.dpo_loss_weights is not None:
|
||||||
|
training_args_kwargs["loss_weights"] = cfg.dpo_loss_weights
|
||||||
|
|
||||||
if cfg.rl is RLType.IPO:
|
if cfg.rl is RLType.IPO:
|
||||||
training_args_kwargs["loss_type"] = ["ipo"]
|
training_args_kwargs["loss_type"] = ["ipo"]
|
||||||
|
|
||||||
# Label smoothing is not compatible with IPO
|
# Label smoothing is not compatible with IPO
|
||||||
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
if cfg.rl is RLType.DPO and cfg.dpo_label_smoothing:
|
||||||
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
training_args_kwargs["label_smoothing"] = cfg.dpo_label_smoothing
|
||||||
|
|||||||
27
src/axolotl/integrations/hatchery/__init__.py
Normal file
27
src/axolotl/integrations/hatchery/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) Axolotl AI
|
||||||
|
# Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
|
"""Hatchery/Tinker remote training integration for Axolotl.
|
||||||
|
|
||||||
|
Routes axolotl's preprocessed data to a remote training API (Tinker or
|
||||||
|
Hatchery) instead of running forward/backward locally. The remote
|
||||||
|
service handles model weights, LoRA adapters, and gradient updates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .args import HatcheryArgs, HatcheryConfig
|
||||||
|
from .plugin import HatcheryPlugin
|
||||||
|
|
||||||
|
__all__ = ["HatcheryArgs", "HatcheryConfig", "HatcheryPlugin"]
|
||||||
|
|
||||||
|
# Usage:
|
||||||
|
# plugins:
|
||||||
|
# - axolotl.integrations.hatchery.HatcheryPlugin
|
||||||
|
#
|
||||||
|
# hatchery:
|
||||||
|
# backend: tinker # or "hatchery"
|
||||||
|
# lora_rank: 32
|
||||||
|
# loss_fn: cross_entropy # SFT
|
||||||
|
# # loss_fn: ppo # RL (auto-selects HatcheryRLTrainer)
|
||||||
|
#
|
||||||
|
# learning_rate: 1e-4 # top-level, not under hatchery:
|
||||||
62
src/axolotl/integrations/hatchery/args.py
Normal file
62
src/axolotl/integrations/hatchery/args.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) Axolotl AI
|
||||||
|
# Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
|
"""Pydantic config schema for the Hatchery integration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class HatcheryConfig(BaseModel):
|
||||||
|
"""Nested config under `hatchery:` in the axolotl YAML.
|
||||||
|
|
||||||
|
Only contains hatchery-specific settings. Standard training params
|
||||||
|
(learning_rate, weight_decay, adam_beta1/2, max_grad_norm,
|
||||||
|
gradient_accumulation_steps) are read from axolotl's top-level config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Backend & connection
|
||||||
|
backend: Literal["tinker", "hatchery"] = "tinker"
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
project_id: Optional[str] = None
|
||||||
|
|
||||||
|
# LoRA config sent to remote
|
||||||
|
lora_rank: int = Field(32, ge=1, le=256)
|
||||||
|
train_attn: bool = True
|
||||||
|
train_mlp: bool = True
|
||||||
|
train_unembed: bool = True
|
||||||
|
|
||||||
|
# Loss function
|
||||||
|
loss_fn: Literal["cross_entropy", "importance_sampling", "ppo", "cispo", "dro"] = (
|
||||||
|
"cross_entropy"
|
||||||
|
)
|
||||||
|
loss_fn_config: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
# Pipelining: submit next batch before awaiting previous result
|
||||||
|
pipeline: bool = True
|
||||||
|
|
||||||
|
# Sampling params (for RL flows)
|
||||||
|
max_sample_tokens: int = 256
|
||||||
|
sample_temperature: float = 1.0
|
||||||
|
num_samples: int = 4
|
||||||
|
|
||||||
|
# Reward functions (for RL) — list of fully qualified names
|
||||||
|
reward_funcs: Optional[list[str]] = None
|
||||||
|
|
||||||
|
# Checkpointing
|
||||||
|
save_steps: Optional[int] = None
|
||||||
|
save_name_prefix: str = "checkpoint"
|
||||||
|
|
||||||
|
# Timeout per future (seconds)
|
||||||
|
future_timeout: float = 600.0
|
||||||
|
|
||||||
|
|
||||||
|
class HatcheryArgs(BaseModel):
|
||||||
|
"""Top-level mixin that adds the nested `hatchery:` field."""
|
||||||
|
|
||||||
|
hatchery: Optional[HatcheryConfig] = None
|
||||||
160
src/axolotl/integrations/hatchery/data.py
Normal file
160
src/axolotl/integrations/hatchery/data.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) Axolotl AI
|
||||||
|
# Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
|
"""Convert axolotl batch tensors to Tinker/Hatchery Datum format.
|
||||||
|
|
||||||
|
Both Tinker and Hatchery expect the client to apply the causal LM shift:
|
||||||
|
|
||||||
|
Original tokens: [t0, t1, t2, ..., t_{L-1}]
|
||||||
|
model_input: [t0, t1, ..., t_{L-2}] (last token dropped)
|
||||||
|
target_tokens: [t1, t2, ..., t_{L-1}] (first token dropped)
|
||||||
|
weights: [w1, w2, ..., w_{L-1}] (aligned to targets)
|
||||||
|
|
||||||
|
At position i, the model sees t_i and predicts target_tokens[i] = t_{i+1}.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_to_wire(t: torch.Tensor) -> dict[str, Any]:
|
||||||
|
"""Serialize a tensor to the TensorData wire dict."""
|
||||||
|
flat = t.detach().cpu().flatten()
|
||||||
|
dtype_map = {
|
||||||
|
torch.float32: "float32",
|
||||||
|
torch.float16: "float16",
|
||||||
|
torch.bfloat16: "bfloat16",
|
||||||
|
torch.int64: "int64",
|
||||||
|
torch.int32: "int32",
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"dtype": dtype_map.get(flat.dtype, "float32"),
|
||||||
|
"shape": list(t.shape),
|
||||||
|
"data": flat.tolist(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_datum(
|
||||||
|
tokens: list[int],
|
||||||
|
loss_fn_inputs: dict[str, torch.Tensor],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Build a Datum as a plain dict (wire-compatible with both Tinker and Hatchery)."""
|
||||||
|
return {
|
||||||
|
"model_input": {
|
||||||
|
"chunks": [{"type": "encoded_text", "tokens": tokens}],
|
||||||
|
},
|
||||||
|
"loss_fn_inputs": {
|
||||||
|
key: _tensor_to_wire(tensor) for key, tensor in loss_fn_inputs.items()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def datums_to_tinker(datums: list[dict[str, Any]]):
|
||||||
|
"""Wrap plain-dict datums into tinker.types.Datum objects.
|
||||||
|
|
||||||
|
Both the Tinker SDK and updated Hatchery client accept these.
|
||||||
|
"""
|
||||||
|
import tinker.types as tt
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for d in datums:
|
||||||
|
tokens = d["model_input"]["chunks"][0]["tokens"]
|
||||||
|
tinker_inputs = {}
|
||||||
|
for key, wire in d["loss_fn_inputs"].items():
|
||||||
|
tinker_inputs[key] = tt.TensorData(
|
||||||
|
data=wire["data"],
|
||||||
|
dtype=wire["dtype"],
|
||||||
|
shape=wire["shape"],
|
||||||
|
)
|
||||||
|
result.append(
|
||||||
|
tt.Datum(
|
||||||
|
model_input=tt.ModelInput.from_ints(tokens),
|
||||||
|
loss_fn_inputs=tinker_inputs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def batch_to_datums_sft(
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Convert an axolotl SFT batch to Datum dicts with causal shift."""
|
||||||
|
batch_size = input_ids.size(0)
|
||||||
|
datums = []
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
ids = input_ids[i]
|
||||||
|
lbl = labels[i]
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
seq_len = int(attention_mask[i].sum().item())
|
||||||
|
ids = ids[:seq_len]
|
||||||
|
lbl = lbl[:seq_len]
|
||||||
|
|
||||||
|
model_tokens = ids[:-1].tolist()
|
||||||
|
shifted_labels = lbl[1:]
|
||||||
|
|
||||||
|
target_tokens = shifted_labels.clone()
|
||||||
|
weights = (shifted_labels != -100).float()
|
||||||
|
target_tokens[target_tokens == -100] = 0
|
||||||
|
|
||||||
|
datums.append(
|
||||||
|
_make_datum(
|
||||||
|
model_tokens,
|
||||||
|
{
|
||||||
|
"target_tokens": target_tokens,
|
||||||
|
"weights": weights,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return datums
|
||||||
|
|
||||||
|
|
||||||
|
def batch_to_datums_rl(
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
logprobs: torch.Tensor,
|
||||||
|
advantages: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Convert an RL batch to importance_sampling/ppo Datum dicts with causal shift."""
|
||||||
|
batch_size = input_ids.size(0)
|
||||||
|
datums = []
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
ids = input_ids[i]
|
||||||
|
lbl = labels[i]
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
seq_len = int(attention_mask[i].sum().item())
|
||||||
|
else:
|
||||||
|
seq_len = ids.size(0)
|
||||||
|
ids = ids[:seq_len]
|
||||||
|
lbl = lbl[:seq_len]
|
||||||
|
lp = logprobs[i, :seq_len]
|
||||||
|
adv = advantages[i, :seq_len]
|
||||||
|
|
||||||
|
model_tokens = ids[:-1].tolist()
|
||||||
|
|
||||||
|
target_tokens = lbl[1:].clone()
|
||||||
|
target_tokens[target_tokens == -100] = 0
|
||||||
|
|
||||||
|
datums.append(
|
||||||
|
_make_datum(
|
||||||
|
model_tokens,
|
||||||
|
{
|
||||||
|
"target_tokens": target_tokens,
|
||||||
|
"logprobs": lp[1:],
|
||||||
|
"advantages": adv[1:],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return datums
|
||||||
87
src/axolotl/integrations/hatchery/examples/prep_math_rl.py
Normal file
87
src/axolotl/integrations/hatchery/examples/prep_math_rl.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) Axolotl AI
|
||||||
|
# Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
|
"""Prepare hendrycks_math for RL training with Hatchery/Tinker.
|
||||||
|
|
||||||
|
Creates a dataset with chat-formatted prompts that include
|
||||||
|
a hidden gold answer tag for the reward function.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
python src/axolotl/integrations/hatchery/examples/prep_math_rl.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
from datasets import Dataset, load_dataset
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def extract_boxed(text: str) -> str:
|
||||||
|
match = re.search(r"\\boxed\{", text)
|
||||||
|
if not match:
|
||||||
|
return ""
|
||||||
|
start = match.end()
|
||||||
|
depth = 1
|
||||||
|
i = start
|
||||||
|
while i < len(text) and depth > 0:
|
||||||
|
if text[i] == "{":
|
||||||
|
depth += 1
|
||||||
|
elif text[i] == "}":
|
||||||
|
depth -= 1
|
||||||
|
i += 1
|
||||||
|
return text[start : i - 1] if depth == 0 else ""
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", trust_remote_code=True)
|
||||||
|
|
||||||
|
ds = load_dataset("EleutherAI/hendrycks_math", "algebra", split="test")
|
||||||
|
level = os.environ.get("MATH_LEVEL", "Level 1")
|
||||||
|
filtered_rows = [x for x in ds if x["level"] == level]
|
||||||
|
print(f"{level} algebra: {len(filtered_rows)} problems")
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for prob in filtered_rows:
|
||||||
|
gold = extract_boxed(prob["solution"])
|
||||||
|
if not gold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Format as chat prompt with hidden gold tag
|
||||||
|
prompt = (
|
||||||
|
f"Solve the following math problem. "
|
||||||
|
f"Show your work and put your final answer in \\boxed{{}}.\n\n"
|
||||||
|
f"{prob['problem']}"
|
||||||
|
f"<|gold|>{gold}<|/gold|>"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tokenize the prompt
|
||||||
|
text = tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": prompt}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
prompt_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"input_ids": prompt_ids,
|
||||||
|
"labels": [-100] * len(prompt_ids),
|
||||||
|
"attention_mask": [1] * len(prompt_ids),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
out = Dataset.from_list(rows)
|
||||||
|
out_dir = f"./data/math_rl_{level.lower().replace(' ', '')}"
|
||||||
|
out.save_to_disk(out_dir)
|
||||||
|
print(f"Saved {len(out)} examples to {out_dir}")
|
||||||
|
if rows:
|
||||||
|
print(
|
||||||
|
f"Prompt length range: {min(len(r['input_ids']) for r in rows)}"
|
||||||
|
f"-{max(len(r['input_ids']) for r in rows)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
47
src/axolotl/integrations/hatchery/examples/tinker_rl.yaml
Normal file
47
src/axolotl/integrations/hatchery/examples/tinker_rl.yaml
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# RL (GRPO): hendrycks_math Level 1 via Tinker with Qwen3-8B
|
||||||
|
#
|
||||||
|
# Prep:
|
||||||
|
# python src/axolotl/integrations/hatchery/examples/prep_math_rl.py
|
||||||
|
#
|
||||||
|
# Run:
|
||||||
|
# export TINKER_API_KEY="your-key"
|
||||||
|
# axolotl train src/axolotl/integrations/hatchery/examples/tinker_rl.yaml
|
||||||
|
|
||||||
|
base_model: Qwen/Qwen3-8B
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.hatchery.HatcheryPlugin
|
||||||
|
|
||||||
|
hatchery:
|
||||||
|
backend: tinker
|
||||||
|
lora_rank: 16
|
||||||
|
loss_fn: importance_sampling
|
||||||
|
max_sample_tokens: 2048
|
||||||
|
sample_temperature: 0.7
|
||||||
|
num_samples: 4
|
||||||
|
pipeline: true
|
||||||
|
save_steps: 5
|
||||||
|
reward_funcs:
|
||||||
|
- axolotl.integrations.hatchery.rewards.math_reward.math_reward
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: ./data/math_rl_level1
|
||||||
|
ds_type: arrow
|
||||||
|
type: completion
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
|
||||||
|
learning_rate: 5.0e-5
|
||||||
|
optimizer: adamw_torch
|
||||||
|
adam_beta1: 0.9
|
||||||
|
adam_beta2: 0.95
|
||||||
|
weight_decay: 0.01
|
||||||
|
max_grad_norm: 1.0
|
||||||
|
|
||||||
|
max_steps: 10
|
||||||
|
num_epochs: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
output_dir: ./outputs/tinker-rl-math
|
||||||
42
src/axolotl/integrations/hatchery/examples/tinker_sft.yaml
Normal file
42
src/axolotl/integrations/hatchery/examples/tinker_sft.yaml
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# SFT: KIMI-K2 thinking data via Tinker remote API with Qwen3-8B
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# export TINKER_API_KEY="your-key"
|
||||||
|
# axolotl train src/axolotl/integrations/hatchery/examples/tinker_sft.yaml
|
||||||
|
|
||||||
|
base_model: Qwen/Qwen3-8B
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.hatchery.HatcheryPlugin
|
||||||
|
|
||||||
|
hatchery:
|
||||||
|
backend: tinker
|
||||||
|
lora_rank: 16
|
||||||
|
loss_fn: cross_entropy
|
||||||
|
pipeline: true
|
||||||
|
save_steps: 10
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: TeichAI/kimi-k2-thinking-1000x
|
||||||
|
split: train[:50]
|
||||||
|
type: chat_template
|
||||||
|
chat_template: qwen3
|
||||||
|
split_thinking: true
|
||||||
|
|
||||||
|
chat_template: qwen3
|
||||||
|
sequence_len: 2048
|
||||||
|
|
||||||
|
learning_rate: 3.0e-4
|
||||||
|
optimizer: adamw_torch
|
||||||
|
adam_beta1: 0.9
|
||||||
|
adam_beta2: 0.95
|
||||||
|
weight_decay: 0.01
|
||||||
|
max_grad_norm: 1.0
|
||||||
|
|
||||||
|
num_epochs: 1
|
||||||
|
max_steps: 20
|
||||||
|
micro_batch_size: 2
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
logging_steps: 1
|
||||||
|
|
||||||
|
output_dir: ./outputs/tinker-sft
|
||||||
147
src/axolotl/integrations/hatchery/plugin.py
Normal file
147
src/axolotl/integrations/hatchery/plugin.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) Axolotl AI
|
||||||
|
# Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
|
"""Axolotl plugin that routes training to a remote Hatchery/Tinker API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from peft import PeftModel
|
||||||
|
from transformers import AutoConfig, PreTrainedModel, Trainer
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class HatcheryPlugin(BasePlugin):
|
||||||
|
"""Plugin that replaces local training with remote API calls.
|
||||||
|
|
||||||
|
Activated by adding to the axolotl YAML:
|
||||||
|
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.hatchery.HatcheryPlugin
|
||||||
|
|
||||||
|
hatchery:
|
||||||
|
backend: tinker # or "hatchery"
|
||||||
|
lora_rank: 32
|
||||||
|
loss_fn: cross_entropy
|
||||||
|
# ... see HatcheryConfig for full options
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_input_args(self) -> str:
|
||||||
|
return "axolotl.integrations.hatchery.args.HatcheryArgs"
|
||||||
|
|
||||||
|
def register(self, cfg: dict):
|
||||||
|
"""Auto-set config values needed for remote training."""
|
||||||
|
if cfg.get("remove_unused_columns") is None:
|
||||||
|
cfg["remove_unused_columns"] = False
|
||||||
|
|
||||||
|
def pre_model_load(self, cfg: DictDefault):
|
||||||
|
"""Replace model loading with a tiny stub."""
|
||||||
|
hcfg = cfg.hatchery or {}
|
||||||
|
backend = (
|
||||||
|
hcfg.get("backend", "tinker")
|
||||||
|
if isinstance(hcfg, dict)
|
||||||
|
else getattr(hcfg, "backend", "tinker")
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
f"Hatchery plugin active: training dispatched to remote "
|
||||||
|
f"{backend} API. Skipping local model weight loading."
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.loaders import ModelLoader
|
||||||
|
|
||||||
|
def _stub_build_model(loader_self) -> bool:
|
||||||
|
base_model = loader_self.cfg.base_model
|
||||||
|
LOG.info(f"Skipping model weight loading for: {base_model}")
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
trust_remote_code=loader_self.cfg.get("trust_remote_code", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
class _Stub(PreTrainedModel):
|
||||||
|
config_class = type(config)
|
||||||
|
_no_split_modules: list[str] = []
|
||||||
|
supports_gradient_checkpointing = False
|
||||||
|
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super().__init__(cfg)
|
||||||
|
vocab_size = getattr(cfg, "vocab_size", 32000)
|
||||||
|
self.embed_tokens = torch.nn.Embedding(vocab_size, 1)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embed_tokens
|
||||||
|
|
||||||
|
def set_input_embeddings(self, value):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
loader_self.model = _Stub(config)
|
||||||
|
return True
|
||||||
|
|
||||||
|
ModelLoader._build_model = _stub_build_model # type: ignore[method-assign,assignment]
|
||||||
|
|
||||||
|
def get_trainer_cls(self, cfg: DictDefault) -> type[Trainer] | None:
|
||||||
|
"""Return the appropriate remote trainer class."""
|
||||||
|
hcfg = cfg.hatchery
|
||||||
|
loss_fn = getattr(hcfg, "loss_fn", "cross_entropy") if hcfg else "cross_entropy"
|
||||||
|
|
||||||
|
if loss_fn in ("importance_sampling", "ppo", "cispo", "dro"):
|
||||||
|
from .rl_trainer import HatcheryRLTrainer
|
||||||
|
|
||||||
|
return HatcheryRLTrainer
|
||||||
|
|
||||||
|
from .trainer import HatcheryTrainer
|
||||||
|
|
||||||
|
return HatcheryTrainer
|
||||||
|
|
||||||
|
def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||||
|
model._hatchery_remote = True
|
||||||
|
|
||||||
|
def post_train(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
|
||||||
|
LOG.info(
|
||||||
|
"Hatchery: skipping local model save (weights are on remote API). "
|
||||||
|
"Use `tinker checkpoint download` or hatchery CLI to retrieve."
|
||||||
|
)
|
||||||
|
|
||||||
|
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
|
||||||
|
"""Inject hatchery config + axolotl training params into the trainer."""
|
||||||
|
from .args import HatcheryConfig
|
||||||
|
from .rl_trainer import HatcheryRLTrainer
|
||||||
|
from .trainer import HatcheryTrainer
|
||||||
|
|
||||||
|
if not isinstance(trainer, (HatcheryTrainer, HatcheryRLTrainer)):
|
||||||
|
return
|
||||||
|
|
||||||
|
hcfg = cfg.hatchery
|
||||||
|
if isinstance(hcfg, dict):
|
||||||
|
hatchery_config = HatcheryConfig(**hcfg)
|
||||||
|
elif hcfg is None:
|
||||||
|
hatchery_config = HatcheryConfig()
|
||||||
|
else:
|
||||||
|
hatchery_config = hcfg
|
||||||
|
|
||||||
|
trainer.hatchery_args = hatchery_config
|
||||||
|
trainer._base_model_name = cfg.base_model
|
||||||
|
|
||||||
|
# Pull standard training params from axolotl config so they
|
||||||
|
# don't need to be duplicated under hatchery:
|
||||||
|
trainer._optim_params = {
|
||||||
|
"learning_rate": cfg.learning_rate
|
||||||
|
if cfg.learning_rate is not None
|
||||||
|
else 1e-4,
|
||||||
|
"beta1": cfg.adam_beta1 if cfg.adam_beta1 is not None else 0.9,
|
||||||
|
"beta2": cfg.adam_beta2 if cfg.adam_beta2 is not None else 0.95,
|
||||||
|
"eps": cfg.adam_epsilon if cfg.adam_epsilon is not None else 1e-12,
|
||||||
|
"weight_decay": cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||||
|
"grad_clip_norm": cfg.max_grad_norm
|
||||||
|
if cfg.max_grad_norm is not None
|
||||||
|
else 0.0,
|
||||||
|
}
|
||||||
3
src/axolotl/integrations/hatchery/rewards/__init__.py
Normal file
3
src/axolotl/integrations/hatchery/rewards/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) Axolotl AI
|
||||||
|
# Licensed under the Apache License, Version 2.0
|
||||||
78
src/axolotl/integrations/hatchery/rewards/math_reward.py
Normal file
78
src/axolotl/integrations/hatchery/rewards/math_reward.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) Axolotl AI
|
||||||
|
# Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
|
"""Math reward function for hendrycks_math GRPO training.
|
||||||
|
|
||||||
|
Uses math_verify for robust answer comparison. Falls back to
|
||||||
|
exact string match of \\boxed{} content only when math_verify
|
||||||
|
is unavailable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_boxed(text: str) -> str | None:
|
||||||
|
"""Extract \\boxed{...} answer handling nested braces."""
|
||||||
|
match = re.search(r"\\boxed\{", text)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
start = match.end()
|
||||||
|
depth = 1
|
||||||
|
i = start
|
||||||
|
while i < len(text) and depth > 0:
|
||||||
|
if text[i] == "{":
|
||||||
|
depth += 1
|
||||||
|
elif text[i] == "}":
|
||||||
|
depth -= 1
|
||||||
|
i += 1
|
||||||
|
return text[start : i - 1] if depth == 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
def math_reward(prompts: list[str], completions: list[str], **kwargs) -> list[float]:
|
||||||
|
"""Score completions by checking if \\boxed{} answer matches the gold answer.
|
||||||
|
|
||||||
|
The gold answer is extracted from the prompt (appended as a hidden
|
||||||
|
tag by the dataset preprocessing). Format:
|
||||||
|
... <|gold|>ANSWER<|/gold|>
|
||||||
|
"""
|
||||||
|
rewards = []
|
||||||
|
for prompt, completion in zip(prompts, completions, strict=True):
|
||||||
|
gold_match = re.search(r"<\|gold\|>(.*?)<\|/gold\|>", prompt)
|
||||||
|
if not gold_match:
|
||||||
|
rewards.append(0.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
gold_answer = gold_match.group(1).strip()
|
||||||
|
pred_answer = extract_boxed(completion)
|
||||||
|
|
||||||
|
if pred_answer is None:
|
||||||
|
rewards.append(0.0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
verified = None
|
||||||
|
try:
|
||||||
|
from math_verify import parse, verify
|
||||||
|
|
||||||
|
gold_parsed = parse(gold_answer)
|
||||||
|
pred_parsed = parse(pred_answer)
|
||||||
|
verified = verify(gold_parsed, pred_parsed)
|
||||||
|
except Exception:
|
||||||
|
LOG.debug(
|
||||||
|
"math_verify unavailable or failed, using string fallback",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if verified is not None:
|
||||||
|
rewards.append(1.0 if verified else 0.0)
|
||||||
|
elif pred_answer.strip() == gold_answer.strip():
|
||||||
|
rewards.append(1.0)
|
||||||
|
else:
|
||||||
|
rewards.append(0.0)
|
||||||
|
|
||||||
|
return rewards
|
||||||
409
src/axolotl/integrations/hatchery/rl_trainer.py
Normal file
409
src/axolotl/integrations/hatchery/rl_trainer.py
Normal file
@@ -0,0 +1,409 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) Axolotl AI
|
||||||
|
# Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
|
"""Remote RL trainer (GRPO/PPO) using Tinker or Hatchery API.
|
||||||
|
|
||||||
|
Full RL loop per step:
|
||||||
|
1. Extract prompts from dataset batch
|
||||||
|
2. Sample N completions per prompt via remote SamplingClient
|
||||||
|
3. Score completions with local reward functions
|
||||||
|
4. Compute GRPO-style advantages (per-group normalization)
|
||||||
|
5. Send (prompt+completion, logprobs, advantages) as forward_backward
|
||||||
|
6. Optimizer step
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.trainer_utils import TrainOutput
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
from .args import HatcheryConfig
|
||||||
|
from .data import batch_to_datums_rl, datums_to_tinker
|
||||||
|
from .trainer import _create_training_client
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_reward_func(fqn: str) -> Callable:
|
||||||
|
"""Load a reward function from a fully qualified name like 'module.func'."""
|
||||||
|
module_path = ".".join(fqn.split(".")[:-1])
|
||||||
|
func_name = fqn.split(".")[-1]
|
||||||
|
mod = importlib.import_module(module_path)
|
||||||
|
func = getattr(mod, func_name)
|
||||||
|
if len(inspect.signature(func).parameters) < 2:
|
||||||
|
raise ValueError(f"Reward function {fqn} must accept (prompts, completions)")
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
class HatcheryRLTrainer(AxolotlTrainer):
|
||||||
|
"""Remote RL trainer using Tinker/Hatchery for sampling and training."""
|
||||||
|
|
||||||
|
hatchery_args: Optional[HatcheryConfig]
|
||||||
|
_base_model_name: Optional[str]
|
||||||
|
_training_client: Any
|
||||||
|
_reward_functions: list[Callable]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.hatchery_args = None
|
||||||
|
self._base_model_name = None
|
||||||
|
self._training_client = None
|
||||||
|
self._reward_functions = []
|
||||||
|
|
||||||
|
def _ensure_reward_functions(self):
|
||||||
|
if self._reward_functions:
|
||||||
|
return
|
||||||
|
args = self.hatchery_args
|
||||||
|
if not args or not args.reward_funcs:
|
||||||
|
raise ValueError(
|
||||||
|
"No reward functions configured. Set hatchery.reward_funcs "
|
||||||
|
"in YAML, e.g. reward_funcs: ['my_module.my_reward']"
|
||||||
|
)
|
||||||
|
for fqn in args.reward_funcs:
|
||||||
|
self._reward_functions.append(_load_reward_func(fqn))
|
||||||
|
LOG.info(f"Loaded {len(self._reward_functions)} reward function(s)")
|
||||||
|
|
||||||
|
def _get_training_client(self):
|
||||||
|
if self._training_client is not None:
|
||||||
|
return self._training_client
|
||||||
|
|
||||||
|
self._training_client = _create_training_client(
|
||||||
|
self.hatchery_args, self._base_model_name
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
f"Remote RL session created: backend={self.hatchery_args.backend}, "
|
||||||
|
f"model={self._base_model_name}, rank={self.hatchery_args.lora_rank}"
|
||||||
|
)
|
||||||
|
return self._training_client
|
||||||
|
|
||||||
|
def _sample_completions(self, prompt_ids_list: list[list[int]]):
|
||||||
|
"""Sample completions for prompts via remote API."""
|
||||||
|
import tinker.types as tt
|
||||||
|
|
||||||
|
tc = self._get_training_client()
|
||||||
|
args = self.hatchery_args
|
||||||
|
assert args is not None # validated by _get_training_client
|
||||||
|
results = []
|
||||||
|
|
||||||
|
sc = tc.save_weights_and_get_sampling_client()
|
||||||
|
|
||||||
|
for prompt_ids in prompt_ids_list:
|
||||||
|
if hasattr(sc, "sampling_session_id"):
|
||||||
|
sample_result = sc.sample(
|
||||||
|
prompt_ids,
|
||||||
|
max_tokens=args.max_sample_tokens,
|
||||||
|
temperature=args.sample_temperature,
|
||||||
|
n=args.num_samples,
|
||||||
|
).result(timeout=args.future_timeout)
|
||||||
|
else:
|
||||||
|
mi = tt.ModelInput.from_ints(prompt_ids)
|
||||||
|
sp = tt.SamplingParams(
|
||||||
|
max_tokens=args.max_sample_tokens,
|
||||||
|
temperature=args.sample_temperature,
|
||||||
|
top_p=0.95,
|
||||||
|
top_k=-1,
|
||||||
|
)
|
||||||
|
sample_result = sc.sample(
|
||||||
|
prompt=mi,
|
||||||
|
num_samples=args.num_samples,
|
||||||
|
sampling_params=sp,
|
||||||
|
).result(timeout=args.future_timeout)
|
||||||
|
|
||||||
|
sequences = (
|
||||||
|
sample_result.sequences
|
||||||
|
if hasattr(sample_result, "sequences")
|
||||||
|
else sample_result.get("sequences", [])
|
||||||
|
)
|
||||||
|
for seq in sequences:
|
||||||
|
tokens = (
|
||||||
|
list(seq.tokens)
|
||||||
|
if hasattr(seq, "tokens")
|
||||||
|
else seq.get("tokens", [])
|
||||||
|
)
|
||||||
|
logprobs = (
|
||||||
|
list(seq.logprobs)
|
||||||
|
if hasattr(seq, "logprobs") and seq.logprobs
|
||||||
|
else seq.get("logprobs", [])
|
||||||
|
)
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"tokens": list(prompt_ids) + tokens,
|
||||||
|
"completion_tokens": tokens,
|
||||||
|
"logprobs": logprobs,
|
||||||
|
"prompt_len": len(prompt_ids),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _compute_rewards(
|
||||||
|
self, prompts: list[str], completions: list[str]
|
||||||
|
) -> list[float]:
|
||||||
|
total_rewards = [0.0] * len(completions)
|
||||||
|
for reward_fn in self._reward_functions:
|
||||||
|
rewards = reward_fn(prompts, completions)
|
||||||
|
for i, r in enumerate(rewards):
|
||||||
|
total_rewards[i] += r
|
||||||
|
return total_rewards
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_advantages(rewards: list[float], group_size: int) -> list[float]:
|
||||||
|
advantages = []
|
||||||
|
for i in range(0, len(rewards), group_size):
|
||||||
|
group = rewards[i : i + group_size]
|
||||||
|
mean = sum(group) / len(group)
|
||||||
|
var = sum((r - mean) ** 2 for r in group) / max(len(group), 1)
|
||||||
|
std = var**0.5 if var > 1e-8 else 1.0
|
||||||
|
advantages.extend([(r - mean) / std for r in group])
|
||||||
|
return advantages
|
||||||
|
|
||||||
|
def _do_optim_step(self):
|
||||||
|
import tinker.types as tt
|
||||||
|
|
||||||
|
tc = self._get_training_client()
|
||||||
|
return tc.optim_step(tt.AdamParams(**self._optim_params))
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
resume_from_checkpoint: Optional[str] = None,
|
||||||
|
trial: Any = None,
|
||||||
|
ignore_keys_for_eval: Optional[list[str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> TrainOutput:
|
||||||
|
args = self.hatchery_args
|
||||||
|
if args is None:
|
||||||
|
raise RuntimeError("hatchery_args not configured")
|
||||||
|
|
||||||
|
self._ensure_reward_functions()
|
||||||
|
|
||||||
|
train_dataloader = self.get_train_dataloader()
|
||||||
|
num_train_epochs = int(self.args.num_train_epochs)
|
||||||
|
max_steps = self.args.max_steps if self.args.max_steps > 0 else 1000
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"Remote RL training: max_steps={max_steps}, "
|
||||||
|
f"loss_fn={args.loss_fn}, samples/prompt={args.num_samples}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.state.max_steps = max_steps
|
||||||
|
self.state.num_train_epochs = num_train_epochs
|
||||||
|
self.state.is_local_process_zero = True
|
||||||
|
self.state.is_world_process_zero = True
|
||||||
|
|
||||||
|
self.control = self.callback_handler.on_train_begin(
|
||||||
|
self.args,
|
||||||
|
self.state,
|
||||||
|
self.control, # type: ignore[has-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = self.processing_class
|
||||||
|
global_step = 0
|
||||||
|
total_loss = 0.0
|
||||||
|
total_reward = 0.0
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for _epoch in range(num_train_epochs):
|
||||||
|
if global_step >= max_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
for batch in train_dataloader:
|
||||||
|
if global_step >= max_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.control = self.callback_handler.on_step_begin(
|
||||||
|
self.args, self.state, self.control
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_ids_batch = batch["input_ids"]
|
||||||
|
# Full prompt text (with gold tag) for reward scoring
|
||||||
|
prompt_texts = tokenizer.batch_decode(
|
||||||
|
prompt_ids_batch, skip_special_tokens=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Strip <|gold|>...<|/gold|> from token ids before
|
||||||
|
# sending to the model for sampling — the gold answer
|
||||||
|
# must only be visible to the local reward function.
|
||||||
|
sampling_prompts = []
|
||||||
|
for prompt_text in prompt_texts:
|
||||||
|
clean = re.sub(r"<\|gold\|>.*?<\|/gold\|>", "", prompt_text)
|
||||||
|
clean_ids = tokenizer.encode(clean, add_special_tokens=False)
|
||||||
|
sampling_prompts.append(clean_ids)
|
||||||
|
|
||||||
|
# 1. Sample completions (without gold answer)
|
||||||
|
t0 = time.time()
|
||||||
|
samples = self._sample_completions(sampling_prompts)
|
||||||
|
t_sample = time.time() - t0
|
||||||
|
|
||||||
|
if not samples:
|
||||||
|
LOG.warning("No samples generated, skipping step")
|
||||||
|
continue
|
||||||
|
LOG.info(
|
||||||
|
f"Sampled {len(samples)} completions, "
|
||||||
|
f"avg_len={sum(len(s['completion_tokens']) for s in samples) / len(samples):.0f}tok"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Decode and score
|
||||||
|
completion_texts = [
|
||||||
|
tokenizer.decode(s["completion_tokens"], skip_special_tokens=False)
|
||||||
|
for s in samples
|
||||||
|
]
|
||||||
|
sample_prompts = []
|
||||||
|
for prompt_text in prompt_texts:
|
||||||
|
sample_prompts.extend([prompt_text] * args.num_samples)
|
||||||
|
|
||||||
|
rewards = self._compute_rewards(sample_prompts, completion_texts)
|
||||||
|
|
||||||
|
# 3. GRPO advantages
|
||||||
|
advantages_list = self._compute_advantages(
|
||||||
|
rewards, group_size=args.num_samples
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Build training data
|
||||||
|
all_datums = []
|
||||||
|
for i, sample in enumerate(samples):
|
||||||
|
full_tokens = sample["tokens"]
|
||||||
|
prompt_len = sample["prompt_len"]
|
||||||
|
seq_len = len(full_tokens)
|
||||||
|
|
||||||
|
input_ids = torch.tensor([full_tokens], dtype=torch.long)
|
||||||
|
labels = torch.full((1, seq_len), -100, dtype=torch.long)
|
||||||
|
labels[0, prompt_len:] = torch.tensor(full_tokens[prompt_len:])
|
||||||
|
|
||||||
|
logprobs_t = torch.zeros(1, seq_len)
|
||||||
|
if sample["logprobs"]:
|
||||||
|
lp = sample["logprobs"][: seq_len - prompt_len]
|
||||||
|
logprobs_t[0, prompt_len : prompt_len + len(lp)] = torch.tensor(
|
||||||
|
lp
|
||||||
|
)
|
||||||
|
|
||||||
|
adv_t = torch.zeros(1, seq_len)
|
||||||
|
adv_t[0, prompt_len:] = advantages_list[i]
|
||||||
|
|
||||||
|
all_datums.extend(
|
||||||
|
batch_to_datums_rl(input_ids, labels, logprobs_t, adv_t)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Forward backward (one datum at a time for memory) + optim
|
||||||
|
t0 = time.time()
|
||||||
|
tc = self._get_training_client()
|
||||||
|
step_loss = 0.0
|
||||||
|
for datum in all_datums:
|
||||||
|
fb_future = tc.forward_backward(
|
||||||
|
datums_to_tinker([datum]),
|
||||||
|
loss_fn=args.loss_fn,
|
||||||
|
loss_fn_config=args.loss_fn_config,
|
||||||
|
)
|
||||||
|
fb_result = fb_future.result(timeout=args.future_timeout)
|
||||||
|
if hasattr(fb_result, "metrics"):
|
||||||
|
step_loss += float(
|
||||||
|
(fb_result.metrics or {}).get("loss:sum", 0.0)
|
||||||
|
)
|
||||||
|
elif isinstance(fb_result, dict):
|
||||||
|
step_loss += float(
|
||||||
|
fb_result.get("metrics", {}).get("loss:sum", 0.0)
|
||||||
|
)
|
||||||
|
optim_future = self._do_optim_step()
|
||||||
|
if not args.pipeline:
|
||||||
|
optim_future.result(timeout=args.future_timeout)
|
||||||
|
t_train = time.time() - t0
|
||||||
|
|
||||||
|
mean_reward = sum(rewards) / len(rewards)
|
||||||
|
accuracy = sum(1 for r in rewards if r > 0) / len(rewards)
|
||||||
|
mean_adv = sum(abs(a) for a in advantages_list) / len(advantages_list)
|
||||||
|
global_step += 1
|
||||||
|
total_loss += step_loss
|
||||||
|
total_reward += mean_reward
|
||||||
|
self.state.global_step = global_step
|
||||||
|
|
||||||
|
log_interval = self.args.logging_steps or 1
|
||||||
|
if global_step % log_interval == 0:
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
LOG.info(
|
||||||
|
f"[step {global_step}/{max_steps}] "
|
||||||
|
f"acc={accuracy:.2f} reward={mean_reward:.3f} "
|
||||||
|
f"|adv|={mean_adv:.3f} loss:sum={step_loss:.1f} "
|
||||||
|
f"sample={t_sample:.1f}s train={t_train:.1f}s "
|
||||||
|
f"{elapsed / global_step:.1f}s/step"
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
{
|
||||||
|
"loss": step_loss,
|
||||||
|
"reward": mean_reward,
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"mean_abs_advantage": mean_adv,
|
||||||
|
"learning_rate": self._optim_params["learning_rate"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.save_steps and global_step % args.save_steps == 0:
|
||||||
|
self._save_remote_checkpoint(global_step)
|
||||||
|
|
||||||
|
self.control = self.callback_handler.on_step_end(
|
||||||
|
self.args, self.state, self.control
|
||||||
|
)
|
||||||
|
if self.control.should_training_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.control.should_training_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
if global_step > 0:
|
||||||
|
self._save_remote_checkpoint(global_step, name="final")
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
avg_loss = total_loss / max(global_step, 1)
|
||||||
|
avg_reward = total_reward / max(global_step, 1)
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"RL training complete: {global_step} steps, {elapsed:.1f}s, "
|
||||||
|
f"avg_reward={avg_reward:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.control = self.callback_handler.on_train_end(
|
||||||
|
self.args, self.state, self.control
|
||||||
|
)
|
||||||
|
|
||||||
|
return TrainOutput(
|
||||||
|
global_step=global_step,
|
||||||
|
training_loss=avg_loss,
|
||||||
|
metrics={
|
||||||
|
"train_loss": avg_loss,
|
||||||
|
"train_reward": avg_reward,
|
||||||
|
"train_runtime": elapsed,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _save_remote_checkpoint(self, step: int, name: Optional[str] = None):
|
||||||
|
tc = self._get_training_client()
|
||||||
|
args = self.hatchery_args
|
||||||
|
assert args is not None # validated by _get_training_client
|
||||||
|
ckpt_name = name or f"{args.save_name_prefix}-{step:06d}"
|
||||||
|
try:
|
||||||
|
future = tc.save_state(ckpt_name)
|
||||||
|
future.result(timeout=args.future_timeout)
|
||||||
|
LOG.info(f"Remote checkpoint saved: {ckpt_name}")
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(f"Failed to save checkpoint {ckpt_name}")
|
||||||
|
if name == "final":
|
||||||
|
raise
|
||||||
|
|
||||||
|
def save_model(self, output_dir=None, _internal_call=False):
|
||||||
|
self._save_remote_checkpoint(
|
||||||
|
step=self.state.global_step,
|
||||||
|
name=output_dir or "hf-save",
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"HatcheryRLTrainer uses remote API; compute_loss not called locally."
|
||||||
|
)
|
||||||
327
src/axolotl/integrations/hatchery/trainer.py
Normal file
327
src/axolotl/integrations/hatchery/trainer.py
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) Axolotl AI
|
||||||
|
# Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
|
"""Remote trainer that dispatches to Tinker or Hatchery API."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers.trainer_utils import TrainOutput
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
from .args import HatcheryConfig
|
||||||
|
from .data import batch_to_datums_sft, datums_to_tinker
|
||||||
|
|
||||||
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_loss(result) -> float:
|
||||||
|
"""Extract loss:sum from a forward_backward result.
|
||||||
|
|
||||||
|
Tinker's cross_entropy (and other losses) return the SUM of per-token
|
||||||
|
losses, not the mean. This is by design — it lets users control
|
||||||
|
normalization via the weights tensor. The trainer logs this raw sum;
|
||||||
|
users who want per-token loss should divide by number of active tokens.
|
||||||
|
"""
|
||||||
|
if hasattr(result, "metrics"):
|
||||||
|
metrics = result.metrics or {}
|
||||||
|
return float(metrics.get("loss:sum", metrics.get("loss", 0.0)))
|
||||||
|
if isinstance(result, dict):
|
||||||
|
metrics = result.get("metrics", {})
|
||||||
|
return float(metrics.get("loss:sum", metrics.get("loss", 0.0)))
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def _create_training_client(args: HatcheryConfig, base_model: str):
|
||||||
|
"""Create a training client for either Tinker or Hatchery backend."""
|
||||||
|
if args.backend == "tinker":
|
||||||
|
import tinker
|
||||||
|
|
||||||
|
api_key = args.api_key or os.environ.get("TINKER_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"Tinker API key required. Set `hatchery.api_key` in config "
|
||||||
|
"or TINKER_API_KEY env var."
|
||||||
|
)
|
||||||
|
os.environ["TINKER_API_KEY"] = api_key
|
||||||
|
|
||||||
|
service = tinker.ServiceClient(project_id=args.project_id)
|
||||||
|
return service.create_lora_training_client(
|
||||||
|
base_model=base_model,
|
||||||
|
rank=args.lora_rank,
|
||||||
|
train_mlp=args.train_mlp,
|
||||||
|
train_attn=args.train_attn,
|
||||||
|
train_unembed=args.train_unembed,
|
||||||
|
)
|
||||||
|
|
||||||
|
from hatchery.core.client import HatcheryClient
|
||||||
|
|
||||||
|
base_url = args.base_url or os.environ.get("HATCHERY_URL", "http://127.0.0.1:8420")
|
||||||
|
token = args.api_key or os.environ.get("HATCHERY_API_KEY", "dev")
|
||||||
|
|
||||||
|
client = HatcheryClient(base_url=base_url, token=token, timeout=args.future_timeout)
|
||||||
|
return client.create_lora_training_client(
|
||||||
|
base_model=base_model,
|
||||||
|
rank=args.lora_rank,
|
||||||
|
train_attn=args.train_attn,
|
||||||
|
train_mlp=args.train_mlp,
|
||||||
|
train_unembed=args.train_unembed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HatcheryTrainer(AxolotlTrainer):
|
||||||
|
"""Trainer that sends preprocessed batches to a remote training API.
|
||||||
|
|
||||||
|
Replaces local forward/backward with remote API calls to Tinker or
|
||||||
|
Hatchery. Uses axolotl's full data preprocessing pipeline (tokenization,
|
||||||
|
chat templates, packing, etc.) but offloads compute to remote GPUs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
hatchery_args: Optional[HatcheryConfig]
|
||||||
|
_base_model_name: Optional[str]
|
||||||
|
_training_client: Any
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.hatchery_args = None
|
||||||
|
self._base_model_name = None
|
||||||
|
self._training_client = None
|
||||||
|
|
||||||
|
def _get_training_client(self):
|
||||||
|
"""Lazily create the remote training session."""
|
||||||
|
if self._training_client is not None:
|
||||||
|
return self._training_client
|
||||||
|
|
||||||
|
args = self.hatchery_args
|
||||||
|
if args is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"HatcheryTrainer.hatchery_args not set. "
|
||||||
|
"Ensure the HatcheryPlugin is registered."
|
||||||
|
)
|
||||||
|
|
||||||
|
base_model = self._base_model_name
|
||||||
|
if not base_model:
|
||||||
|
raise RuntimeError("HatcheryTrainer._base_model_name not set.")
|
||||||
|
|
||||||
|
self._training_client = _create_training_client(args, base_model)
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"Remote training session created: backend={args.backend}, "
|
||||||
|
f"model={base_model}, rank={args.lora_rank}"
|
||||||
|
)
|
||||||
|
return self._training_client
|
||||||
|
|
||||||
|
def _send_batch(self, batch: dict[str, torch.Tensor]):
|
||||||
|
"""Convert batch to datums and send forward_backward to remote.
|
||||||
|
|
||||||
|
Returns (future, n_active_tokens) where n_active_tokens counts
|
||||||
|
the completion tokens in this batch (for loss normalization).
|
||||||
|
"""
|
||||||
|
input_ids = batch["input_ids"]
|
||||||
|
labels = batch["labels"]
|
||||||
|
attention_mask = batch.get("attention_mask")
|
||||||
|
|
||||||
|
n_active = int((labels[:, 1:] != -100).sum().item())
|
||||||
|
datums = batch_to_datums_sft(input_ids, labels, attention_mask)
|
||||||
|
|
||||||
|
tc = self._get_training_client()
|
||||||
|
args = self.hatchery_args
|
||||||
|
assert args is not None # validated by _get_training_client
|
||||||
|
send_datums = datums_to_tinker(datums)
|
||||||
|
|
||||||
|
future = tc.forward_backward(
|
||||||
|
send_datums,
|
||||||
|
loss_fn=args.loss_fn,
|
||||||
|
loss_fn_config=args.loss_fn_config,
|
||||||
|
)
|
||||||
|
return future, n_active
|
||||||
|
|
||||||
|
def _do_optim_step(self):
|
||||||
|
"""Send optimizer step to remote using axolotl's training params."""
|
||||||
|
import tinker.types as tt
|
||||||
|
|
||||||
|
tc = self._get_training_client()
|
||||||
|
return tc.optim_step(tt.AdamParams(**self._optim_params))
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
resume_from_checkpoint: Optional[str] = None,
|
||||||
|
trial: Any = None,
|
||||||
|
ignore_keys_for_eval: Optional[list[str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> TrainOutput:
|
||||||
|
"""Main training loop — sends batches to remote API."""
|
||||||
|
args = self.hatchery_args
|
||||||
|
if args is None:
|
||||||
|
raise RuntimeError("hatchery_args not configured")
|
||||||
|
|
||||||
|
train_dataloader = self.get_train_dataloader()
|
||||||
|
num_batches = len(train_dataloader)
|
||||||
|
|
||||||
|
grad_accum = self.args.gradient_accumulation_steps
|
||||||
|
num_train_epochs = int(self.args.num_train_epochs)
|
||||||
|
steps_per_epoch = max(num_batches // grad_accum, 1)
|
||||||
|
max_steps = (
|
||||||
|
self.args.max_steps
|
||||||
|
if self.args.max_steps > 0
|
||||||
|
else steps_per_epoch * num_train_epochs
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"Remote training: {num_batches} batches/epoch, "
|
||||||
|
f"{grad_accum} grad_accum, {max_steps} max steps, "
|
||||||
|
f"{num_train_epochs} epochs"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.state.max_steps = max_steps
|
||||||
|
self.state.num_train_epochs = num_train_epochs
|
||||||
|
self.state.is_local_process_zero = True
|
||||||
|
self.state.is_world_process_zero = True
|
||||||
|
|
||||||
|
self.control = self.callback_handler.on_train_begin(
|
||||||
|
self.args,
|
||||||
|
self.state,
|
||||||
|
self.control, # type: ignore[has-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
global_step = 0
|
||||||
|
total_loss = 0.0
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for _epoch in range(num_train_epochs):
|
||||||
|
if global_step >= max_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.control = self.callback_handler.on_epoch_begin(
|
||||||
|
self.args, self.state, self.control
|
||||||
|
)
|
||||||
|
|
||||||
|
pending_fb_futures = []
|
||||||
|
accum_count = 0
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(train_dataloader):
|
||||||
|
if global_step >= max_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.control = self.callback_handler.on_step_begin(
|
||||||
|
self.args, self.state, self.control
|
||||||
|
)
|
||||||
|
|
||||||
|
fb_future, n_active = self._send_batch(batch)
|
||||||
|
pending_fb_futures.append((fb_future, n_active))
|
||||||
|
accum_count += 1
|
||||||
|
|
||||||
|
if accum_count >= grad_accum:
|
||||||
|
step_loss_sum = 0.0
|
||||||
|
step_active = 0
|
||||||
|
for fut, n_act in pending_fb_futures:
|
||||||
|
result = fut.result(timeout=args.future_timeout)
|
||||||
|
step_loss_sum += _extract_loss(result)
|
||||||
|
step_active += n_act
|
||||||
|
|
||||||
|
optim_future = self._do_optim_step()
|
||||||
|
if not args.pipeline:
|
||||||
|
optim_future.result(timeout=args.future_timeout)
|
||||||
|
|
||||||
|
step_loss = (
|
||||||
|
step_loss_sum / step_active
|
||||||
|
if step_active > 0
|
||||||
|
else step_loss_sum
|
||||||
|
)
|
||||||
|
|
||||||
|
global_step += 1
|
||||||
|
total_loss += step_loss
|
||||||
|
self.state.global_step = global_step
|
||||||
|
self.state.epoch = _epoch + (batch_idx + 1) / num_batches
|
||||||
|
|
||||||
|
log_interval = self.args.logging_steps or 1
|
||||||
|
if global_step % log_interval == 0:
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
avg_loss = total_loss / global_step
|
||||||
|
LOG.info(
|
||||||
|
f"[step {global_step}/{max_steps}] "
|
||||||
|
f"loss/tok={step_loss:.4f} avg={avg_loss:.4f} "
|
||||||
|
f"active={step_active} "
|
||||||
|
f"{elapsed / global_step:.2f}s/step"
|
||||||
|
)
|
||||||
|
self.log(
|
||||||
|
{
|
||||||
|
"loss": step_loss,
|
||||||
|
"learning_rate": self._optim_params["learning_rate"],
|
||||||
|
"epoch": self.state.epoch,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.save_steps and global_step % args.save_steps == 0:
|
||||||
|
self._save_remote_checkpoint(global_step)
|
||||||
|
|
||||||
|
self.control = self.callback_handler.on_step_end(
|
||||||
|
self.args, self.state, self.control
|
||||||
|
)
|
||||||
|
|
||||||
|
pending_fb_futures = []
|
||||||
|
accum_count = 0
|
||||||
|
|
||||||
|
if self.control.should_training_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.control = self.callback_handler.on_epoch_end(
|
||||||
|
self.args, self.state, self.control
|
||||||
|
)
|
||||||
|
if self.control.should_training_stop:
|
||||||
|
break
|
||||||
|
|
||||||
|
if global_step > 0:
|
||||||
|
self._save_remote_checkpoint(global_step, name="final")
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
avg_loss = total_loss / max(global_step, 1)
|
||||||
|
|
||||||
|
LOG.info(
|
||||||
|
f"Training complete: {global_step} steps, {elapsed:.1f}s total, "
|
||||||
|
f"{elapsed / max(global_step, 1):.2f}s/step, avg_loss={avg_loss:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.control = self.callback_handler.on_train_end(
|
||||||
|
self.args, self.state, self.control
|
||||||
|
)
|
||||||
|
|
||||||
|
return TrainOutput(
|
||||||
|
global_step=global_step,
|
||||||
|
training_loss=avg_loss,
|
||||||
|
metrics={"train_loss": avg_loss, "train_runtime": elapsed},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _save_remote_checkpoint(self, step: int, name: Optional[str] = None):
|
||||||
|
"""Save a checkpoint on the remote service."""
|
||||||
|
tc = self._get_training_client()
|
||||||
|
args = self.hatchery_args
|
||||||
|
assert args is not None # validated by _get_training_client
|
||||||
|
ckpt_name = name or f"{args.save_name_prefix}-{step:06d}"
|
||||||
|
try:
|
||||||
|
future = tc.save_state(ckpt_name)
|
||||||
|
future.result(timeout=args.future_timeout)
|
||||||
|
LOG.info(f"Remote checkpoint saved: {ckpt_name}")
|
||||||
|
except Exception:
|
||||||
|
LOG.exception(f"Failed to save checkpoint {ckpt_name}")
|
||||||
|
if name == "final":
|
||||||
|
raise
|
||||||
|
|
||||||
|
def save_model(self, output_dir=None, _internal_call=False):
|
||||||
|
"""Delegate to remote checkpoint save so HF callbacks create checkpoints."""
|
||||||
|
self._save_remote_checkpoint(
|
||||||
|
step=self.state.global_step,
|
||||||
|
name=output_dir or "hf-save",
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"HatcheryTrainer uses remote API; compute_loss should not be called."
|
||||||
|
)
|
||||||
@@ -11,7 +11,7 @@ kd_ce_alpha: 0.1
|
|||||||
kd_alpha: 0.9
|
kd_alpha: 0.9
|
||||||
kd_temperature: 1.0
|
kd_temperature: 1.0
|
||||||
|
|
||||||
torch_compile: True # torch>=2.6.0, recommended to reduce vram
|
torch_compile: True # recommended to reduce vram
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: ...
|
- path: ...
|
||||||
|
|||||||
@@ -2,35 +2,17 @@
|
|||||||
# Copyright (c) Axolotl AI
|
# Copyright (c) Axolotl AI
|
||||||
# Licensed under the Apache License, Version 2.0
|
# Licensed under the Apache License, Version 2.0
|
||||||
|
|
||||||
from .lora_layout import (
|
from . import layers
|
||||||
peft_down_proj_lora_to_scattermoe,
|
from .lora_ops import ParallelExperts
|
||||||
peft_lora_B_to_scattermoe,
|
from .parallel_experts import flatten_sort_count, parallel_linear
|
||||||
peft_lora_to_scattermoe,
|
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
|
||||||
validate_scattermoe_lora_shapes,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"peft_down_proj_lora_to_scattermoe",
|
"layers",
|
||||||
"peft_lora_B_to_scattermoe",
|
"ParallelExperts",
|
||||||
"peft_lora_to_scattermoe",
|
"flatten_sort_count",
|
||||||
"validate_scattermoe_lora_shapes",
|
"parallel_linear",
|
||||||
|
"ScatterMoELoRA",
|
||||||
|
"parallel_linear_lora",
|
||||||
|
"lora_ops",
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
|
||||||
from . import layers
|
|
||||||
from .lora_ops import ParallelExperts
|
|
||||||
from .parallel_experts import flatten_sort_count, parallel_linear
|
|
||||||
from .parallel_linear_lora import ScatterMoELoRA, parallel_linear_lora
|
|
||||||
except ModuleNotFoundError as exc:
|
|
||||||
if exc.name != "triton":
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
__all__ += [
|
|
||||||
"layers",
|
|
||||||
"ParallelExperts",
|
|
||||||
"flatten_sort_count",
|
|
||||||
"parallel_linear",
|
|
||||||
"ScatterMoELoRA",
|
|
||||||
"parallel_linear_lora",
|
|
||||||
"lora_ops",
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -35,19 +35,46 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from .lora_layout import (
|
|
||||||
peft_down_proj_lora_to_scattermoe,
|
|
||||||
peft_lora_B_to_scattermoe,
|
|
||||||
peft_lora_to_scattermoe,
|
|
||||||
)
|
|
||||||
from .parallel_experts import flatten_sort_count, parallel_linear
|
from .parallel_experts import flatten_sort_count, parallel_linear
|
||||||
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
|
from .parallel_linear_lora import get_lora_params_from_wrapper, parallel_linear_lora
|
||||||
|
|
||||||
__all__ = [
|
# =============================================================================
|
||||||
"peft_down_proj_lora_to_scattermoe",
|
# LoRA layout conversion utilities (peft <-> scattermoe)
|
||||||
"peft_lora_B_to_scattermoe",
|
# =============================================================================
|
||||||
"peft_lora_to_scattermoe",
|
|
||||||
]
|
|
||||||
|
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
|
||||||
|
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
|
||||||
|
expert-major ``[N, r*E]``.
|
||||||
|
|
||||||
|
peft reshapes B to ``[out, r, E]`` (rank-major).
|
||||||
|
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
|
||||||
|
"""
|
||||||
|
N = peft_B.shape[0]
|
||||||
|
return (
|
||||||
|
peft_B.reshape(N, rank, num_experts)
|
||||||
|
.permute(0, 2, 1)
|
||||||
|
.contiguous()
|
||||||
|
.reshape(N, num_experts * rank)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||||
|
"""Convert peft LoRA weights to scattermoe layout.
|
||||||
|
|
||||||
|
peft >=0.19.1 assigns in/out features for 3D params such that
|
||||||
|
A and B already align with scattermoe's convention (no A<->B swap).
|
||||||
|
Only B needs rank-major → expert-major layout conversion.
|
||||||
|
"""
|
||||||
|
smoe_A = peft_A
|
||||||
|
smoe_B = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||||
|
return smoe_A, smoe_B
|
||||||
|
|
||||||
|
|
||||||
|
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||||
|
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
|
||||||
|
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# ParamWrapper unwrapping
|
# ParamWrapper unwrapping
|
||||||
@@ -137,7 +164,7 @@ def _unwrap_experts_lora(experts_module):
|
|||||||
if gup is not None:
|
if gup is not None:
|
||||||
num_experts = gup.shape[0]
|
num_experts = gup.shape[0]
|
||||||
|
|
||||||
# Extract gate_up_proj LoRA
|
# Extract gate_up_proj LoRA (needs A<->B swap due to transposition)
|
||||||
gup_lora = None
|
gup_lora = None
|
||||||
gup_wrapper = wrappers.get("gate_up_proj")
|
gup_wrapper = wrappers.get("gate_up_proj")
|
||||||
if gup_wrapper is not None:
|
if gup_wrapper is not None:
|
||||||
@@ -146,7 +173,7 @@ def _unwrap_experts_lora(experts_module):
|
|||||||
rank = lora_A.shape[0] // num_experts
|
rank = lora_A.shape[0] // num_experts
|
||||||
gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
|
gup_lora = _convert_smoe_lora(lora_A, lora_B, num_experts, rank, scaling)
|
||||||
|
|
||||||
# Extract down_proj LoRA
|
# Extract down_proj LoRA (needs A<->B swap due to transposition)
|
||||||
down_lora = None
|
down_lora = None
|
||||||
down_wrapper = wrappers.get("down_proj")
|
down_wrapper = wrappers.get("down_proj")
|
||||||
if down_wrapper is not None:
|
if down_wrapper is not None:
|
||||||
|
|||||||
@@ -1,78 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# Copyright (c) Axolotl AI
|
|
||||||
# Licensed under the Apache License, Version 2.0
|
|
||||||
|
|
||||||
"""Pure tensor layout helpers for ScatterMoE LoRA weights."""
|
|
||||||
|
|
||||||
|
|
||||||
def peft_lora_B_to_scattermoe(peft_B, num_experts, rank):
|
|
||||||
"""Convert peft rank-major lora_B ``[out, E*r]`` to scattermoe
|
|
||||||
expert-major ``[N, r*E]``.
|
|
||||||
|
|
||||||
peft reshapes B to ``[out, r, E]`` (rank-major).
|
|
||||||
scattermoe slices B as ``[:, e*r:(e+1)*r]`` (expert-major).
|
|
||||||
"""
|
|
||||||
N = peft_B.shape[0]
|
|
||||||
return (
|
|
||||||
peft_B.reshape(N, rank, num_experts)
|
|
||||||
.permute(0, 2, 1)
|
|
||||||
.contiguous()
|
|
||||||
.reshape(N, num_experts * rank)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
|
||||||
"""Convert peft LoRA weights to scattermoe layout.
|
|
||||||
|
|
||||||
peft operates on the parameter in its native storage layout ``[E, dim1, dim2]``
|
|
||||||
where ``out_features=dim1, in_features=dim2``. ScatterMoE transposes the
|
|
||||||
parameter (``W = param.transpose(2, 1)``), giving ``[E, dim2, dim1]`` with
|
|
||||||
``K=dim2, N=dim1``.
|
|
||||||
|
|
||||||
peft gives:
|
|
||||||
lora_A ``[r*E, dim2]``, lora_B ``[dim1, r*E]``
|
|
||||||
|
|
||||||
scattermoe needs:
|
|
||||||
lora_A ``[r*E, K=dim2]``, lora_B ``[N=dim1, r*E]``
|
|
||||||
|
|
||||||
peft's A already matches ScatterMoE's A shape. Only B needs conversion from
|
|
||||||
peft's rank-major layout to ScatterMoE's expert-major layout.
|
|
||||||
"""
|
|
||||||
smoe_A = peft_A
|
|
||||||
smoe_B = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
|
||||||
|
|
||||||
return smoe_A, smoe_B
|
|
||||||
|
|
||||||
|
|
||||||
def peft_down_proj_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
|
||||||
"""Deprecated alias for :func:`peft_lora_to_scattermoe`."""
|
|
||||||
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_scattermoe_lora_shapes(expert_weights, lora_A, lora_B):
|
|
||||||
"""Validate LoRA tensor layout before dispatching ScatterMoE kernels."""
|
|
||||||
E, K, N = expert_weights.shape
|
|
||||||
if lora_A.dim() != 2 or lora_B.dim() != 2:
|
|
||||||
raise ValueError(
|
|
||||||
"ScatterMoE LoRA expects 2D lora_A and lora_B tensors, got "
|
|
||||||
f"lora_A={tuple(lora_A.shape)} and lora_B={tuple(lora_B.shape)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
if lora_A.size(0) % E != 0:
|
|
||||||
raise ValueError(
|
|
||||||
"ScatterMoE LoRA expects lora_A rows to be divisible by the number "
|
|
||||||
f"of experts ({E}), got lora_A={tuple(lora_A.shape)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
rank = lora_A.size(0) // E
|
|
||||||
expected_A = (E * rank, K)
|
|
||||||
expected_B = (N, E * rank)
|
|
||||||
if tuple(lora_A.shape) != expected_A or tuple(lora_B.shape) != expected_B:
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid ScatterMoE LoRA layout for expert_weights "
|
|
||||||
f"{tuple(expert_weights.shape)}. Expected lora_A={expected_A} and "
|
|
||||||
f"lora_B={expected_B}, got lora_A={tuple(lora_A.shape)} and "
|
|
||||||
f"lora_B={tuple(lora_B.shape)}. For PEFT target_parameters, keep "
|
|
||||||
"lora_A as [E*r, K] and only convert lora_B from rank-major to "
|
|
||||||
"expert-major layout."
|
|
||||||
)
|
|
||||||
@@ -34,7 +34,6 @@ from .kernels.lora_ops import (
|
|||||||
scatter2scatter_lora,
|
scatter2scatter_lora,
|
||||||
scatter2scatter_lora_dX,
|
scatter2scatter_lora_dX,
|
||||||
)
|
)
|
||||||
from .lora_layout import validate_scattermoe_lora_shapes
|
|
||||||
|
|
||||||
|
|
||||||
class ScatterMoELoRA(torch.autograd.Function):
|
class ScatterMoELoRA(torch.autograd.Function):
|
||||||
@@ -423,6 +422,11 @@ def get_lora_params_from_wrapper(module) -> tuple:
|
|||||||
return lora_A, lora_B, scaling
|
return lora_A, lora_B, scaling
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Drop-in replacement for parallel_linear
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def parallel_linear_lora(
|
def parallel_linear_lora(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
expert_weights: torch.Tensor,
|
expert_weights: torch.Tensor,
|
||||||
@@ -447,7 +451,6 @@ def parallel_linear_lora(
|
|||||||
Otherwise falls back to standard scatter2scatter.
|
Otherwise falls back to standard scatter2scatter.
|
||||||
"""
|
"""
|
||||||
if lora_A is not None and lora_B is not None:
|
if lora_A is not None and lora_B is not None:
|
||||||
validate_scattermoe_lora_shapes(expert_weights, lora_A, lora_B)
|
|
||||||
return ScatterMoELoRA.apply(
|
return ScatterMoELoRA.apply(
|
||||||
inputs,
|
inputs,
|
||||||
expert_weights,
|
expert_weights,
|
||||||
|
|||||||
@@ -170,7 +170,6 @@ class PatchManager:
|
|||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
self._apply_llama_flash_attn_patches(model)
|
self._apply_llama_flash_attn_patches(model)
|
||||||
self._apply_unsloth_patches(model)
|
|
||||||
self._apply_lora_kernel_patch(model)
|
self._apply_lora_kernel_patch(model)
|
||||||
self._apply_scaling_softmax_patch(model)
|
self._apply_scaling_softmax_patch(model)
|
||||||
|
|
||||||
@@ -423,7 +422,16 @@ class PatchManager:
|
|||||||
patch_gemma4_fused_attn,
|
patch_gemma4_fused_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
patch_gemma4_fused_attn()
|
# Shared-KV side channel when activation checkpointing (PR #3611).
|
||||||
|
fsdp_cfg = self.cfg.fsdp_config
|
||||||
|
needs_shared_kv_workaround = (not self.inference) and bool(
|
||||||
|
self.cfg.gradient_checkpointing
|
||||||
|
or self.cfg.activation_offloading
|
||||||
|
or (fsdp_cfg is not None and fsdp_cfg.activation_checkpointing)
|
||||||
|
)
|
||||||
|
patch_gemma4_fused_attn(
|
||||||
|
install_shared_kv_workaround=needs_shared_kv_workaround
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fix_nemotron_h_conversion_mapping():
|
def _fix_nemotron_h_conversion_mapping():
|
||||||
@@ -701,24 +709,10 @@ class PatchManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
patch_fa_llama_cross_entropy()
|
patch_fa_llama_cross_entropy()
|
||||||
elif self.cfg.unsloth_cross_entropy_loss:
|
|
||||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
|
||||||
|
|
||||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
|
||||||
|
|
||||||
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
|
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_llama_rms_norm
|
from axolotl.monkeypatch.llama_attn_hijack_flash import patch_llama_rms_norm
|
||||||
|
|
||||||
patch_llama_rms_norm()
|
patch_llama_rms_norm()
|
||||||
elif self.cfg.unsloth_rms_norm:
|
|
||||||
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
|
|
||||||
|
|
||||||
patch_unsloth_layernorm()
|
|
||||||
|
|
||||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
|
||||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
|
||||||
|
|
||||||
patch_self_attn_lora()
|
|
||||||
|
|
||||||
def _patch_llama_flash_attention(self):
|
def _patch_llama_flash_attention(self):
|
||||||
"""Apply Flash Attention patches for LLaMA models."""
|
"""Apply Flash Attention patches for LLaMA models."""
|
||||||
@@ -785,23 +779,6 @@ class PatchManager:
|
|||||||
LOG.info("Patching with SwiGLU...")
|
LOG.info("Patching with SwiGLU...")
|
||||||
replace_llama_mlp_with_swiglu(model)
|
replace_llama_mlp_with_swiglu(model)
|
||||||
|
|
||||||
def _apply_unsloth_patches(self, model):
|
|
||||||
"""Apply unsloth optimization patches."""
|
|
||||||
if self.cfg.unsloth_lora_mlp:
|
|
||||||
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
|
|
||||||
|
|
||||||
integrate_lora_mlp_patch(peft_model=model)
|
|
||||||
|
|
||||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
|
||||||
from axolotl.monkeypatch.unsloth_ import integrate_lora_patch
|
|
||||||
|
|
||||||
integrate_lora_patch(peft_model=model, cfg=self.cfg)
|
|
||||||
|
|
||||||
if self.cfg.unsloth_rope:
|
|
||||||
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
|
|
||||||
|
|
||||||
integrate_rope_embeddings()
|
|
||||||
|
|
||||||
def _apply_lora_kernel_patch(self, model):
|
def _apply_lora_kernel_patch(self, model):
|
||||||
"""Apply LoRA kernel patches."""
|
"""Apply LoRA kernel patches."""
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
|
|||||||
processor_kwargs = {}
|
processor_kwargs = {}
|
||||||
if cfg.revision_of_model:
|
if cfg.revision_of_model:
|
||||||
processor_kwargs["revision"] = cfg.revision_of_model
|
processor_kwargs["revision"] = cfg.revision_of_model
|
||||||
|
if cfg.processor_kwargs:
|
||||||
|
processor_kwargs.update(cfg.processor_kwargs)
|
||||||
|
|
||||||
if cfg.tokenizer_use_mistral_common:
|
if cfg.tokenizer_use_mistral_common:
|
||||||
|
|
||||||
|
|||||||
@@ -6,15 +6,29 @@ kernels, eliminating intermediate tensor allocations from rotate_half / apply_ro
|
|||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn
|
from axolotl.monkeypatch.models.gemma4.fused_attn import patch_gemma4_fused_attn
|
||||||
patch_gemma4_fused_attn()
|
# Pass install_shared_kv_workaround=True when activation checkpointing is enabled.
|
||||||
|
patch_gemma4_fused_attn(install_shared_kv_workaround=True)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Module-level dict used as a side channel for shared KV states avoiding kwarg and TLS
|
||||||
|
# to prevent memory leak on gradient checkpoint enabled training (PR #3611)
|
||||||
|
_GEMMA4_SHARED_KV_STORE: dict = {"store": None}
|
||||||
|
|
||||||
|
|
||||||
|
def _set_shared_kv_states(store):
|
||||||
|
_GEMMA4_SHARED_KV_STORE["store"] = store
|
||||||
|
|
||||||
|
|
||||||
|
def _get_shared_kv_states():
|
||||||
|
return _GEMMA4_SHARED_KV_STORE["store"]
|
||||||
|
|
||||||
|
|
||||||
def _make_fused_forward(original_forward):
|
def _make_fused_forward(original_forward):
|
||||||
@@ -30,7 +44,7 @@ def _make_fused_forward(original_forward):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
position_embeddings: torch.Tensor,
|
position_embeddings: torch.Tensor,
|
||||||
attention_mask: torch.Tensor | None,
|
attention_mask: torch.Tensor | None,
|
||||||
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]],
|
shared_kv_states: dict[int, tuple[torch.Tensor, torch.Tensor]] | None = None,
|
||||||
past_key_values=None,
|
past_key_values=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
@@ -39,6 +53,10 @@ def _make_fused_forward(original_forward):
|
|||||||
eager_attention_forward,
|
eager_attention_forward,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
store = _get_shared_kv_states()
|
||||||
|
if store is not None:
|
||||||
|
shared_kv_states = store
|
||||||
|
|
||||||
input_shape = hidden_states.shape[:-1]
|
input_shape = hidden_states.shape[:-1]
|
||||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||||
eps = self.config.rms_norm_eps
|
eps = self.config.rms_norm_eps
|
||||||
@@ -133,15 +151,44 @@ def _make_fused_forward(original_forward):
|
|||||||
return fused_forward
|
return fused_forward
|
||||||
|
|
||||||
|
|
||||||
def patch_gemma4_fused_attn():
|
def _patch_decoder_layer_call():
|
||||||
|
"""Strip `shared_kv_states` from decoder-layer kwargs and route via the
|
||||||
|
module-level side channel so the checkpoint partial cannot pin it (PR #3611).
|
||||||
"""
|
"""
|
||||||
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels.
|
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextDecoderLayer
|
||||||
|
|
||||||
|
if getattr(Gemma4TextDecoderLayer, "_axolotl_shared_kv_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
original_call = Gemma4TextDecoderLayer.__call__
|
||||||
|
|
||||||
|
def patched_call(self, *args, **kwargs):
|
||||||
|
shared_kv = kwargs.pop("shared_kv_states", None)
|
||||||
|
# Overwrite unconditionally (including with None) so a previous step's
|
||||||
|
# dict cannot leak into a later call without shared_kv_states (PR #3611).
|
||||||
|
_set_shared_kv_states(shared_kv)
|
||||||
|
return original_call(self, *args, **kwargs)
|
||||||
|
|
||||||
|
Gemma4TextDecoderLayer.__call__ = patched_call
|
||||||
|
Gemma4TextDecoderLayer._axolotl_shared_kv_patched = True
|
||||||
|
|
||||||
|
|
||||||
|
def patch_gemma4_fused_attn(install_shared_kv_workaround: bool = False):
|
||||||
|
"""
|
||||||
|
Monkeypatch Gemma4TextAttention.forward to use fused RMSNorm+RoPE kernels,
|
||||||
|
and optionally route `shared_kv_states` via a module-level side channel to
|
||||||
|
avoid a VRAM leak under activation checkpointing (PR #3611).
|
||||||
"""
|
"""
|
||||||
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
|
from transformers.models.gemma4.modeling_gemma4 import Gemma4TextAttention
|
||||||
|
|
||||||
original_forward = Gemma4TextAttention.forward
|
original_forward = Gemma4TextAttention.forward
|
||||||
Gemma4TextAttention.forward = _make_fused_forward(original_forward)
|
Gemma4TextAttention.forward = _make_fused_forward(original_forward)
|
||||||
|
|
||||||
|
if install_shared_kv_workaround:
|
||||||
|
_patch_decoder_layer_call()
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels"
|
"Patched Gemma4TextAttention.forward with fused RMSNorm+RoPE Triton kernels"
|
||||||
)
|
)
|
||||||
|
if install_shared_kv_workaround:
|
||||||
|
logger.info("Installed Gemma4 shared_kv_states side channel (PR #3611)")
|
||||||
|
|||||||
@@ -407,7 +407,10 @@ def selective_log_softmax(logits, index) -> torch.Tensor:
|
|||||||
K = index.shape[-1]
|
K = index.shape[-1]
|
||||||
original_index_shape = index.shape
|
original_index_shape = index.shape
|
||||||
|
|
||||||
flat_logits = logits.reshape(-1, V).contiguous()
|
try:
|
||||||
|
flat_logits = logits.view(-1, V)
|
||||||
|
except RuntimeError:
|
||||||
|
flat_logits = logits.reshape(-1, V).contiguous()
|
||||||
flat_index = index.reshape(-1, K).contiguous()
|
flat_index = index.reshape(-1, K).contiguous()
|
||||||
|
|
||||||
BLOCK_V = 4096
|
BLOCK_V = 4096
|
||||||
|
|||||||
@@ -1,252 +0,0 @@
|
|||||||
"""module for patching with unsloth optimizations"""
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
import types
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from peft import PeftModelForCausalLM
|
|
||||||
from torch import nn
|
|
||||||
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import detab_code
|
|
||||||
from axolotl.utils.logging import get_logger
|
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
|
||||||
|
|
||||||
ORIGINAL_QKV_CODE = """
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
""".lstrip("\n")
|
|
||||||
|
|
||||||
PATCHED_QKV_CODE = """
|
|
||||||
query_states, key_states, value_states = self.apply_qkv(self, hidden_states)
|
|
||||||
""".lstrip("\n")
|
|
||||||
|
|
||||||
ORIGINAL_O_CODE = """
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
""".lstrip("\n")
|
|
||||||
|
|
||||||
PATCHED_O_CODE = """
|
|
||||||
attn_output = self.apply_o(self, attn_output)
|
|
||||||
""".lstrip("\n")
|
|
||||||
|
|
||||||
|
|
||||||
def original_apply_qkv(self, hidden_states):
|
|
||||||
query_states = self.q_proj(hidden_states)
|
|
||||||
key_states = self.k_proj(hidden_states)
|
|
||||||
value_states = self.v_proj(hidden_states)
|
|
||||||
return query_states, key_states, value_states
|
|
||||||
|
|
||||||
|
|
||||||
def original_apply_o(self, hidden_states):
|
|
||||||
attn_output = self.o_proj(hidden_states)
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
|
|
||||||
def get_self_attn_code() -> str:
|
|
||||||
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
|
||||||
return forward
|
|
||||||
|
|
||||||
|
|
||||||
def check_self_attn_is_patchable() -> bool:
|
|
||||||
qkv = get_self_attn_code()
|
|
||||||
qkv, _ = detab_code(qkv)
|
|
||||||
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
|
|
||||||
|
|
||||||
|
|
||||||
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
|
||||||
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
|
|
||||||
|
|
||||||
def UnslothForCausalLMLoss(
|
|
||||||
logits,
|
|
||||||
labels,
|
|
||||||
vocab_size: int,
|
|
||||||
num_items_in_batch: int = None,
|
|
||||||
ignore_index: int = -100,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
|
||||||
logits = logits.float()
|
|
||||||
# Shift so that tokens < n predict n
|
|
||||||
shift_logits = logits[..., :-1, :].contiguous()
|
|
||||||
shift_labels = labels[..., 1:].contiguous()
|
|
||||||
|
|
||||||
loss = fast_cross_entropy_loss(
|
|
||||||
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
|
|
||||||
)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
if model_type == "llama":
|
|
||||||
from transformers.loss import loss_utils
|
|
||||||
|
|
||||||
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported model type")
|
|
||||||
|
|
||||||
|
|
||||||
self_attn_lora_patched = False
|
|
||||||
|
|
||||||
|
|
||||||
def patch_self_attn_lora():
|
|
||||||
global self_attn_lora_patched
|
|
||||||
if self_attn_lora_patched:
|
|
||||||
# prevent patching multiple times
|
|
||||||
return
|
|
||||||
self_attn_forward = get_self_attn_code()
|
|
||||||
LlamaFlashAttention2._original_forward = self_attn_forward
|
|
||||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
|
||||||
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original qkv code not found"
|
|
||||||
assert ORIGINAL_O_CODE in self_attn_forward, "Original o code not found"
|
|
||||||
|
|
||||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
|
|
||||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
|
||||||
self_attn_forward = self_attn_forward.replace(
|
|
||||||
"def forward(",
|
|
||||||
"def unsloth_attn_forward(",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load imports necessary
|
|
||||||
import transformers.models.llama.modeling_llama
|
|
||||||
|
|
||||||
items_to_import = []
|
|
||||||
for item in dir(transformers.models.llama.modeling_llama):
|
|
||||||
if item in self_attn_forward:
|
|
||||||
items_to_import.append(item)
|
|
||||||
|
|
||||||
exec(
|
|
||||||
"from transformers.models.llama.modeling_llama import ("
|
|
||||||
+ ", ".join(x for x in items_to_import)
|
|
||||||
+ ")",
|
|
||||||
globals(),
|
|
||||||
)
|
|
||||||
exec(self_attn_forward, globals())
|
|
||||||
self_attn_lora_patched = True
|
|
||||||
LOG.info("patching unsloth attn lora")
|
|
||||||
LlamaFlashAttention2.forward = unsloth_attn_forward
|
|
||||||
|
|
||||||
|
|
||||||
def integrate_rope_embeddings():
|
|
||||||
import transformers.models.llama.modeling_llama
|
|
||||||
from unsloth.kernels.rope_embedding import fast_rope_embedding
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
cos,
|
|
||||||
sin,
|
|
||||||
position_ids=None,
|
|
||||||
unsqueeze_dim=1,
|
|
||||||
):
|
|
||||||
return fast_rope_embedding(q, k, cos, sin)
|
|
||||||
|
|
||||||
LOG.info("patching unsloth RoPE embeddings")
|
|
||||||
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
|
||||||
|
|
||||||
|
|
||||||
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
|
||||||
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
|
|
||||||
from unsloth.kernels import apply_lora_mlp_swiglu
|
|
||||||
|
|
||||||
apply_lora_mlp = apply_lora_mlp_swiglu
|
|
||||||
elif peft_model.base_model.config.model_type == "gemma":
|
|
||||||
from unsloth.kernels import apply_lora_mlp_geglu_approx
|
|
||||||
|
|
||||||
apply_lora_mlp = apply_lora_mlp_geglu_approx
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Model type {peft_model.base_model.config.model_type} not supported"
|
|
||||||
)
|
|
||||||
|
|
||||||
for idx, layer in enumerate(peft_model.model.model.layers):
|
|
||||||
layer_modules = [
|
|
||||||
getattr(layer.mlp, linear_proj)
|
|
||||||
for linear_proj in ["gate_proj", "up_proj", "down_proj"]
|
|
||||||
]
|
|
||||||
is_mlp_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
|
||||||
mlp_no_bias = all(
|
|
||||||
getattr(module, "base_layer", module).bias is None
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
mlp_not_dora = all(
|
|
||||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
|
||||||
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
|
||||||
else:
|
|
||||||
LOG.warning(f"unable to apply unsloth lora mlp patch to layer {idx}")
|
|
||||||
|
|
||||||
|
|
||||||
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
|
||||||
from unsloth.kernels import apply_lora_o, apply_lora_qkv
|
|
||||||
|
|
||||||
for idx, layer in enumerate(peft_model.model.model.layers):
|
|
||||||
if cfg.unsloth_lora_qkv:
|
|
||||||
layer_modules = [
|
|
||||||
getattr(layer.self_attn, linear_proj)
|
|
||||||
for linear_proj in ["q_proj", "k_proj", "v_proj"]
|
|
||||||
]
|
|
||||||
is_qkv_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
|
||||||
qkv_no_bias = all(
|
|
||||||
getattr(module, "base_layer", module).bias is None
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
qkv_not_dora = all(
|
|
||||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_qkv_lora and qkv_no_bias and qkv_not_dora:
|
|
||||||
layer.self_attn.apply_qkv = apply_lora_qkv
|
|
||||||
else:
|
|
||||||
layer.self_attn.apply_qkv = original_apply_qkv
|
|
||||||
LOG.warning(f"unable to apply unsloth lora qkv patch to layer {idx}")
|
|
||||||
if cfg.unsloth_lora_o:
|
|
||||||
layer_modules = [
|
|
||||||
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
|
||||||
]
|
|
||||||
is_o_lora = all(hasattr(module, "lora_A") for module in layer_modules)
|
|
||||||
o_no_bias = all(
|
|
||||||
getattr(module, "base_layer", module).bias is None
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
o_not_dora = all(
|
|
||||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
|
||||||
for module in layer_modules
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_o_lora and o_no_bias and o_not_dora:
|
|
||||||
layer.self_attn.apply_o = apply_lora_o
|
|
||||||
else:
|
|
||||||
layer.self_attn.apply_o = original_apply_o
|
|
||||||
LOG.warning(f"unable to apply unsloth lora o_proj patch to layer {idx}")
|
|
||||||
|
|
||||||
|
|
||||||
def patch_unsloth_layernorm():
|
|
||||||
try:
|
|
||||||
import transformers.models.llama.modeling_llama
|
|
||||||
from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm
|
|
||||||
|
|
||||||
class LlamaRMSNorm(nn.Module):
|
|
||||||
"""LlamaRMSNorm"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
"""
|
|
||||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
return Fast_RMS_Layernorm.apply(
|
|
||||||
hidden_states, self.weight, self.variance_epsilon, False
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("patching with unsloth.kernels.rms_layernorm")
|
|
||||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
|
||||||
except ImportError:
|
|
||||||
LOG.warning("missing unsloth library")
|
|
||||||
@@ -394,8 +394,8 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
|
||||||
try:
|
try:
|
||||||
return all(isinstance(v, list) for v in prompt.values()) and all(
|
return all(isinstance(v, (str, list)) for v in prompt.values()) and all(
|
||||||
isinstance(v, list) for v in prompt[self.prompter.field_messages]
|
isinstance(v, (str, list)) for v in prompt[self.prompter.field_messages]
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return False
|
return False
|
||||||
@@ -1004,6 +1004,13 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
if tools is None:
|
if tools is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Some datasets have tools set to str
|
||||||
|
if isinstance(tools, str):
|
||||||
|
try:
|
||||||
|
tools = json.loads(tools)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
LOG.error(f"Error parsing tool parameters as JSON. Error: {e}")
|
||||||
|
raise
|
||||||
if isinstance(tools, list):
|
if isinstance(tools, list):
|
||||||
# Process each tool to handle JSON string parameters
|
# Process each tool to handle JSON string parameters
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
@@ -1034,6 +1041,22 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
if messages is None:
|
if messages is None:
|
||||||
raise ValueError("Messages is null. Please check `field_messages`.")
|
raise ValueError("Messages is null. Please check `field_messages`.")
|
||||||
|
|
||||||
|
if isinstance(messages, str):
|
||||||
|
try:
|
||||||
|
messages = json.loads(messages)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
LOG.error(f"Error parsing messages as JSON. Error: {e}")
|
||||||
|
raise
|
||||||
|
assert isinstance(messages, list), (
|
||||||
|
f"For SFT datasets that are stored in `str` format, the turns must be saved in a list of dictionaries, got {type(message)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extra check here to make sure decoded json is a list of dicts.
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
assert isinstance(message, dict), (
|
||||||
|
f"For SFT datasets that are stored in `str` format, each turns must be saved in a dictionary, got {type(message)} for the turn {i}"
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(messages, list):
|
if isinstance(messages, list):
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from .batching import (
|
|||||||
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
PretrainingBatchSamplerDataCollatorForSeq2Seq,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
|
from .dpo import AxolotlDPODataCollatorWithPadding
|
||||||
from .mamba import MambaDataCollator
|
from .mamba import MambaDataCollator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -13,5 +14,6 @@ __all__ = [
|
|||||||
"BatchSamplerDataCollatorForSeq2Seq",
|
"BatchSamplerDataCollatorForSeq2Seq",
|
||||||
"V2BatchSamplerDataCollatorForSeq2Seq",
|
"V2BatchSamplerDataCollatorForSeq2Seq",
|
||||||
"PretrainingBatchSamplerDataCollatorForSeq2Seq",
|
"PretrainingBatchSamplerDataCollatorForSeq2Seq",
|
||||||
|
"AxolotlDPODataCollatorWithPadding",
|
||||||
"MambaDataCollator",
|
"MambaDataCollator",
|
||||||
]
|
]
|
||||||
|
|||||||
128
src/axolotl/utils/collators/dpo.py
Normal file
128
src/axolotl/utils/collators/dpo.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""DPO/ORPO/IPO/KTO data collator with pad_to_multiple_of support.
|
||||||
|
|
||||||
|
Extends TRL's DPODataCollatorWithPadding to round padded sequence lengths
|
||||||
|
up to a fixed multiple. This stabilizes Triton autotune caches for kernels
|
||||||
|
that key on sequence length (e.g. fla's linear attention kernels used by
|
||||||
|
Qwen3.5), which otherwise re-autotune on every distinct batch length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from trl.experimental.utils import DPODataCollatorWithPadding
|
||||||
|
from trl.trainer.utils import pad
|
||||||
|
|
||||||
|
|
||||||
|
def _round_up(length: int, multiple: int) -> int:
|
||||||
|
return ((length + multiple - 1) // multiple) * multiple
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlDPODataCollatorWithPadding(DPODataCollatorWithPadding):
|
||||||
|
"""DPO data collator that pads to a multiple of ``pad_to_multiple_of``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pad_token_id: Tokenizer pad token id (inherited).
|
||||||
|
is_encoder_decoder: Whether the model is encoder-decoder (inherited).
|
||||||
|
pad_to_multiple_of: If set, padded lengths are rounded up to this
|
||||||
|
multiple. Helps stabilize Triton autotune caches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pad_to_multiple_of: int | None = None
|
||||||
|
|
||||||
|
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
|
||||||
|
pad_to_mult = self.pad_to_multiple_of
|
||||||
|
|
||||||
|
padded_batch: dict[str, Any] = {}
|
||||||
|
for k in features[0].keys():
|
||||||
|
if k.endswith(
|
||||||
|
("_input_ids", "_attention_mask", "_labels", "_pixel_values")
|
||||||
|
):
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
if k.endswith("_pixel_values"):
|
||||||
|
to_pad = [
|
||||||
|
torch.tensor(ex[k], dtype=torch.float32) for ex in features
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
to_pad = [torch.LongTensor(ex[k]) for ex in features]
|
||||||
|
|
||||||
|
if k.startswith("prompt") and k.endswith("input_ids"):
|
||||||
|
if self.pad_token_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Padding is enabled, but the tokenizer is not configured with a padding token."
|
||||||
|
)
|
||||||
|
padding_value = self.pad_token_id
|
||||||
|
elif k.endswith("_attention_mask"):
|
||||||
|
padding_value = 0
|
||||||
|
elif k.endswith("_pixel_values"):
|
||||||
|
padding_value = 0
|
||||||
|
elif (
|
||||||
|
k.startswith(("chosen", "rejected", "completion"))
|
||||||
|
or "decoder" in k
|
||||||
|
):
|
||||||
|
padding_value = -100
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected key in batch '{k}'")
|
||||||
|
|
||||||
|
padded = pad_sequence(
|
||||||
|
to_pad, batch_first=True, padding_value=padding_value
|
||||||
|
)
|
||||||
|
if pad_to_mult:
|
||||||
|
cur = padded.shape[1]
|
||||||
|
target = _round_up(cur, pad_to_mult)
|
||||||
|
if target > cur:
|
||||||
|
extra = target - cur
|
||||||
|
pad_shape = list(padded.shape)
|
||||||
|
pad_shape[1] = extra
|
||||||
|
filler = torch.full(
|
||||||
|
pad_shape,
|
||||||
|
padding_value,
|
||||||
|
dtype=padded.dtype,
|
||||||
|
device=padded.device,
|
||||||
|
)
|
||||||
|
padded = torch.cat([padded, filler], dim=1)
|
||||||
|
padded_batch[k] = padded
|
||||||
|
else:
|
||||||
|
if k.endswith("_input_ids"):
|
||||||
|
if self.pad_token_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Padding is enabled, but the tokenizer is not configured with a padding token."
|
||||||
|
)
|
||||||
|
padding_value = self.pad_token_id
|
||||||
|
elif k.endswith("_labels"):
|
||||||
|
padding_value = -100
|
||||||
|
elif k.endswith("_attention_mask"):
|
||||||
|
padding_value = 0
|
||||||
|
elif k.endswith("_pixel_values"):
|
||||||
|
padding_value = 0
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected key in batch '{k}'")
|
||||||
|
|
||||||
|
padding_side = (
|
||||||
|
"left"
|
||||||
|
if k in ("prompt_input_ids", "prompt_attention_mask")
|
||||||
|
else "right"
|
||||||
|
)
|
||||||
|
|
||||||
|
dtype = (
|
||||||
|
torch.float32 if k.endswith("_pixel_values") else torch.int64
|
||||||
|
)
|
||||||
|
to_pad = [torch.tensor(ex[k], dtype=dtype) for ex in features]
|
||||||
|
|
||||||
|
# trl.pad() natively supports pad_to_multiple_of
|
||||||
|
padded_batch[k] = pad(
|
||||||
|
to_pad,
|
||||||
|
padding_value=padding_value,
|
||||||
|
padding_side=padding_side,
|
||||||
|
pad_to_multiple_of=pad_to_mult,
|
||||||
|
)
|
||||||
|
elif k.endswith("_logps"):
|
||||||
|
padded_batch[k] = torch.tensor([ex[k] for ex in features])
|
||||||
|
else:
|
||||||
|
padded_batch[k] = [ex[k] for ex in features]
|
||||||
|
|
||||||
|
return padded_batch
|
||||||
@@ -309,6 +309,16 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
dpo_padding_free: bool | None = None
|
dpo_padding_free: bool | None = None
|
||||||
|
|
||||||
|
dpo_loss_type: Annotated[list[str], MinLen(1)] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "List of DPO losses to use."},
|
||||||
|
)
|
||||||
|
|
||||||
|
dpo_loss_weights: Annotated[list[float], MinLen(1)] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={"description": "Weights for each DPO loss."},
|
||||||
|
)
|
||||||
|
|
||||||
datasets: (
|
datasets: (
|
||||||
Annotated[
|
Annotated[
|
||||||
list[
|
list[
|
||||||
@@ -663,6 +673,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Pad inputs so each step uses constant sized buffers. This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently. Defaults to True if `sample_packing` enabled"
|
"description": "Pad inputs so each step uses constant sized buffers. This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently. Defaults to True if `sample_packing` enabled"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
pad_to_multiple_of: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": ("Pad each batch to a multiple of this value.")
|
||||||
|
},
|
||||||
|
)
|
||||||
curriculum_sampling: bool | None = Field(
|
curriculum_sampling: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -823,13 +839,6 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
unsloth_cross_entropy_loss: bool | None = None
|
|
||||||
unsloth_lora_mlp: bool | None = None
|
|
||||||
unsloth_lora_qkv: bool | None = None
|
|
||||||
unsloth_lora_o: bool | None = None
|
|
||||||
unsloth_rms_norm: bool | None = None
|
|
||||||
unsloth_rope: bool | None = None
|
|
||||||
|
|
||||||
lora_mlp_kernel: bool | None = Field(
|
lora_mlp_kernel: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -1007,7 +1016,7 @@ class AxolotlInputConfig(
|
|||||||
torch_compile: Literal["auto"] | bool | None = Field(
|
torch_compile: Literal["auto"] | bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "Whether to use torch.compile and which backend to use. setting to `auto` will enable torch compile when torch>=2.6.0"
|
"description": "Whether to use torch.compile and which backend to use."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
torch_compile_backend: str | None = Field(
|
torch_compile_backend: str | None = Field(
|
||||||
@@ -1469,21 +1478,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_multigpu_unsloth(cls, data):
|
|
||||||
if (
|
|
||||||
data.get("unsloth_lora_mlp")
|
|
||||||
or data.get("unsloth_lora_qkv")
|
|
||||||
or data.get("unsloth_lora_o")
|
|
||||||
):
|
|
||||||
capabilities = data.get("capabilities")
|
|
||||||
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
|
||||||
raise ValueError(
|
|
||||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_multigpu_lora_kernels(cls, data):
|
def check_multigpu_lora_kernels(cls, data):
|
||||||
@@ -1537,8 +1531,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
# RL trainers not tested so don't enable kernels by default
|
# RL trainers not tested so don't enable kernels by default
|
||||||
return data
|
return data
|
||||||
if data.get("adapter") in ["lora", "qlora"]:
|
if data.get("adapter") in ["lora", "qlora"]:
|
||||||
# Skip if already set, using unsloth optimizations, or using 8-bit
|
# Skip if already set or using 8-bit
|
||||||
unsloth_fields = ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
|
||||||
kernel_fields = [
|
kernel_fields = [
|
||||||
"lora_mlp_kernel",
|
"lora_mlp_kernel",
|
||||||
"lora_qkv_kernel",
|
"lora_qkv_kernel",
|
||||||
@@ -1547,7 +1540,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
]
|
]
|
||||||
if (
|
if (
|
||||||
any(data.get(k) is not None for k in kernel_fields)
|
any(data.get(k) is not None for k in kernel_fields)
|
||||||
or any(data.get(k) for k in unsloth_fields)
|
|
||||||
or data.get("adapter") == "lora"
|
or data.get("adapter") == "lora"
|
||||||
and data.get("load_in_8bit")
|
and data.get("load_in_8bit")
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -64,6 +64,12 @@ class ModelInputConfig(BaseModel):
|
|||||||
processor_type: str | None = Field(
|
processor_type: str | None = Field(
|
||||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||||
)
|
)
|
||||||
|
processor_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "kwargs forwarded to the processor's from_pretrained(), overriding processor config (e.g. image_seq_length, min_pixels, etc.)."
|
||||||
|
},
|
||||||
|
)
|
||||||
tokenizer_save_jinja_files: bool | None = Field(
|
tokenizer_save_jinja_files: bool | None = Field(
|
||||||
default=True, # match the default behavior from transformers
|
default=True, # match the default behavior from transformers
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -107,6 +113,22 @@ class ModelInputConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
return trust_remote_code
|
return trust_remote_code
|
||||||
|
|
||||||
|
@field_validator("processor_kwargs")
|
||||||
|
@classmethod
|
||||||
|
def reject_reserved_processor_kwargs(cls, processor_kwargs):
|
||||||
|
if not processor_kwargs:
|
||||||
|
return processor_kwargs
|
||||||
|
reserved = {"revision", "trust_remote_code"}
|
||||||
|
conflicts = reserved.intersection(processor_kwargs)
|
||||||
|
if conflicts:
|
||||||
|
raise ValueError(
|
||||||
|
"Do not set reserved keys "
|
||||||
|
f"{sorted(conflicts)} inside `processor_kwargs`; "
|
||||||
|
"use the top-level `revision_of_model` / `trust_remote_code` "
|
||||||
|
"config keys instead."
|
||||||
|
)
|
||||||
|
return processor_kwargs
|
||||||
|
|
||||||
|
|
||||||
class ModelOutputConfig(BaseModel):
|
class ModelOutputConfig(BaseModel):
|
||||||
"""model save configuration subset"""
|
"""model save configuration subset"""
|
||||||
|
|||||||
@@ -52,6 +52,26 @@ class DatasetValidationMixin:
|
|||||||
|
|
||||||
return datasets
|
return datasets
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_deprecated_unsloth_fields(cls, data):
|
||||||
|
deprecated_fields = [
|
||||||
|
"unsloth_cross_entropy_loss",
|
||||||
|
"unsloth_lora_mlp",
|
||||||
|
"unsloth_lora_qkv",
|
||||||
|
"unsloth_lora_o",
|
||||||
|
"unsloth_rms_norm",
|
||||||
|
"unsloth_rope",
|
||||||
|
]
|
||||||
|
found = [f for f in deprecated_fields if data.get(f)]
|
||||||
|
if found:
|
||||||
|
raise ValueError(
|
||||||
|
f"`{'`, `'.join(found)}` {'has' if len(found) == 1 else 'have'} been removed. "
|
||||||
|
"Please use `lora_mlp_kernel`, `lora_qkv_kernel`, `lora_o_kernel` instead. "
|
||||||
|
"See: https://docs.axolotl.ai/docs/lora_optims.html"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_dataset_or_pretraining_dataset(cls, data):
|
def check_dataset_or_pretraining_dataset(cls, data):
|
||||||
@@ -558,6 +578,11 @@ class TrainingValidationMixin:
|
|||||||
"Setting chat_template is not supported with mistral-common tokenizer"
|
"Setting chat_template is not supported with mistral-common tokenizer"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if data.get("processor_kwargs"):
|
||||||
|
raise ValueError(
|
||||||
|
"processor_kwargs is not supported with mistral-common tokenizer"
|
||||||
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -607,36 +632,6 @@ class LoRAValidationMixin:
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_qlora_unsloth(cls, data):
|
|
||||||
if (
|
|
||||||
data.get("unsloth_lora_mlp")
|
|
||||||
or data.get("unsloth_lora_qkv")
|
|
||||||
or data.get("unsloth_lora_o")
|
|
||||||
):
|
|
||||||
if data.get("adapter") == "lora" and data.get("load_in_8bit"):
|
|
||||||
raise ValueError(
|
|
||||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_lora_axolotl_unsloth(cls, data):
|
|
||||||
is_lora_kernel = any(
|
|
||||||
data.get(k) for k in ["lora_mlp_kernel", "lora_qkv_kernel", "lora_o_kernel"]
|
|
||||||
)
|
|
||||||
is_unsloth_lora = any(
|
|
||||||
data.get(k)
|
|
||||||
for k in ["unsloth_lora_mlp", "unsloth_lora_qkv", "unsloth_lora_o"]
|
|
||||||
)
|
|
||||||
if is_lora_kernel and is_unsloth_lora:
|
|
||||||
raise ValueError(
|
|
||||||
"both lora_mlp_kernel and unsloth_lora_mlp cannot be true (similarly for lora_qkv_kernel, lora_o_kernel)"
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_fused_lora(self):
|
def check_fused_lora(self):
|
||||||
if self.adapter in ["lora", "qlora"] and self.flash_attn_fuse_mlp:
|
if self.adapter in ["lora", "qlora"] and self.flash_attn_fuse_mlp:
|
||||||
@@ -770,6 +765,40 @@ class RLValidationMixin:
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_dpo(cls, data):
|
||||||
|
dpo_loss_type = data.get("dpo_loss_type")
|
||||||
|
dpo_loss_weights = data.get("dpo_loss_weights")
|
||||||
|
rl = data.get("rl")
|
||||||
|
|
||||||
|
if rl == "ipo":
|
||||||
|
LOG.warning(
|
||||||
|
"rl: ipo will soon be deprecated. Use `rl: dpo` with `dpo_loss_type: ['ipo']` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
if rl == "dpo":
|
||||||
|
if dpo_loss_weights is not None and dpo_loss_type is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`dpo_loss_weights` requires `dpo_loss_type` to be set"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
dpo_loss_type is not None
|
||||||
|
and dpo_loss_weights is not None
|
||||||
|
and len(dpo_loss_type) != len(dpo_loss_weights)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"`dpo_loss_type` and `dpo_loss_weights` must be the same length, "
|
||||||
|
f"but got {len(dpo_loss_type)} losses and {len(dpo_loss_weights)} weights"
|
||||||
|
)
|
||||||
|
elif dpo_loss_type is not None or dpo_loss_weights is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"`dpo_loss_type` and `dpo_loss_weights` are for DPO only,"
|
||||||
|
f"but got {rl=}, {dpo_loss_type=} and {dpo_loss_weights=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_grpo_batch_size_divisibility(cls, data):
|
def check_grpo_batch_size_divisibility(cls, data):
|
||||||
@@ -942,17 +971,6 @@ class OptimizationValidationMixin:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_xentropy_patch_conflicts(cls, data):
|
|
||||||
if data.get("flash_attn_cross_entropy") and data.get(
|
|
||||||
"unsloth_cross_entropy_loss"
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_cross_entropy_conflicts(cls, data):
|
def check_cross_entropy_conflicts(cls, data):
|
||||||
|
|||||||
@@ -1,102 +0,0 @@
|
|||||||
"""
|
|
||||||
dynamic requirements for axolotl
|
|
||||||
"""
|
|
||||||
|
|
||||||
import platform
|
|
||||||
import re
|
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
|
||||||
|
|
||||||
from setuptools.command.build_py import build_py as _build_py
|
|
||||||
|
|
||||||
|
|
||||||
def parse_requirements():
|
|
||||||
_install_requires = []
|
|
||||||
_dependency_links = []
|
|
||||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
|
||||||
lines = [r.strip() for r in requirements_file.readlines()]
|
|
||||||
for line in lines:
|
|
||||||
is_extras = (
|
|
||||||
"flash-attn" in line
|
|
||||||
or "flash-attention" in line
|
|
||||||
or "deepspeed" in line
|
|
||||||
or "mamba-ssm" in line
|
|
||||||
or "lion-pytorch" in line
|
|
||||||
)
|
|
||||||
if line.startswith("--extra-index-url"):
|
|
||||||
# Handle custom index URLs
|
|
||||||
_, url = line.split()
|
|
||||||
_dependency_links.append(url)
|
|
||||||
elif not is_extras and line and line[0] != "#":
|
|
||||||
# Handle standard packages
|
|
||||||
_install_requires.append(line)
|
|
||||||
|
|
||||||
try:
|
|
||||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
|
||||||
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
|
||||||
|
|
||||||
if "Darwin" in platform.system():
|
|
||||||
# don't install xformers on MacOS
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
else:
|
|
||||||
# detect the version of torch already installed
|
|
||||||
# and set it so dependencies don't clobber the torch version
|
|
||||||
try:
|
|
||||||
torch_version = version("torch")
|
|
||||||
except PackageNotFoundError:
|
|
||||||
torch_version = "2.5.1"
|
|
||||||
_install_requires.append(f"torch=={torch_version}")
|
|
||||||
|
|
||||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
|
||||||
if version_match:
|
|
||||||
major, minor, patch = version_match.groups()
|
|
||||||
major, minor = int(major), int(minor)
|
|
||||||
patch = (
|
|
||||||
int(patch) if patch is not None else 0
|
|
||||||
) # Default patch to 0 if not present
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid version format")
|
|
||||||
|
|
||||||
if (major, minor) >= (2, 5):
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
if patch == 0:
|
|
||||||
_install_requires.append("xformers==0.0.28.post2")
|
|
||||||
else:
|
|
||||||
_install_requires.append("xformers==0.0.28.post3")
|
|
||||||
elif (major, minor) >= (2, 4):
|
|
||||||
if patch == 0:
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
_install_requires.append("xformers>=0.0.27")
|
|
||||||
else:
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
_install_requires.append("xformers==0.0.28.post1")
|
|
||||||
elif (major, minor) >= (2, 3):
|
|
||||||
_install_requires.pop(_install_requires.index(torchao_version))
|
|
||||||
if patch == 0:
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
_install_requires.append("xformers>=0.0.26.post1")
|
|
||||||
else:
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
_install_requires.append("xformers>=0.0.27")
|
|
||||||
elif (major, minor) >= (2, 2):
|
|
||||||
_install_requires.pop(_install_requires.index(torchao_version))
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
_install_requires.append("xformers>=0.0.25.post1")
|
|
||||||
else:
|
|
||||||
_install_requires.pop(_install_requires.index(torchao_version))
|
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
|
||||||
_install_requires.append("xformers>=0.0.23.post1")
|
|
||||||
|
|
||||||
except PackageNotFoundError:
|
|
||||||
pass
|
|
||||||
return _install_requires, _dependency_links
|
|
||||||
|
|
||||||
|
|
||||||
class BuildPyCommand(_build_py):
|
|
||||||
"""
|
|
||||||
custom build_py command to parse dynamic requirements
|
|
||||||
"""
|
|
||||||
|
|
||||||
def finalize_options(self):
|
|
||||||
super().finalize_options()
|
|
||||||
install_requires, _ = parse_requirements()
|
|
||||||
self.distribution.install_requires = install_requires
|
|
||||||
@@ -325,10 +325,10 @@ def download_phi_4_reasoning_model_fixture():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_phi_3_medium_model_fixture():
|
def download_phi_3_mini_model_fixture():
|
||||||
# download the tokenizer only
|
# download the tokenizer only
|
||||||
snapshot_download_w_retry(
|
snapshot_download_w_retry(
|
||||||
"microsoft/Phi-3-medium-128k-instruct",
|
"microsoft/Phi-3-mini-4k-instruct",
|
||||||
repo_type="model",
|
repo_type="model",
|
||||||
allow_patterns=["*token*", "config.json"],
|
allow_patterns=["*token*", "config.json"],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -96,6 +96,8 @@ def fixture_dpo_cfg(base_cfg):
|
|||||||
"dpo_use_weighting": True,
|
"dpo_use_weighting": True,
|
||||||
"dpo_label_smoothing": 0.1,
|
"dpo_label_smoothing": 0.1,
|
||||||
"beta": 0.1, # DPO beta
|
"beta": 0.1, # DPO beta
|
||||||
|
"dpo_loss_type": ["sigmoid", "sft"],
|
||||||
|
"dpo_loss_weights": [1.0, 0.5],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return cfg
|
return cfg
|
||||||
@@ -164,7 +166,8 @@ def fixture_ipo_cfg(base_cfg):
|
|||||||
cfg = base_cfg.copy()
|
cfg = base_cfg.copy()
|
||||||
cfg.update(
|
cfg.update(
|
||||||
{
|
{
|
||||||
"rl": RLType.IPO,
|
"rl": RLType.DPO,
|
||||||
|
"dpo_loss_type": ["ipo"],
|
||||||
"dpo_label_smoothing": 0,
|
"dpo_label_smoothing": 0,
|
||||||
"beta": 0.1,
|
"beta": 0.1,
|
||||||
}
|
}
|
||||||
@@ -300,6 +303,8 @@ class TestHFRLTrainerBuilder:
|
|||||||
assert training_arguments.use_weighting is True
|
assert training_arguments.use_weighting is True
|
||||||
assert training_arguments.label_smoothing == 0.1
|
assert training_arguments.label_smoothing == 0.1
|
||||||
assert training_arguments.precompute_ref_log_probs is True
|
assert training_arguments.precompute_ref_log_probs is True
|
||||||
|
assert training_arguments.loss_type == ["sigmoid", "sft"]
|
||||||
|
assert training_arguments.loss_weights == [1.0, 0.5]
|
||||||
|
|
||||||
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
|
def test_orpo_training_arguments(self, orpo_cfg, model, tokenizer):
|
||||||
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
|
builder = HFRLTrainerBuilder(orpo_cfg, model, tokenizer)
|
||||||
|
|||||||
@@ -54,7 +54,9 @@ except (ImportError, ModuleNotFoundError):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
def peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||||
return peft_A, peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
smoe_A = peft_A
|
||||||
|
smoe_B = peft_lora_B_to_scattermoe(peft_B, num_experts, rank)
|
||||||
|
return smoe_A, smoe_B
|
||||||
|
|
||||||
def _unwrap_experts_lora(experts_module):
|
def _unwrap_experts_lora(experts_module):
|
||||||
return experts_module, None, None
|
return experts_module, None, None
|
||||||
@@ -127,7 +129,11 @@ def scattermoe_lora_B_to_peft(smoe_B, num_experts, rank):
|
|||||||
|
|
||||||
|
|
||||||
def peft_gate_up_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
def peft_gate_up_lora_to_scattermoe(peft_A, peft_B, num_experts, rank):
|
||||||
"""Convert peft LoRA for gate_up_proj to scattermoe layout."""
|
"""Convert peft LoRA for gate_up_proj to scattermoe layout.
|
||||||
|
|
||||||
|
Both gate_up_proj and down_proj need the A<->B swap because
|
||||||
|
scattermoe transposes the parameter (W = param.T).
|
||||||
|
"""
|
||||||
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
|
return peft_lora_to_scattermoe(peft_A, peft_B, num_experts, rank)
|
||||||
|
|
||||||
|
|
||||||
@@ -300,6 +306,8 @@ class TestLoRABLayoutConversion:
|
|||||||
hidden, inter = 32, 16
|
hidden, inter = 32, 16
|
||||||
scaling = 2.0
|
scaling = 2.0
|
||||||
|
|
||||||
|
# peft >=0.19.1 for down_proj [E, hidden, inter]:
|
||||||
|
# swaps in/out, lora_A [r*E, inter], lora_B [hidden, r*E]
|
||||||
peft_A = torch.randn(E * r, inter)
|
peft_A = torch.randn(E * r, inter)
|
||||||
peft_B = torch.randn(hidden, E * r)
|
peft_B = torch.randn(hidden, E * r)
|
||||||
|
|
||||||
@@ -308,8 +316,6 @@ class TestLoRABLayoutConversion:
|
|||||||
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
||||||
|
|
||||||
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
|
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
|
||||||
assert smoe_A.shape == (E * r, inter)
|
|
||||||
assert smoe_B.shape == (hidden, E * r)
|
|
||||||
for e in range(E):
|
for e in range(E):
|
||||||
A_e = smoe_A[e * r : (e + 1) * r, :]
|
A_e = smoe_A[e * r : (e + 1) * r, :]
|
||||||
B_e = smoe_B[:, e * r : (e + 1) * r]
|
B_e = smoe_B[:, e * r : (e + 1) * r]
|
||||||
@@ -319,30 +325,22 @@ class TestLoRABLayoutConversion:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_gate_up_proj_conversion(self):
|
def test_gate_up_proj_conversion(self):
|
||||||
"""Verify gate_up_proj LoRA conversion with non-square dims (Qwen3-like).
|
"""Verify gate_up_proj LoRA conversion with non-square dims.
|
||||||
|
|
||||||
gate_up_proj param: [E, 2*inter, hidden].
|
gate_up_proj param: [E, 2*inter, hidden].
|
||||||
peft: in_features=hidden, out_features=2*inter.
|
peft swaps in/out for 3D: lora_A [r*E, hidden], lora_B [2*inter, r*E].
|
||||||
peft lora_A: [r*E, hidden], lora_B: [2*inter, r*E].
|
|
||||||
|
|
||||||
scattermoe W = param.T = [E, hidden, 2*inter], K=hidden, N=2*inter.
|
|
||||||
scattermoe needs: lora_A [r*E, K=hidden], lora_B [N=2*inter, r*E].
|
scattermoe needs: lora_A [r*E, K=hidden], lora_B [N=2*inter, r*E].
|
||||||
|
|
||||||
Uses non-square dims (hidden=32 != 2*inter=24) to catch layout bugs.
|
|
||||||
"""
|
"""
|
||||||
E, r = 4, 2
|
E, r = 4, 2
|
||||||
hidden, inter = 32, 12 # 2*inter=24 != hidden=32
|
hidden, inter = 32, 12 # 2*inter=24 != hidden=32
|
||||||
scaling = 2.0
|
scaling = 2.0
|
||||||
|
|
||||||
# peft assigns: in_features=hidden, out_features=2*inter
|
peft_A = torch.randn(E * r, hidden) # [r*E, in=hidden]
|
||||||
peft_A = torch.randn(E * r, hidden) # [r*E, in_features=hidden]
|
peft_B = torch.randn(2 * inter, E * r) # [out=2*inter, r*E]
|
||||||
peft_B = torch.randn(2 * inter, E * r) # [out_features=2*inter, r*E]
|
|
||||||
|
|
||||||
A_r = peft_A.reshape(E, r, hidden)
|
A_r = peft_A.reshape(E, r, hidden)
|
||||||
B_r = peft_B.reshape(2 * inter, r, E)
|
B_r = peft_B.reshape(2 * inter, r, E)
|
||||||
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
||||||
# delta_peft[e] has shape [out_features, in_features] = [2*inter, hidden]
|
|
||||||
# = param[e] shape [2*inter, hidden]
|
|
||||||
|
|
||||||
smoe_A, smoe_B = peft_gate_up_lora_to_scattermoe(peft_A, peft_B, E, r)
|
smoe_A, smoe_B = peft_gate_up_lora_to_scattermoe(peft_A, peft_B, E, r)
|
||||||
# smoe_A should be [r*E, K=hidden], smoe_B should be [N=2*inter, r*E]
|
# smoe_A should be [r*E, K=hidden], smoe_B should be [N=2*inter, r*E]
|
||||||
@@ -400,8 +398,7 @@ class TestPeftLoRAWeightExtraction:
|
|||||||
r,
|
r,
|
||||||
)
|
)
|
||||||
|
|
||||||
# gate_up_proj [E, 2*inter, hidden]
|
# gate_up_proj [E, 2*inter, hidden] — peft swaps in/out for 3D
|
||||||
# peft: in_features=hidden (last dim), out_features=2*inter (middle dim)
|
|
||||||
assert trainable[
|
assert trainable[
|
||||||
"base_model.model.moe.experts.base_layer.lora_A.default.weight"
|
"base_model.model.moe.experts.base_layer.lora_A.default.weight"
|
||||||
].shape == (E * r, config.hidden_size)
|
].shape == (E * r, config.hidden_size)
|
||||||
@@ -409,8 +406,7 @@ class TestPeftLoRAWeightExtraction:
|
|||||||
"base_model.model.moe.experts.base_layer.lora_B.default.weight"
|
"base_model.model.moe.experts.base_layer.lora_B.default.weight"
|
||||||
].shape == (2 * config.intermediate_size, E * r)
|
].shape == (2 * config.intermediate_size, E * r)
|
||||||
|
|
||||||
# down_proj [E, hidden, inter]
|
# down_proj [E, hidden, inter] — peft swaps in/out for 3D
|
||||||
# peft: in_features=inter (last dim), out_features=hidden (middle dim)
|
|
||||||
assert trainable[
|
assert trainable[
|
||||||
"base_model.model.moe.experts.lora_A.default.weight"
|
"base_model.model.moe.experts.lora_A.default.weight"
|
||||||
].shape == (E * r, config.intermediate_size)
|
].shape == (E * r, config.intermediate_size)
|
||||||
@@ -467,29 +463,26 @@ class TestPeftLoRAWeightExtraction:
|
|||||||
assert gup_lora is not None, "gate_up_proj LoRA not detected"
|
assert gup_lora is not None, "gate_up_proj LoRA not detected"
|
||||||
assert down_lora is not None, "down_proj LoRA not detected"
|
assert down_lora is not None, "down_proj LoRA not detected"
|
||||||
|
|
||||||
# Check shapes after peft->scattermoe conversion.
|
# gate_up_proj: K=hidden, N=2*inter
|
||||||
# gate_up_proj: peft A [E*r, hidden] / B [2*inter, E*r]
|
|
||||||
# scattermoe: smoe_A [E*r, hidden], smoe_B [2*inter, E*r]
|
|
||||||
E, r = config.num_experts, 4
|
E, r = config.num_experts, 4
|
||||||
gup_A, gup_B, gup_s = gup_lora
|
gup_A, gup_B, gup_s = gup_lora
|
||||||
assert gup_A.shape == (E * r, config.hidden_size), (
|
assert gup_A.shape == (E * r, config.hidden_size), (
|
||||||
f"gate_up_proj smoe_A: expected [r*E, hidden]={(E * r, config.hidden_size)}, "
|
f"gate_up_proj smoe_A: expected [r*E, K=hidden]={(E * r, config.hidden_size)}, "
|
||||||
f"got {gup_A.shape}"
|
f"got {gup_A.shape}"
|
||||||
)
|
)
|
||||||
assert gup_B.shape == (2 * config.intermediate_size, E * r), (
|
assert gup_B.shape == (2 * config.intermediate_size, E * r), (
|
||||||
f"gate_up_proj smoe_B: expected [2*inter, r*E]="
|
f"gate_up_proj smoe_B: expected [N=2*inter, r*E]="
|
||||||
f"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}"
|
f"{(2 * config.intermediate_size, E * r)}, got {gup_B.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# down_proj: peft A [E*r, inter] / B [hidden, E*r]
|
# down_proj: K=inter, N=hidden
|
||||||
# scattermoe: smoe_A [E*r, inter], smoe_B [hidden, E*r]
|
|
||||||
down_A, down_B, down_s = down_lora
|
down_A, down_B, down_s = down_lora
|
||||||
assert down_A.shape == (E * r, config.intermediate_size), (
|
assert down_A.shape == (E * r, config.intermediate_size), (
|
||||||
f"down_proj smoe_A: expected [r*E, inter]={(E * r, config.intermediate_size)}, "
|
f"down_proj smoe_A: expected [r*E, K=inter]={(E * r, config.intermediate_size)}, "
|
||||||
f"got {down_A.shape}"
|
f"got {down_A.shape}"
|
||||||
)
|
)
|
||||||
assert down_B.shape == (config.hidden_size, E * r), (
|
assert down_B.shape == (config.hidden_size, E * r), (
|
||||||
f"down_proj smoe_B: expected [hidden, r*E]={(config.hidden_size, E * r)}, "
|
f"down_proj smoe_B: expected [N=hidden, r*E]={(config.hidden_size, E * r)}, "
|
||||||
f"got {down_B.shape}"
|
f"got {down_B.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
|
||||||
reason="Unsloth integration will be broken going into latest transformers"
|
|
||||||
)
|
|
||||||
class TestUnslothIntegration(unittest.TestCase):
|
|
||||||
"""Unsloth monkeypatch integration tests."""
|
|
||||||
|
|
||||||
def test_is_self_attn_patchable(self):
|
|
||||||
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
|
|
||||||
|
|
||||||
# ensures the current version of transformers has loss code that matches our patching code
|
|
||||||
self.assertTrue(
|
|
||||||
check_self_attn_is_patchable(),
|
|
||||||
"HF transformers self attention code has changed and isn't patchable",
|
|
||||||
)
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
"""
|
|
||||||
e2e tests for unsloth qlora
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.common.datasets import load_datasets
|
|
||||||
from axolotl.train import train
|
|
||||||
from axolotl.utils.config import normalize_config, validate_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
from ..utils import check_model_output_exists, check_tensorboard
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
|
||||||
reason="Unsloth integration will be broken going into latest transformers"
|
|
||||||
)
|
|
||||||
class TestUnslothQLoRA:
|
|
||||||
"""
|
|
||||||
Test class for Unsloth QLoRA Llama models
|
|
||||||
"""
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"sample_packing",
|
|
||||||
[True, False],
|
|
||||||
)
|
|
||||||
def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"sample_packing": sample_packing,
|
|
||||||
"flash_attention": True,
|
|
||||||
"unsloth_lora_mlp": True,
|
|
||||||
"unsloth_lora_qkv": True,
|
|
||||||
"unsloth_lora_o": True,
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"adapter": "qlora",
|
|
||||||
"lora_r": 16,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0.05,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 5,
|
|
||||||
"save_steps": 10,
|
|
||||||
"micro_batch_size": 4,
|
|
||||||
"gradient_accumulation_steps": 2,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"use_tensorboard": True,
|
|
||||||
"bf16": "auto",
|
|
||||||
"save_first_step": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
dataset_meta = load_datasets(cfg=cfg)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
|
||||||
|
|
||||||
check_tensorboard(
|
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"unsloth_lora_mlp": True,
|
|
||||||
"unsloth_lora_qkv": True,
|
|
||||||
"unsloth_lora_o": True,
|
|
||||||
"sample_packing": False,
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"adapter": "qlora",
|
|
||||||
"lora_r": 16,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0.05,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 5,
|
|
||||||
"save_steps": 10,
|
|
||||||
"micro_batch_size": 4,
|
|
||||||
"gradient_accumulation_steps": 2,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"use_tensorboard": True,
|
|
||||||
"bf16": "auto",
|
|
||||||
"save_first_step": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
dataset_meta = load_datasets(cfg=cfg)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
|
||||||
|
|
||||||
check_tensorboard(
|
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"sdp_attention",
|
|
||||||
[True, False],
|
|
||||||
)
|
|
||||||
def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention):
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"unsloth_lora_mlp": True,
|
|
||||||
"unsloth_lora_qkv": True,
|
|
||||||
"unsloth_lora_o": True,
|
|
||||||
"sample_packing": False,
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"adapter": "qlora",
|
|
||||||
"lora_r": 16,
|
|
||||||
"lora_alpha": 16,
|
|
||||||
"lora_dropout": 0.05,
|
|
||||||
"lora_target_linear": True,
|
|
||||||
"val_set_size": 0.05,
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|endoftext|>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 1,
|
|
||||||
"max_steps": 5,
|
|
||||||
"save_steps": 10,
|
|
||||||
"micro_batch_size": 4,
|
|
||||||
"gradient_accumulation_steps": 2,
|
|
||||||
"sdp_attention": sdp_attention,
|
|
||||||
"output_dir": temp_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_8bit",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
"use_tensorboard": True,
|
|
||||||
"fp16": True,
|
|
||||||
"save_first_step": False,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = validate_config(cfg)
|
|
||||||
normalize_config(cfg)
|
|
||||||
dataset_meta = load_datasets(cfg=cfg)
|
|
||||||
|
|
||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
|
||||||
check_model_output_exists(temp_dir, cfg)
|
|
||||||
|
|
||||||
check_tensorboard(
|
|
||||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss (%s) is too high"
|
|
||||||
)
|
|
||||||
@@ -116,6 +116,58 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_rpo(self, temp_dir):
|
||||||
|
# For TRL >= 0.29, loss_type=["sigmoid", "sft"], loss_weights=[1, alpha]
|
||||||
|
# replaces loss_type="rpo", rpo_alpha=alpha.
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"tokenizer_type": "AutoTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 64,
|
||||||
|
"lora_alpha": 32,
|
||||||
|
"lora_dropout": 0.1,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"rl": "dpo",
|
||||||
|
"dpo_loss_type": ["sigmoid", "sft"],
|
||||||
|
"dpo_loss_weights": [1.0, 1.0],
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||||
|
"type": "chatml.ultra",
|
||||||
|
"split": "train",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "paged_adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"warmup_steps": 5,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
|
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||||
|
"save_first_step": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = validate_config(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||||
|
|
||||||
@pytest.mark.skip("kto_pair no longer supported in trl")
|
@pytest.mark.skip("kto_pair no longer supported in trl")
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_kto_pair_lora(self, temp_dir):
|
def test_kto_pair_lora(self, temp_dir):
|
||||||
@@ -181,7 +233,8 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"pad_token": "<|endoftext|>",
|
"pad_token": "<|endoftext|>",
|
||||||
},
|
},
|
||||||
"rl": "ipo",
|
"rl": "dpo",
|
||||||
|
"dpo_loss_type": ["ipo"],
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||||
|
|||||||
@@ -21,51 +21,6 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class TestPeftScatterMoELoRALayout:
|
|
||||||
"""CPU-only guards for PEFT target_parameters layout conversion."""
|
|
||||||
|
|
||||||
def test_peft_layout_keeps_a_and_reorders_b(self):
|
|
||||||
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import (
|
|
||||||
peft_lora_to_scattermoe,
|
|
||||||
)
|
|
||||||
|
|
||||||
E, r, K, N = 3, 2, 5, 7
|
|
||||||
scaling = 2.0
|
|
||||||
peft_A = torch.randn(E * r, K)
|
|
||||||
peft_B = torch.randn(N, E * r)
|
|
||||||
|
|
||||||
smoe_A, smoe_B = peft_lora_to_scattermoe(peft_A, peft_B, E, r)
|
|
||||||
|
|
||||||
assert smoe_A is peft_A
|
|
||||||
assert smoe_A.shape == (E * r, K)
|
|
||||||
assert smoe_B.shape == (N, E * r)
|
|
||||||
|
|
||||||
A_r = peft_A.reshape(E, r, K)
|
|
||||||
B_r = peft_B.reshape(N, r, E)
|
|
||||||
delta_peft = torch.einsum("o r e, e r i -> e o i", B_r, A_r) * scaling
|
|
||||||
|
|
||||||
for e in range(E):
|
|
||||||
A_e = smoe_A[e * r : (e + 1) * r, :]
|
|
||||||
B_e = smoe_B[:, e * r : (e + 1) * r]
|
|
||||||
torch.testing.assert_close(scaling * (B_e @ A_e), delta_peft[e])
|
|
||||||
|
|
||||||
def test_swapped_layout_fails_before_kernel_dispatch(self):
|
|
||||||
from axolotl.integrations.kernels.libs.scattermoe_lora.lora_layout import (
|
|
||||||
validate_scattermoe_lora_shapes,
|
|
||||||
)
|
|
||||||
|
|
||||||
E, r, K, N = 3, 2, 5, 7
|
|
||||||
expert_weights = torch.empty(E, K, N)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Invalid ScatterMoE LoRA layout"):
|
|
||||||
validate_scattermoe_lora_shapes(
|
|
||||||
expert_weights=expert_weights,
|
|
||||||
lora_A=torch.empty(E * r, N),
|
|
||||||
lora_B=torch.empty(K, E * r),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# 1. KernelsArgs: disable_mlp_kernel validator
|
# 1. KernelsArgs: disable_mlp_kernel validator
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
171
tests/monkeypatch/test_gemma4_fused_attn_patch.py
Normal file
171
tests/monkeypatch/test_gemma4_fused_attn_patch.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""Unit tests for the Gemma4 fused-attention shared_kv_states routing patch."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
gemma4_modeling = pytest.importorskip("transformers.models.gemma4.modeling_gemma4")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def clean_decoder_layer_patch_slate():
|
||||||
|
"""Save and restore Gemma4TextDecoderLayer.__call__ and the sentinel."""
|
||||||
|
from axolotl.monkeypatch.models.gemma4 import fused_attn
|
||||||
|
|
||||||
|
cls = gemma4_modeling.Gemma4TextDecoderLayer
|
||||||
|
original_call = cls.__call__
|
||||||
|
had_sentinel = getattr(cls, "_axolotl_shared_kv_patched", False)
|
||||||
|
|
||||||
|
if had_sentinel:
|
||||||
|
del cls._axolotl_shared_kv_patched
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield cls, fused_attn
|
||||||
|
finally:
|
||||||
|
cls.__call__ = original_call
|
||||||
|
if had_sentinel:
|
||||||
|
cls._axolotl_shared_kv_patched = True
|
||||||
|
elif hasattr(cls, "_axolotl_shared_kv_patched"):
|
||||||
|
del cls._axolotl_shared_kv_patched
|
||||||
|
fused_attn._set_shared_kv_states(None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPatchedDecoderLayerCall:
|
||||||
|
def test_pops_shared_kv_states_and_populates_store(
|
||||||
|
self, clean_decoder_layer_patch_slate
|
||||||
|
):
|
||||||
|
cls, fused_attn = clean_decoder_layer_patch_slate
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def spy(self, *args, **kwargs):
|
||||||
|
captured["args"] = args
|
||||||
|
captured["kwargs"] = dict(kwargs)
|
||||||
|
return "spy_return"
|
||||||
|
|
||||||
|
cls.__call__ = spy
|
||||||
|
fused_attn._patch_decoder_layer_call()
|
||||||
|
|
||||||
|
assert getattr(cls, "_axolotl_shared_kv_patched", False) is True
|
||||||
|
assert cls.__call__ is not spy
|
||||||
|
|
||||||
|
shared_kv = {"layer_0": ("k", "v")}
|
||||||
|
result = cls.__call__(
|
||||||
|
object(),
|
||||||
|
"positional_arg",
|
||||||
|
shared_kv_states=shared_kv,
|
||||||
|
other_kwarg="keep_me",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "spy_return"
|
||||||
|
assert captured["args"] == ("positional_arg",)
|
||||||
|
assert "shared_kv_states" not in captured["kwargs"]
|
||||||
|
assert captured["kwargs"] == {"other_kwarg": "keep_me"}
|
||||||
|
assert fused_attn._get_shared_kv_states() is shared_kv
|
||||||
|
|
||||||
|
def test_clears_store_when_kwarg_absent(self, clean_decoder_layer_patch_slate):
|
||||||
|
"""Regression for commit 251021e1: a prior step's dict must not leak
|
||||||
|
into a later call that omits `shared_kv_states`."""
|
||||||
|
cls, fused_attn = clean_decoder_layer_patch_slate
|
||||||
|
|
||||||
|
def spy(self, *args, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
cls.__call__ = spy
|
||||||
|
fused_attn._patch_decoder_layer_call()
|
||||||
|
|
||||||
|
stale = {"stale_step": True}
|
||||||
|
fused_attn._set_shared_kv_states(stale)
|
||||||
|
assert fused_attn._get_shared_kv_states() is stale
|
||||||
|
|
||||||
|
cls.__call__(object())
|
||||||
|
|
||||||
|
assert fused_attn._get_shared_kv_states() is None
|
||||||
|
|
||||||
|
def test_store_visible_across_threads(self):
|
||||||
|
"""Regression for commit e3669b2c: the store must be readable from
|
||||||
|
threads other than the one that set it. `threading.local()` failed
|
||||||
|
this invariant, crashing with 'NoneType' object is not subscriptable'
|
||||||
|
on MoE Gemma4 variants when autograd worker threads ran backward
|
||||||
|
recompute under HF-Trainer gradient_checkpointing."""
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.models.gemma4 import fused_attn
|
||||||
|
|
||||||
|
sentinel = {"layer_0": ("k", "v")}
|
||||||
|
try:
|
||||||
|
fused_attn._set_shared_kv_states(sentinel)
|
||||||
|
|
||||||
|
seen = {}
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
seen["value"] = fused_attn._get_shared_kv_states()
|
||||||
|
|
||||||
|
t = threading.Thread(target=worker)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert seen["value"] is sentinel
|
||||||
|
finally:
|
||||||
|
fused_attn._set_shared_kv_states(None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def clean_entry_point_patch_slate():
|
||||||
|
"""Save and restore Gemma4TextAttention.forward and Gemma4TextDecoderLayer.__call__."""
|
||||||
|
from axolotl.monkeypatch.models.gemma4 import fused_attn
|
||||||
|
|
||||||
|
decoder_cls = gemma4_modeling.Gemma4TextDecoderLayer
|
||||||
|
attn_cls = gemma4_modeling.Gemma4TextAttention
|
||||||
|
|
||||||
|
original_call = decoder_cls.__call__
|
||||||
|
original_forward = attn_cls.forward
|
||||||
|
had_sentinel = getattr(decoder_cls, "_axolotl_shared_kv_patched", False)
|
||||||
|
|
||||||
|
if had_sentinel:
|
||||||
|
del decoder_cls._axolotl_shared_kv_patched
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield decoder_cls, attn_cls, original_call, original_forward, fused_attn
|
||||||
|
finally:
|
||||||
|
decoder_cls.__call__ = original_call
|
||||||
|
attn_cls.forward = original_forward
|
||||||
|
if had_sentinel:
|
||||||
|
decoder_cls._axolotl_shared_kv_patched = True
|
||||||
|
elif hasattr(decoder_cls, "_axolotl_shared_kv_patched"):
|
||||||
|
del decoder_cls._axolotl_shared_kv_patched
|
||||||
|
fused_attn._set_shared_kv_states(None)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPatchGemma4FusedAttnEntryPoint:
|
||||||
|
def test_default_flag_swaps_only_attention_forward(
|
||||||
|
self, clean_entry_point_patch_slate
|
||||||
|
):
|
||||||
|
(
|
||||||
|
decoder_cls,
|
||||||
|
attn_cls,
|
||||||
|
original_call,
|
||||||
|
original_forward,
|
||||||
|
fused_attn,
|
||||||
|
) = clean_entry_point_patch_slate
|
||||||
|
|
||||||
|
fused_attn.patch_gemma4_fused_attn()
|
||||||
|
|
||||||
|
assert attn_cls.forward is not original_forward
|
||||||
|
assert decoder_cls.__call__ is original_call
|
||||||
|
assert not getattr(decoder_cls, "_axolotl_shared_kv_patched", False)
|
||||||
|
|
||||||
|
def test_workaround_flag_installs_decoder_layer_patch(
|
||||||
|
self, clean_entry_point_patch_slate
|
||||||
|
):
|
||||||
|
(
|
||||||
|
decoder_cls,
|
||||||
|
attn_cls,
|
||||||
|
original_call,
|
||||||
|
original_forward,
|
||||||
|
fused_attn,
|
||||||
|
) = clean_entry_point_patch_slate
|
||||||
|
|
||||||
|
fused_attn.patch_gemma4_fused_attn(install_shared_kv_workaround=True)
|
||||||
|
|
||||||
|
assert attn_cls.forward is not original_forward
|
||||||
|
assert decoder_cls.__call__ is not original_call
|
||||||
|
assert getattr(decoder_cls, "_axolotl_shared_kv_patched", False) is True
|
||||||
@@ -111,7 +111,7 @@ def fixture_argilla_chat_dataset():
|
|||||||
@pytest.fixture(name="phi3_tokenizer")
|
@pytest.fixture(name="phi3_tokenizer")
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def fixture_phi3_tokenizer():
|
def fixture_phi3_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct")
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
@@ -214,8 +214,8 @@ class TestAssistantDPOChatTemplatePhi3:
|
|||||||
+ "<|user|>\ngoodbye<|end|>\n"
|
+ "<|user|>\ngoodbye<|end|>\n"
|
||||||
+ "<|assistant|>\n"
|
+ "<|assistant|>\n"
|
||||||
)
|
)
|
||||||
assert result["chosen"] == "goodbye<|end|>"
|
assert result["chosen"] == "goodbye<|end|>\n<|endoftext|>"
|
||||||
assert result["rejected"] == "party on<|end|>"
|
assert result["rejected"] == "party on<|end|>\n<|endoftext|>"
|
||||||
|
|
||||||
|
|
||||||
class TestAssistantDPOChatTemplateGemma:
|
class TestAssistantDPOChatTemplateGemma:
|
||||||
@@ -290,8 +290,8 @@ class TestArgillaChatDPOChatTemplate:
|
|||||||
)
|
)
|
||||||
result = transform_fn(argilla_chat_dataset[0], tokenizer=phi3_tokenizer)
|
result = transform_fn(argilla_chat_dataset[0], tokenizer=phi3_tokenizer)
|
||||||
assert result["prompt"] == "<|user|>\nhello<|end|>\n" + "<|assistant|>\n"
|
assert result["prompt"] == "<|user|>\nhello<|end|>\n" + "<|assistant|>\n"
|
||||||
assert result["chosen"] == "goodbye<|end|>"
|
assert result["chosen"] == "goodbye<|end|>\n<|endoftext|>"
|
||||||
assert result["rejected"] == "party on<|end|>"
|
assert result["rejected"] == "party on<|end|>\n<|endoftext|>"
|
||||||
|
|
||||||
|
|
||||||
class TestDPOChatTemplateToolRole:
|
class TestDPOChatTemplateToolRole:
|
||||||
|
|||||||
@@ -487,3 +487,70 @@ class TestDatasetPreparation:
|
|||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
|
@enable_hf_offline
|
||||||
|
def test_load_dataset_with_str_json_data(self, tokenizer):
|
||||||
|
"""
|
||||||
|
Test loading datasets where data is stored as str JSON instead of list of dicts.
|
||||||
|
see: https://github.com/axolotl-ai-cloud/axolotl/pull/3607 for more details.
|
||||||
|
"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
import json
|
||||||
|
|
||||||
|
str_json_ds = Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"messages": json.dumps(
|
||||||
|
[
|
||||||
|
{"role": "user", "content": "Hello how are you?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I am doing good thanks",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": json.dumps(
|
||||||
|
[
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{"role": "assistant", "content": "2+2 equals 4."},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
tmp_ds_path = Path(tmp_dir) / "str_json_dataset.parquet"
|
||||||
|
str_json_ds.to_parquet(tmp_ds_path)
|
||||||
|
|
||||||
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"sequence_len": 512,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": str(tmp_ds_path),
|
||||||
|
"name": "test_str_json",
|
||||||
|
"type": "chat_template",
|
||||||
|
"field_messages": "messages",
|
||||||
|
"message_field_role": "role",
|
||||||
|
"message_field_content": "content",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"dataset_num_proc": 4,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"axolotl.common.const.DEFAULT_DATASET_PREPARED_PATH", str(prepared_path)
|
||||||
|
):
|
||||||
|
dataset, _ = _load_tokenized_prepared_datasets(tokenizer, cfg)
|
||||||
|
|
||||||
|
assert len(dataset) == 2
|
||||||
|
assert "input_ids" in dataset.features
|
||||||
|
assert "attention_mask" in dataset.features
|
||||||
|
assert "labels" in dataset.features
|
||||||
|
|
||||||
|
assert len(dataset[0]["input_ids"]) > 0
|
||||||
|
|||||||
@@ -133,3 +133,108 @@ class TestRevisionParameter:
|
|||||||
|
|
||||||
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||||
assert "revision" not in call_kwargs.kwargs
|
assert "revision" not in call_kwargs.kwargs
|
||||||
|
|
||||||
|
@patch("axolotl.loaders.processor.AutoProcessor")
|
||||||
|
def test_load_processor_forwards_processor_kwargs(self, mock_auto_processor):
|
||||||
|
mock_processor = MagicMock()
|
||||||
|
mock_processor.size = {}
|
||||||
|
mock_auto_processor.from_pretrained.return_value = mock_processor
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"processor_config": "some-model",
|
||||||
|
"trust_remote_code": False,
|
||||||
|
"processor_kwargs": {
|
||||||
|
"image_seq_length": 1120,
|
||||||
|
"max_soft_tokens": 1120,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||||
|
|
||||||
|
from axolotl.loaders.processor import load_processor
|
||||||
|
|
||||||
|
load_processor(cfg, tokenizer)
|
||||||
|
|
||||||
|
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||||
|
assert call_kwargs.kwargs.get("image_seq_length") == 1120
|
||||||
|
assert call_kwargs.kwargs.get("max_soft_tokens") == 1120
|
||||||
|
|
||||||
|
@patch("axolotl.loaders.processor.AutoProcessor")
|
||||||
|
def test_load_processor_omits_processor_kwargs_when_unset(
|
||||||
|
self, mock_auto_processor
|
||||||
|
):
|
||||||
|
mock_processor = MagicMock()
|
||||||
|
mock_processor.size = {}
|
||||||
|
mock_auto_processor.from_pretrained.return_value = mock_processor
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"processor_config": "some-model",
|
||||||
|
"trust_remote_code": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||||
|
|
||||||
|
from axolotl.loaders.processor import load_processor
|
||||||
|
|
||||||
|
load_processor(cfg, tokenizer)
|
||||||
|
|
||||||
|
call_kwargs = mock_auto_processor.from_pretrained.call_args
|
||||||
|
assert "image_seq_length" not in call_kwargs.kwargs
|
||||||
|
assert "max_soft_tokens" not in call_kwargs.kwargs
|
||||||
|
|
||||||
|
def test_processor_kwargs_schema_rejects_revision(self):
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.model import ModelInputConfig
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="revision"):
|
||||||
|
ModelInputConfig(
|
||||||
|
base_model="some-model",
|
||||||
|
processor_kwargs={"revision": "abc123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_processor_kwargs_schema_rejects_trust_remote_code(self):
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.model import ModelInputConfig
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="trust_remote_code"):
|
||||||
|
ModelInputConfig(
|
||||||
|
base_model="some-model",
|
||||||
|
processor_kwargs={"trust_remote_code": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_processor_kwargs_schema_accepts_valid_keys(self):
|
||||||
|
from axolotl.utils.schemas.model import ModelInputConfig
|
||||||
|
|
||||||
|
cfg = ModelInputConfig(
|
||||||
|
base_model="some-model",
|
||||||
|
processor_kwargs={"image_seq_length": 1120, "max_soft_tokens": 1120},
|
||||||
|
)
|
||||||
|
assert cfg.processor_kwargs == {
|
||||||
|
"image_seq_length": 1120,
|
||||||
|
"max_soft_tokens": 1120,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_processor_kwargs_schema_accepts_none_and_empty(self):
|
||||||
|
from axolotl.utils.schemas.model import ModelInputConfig
|
||||||
|
|
||||||
|
assert ModelInputConfig(base_model="x").processor_kwargs is None
|
||||||
|
assert (
|
||||||
|
ModelInputConfig(base_model="x", processor_kwargs={}).processor_kwargs == {}
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_processor_kwargs_incompatible_with_mistral_common(self, min_base_cfg):
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.config import validate_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
cfg = min_base_cfg | DictDefault(
|
||||||
|
tokenizer_use_mistral_common=True,
|
||||||
|
processor_kwargs={"image_seq_length": 1120},
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="processor_kwargs"):
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user