Compare commits
7 Commits
v0.10.0
...
runpod-sls
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b9a2dde4b | ||
|
|
388e950016 | ||
|
|
fb4adbb311 | ||
|
|
5e8abca54f | ||
|
|
168ec339e5 | ||
|
|
cb7185998b | ||
|
|
c2fc35f520 |
88
.github/workflows/base.yml
vendored
88
.github/workflows/base.yml
vendored
@@ -16,63 +16,60 @@ on:
|
||||
jobs:
|
||||
build-base:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
timeout-minutes: 480
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: ubuntu-latest-m
|
||||
runs-on: axolotl-gpu-runner
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.7.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.7.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: nightly
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-base-nightly"
|
||||
# # "next" is for release candidates of pytorch
|
||||
# - cuda: "128"
|
||||
# cuda_version: 12.8.1
|
||||
# cudnn_version: ""
|
||||
# python_version: "3.11"
|
||||
# pytorch: next
|
||||
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
# dockerfile: "Dockerfile-base-next"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: next
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -94,60 +91,7 @@ jobs:
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/${{ matrix.dockerfile }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
build-args: |
|
||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||
CUDNN_VERSION=${{ matrix.cudnn_version }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTHON_VERSION=${{ matrix.python_version }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
|
||||
build-base-uv:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
timeout-minutes: 480
|
||||
runs-on: ubuntu-latest-m
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "126"
|
||||
cuda_version: 12.6.3
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
dockerfile: "Dockerfile-uv-base"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
axolotlai/axolotl-base-uv
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/${{ matrix.dockerfile }}
|
||||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || matrix.pytorch == 'next' && './docker/Dockerfile-base-next' || './docker/Dockerfile-base' }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
1
.github/workflows/lint.yml
vendored
1
.github/workflows/lint.yml
vendored
@@ -9,7 +9,6 @@ on:
|
||||
- '.github/workflows/*.yml'
|
||||
- "*.[q]md"
|
||||
- "examples/**/*.y[a]?ml"
|
||||
- ".pre-commit-config.yaml"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
29
.github/workflows/main.yml
vendored
29
.github/workflows/main.yml
vendored
@@ -18,8 +18,13 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras: vllm
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
@@ -29,13 +34,8 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras: vllm
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -67,7 +67,6 @@ jobs:
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
||||
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
@@ -83,6 +82,11 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
@@ -97,12 +101,7 @@ jobs:
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
|
||||
16
.github/workflows/multi-gpu-e2e.yml
vendored
16
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -3,13 +3,12 @@ name: docker-multigpu-tests-biweekly
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'tests/e2e/multigpu/**.py'
|
||||
- 'tests/e2e/multigpu/*.py'
|
||||
- 'requirements.txt'
|
||||
- 'setup.py'
|
||||
- 'pyproject.toml'
|
||||
- '.github/workflows/multi-gpu-e2e.yml'
|
||||
- 'src/axolotl/core/trainers/mixins/sequence_parallel.py'
|
||||
- 'src/axolotl/utils/distributed.py'
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||
@@ -33,17 +32,24 @@ jobs:
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras: # no vllm support for 2.4.1
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.7.0
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
@@ -59,7 +65,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
pip install modal==0.71.8 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
|
||||
10
.github/workflows/nightlies.yml
vendored
10
.github/workflows/nightlies.yml
vendored
@@ -12,6 +12,11 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
@@ -65,6 +70,11 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
|
||||
9
.github/workflows/precommit-autoupdate.yml
vendored
9
.github/workflows/precommit-autoupdate.yml
vendored
@@ -25,6 +25,7 @@ jobs:
|
||||
pre-commit autoupdate
|
||||
if [[ -n $(git status --porcelain) ]]; then
|
||||
echo "changes=true" >> $GITHUB_OUTPUT
|
||||
git diff .pre-commit-config.yaml > pre-commit-update.diff
|
||||
fi
|
||||
|
||||
- name: Create Pull Request
|
||||
@@ -38,3 +39,11 @@ jobs:
|
||||
commit-message: "chore: update pre-commit hooks"
|
||||
body: |
|
||||
Automated PR to update pre-commit hooks to their latest versions.
|
||||
|
||||
<details>
|
||||
<summary>Changes:</summary>
|
||||
|
||||
```diff
|
||||
${{ steps.update.outputs.diff }}
|
||||
```
|
||||
</details>
|
||||
|
||||
61
.github/workflows/preview-docs.yml
vendored
61
.github/workflows/preview-docs.yml
vendored
@@ -1,61 +0,0 @@
|
||||
name: Preview
|
||||
on:
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
|
||||
# Run the workflow only when one of these files changes
|
||||
paths:
|
||||
- '**/*.md' # any Markdown file
|
||||
- '**/*.qmd' # any Quarto file
|
||||
- '_quarto.yaml'
|
||||
|
||||
permissions:
|
||||
checks: write
|
||||
contents: write
|
||||
deployments: write
|
||||
issues: write
|
||||
discussions: write
|
||||
pages: write
|
||||
pull-requests: write
|
||||
statuses: write
|
||||
|
||||
jobs:
|
||||
preview:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Quarto
|
||||
uses: quarto-dev/quarto-actions/setup@v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python3 -m pip install jupyter quartodoc
|
||||
python3 -m pip install -e . --no-deps
|
||||
|
||||
- name: Build autodoc
|
||||
run: quartodoc build
|
||||
|
||||
- name: Quarto render
|
||||
run: quarto render
|
||||
|
||||
- name: Netlify Publish
|
||||
uses: nwtgck/actions-netlify@v3.0
|
||||
with:
|
||||
publish-dir: './_site'
|
||||
enable-pull-request-comment: true
|
||||
enable-github-deployment: true
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
deploy-message: "Deployed On Netlify"
|
||||
github-deployment-environment: 'preview'
|
||||
github-deployment-description: 'Preview Deployment'
|
||||
env:
|
||||
NETLIFY_AUTH_TOKEN: ${{ secrets.NETLIFY_AUTH_TOKEN }}
|
||||
NETLIFY_SITE_ID: ${{ secrets.NETLIFY_SITE_ID }}
|
||||
96
.github/workflows/tests-nightly.yml
vendored
96
.github/workflows/tests-nightly.yml
vendored
@@ -18,102 +18,15 @@ jobs:
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
preload-cache:
|
||||
name: Preload HF cache
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
env:
|
||||
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v tests/conftest.py
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
files: ./coverage.xml
|
||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
||||
fail_ci_if_error: false
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Save HF cache
|
||||
id: hf-cache
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
runs-on: ubuntu-latest
|
||||
needs: [preload-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
|
||||
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
@@ -193,6 +106,13 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
|
||||
128
.github/workflows/tests.yml
vendored
128
.github/workflows/tests.yml
vendored
@@ -27,9 +27,6 @@ concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
env:
|
||||
TRANSFORMERS_IS_CI: "yes"
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
@@ -47,23 +44,26 @@ jobs:
|
||||
pytest:
|
||||
name: PyTest
|
||||
runs-on: ubuntu-latest
|
||||
# needs: [preload-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
|
||||
pytorch_version: ["2.4.1", "2.5.1", "2.6.0", "2.7.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p /home/runner/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -118,25 +118,38 @@ jobs:
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Save HF cache
|
||||
id: hf-cache
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
pytest-sdist:
|
||||
name: PyTest from Source Dist
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 1
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.5.1", "2.6.0", "2.7.1"]
|
||||
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore Cache from S3
|
||||
id: hf-cache-restore-s3
|
||||
run: |
|
||||
mkdir -p /home/runner/.cache/huggingface/hub
|
||||
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
@@ -183,12 +196,20 @@ jobs:
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Save HF cache
|
||||
id: hf-cache
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
docker-e2e-tests-1st:
|
||||
# Run this job first as a gate for running the remainder of the test matrix
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
timeout-minutes: 90
|
||||
needs: [pre-commit, pytest, pytest-sdist]
|
||||
|
||||
strategy:
|
||||
@@ -201,13 +222,6 @@ jobs:
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras: vllm
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
dockerfile: "Dockerfile-uv.jinja"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -218,7 +232,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
pip install modal==0.71.8 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -229,7 +243,6 @@ jobs:
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
@@ -238,9 +251,7 @@ jobs:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 120
|
||||
# Only run the remainder of the matrix if the first e2e check passed;
|
||||
# this is to save on wasted compute costs for known failures that get caught in the first run
|
||||
timeout-minutes: 90
|
||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||
|
||||
strategy:
|
||||
@@ -250,25 +261,19 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: llmcompressor
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
axolotl_extras: vllm
|
||||
- cuda: 126
|
||||
cuda_version: 12.6.3
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 128
|
||||
cuda_version: 12.8.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.7.1
|
||||
pytorch: 2.7.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
@@ -281,7 +286,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
pip install modal==0.71.8 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -292,47 +297,6 @@ jobs:
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.e2e_tests
|
||||
|
||||
docker-e2e-cleanup:
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
needs: [docker-e2e-tests]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
num_gpus: 1
|
||||
axolotl_extras: vllm
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==1.0.2 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.cleanup
|
||||
|
||||
@@ -19,15 +19,15 @@ repos:
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.2.0
|
||||
rev: 7.1.2
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/pylint-dev/pylint
|
||||
rev: v3.3.7
|
||||
rev: v3.3.6
|
||||
hooks:
|
||||
- id: pylint
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.16.0
|
||||
rev: v1.15.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
FROM axolotlai/axolotl-cloud:main-py3.11-cu124-2.6.0
|
||||
FROM runpod/pytorch:3.10-2.0.0-117
|
||||
|
||||
COPY .runpod/requirements.txt /requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install --upgrade pip && \
|
||||
python3 -m pip install --upgrade -r /requirements.txt
|
||||
|
||||
|
||||
# Environment settings
|
||||
ARG BASE_VOLUME="/runpod-volume"
|
||||
ENV BASE_VOLUME=$BASE_VOLUME
|
||||
@@ -14,5 +15,4 @@ ENV TRANSFORMERS_CACHE="${BASE_VOLUME}/huggingface-cache/hub"
|
||||
|
||||
COPY .runpod/src /src
|
||||
|
||||
WORKDIR /src
|
||||
CMD ["python3", "/src/handler.py"]
|
||||
|
||||
@@ -5,3 +5,11 @@
|
||||
# git+https://github.com/runpod/runpod-python.git
|
||||
# To learn more, see https://pip.pypa.io/en/stable/reference/requirements-file-format/
|
||||
runpod~=1.7.0
|
||||
huggingface_hub
|
||||
typing-extensions
|
||||
pydantic
|
||||
pydantic-settings
|
||||
hf-transfer
|
||||
setuptools
|
||||
numpy==2.0.0
|
||||
axolotl[flash-attn,deepspeed]
|
||||
|
||||
@@ -242,12 +242,16 @@
|
||||
# early_stopping_patience: 3
|
||||
|
||||
# # Specify a scheduler and kwargs to use with the optimizer
|
||||
# lr_scheduler: # 'one_cycle' | empty for cosine
|
||||
# lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
||||
# lr_scheduler_kwargs:
|
||||
|
||||
# # For one_cycle optim
|
||||
# lr_div_factor: # Learning rate div factor
|
||||
|
||||
# # For log_sweep optim
|
||||
# log_sweep_min_lr:
|
||||
# log_sweep_max_lr:
|
||||
|
||||
# # Specify optimizer
|
||||
# # Valid values are driven by the Transformers OptimizerNames class, see:
|
||||
# # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
||||
|
||||
@@ -57,10 +57,8 @@ async def handler(job):
|
||||
logger.info("Training Complete.")
|
||||
|
||||
# Cleanup
|
||||
if "WANDB_API_KEY" in os.environ:
|
||||
del os.environ["WANDB_API_KEY"]
|
||||
if "HF_TOKEN" in os.environ:
|
||||
del os.environ["HF_TOKEN"]
|
||||
del os.environ["WANDB_API_KEY"]
|
||||
del os.environ["HF_TOKEN"]
|
||||
|
||||
|
||||
runpod.serverless.start({"handler": handler, "return_aggregate_stream": True})
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
{
|
||||
"input": {
|
||||
"name": "quick_smoke_test_sft",
|
||||
"user_id": "user",
|
||||
"model_id": "llama-test",
|
||||
"run_id": "llama-test",
|
||||
"credentials": {
|
||||
"wandb_api_key": "",
|
||||
"hf_token": ""
|
||||
},
|
||||
"args": {
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"load_in_4bit": true,
|
||||
"strict": false,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]"
|
||||
}
|
||||
],
|
||||
"val_set_size": 0.02,
|
||||
"output_dir": "./outputs/lora-out",
|
||||
"sequence_len": 4096,
|
||||
"sample_packing": true,
|
||||
"eval_sample_packing": false,
|
||||
"pad_to_sequence_len": true,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 64,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": true,
|
||||
"lora_modules_to_save": [
|
||||
"embed_tokens",
|
||||
"lm_head"
|
||||
],
|
||||
"gradient_accumulation_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"num_epochs": 1,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"lr_scheduler": "cosine",
|
||||
"learning_rate": 0.0002,
|
||||
"train_on_inputs": false,
|
||||
"group_by_length": false,
|
||||
"bf16": "auto",
|
||||
"tf32": true,
|
||||
"gradient_checkpointing": true,
|
||||
"logging_steps": 1,
|
||||
"flash_attention": true,
|
||||
"warmup_steps": 1,
|
||||
"evals_per_epoch": 1,
|
||||
"eval_max_new_tokens": 128,
|
||||
"saves_per_epoch": 1,
|
||||
"weight_decay": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>"
|
||||
},
|
||||
"max_steps": 20
|
||||
},
|
||||
"timeout": 100000
|
||||
},
|
||||
"config": {
|
||||
"gpuTypeId": "NVIDIA GeForce RTX 4090",
|
||||
"gpuCount": 1,
|
||||
"containerDiskInGb": 200,
|
||||
"env": [
|
||||
{
|
||||
"key": "TOKENIZER",
|
||||
"value": ""
|
||||
},
|
||||
{
|
||||
"key": "DISABLE_LOG_STATS",
|
||||
"value": "true"
|
||||
}
|
||||
],
|
||||
"allowedCudaVersions": [
|
||||
"12.8",
|
||||
"12.7",
|
||||
"12.6",
|
||||
"12.5",
|
||||
"12.4"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -11,43 +11,43 @@
|
||||
"hf_token": ""
|
||||
},
|
||||
"args": {
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"base_model": "NousResearch/Meta-Llama-3-8B",
|
||||
"model_type": "LlamaForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"load_in_4bit": true,
|
||||
"load_in_8bit": true,
|
||||
"load_in_4bit": false,
|
||||
"strict": false,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
"split": "train[:10%]"
|
||||
"type": "alpaca"
|
||||
}
|
||||
],
|
||||
"val_set_size": 0.02,
|
||||
"val_set_size": 0.05,
|
||||
"output_dir": "./outputs/lora-out",
|
||||
"sequence_len": 4096,
|
||||
"sample_packing": true,
|
||||
"eval_sample_packing": false,
|
||||
"pad_to_sequence_len": true,
|
||||
"adapter": "qlora",
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 64,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": true,
|
||||
"lora_modules_to_save": [
|
||||
"embed_tokens",
|
||||
"lm_head"
|
||||
],
|
||||
"gradient_accumulation_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"micro_batch_size": 2,
|
||||
"num_epochs": 1,
|
||||
"optimizer": "adamw_torch_fused",
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"learning_rate": 0.0002,
|
||||
"train_on_inputs": false,
|
||||
"group_by_length": false,
|
||||
"bf16": "auto",
|
||||
"tf32": true,
|
||||
"tf32": false,
|
||||
"gradient_checkpointing": true,
|
||||
"logging_steps": 1,
|
||||
"flash_attention": true,
|
||||
@@ -57,9 +57,8 @@
|
||||
"saves_per_epoch": 1,
|
||||
"weight_decay": 0.0,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>"
|
||||
},
|
||||
"max_steps": 20
|
||||
"pad_token": "<|end_of_text|>"
|
||||
}
|
||||
}
|
||||
},
|
||||
"timeout": 100000
|
||||
|
||||
76
README.md
76
README.md
@@ -22,32 +22,28 @@
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
|
||||
</p>
|
||||
|
||||
|
||||
## 🎉 Latest Updates
|
||||
|
||||
- 2025/06: Magistral with mistral-common tokenizer support has been added to Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/magistral) to start training your own Magistral models with Axolotl!
|
||||
- 2025/05: Quantization Aware Training (QAT) support has been added to Axolotl. Explore the [docs](https://docs.axolotl.ai/docs/qat.html) to learn more!
|
||||
- 2025/04: Llama 4 support has been added in Axolotl. See [examples](https://github.com/axolotl-ai-cloud/axolotl/tree/main/examples/llama-4) to start training your own Llama 4 models with Axolotl's linearized version!
|
||||
- 2025/03: Axolotl has implemented Sequence Parallelism (SP) support. Read the [blog](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl) and [docs](https://docs.axolotl.ai/docs/sequence_parallelism.html) to learn how to scale your context length when fine-tuning.
|
||||
- 2025/03: (Beta) Fine-tuning Multimodal models is now supported in Axolotl. Check out the [docs](https://docs.axolotl.ai/docs/multimodal.html) to fine-tune your own!
|
||||
- 2025/02: Axolotl has added LoRA optimizations to reduce memory usage and improve training speed for LoRA and QLoRA in single GPU and multi-GPU training (DDP and DeepSpeed). Jump into the [docs](https://docs.axolotl.ai/docs/lora_optims.html) to give it a try.
|
||||
- 2025/02: Axolotl has added GRPO support. Dive into our [blog](https://huggingface.co/blog/axolotl-ai-co/training-llms-w-interpreter-feedback-wasm) and [GRPO example](https://github.com/axolotl-ai-cloud/grpo_code) and have some fun!
|
||||
- 2025/01: Axolotl has added Reward Modelling / Process Reward Modelling fine-tuning support. See [docs](https://docs.axolotl.ai/docs/reward_modelling.html).
|
||||
|
||||
## ✨ Overview
|
||||
|
||||
Axolotl is a tool designed to streamline post-training for various AI models.
|
||||
Post-training refers to any modifications or additional training performed on
|
||||
pre-trained models - including full model fine-tuning, parameter-efficient tuning (like
|
||||
LoRA and QLoRA), supervised fine-tuning (SFT), instruction tuning, and alignment
|
||||
techniques. With support for multiple model architectures and training configurations,
|
||||
Axolotl makes it easy to get started with these techniques.
|
||||
|
||||
Axolotl is designed to work with YAML config files that contain everything you need to
|
||||
preprocess a dataset, train or fine-tune a model, run model inference or evaluation,
|
||||
and much more.
|
||||
|
||||
Features:
|
||||
|
||||
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
|
||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), Sequence Parallelism (SP), LoRA optimizations, Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!
|
||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||
|
||||
|
||||
- Train various Huggingface models such as llama, pythia, falcon, mpt
|
||||
- Supports fullfinetune, lora, qlora, relora, and gptq
|
||||
- Customize configurations using a simple yaml file or CLI overwrite
|
||||
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
|
||||
- Integrated with [xformers](https://github.com/facebookresearch/xformers), flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
|
||||
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
|
||||
- Easily run with Docker locally or on the cloud
|
||||
- Log results and optionally checkpoints to wandb, mlflow or Comet
|
||||
- And more!
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
@@ -55,7 +51,7 @@ Features:
|
||||
|
||||
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python 3.11
|
||||
- PyTorch ≥2.5.1
|
||||
- PyTorch ≥2.4.1
|
||||
|
||||
### Installation
|
||||
|
||||
@@ -85,12 +81,19 @@ axolotl train examples/llama-3/lora-1b.yml
|
||||
|
||||
That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/getting-started.html) for a more detailed walkthrough.
|
||||
|
||||
## ✨ Key Features
|
||||
|
||||
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more
|
||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, and more
|
||||
- **Easy Configuration**: Simple YAML files to control your training setup
|
||||
- **Performance Optimizations**: Flash Attention, xformers, multi-GPU training
|
||||
- **Flexible Dataset Handling**: Use various formats and custom datasets
|
||||
- **Cloud Ready**: Run on cloud platforms or local hardware
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- [Installation Options](https://docs.axolotl.ai/docs/installation.html) - Detailed setup instructions for different environments
|
||||
- [Configuration Guide](https://docs.axolotl.ai/docs/config.html) - Full configuration options and examples
|
||||
- [Dataset Loading](https://docs.axolotl.ai/docs/dataset_loading.html) - Loading datasets from various sources
|
||||
- [Dataset Guide](https://docs.axolotl.ai/docs/dataset-formats/) - Supported formats and how to use them
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
@@ -109,6 +112,31 @@ That's it! Check out our [Getting Started Guide](https://docs.axolotl.ai/docs/ge
|
||||
|
||||
Contributions are welcome! Please see our [Contributing Guide](https://github.com/axolotl-ai-cloud/axolotl/blob/main/.github/CONTRIBUTING.md) for details.
|
||||
|
||||
## Supported Models
|
||||
|
||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Mixtral8X22 | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
|
||||
| Jamba | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
|
||||
|
||||
✅: supported
|
||||
❌: not supported
|
||||
❓: untested
|
||||
|
||||
## ❤️ Sponsors
|
||||
|
||||
Thank you to our sponsors who help make Axolotl possible:
|
||||
|
||||
43
_quarto.yml
43
_quarto.yml
@@ -17,9 +17,7 @@ quartodoc:
|
||||
- convert
|
||||
- prompt_tokenizers
|
||||
- logging_config
|
||||
- core.builders.base
|
||||
- core.builders.causal
|
||||
- core.builders.rl
|
||||
- core.trainer_builder
|
||||
- core.training_args
|
||||
- core.chat.messages
|
||||
- core.chat.format.chatml
|
||||
@@ -45,37 +43,13 @@ quartodoc:
|
||||
- cli.vllm_serve
|
||||
- cli.cloud.base
|
||||
- cli.cloud.modal_
|
||||
- cli.quantize
|
||||
- title: Trainers
|
||||
desc: Training implementations
|
||||
contents:
|
||||
- core.trainers.base
|
||||
- core.trainers.trl
|
||||
- core.trainers.mamba
|
||||
- core.trainers.relora
|
||||
- core.trainers.dpo.trainer
|
||||
- core.trainers.grpo.trainer
|
||||
- core.trainers.grpo.sampler
|
||||
- core.trainers.utils
|
||||
- title: Model Loading
|
||||
desc: Functionality for loading and patching models, tokenizers, etc.
|
||||
contents:
|
||||
- loaders.model
|
||||
- loaders.tokenizer
|
||||
- loaders.processor
|
||||
- loaders.adapter
|
||||
- loaders.patch_manager
|
||||
- loaders.constants
|
||||
- title: Mixins
|
||||
desc: Mixin classes for augmenting trainers
|
||||
contents:
|
||||
- core.trainers.mixins.optimizer
|
||||
- core.trainers.mixins.rng_state_loader
|
||||
- core.trainers.mixins.scheduler
|
||||
- title: Context Managers
|
||||
desc: Context managers for altering trainer behaviors
|
||||
contents:
|
||||
- utils.ctx_managers.sequence_parallel
|
||||
- title: Prompt Strategies
|
||||
desc: Prompt formatting strategies
|
||||
contents:
|
||||
@@ -112,7 +86,7 @@ quartodoc:
|
||||
- kernels.swiglu
|
||||
- kernels.quantize
|
||||
- kernels.utils
|
||||
- title: Monkey Patches
|
||||
- title: MonkeyPatches
|
||||
desc: Runtime patches for model optimizations
|
||||
contents:
|
||||
- monkeypatch.llama_attn_hijack_flash
|
||||
@@ -129,16 +103,17 @@ quartodoc:
|
||||
- monkeypatch.trainer_fsdp_optim
|
||||
- monkeypatch.transformers_fa_utils
|
||||
- monkeypatch.unsloth_
|
||||
- monkeypatch.attention.mllama
|
||||
- monkeypatch.data.batch_dataset_fetcher
|
||||
- monkeypatch.mixtral
|
||||
- monkeypatch.gradient_checkpointing.offload_cpu
|
||||
- monkeypatch.gradient_checkpointing.offload_disk
|
||||
- title: Utils
|
||||
desc: Utility functions
|
||||
contents:
|
||||
- utils.models
|
||||
- utils.tokenization
|
||||
- utils.chat_templates
|
||||
- utils.lora
|
||||
- utils.lora_embeddings
|
||||
- utils.model_shard_quant
|
||||
- utils.bench
|
||||
- utils.freeze
|
||||
@@ -149,7 +124,7 @@ quartodoc:
|
||||
- utils.optimizers.adopt
|
||||
- utils.data.pretraining
|
||||
- utils.data.sft
|
||||
- utils.quantization
|
||||
- utils.gradient_checkpointing.unsloth
|
||||
- title: Schemas
|
||||
desc: Pydantic data models for Axolotl config
|
||||
contents:
|
||||
@@ -199,14 +174,12 @@ quartodoc:
|
||||
- utils.callbacks.lisa
|
||||
- utils.callbacks.mlflow_
|
||||
- utils.callbacks.comet_
|
||||
- utils.callbacks.qat
|
||||
|
||||
website:
|
||||
title: "Axolotl"
|
||||
description: "We make fine-tuning accessible, scalable, and fun"
|
||||
favicon: favicon.jpg
|
||||
|
||||
google-analytics: "G-9KYCVJBNMQ"
|
||||
|
||||
navbar:
|
||||
logo: image/axolotl_logo_digital_white.svg
|
||||
title: false
|
||||
@@ -259,8 +232,6 @@ website:
|
||||
- docs/lr_groups.qmd
|
||||
- docs/lora_optims.qmd
|
||||
- docs/dataset_loading.qmd
|
||||
- docs/qat.qmd
|
||||
- docs/quantize.qmd
|
||||
|
||||
- section: "Core Concepts"
|
||||
contents:
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
FROM axolotlai/axolotl-base-uv:{{ BASE_TAG }}
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
||||
ENV CUDA="{{ CUDA }}"
|
||||
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
|
||||
ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
RUN git fetch origin +$GITHUB_REF && \
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN uv pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
uv pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||
fi
|
||||
|
||||
RUN python scripts/unsloth_install.py --uv | sh
|
||||
RUN python scripts/cutcrossentropy_install.py --uv | sh
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN uv pip install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch
|
||||
|
||||
# helper for huggingface-login cli
|
||||
RUN git config --global credential.helper store
|
||||
@@ -18,7 +18,7 @@ pytest -v --durations=10 \
|
||||
--cov-append
|
||||
|
||||
# Run patched tests excluding lora kernels with coverage append
|
||||
pytest --full-trace -vvv --durations=10 \
|
||||
pytest -v --durations=10 \
|
||||
--ignore=tests/e2e/patched/lora_kernels \
|
||||
/workspace/axolotl/tests/e2e/patched \
|
||||
--cov=axolotl \
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
"""Modal app to run axolotl GPU cleanup"""
|
||||
|
||||
from .single_gpu import VOLUME_CONFIG, app, cicd_image, run_cmd
|
||||
|
||||
|
||||
@app.function(
|
||||
image=cicd_image,
|
||||
timeout=60 * 60,
|
||||
cpu=8.0,
|
||||
memory=131072,
|
||||
volumes=VOLUME_CONFIG,
|
||||
)
|
||||
def cleanup():
|
||||
run_cmd("./cicd/cleanup.sh", "/workspace/axolotl")
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
cleanup.remote()
|
||||
@@ -1,6 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# cleanup old cache files for datasets processing and intermediate mappings
|
||||
find /workspace/data/huggingface-cache/hub/datasets -name "cache-*" -type f -mtime +1 -exec rm {} \;
|
||||
find /workspace/data/huggingface-cache/hub/datasets -name "*.lock" -type f -mtime +1 -exec rm {} \;
|
||||
@@ -1,12 +1,75 @@
|
||||
"""Modal app to run axolotl GPU tests"""
|
||||
|
||||
from .single_gpu import GPU_CONFIG, VOLUME_CONFIG, app, cicd_image, run_cmd
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
|
||||
import jinja2
|
||||
import modal
|
||||
from jinja2 import select_autoescape
|
||||
from modal import App, Image
|
||||
|
||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
df_template = template_env.get_template("Dockerfile.jinja")
|
||||
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
|
||||
"CUDA": os.environ.get("CUDA", "121"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||
}
|
||||
|
||||
dockerfile_contents = df_template.render(**df_args)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||
f.write(dockerfile_contents)
|
||||
|
||||
cicd_image = Image.from_dockerfile(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
context_mount=None,
|
||||
force_build=True,
|
||||
gpu="A10G",
|
||||
).env(df_args)
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
|
||||
hf_cache_volume = modal.Volume.from_name(
|
||||
"axolotl-ci-hf-hub-cache", create_if_missing=True
|
||||
)
|
||||
VOLUME_CONFIG = {
|
||||
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
|
||||
}
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
|
||||
|
||||
|
||||
def run_cmd(cmd: str, run_folder: str):
|
||||
import subprocess # nosec
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
|
||||
@app.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=90 * 60, # 90 min
|
||||
timeout=60 * 60,
|
||||
cpu=8.0,
|
||||
memory=131072,
|
||||
volumes=VOLUME_CONFIG,
|
||||
|
||||
@@ -24,9 +24,9 @@ df_template = template_env.get_template("Dockerfile.jinja")
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
|
||||
"CUDA": os.environ.get("CUDA", "124"),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.4.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.4.1"),
|
||||
"CUDA": os.environ.get("CUDA", "121"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||
@@ -55,7 +55,7 @@ VOLUME_CONFIG = {
|
||||
}
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 2))
|
||||
GPU_CONFIG = f"H100:{N_GPUS}"
|
||||
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
|
||||
|
||||
|
||||
def run_cmd(cmd: str, run_folder: str):
|
||||
@@ -70,7 +70,7 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=90 * 60,
|
||||
cpu=16.0,
|
||||
cpu=8.0,
|
||||
memory=131072 * N_GPUS,
|
||||
volumes=VOLUME_CONFIG,
|
||||
)
|
||||
|
||||
@@ -20,4 +20,4 @@ pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
|
||||
--cov-report=xml:multigpu-coverage.xml
|
||||
|
||||
# Upload coverage to Codecov
|
||||
codecov upload-process -t "${CODECOV_TOKEN}" -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION} || true
|
||||
codecov upload-process -t $CODECOV_TOKEN -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
"""Modal app to run axolotl GPU tests"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
|
||||
import jinja2
|
||||
import modal
|
||||
import modal.experimental
|
||||
from jinja2 import select_autoescape
|
||||
from modal import App
|
||||
|
||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||
template_env = jinja2.Environment(
|
||||
loader=template_loader, autoescape=select_autoescape()
|
||||
)
|
||||
dockerfile = os.environ.get("E2E_DOCKERFILE", "Dockerfile.jinja")
|
||||
df_template = template_env.get_template(dockerfile)
|
||||
|
||||
df_args = {
|
||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.5.1"),
|
||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu124-2.5.1"),
|
||||
"CUDA": os.environ.get("CUDA", "124"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||
}
|
||||
|
||||
dockerfile_contents = df_template.render(**df_args)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||
f.write(dockerfile_contents)
|
||||
|
||||
cicd_image = modal.experimental.raw_dockerfile_image(
|
||||
pathlib.Path(temp_dir) / "Dockerfile",
|
||||
# context_mount=None,
|
||||
force_build=True,
|
||||
# gpu="A10G",
|
||||
).env(df_args)
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
|
||||
hf_cache_volume = modal.Volume.from_name(
|
||||
"axolotl-ci-hf-hub-cache", create_if_missing=True
|
||||
)
|
||||
VOLUME_CONFIG = {
|
||||
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
|
||||
}
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||
GPU_CONFIG = f"L40S:{N_GPUS}"
|
||||
|
||||
|
||||
def run_cmd(cmd: str, run_folder: str):
|
||||
import subprocess # nosec
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
@@ -19,7 +19,7 @@ coverage:
|
||||
if_no_uploads: error
|
||||
if_not_found: success
|
||||
if_ci_failed: error
|
||||
only_pulls: true
|
||||
only_pulls: false
|
||||
flags: null
|
||||
paths: null
|
||||
patch:
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
{
|
||||
"compile": {
|
||||
"disable": false,
|
||||
"backend": "inductor"
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu"
|
||||
},
|
||||
"contiguous_gradients": true,
|
||||
"overlap_comm": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": "auto"
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -38,6 +38,6 @@ RUN git lfs install --skip-repo && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
|
||||
pip3 install flash-attn==2.7.4.post1; \
|
||||
fi
|
||||
|
||||
@@ -29,7 +29,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==2.7.1 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir -U torch==2.7.0 --extra-index-url https://download.pytorch.org/whl/test/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
ARG CUDA_VERSION="12.6.3"
|
||||
ARG CUDNN_VERSION=""
|
||||
ARG UBUNTU_VERSION="22.04"
|
||||
ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="2.6.0"
|
||||
ARG CUDA="126"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
ENV UV_TORCH_BACKEND="cu${CUDA}"
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config curl && rm -rf /var/lib/apt/lists/* \
|
||||
&& git lfs install --skip-repo \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
ENV PATH="/root/.local/bin:${PATH}"
|
||||
|
||||
RUN uv python install ${PYTHON_VERSION}
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN uv venv --no-project --relocatable axolotl-venv
|
||||
|
||||
ENV PATH="/workspace/axolotl-venv/bin:${PATH}"
|
||||
|
||||
RUN uv pip install packaging setuptools wheel psutil \
|
||||
&& uv pip install torch==${PYTORCH_VERSION} \
|
||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||
&& uv pip install awscli pydantic
|
||||
|
||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
||||
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
|
||||
fi
|
||||
10
docs/cli.qmd
10
docs/cli.qmd
@@ -209,16 +209,6 @@ axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
|
||||
|
||||
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
|
||||
|
||||
### quantize
|
||||
|
||||
Quantizes a model using the quantization configuration specified in your YAML file.
|
||||
|
||||
```bash
|
||||
axolotl quantize config.yml
|
||||
```
|
||||
|
||||
See [Quantization](./quantize.qmd) for more details.
|
||||
|
||||
|
||||
## Legacy CLI Usage
|
||||
|
||||
|
||||
111
docs/config.qmd
111
docs/config.qmd
@@ -27,15 +27,11 @@ trust_remote_code:
|
||||
tokenizer_use_fast:
|
||||
# Whether to use the legacy tokenizer setting, defaults to True
|
||||
tokenizer_legacy:
|
||||
# Whether to use mistral-common tokenizer. If set to True, it will use the mistral-common tokenizer.
|
||||
tokenizer_use_mistral_common:
|
||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||
# This is reported to improve training speed on some models
|
||||
resize_token_embeddings_to_32x:
|
||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||
shrink_embeddings:
|
||||
# Optional[bool] Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs
|
||||
embeddings_skip_upcast:
|
||||
# Whether to load the model with randomly initialized weights. Useful for
|
||||
# pre-training a model from scratch or debugging purposes.
|
||||
random_init_weights:
|
||||
@@ -67,20 +63,6 @@ bnb_config_kwargs:
|
||||
bnb_4bit_quant_type: nf4
|
||||
bnb_4bit_use_double_quant: true
|
||||
|
||||
# quantization aware training
|
||||
qat:
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
|
||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
||||
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
|
||||
|
||||
# post-training quantization
|
||||
quantization:
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
||||
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
|
||||
|
||||
|
||||
# Whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
@@ -91,12 +73,11 @@ load_in_8bit: true
|
||||
load_in_4bit:
|
||||
|
||||
# Use CUDA bf16
|
||||
bf16: true # bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere
|
||||
bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
|
||||
# Use CUDA fp16
|
||||
fp16: true
|
||||
# Use CUDA tf32
|
||||
tf32: true # require >=ampere
|
||||
# Note: if bf16 is set to 'auto', and fp16 is set to true, we will prefer the explict fp16 setting
|
||||
|
||||
# No AMP (automatic mixed precision)
|
||||
bfloat16: true # require >=ampere
|
||||
@@ -114,10 +95,8 @@ plugins:
|
||||
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
# A list of one or more datasets to finetune the model with
|
||||
# See https://docs.axolotl.ai/docs/dataset_loading.html for guide on loading datasets
|
||||
# See https://docs.axolotl.ai/docs/dataset-formats/ for guide on dataset formats
|
||||
datasets:
|
||||
# HuggingFace dataset repo | s3:// | gs:// | path to local file or directory
|
||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
@@ -175,14 +154,6 @@ datasets:
|
||||
# Key containing the messages (default: "messages")
|
||||
field_messages: messages
|
||||
|
||||
# Key containing the tools (default: "tools")
|
||||
# Must be a list[dict] and follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||
field_tools: tools
|
||||
|
||||
# Key containing the system message (default: "system")
|
||||
# If the system message is not present in the dataset sample, it will be loaded from the field_system property.
|
||||
field_system: system
|
||||
|
||||
# Mapping of properties from the input dataset to the chat template.
|
||||
# (default: message_property_mappings={'role':'role', 'content':'content'})
|
||||
# If a property exists in the template but not in this mapping, the system will attempt
|
||||
@@ -209,14 +180,10 @@ datasets:
|
||||
# adding a system turn with empty content.
|
||||
drop_system_message:
|
||||
|
||||
# Optional[bool]. (for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags
|
||||
# See example at `docs/dataset-formats/conversation.qmd`
|
||||
split_thinking:
|
||||
|
||||
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
||||
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
||||
# See examples at `docs/dataset-formats/conversation.qmd`
|
||||
# Note: If the below 5 fields are empty, defaults to training only on the last message.
|
||||
# Note: If the below 4 fields are set to empty, defaults to training only on the last message.
|
||||
|
||||
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
|
||||
roles_to_train: ["assistant"] # default
|
||||
@@ -225,13 +192,7 @@ datasets:
|
||||
# - turn (default): train on the EOS token at the end of each trainable turn
|
||||
# - last: train on the last EOS token in the conversation
|
||||
# TIP: Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
|
||||
train_on_eos: turn
|
||||
# Optional[str]. Which EOT (End-of-Turn) tokens to train on in the conversation. Possible values are:
|
||||
# - all: train on all EOT tokens
|
||||
# - turn: train on the EOT token at the end of each trainable turn
|
||||
# - last: train on the last EOT token in the conversation
|
||||
# If not specified, defaults to the value of train_on_eos for backward compatibility.
|
||||
train_on_eot:
|
||||
train_on_eos: last
|
||||
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
|
||||
message_field_training: training
|
||||
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
|
||||
@@ -243,7 +204,7 @@ datasets:
|
||||
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
|
||||
shuffle_merged_datasets: true
|
||||
|
||||
# Deduplicates datasets and test_datasets with identical entries.
|
||||
Deduplicates datasets and test_datasets with identical entries.
|
||||
dataset_exact_deduplication: true
|
||||
|
||||
# A list of one or more datasets to eval the model with.
|
||||
@@ -292,25 +253,10 @@ trl:
|
||||
|
||||
num_generations: # Optional[int]. Number of generations to sample.
|
||||
log_completions: # Optional[bool]. Whether to log completions.
|
||||
num_completions_to_print: # Optional[int]. Number of completions to print when log_completions is True.
|
||||
|
||||
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
|
||||
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
|
||||
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
|
||||
scale_rewards: # Optional[bool]. Whether to scale rewards by their standard deviation.
|
||||
|
||||
temperature: # Optional[float]. Sampling temperature for the GRPO policy.
|
||||
top_p: # Optional[float]. Top-p sampling probability for the generation policy.
|
||||
top_k: # Optional[int]. Top-k sampling for the generation policy.
|
||||
min_p: # Optional[float]. Minimum probability for the generation policy.
|
||||
repetition_penalty: # Optional[float]. Penalty for tokens that appear in prompt and generated text.
|
||||
|
||||
num_iterations: # Optional[int]. Number of iterations per batch (μ) for GRPO.
|
||||
epsilon: # Optional[float]. Epsilon value for clipping in the GRPO algorithm.
|
||||
epsilon_high: # Optional[float]. Upper-bound epsilon value for clipping in the GRPO algorithm.
|
||||
use_liger_loss: # Optional[bool]. Whether to use Liger loss for GRPO.
|
||||
loss_type: # Optional[str]. Loss formulation to use. Supported values: grpo, bnpo, dr_grpo.
|
||||
mask_truncated_completions: # Optional[bool]. Whether to exclude truncated completions from loss calculation.
|
||||
|
||||
|
||||
# reward modelling: `True` or `False`
|
||||
@@ -329,17 +275,8 @@ process_reward_model:
|
||||
chat_template: tokenizer_default
|
||||
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
||||
chat_template_jinja: null
|
||||
# Optional[List[str]]. Custom EOT (End-of-Turn) tokens to mask/unmask during training.
|
||||
# These tokens mark the boundaries between conversation turns.
|
||||
# For example: ["/INST", "</s>", "[/SYSTEM_PROMPT]"]
|
||||
# If not specified, defaults to just the model's eos_token.
|
||||
# This is useful for templates that use multiple delimiter tokens.
|
||||
eot_tokens:
|
||||
# - "</s>"
|
||||
# - "[/INST]"
|
||||
# - "[/SYSTEM_PROMPT]"
|
||||
# Changes the default system message
|
||||
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
|
||||
# Changes the default system message. Currently only supports chatml.
|
||||
default_system_message: You are a helpful assistant. Please give a long and detailed answer.
|
||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
@@ -520,7 +457,6 @@ output_dir: ./completed-model
|
||||
# setting to `auto` will enable torch compile when torch>=2.5.1
|
||||
torch_compile: # Optional[Union[Literal["auto"], bool]]
|
||||
torch_compile_backend: # Optional[str]
|
||||
torch_compile_mode: # 'default' | 'reduce-overhead' | 'max-autotune'
|
||||
|
||||
# Training hyperparameters
|
||||
|
||||
@@ -543,7 +479,6 @@ save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of eac
|
||||
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
|
||||
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
||||
save_total_limit: # Checkpoints saved at a time
|
||||
save_only_model: # Save only the model weights, skipping the optimizer. Using this means you can't resume from checkpoints.
|
||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||
# if both are set, num_epochs will not be guaranteed.
|
||||
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
||||
@@ -567,7 +502,7 @@ profiler_steps: # enable the pytorch profiler to capture the first N steps of tr
|
||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||
|
||||
# Save model as safetensors (require safetensors package). Default True
|
||||
# Save model as safetensors (require safetensors package)
|
||||
save_safetensors:
|
||||
|
||||
# Whether to mask out or include the human's prompt from the training labels
|
||||
@@ -577,7 +512,7 @@ train_on_inputs: false
|
||||
# Note that training loss may have an oscillating pattern with this enabled.
|
||||
group_by_length: false
|
||||
|
||||
# Whether to use gradient checkpointing. Available options are: true, false, "offload", "offload_disk".
|
||||
# Whether to use gradient checkpointing. Available options are: true, false, "offload".
|
||||
# https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
# additional kwargs to pass to the trainer for gradient checkpointing
|
||||
@@ -589,24 +524,7 @@ gradient_checkpointing: false
|
||||
early_stopping_patience: 3
|
||||
|
||||
# Specify a scheduler and kwargs to use with the optimizer
|
||||
# Valid values are driven by the Transformers SchedulerType class, see:
|
||||
# https://github.com/huggingface/transformers/blob/5f4ecf2d9f867a1255131d2461d75793c0cf1db2/src/transformers/trainer_utils.py#L420
|
||||
# Valid values include
|
||||
# - 'linear'
|
||||
# - 'cosine' (default)
|
||||
# - 'cosine_with_restarts'
|
||||
# - 'polynomial'
|
||||
# - 'constant'
|
||||
# - 'constant_with_warmup'
|
||||
# - 'inverse_sqrt'
|
||||
# - 'reduce_lr_on_plateau'
|
||||
# - 'cosine_with_min_lr'
|
||||
# - 'warmup_stable_decay'
|
||||
|
||||
# Additional schedulers include:
|
||||
# - 'one_cycle'
|
||||
# - 'rex'
|
||||
lr_scheduler:
|
||||
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
|
||||
lr_scheduler_kwargs:
|
||||
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
|
||||
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
|
||||
@@ -624,7 +542,7 @@ lr_div_factor: # Learning rate div factor
|
||||
#
|
||||
# Valid values for 'optimizer' include:
|
||||
# - adamw_torch
|
||||
# - adamw_torch_fused (default)
|
||||
# - adamw_torch_fused
|
||||
# - adamw_torch_xla
|
||||
# - adamw_torch_npu_fused
|
||||
# - adamw_apex_fused
|
||||
@@ -668,7 +586,6 @@ lr_div_factor: # Learning rate div factor
|
||||
# - optimi_adamw
|
||||
# - ao_adamw_8bit
|
||||
# - ao_adamw_fp8
|
||||
# - came_pytorch
|
||||
optimizer:
|
||||
# Dictionary of arguments to pass to the optimizer
|
||||
optim_args:
|
||||
@@ -688,9 +605,7 @@ weight_decay:
|
||||
# adamw hyperparams
|
||||
adam_beta1:
|
||||
adam_beta2:
|
||||
adam_beta3: # only used for CAME Optimizer
|
||||
adam_epsilon:
|
||||
adam_epsilon2: # only used for CAME Optimizer
|
||||
# Gradient clipping max norm
|
||||
max_grad_norm:
|
||||
|
||||
@@ -746,10 +661,8 @@ special_tokens:
|
||||
# unk_token: "<unk>"
|
||||
# pad_token: "[PAD]"
|
||||
|
||||
# Optional[list[str]]. Add extra tokens to the tokenizer.
|
||||
# Add extra tokens.
|
||||
tokens:
|
||||
# - "<|startoftext|>"
|
||||
# - "<|endoftext|>"
|
||||
|
||||
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
|
||||
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
|
||||
|
||||
@@ -49,8 +49,7 @@ sections = [
|
||||
("Knowledge Distillation (KD)", "kd"),
|
||||
("Liger Kernels", "liger"),
|
||||
("Language Model Evaluation Harness (LM Eval)", "lm_eval"),
|
||||
("Spectrum", "spectrum"),
|
||||
("LLMCompressor", "llm_compressor")
|
||||
("Spectrum", "spectrum")
|
||||
]
|
||||
|
||||
for section_name, folder_name in sections:
|
||||
|
||||
@@ -4,6 +4,18 @@ description: Conversation format for supervised fine-tuning.
|
||||
order: 3
|
||||
---
|
||||
|
||||
## sharegpt
|
||||
|
||||
::: {.callout-important}
|
||||
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section below.
|
||||
:::
|
||||
|
||||
## pygmalion
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
## chat_template
|
||||
|
||||
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
||||
@@ -52,9 +64,7 @@ We recommend checking the below examples for other usecases.
|
||||
|
||||
### Examples
|
||||
|
||||
#### Training on last message
|
||||
|
||||
(Legacy) Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
||||
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
@@ -68,9 +78,7 @@ datasets:
|
||||
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
|
||||
:::
|
||||
|
||||
#### Overriding default chat template
|
||||
|
||||
Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
chat_template: gemma # this overwrites the tokenizer's chat_template
|
||||
@@ -80,13 +88,7 @@ datasets:
|
||||
roles_to_train: ["assistant"] # default value
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
If you want to use built-in chat_template, use `chat_template: tokenizer_default` (this is set by default).
|
||||
:::
|
||||
|
||||
#### Using default chat template with fallback
|
||||
|
||||
Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
|
||||
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
|
||||
@@ -95,9 +97,7 @@ datasets:
|
||||
type: chat_template
|
||||
```
|
||||
|
||||
#### Custom Jinja template
|
||||
|
||||
Using a custom jinja template on OpenAI messages format, training on all assistant messages.
|
||||
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
|
||||
@@ -109,123 +109,10 @@ datasets:
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
Please make sure that your `tokenizer.eos_token` is same as EOS (End-of-Sequence) token in template. Otherwise, set `eos_token` under `special_tokens: `.
|
||||
Please make sure that your `tokenizer.eos_token` is same as EOS/EOT token in template. Otherwise, set `eos_token` under `special_tokens`.
|
||||
:::
|
||||
|
||||
#### Using template with different token for EOT and EOS
|
||||
|
||||
- If you are using a template that has a different EOT (End-of-Turn) token from EOS token or multiple EOT tokens (like Mistral V7 Tekken), set the `eot_tokens: ` config. The handling of EOT tokens follows `train_on_eos: ` which defaults to turn.
|
||||
|
||||
```yaml
|
||||
eot_tokens:
|
||||
- "[/INST]"
|
||||
# - "[/SYSTEM_PROMPT]"
|
||||
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
|
||||
# optional
|
||||
train_on_eot: turn # defaults read from train_on_eos (which defaults to turn)
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
See [config documentation](../config.qmd) for detailed explanations of "turn", "last", and "all" options for training on tokens.
|
||||
:::
|
||||
|
||||
::: {.callout-note}
|
||||
Using `eot_tokens` requires each token that exists in `chat_template` to be a single token in the tokenizer. Otherwise, the tokenizer will split the token and cause unexpected behavior.
|
||||
|
||||
You can add those tokens as new tokens under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `. See [config](../config.qmd) for more details.
|
||||
:::
|
||||
|
||||
- Continuing from the previous example, if you want to train on all EOT token trainable turns but only last EOS token, set `train_on_eos: last`.
|
||||
|
||||
```yaml
|
||||
eot_tokens:
|
||||
- "[/INST]"
|
||||
# ...
|
||||
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
|
||||
train_on_eos: last
|
||||
train_on_eot: turn
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
If EOS token only appears at the end of a prompt, `train_on_eos: last` is equivalent to `train_on_eos: turn`. Therefore, generally, you can leave them to their defaults and omit them.
|
||||
:::
|
||||
|
||||
|
||||
#### Using tool use
|
||||
|
||||
Instead of passing `tools` via the system prompt, an alternative method would be to have the `tools` in a separate column and loaded via `chat_template` to let the template dynamically build it.
|
||||
|
||||
```json
|
||||
{
|
||||
"tools": [
|
||||
{
|
||||
"type": "...",
|
||||
"function": {
|
||||
"name": "...",
|
||||
"description": "...",
|
||||
"parameters": {
|
||||
"type": "...",
|
||||
"properties": {
|
||||
// ...
|
||||
},
|
||||
"required": ["..."],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"messages": [
|
||||
// ...
|
||||
{
|
||||
"role": "assistant", // call the function via assistant
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "...",
|
||||
"arguments": {
|
||||
"...": "...",
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"name": "...",
|
||||
"content": "..."
|
||||
},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
Tools need to follow [JSON schema](https://json-schema.org/learn/getting-started-step-by-step).
|
||||
:::
|
||||
|
||||
```yaml
|
||||
chat_template: llama4
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
# field_tools: tools # default is `tools`
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
Look into the `chat_template` you are using to see if it supports `tools` and what the expected role is for the tool answer. In the example above, the tool answer is expected to be in the `tool` or `ipython` role for `llama4` template.
|
||||
:::
|
||||
|
||||
|
||||
#### Using fine-grained control over token masking
|
||||
|
||||
(Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
|
||||
For a data sample that looks like:
|
||||
|
||||
@@ -275,45 +162,3 @@ datasets:
|
||||
::: {.callout-tip}
|
||||
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
||||
:::
|
||||
|
||||
#### Reasoning split
|
||||
|
||||
(For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
chat_template: qwen3
|
||||
split_thinking: true
|
||||
```
|
||||
|
||||
For example, a content can look like:
|
||||
|
||||
```json
|
||||
{
|
||||
"content": "<think>Some thinking outputs</think>Output after thinking."
|
||||
}
|
||||
```
|
||||
|
||||
After split, it will look like:
|
||||
|
||||
```json
|
||||
{
|
||||
"reasoning_content": "Some thinking outputs",
|
||||
"content": "Output after thinking..."
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## sharegpt
|
||||
|
||||
::: {.callout-important}
|
||||
ShareGPT is deprecated!. Please see [chat_template](#chat_template) section.
|
||||
:::
|
||||
|
||||
## pygmalion
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
@@ -36,6 +36,10 @@ It is typically recommended to save your dataset as `.jsonl` due to its flexibil
|
||||
|
||||
Axolotl supports loading from a Hugging Face hub repo or from local files.
|
||||
|
||||
::: {.callout-important}
|
||||
For pre-training only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts.
|
||||
:::
|
||||
|
||||
### Pre-training from Hugging Face hub datasets
|
||||
|
||||
As an example, to train using a Hugging Face dataset `hf_org/name`, you can pass the following config:
|
||||
@@ -73,21 +77,18 @@ datasets:
|
||||
type: completion
|
||||
```
|
||||
|
||||
From local files:
|
||||
From local files (either example works):
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: A.jsonl
|
||||
type: completion
|
||||
|
||||
- path: B.jsonl
|
||||
- path: json
|
||||
data_files: ["A.jsonl", "B.jsonl", "C.jsonl"]
|
||||
type: completion
|
||||
```
|
||||
|
||||
::: {.callout-important}
|
||||
For `completion` only, Axolotl would split texts if it exceeds the context length into multiple smaller prompts. If you are interested in having this for `pretraining_dataset` too, please let us know or help make a PR!
|
||||
:::
|
||||
|
||||
### Pre-training dataset configuration tips
|
||||
|
||||
#### Setting max_steps
|
||||
|
||||
@@ -54,7 +54,7 @@ datasets:
|
||||
|
||||
#### Files
|
||||
|
||||
To load a JSON file, you would do something like this:
|
||||
Usually, to load a JSON file, you would do something like this:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
@@ -66,11 +66,19 @@ Which translates to the following config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: data.json
|
||||
ds_type: json
|
||||
- path: json
|
||||
data_files: /path/to/your/file.jsonl
|
||||
```
|
||||
|
||||
In the example above, it can be seen that we can just point the `path` to the file or directory along with the `ds_type` to load the dataset.
|
||||
However, to make things easier, we have added a few shortcuts for loading local dataset files.
|
||||
|
||||
You can just point the `path` to the file or directory along with the `ds_type` to load the dataset. The below example shows for a JSON file:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: /path/to/your/file.jsonl
|
||||
ds_type: json
|
||||
```
|
||||
|
||||
This works for CSV, JSON, Parquet, and Arrow files.
|
||||
|
||||
|
||||
@@ -8,10 +8,6 @@ format:
|
||||
|
||||
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use the tags with Pytorch 2.7.1 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
## Base
|
||||
|
||||
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
|
||||
@@ -32,10 +28,11 @@ main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||
|
||||
Tags examples:
|
||||
|
||||
- `main-base-py3.11-cu128-2.7.1`
|
||||
- `main-base-py3.11-cu126-2.7.1`
|
||||
- `main-base-py3.11-cu128-2.7.0`
|
||||
- `main-base-py3.11-cu126-2.7.0`
|
||||
- `main-base-py3.11-cu124-2.6.0`
|
||||
- `main-base-py3.11-cu124-2.5.1`
|
||||
- `main-base-py3.11-cu124-2.4.1`
|
||||
|
||||
## Main
|
||||
|
||||
@@ -76,10 +73,12 @@ Tags examples:
|
||||
- `main-py3.11-cu126-2.7.0`
|
||||
- `main-py3.11-cu124-2.6.0`
|
||||
- `main-py3.11-cu124-2.5.1`
|
||||
- `main-py3.11-cu124-2.4.1`
|
||||
- `main-latest`
|
||||
- `main-20250303-py3.11-cu124-2.6.0`
|
||||
- `main-20250303-py3.11-cu124-2.5.1`
|
||||
- `0.9.2`
|
||||
- `main-20250303-py3.11-cu124-2.4.1`
|
||||
- `0.7.1`
|
||||
|
||||
## Cloud
|
||||
|
||||
|
||||
48
docs/faq.qmd
48
docs/faq.qmd
@@ -73,54 +73,10 @@ description: Frequently asked questions
|
||||
|
||||
> A: This is likely an empty turn.
|
||||
|
||||
**Q: The EOS token is incorrectly being masked or not being masked / `EOS token __ not found in chat template`.**
|
||||
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
|
||||
|
||||
> A: There can be two reasons:
|
||||
|
||||
> 1. This is because of the mismatch between `tokenizer.eos_token` and EOS token in template. Please make sure to set `eos_token: ` under `special_tokens: ` to the same EOS token as in template.
|
||||
|
||||
> 2. The EOS token is not in the template. Please check if your template is correct. As an example, `phi_35` template does not use its dedicated EOS token `<|endoftext|>` at the end.
|
||||
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
|
||||
|
||||
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
|
||||
|
||||
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.
|
||||
|
||||
**Q: The EOT token(s) are incorrectly being masked or not being masked / `EOT token __ not found in chat template`.**
|
||||
|
||||
> A: There can be two reasons:
|
||||
|
||||
> 1. The EOT token is different from the EOS token and was not specified under `eot_tokens: `. Please set `eot_tokens: ` to the same EOT token(s) as in template.
|
||||
|
||||
> 2. There is more than one EOT token per turn in the template. Please raise an issue with examples as we recognize this as an edge case.
|
||||
|
||||
**Q: `EOT token encoding failed. Please check if the token is valid and can be encoded.`**
|
||||
|
||||
> A: There could be some issue with the tokenizer or unicode encoding. Please raise an issue with examples with the EOT token & tokenizer causing the issue.
|
||||
|
||||
**Q: `EOT token __ is encoded as multiple tokens.`**
|
||||
|
||||
> A: This is because the EOT token is encoded as multiple tokens which can cause unexpected behavior. Please add it under `tokens: ` or (recommended) override unused added_tokens via `added_tokens_overrides: `.
|
||||
|
||||
**Q: `Conflict between train_on_eos and train_on_eot. eos_token is in eot_tokens and train_on_eos != train_on_eot`**
|
||||
|
||||
> A: This is because the EOS token is in the `eot_tokens: ` while mismatch between `train_on_eos: ` and `train_on_eot: `. This will cause one to override the other. Please ensure that `train_on_eos: ` and `train_on_eot: ` are the same or remove the EOS token from `eot_tokens: `.
|
||||
|
||||
**Q: If `eot_tokens: ` is not provided, what happens?**
|
||||
|
||||
> A: If `eot_tokens: ` is not provided, the default behavior is the same as before. EOS tokens used to delimit turns are masked/unmasked depending on whether the turn is trainable.
|
||||
|
||||
> Internally, `eot_tokens: tokenizer.eos_token` and `train_on_eot: train_on_eos` (which defaults to `turn`). This transition helps clarify the naming and behavior of EOT/EOS tokens.
|
||||
|
||||
**Q: `Data processing error: CAS service error`**
|
||||
|
||||
> A: Try disabling XET with `export HF_HUB_DISABLE_XET=1`
|
||||
|
||||
**Q: `torch._inductor.exc.LoweringException: NoValidChoicesError: No choices to select, please consider adding ATEN into max_autotune_gemm_backends config (defined in torch/_inductor/config.py) to allow at least one choice. `**
|
||||
|
||||
> A: Depending on the version of torch, you may need to include this in your YAML:
|
||||
|
||||
> ```yaml
|
||||
> flex_attn_compile_kwargs:
|
||||
> dynamic: false
|
||||
> mode: max-autotune-no-cudagraphs
|
||||
> ```
|
||||
|
||||
@@ -104,7 +104,7 @@ the `alpaca` dataset format, which has the following format:
|
||||
Please see our [Dataset Formats](dataset-formats) for more dataset formats and how to
|
||||
format them.
|
||||
|
||||
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca`
|
||||
2. Prepare your JSONL data in the specified format (in this case, the expected `alpaca
|
||||
format):
|
||||
|
||||
```json
|
||||
@@ -120,12 +120,6 @@ axolotl train my_training.yml
|
||||
|
||||
## Common Tasks {#sec-common-tasks}
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
The same yaml file is used for training, inference, and merging.
|
||||
|
||||
:::
|
||||
|
||||
### Testing Your Model {#sec-testing}
|
||||
|
||||
After training, test your model:
|
||||
@@ -134,16 +128,6 @@ After training, test your model:
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||
```
|
||||
|
||||
More details can be found in [Inference](inference.qmd).
|
||||
|
||||
### Using a UI {#sec-ui}
|
||||
|
||||
Launch a Gradio interface:
|
||||
|
||||
```bash
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
||||
```
|
||||
|
||||
### Preprocessing Data {#sec-preprocessing}
|
||||
|
||||
For large datasets, preprocess first:
|
||||
@@ -152,22 +136,14 @@ For large datasets, preprocess first:
|
||||
axolotl preprocess my_training.yml
|
||||
```
|
||||
|
||||
Please make sure to set `dataset_prepared_path: ` in your config to set the path to save the prepared dataset.
|
||||
### Using a UI {#sec-ui}
|
||||
|
||||
More details can be found in [Dataset Preprocessing](dataset_preprocessing.qmd).
|
||||
|
||||
### Merging LoRA weights {#sec-merging-lora}
|
||||
|
||||
To merge the LoRA weights back into the base model, run:
|
||||
Launch a Gradio interface:
|
||||
|
||||
```bash
|
||||
axolotl merge-lora my_training.yml --lora-model-dir="./outputs/lora-out"
|
||||
axolotl inference my_training.yml --lora-model-dir="./outputs/lora-out" --gradio
|
||||
```
|
||||
|
||||
The merged model will be saved in the `{output_dir}/merged` directory.
|
||||
|
||||
More details can be found in [Merging LoRA weights](inference.qmd#sec-merging).
|
||||
|
||||
## Next Steps {#sec-next-steps}
|
||||
|
||||
Now that you have the basics, you might want to:
|
||||
@@ -180,7 +156,6 @@ Now that you have the basics, you might want to:
|
||||
Check our other guides for details on these topics:
|
||||
|
||||
- [Configuration Guide](config.qmd) - Full configuration options
|
||||
- [Dataset Loading](dataset_loading.qmd) - Loading datasets from various sources
|
||||
- [Dataset Formats](dataset-formats) - Working with different data formats
|
||||
- [Multi-GPU Training](multi-gpu.qmd)
|
||||
- [Multi-Node Training](multi-node.qmd)
|
||||
|
||||
@@ -15,7 +15,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
||||
|
||||
- NVIDIA GPU (Ampere architecture or newer for `bf16` and Flash Attention) or AMD GPU
|
||||
- Python ≥3.10
|
||||
- PyTorch ≥2.5.1
|
||||
- PyTorch ≥2.4.1
|
||||
|
||||
## Installation Methods {#sec-installation-methods}
|
||||
|
||||
@@ -25,10 +25,6 @@ Please make sure to have Pytorch installed before installing Axolotl in your loc
|
||||
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
||||
:::
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use Pytorch 2.7.0 and CUDA 12.8.
|
||||
:::
|
||||
|
||||
### PyPI Installation (Recommended) {#sec-pypi}
|
||||
|
||||
```{.bash}
|
||||
@@ -41,40 +37,6 @@ installed) in order not to clobber it, and so that we set the correct version of
|
||||
dependencies that are specific to the PyTorch version or other installed
|
||||
co-dependencies.
|
||||
|
||||
### uv Installation {#sec-uv}
|
||||
|
||||
uv is a fast, reliable Python package installer and resolver built in Rust. It offers significant performance improvements over pip and provides better dependency resolution, making it an excellent choice for complex environments.
|
||||
|
||||
Install uv if not already installed
|
||||
```{.bash}
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
source $HOME/.local/bin/env
|
||||
```
|
||||
|
||||
Choose your CUDA version to use with PyTorch; e.g. `cu124`, `cu126`, `cu128`,
|
||||
then create the venv and activate
|
||||
```{.bash}
|
||||
export UV_TORCH_BACKEND=cu126
|
||||
uv venv --no-project --relocatable
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
Install PyTorch
|
||||
- PyTorch 2.6.0 recommended
|
||||
```{.bash}
|
||||
uv pip install packaging setuptools wheel
|
||||
uv pip install torch==2.6.0
|
||||
uv pip install awscli pydantic
|
||||
```
|
||||
|
||||
Install axolotl from PyPi
|
||||
```{.bash}
|
||||
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn]
|
||||
|
||||
# optionally install with vLLM if you're using torch==2.6.0 and want to train w/ GRPO
|
||||
uv pip install --no-build-isolation axolotl[deepspeed,flash-attn,vllm]
|
||||
```
|
||||
|
||||
### Edge/Development Build {#sec-edge-build}
|
||||
|
||||
For the latest features between releases:
|
||||
@@ -110,10 +72,6 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
||||
```
|
||||
:::
|
||||
|
||||
::: {.callout-important}
|
||||
For Blackwell GPUs, please use `axolotlai/axolotl:main-py3.11-cu128-2.7.0` or the cloud variant `axolotlai/axolotl-cloud:main-py3.11-cu128-2.7.0`.
|
||||
:::
|
||||
|
||||
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
||||
|
||||
## Cloud Environments {#sec-cloud}
|
||||
|
||||
@@ -84,10 +84,6 @@ lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
Currently, LoRA kernels are not supported for RLHF training, only SFT.
|
||||
:::
|
||||
|
||||
## Requirements
|
||||
|
||||
- One or more NVIDIA or AMD GPUs (in order to use the Triton kernels)
|
||||
|
||||
@@ -87,7 +87,20 @@ We support sequence parallelism (SP) via the
|
||||
allows one to split up sequences across GPUs, which is useful in the event that a
|
||||
single sequence causes OOM errors during model training.
|
||||
|
||||
See our [dedicated guide](sequence_parallelism.qmd) for more information.
|
||||
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
|
||||
or from source with `pip install .[ring-flash-attn]`.
|
||||
|
||||
Your Axolotl YAML config should contain the following lines:
|
||||
|
||||
```{.yaml}
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
flash_attention: true # Required with sequence parallelism
|
||||
|
||||
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
|
||||
heads_k_stride: 1
|
||||
```
|
||||
|
||||
See our [dedicated guide](sequence_parallelism.qmd) for more details.
|
||||
|
||||
### FSDP + QLoRA {#sec-fsdp-qlora}
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ datasets:
|
||||
# leave the vision model and vision tower frozen
|
||||
# load_in_8bit: true
|
||||
adapter: lora
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
# (optional) if you want to resize images to a set size
|
||||
image_size: 512
|
||||
@@ -164,7 +164,7 @@ Here is an example of a multi-modal dataset:
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
|
||||
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
|
||||
{"type": "text", "text": "Describe this image in detail."}
|
||||
]
|
||||
},
|
||||
|
||||
32
docs/qat.qmd
32
docs/qat.qmd
@@ -1,32 +0,0 @@
|
||||
---
|
||||
title: "Quantization Aware Training (QAT)"
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-expand: 2
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
[Quantization Aware Training](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/#quantization-aware-training) (QAT) is a technique for improving the accuracy of models which are quantized
|
||||
by applying "fake" quantizations to the model's weights (and optionally, activations) during training. This fake
|
||||
quantization allows for the model to adjust for noise introduced by the quantization, so when the model is eventually
|
||||
quantized, the accuracy loss is minimized. We use the quantization techniques implemented in [torchao](https://github.com/pytorch/ao) to provide
|
||||
support for QAT and post-training quantization (PTQ) in axolotl.
|
||||
|
||||
We recommend reviewing the excellent QAT tutorial in the [torchtune library](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#quantizing-the-qat-model),
|
||||
and the QAT documentation in the [torchao library](https://github.com/pytorch/ao/tree/main/torchao/quantization/qat), for more details.
|
||||
|
||||
## Configuring QAT in Axolotl
|
||||
|
||||
To enable QAT in axolotl, add the following to your configuration file:
|
||||
|
||||
```yaml
|
||||
qat:
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
|
||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
||||
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
|
||||
```
|
||||
|
||||
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize`](./quantize.qmd) command to do this.
|
||||
@@ -1,53 +0,0 @@
|
||||
---
|
||||
title: "Quantization with torchao"
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-expand: 2
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
Quantization is a technique to lower the memory footprint of your model, potentially at the cost of accuracy or model performance. We support quantizing your model using the [torchao](https://github.com/pytorch/ao) library. Quantization is supported for both post-training quantization (PTQ) and quantization-aware training (QAT).
|
||||
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
We do not currently support quantization techniques such as GGUF/GPTQ,EXL2 at the moment.
|
||||
|
||||
:::
|
||||
|
||||
## Configuring Quantization in Axolotl
|
||||
|
||||
Quantization is configured using the `quantization` key in your configuration file.
|
||||
|
||||
```yaml
|
||||
base_model: # The path to the model to quantize.
|
||||
quantization:
|
||||
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
|
||||
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
|
||||
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
|
||||
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
|
||||
|
||||
output_dir: # The path to the output directory.
|
||||
```
|
||||
|
||||
Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory.
|
||||
|
||||
You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.md) - you can do this by using the existing QAT configuration file which
|
||||
you used to train the model:
|
||||
|
||||
```yaml
|
||||
# qat.yml
|
||||
qat:
|
||||
activation_dtype: int8
|
||||
weight_dtype: int8
|
||||
group_size: 256
|
||||
quantize_embedding: true
|
||||
|
||||
output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
|
||||
```
|
||||
|
||||
```bash
|
||||
axolotl quantize qat.yml
|
||||
```
|
||||
|
||||
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.
|
||||
@@ -16,8 +16,7 @@ feedback. Various methods include, but not limited to:
|
||||
- [Identity Preference Optimization (IPO)](#ipo)
|
||||
- [Kahneman-Tversky Optimization (KTO)](#kto)
|
||||
- [Odds Ratio Preference Optimization (ORPO)](#orpo)
|
||||
- [Group Relative Policy Optimization (GRPO)](#grpo)
|
||||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl, if you're interested in contributing, please reach out!)
|
||||
- Proximal Policy Optimization (PPO) (not yet supported in axolotl)
|
||||
|
||||
|
||||
## RLHF using Axolotl
|
||||
@@ -500,10 +499,12 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
### GRPO
|
||||
|
||||
::: {.callout-tip}
|
||||
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/grpo_code).
|
||||
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
|
||||
:::
|
||||
|
||||
In the latest GRPO implementation, `vLLM` is used to significantly speedup trajectory generation during training. In this example, we're using 4 GPUs - 2 for training, and 2 for vLLM:
|
||||
If you have multiple GPUs available, we reccomend using `vLLM` with the `GRPOTrainer` to significantly speedup trajectory generation during training.
|
||||
First, launch a `vLLM` server using `trl vllm-serve` - you may use a config file or CLI overrides to configure your vLLM server. In this example, we're
|
||||
using 4 GPUs - 2 for training, and 2 for vLLM:
|
||||
|
||||
::: {.callout-important}
|
||||
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
|
||||
@@ -538,10 +539,6 @@ Your `vLLM` instance will now attempt to spin up, and it's time to kick off trai
|
||||
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
|
||||
```
|
||||
|
||||
::: {.callout-note}
|
||||
Due to TRL's implementation with vLLM, the vLLM instance must use the last N GPUs instead of the first N GPUs. This is why in the example above, we use `CUDA_VISIBLE_DEVICES=2,3` for the vLLM instance.
|
||||
:::
|
||||
|
||||
#### Reward functions
|
||||
|
||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||
@@ -583,20 +580,7 @@ datasets:
|
||||
|
||||
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
||||
|
||||
To see all configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/v0.9.2/src/axolotl/utils/schemas/trl.py).
|
||||
|
||||
#### GRPO with DAPO/Dr. GRPO loss
|
||||
|
||||
The DAPO paper and subsequently Dr. GRPO paper proposed an alternative loss function for GRPO to remediate the penalty in longer responses.
|
||||
|
||||
```yaml
|
||||
trl:
|
||||
loss_type: dr_grpo
|
||||
# Normalizes loss based on max completion length (default: 256)
|
||||
max_completion_length:
|
||||
```
|
||||
|
||||
For more information, see [GRPO docs](https://huggingface.co/docs/trl/v0.17.0/en/grpo_trainer#loss-types).
|
||||
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
|
||||
|
||||
### SimPO
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ title: Sequence Parallelism
|
||||
description: Train with long sequences split across multiple GPUs.
|
||||
---
|
||||
|
||||
# Sequence Parallelism
|
||||
|
||||
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
||||
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
||||
GPU processes a different portion of the sequence, and the results are aggregated
|
||||
@@ -25,7 +27,7 @@ To enable sequence parallelism, add the following to your configuration file:
|
||||
sequence_parallel_degree: 4 # Split sequences across 4 GPUs
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
# Optional; one of "varlen_llama3", "batch_ring", "batch_zigzag", "batch_stripe". Defaults to
|
||||
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
|
||||
ring_attn_func:
|
||||
```
|
||||
@@ -41,7 +43,7 @@ When sequence parallelism is enabled:
|
||||
|
||||
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
||||
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
||||
3. Position IDs are adjusted to maintain proper relative positions
|
||||
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
|
||||
4. The trainer uses special ring communication patterns for attention operations
|
||||
|
||||
## Requirements
|
||||
@@ -67,11 +69,9 @@ sequence_len: 8192
|
||||
...
|
||||
|
||||
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||
flash_attention: true # Required with sequence parallelism
|
||||
# Optional; strides across the key dimension. Larger values use more memory but should make training faster.
|
||||
heads_k_stride: 1
|
||||
# Optional; one of "varlen_llama3" or "batch_ring". Defaults to
|
||||
# "varlen_llama3" when `sample_packing: true`, and "batch_ring" otherwise.
|
||||
ring_attn_func:
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
@@ -28,7 +28,7 @@ pad_to_sequence_len: true
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
@@ -30,7 +30,7 @@ pad_to_sequence_len: false
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
@@ -29,7 +29,7 @@ pad_to_sequence_len: false
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
base_model: meta-llama/Llama-3.2-3B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
datasets:
|
||||
- path: yahma/alpaca-cleaned
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/qat_out/
|
||||
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
sequence_len: 512
|
||||
|
||||
flex_attention: true
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
qat:
|
||||
activation_dtype: int8
|
||||
weight_dtype: int4
|
||||
group_size: 32
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 16
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
|
||||
cosine_constant_lr_ratio: 0
|
||||
cosine_min_lr_ratio: 1.0
|
||||
learning_rate: 2e-5
|
||||
save_only_model: true
|
||||
bf16: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_steps: 10
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
@@ -5,10 +5,6 @@ tokenizer_type: AutoTokenizer
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
special_tokens:
|
||||
pad_token: <|finetune_right_pad_id|>
|
||||
eos_token: <|eot_id|>
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ base_model: NousResearch/Llama-3.2-1B
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
@@ -38,7 +38,6 @@ wandb_log_model:
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
|
||||
optimizer: adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
base_model: neuralmagic/Sparse-Llama-3.1-8B-2of4
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.llm_compressor.LLMCompressorPlugin
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
eval_sample_packing: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: paged_adamw_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 100
|
||||
evals_per_epoch: 2
|
||||
eval_table_size:
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: <|end_of_text|>
|
||||
|
||||
llmcompressor:
|
||||
recipe:
|
||||
finetuning_stage:
|
||||
finetuning_modifiers:
|
||||
ConstantPruningModifier:
|
||||
targets: [
|
||||
're:.*q_proj.weight',
|
||||
're:.*k_proj.weight',
|
||||
're:.*v_proj.weight',
|
||||
're:.*o_proj.weight',
|
||||
're:.*gate_proj.weight',
|
||||
're:.*up_proj.weight',
|
||||
're:.*down_proj.weight',
|
||||
]
|
||||
start: 0
|
||||
save_compressed: true
|
||||
@@ -34,5 +34,3 @@ We provide a script to delinearize Llama 4 linearized models into regular Huggin
|
||||
```bash
|
||||
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
|
||||
```
|
||||
|
||||
Note: This only works with the non-quantized linearized model. If you have an adapter, merge it with the *non-quantized linearized* model before delinearizing.
|
||||
|
||||
@@ -25,7 +25,7 @@ pad_to_sequence_len: false
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
# Finetune Magistral Small with Axolotl
|
||||
|
||||
Magistral Small is a 24B parameter opensource model from MistralAI found on [HuggingFace](https://huggingface.co/mistralai/Magistral-Small-2506). This guide shows how to fine-tune it with Axolotl with multi-turn conversations with proper masking.
|
||||
|
||||
MistralAI has also released a proprietary medium-sized version called Magistral Medium.
|
||||
|
||||
Thanks to the team at MistralAI for giving us early access to prepare for this release.
|
||||
|
||||
## Getting started
|
||||
|
||||
1. Install Axolotl following the [installation guide](https://docs.axolotl.ai/docs/installation.html). You need to install from main as Magistral is only on nightly or use our latest [Docker images](https://docs.axolotl.ai/docs/docker.html).
|
||||
|
||||
Here is an example of how to install from main for pip:
|
||||
|
||||
```bash
|
||||
# Ensure you have Pytorch installed (Pytorch 2.6.0 recommended)
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,mistral]'
|
||||
```
|
||||
|
||||
2. Download the example config:
|
||||
|
||||
```bash
|
||||
axolotl fetch examples
|
||||
```
|
||||
|
||||
3. Run the finetuning example:
|
||||
|
||||
```bash
|
||||
axolotl train examples/magistral/magistral-small-qlora.yaml
|
||||
```
|
||||
|
||||
This config uses about 24GB VRAM.
|
||||
|
||||
Let us know how it goes. Happy finetuning! 🚀
|
||||
|
||||
### TIPS
|
||||
|
||||
- For inference, the official MistralAI team recommends `top_p: 0.95` and `temperature: 0.7` with `max_tokens: 40960`.
|
||||
- You can run a full finetuning by removing the `adapter: qlora` and `load_in_4bit: true` from the config.
|
||||
- Read more on how to load your own dataset at [docs](https://docs.axolotl.ai/docs/dataset_loading.html).
|
||||
- The dataset format is the OpenAI Messages format as seen [here](https://docs.axolotl.ai/docs/dataset-formats/conversation.html#chat_template).
|
||||
|
||||
## Optimization Guides
|
||||
|
||||
- [Multi-GPU Training](https://docs.axolotl.ai/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://docs.axolotl.ai/docs/multi-node.html)
|
||||
- [LoRA Optimizations](https://docs.axolotl.ai/docs/lora_optims.html)
|
||||
|
||||
## Limitations
|
||||
|
||||
We only support the `mistral-common` tokenizer for Supervised Fine-tuning at the moment and for `type: chat_template` only.
|
||||
|
||||
The tokenizer does not work with `dataset.map` with multiprocessing, so we had to disable it. In addition, we do not support overriding tokens yet.
|
||||
|
||||
## Related Resources
|
||||
|
||||
- [MistralAI Magistral Blog](https://mistral.ai/news/magistral/)
|
||||
- [Axolotl Docs](https://docs.axolotl.ai)
|
||||
- [Axolotl Website](https://axolotl.ai)
|
||||
- [Axolotl GitHub](https://github.com/axolotl-ai-cloud/axolotl)
|
||||
- [Axolotl Discord](https://discord.gg/7m9sfhzaf3)
|
||||
|
||||
|
||||
## Future Work
|
||||
|
||||
- Add parity to Preference Tuning, RL, Multi-modal, etc.
|
||||
- Add parity to other tokenizer configs like overriding tokens.
|
||||
@@ -1,72 +0,0 @@
|
||||
base_model: mistralai/Magistral-Small-2506
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing:
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_transformer_layer_cls_to_wrap: MistralDecoderLayer
|
||||
fsdp_activation_checkpointing: true
|
||||
@@ -1,63 +0,0 @@
|
||||
base_model: mistralai/Magistral-Small-2506
|
||||
|
||||
# Enable to use mistral-common tokenizer
|
||||
tokenizer_use_mistral_common: true
|
||||
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
@@ -27,7 +27,7 @@ pad_to_sequence_len: false
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
@@ -1,341 +0,0 @@
|
||||
# Finetuning LLMs to output audio
|
||||
|
||||
In this example, we finetune Orpcanopylabs/orpheus-tts-0.1-pretrained (a LLaMA 3.2 3b model) to output audio.
|
||||
|
||||
The `finetune.yml` withe current settings will run on any Nvidia GPU with 45GB VRAM or more. If you adjust the batch size it can easily run on any GPU under 24GB.
|
||||
|
||||
## Dataset pre-processing for pre-training
|
||||
If you are adding another voice in English, please jump ahead to finetuning pre-processing.
|
||||
|
||||
For this to work, we need to preprocess our dataset. Since we are expecting to output audio, we will need to add tokens to the tokenizer.
|
||||
|
||||
Using this code, it will download the SNAC model and add the correct tokens and upload the final dataset.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from snac import SNAC
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from datasets import load_dataset
|
||||
import random
|
||||
import torchaudio.transforms as T
|
||||
from transformers import AutoTokenizer
|
||||
import os
|
||||
|
||||
my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
|
||||
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"
|
||||
|
||||
dsn = my_original_dataset_name
|
||||
|
||||
snapshot_download(
|
||||
repo_id=dsn,
|
||||
repo_type="dataset",
|
||||
revision="main",
|
||||
max_workers=64,
|
||||
)
|
||||
|
||||
|
||||
ds = load_dataset(dsn, split="train")
|
||||
ds_sample_rate = ds[0]["audio"]["sampling_rate"]
|
||||
|
||||
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
|
||||
model = model.to("mps")
|
||||
|
||||
def tokenise_audio(waveform):
|
||||
waveform = torch.from_numpy(waveform).unsqueeze(0)
|
||||
waveform = waveform.to(dtype=torch.float32)
|
||||
resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
|
||||
waveform = resample_transform(waveform)
|
||||
|
||||
waveform = waveform.unsqueeze(0).to("cuda")
|
||||
|
||||
#generate the codes from snac
|
||||
with torch.inference_mode():
|
||||
codes = model.encode(waveform)
|
||||
|
||||
all_codes = []
|
||||
for i in range(codes[0].shape[1]):
|
||||
all_codes.append(codes[0][0][i].item()+128266)
|
||||
all_codes.append(codes[1][0][2*i].item()+128266+4096)
|
||||
all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
|
||||
all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
|
||||
all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
|
||||
all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
|
||||
all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
|
||||
|
||||
|
||||
return all_codes
|
||||
|
||||
def add_codes(example):
|
||||
# Always initialize codes_list to None
|
||||
codes_list = None
|
||||
|
||||
try:
|
||||
answer_audio = example.get("audio")
|
||||
# If there's a valid audio array, tokenise it
|
||||
if answer_audio and "array" in answer_audio:
|
||||
audio_array = answer_audio["array"]
|
||||
codes_list = tokenise_audio(audio_array)
|
||||
except Exception as e:
|
||||
print(f"Skipping row due to error: {e}")
|
||||
# Keep codes_list as None if we fail
|
||||
example["codes_list"] = codes_list
|
||||
|
||||
return example
|
||||
|
||||
ds = ds.map(add_codes, remove_columns=["audio"])
|
||||
|
||||
#@title Load Tokenizer
|
||||
tokeniser_length = 128256
|
||||
start_of_text = 128000
|
||||
end_of_text = 128009
|
||||
|
||||
start_of_speech = tokeniser_length + 1
|
||||
end_of_speech = tokeniser_length + 2
|
||||
|
||||
start_of_human = tokeniser_length + 3
|
||||
end_of_human = tokeniser_length + 4
|
||||
|
||||
start_of_ai = tokeniser_length + 5
|
||||
end_of_ai = tokeniser_length + 6
|
||||
pad_token = tokeniser_length + 7
|
||||
|
||||
audio_tokens_start = tokeniser_length + 10
|
||||
|
||||
tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
num_proc = os.cpu_count() - 2
|
||||
|
||||
ds = ds.filter(lambda x: x["codes_list"] is not None)
|
||||
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)
|
||||
|
||||
#@title Create Input Ids
|
||||
def remove_duplicate_frames(example):
|
||||
vals = example["codes_list"]
|
||||
if len(vals) % 7 != 0:
|
||||
raise ValueError("Input list length must be divisible by 7")
|
||||
|
||||
result = vals[:7]
|
||||
|
||||
removed_frames = 0
|
||||
|
||||
for i in range(7, len(vals), 7):
|
||||
current_first = vals[i]
|
||||
previous_first = result[-7]
|
||||
|
||||
if current_first != previous_first:
|
||||
result.extend(vals[i:i+7])
|
||||
else:
|
||||
removed_frames += 1
|
||||
|
||||
example["codes_list"] = result
|
||||
|
||||
return example
|
||||
|
||||
ds = ds.map(remove_duplicate_frames, num_proc=num_proc)
|
||||
|
||||
|
||||
def create_input_ids(example):
|
||||
text_ids = tokenizer.encode({example['text']}, add_special_tokens=True)
|
||||
text_ids.append(end_of_text)
|
||||
example["text_tokens"] = text_ids
|
||||
input_ids = (
|
||||
[start_of_human]
|
||||
+ example["text_tokens"]
|
||||
+ [end_of_human]
|
||||
+ [start_of_ai]
|
||||
+ [start_of_speech]
|
||||
+ example["codes_list"]
|
||||
+ [end_of_speech]
|
||||
+ [end_of_ai]
|
||||
)
|
||||
example["input_ids"] = input_ids
|
||||
example["labels"] = input_ids
|
||||
example["attention_mask"] = [1] * len(input_ids)
|
||||
|
||||
return example
|
||||
|
||||
ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])
|
||||
|
||||
#@title Remove unnecessary columns
|
||||
columns_to_keep = ["input_ids", "labels", "attention_mask"]
|
||||
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]
|
||||
|
||||
ds = ds.remove_columns(columns_to_remove)
|
||||
|
||||
ds.push_to_hub(name_to_push_dataset_to)
|
||||
```
|
||||
|
||||
|
||||
## Finetune pre-processing
|
||||
Use this code to add a new voice.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from snac import SNAC
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from datasets import load_dataset
|
||||
import random
|
||||
import torchaudio.transforms as T
|
||||
from transformers import AutoTokenizer
|
||||
import os
|
||||
|
||||
my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
|
||||
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"
|
||||
|
||||
dsn = my_original_dataset_name
|
||||
|
||||
snapshot_download(
|
||||
repo_id=dsn,
|
||||
repo_type="dataset",
|
||||
revision="main",
|
||||
max_workers=64,
|
||||
)
|
||||
|
||||
|
||||
ds = load_dataset(dsn, split="train")
|
||||
ds_sample_rate = ds[0]["audio"]["sampling_rate"]
|
||||
|
||||
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
|
||||
model = model.to("mps")
|
||||
|
||||
def tokenise_audio(waveform):
|
||||
waveform = torch.from_numpy(waveform).unsqueeze(0)
|
||||
waveform = waveform.to(dtype=torch.float32)
|
||||
resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
|
||||
waveform = resample_transform(waveform)
|
||||
|
||||
waveform = waveform.unsqueeze(0).to("cuda")
|
||||
|
||||
#generate the codes from snac
|
||||
with torch.inference_mode():
|
||||
codes = model.encode(waveform)
|
||||
|
||||
all_codes = []
|
||||
for i in range(codes[0].shape[1]):
|
||||
all_codes.append(codes[0][0][i].item()+128266)
|
||||
all_codes.append(codes[1][0][2*i].item()+128266+4096)
|
||||
all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
|
||||
all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
|
||||
all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
|
||||
all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
|
||||
all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
|
||||
|
||||
|
||||
return all_codes
|
||||
|
||||
def add_codes(example):
|
||||
# Always initialize codes_list to None
|
||||
codes_list = None
|
||||
|
||||
try:
|
||||
answer_audio = example.get("audio")
|
||||
# If there's a valid audio array, tokenise it
|
||||
if answer_audio and "array" in answer_audio:
|
||||
audio_array = answer_audio["array"]
|
||||
codes_list = tokenise_audio(audio_array)
|
||||
except Exception as e:
|
||||
print(f"Skipping row due to error: {e}")
|
||||
# Keep codes_list as None if we fail
|
||||
example["codes_list"] = codes_list
|
||||
|
||||
return example
|
||||
|
||||
ds = ds.map(add_codes, remove_columns=["audio"])
|
||||
|
||||
#@title Load Tokenizer
|
||||
tokeniser_length = 128256
|
||||
start_of_text = 128000
|
||||
end_of_text = 128009
|
||||
|
||||
start_of_speech = tokeniser_length + 1
|
||||
end_of_speech = tokeniser_length + 2
|
||||
|
||||
start_of_human = tokeniser_length + 3
|
||||
end_of_human = tokeniser_length + 4
|
||||
|
||||
start_of_ai = tokeniser_length + 5
|
||||
end_of_ai = tokeniser_length + 6
|
||||
pad_token = tokeniser_length + 7
|
||||
|
||||
audio_tokens_start = tokeniser_length + 10
|
||||
|
||||
tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
num_proc = os.cpu_count() - 2
|
||||
|
||||
ds = ds.filter(lambda x: x["codes_list"] is not None)
|
||||
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)
|
||||
|
||||
#@title Create Input Ids
|
||||
def remove_duplicate_frames(example):
|
||||
vals = example["codes_list"]
|
||||
if len(vals) % 7 != 0:
|
||||
raise ValueError("Input list length must be divisible by 7")
|
||||
|
||||
result = vals[:7]
|
||||
|
||||
removed_frames = 0
|
||||
|
||||
for i in range(7, len(vals), 7):
|
||||
current_first = vals[i]
|
||||
previous_first = result[-7]
|
||||
|
||||
if current_first != previous_first:
|
||||
result.extend(vals[i:i+7])
|
||||
else:
|
||||
removed_frames += 1
|
||||
|
||||
example["codes_list"] = result
|
||||
|
||||
return example
|
||||
|
||||
ds = ds.map(remove_duplicate_frames, num_proc=num_proc)
|
||||
|
||||
tok_info = '''*** HERE you can modify the text prompt
|
||||
i.e. if you wanted a multispeaker model like canopylabs/orpheus-3b-0.1-ft, you can pass:
|
||||
f"{example["source"]}: {example["text"]}", as is passed.
|
||||
'''
|
||||
print(tok_info)
|
||||
|
||||
def create_input_ids(example):
|
||||
text_ids = tokenizer.encode(f"{example['speaker_id']}: {example['text']}", add_special_tokens=True)
|
||||
text_ids.append(end_of_text)
|
||||
example["text_tokens"] = text_ids
|
||||
input_ids = (
|
||||
[start_of_human]
|
||||
+ example["text_tokens"]
|
||||
+ [end_of_human]
|
||||
+ [start_of_ai]
|
||||
+ [start_of_speech]
|
||||
+ example["codes_list"]
|
||||
+ [end_of_speech]
|
||||
+ [end_of_ai]
|
||||
)
|
||||
example["input_ids"] = input_ids
|
||||
example["labels"] = input_ids
|
||||
example["attention_mask"] = [1] * len(input_ids)
|
||||
|
||||
return example
|
||||
|
||||
ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])
|
||||
|
||||
#@title Remove unnecessary columns
|
||||
columns_to_keep = ["input_ids", "labels", "attention_mask"]
|
||||
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]
|
||||
|
||||
ds = ds.remove_columns(columns_to_remove)
|
||||
|
||||
ds.push_to_hub(name_to_push_dataset_to)
|
||||
```
|
||||
|
||||
## Training
|
||||
After preprocessing is done, fill out the blanks in finetune.yml and simply run `axolotl train finetune.yml`
|
||||
|
||||
## Inference
|
||||
For inference, please refer to the original [orpheus github](https://github.com/canopyai/Orpheus-TTS/tree/main).
|
||||
@@ -1,52 +0,0 @@
|
||||
base_model: canopylabs/orpheus-3b-0.1-pretrained
|
||||
|
||||
hub_model_id: <your-hub-model-id>
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
datasets:
|
||||
- path: <your-hf-dataset-id>
|
||||
type: # leave empty to load pre-tokenized
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 8
|
||||
micro_batch_size: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: auto
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 5
|
||||
saves_per_epoch: 5
|
||||
weight_decay: 0.05
|
||||
|
||||
special_tokens:
|
||||
pad_token: <custom_token_7>
|
||||
@@ -25,7 +25,7 @@ pad_to_sequence_len: false
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
@@ -25,7 +25,7 @@ pad_to_sequence_len: false
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
|
||||
@@ -2,6 +2,7 @@ base_model: Qwen/Qwen2.5-0.5B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
|
||||
chat_template: qwen_25
|
||||
rl: dpo
|
||||
datasets:
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
base_model: Qwen/Qwen3-32B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
strict: false
|
||||
|
||||
chat_template: qwen3
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_property_mappings:
|
||||
role: from
|
||||
content: value
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
dataset_prepared_path: last_run_prepared
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
load_in_4bit: true
|
||||
adapter: qlora
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- k_proj
|
||||
- v_proj
|
||||
- o_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
lora_mlp_kernel: true
|
||||
lora_qkv_kernel: true
|
||||
lora_o_kernel: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_4bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: offload
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
special_tokens:
|
||||
@@ -1,78 +0,0 @@
|
||||
base_model: Qwen/Qwen3-8B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
output_dir: ./outputs/qat_out/
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
flex_attention: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
flex_attn_compile_kwargs:
|
||||
dynamic: false
|
||||
mode: max-autotune-no-cudagraphs
|
||||
|
||||
qat:
|
||||
activation_dtype: int8
|
||||
weight_dtype: int4
|
||||
group_size: 256
|
||||
fake_quant_after_n_steps: 1000
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
max_steps: 2000
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 2e-5
|
||||
|
||||
bf16: true
|
||||
tf32: true
|
||||
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
|
||||
evals_per_epoch: 1
|
||||
saves_per_epoch: 1
|
||||
|
||||
warmup_steps: 10
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_activation_checkpointing: true
|
||||
|
||||
special_tokens:
|
||||
@@ -1,68 +0,0 @@
|
||||
base_model: Qwen/Qwen3-8B
|
||||
# Automatically upload checkpoint and final model to HF
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 64
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
bf16: auto
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
resume_from_checkpoint:
|
||||
logging_steps: 1
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
saves_per_epoch: 1
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
- full_shard
|
||||
- auto_wrap
|
||||
fsdp_config:
|
||||
fsdp_limit_all_gathers: true
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_offload_params: true
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
special_tokens:
|
||||
@@ -6,20 +6,20 @@ triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.5.10
|
||||
liger-kernel==0.5.8
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub==0.32.2
|
||||
peft==0.15.2
|
||||
transformers==4.52.3
|
||||
peft==0.15.1
|
||||
transformers==4.51.3
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.7.0
|
||||
datasets==3.6.0
|
||||
deepspeed>=0.17.0
|
||||
trl==0.18.1
|
||||
hf_xet==1.1.2
|
||||
accelerate==1.6.0
|
||||
datasets==3.5.0
|
||||
deepspeed>=0.15.4
|
||||
trl==0.16.1
|
||||
hf_xet==1.0.0
|
||||
hqq==0.2.5
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
@@ -62,10 +62,8 @@ langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.10.0
|
||||
torchao==0.9.0
|
||||
schedulefree==1.4.1
|
||||
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.3
|
||||
|
||||
mistral-common==1.6.0
|
||||
|
||||
@@ -9,8 +9,6 @@ except ImportError as exc:
|
||||
raise ImportError("Install torch via `pip install torch`") from exc
|
||||
from packaging.version import Version as V
|
||||
|
||||
USE_UV = "--uv" in sys.argv[1:]
|
||||
|
||||
v = V(torch.__version__)
|
||||
|
||||
# no cut-cross-entropy support for torch < 2.4.0
|
||||
@@ -25,9 +23,7 @@ if cce_spec:
|
||||
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
|
||||
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
|
||||
|
||||
UV_PREFIX = "uv " if USE_UV else ""
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a1174ca"'
|
||||
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"'
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
|
||||
@@@@ @@@@@@@@@@@@@@@@
|
||||
|
||||
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory is empty, run the following commands:
|
||||
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
|
||||
|
||||
```
|
||||
cd /workspace
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
# noqa
|
||||
# pylint: skip-file
|
||||
import sys
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError("Install torch via `pip install torch`")
|
||||
from packaging.version import Version as V
|
||||
|
||||
use_uv = "--uv" in sys.argv[1:]
|
||||
|
||||
v = V(torch.__version__)
|
||||
cuda = str(torch.version.cuda)
|
||||
try:
|
||||
@@ -35,7 +31,6 @@ elif v < V("2.6.0"):
|
||||
else:
|
||||
raise RuntimeError(f"Torch = {v} too new!")
|
||||
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
||||
uv_prefix = "uv " if use_uv else ""
|
||||
print(
|
||||
f'{uv_prefix}pip install unsloth-zoo==2024.12.1 && {uv_prefix}pip install --no-deps "unsloth[{x}]==2024.12.4"'
|
||||
f'pip install unsloth-zoo==2024.12.1 && pip install --no-deps "unsloth[{x}]==2024.12.4"'
|
||||
)
|
||||
|
||||
10
setup.py
10
setup.py
@@ -67,13 +67,13 @@ def parse_requirements(extras_require_map):
|
||||
if (major, minor) >= (2, 7):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
extras_require_map["vllm"] = ["vllm==0.8.3"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append(
|
||||
"xformers==0.0.29.post2"
|
||||
) # vllm needs post2 w torch 2.6
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
extras_require_map["vllm"] = ["vllm==0.8.3"]
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
@@ -118,7 +118,7 @@ extras_require = {
|
||||
"yunchang==0.6.0",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.17.0",
|
||||
"deepspeed==0.15.4",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
@@ -142,7 +142,6 @@ extras_require = {
|
||||
"apollo-torch",
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
"came_pytorch==0.1.3",
|
||||
],
|
||||
"ray": [
|
||||
"ray[train]",
|
||||
@@ -150,9 +149,6 @@ extras_require = {
|
||||
"vllm": [
|
||||
"vllm==0.7.2",
|
||||
],
|
||||
"llmcompressor": [
|
||||
"llmcompressor==0.5.1",
|
||||
],
|
||||
}
|
||||
|
||||
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||
|
||||
@@ -4,4 +4,4 @@ import pkgutil
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__) # Make this a namespace package
|
||||
|
||||
__version__ = "0.10.0"
|
||||
__version__ = "0.8.0"
|
||||
|
||||
@@ -2,7 +2,4 @@
|
||||
|
||||
import os
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
configure_logging()
|
||||
|
||||
@@ -28,6 +28,7 @@ class TrainerCliArgs:
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=0)
|
||||
merge_lora: bool = field(default=False)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
main_process_port: Optional[int] = field(default=None)
|
||||
@@ -81,32 +82,6 @@ class VllmServeCliArgs:
|
||||
"hardware support this feature."
|
||||
},
|
||||
)
|
||||
serve_module: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Module to serve. If not set, the default module will be used."
|
||||
},
|
||||
)
|
||||
|
||||
enable_reasoning: Optional[bool] = field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
reasoning_parser: Optional[str] = field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantizeCliArgs:
|
||||
"""Dataclass with CLI arguments for `axolotl quantize` command."""
|
||||
|
||||
base_model: Optional[str] = field(default=None)
|
||||
weight_dtype: Optional[str] = field(default=None)
|
||||
activation_dtype: Optional[str] = field(default=None)
|
||||
quantize_embedding: Optional[bool] = field(default=None)
|
||||
group_size: Optional[int] = field(default=None)
|
||||
output_dir: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -16,15 +16,8 @@ AXOLOTL_LOGO = """
|
||||
@@@@ @@@@@@@@@@@@@@@@
|
||||
"""
|
||||
|
||||
HAS_PRINTED_LOGO = False
|
||||
|
||||
|
||||
def print_axolotl_text_art():
|
||||
"""Prints axolotl ASCII art."""
|
||||
|
||||
global HAS_PRINTED_LOGO # pylint: disable=global-statement
|
||||
if HAS_PRINTED_LOGO:
|
||||
return
|
||||
if is_main_process():
|
||||
HAS_PRINTED_LOGO = True
|
||||
print(AXOLOTL_LOGO)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Various checks for Axolotl CLI."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -7,9 +8,10 @@ from accelerate.commands.config import config_args
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.logging_config import configure_logging
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
configure_logging()
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_accelerate_default_config() -> None:
|
||||
|
||||
@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
|
||||
return res
|
||||
|
||||
def get_image(self):
|
||||
docker_tag = "main-py3.11-cu124-2.6.0"
|
||||
docker_tag = "main-py3.11-cu124-2.5.1"
|
||||
if self.config.docker_tag:
|
||||
docker_tag = self.config.docker_tag
|
||||
docker_image = f"axolotlai/axolotl:{docker_tag}"
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Configuration loading and processing."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -21,12 +21,11 @@ from axolotl.utils.config import (
|
||||
validate_config,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = get_logger(__name__, use_environ=True)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
|
||||
@@ -119,12 +118,12 @@ def choose_config(path: Path) -> str:
|
||||
)
|
||||
|
||||
if len(yaml_files) == 1:
|
||||
LOG.info(f"Using default YAML file '{yaml_files[0]}'")
|
||||
print(f"Using default YAML file '{yaml_files[0]}'")
|
||||
return str(yaml_files[0])
|
||||
|
||||
LOG.info("Choose a YAML file:")
|
||||
print("Choose a YAML file:")
|
||||
for idx, file in enumerate(yaml_files):
|
||||
LOG.info(f"{idx + 1}. {file}")
|
||||
print(f"{idx + 1}. {file}")
|
||||
|
||||
chosen_file = None
|
||||
while chosen_file is None:
|
||||
@@ -133,9 +132,9 @@ def choose_config(path: Path) -> str:
|
||||
if 1 <= choice <= len(yaml_files):
|
||||
chosen_file = str(yaml_files[choice - 1])
|
||||
else:
|
||||
LOG.info("Invalid choice. Please choose a number from the list.")
|
||||
print("Invalid choice. Please choose a number from the list.")
|
||||
except ValueError:
|
||||
LOG.info("Invalid input. Please enter a number.")
|
||||
print("Invalid input. Please enter a number.")
|
||||
|
||||
return chosen_file
|
||||
|
||||
@@ -153,15 +152,7 @@ def prepare_plugins(cfg: DictDefault):
|
||||
plugin_manager.register(plugin_name)
|
||||
|
||||
|
||||
def plugin_set_cfg(cfg: DictDefault):
|
||||
if cfg.get("plugins"):
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.cfg = cfg
|
||||
|
||||
|
||||
def load_cfg(
|
||||
config: str | Path | DictDefault = Path("examples/"), **kwargs
|
||||
) -> DictDefault:
|
||||
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
|
||||
"""
|
||||
Loads the `axolotl` configuration stored at `config`, validates it, and performs
|
||||
various setup.
|
||||
@@ -173,24 +164,13 @@ def load_cfg(
|
||||
Returns:
|
||||
`DictDefault` mapping configuration keys to values.
|
||||
"""
|
||||
if isinstance(config, (str, Path)):
|
||||
config = check_remote_config(config)
|
||||
if Path(config).is_dir():
|
||||
config = choose_config(Path(config))
|
||||
config = check_remote_config(config)
|
||||
if Path(config).is_dir():
|
||||
config = choose_config(Path(config))
|
||||
|
||||
# Load the config from the yaml file
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
|
||||
cfg.axolotl_config_path = config
|
||||
else:
|
||||
cfg = config
|
||||
with NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||
) as temp_file:
|
||||
temp_file.write(yaml.dump(config.to_dict()))
|
||||
temp_file.close()
|
||||
cfg.axolotl_config_path = temp_file.name
|
||||
# Load the config from the yaml file
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
|
||||
# If there are any options passed in the cli, if it is something that seems valid
|
||||
# from the yaml, then overwrite the value
|
||||
@@ -204,6 +184,8 @@ def load_cfg(
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
|
||||
cfg.axolotl_config_path = config
|
||||
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
||||
@@ -231,6 +213,5 @@ def load_cfg(
|
||||
setup_wandb_env_vars(cfg)
|
||||
setup_mlflow_env_vars(cfg)
|
||||
setup_comet_env_vars(cfg)
|
||||
plugin_set_cfg(cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""CLI to run evaluation on a model."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@@ -14,11 +14,9 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.evaluate import evaluate
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
@@ -31,14 +29,10 @@ def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
check_user_token()
|
||||
|
||||
if cfg.rl:
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""CLI to run inference on a trained model."""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
@@ -21,9 +22,8 @@ from axolotl.utils.chat_templates import (
|
||||
get_chat_template_from_config,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_multi_line_input() -> str:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
import tempfile
|
||||
@@ -16,7 +17,6 @@ import axolotl
|
||||
from axolotl.cli.args import (
|
||||
EvaluateCliArgs,
|
||||
PreprocessCliArgs,
|
||||
QuantizeCliArgs,
|
||||
TrainerCliArgs,
|
||||
VllmServeCliArgs,
|
||||
)
|
||||
@@ -28,13 +28,11 @@ from axolotl.cli.utils import (
|
||||
fetch_from_github,
|
||||
filter_none_kwargs,
|
||||
)
|
||||
from axolotl.cli.vllm_serve import do_vllm_serve
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
|
||||
@@ -58,8 +56,6 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
|
||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||
config options.
|
||||
"""
|
||||
patch_optimized_env()
|
||||
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_preprocess
|
||||
|
||||
@@ -105,7 +101,7 @@ def train(
|
||||
config options.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
if "use_ray" in kwargs and kwargs["use_ray"]:
|
||||
accelerate = False
|
||||
@@ -179,7 +175,7 @@ def train(
|
||||
|
||||
do_cli(config=cfg_file, **kwargs)
|
||||
except subprocess.CalledProcessError as exc:
|
||||
LOG.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
|
||||
if not sweep:
|
||||
raise exc
|
||||
|
||||
@@ -331,21 +327,9 @@ def fetch(directory: str, dest: Optional[str]) -> None:
|
||||
@add_options_from_dataclass(VllmServeCliArgs)
|
||||
@filter_none_kwargs
|
||||
def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
|
||||
from axolotl.cli.vllm_serve import do_vllm_serve
|
||||
|
||||
do_vllm_serve(config, cli_args)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(QuantizeCliArgs)
|
||||
@filter_none_kwargs
|
||||
def quantize(config: str, **cli_args: QuantizeCliArgs):
|
||||
from axolotl.cli.quantize import do_quantize
|
||||
|
||||
do_quantize(config, cli_args)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("model", type=click.Path(exists=True, path_type=str))
|
||||
@click.argument("output", type=click.Path(exists=False, path_type=str))
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
"""CLI to merge a trained LoRA into a base model."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_merge_lora(*, cfg: DictDefault) -> None:
|
||||
@@ -66,6 +68,12 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||
Raises:
|
||||
ValueError: If target directory for LoRA merged model does not exist.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
parser = transformers.HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
parsed_cli_args.merge_lora = True
|
||||
|
||||
parsed_cfg = load_cfg(
|
||||
config,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
@@ -10,6 +11,7 @@ import fire
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
|
||||
import transformers
|
||||
from accelerate.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
@@ -22,11 +24,11 @@ from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file as safe_save_file
|
||||
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
|
||||
@@ -195,6 +197,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parser = transformers.HfArgumentParser(TrainerCliArgs)
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
parsed_cli_args.merge_lora = True
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
|
||||
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""CLI to run preprocessing of a dataset."""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
@@ -17,12 +18,10 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.trainer import disable_datasets_caching
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
@@ -48,10 +47,7 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
with disable_datasets_caching():
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
if plugin_manager.load_datasets(cfg, preprocess=True):
|
||||
pass
|
||||
elif cfg.rl:
|
||||
if cfg.rl:
|
||||
load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
"""
|
||||
CLI to post-training quantize a model using torchao
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from axolotl.cli.art import print_axolotl_text_art
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def do_quantize(
|
||||
config: Union[Path, str],
|
||||
cli_args: dict,
|
||||
):
|
||||
"""
|
||||
Quantizes a model's model's weights
|
||||
|
||||
Args:
|
||||
config (Union[Path, str]): The path to the config file
|
||||
cli_args (dict): Additional command-line arguments
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
|
||||
cfg = load_cfg(config)
|
||||
|
||||
if cfg.qat and cfg.quantization:
|
||||
raise ValueError(
|
||||
"QAT and quantization cannot be used together. Please specify only one of qat or quantization in your config file."
|
||||
)
|
||||
|
||||
if cfg.qat:
|
||||
quantize_cfg = cfg.qat
|
||||
elif cfg.quantization:
|
||||
quantize_cfg = cfg.quantization
|
||||
else:
|
||||
raise ValueError(
|
||||
"No quantization configuration found. Please specify either qat or quantization in your config file."
|
||||
)
|
||||
|
||||
model_path = cli_args.get("model_path") or cfg.output_dir
|
||||
if weight_dtype := cli_args.get("weight_dtype"):
|
||||
weight_dtype = TorchIntDType[weight_dtype]
|
||||
else:
|
||||
weight_dtype = quantize_cfg.weight_dtype
|
||||
if activation_dtype := cli_args.get("activation_dtype"):
|
||||
activation_dtype = TorchIntDType[activation_dtype]
|
||||
else:
|
||||
activation_dtype = quantize_cfg.activation_dtype
|
||||
group_size = cli_args.get("group_size") or quantize_cfg.group_size
|
||||
quantize_embedding = (
|
||||
cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding
|
||||
)
|
||||
output_dir = cli_args.get("output_dir") or cfg.output_dir
|
||||
|
||||
LOG.info(f"Loading model from {model_path}...")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
|
||||
|
||||
LOG.info(
|
||||
f"Quantizing model with configuration: \n"
|
||||
f"\tweight_dtype: {weight_dtype}\n"
|
||||
f"\tactivation_dtype: {activation_dtype}\n"
|
||||
f"\tgroup_size: {group_size}\n"
|
||||
f"\tquantize_embedding: {quantize_embedding}"
|
||||
)
|
||||
|
||||
quantize_model_for_ptq(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
)
|
||||
|
||||
LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
|
||||
model.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
safe_serialization=False,
|
||||
progressbar=True,
|
||||
)
|
||||
tokenizer.save_pretrained(
|
||||
str(Path(output_dir) / "quantized"),
|
||||
safe_serialization=False,
|
||||
progressbar=True,
|
||||
)
|
||||
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")
|
||||
@@ -1,6 +1,6 @@
|
||||
"""CLI to run training on a model."""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
@@ -17,10 +17,12 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.train import train
|
||||
from axolotl.utils import patch_optimized_env
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
"""
|
||||
@@ -33,27 +35,21 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
cli_args: Training-specific CLI arguments.
|
||||
"""
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
patch_optimized_env()
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
|
||||
if not dataset_meta:
|
||||
if cfg.rl:
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
if cfg.rl:
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
del model, tokenizer, trainer
|
||||
|
||||
gc.collect()
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import concurrent.futures
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import NoneType
|
||||
@@ -19,12 +20,12 @@ from transformers import (
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
configure_logging()
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def strip_optional_type(field_type: type | str | None):
|
||||
@@ -319,8 +320,7 @@ def load_model_and_tokenizer(
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
LOG.info("loading model...")
|
||||
model_loader = ModelLoader(cfg, tokenizer, inference=inference)
|
||||
model, _ = model_loader.load()
|
||||
model, _ = load_model(cfg, tokenizer, inference=inference)
|
||||
|
||||
processor = None
|
||||
if cfg.is_multimodal:
|
||||
|
||||
@@ -2,27 +2,15 @@
|
||||
CLI to start the vllm server for online RL
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import trl
|
||||
from trl.scripts.vllm_serve import ScriptArguments
|
||||
from trl.scripts.vllm_serve import main as vllm_serve_main
|
||||
|
||||
from axolotl.cli.config import load_cfg
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlScriptArguments(ScriptArguments):
|
||||
"""
|
||||
Additional arguments for the VLLM server
|
||||
"""
|
||||
|
||||
reasoning_parser: str = field(default="", kw_only=True)
|
||||
enable_reasoning: bool | None = field(default=None, kw_only=True)
|
||||
|
||||
|
||||
def do_vllm_serve(
|
||||
config: Union[Path, str],
|
||||
cli_args: dict,
|
||||
@@ -37,13 +25,9 @@ def do_vllm_serve(
|
||||
Returns:
|
||||
process_id: the process id of the started VLLM server
|
||||
"""
|
||||
patch_vllm_worker()
|
||||
cfg = load_cfg(config)
|
||||
model = cfg.base_model
|
||||
|
||||
serve_module = cli_args.get("serve_module", "trl.scripts.vllm_serve")
|
||||
vllm_serve_main = getattr(__import__(serve_module, fromlist=["main"]), "main")
|
||||
|
||||
tensor_parallel_size = (
|
||||
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||
)
|
||||
@@ -57,16 +41,9 @@ def do_vllm_serve(
|
||||
enable_prefix_caching = (
|
||||
cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching
|
||||
)
|
||||
reasoning_parser = (
|
||||
cli_args.get("reasoning_parser") or cfg.vllm.reasoning_parser or ""
|
||||
)
|
||||
enable_reasoning = (
|
||||
cli_args.get("enable_reasoning") or cfg.vllm.enable_reasoning or False
|
||||
)
|
||||
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
vllm_script_args = AxolotlScriptArguments(
|
||||
model=model,
|
||||
vllm_script_args = ScriptArguments(
|
||||
model,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
host=host,
|
||||
port=port,
|
||||
@@ -74,67 +51,5 @@ def do_vllm_serve(
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
reasoning_parser=reasoning_parser,
|
||||
enable_reasoning=enable_reasoning,
|
||||
)
|
||||
vllm_serve_main(vllm_script_args)
|
||||
|
||||
|
||||
def patch_vllm_worker():
|
||||
from multiprocessing.connection import Connection
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
def llm_worker(
|
||||
script_args: AxolotlScriptArguments,
|
||||
data_parallel_rank: int,
|
||||
master_port: int,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
# Set required environment variables for DP to work with vLLM
|
||||
os.environ["VLLM_DP_RANK"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank)
|
||||
os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size)
|
||||
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
||||
|
||||
llm = LLM(
|
||||
model=script_args.model,
|
||||
revision=script_args.revision,
|
||||
tensor_parallel_size=script_args.tensor_parallel_size,
|
||||
gpu_memory_utilization=script_args.gpu_memory_utilization,
|
||||
enforce_eager=script_args.enforce_eager,
|
||||
dtype=script_args.dtype,
|
||||
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
||||
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
||||
# This is particularly useful here because we generate completions from the same prompts.
|
||||
enable_prefix_caching=script_args.enable_prefix_caching,
|
||||
kv_cache_dtype=script_args.kv_cache_dtype,
|
||||
max_model_len=script_args.max_model_len,
|
||||
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
|
||||
enable_reasoning=script_args.enable_reasoning,
|
||||
reasoning_parser=script_args.reasoning_parser,
|
||||
)
|
||||
|
||||
# Send ready signal to parent process
|
||||
connection.send({"status": "ready"})
|
||||
|
||||
while True:
|
||||
# Wait for commands from the parent process
|
||||
try:
|
||||
command = connection.recv()
|
||||
except KeyboardInterrupt:
|
||||
llm.collective_rpc(method="close_communicator")
|
||||
break
|
||||
|
||||
# Handle commands
|
||||
if command["type"] in ["call", "fire_and_forget"]:
|
||||
method_name = command["method"]
|
||||
args, kwargs = command.get("args", ()), command.get("kwargs", {})
|
||||
method = getattr(llm, method_name)
|
||||
result = method(*args, **kwargs)
|
||||
if command["type"] == "call":
|
||||
connection.send(result)
|
||||
elif command["type"] == "shutdown":
|
||||
break
|
||||
|
||||
trl.scripts.vllm_serve.llm_worker = llm_worker
|
||||
|
||||
@@ -11,6 +11,5 @@ MOE_ARCH_BLOCK = {
|
||||
],
|
||||
"mixtral": "MixtralSparseMoeBlock",
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Various shared constants"""
|
||||
"""
|
||||
Various shared constants
|
||||
"""
|
||||
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
"""Dataset loading utilities."""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
from datasets import Dataset
|
||||
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
from axolotl.utils.models import load_processor, load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -28,7 +29,16 @@ class TrainDatasetMeta:
|
||||
|
||||
|
||||
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
||||
"""Randomly sample `num_samples` samples with replacement from `dataset`."""
|
||||
"""
|
||||
Randomly sample `num_samples` samples from `dataset`.
|
||||
|
||||
Args:
|
||||
dataset: Dataset.
|
||||
num_samples: Number of samples to return.
|
||||
|
||||
Returns:
|
||||
Random sample (with replacement) of examples in `dataset`.
|
||||
"""
|
||||
return dataset.select(
|
||||
[random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec
|
||||
)
|
||||
@@ -37,27 +47,29 @@ def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
||||
def load_datasets(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
|
||||
debug: bool = False,
|
||||
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
||||
) -> TrainDatasetMeta:
|
||||
"""Loads one or more training or evaluation datasets, calling
|
||||
`axolotl.utils.data.prepare_datasets`. Optionally, logs out debug information.
|
||||
"""
|
||||
Loads one or more training or evaluation datasets, calling
|
||||
`axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
cli_args: Command-specific CLI arguments.
|
||||
debug: Whether to print out tokenization of sample. This is duplicated in
|
||||
`cfg` and `cli_args`, but is kept due to use in our Colab notebooks.
|
||||
|
||||
Returns:
|
||||
Dataclass with fields for training and evaluation datasets and the computed
|
||||
`total_num_steps`.
|
||||
`total_num_steps`.
|
||||
"""
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
||||
preprocess_iterable = getattr(cli_args, "iterable", False)
|
||||
preprocess_iterable = (
|
||||
hasattr(cli_args, "iterable")
|
||||
and cli_args.iterable is not None
|
||||
and cli_args.iterable
|
||||
)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||
cfg,
|
||||
tokenizer,
|
||||
processor=processor,
|
||||
@@ -65,22 +77,19 @@ def load_datasets(
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.debug
|
||||
or getattr(cli_args, "debug", False)
|
||||
or getattr(cli_args, "debug_text_only", False)
|
||||
or getattr(cli_args, "debug_num_examples", 0) > 0
|
||||
or debug
|
||||
cli_args.debug
|
||||
or cfg.debug
|
||||
or cli_args.debug_text_only
|
||||
or int(cli_args.debug_num_examples) > 0
|
||||
):
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||
text_only = cli_args.debug_text_only if cli_args else False
|
||||
train_samples = sample_dataset(train_dataset, num_examples)
|
||||
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
|
||||
check_dataset_labels(
|
||||
train_samples,
|
||||
tokenizer,
|
||||
num_examples=num_examples,
|
||||
text_only=text_only,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
)
|
||||
|
||||
LOG.info("printing prompters...")
|
||||
@@ -95,10 +104,13 @@ def load_datasets(
|
||||
|
||||
|
||||
def load_preference_datasets(
|
||||
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs | None = None
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
|
||||
) -> TrainDatasetMeta:
|
||||
"""Loads one or more training or evaluation datasets for RL training using paired
|
||||
preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`.
|
||||
"""
|
||||
Loads one or more training or evaluation datasets for RL training using paired
|
||||
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
|
||||
Optionally, logs out debug information.
|
||||
|
||||
Args:
|
||||
@@ -109,28 +121,23 @@ def load_preference_datasets(
|
||||
Dataclass with fields for training and evaluation datasets and the computed
|
||||
`total_num_steps`.
|
||||
"""
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)
|
||||
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
|
||||
total_num_steps: Optional[int] = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
if cfg.rl == "grpo":
|
||||
total_num_steps = None
|
||||
|
||||
total_num_steps: int | None = None
|
||||
if cfg.rl is not RLType.GRPO:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
|
||||
if (cli_args and cli_args.debug) or cfg.debug:
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
|
||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||
text_only = cli_args.debug_text_only if cli_args else False
|
||||
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
train_samples = sample_dataset(train_dataset, num_examples)
|
||||
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
|
||||
check_dataset_labels(
|
||||
dataset=train_samples,
|
||||
tokenizer=tokenizer,
|
||||
num_examples=num_examples,
|
||||
text_only=text_only,
|
||||
train_samples,
|
||||
tokenizer,
|
||||
num_examples=cli_args.debug_num_examples,
|
||||
text_only=cli_args.debug_text_only,
|
||||
rl_mode=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Trainer builder classes"""
|
||||
|
||||
from .causal import HFCausalTrainerBuilder
|
||||
from .rl import HFRLTrainerBuilder
|
||||
|
||||
__all__ = ["HFCausalTrainerBuilder", "HFRLTrainerBuilder"]
|
||||
@@ -1,508 +0,0 @@
|
||||
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Base class for trainer builder"""
|
||||
|
||||
import abc
|
||||
import importlib
|
||||
import logging
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.trainer.lr import patch_trainer_get_lr
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
GCCallback,
|
||||
GPUStatsCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
)
|
||||
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
with suppress(ImportError):
|
||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""Base class for trainer builder."""
|
||||
|
||||
def __init__(self, cfg, model, tokenizer, processor=None):
|
||||
self.cfg = cfg
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.processor = processor
|
||||
|
||||
self._train_dataset = None
|
||||
self._eval_dataset = None
|
||||
self._model_ref = None
|
||||
self._peft_config = None
|
||||
|
||||
# If the model supports tagging, add the axolotl tag.
|
||||
# This makes sure the tag is correctly pushed even if a user calls
|
||||
# model.push_to_hub instead of trainer.push_to_hub.
|
||||
if hasattr(model, "add_model_tags"):
|
||||
model.add_model_tags(["axolotl"])
|
||||
|
||||
patch_trainer_get_lr()
|
||||
|
||||
@property
|
||||
def model_ref(self):
|
||||
return self._model_ref
|
||||
|
||||
@model_ref.setter
|
||||
def model_ref(self, model):
|
||||
self._model_ref = model
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
return self._train_dataset
|
||||
|
||||
@train_dataset.setter
|
||||
def train_dataset(self, dataset):
|
||||
self._train_dataset = dataset
|
||||
|
||||
@property
|
||||
def eval_dataset(self):
|
||||
return self._eval_dataset
|
||||
|
||||
@eval_dataset.setter
|
||||
def eval_dataset(self, dataset):
|
||||
self._eval_dataset = dataset
|
||||
|
||||
@property
|
||||
def peft_config(self):
|
||||
return self._peft_config
|
||||
|
||||
@peft_config.setter
|
||||
def peft_config(self, peft_config):
|
||||
self._peft_config = peft_config
|
||||
|
||||
@abstractmethod
|
||||
def build(self, total_num_steps):
|
||||
pass
|
||||
|
||||
def get_callbacks(self) -> list[TrainerCallback]:
|
||||
callbacks = []
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||
)
|
||||
|
||||
if self.cfg.profiler_steps:
|
||||
callbacks.append(
|
||||
PytorchProfilerCallback(
|
||||
steps_to_profile=self.cfg.profiler_steps,
|
||||
)
|
||||
)
|
||||
|
||||
if self.cfg.gc_steps:
|
||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_mlflow and is_mlflow_available():
|
||||
from axolotl.utils.callbacks.mlflow_ import (
|
||||
SaveAxolotlConfigtoMlflowCallback,
|
||||
)
|
||||
|
||||
callbacks.extend(
|
||||
[
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
|
||||
]
|
||||
)
|
||||
if self.cfg.use_comet and is_comet_available():
|
||||
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
||||
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
callbacks.append(GPUStatsCallback(cfg=self.cfg))
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
"""
|
||||
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
||||
"""
|
||||
callbacks = []
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
[
|
||||
cb
|
||||
for cb in plugin_manager.add_callbacks_post_trainer(
|
||||
self.cfg, trainer
|
||||
)
|
||||
if cb
|
||||
]
|
||||
)
|
||||
return callbacks
|
||||
|
||||
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
||||
# TODO
|
||||
return training_arguments_kwargs
|
||||
|
||||
def hook_post_create_training_args(self, training_arguments):
|
||||
# TODO
|
||||
return training_arguments
|
||||
|
||||
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
|
||||
# TODO
|
||||
return trainer_kwargs, trainer_cls
|
||||
|
||||
def hook_post_create_trainer(self, trainer):
|
||||
# TODO
|
||||
return trainer
|
||||
|
||||
def _configure_warmup_and_logging(
|
||||
self, total_num_steps: int, training_args_kwargs: dict
|
||||
):
|
||||
warmup_steps = 0
|
||||
warmup_ratio = 0.0
|
||||
if self.cfg.warmup_steps:
|
||||
warmup_steps = self.cfg.warmup_steps
|
||||
elif self.cfg.warmup_ratio:
|
||||
if total_num_steps:
|
||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||
else:
|
||||
warmup_ratio = self.cfg.warmup_ratio
|
||||
elif total_num_steps:
|
||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||
else:
|
||||
warmup_ratio = 0.03
|
||||
|
||||
if warmup_steps == 1:
|
||||
warmup_steps = 2
|
||||
|
||||
if self.cfg.logging_steps is not None:
|
||||
training_args_kwargs["logging_steps"] = self.cfg.logging_steps
|
||||
else:
|
||||
training_args_kwargs["logging_steps"] = (
|
||||
500 # transformers defaults to 500
|
||||
if not total_num_steps
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
|
||||
training_args_kwargs["warmup_ratio"] = warmup_ratio
|
||||
training_args_kwargs["warmup_steps"] = warmup_steps
|
||||
|
||||
def _configure_precision_settings(self, training_args_kwargs: dict):
|
||||
training_args_kwargs["fp16"] = (self.cfg.fp16 and not self.cfg.bf16) or False
|
||||
training_args_kwargs["tf32"] = self.cfg.tf32
|
||||
if self.cfg.bf16 == "full":
|
||||
training_args_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
|
||||
|
||||
def _configure_scheduler(self, training_args_kwargs: dict):
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
|
||||
training_args_kwargs["lr_scheduler_type"] = "cosine"
|
||||
training_args_kwargs["alternate_lr_scheduler_type"] = self.cfg.lr_scheduler
|
||||
else:
|
||||
training_args_kwargs["lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
||||
)
|
||||
training_args_kwargs["lr_scheduler_kwargs"] = (
|
||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||
)
|
||||
|
||||
def _configure_optimizer(self, training_args_kwargs: dict, trainer_kwargs: dict):
|
||||
def _configure_custom_optimizer(
|
||||
training_args_kwargs: dict, trainer_kwargs: dict
|
||||
):
|
||||
# Common optimizer kwargs
|
||||
optimizer_kwargs = {
|
||||
"lr": training_args_kwargs["learning_rate"],
|
||||
"weight_decay": training_args_kwargs["weight_decay"],
|
||||
}
|
||||
|
||||
# Adam-specific kwargs
|
||||
adam_kwargs: dict = {}
|
||||
if training_args_kwargs.get("adam_beta1") and training_args_kwargs.get(
|
||||
"adam_beta2"
|
||||
):
|
||||
adam_kwargs["betas"] = (
|
||||
training_args_kwargs.get("adam_beta1"),
|
||||
training_args_kwargs.get("adam_beta2"),
|
||||
)
|
||||
if training_args_kwargs.get("adam_epsilon"):
|
||||
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
|
||||
|
||||
if self.cfg.optimizer == "muon":
|
||||
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
|
||||
optimizer_cls = MuonOptimizerFactory
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "optimi_adamw":
|
||||
from optimi import AdamW
|
||||
|
||||
optimizer_kwargs["foreach"] = False
|
||||
optimizer_cls = AdamW
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_4bit":
|
||||
# TODO remove 20250401
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||
|
||||
optimizer_cls = AdamW4bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
|
||||
LOG.warning(
|
||||
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
||||
)
|
||||
elif self.cfg.optimizer == "ao_adamw_8bit":
|
||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||
|
||||
optimizer_cls = AdamW8bit
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||
|
||||
optimizer_cls = AdamWFp8
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "adopt_adamw":
|
||||
from axolotl.utils.optimizers.adopt import ADOPT
|
||||
|
||||
optimizer_cls = ADOPT
|
||||
adam_kwargs["decouple"] = True
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "came_pytorch":
|
||||
from came_pytorch import CAME
|
||||
|
||||
optimizer_cls = CAME
|
||||
|
||||
beta1 = training_args_kwargs.get("adam_beta1", 0.9)
|
||||
beta2 = training_args_kwargs.get("adam_beta2", 0.999)
|
||||
beta3 = training_args_kwargs.get("adam_beta3", 0.9999)
|
||||
eps1 = training_args_kwargs.get("adam_epsilon", 1e-30)
|
||||
eps2 = training_args_kwargs.get("adam_epsilon2", 1e-16)
|
||||
adam_kwargs["betas"] = (beta1, beta2, beta3)
|
||||
adam_kwargs["eps"] = (eps1, eps2)
|
||||
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unhandled optimizer: {self.cfg.optimizer}. Please raise an Issue."
|
||||
)
|
||||
|
||||
# Parse any additional optimizer args from config
|
||||
if self.cfg.optim_args:
|
||||
if isinstance(self.cfg.optim_args, dict):
|
||||
optimizer_kwargs.update(self.cfg.optim_args)
|
||||
else:
|
||||
# Parse string format "key1=value1,key2=value2"
|
||||
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
|
||||
key, value = mapping.split("=")
|
||||
optimizer_kwargs[key] = value
|
||||
|
||||
# Note: This is not used in training_args_kwargs, but in trainer_kwargs
|
||||
trainer_kwargs["optimizer_cls_and_kwargs"] = (
|
||||
optimizer_cls,
|
||||
optimizer_kwargs,
|
||||
)
|
||||
|
||||
# Handle custom optimizer
|
||||
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
|
||||
if self.cfg.optimizer in custom_supported_optimizers:
|
||||
_configure_custom_optimizer(training_args_kwargs, trainer_kwargs)
|
||||
else:
|
||||
# Use transformers' optimizer
|
||||
training_args_kwargs["optim"] = self.cfg.optimizer
|
||||
|
||||
# Parse any additional optimizer args from config
|
||||
if self.cfg.optim_args:
|
||||
if isinstance(self.cfg.optim_args, dict):
|
||||
optim_args = ",".join(
|
||||
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
||||
)
|
||||
else:
|
||||
optim_args = self.cfg.optim_args
|
||||
training_args_kwargs["optim_args"] = optim_args
|
||||
|
||||
if (
|
||||
self.cfg.optimizer == "adamw_anyprecision"
|
||||
and Path(self.cfg.torchdistx_path).exists()
|
||||
):
|
||||
sys.path.append(self.cfg.torchdistx_path)
|
||||
importlib.import_module("torchdistx")
|
||||
|
||||
def _configure_hub_parameters(self, training_args_kwargs: dict):
|
||||
if self.cfg.hub_model_id:
|
||||
training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||
training_args_kwargs["push_to_hub"] = True
|
||||
training_args_kwargs["hub_private_repo"] = True
|
||||
training_args_kwargs["hub_always_push"] = True
|
||||
|
||||
if self.cfg.hub_strategy:
|
||||
training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||
|
||||
def _configure_save_and_eval_strategy(self, training_args_kwargs: dict):
|
||||
# save_strategy and save_steps
|
||||
if self.cfg.save_steps:
|
||||
training_args_kwargs["save_strategy"] = "steps"
|
||||
training_args_kwargs["save_steps"] = self.cfg.save_steps
|
||||
elif self.cfg.save_strategy:
|
||||
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
|
||||
else:
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_args_kwargs["save_total_limit"] = (
|
||||
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
|
||||
)
|
||||
|
||||
# eval_strategy and eval_steps
|
||||
if not self.eval_dataset and self.cfg.val_set_size == 0:
|
||||
# do not eval if no eval_dataset and val_set_size=0
|
||||
training_args_kwargs["eval_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
training_args_kwargs["eval_strategy"] = "steps"
|
||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
training_args_kwargs["eval_on_start"] = True
|
||||
elif self.cfg.eval_strategy:
|
||||
training_args_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||
training_args_kwargs["eval_on_start"] = True
|
||||
|
||||
def _configure_reporting(self, training_args_kwargs: dict):
|
||||
report_to = []
|
||||
if self.cfg.use_wandb:
|
||||
report_to.append("wandb")
|
||||
if self.cfg.use_mlflow:
|
||||
report_to.append("mlflow")
|
||||
if self.cfg.use_tensorboard:
|
||||
report_to.append("tensorboard")
|
||||
if self.cfg.use_comet:
|
||||
report_to.append("comet_ml")
|
||||
|
||||
training_args_kwargs["report_to"] = report_to
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
elif self.cfg.use_mlflow:
|
||||
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||
else:
|
||||
training_args_kwargs["run_name"] = None
|
||||
|
||||
def _configure_torch_compile(self, training_args_kwargs: dict):
|
||||
if self.cfg.torch_compile and getattr(torch, "_dynamo", None):
|
||||
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||
True
|
||||
)
|
||||
training_args_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||
if self.cfg.torch_compile_backend:
|
||||
training_args_kwargs["torch_compile_backend"] = (
|
||||
self.cfg.torch_compile_backend
|
||||
)
|
||||
if self.cfg.torch_compile_mode:
|
||||
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
|
||||
|
||||
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_args_kwargs["gradient_checkpointing"] = (
|
||||
self.cfg.gradient_checkpointing
|
||||
)
|
||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||
training_args_kwargs["gradient_checkpointing_kwargs"] = (
|
||||
self.cfg.gradient_checkpointing_kwargs
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["gradient_checkpointing_kwargs"] = {
|
||||
"use_reentrant": False
|
||||
}
|
||||
|
||||
def _set_base_training_args(
|
||||
self, total_num_steps
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
training_args_kwargs: dict[str, Any] = {}
|
||||
trainer_kwargs: dict[str, Any] = {}
|
||||
|
||||
self._configure_warmup_and_logging(total_num_steps, training_args_kwargs)
|
||||
self._configure_precision_settings(training_args_kwargs)
|
||||
self._configure_save_and_eval_strategy(training_args_kwargs)
|
||||
self._configure_gradient_checkpointing(training_args_kwargs)
|
||||
|
||||
# set arg into trainer_args_kwargs with same name if value not None
|
||||
for arg in [
|
||||
# optim/scheduler
|
||||
"adam_beta1",
|
||||
"adam_beta2",
|
||||
"adam_beta3",
|
||||
"adam_epsilon",
|
||||
"adam_epsilon2",
|
||||
"cosine_min_lr_ratio",
|
||||
"cosine_constant_lr_ratio",
|
||||
"optim_target_modules",
|
||||
# trainer
|
||||
"max_grad_norm",
|
||||
"dataloader_num_workers",
|
||||
"dataloader_pin_memory",
|
||||
"dataloader_prefetch_factor",
|
||||
"gradient_accumulation_steps",
|
||||
"learning_rate",
|
||||
"embedding_lr",
|
||||
"embedding_lr_scale",
|
||||
"lr_groups",
|
||||
"loraplus_lr_ratio",
|
||||
"loraplus_lr_embedding",
|
||||
"output_dir",
|
||||
"save_safetensors",
|
||||
"save_only_model",
|
||||
"include_tokens_per_second",
|
||||
"weight_decay",
|
||||
"seed",
|
||||
]:
|
||||
if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
|
||||
training_args_kwargs[arg] = getattr(self.cfg, arg)
|
||||
|
||||
training_args_kwargs["per_device_train_batch_size"] = self.cfg.micro_batch_size
|
||||
|
||||
if self.cfg.eval_batch_size:
|
||||
training_args_kwargs["per_device_eval_batch_size"] = (
|
||||
self.cfg.eval_batch_size
|
||||
)
|
||||
|
||||
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
|
||||
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
|
||||
if self.cfg.dataset_processes:
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
# max_length is not used in CausalTrainer
|
||||
if self.cfg.reward_model or self.cfg.rl:
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
self._configure_reporting(training_args_kwargs)
|
||||
self._configure_hub_parameters(training_args_kwargs)
|
||||
self._configure_scheduler(training_args_kwargs)
|
||||
self._configure_optimizer(training_args_kwargs, trainer_kwargs)
|
||||
self._configure_torch_compile(training_args_kwargs)
|
||||
|
||||
return training_args_kwargs, trainer_kwargs
|
||||
@@ -1,488 +0,0 @@
|
||||
"""Builder for causal trainers"""
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Type, Union
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
DataCollatorWithFlattening,
|
||||
EarlyStoppingCallback,
|
||||
)
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.builders.base import TrainerBuilderBase
|
||||
from axolotl.core.trainers import (
|
||||
AxolotlMambaTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
AxolotlRewardTrainer,
|
||||
AxolotlTrainer,
|
||||
ReLoRATrainer,
|
||||
)
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||
from axolotl.processing_strategies import get_processing_strategy
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
LossWatchDogCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
colab_inference_post_train_callback,
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
MambaDataCollator,
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
Build the HuggingFace training args/trainer for causal models and reward modeling
|
||||
using TRL.
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
callbacks.append(ReLoRACallback(self.cfg))
|
||||
|
||||
if (
|
||||
hasattr(self.model, "use_bettertransformer")
|
||||
and self.model.use_bettertransformer is True
|
||||
):
|
||||
callbacks.append(SaveBetterTransformerModelCallback())
|
||||
|
||||
# TODO: check if can move to base class
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
|
||||
if self.cfg.qat:
|
||||
callbacks.append(QATCallback(self.cfg.qat))
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer, "wandb"
|
||||
)
|
||||
callbacks.append(LogPredictionCallback(self.cfg))
|
||||
if (
|
||||
self.cfg.use_mlflow
|
||||
and is_mlflow_available()
|
||||
and self.cfg.eval_table_size > 0
|
||||
):
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer, "mlflow"
|
||||
)
|
||||
callbacks.append(LogPredictionCallback(self.cfg))
|
||||
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer, "comet_ml"
|
||||
)
|
||||
callbacks.append(LogPredictionCallback(self.cfg))
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
||||
if self.cfg.do_causal_lm_eval:
|
||||
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
|
||||
trainer, self.tokenizer
|
||||
)
|
||||
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
|
||||
|
||||
if self.cfg.early_stopping_patience:
|
||||
early_stop_cb = EarlyStoppingCallback(
|
||||
self.cfg.early_stopping_patience,
|
||||
)
|
||||
callbacks.append(early_stop_cb)
|
||||
|
||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||
callbacks.append(lisa_callback_factory(trainer))
|
||||
|
||||
if any("COLAB_" in key for key in os.environ):
|
||||
ColabCallback = colab_inference_post_train_callback(trainer)
|
||||
callbacks.append(ColabCallback(self.cfg))
|
||||
|
||||
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self):
|
||||
"""
|
||||
Gets the trainer class for the given configuration.
|
||||
"""
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
if trainer_cls:
|
||||
return trainer_cls
|
||||
if self.cfg.relora_steps:
|
||||
return ReLoRATrainer
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return AxolotlMambaTrainer
|
||||
if self.cfg.reward_model:
|
||||
return AxolotlRewardTrainer
|
||||
if self.cfg.process_reward_model:
|
||||
return AxolotlPRMTrainer
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlPRMConfig,
|
||||
AxolotlRewardConfig,
|
||||
AxolotlTrainingArguments,
|
||||
)
|
||||
|
||||
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
|
||||
total_num_steps
|
||||
)
|
||||
|
||||
if self.cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||
if self.cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = {
|
||||
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
|
||||
}
|
||||
|
||||
if self.cfg.adapter == "qlora":
|
||||
training_arguments_kwargs["qlora"] = True
|
||||
|
||||
# deepspeed
|
||||
if self.cfg.deepspeed:
|
||||
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
||||
|
||||
if self.cfg.lr_quadratic_warmup is not None:
|
||||
training_arguments_kwargs["lr_quadratic_warmup"] = (
|
||||
self.cfg.lr_quadratic_warmup
|
||||
)
|
||||
|
||||
if self.cfg.dataloader_drop_last is not None:
|
||||
training_arguments_kwargs["dataloader_drop_last"] = (
|
||||
self.cfg.dataloader_drop_last
|
||||
)
|
||||
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
||||
training_arguments_kwargs["dataloader_drop_last"] = True
|
||||
|
||||
if self.cfg.remove_unused_columns is not None:
|
||||
training_arguments_kwargs["remove_unused_columns"] = (
|
||||
self.cfg.remove_unused_columns
|
||||
)
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
||||
if self.cfg.bench_dataset:
|
||||
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
|
||||
if self.cfg.do_causal_lm_eval:
|
||||
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
|
||||
if self.cfg.metric_for_best_model:
|
||||
training_arguments_kwargs["metric_for_best_model"] = (
|
||||
self.cfg.metric_for_best_model
|
||||
)
|
||||
if self.cfg.greater_is_better:
|
||||
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
||||
|
||||
# DDP Config
|
||||
if self.cfg.ddp_timeout:
|
||||
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
|
||||
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
|
||||
if self.cfg.ddp_bucket_cap_mb:
|
||||
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
|
||||
if self.cfg.ddp_broadcast_buffers is not None:
|
||||
training_arguments_kwargs["ddp_broadcast_buffers"] = (
|
||||
self.cfg.ddp_broadcast_buffers
|
||||
)
|
||||
|
||||
# these are all the "standard" kwargs that are def used
|
||||
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||
|
||||
if self.cfg.auto_find_batch_size is not None:
|
||||
training_arguments_kwargs["auto_find_batch_size"] = (
|
||||
self.cfg.auto_find_batch_size
|
||||
)
|
||||
|
||||
training_arguments_kwargs["eval_accumulation_steps"] = (
|
||||
self.cfg.gradient_accumulation_steps
|
||||
)
|
||||
|
||||
training_arguments_kwargs["load_best_model_at_end"] = (
|
||||
(
|
||||
self.cfg.load_best_model_at_end is not False
|
||||
or self.cfg.early_stopping_patience
|
||||
)
|
||||
and (
|
||||
(not self.cfg.test_datasets and self.cfg.val_set_size > 0)
|
||||
or (self.cfg.test_datasets and self.cfg.val_set_size == 0)
|
||||
)
|
||||
and self.cfg.save_steps
|
||||
and self.cfg.eval_steps
|
||||
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
||||
) or False
|
||||
|
||||
# handle ddp
|
||||
ddp_find_unused_parameters = None
|
||||
if self.cfg.ddp:
|
||||
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
|
||||
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
||||
ddp_find_unused_parameters
|
||||
)
|
||||
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||
|
||||
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
|
||||
training_arguments_kwargs["multipack_real_batches"] = (
|
||||
self.cfg.multipack_real_batches
|
||||
if self.cfg.multipack_real_batches is not None
|
||||
else not self.cfg.flash_attention
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = bool(
|
||||
self.cfg.eval_sample_packing
|
||||
)
|
||||
if self.cfg.sample_packing_bin_size is not None:
|
||||
training_arguments_kwargs["sample_packing_bin_size"] = (
|
||||
self.cfg.sample_packing_bin_size
|
||||
)
|
||||
if self.cfg.sample_packing_group_size is not None:
|
||||
training_arguments_kwargs["sample_packing_group_size"] = (
|
||||
self.cfg.sample_packing_group_size
|
||||
)
|
||||
if self.cfg.sample_packing_eff_est:
|
||||
training_arguments_kwargs["sample_packing_efficiency"] = (
|
||||
self.cfg.sample_packing_eff_est
|
||||
)
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||
training_arguments_kwargs["relora_warmup_steps"] = (
|
||||
self.cfg.relora_warmup_steps
|
||||
)
|
||||
if self.cfg.relora_anneal_steps:
|
||||
training_arguments_kwargs["relora_anneal_steps"] = (
|
||||
self.cfg.relora_anneal_steps
|
||||
)
|
||||
if self.cfg.relora_prune_ratio:
|
||||
training_arguments_kwargs["relora_prune_ratio"] = (
|
||||
self.cfg.relora_prune_ratio
|
||||
)
|
||||
|
||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
||||
training_arguments_kwargs["lisa_step_interval"] = (
|
||||
self.cfg.lisa_step_interval
|
||||
)
|
||||
training_arguments_kwargs["lisa_layers_attribute"] = (
|
||||
self.cfg.lisa_layers_attribute
|
||||
)
|
||||
|
||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||
training_arguments_kwargs
|
||||
)
|
||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||
if self.cfg.chat_template:
|
||||
training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
|
||||
cfg=self.cfg,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
if self.cfg.neftune_noise_alpha is not None:
|
||||
training_arguments_kwargs["neftune_noise_alpha"] = (
|
||||
self.cfg.neftune_noise_alpha
|
||||
)
|
||||
|
||||
if self.cfg.accelerator_config:
|
||||
training_arguments_kwargs["accelerator_config"] = (
|
||||
self.cfg.accelerator_config
|
||||
)
|
||||
|
||||
if self.cfg.image_size:
|
||||
training_arguments_kwargs["image_size"] = self.cfg.image_size
|
||||
if self.cfg.image_resize_algorithm:
|
||||
training_arguments_kwargs["image_resize_algorithm"] = (
|
||||
self.cfg.image_resize_algorithm
|
||||
)
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_training_args = plugin_manager.get_training_args(self.cfg)
|
||||
if plugin_training_args:
|
||||
training_arguments_kwargs.update(plugin_training_args)
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
elif self.cfg.process_reward_model:
|
||||
training_args_cls = AxolotlPRMConfig
|
||||
else:
|
||||
training_args_cls = AxolotlTrainingArguments
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
training_args = self.hook_post_create_training_args(training_args)
|
||||
|
||||
# unset run_name so wandb sets up experiment names
|
||||
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
|
||||
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
|
||||
None
|
||||
)
|
||||
|
||||
data_collator_kwargs = {
|
||||
"padding": True, # True/"longest" is the default
|
||||
}
|
||||
multiple = 64
|
||||
if self.cfg.pad_to_sequence_len:
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple * math.ceil(
|
||||
self.cfg.sequence_len / multiple
|
||||
)
|
||||
else:
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = multiple
|
||||
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
|
||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||
trainer_kwargs, trainer_cls
|
||||
)
|
||||
if eval_data_collator := self.build_collator(
|
||||
training_args, is_eval=True, **data_collator_kwargs
|
||||
):
|
||||
if not (self.cfg.reward_model or self.cfg.process_reward_model):
|
||||
trainer_kwargs["eval_data_collator"] = eval_data_collator
|
||||
if not (self.cfg.reward_model or self.cfg.process_reward_model):
|
||||
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
)
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "processing_class" in sig.parameters:
|
||||
trainer_kwargs["processing_class"] = self.tokenizer
|
||||
elif "tokenizer" in sig.parameters:
|
||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
if (
|
||||
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
|
||||
and self.cfg.datasets is not None
|
||||
):
|
||||
trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
trainer = trainer_cls(
|
||||
model=self.model,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
args=training_args,
|
||||
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
||||
callbacks=self.get_callbacks(),
|
||||
**trainer_kwargs,
|
||||
)
|
||||
trainer = self.hook_post_create_trainer(trainer)
|
||||
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||
trainer.add_callback(callback)
|
||||
|
||||
if self.cfg.deepspeed and self.cfg.sample_packing:
|
||||
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||
"train_micro_batch_size_per_gpu"
|
||||
] = self.cfg.micro_batch_size
|
||||
|
||||
return trainer
|
||||
|
||||
def build_collator(
|
||||
self,
|
||||
training_args, # type: "AxolotlTrainingArguments" # type: ignore
|
||||
is_eval=False,
|
||||
**kwargs,
|
||||
):
|
||||
if training_args.pretraining:
|
||||
if (
|
||||
self.cfg.pretraining_sample_concatenation is False
|
||||
or self.cfg.micro_batch_size > 1
|
||||
):
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
return None
|
||||
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||
|
||||
use_batch_sampler_collator = False
|
||||
if is_eval is False and training_args.sample_packing:
|
||||
use_batch_sampler_collator = True
|
||||
if is_eval and training_args.eval_sample_packing:
|
||||
use_batch_sampler_collator = True
|
||||
|
||||
collator: Type[
|
||||
Union[
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
DataCollatorWithFlattening,
|
||||
RewardDataCollatorWithPadding,
|
||||
]
|
||||
]
|
||||
collator_args = [self.tokenizer]
|
||||
|
||||
collator_cls_and_kwargs = None
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs(
|
||||
self.cfg, is_eval=is_eval
|
||||
)
|
||||
|
||||
if collator_cls_and_kwargs:
|
||||
collator = collator_cls_and_kwargs[0]
|
||||
if kwargs and isinstance(kwargs, dict):
|
||||
kwargs.update(collator_cls_and_kwargs[1])
|
||||
elif self.cfg.reward_model:
|
||||
collator = RewardDataCollatorWithPadding
|
||||
elif use_batch_sampler_collator:
|
||||
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
|
||||
# supported multipack models, or non-flash-attention llama
|
||||
if (
|
||||
self.cfg.flex_attention
|
||||
or self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
or (
|
||||
self.cfg.model_config_type in ["llama"]
|
||||
and self.cfg.flash_attention is not True
|
||||
)
|
||||
):
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
if self.cfg.processor_type and self.processor:
|
||||
collator = MultiModalChatDataCollator
|
||||
kwargs["processing_strategy"] = get_processing_strategy(
|
||||
self.processor,
|
||||
training_args.chat_template,
|
||||
self.cfg.chat_template,
|
||||
image_size=training_args.image_size,
|
||||
image_resize_algorithm=training_args.image_resize_algorithm,
|
||||
)
|
||||
elif self.cfg.batch_flattening:
|
||||
collator = DataCollatorWithFlattening
|
||||
collator_args.pop(0)
|
||||
kwargs.pop("pad_to_multiple_of", None)
|
||||
kwargs.pop("padding", None)
|
||||
else:
|
||||
collator = DataCollatorForSeq2Seq
|
||||
|
||||
kwargs["return_tensors"] = "pt"
|
||||
|
||||
return collator(
|
||||
*collator_args,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -1,238 +0,0 @@
|
||||
"""Builder for RLHF trainers"""
|
||||
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.core.builders.base import TrainerBuilderBase
|
||||
from axolotl.core.trainers import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
)
|
||||
from axolotl.core.trainers.dpo import DPOStrategy
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import ensure_dtype
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
|
||||
if self.cfg.qat:
|
||||
callbacks.append(QATCallback(self.cfg.qat))
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self, trainer_kwargs: dict):
|
||||
"""
|
||||
Returns trainer_cls and trainer_cls_args
|
||||
"""
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
trainer_cls_args = [] # type: ignore
|
||||
|
||||
if trainer_cls is not None:
|
||||
return trainer_cls, trainer_cls_args
|
||||
|
||||
trainer_cls = None
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
trainer_cls = DPOStrategy.get_trainer_class()
|
||||
trainer_cls_args.append(self.model_ref)
|
||||
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
trainer_cls = AxolotlKTOTrainer
|
||||
elif self.cfg.rl is RLType.SIMPO:
|
||||
trainer_cls = AxolotlCPOTrainer
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
return trainer_cls, trainer_cls_args
|
||||
|
||||
def _build_training_arguments(self, total_num_steps):
|
||||
"""
|
||||
Returns training_args and trainer_kwargs
|
||||
"""
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlCPOConfig,
|
||||
AxolotlKTOConfig,
|
||||
AxolotlORPOConfig,
|
||||
)
|
||||
|
||||
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
|
||||
total_num_steps=total_num_steps
|
||||
)
|
||||
|
||||
if self.cfg.remove_unused_columns is not None:
|
||||
training_args_kwargs["remove_unused_columns"] = (
|
||||
self.cfg.remove_unused_columns
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["remove_unused_columns"] = False
|
||||
|
||||
if self.cfg.trl and self.cfg.trl.beta is not None:
|
||||
training_args_kwargs["beta"] = self.cfg.trl.beta
|
||||
elif self.cfg.rl_beta is not None:
|
||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
||||
elif self.cfg.orpo_alpha is not None:
|
||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||
|
||||
if self.cfg.rpo_alpha is not None:
|
||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
training_args_kwargs["run_name"] = self.cfg.wandb_name
|
||||
|
||||
training_args_cls = None
|
||||
blocklist_args_kwargs = []
|
||||
if self.cfg.rl is RLType.SIMPO:
|
||||
training_args_cls = AxolotlCPOConfig
|
||||
training_args_kwargs["loss_type"] = "simpo"
|
||||
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
|
||||
training_args_kwargs["desirable_weight"] = (
|
||||
self.cfg.kto_desirable_weight or 1.0
|
||||
)
|
||||
training_args_kwargs["undesirable_weight"] = (
|
||||
self.cfg.kto_undesirable_weight or 1.0
|
||||
)
|
||||
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl is RLType.GRPO:
|
||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
training_args_kwargs.update(DPOStrategy.set_training_args_kwargs(self.cfg))
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
for blocklist_key in blocklist_args_kwargs:
|
||||
if blocklist_key in training_args_kwargs:
|
||||
del training_args_kwargs[blocklist_key]
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_training_args = plugin_manager.get_training_args(self.cfg)
|
||||
if plugin_training_args:
|
||||
training_args_kwargs.update(plugin_training_args)
|
||||
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
logging_first_step=True,
|
||||
**training_args_kwargs,
|
||||
)
|
||||
|
||||
# unset run_name so wandb sets up experiment names
|
||||
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
|
||||
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
|
||||
None
|
||||
)
|
||||
|
||||
return training_args, trainer_kwargs
|
||||
|
||||
def build(self, total_num_steps):
|
||||
training_args, trainer_kwargs = self._build_training_arguments(total_num_steps)
|
||||
|
||||
if self.eval_dataset:
|
||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
|
||||
trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||
self.cfg.precompute_ref_log_probs
|
||||
)
|
||||
|
||||
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
|
||||
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "tokenizer" in sig.parameters:
|
||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
else:
|
||||
trainer_kwargs["processing_class"] = self.tokenizer
|
||||
|
||||
if self.cfg.datasets is not None and (
|
||||
trainer_cls is DPOStrategy.get_trainer_class()
|
||||
):
|
||||
trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
|
||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||
trainer_kwargs, trainer_cls
|
||||
)
|
||||
|
||||
trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
train_dataset=self.train_dataset,
|
||||
callbacks=self.get_callbacks(),
|
||||
**trainer_kwargs,
|
||||
)
|
||||
if self.cfg.fsdp:
|
||||
ensure_dtype(trainer.model, dtype=self.cfg.torch_dtype)
|
||||
if self.cfg.rl in [RLType.DPO, RLType.IPO] and trainer.ref_model:
|
||||
ensure_dtype(trainer.ref_model, dtype=self.cfg.torch_dtype)
|
||||
|
||||
trainer = self.hook_post_create_trainer(trainer)
|
||||
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||
trainer.add_callback(callback)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
HF Factory class for PPO Trainer
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def build(self, total_num_steps):
|
||||
# TODO: build PPOConfig
|
||||
raise NotImplementedError("PPO trainer builder is not implemented yet.")
|
||||
@@ -156,6 +156,7 @@ class Messages(BaseModel):
|
||||
len(input_ids) : len(input_ids) + len(pending_input_ids)
|
||||
]
|
||||
if new_pending_inputs != pending_input_ids:
|
||||
# logging.warning("tokenization mismatch from concatenation.")
|
||||
pending_input_ids = new_pending_inputs
|
||||
input_ids.extend(pending_input_ids)
|
||||
if pending_weight:
|
||||
|
||||
1212
src/axolotl/core/trainer_builder.py
Executable file
1212
src/axolotl/core/trainer_builder.py
Executable file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user