Compare commits
67 Commits
tp_support
...
sp-rl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f30d3d33a | ||
|
|
ce07081d6c | ||
|
|
3ce43b6db9 | ||
|
|
7d0eb66b54 | ||
|
|
df119e3724 | ||
|
|
f4ae8816bb | ||
|
|
9b95e06cbb | ||
|
|
e0aba74dd0 | ||
|
|
328d598114 | ||
|
|
4d36ecc724 | ||
|
|
7acf93b59f | ||
|
|
b6fc46ada8 | ||
|
|
b35992262e | ||
|
|
ef6eb77cc8 | ||
|
|
5410195e0b | ||
|
|
cf0c79d52e | ||
|
|
4ba80a0e5a | ||
|
|
c49682132b | ||
|
|
e46239f8d3 | ||
|
|
05f03b541a | ||
|
|
a4e430e7c4 | ||
|
|
6cdcb8ddd5 | ||
|
|
a7811ad4a0 | ||
|
|
e2da821e67 | ||
|
|
2c34a4634e | ||
|
|
a9b0733f2c | ||
|
|
9f00465a5c | ||
|
|
86bac48d14 | ||
|
|
e44953d50c | ||
|
|
23f0c51d88 | ||
|
|
113e9cd193 | ||
|
|
61825a464a | ||
|
|
c907ac173e | ||
|
|
187227d837 | ||
|
|
f8de8bb4f2 | ||
|
|
8e604848a4 | ||
|
|
aae4337f40 | ||
|
|
38df5a36ea | ||
|
|
4d92a68a96 | ||
|
|
85147ec430 | ||
|
|
51cd409488 | ||
|
|
7235123d44 | ||
|
|
4f5eb42a73 | ||
|
|
fbe54be6b8 | ||
|
|
04f6324833 | ||
|
|
f0072f3b9d | ||
|
|
59899b9817 | ||
|
|
4a736986fa | ||
|
|
5d0f110a3b | ||
|
|
83f8698b8a | ||
|
|
60a11a6410 | ||
|
|
46a045e528 | ||
|
|
3b477e08a0 | ||
|
|
16dc6ee68d | ||
|
|
fa7c79b3b9 | ||
|
|
ae66374156 | ||
|
|
5e21b1a9da | ||
|
|
575e5f28ec | ||
|
|
0134093acc | ||
|
|
d4de93a7bb | ||
|
|
c8191394e9 | ||
|
|
f18231c653 | ||
|
|
9ed4f6b3aa | ||
|
|
05dddfc41d | ||
|
|
8e30917440 | ||
|
|
d883b11b6f | ||
|
|
f4910dd2ea |
14
.github/workflows/base.yml
vendored
14
.github/workflows/base.yml
vendored
@@ -40,6 +40,18 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "126"
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: nightly
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -61,7 +73,7 @@ jobs:
|
|||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v4
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./docker/Dockerfile-base
|
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|||||||
7
.github/workflows/docs.yml
vendored
7
.github/workflows/docs.yml
vendored
@@ -20,9 +20,12 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
- name: install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install jupyter
|
python3 -m pip install jupyter quartodoc
|
||||||
|
python3 -m pip install -e . --no-deps
|
||||||
|
- name: Build autodoc
|
||||||
|
run: quartodoc build
|
||||||
- name: Publish to GitHub Pages (and render)
|
- name: Publish to GitHub Pages (and render)
|
||||||
uses: quarto-dev/quarto-actions/publish@v2
|
uses: quarto-dev/quarto-actions/publish@v2
|
||||||
with:
|
with:
|
||||||
|
|||||||
7
.github/workflows/main.yml
vendored
7
.github/workflows/main.yml
vendored
@@ -25,12 +25,12 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras: vllm
|
axolotl_extras: vllm
|
||||||
is_latest: true
|
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -87,6 +87,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
3
.github/workflows/multi-gpu-e2e.yml
vendored
3
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -42,8 +42,7 @@ jobs:
|
|||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
# awaiting vllm#12721
|
axolotl_extras: vllm
|
||||||
axolotl_extras:
|
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
|
|||||||
5
.github/workflows/nightlies.yml
vendored
5
.github/workflows/nightlies.yml
vendored
@@ -80,6 +80,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
49
.github/workflows/precommit-autoupdate.yml
vendored
Normal file
49
.github/workflows/precommit-autoupdate.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
name: Pre-commit auto-update
|
||||||
|
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: '0 0 * * 0' # Run weekly
|
||||||
|
workflow_dispatch: # Manual kickoff
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
auto-update:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
pull-requests: write
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Update pre-commit hooks
|
||||||
|
id: update
|
||||||
|
run: |
|
||||||
|
pip install pre-commit
|
||||||
|
pre-commit autoupdate
|
||||||
|
if [[ -n $(git status --porcelain) ]]; then
|
||||||
|
echo "changes=true" >> $GITHUB_OUTPUT
|
||||||
|
git diff .pre-commit-config.yaml > pre-commit-update.diff
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Create Pull Request
|
||||||
|
if: steps.update.outputs.changes == 'true'
|
||||||
|
uses: peter-evans/create-pull-request@v6
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
branch: update/pre-commit-hooks
|
||||||
|
delete-branch: true
|
||||||
|
title: "chore: update pre-commit hooks"
|
||||||
|
commit-message: "chore: update pre-commit hooks"
|
||||||
|
body: |
|
||||||
|
Automated PR to update pre-commit hooks to their latest versions.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Changes:</summary>
|
||||||
|
|
||||||
|
```diff
|
||||||
|
${{ steps.update.outputs.diff }}
|
||||||
|
```
|
||||||
|
</details>
|
||||||
2
.github/workflows/pypi.yml
vendored
2
.github/workflows/pypi.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install wheel packaging
|
pip3 install wheel packaging==23.2
|
||||||
pip3 install --no-build-isolation -e .
|
pip3 install --no-build-isolation -e .
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
|
|||||||
27
.github/workflows/tests-nightly.yml
vendored
27
.github/workflows/tests-nightly.yml
vendored
@@ -33,6 +33,15 @@ jobs:
|
|||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Restore HF cache
|
||||||
|
id: hf-cache-restore
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
/home/runner/.cache/huggingface/hub/datasets--*
|
||||||
|
/home/runner/.cache/huggingface/hub/models--*
|
||||||
|
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
@@ -42,11 +51,11 @@ jobs:
|
|||||||
- name: upgrade pip
|
- name: upgrade pip
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging setuptools wheel
|
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
pip3 install torch==${{ matrix.pytorch_version }}
|
||||||
|
|
||||||
- name: Update requirements.txt
|
- name: Update requirements.txt
|
||||||
run: |
|
run: |
|
||||||
@@ -58,8 +67,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 show torch
|
||||||
pip3 install --upgrade packaging
|
|
||||||
pip3 install --no-build-isolation -U -e .
|
pip3 install --no-build-isolation -U -e .
|
||||||
python scripts/unsloth_install.py | sh
|
python scripts/unsloth_install.py | sh
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
@@ -73,10 +81,15 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
|
- name: Pre-Download dataset fixture
|
||||||
|
run: |
|
||||||
|
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest tests/patched/
|
pytest -v tests/patched/
|
||||||
|
pytest -v tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -136,4 +149,4 @@ jobs:
|
|||||||
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.tests
|
modal run cicd.e2e_tests
|
||||||
|
|||||||
27
.github/workflows/tests.yml
vendored
27
.github/workflows/tests.yml
vendored
@@ -63,7 +63,7 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
/home/runner/.cache/huggingface/hub/datasets--*
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
/home/runner/.cache/huggingface/hub/models--*
|
||||||
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
|
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
@@ -74,7 +74,7 @@ jobs:
|
|||||||
- name: upgrade pip
|
- name: upgrade pip
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging setuptools wheel
|
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
@@ -96,10 +96,15 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
|
- name: Pre-Download dataset fixture
|
||||||
|
run: |
|
||||||
|
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v tests/patched/
|
||||||
|
pytest -v tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -136,7 +141,7 @@ jobs:
|
|||||||
path: |
|
path: |
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
/home/runner/.cache/huggingface/hub/datasets--*
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
/home/runner/.cache/huggingface/hub/models--*
|
||||||
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
|
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
@@ -147,7 +152,7 @@ jobs:
|
|||||||
- name: upgrade pip
|
- name: upgrade pip
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
|
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
@@ -170,10 +175,14 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
|
- name: Show HF cache
|
||||||
|
run: huggingface-cli scan-cache
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest -v tests/patched/
|
pytest -v tests/patched/
|
||||||
|
pytest -v tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -227,7 +236,7 @@ jobs:
|
|||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.tests
|
modal run cicd.e2e_tests
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
@@ -251,7 +260,7 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras: vllm
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -274,4 +283,4 @@ jobs:
|
|||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
run: |
|
run: |
|
||||||
modal run cicd.tests
|
modal run cicd.e2e_tests
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -181,6 +181,10 @@ prepared-datasets/
|
|||||||
submit.sh
|
submit.sh
|
||||||
*.out*
|
*.out*
|
||||||
|
|
||||||
|
# Quartodoc generated files
|
||||||
|
objects.json
|
||||||
|
site_libs/
|
||||||
|
|
||||||
typings/
|
typings/
|
||||||
out/
|
out/
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
[settings]
|
[settings]
|
||||||
profile=black
|
profile=black
|
||||||
known_third_party=wandb,comet_ml
|
known_third_party=wandb,comet_ml
|
||||||
|
known_local_folder=src,tests
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ default_language_version:
|
|||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.4.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@@ -11,23 +11,23 @@ repos:
|
|||||||
- id: no-commit-to-branch
|
- id: no-commit-to-branch
|
||||||
args: ['--branch', 'main']
|
args: ['--branch', 'main']
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: 23.3.0
|
rev: 25.1.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 5.12.0
|
rev: 6.0.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 6.1.0
|
rev: 7.1.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/PyCQA/pylint
|
- repo: https://github.com/pylint-dev/pylint
|
||||||
rev: v3.3.0
|
rev: v3.3.6
|
||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.3.0
|
rev: v1.15.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
@@ -36,7 +36,7 @@ repos:
|
|||||||
'pydantic>=2.5.3',
|
'pydantic>=2.5.3',
|
||||||
]
|
]
|
||||||
- repo: https://github.com/PyCQA/bandit
|
- repo: https://github.com/PyCQA/bandit
|
||||||
rev: 1.7.5
|
rev: 1.8.3
|
||||||
hooks:
|
hooks:
|
||||||
- id: bandit
|
- id: bandit
|
||||||
args: [
|
args: [
|
||||||
|
|||||||
@@ -19,9 +19,6 @@
|
|||||||
<br/>
|
<br/>
|
||||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
|
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
|
||||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
|
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
|
||||||
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
|
|
||||||
<img alt="phorm.ai" src="https://img.shields.io/badge/Phorm-Ask_AI-%23F2777A.svg?&logo=data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iNSIgaGVpZ2h0PSI0IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgogIDxwYXRoIGQ9Ik00LjQzIDEuODgyYTEuNDQgMS40NCAwIDAgMS0uMDk4LjQyNmMtLjA1LjEyMy0uMTE1LjIzLS4xOTIuMzIyLS4wNzUuMDktLjE2LjE2NS0uMjU1LjIyNmExLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxMmMtLjA5OS4wMTItLjE5Mi4wMTQtLjI3OS4wMDZsLTEuNTkzLS4xNHYtLjQwNmgxLjY1OGMuMDkuMDAxLjE3LS4xNjkuMjQ2LS4xOTFhLjYwMy42MDMgMCAwIDAgLjItLjEwNi41MjkuNTI5IDAgMCAwIC4xMzgtLjE3LjY1NC42NTQgMCAwIDAgLjA2NS0uMjRsLjAyOC0uMzJhLjkzLjkzIDAgMCAwLS4wMzYtLjI0OS41NjcuNTY3IDAgMCAwLS4xMDMtLjIuNTAyLjUwMiAwIDAgMC0uMTY4LS4xMzguNjA4LjYwOCAwIDAgMC0uMjQtLjA2N0wyLjQzNy43MjkgMS42MjUuNjcxYS4zMjIuMzIyIDAgMCAwLS4yMzIuMDU4LjM3NS4zNzUgMCAwIDAtLjExNi4yMzJsLS4xMTYgMS40NS0uMDU4LjY5Ny0uMDU4Ljc1NEwuNzA1IDRsLS4zNTctLjA3OUwuNjAyLjkwNkMuNjE3LjcyNi42NjMuNTc0LjczOS40NTRhLjk1OC45NTggMCAwIDEgLjI3NC0uMjg1Ljk3MS45NzEgMCAwIDEgLjMzNy0uMTRjLjExOS0uMDI2LjIyNy0uMDM0LjMyNS0uMDI2TDMuMjMyLjE2Yy4xNTkuMDE0LjMzNi4wMy40NTkuMDgyYTEuMTczIDEuMTczIDAgMCAxIC41NDUuNDQ3Yy4wNi4wOTQuMTA5LjE5Mi4xNDQuMjkzYTEuMzkyIDEuMzkyIDAgMCAxIC4wNzguNThsLS4wMjkuMzJaIiBmaWxsPSIjRjI3NzdBIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+CiAgPHBhdGggZD0iTTQuMDgyIDIuMDA3YTEuNDU1IDEuNDU1IDAgMCAxLS4wOTguNDI3Yy0uMDUuMTI0LS4xMTQuMjMyLS4xOTIuMzI0YTEuMTMgMS4xMyAwIDAgMS0uMjU0LjIyNyAxLjM1MyAxLjM1MyAwIDAgMS0uNTk1LjIxNGMtLjEuMDEyLS4xOTMuMDE0LS4yOC4wMDZsLTEuNTYtLjEwOC4wMzQtLjQwNi4wMy0uMzQ4IDEuNTU5LjE1NGMuMDkgMCAuMTczLS4wMS4yNDgtLjAzM2EuNjAzLjYwMyAwIDAgMCAuMi0uMTA2LjUzMi41MzIgMCAwIDAgLjEzOS0uMTcyLjY2LjY2IDAgMCAwIC4wNjQtLjI0MWwuMDI5LS4zMjFhLjk0Ljk0IDAgMCAwLS4wMzYtLjI1LjU3LjU3IDAgMCAwLS4xMDMtLjIwMi41MDIuNTAyIDAgMCAwLS4xNjgtLjEzOC42MDUuNjA1IDAgMCAwLS4yNC0uMDY3TDEuMjczLjgyN2MtLjA5NC0uMDA4LS4xNjguMDEtLjIyMS4wNTUtLjA1My4wNDUtLjA4NC4xMTQtLjA5Mi4yMDZMLjcwNSA0IDAgMy45MzhsLjI1NS0yLjkxMUExLjAxIDEuMDEgMCAwIDEgLjM5My41NzIuOTYyLjk2MiAwIDAgMSAuNjY2LjI4NmEuOTcuOTcgMCAwIDEgLjMzOC0uMTRDMS4xMjIuMTIgMS4yMy4xMSAxLjMyOC4xMTlsMS41OTMuMTRjLjE2LjAxNC4zLjA0Ny40MjMuMWExLjE3IDEuMTcgMCAwIDEgLjU0NS40NDhjLjA2MS4wOTUuMTA5LjE5My4xNDQuMjk1YTEuNDA2IDEuNDA2IDAgMCAxIC4wNzcuNTgzbC0uMDI4LjMyMloiIGZpbGw9IndoaXRlIi8+Cjwvc3ZnPgo=">
|
|
||||||
</a>
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
Axolotl is a tool designed to streamline post-training for various AI models.
|
Axolotl is a tool designed to streamline post-training for various AI models.
|
||||||
@@ -58,6 +55,7 @@ Features:
|
|||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
|
|
||||||
# Download example axolotl configs, deepspeed configs
|
# Download example axolotl configs, deepspeed configs
|
||||||
@@ -99,6 +97,7 @@ That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github
|
|||||||
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
||||||
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
||||||
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
||||||
|
- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/docs/api/) - Auto-generated code documentation
|
||||||
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
||||||
|
|
||||||
## 🤝 Getting Help
|
## 🤝 Getting Help
|
||||||
|
|||||||
204
_quarto.yml
204
_quarto.yml
@@ -1,6 +1,180 @@
|
|||||||
project:
|
project:
|
||||||
type: website
|
type: website
|
||||||
|
|
||||||
|
quartodoc:
|
||||||
|
dir: docs/api
|
||||||
|
package: axolotl
|
||||||
|
title: API Reference
|
||||||
|
parser: google
|
||||||
|
|
||||||
|
sections:
|
||||||
|
- title: Core
|
||||||
|
desc: Core functionality for training
|
||||||
|
contents:
|
||||||
|
- train
|
||||||
|
- evaluate
|
||||||
|
- datasets
|
||||||
|
- convert
|
||||||
|
- prompt_tokenizers
|
||||||
|
- logging_config
|
||||||
|
- core.trainer_builder
|
||||||
|
- core.training_args
|
||||||
|
- core.chat.messages
|
||||||
|
- core.chat.format.chatml
|
||||||
|
- core.chat.format.llama3x
|
||||||
|
- core.chat.format.shared
|
||||||
|
- core.datasets.chat
|
||||||
|
- core.datasets.transforms.chat_builder
|
||||||
|
- title: CLI
|
||||||
|
desc: Command-line interface
|
||||||
|
contents:
|
||||||
|
- cli.main
|
||||||
|
- cli.train
|
||||||
|
- cli.evaluate
|
||||||
|
- cli.args
|
||||||
|
- cli.checks
|
||||||
|
- cli.config
|
||||||
|
- cli.inference
|
||||||
|
- cli.merge_lora
|
||||||
|
- cli.merge_sharded_fsdp_weights
|
||||||
|
- cli.preprocess
|
||||||
|
- cli.sweeps
|
||||||
|
- cli.utils
|
||||||
|
- cli.vllm_serve
|
||||||
|
- cli.cloud.base
|
||||||
|
- cli.cloud.modal_
|
||||||
|
- title: Trainers
|
||||||
|
desc: Training implementations
|
||||||
|
contents:
|
||||||
|
- core.trainers.base
|
||||||
|
- core.trainers.trl
|
||||||
|
- core.trainers.dpo.trainer
|
||||||
|
- core.trainers.grpo.trainer
|
||||||
|
- title: Prompt Strategies
|
||||||
|
desc: Prompt formatting strategies
|
||||||
|
contents:
|
||||||
|
- prompt_strategies.base
|
||||||
|
- prompt_strategies.chat_template
|
||||||
|
- prompt_strategies.alpaca_chat
|
||||||
|
- prompt_strategies.alpaca_instruct
|
||||||
|
- prompt_strategies.alpaca_w_system
|
||||||
|
- prompt_strategies.user_defined
|
||||||
|
- prompt_strategies.llama2_chat
|
||||||
|
- prompt_strategies.completion
|
||||||
|
- prompt_strategies.input_output
|
||||||
|
- prompt_strategies.stepwise_supervised
|
||||||
|
- prompt_strategies.metharme
|
||||||
|
- prompt_strategies.orcamini
|
||||||
|
- prompt_strategies.pygmalion
|
||||||
|
- prompt_strategies.messages.chat
|
||||||
|
- prompt_strategies.dpo.chat_template
|
||||||
|
- prompt_strategies.dpo.llama3
|
||||||
|
- prompt_strategies.dpo.chatml
|
||||||
|
- prompt_strategies.dpo.zephyr
|
||||||
|
- prompt_strategies.dpo.user_defined
|
||||||
|
- prompt_strategies.dpo.passthrough
|
||||||
|
- prompt_strategies.kto.llama3
|
||||||
|
- prompt_strategies.kto.chatml
|
||||||
|
- prompt_strategies.kto.user_defined
|
||||||
|
- prompt_strategies.orpo.chat_template
|
||||||
|
- prompt_strategies.bradley_terry.llama3
|
||||||
|
- title: Kernels
|
||||||
|
desc: Low-level performance optimizations
|
||||||
|
contents:
|
||||||
|
- kernels.lora
|
||||||
|
- kernels.geglu
|
||||||
|
- kernels.swiglu
|
||||||
|
- kernels.quantize
|
||||||
|
- kernels.utils
|
||||||
|
- title: MonkeyPatches
|
||||||
|
desc: Runtime patches for model optimizations
|
||||||
|
contents:
|
||||||
|
- monkeypatch.llama_attn_hijack_flash
|
||||||
|
- monkeypatch.llama_attn_hijack_xformers
|
||||||
|
- monkeypatch.mistral_attn_hijack_flash
|
||||||
|
- monkeypatch.multipack
|
||||||
|
- monkeypatch.relora
|
||||||
|
- monkeypatch.llama_expand_mask
|
||||||
|
- monkeypatch.lora_kernels
|
||||||
|
- monkeypatch.utils
|
||||||
|
- monkeypatch.btlm_attn_hijack_flash
|
||||||
|
- monkeypatch.llama_patch_multipack
|
||||||
|
- monkeypatch.stablelm_attn_hijack_flash
|
||||||
|
- monkeypatch.trainer_fsdp_optim
|
||||||
|
- monkeypatch.transformers_fa_utils
|
||||||
|
- monkeypatch.unsloth_
|
||||||
|
- monkeypatch.attention.mllama
|
||||||
|
- monkeypatch.data.batch_dataset_fetcher
|
||||||
|
- monkeypatch.mixtral
|
||||||
|
- title: Utils
|
||||||
|
desc: Utility functions
|
||||||
|
contents:
|
||||||
|
- utils.models
|
||||||
|
- utils.tokenization
|
||||||
|
- utils.chat_templates
|
||||||
|
- utils.lora
|
||||||
|
- utils.lora_embeddings
|
||||||
|
- utils.model_shard_quant
|
||||||
|
- utils.bench
|
||||||
|
- utils.freeze
|
||||||
|
- utils.trainer
|
||||||
|
- utils.schedulers
|
||||||
|
- utils.distributed
|
||||||
|
- utils.dict
|
||||||
|
- utils.optimizers.adopt
|
||||||
|
- utils.data.pretraining
|
||||||
|
- utils.data.sft
|
||||||
|
- utils.gradient_checkpointing.unsloth
|
||||||
|
- title: Schemas
|
||||||
|
desc: Pydantic data models for Axolotl config
|
||||||
|
contents:
|
||||||
|
- utils.schemas.config
|
||||||
|
- utils.schemas.model
|
||||||
|
- utils.schemas.training
|
||||||
|
- utils.schemas.datasets
|
||||||
|
- utils.schemas.peft
|
||||||
|
- utils.schemas.trl
|
||||||
|
- utils.schemas.multimodal
|
||||||
|
- utils.schemas.integrations
|
||||||
|
- utils.schemas.enums
|
||||||
|
- utils.schemas.utils
|
||||||
|
- title: Integrations
|
||||||
|
desc: Third-party integrations and extensions
|
||||||
|
contents:
|
||||||
|
- integrations.base
|
||||||
|
- integrations.cut_cross_entropy.args
|
||||||
|
- integrations.grokfast.optimizer
|
||||||
|
- integrations.kd.trainer
|
||||||
|
- integrations.liger.args
|
||||||
|
- integrations.lm_eval.args
|
||||||
|
- integrations.spectrum.args
|
||||||
|
- title: Common
|
||||||
|
desc: Common utilities and shared functionality
|
||||||
|
contents:
|
||||||
|
- common.architectures
|
||||||
|
- common.const
|
||||||
|
- common.datasets
|
||||||
|
- title: Models
|
||||||
|
desc: Custom model implementations
|
||||||
|
contents:
|
||||||
|
- models.mamba.modeling_mamba
|
||||||
|
- title: Data Processing
|
||||||
|
desc: Data processing utilities
|
||||||
|
contents:
|
||||||
|
- utils.collators.core
|
||||||
|
- utils.collators.batching
|
||||||
|
- utils.collators.mamba
|
||||||
|
- utils.collators.mm_chat
|
||||||
|
- utils.samplers.multipack
|
||||||
|
- title: Callbacks
|
||||||
|
desc: Training callbacks
|
||||||
|
contents:
|
||||||
|
- utils.callbacks.perplexity
|
||||||
|
- utils.callbacks.profiler
|
||||||
|
- utils.callbacks.lisa
|
||||||
|
- utils.callbacks.mlflow_
|
||||||
|
- utils.callbacks.comet_
|
||||||
|
|
||||||
website:
|
website:
|
||||||
title: "Axolotl"
|
title: "Axolotl"
|
||||||
description: "We make fine-tuning accessible, scalable, and fun"
|
description: "We make fine-tuning accessible, scalable, and fun"
|
||||||
@@ -32,14 +206,18 @@ website:
|
|||||||
contents:
|
contents:
|
||||||
- docs/getting-started.qmd
|
- docs/getting-started.qmd
|
||||||
- docs/installation.qmd
|
- docs/installation.qmd
|
||||||
- docs/cli.qmd
|
|
||||||
- docs/inference.qmd
|
- docs/inference.qmd
|
||||||
|
- docs/cli.qmd
|
||||||
|
- docs/config.qmd
|
||||||
|
- text: "API Reference"
|
||||||
|
href: docs/api
|
||||||
|
|
||||||
- section: "Dataset Formats"
|
- section: "Dataset Formats"
|
||||||
contents: docs/dataset-formats/*
|
contents: docs/dataset-formats/*
|
||||||
|
|
||||||
- section: "Deployments"
|
- section: "Deployments"
|
||||||
contents:
|
contents:
|
||||||
|
- docs/docker.qmd
|
||||||
- docs/multi-gpu.qmd
|
- docs/multi-gpu.qmd
|
||||||
- docs/multi-node.qmd
|
- docs/multi-node.qmd
|
||||||
- docs/ray-integration.qmd
|
- docs/ray-integration.qmd
|
||||||
@@ -66,6 +244,7 @@ website:
|
|||||||
- docs/unsloth.qmd
|
- docs/unsloth.qmd
|
||||||
- docs/torchao.qmd
|
- docs/torchao.qmd
|
||||||
- docs/custom_integrations.qmd
|
- docs/custom_integrations.qmd
|
||||||
|
- docs/sequence_parallelism.qmd
|
||||||
|
|
||||||
- section: "Troubleshooting"
|
- section: "Troubleshooting"
|
||||||
contents:
|
contents:
|
||||||
@@ -73,12 +252,27 @@ website:
|
|||||||
- docs/debugging.qmd
|
- docs/debugging.qmd
|
||||||
- docs/nccl.qmd
|
- docs/nccl.qmd
|
||||||
|
|
||||||
- section: "Reference"
|
|
||||||
contents:
|
|
||||||
- docs/config.qmd
|
|
||||||
|
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
theme: darkly
|
theme: darkly
|
||||||
css: styles.css
|
css: styles.css
|
||||||
toc: true
|
toc: true
|
||||||
|
# Enable better handling of line breaks in markdown
|
||||||
|
preserve-tabs: true
|
||||||
|
html-math-method: mathjax
|
||||||
|
# Improved markdown processing options
|
||||||
|
md-extensions:
|
||||||
|
- markdown_it
|
||||||
|
- def_list
|
||||||
|
- attr_list
|
||||||
|
- fenced_divs
|
||||||
|
- tables
|
||||||
|
- html_admonition
|
||||||
|
- lineblocks
|
||||||
|
- fancy_lists
|
||||||
|
# Control whitespace handling
|
||||||
|
whitespace: preserve
|
||||||
|
# Process newlines in paragraphs
|
||||||
|
wrap: preserve
|
||||||
|
# Better line break handling
|
||||||
|
preserve-linebreaks: true
|
||||||
|
|||||||
@@ -31,10 +31,11 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,ring-flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN python scripts/unsloth_install.py | sh
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
|||||||
@@ -3,9 +3,10 @@ set -e
|
|||||||
|
|
||||||
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
|
||||||
|
|
||||||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli /workspace/axolotl/tests/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/lora_kernels # running these with the other patches causes a failure
|
||||||
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
pytest -v --durations=10 --ignore=tests/e2e/patched/lora_kernels /workspace/axolotl/tests/e2e/patched
|
||||||
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
|
||||||
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
|
||||||
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
pytest -v --durations=10 /workspace/axolotl/tests/cli
|
||||||
|
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ --ignore=tests/cli /workspace/axolotl/tests/e2e/
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Modal app to run axolotl GPU tests"""
|
"""Modal app to run axolotl GPU tests"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
modal application to run axolotl gpu tests in Modal
|
modal application to run axolotl gpu tests in Modal
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|||||||
@@ -2,4 +2,5 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
# only run one test at a time so as not to OOM the GPU
|
# only run one test at a time so as not to OOM the GPU
|
||||||
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
|
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
|
||||||
|
pytest -v -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
|||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||||
|
|||||||
39
docker/Dockerfile-base-nightly
Normal file
39
docker/Dockerfile-base-nightly
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
ARG CUDA_VERSION="12.8.1"
|
||||||
|
ARG CUDNN_VERSION="8"
|
||||||
|
ARG UBUNTU_VERSION="22.04"
|
||||||
|
ARG MAX_JOBS=4
|
||||||
|
|
||||||
|
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||||
|
|
||||||
|
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||||
|
|
||||||
|
ARG PYTHON_VERSION="3.11"
|
||||||
|
ARG PYTORCH_VERSION="nightly"
|
||||||
|
ARG CUDA="128"
|
||||||
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
|
||||||
|
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||||
|
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||||
|
|
||||||
|
RUN apt-get update \
|
||||||
|
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& wget \
|
||||||
|
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||||
|
&& mkdir /root/.conda \
|
||||||
|
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||||
|
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||||
|
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||||
|
|
||||||
|
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||||
|
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
|
||||||
|
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||||
|
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||||
|
|
||||||
|
RUN git lfs install --skip-repo && \
|
||||||
|
pip3 install awscli && \
|
||||||
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
@@ -14,7 +14,7 @@ COPY scripts/motd /etc/motd
|
|||||||
|
|
||||||
RUN pip install jupyterlab notebook ipywidgets && \
|
RUN pip install jupyterlab notebook ipywidgets && \
|
||||||
jupyter lab clean
|
jupyter lab clean
|
||||||
RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
||||||
mkdir -p ~/.ssh && \
|
mkdir -p ~/.ssh && \
|
||||||
chmod 700 ~/.ssh && \
|
chmod 700 ~/.ssh && \
|
||||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||||
|
|||||||
2
docs/.gitignore
vendored
2
docs/.gitignore
vendored
@@ -1,2 +1,4 @@
|
|||||||
/.quarto/
|
/.quarto/
|
||||||
_site/
|
_site/
|
||||||
|
/api/*.qmd
|
||||||
|
/api/*.html
|
||||||
|
|||||||
42
docs/cli.qmd
42
docs/cli.qmd
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: "CLI Reference"
|
title: "Command Line Interface (CLI)"
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
toc: true
|
toc: true
|
||||||
@@ -170,7 +170,7 @@ axolotl merge-sharded-fsdp-weights config.yml
|
|||||||
|
|
||||||
### evaluate
|
### evaluate
|
||||||
|
|
||||||
Evaluates a model's performance using metrics specified in the config.
|
Evaluates a model's performance (loss etc) on the train and eval datasets.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Basic evaluation
|
# Basic evaluation
|
||||||
@@ -197,6 +197,8 @@ lm_eval_batch_size: # Batch size for evaluation
|
|||||||
output_dir: # Directory to save evaluation results
|
output_dir: # Directory to save evaluation results
|
||||||
```
|
```
|
||||||
|
|
||||||
|
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
|
||||||
|
|
||||||
## Legacy CLI Usage
|
## Legacy CLI Usage
|
||||||
|
|
||||||
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
|
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:
|
||||||
@@ -235,7 +237,7 @@ Create a cloud config YAML with your Modal settings:
|
|||||||
```yaml
|
```yaml
|
||||||
# cloud_config.yml
|
# cloud_config.yml
|
||||||
provider: modal
|
provider: modal
|
||||||
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
|
gpu: a100 # Supported: l40s, a100-40gb, a100-80gb, a10g, h100, t4, l4
|
||||||
gpu_count: 1 # Number of GPUs to use
|
gpu_count: 1 # Number of GPUs to use
|
||||||
timeout: 86400 # Maximum runtime in seconds (24 hours)
|
timeout: 86400 # Maximum runtime in seconds (24 hours)
|
||||||
branch: main # Git branch to use (optional)
|
branch: main # Git branch to use (optional)
|
||||||
@@ -248,7 +250,7 @@ volumes: # Persistent storage volumes
|
|||||||
- name: axolotl-artifacts
|
- name: axolotl-artifacts
|
||||||
mount: /workspace/artifacts
|
mount: /workspace/artifacts
|
||||||
|
|
||||||
env: # Environment variables
|
secrets: # Secrets to inject
|
||||||
- WANDB_API_KEY
|
- WANDB_API_KEY
|
||||||
- HF_TOKEN
|
- HF_TOKEN
|
||||||
```
|
```
|
||||||
@@ -274,15 +276,27 @@ axolotl lm-eval config.yml --cloud cloud_config.yml
|
|||||||
### Cloud Configuration Options
|
### Cloud Configuration Options
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
provider: # compute provider, currently only `modal` is supported
|
provider: # compute provider, currently only `modal` is supported
|
||||||
gpu: # GPU type to use
|
gpu: # GPU type to use
|
||||||
gpu_count: # Number of GPUs (default: 1)
|
gpu_count: # Number of GPUs (default: 1)
|
||||||
memory: # RAM in GB (default: 128)
|
memory: # RAM in GB (default: 128)
|
||||||
timeout: # Maximum runtime in seconds
|
timeout: # Maximum runtime in seconds
|
||||||
timeout_preprocess: # Preprocessing timeout
|
timeout_preprocess: # Preprocessing timeout
|
||||||
branch: # Git branch to use
|
branch: # Git branch to use
|
||||||
docker_tag: # Custom Docker image tag
|
docker_tag: # Custom Docker image tag
|
||||||
volumes: # List of persistent storage volumes
|
volumes: # List of persistent storage volumes
|
||||||
env: # Environment variables to pass
|
|
||||||
secrets: # Secrets to inject
|
# Environment variables to pass. Can be specified in two ways:
|
||||||
|
# 1. As a string: Will load the value from the host computer's environment variables
|
||||||
|
# 2. As a key-value pair: Will use the specified value directly
|
||||||
|
# Example:
|
||||||
|
# env:
|
||||||
|
# - CUSTOM_VAR # Loads from host's $CUSTOM_VAR
|
||||||
|
# - {CUSTOM_VAR: "value"} # Uses "value" directly
|
||||||
|
env:
|
||||||
|
|
||||||
|
# Secrets to inject. Same input format as `env` but for sensitive data.
|
||||||
|
secrets:
|
||||||
|
# - HF_TOKEN
|
||||||
|
# - WANDB_API_KEY
|
||||||
```
|
```
|
||||||
|
|||||||
185
docs/config.qmd
185
docs/config.qmd
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: Config options
|
title: Config Reference
|
||||||
description: A complete list of all configuration options.
|
description: A complete list of all configuration options.
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -30,6 +30,11 @@ tokenizer_legacy:
|
|||||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||||
# This is reported to improve training speed on some models
|
# This is reported to improve training speed on some models
|
||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_to_32x:
|
||||||
|
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||||
|
shrink_embeddings:
|
||||||
|
# Whether to load the model with randomly initialized weights. Useful for
|
||||||
|
# pre-training a model from scratch or debugging purposes.
|
||||||
|
random_init_weights:
|
||||||
|
|
||||||
# (Internal use only)
|
# (Internal use only)
|
||||||
# Used to identify which the model is based on
|
# Used to identify which the model is based on
|
||||||
@@ -78,14 +83,17 @@ tf32: true # require >=ampere
|
|||||||
bfloat16: true # require >=ampere
|
bfloat16: true # require >=ampere
|
||||||
float16: true
|
float16: true
|
||||||
|
|
||||||
# Use Tensor parallel
|
|
||||||
tensor_parallel: true # require multi-gGPU
|
|
||||||
|
|
||||||
# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset
|
# Limit the memory for all available GPUs to this amount (if an integer, expressed in gigabytes); default: unset
|
||||||
gpu_memory_limit: 20GiB
|
gpu_memory_limit: 20GiB
|
||||||
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
|
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
|
||||||
lora_on_cpu: true
|
lora_on_cpu: true
|
||||||
|
|
||||||
|
# List[str]. Add plugins to extend the pipeline.
|
||||||
|
# See `src/axolotl/integrations` for the available plugins or doc below for more details.
|
||||||
|
# https://axolotl-ai-cloud.github.io/axolotl/docs/custom_integrations.html
|
||||||
|
plugins:
|
||||||
|
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
# A list of one or more datasets to finetune the model with
|
# A list of one or more datasets to finetune the model with
|
||||||
datasets:
|
datasets:
|
||||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||||
@@ -157,8 +165,6 @@ datasets:
|
|||||||
content: value
|
content: value
|
||||||
# ...
|
# ...
|
||||||
|
|
||||||
message_property_mappings:
|
|
||||||
|
|
||||||
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
||||||
roles:
|
roles:
|
||||||
user: ["human", "user"]
|
user: ["human", "user"]
|
||||||
@@ -166,6 +172,12 @@ datasets:
|
|||||||
system: ["system"]
|
system: ["system"]
|
||||||
tool: ["tool"]
|
tool: ["tool"]
|
||||||
|
|
||||||
|
# Optional[bool]. Whether to drop the system turn from the dataset. Only works with chat_template.
|
||||||
|
# This does not drop the default system message from chat_template if it exists. If you wish to,
|
||||||
|
# we recommend using a custom jinja template with the default system message removed or
|
||||||
|
# adding a system turn with empty content.
|
||||||
|
drop_system_message:
|
||||||
|
|
||||||
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
||||||
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
# Priority order: message_field_training > message_field_training_detail > train_on_inputs or role in roles_to_train
|
||||||
# See examples at `docs/dataset-formats/conversation.qmd`
|
# See examples at `docs/dataset-formats/conversation.qmd`
|
||||||
@@ -204,10 +216,46 @@ test_datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- /workspace/data/eval.jsonl
|
- /workspace/data/eval.jsonl
|
||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto'
|
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
|
||||||
rl:
|
rl:
|
||||||
# whether to perform weighting if doing DPO training. Boolean.
|
rl_beta: # Optional[float]. The beta parameter for the RL training.
|
||||||
dpo_use_weighting:
|
|
||||||
|
# dpo
|
||||||
|
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
|
||||||
|
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
|
||||||
|
|
||||||
|
# orpo
|
||||||
|
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
|
||||||
|
|
||||||
|
# kto
|
||||||
|
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
|
||||||
|
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
|
||||||
|
|
||||||
|
# simpo
|
||||||
|
cpo_alpha: 1.0 # Weight of the BC regularizer
|
||||||
|
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
|
||||||
|
|
||||||
|
# grpo
|
||||||
|
trl:
|
||||||
|
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
||||||
|
vllm_server_host: # Optional[str]. Host of the vLLM server to connect to.
|
||||||
|
vllm_server_port: # Optional[int]. Port of the vLLM server to connect to.
|
||||||
|
vllm_server_timeout: # Optional[int]. Total timeout (in seconds) to wait for the vLLM server to respond.
|
||||||
|
vllm_guided_decoding_regex: # Optional[str]. Regex for vLLM guided decoding.
|
||||||
|
|
||||||
|
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
||||||
|
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
||||||
|
|
||||||
|
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
|
||||||
|
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
|
||||||
|
|
||||||
|
num_generations: # Optional[int]. Number of generations to sample.
|
||||||
|
log_completions: # Optional[bool]. Whether to log completions.
|
||||||
|
|
||||||
|
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
|
||||||
|
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
|
||||||
|
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
|
||||||
|
|
||||||
|
|
||||||
# reward modelling: `True` or `False`
|
# reward modelling: `True` or `False`
|
||||||
reward_model:
|
reward_model:
|
||||||
@@ -225,13 +273,13 @@ process_reward_model:
|
|||||||
chat_template: tokenizer_default
|
chat_template: tokenizer_default
|
||||||
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
||||||
chat_template_jinja: null
|
chat_template_jinja: null
|
||||||
# Changes the default system message
|
# Changes the default system message. Currently only supports chatml.
|
||||||
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
|
default_system_message: You are a helpful assistant. Please give a long and detailed answer.
|
||||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||||
# subsequent training attempts load faster, relative path
|
# subsequent training attempts load faster, relative path
|
||||||
dataset_prepared_path: data/last_run_prepared
|
dataset_prepared_path: data/last_run_prepared
|
||||||
# Push prepared dataset to hub
|
# Push prepared dataset to hub
|
||||||
push_dataset_to_hub: # repo path
|
push_dataset_to_hub: # Optional[str] repo_org/repo_name
|
||||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||||
# if not set.
|
# if not set.
|
||||||
dataset_processes: # defaults to os.cpu_count() if not set
|
dataset_processes: # defaults to os.cpu_count() if not set
|
||||||
@@ -272,9 +320,13 @@ total_num_tokens:
|
|||||||
sample_packing_group_size: 100000
|
sample_packing_group_size: 100000
|
||||||
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
|
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
|
||||||
sample_packing_bin_size: 200
|
sample_packing_bin_size: 200
|
||||||
|
sample_pack_sequentially: # Optional[bool]. Whether to pack samples sequentially.
|
||||||
|
|
||||||
# whether to concatenate samples during pretraining
|
# whether to concatenate samples during pretraining
|
||||||
pretraining_sample_concatenation:
|
pretraining_sample_concatenation:
|
||||||
|
|
||||||
|
curriculum_sampling: # Optional[bool]. Whether to use sequential sampling for curriculum learning
|
||||||
|
|
||||||
# Use batch flattening for speedups when not using sample_packing
|
# Use batch flattening for speedups when not using sample_packing
|
||||||
batch_flattening:
|
batch_flattening:
|
||||||
|
|
||||||
@@ -306,7 +358,27 @@ lora_target_modules:
|
|||||||
# - down_proj
|
# - down_proj
|
||||||
# - up_proj
|
# - up_proj
|
||||||
lora_target_linear: # If true, will target all linear modules
|
lora_target_linear: # If true, will target all linear modules
|
||||||
peft_layers_to_transform: # The layer indices to transform, otherwise, apply to all layers
|
|
||||||
|
# List[int] | int. # The layer indices to transform, otherwise, apply to all layers
|
||||||
|
# https://huggingface.co/docs/peft/v0.15.0/en/package_reference/lora#peft.LoraConfig.layers_to_transform
|
||||||
|
peft_layers_to_transform:
|
||||||
|
|
||||||
|
# Optional[bool]. Whether to use DoRA.
|
||||||
|
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#weight-decomposed-low-rank-adaptation-dora
|
||||||
|
peft_use_dora:
|
||||||
|
|
||||||
|
# Optional[bool]. Whether to use RSLoRA.
|
||||||
|
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#rank-stabilized-lora
|
||||||
|
peft_use_rslora:
|
||||||
|
|
||||||
|
# Optional[list[tuple[int, int]]]. List of layer indices to replicate.
|
||||||
|
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#memory-efficient-layer-replication-with-lora
|
||||||
|
peft_layer_replication:
|
||||||
|
|
||||||
|
# bool | Literal["gaussian", "eva", "olora", "pissa", "pissa_niter_[number of iters]", "corda", "loftq"]
|
||||||
|
# How to initialize LoRA weights. Default to True which is MS original implementation.
|
||||||
|
# https://huggingface.co/docs/peft/v0.15.0/en/developer_guides/lora#initialization
|
||||||
|
peft_init_lora_weights:
|
||||||
|
|
||||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||||
@@ -418,6 +490,7 @@ auto_find_batch_size: # Optional[bool]
|
|||||||
|
|
||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
|
do_causal_lm_eval: # Whether to run causal language model evaluation for metrics in `eval_causal_lm_metrics`.
|
||||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||||
|
|
||||||
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
|
profiler_steps: # enable the pytorch profiler to capture the first N steps of training to the output_dir.
|
||||||
@@ -448,7 +521,7 @@ gradient_checkpointing: false
|
|||||||
early_stopping_patience: 3
|
early_stopping_patience: 3
|
||||||
|
|
||||||
# Specify a scheduler and kwargs to use with the optimizer
|
# Specify a scheduler and kwargs to use with the optimizer
|
||||||
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
|
||||||
lr_scheduler_kwargs:
|
lr_scheduler_kwargs:
|
||||||
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
|
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
|
||||||
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
|
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
|
||||||
@@ -458,36 +531,58 @@ lr_div_factor: # Learning rate div factor
|
|||||||
|
|
||||||
# Specify optimizer
|
# Specify optimizer
|
||||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||||
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
# https://github.com/huggingface/transformers/blob/cbf924b76c03828101a34069a96d209314114fd5/src/transformers/training_args.py#L144-L189
|
||||||
#
|
#
|
||||||
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
||||||
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
||||||
# in the examples/ for your model and fine-tuning use case.
|
# in the examples/ for your model and fine-tuning use case.
|
||||||
#
|
#
|
||||||
# Valid values for 'optimizer' include:
|
# Valid values for 'optimizer' include:
|
||||||
# - adamw_hf
|
|
||||||
# - adamw_torch
|
# - adamw_torch
|
||||||
# - adamw_torch_fused
|
# - adamw_torch_fused
|
||||||
# - adamw_torch_xla
|
# - adamw_torch_xla
|
||||||
|
# - adamw_torch_npu_fused
|
||||||
# - adamw_apex_fused
|
# - adamw_apex_fused
|
||||||
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||||
# - adafactor
|
# - adafactor
|
||||||
# - adamw_anyprecision
|
# - adamw_anyprecision
|
||||||
|
# - adamw_torch_4bit
|
||||||
|
# - ademamix
|
||||||
# - sgd
|
# - sgd
|
||||||
# - adagrad
|
# - adagrad
|
||||||
# - adamw_bnb_8bit
|
# - adamw_bnb_8bit
|
||||||
|
# - adamw_8bit # alias for adamw_bnb_8bit
|
||||||
|
# - ademamix_8bit
|
||||||
# - lion_8bit
|
# - lion_8bit
|
||||||
# - lion_32bit
|
# - lion_32bit
|
||||||
# - paged_adamw_32bit
|
# - paged_adamw_32bit
|
||||||
# - paged_adamw_8bit
|
# - paged_adamw_8bit
|
||||||
|
# - paged_ademamix_32bit
|
||||||
|
# - paged_ademamix_8bit
|
||||||
# - paged_lion_32bit
|
# - paged_lion_32bit
|
||||||
# - paged_lion_8bit
|
# - paged_lion_8bit
|
||||||
|
# - rmsprop
|
||||||
|
# - rmsprop_bnb
|
||||||
|
# - rmsprop_bnb_8bit
|
||||||
|
# - rmsprop_bnb_32bit
|
||||||
# - galore_adamw
|
# - galore_adamw
|
||||||
# - galore_adamw_8bit
|
# - galore_adamw_8bit
|
||||||
# - galore_adafactor
|
# - galore_adafactor
|
||||||
# - galore_adamw_layerwise
|
# - galore_adamw_layerwise
|
||||||
# - galore_adamw_8bit_layerwise
|
# - galore_adamw_8bit_layerwise
|
||||||
# - galore_adafactor_layerwise
|
# - galore_adafactor_layerwise
|
||||||
|
# - lomo
|
||||||
|
# - adalomo
|
||||||
|
# - grokadamw
|
||||||
|
# - schedule_free_adamw
|
||||||
|
# - schedule_free_sgd
|
||||||
|
# - apollo_adamw
|
||||||
|
# - apollo_adamw_layerwise
|
||||||
|
#
|
||||||
|
# Additional custom optimizers include:
|
||||||
|
# - optimi_adamw
|
||||||
|
# - ao_adamw_8bit
|
||||||
|
# - ao_adamw_fp8
|
||||||
optimizer:
|
optimizer:
|
||||||
# Dictionary of arguments to pass to the optimizer
|
# Dictionary of arguments to pass to the optimizer
|
||||||
optim_args:
|
optim_args:
|
||||||
@@ -516,27 +611,42 @@ max_grad_norm:
|
|||||||
# currently only supported on Llama and Mistral
|
# currently only supported on Llama and Mistral
|
||||||
neftune_noise_alpha:
|
neftune_noise_alpha:
|
||||||
|
|
||||||
# Whether to bettertransformers
|
# Optional[bool]. Whether to bettertransformers
|
||||||
flash_optimum:
|
flash_optimum:
|
||||||
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
|
||||||
|
# Note: Only one of the following attention patches can be used at a time.
|
||||||
|
# For example, if you set `xformers_attention` to `true`, do not set `flash_attention` to `true`.
|
||||||
|
|
||||||
|
# Optional[bool]. Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||||
xformers_attention:
|
xformers_attention:
|
||||||
# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
# Optional[bool]. Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||||
flash_attention:
|
flash_attention:
|
||||||
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
flash_attn_cross_entropy: # Optional[bool]. Whether to use flash-attention cross entropy implementation - advanced use only
|
||||||
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
flash_attn_rms_norm: # Optional[bool]. Whether to use flash-attention rms norm implementation - advanced use only
|
||||||
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
flash_attn_fuse_qkv: # Optional[bool]. Whether to fuse QKV into a single operation
|
||||||
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
flash_attn_fuse_mlp: # Optional[bool]. Whether to fuse part of the MLP into a single operation
|
||||||
# Whether to use scaled-dot-product attention
|
# Optional[bool]. Whether to use scaled-dot-product attention
|
||||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||||
sdp_attention:
|
sdp_attention:
|
||||||
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
# Optional[bool]. Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
||||||
s2_attention:
|
s2_attention:
|
||||||
# Resume from a specific checkpoint dir
|
|
||||||
|
# Optional[bool]. Whether to use low_cpu_mem_usage
|
||||||
|
low_cpu_mem_usage:
|
||||||
|
# Optional[str]. Resume from a specific checkpoint dir
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
# Optional[bool]. If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||||
# Be careful with this being turned on between different models.
|
# Be careful with this being turned on between different models.
|
||||||
auto_resume_from_checkpoints: false
|
auto_resume_from_checkpoints: false
|
||||||
|
|
||||||
|
## Multimodal section
|
||||||
|
# int | tuple[int, int] | None . Size to resize images to, width x height.
|
||||||
|
# Will read from model/processor config if not set.
|
||||||
|
image_size:
|
||||||
|
# str. Algorithm to use for image resizing. "bilinear", "bicubic", "lanczos". Default is "bilinear".
|
||||||
|
image_resize_algorithm: 'bilinear'
|
||||||
|
## End of multimodal section
|
||||||
|
|
||||||
# Don't mess with this, it's here for accelerate and torchrun
|
# Don't mess with this, it's here for accelerate and torchrun
|
||||||
local_rank:
|
local_rank:
|
||||||
|
|
||||||
@@ -551,6 +661,13 @@ special_tokens:
|
|||||||
# Add extra tokens.
|
# Add extra tokens.
|
||||||
tokens:
|
tokens:
|
||||||
|
|
||||||
|
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
|
||||||
|
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
|
||||||
|
# Can be checked if they exist in tokenizer.json added_tokens.
|
||||||
|
added_tokens_overrides: # Dict[int, str]
|
||||||
|
# 128041: "<|im_start|>"
|
||||||
|
# 128042: "<|im_end|>"
|
||||||
|
|
||||||
# FSDP
|
# FSDP
|
||||||
fsdp:
|
fsdp:
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
@@ -563,6 +680,18 @@ ddp_timeout:
|
|||||||
ddp_bucket_cap_mb:
|
ddp_bucket_cap_mb:
|
||||||
ddp_broadcast_buffers:
|
ddp_broadcast_buffers:
|
||||||
|
|
||||||
|
# Sequence parallelism
|
||||||
|
# Set to a divisor of the number of GPUs available to split sequences into chunks of equal size.
|
||||||
|
# Use in long context training to prevent OOM when sequences cannot fit into a single GPU's VRAM.
|
||||||
|
# E.g., if 4 GPUs are available, set this value to 2 to split each sequence into two equal-sized
|
||||||
|
# subsequences, or set to 4 to split into four equal-sized subsequences.
|
||||||
|
# See https://axolotl-ai-cloud.github.io/axolotl/docs/sequence_parallelism.html for more details.
|
||||||
|
sequence_parallel_degree: 4 # Set to the number of GPUs to split sequences across
|
||||||
|
flash_attention: true # SP requires flash attention
|
||||||
|
micro_batch_size: 1 # SP requires this is set to 1
|
||||||
|
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
||||||
|
heads_k_stride: 1
|
||||||
|
|
||||||
# Path to torch distx for optim 'adamw_anyprecision'
|
# Path to torch distx for optim 'adamw_anyprecision'
|
||||||
torchdistx_path:
|
torchdistx_path:
|
||||||
|
|
||||||
|
|||||||
@@ -55,3 +55,47 @@ sections = [
|
|||||||
for section_name, folder_name in sections:
|
for section_name, folder_name in sections:
|
||||||
print(print_section(section_name, folder_name))
|
print(print_section(section_name, folder_name))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Adding a new integration
|
||||||
|
|
||||||
|
Plugins can be used to customize the behavior of the training pipeline through [hooks](https://en.wikipedia.org/wiki/Hooking). See [`axolotl.integrations.BasePlugin`](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/base.py) for the possible hooks.
|
||||||
|
|
||||||
|
To add a new integration, please follow these steps:
|
||||||
|
|
||||||
|
1. Create a new folder in the `src/axolotl/integrations` directory.
|
||||||
|
2. Add any relevant files (`LICENSE`, `README.md`, `ACKNOWLEDGEMENTS.md`, etc.) to the new folder.
|
||||||
|
3. Add `__init__.py` and `args.py` files to the new folder.
|
||||||
|
- `__init__.py` should import the integration and hook into the appropriate functions.
|
||||||
|
- `args.py` should define the arguments for the integration.
|
||||||
|
4. (If applicable) Add CPU tests under `tests/integrations` or GPU tests under `tests/e2e/integrations`.
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
|
||||||
|
See [src/axolotl/integrations/cut_cross_entropy](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/cut_cross_entropy) for a minimal integration example.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
::: {.callout-warning}
|
||||||
|
|
||||||
|
If you could not load your integration, please ensure you are pip installing in editable mode.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
and correctly spelled the integration name in the config file.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.your_integration_name.YourIntegrationPlugin
|
||||||
|
```
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
|
||||||
|
It is not necessary to place your integration in the `integrations` folder. It can be in any location, so long as it's installed in a package in your python env.
|
||||||
|
|
||||||
|
See this repo for an example: [https://github.com/axolotl-ai-cloud/diff-transformer](https://github.com/axolotl-ai-cloud/diff-transformer)
|
||||||
|
|
||||||
|
:::
|
||||||
|
|||||||
@@ -74,6 +74,10 @@ datasets:
|
|||||||
train_on_eos:
|
train_on_eos:
|
||||||
```
|
```
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
|
||||||
|
:::
|
||||||
|
|
||||||
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
@@ -129,6 +129,7 @@ You can mix and match within each approach or across approaches to train a model
|
|||||||
We suggest this approach when you want to bring your own tokenized dataset.
|
We suggest this approach when you want to bring your own tokenized dataset.
|
||||||
|
|
||||||
Axolotl expects the dataset to have three keys:
|
Axolotl expects the dataset to have three keys:
|
||||||
|
|
||||||
- `input_ids`: from tokenizing formatted prompt
|
- `input_ids`: from tokenizing formatted prompt
|
||||||
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
|
- `attention_mask`: for masking padding. If you don't add padding, it would be equal to `len(input_ids) * [1]`
|
||||||
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.
|
- `labels`: this is the same as `input_ids`, however, if you want to mask certain tokens, you would set those indices to `-100`.
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ description: How datasets are processed
|
|||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
||||||
the [dataset format](docs/dataset-formats) and prompt strategies to:
|
the [dataset format](dataset-formats) and prompt strategies to:
|
||||||
|
|
||||||
- parse the dataset based on the *dataset format*
|
- parse the dataset based on the *dataset format*
|
||||||
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
||||||
|
|||||||
139
docs/docker.qmd
Normal file
139
docs/docker.qmd
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
---
|
||||||
|
title: "Docker"
|
||||||
|
format:
|
||||||
|
html:
|
||||||
|
toc: true
|
||||||
|
toc-depth: 4
|
||||||
|
---
|
||||||
|
|
||||||
|
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||||
|
|
||||||
|
## Base
|
||||||
|
|
||||||
|
The base image is the most minimal image that can install Axolotl. It is based on the `nvidia/cuda` image. It includes python, torch, git, git-lfs, awscli, pydantic, and more.
|
||||||
|
|
||||||
|
#### Image
|
||||||
|
|
||||||
|
```
|
||||||
|
axolotlai/axolotl-base
|
||||||
|
```
|
||||||
|
|
||||||
|
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-base)
|
||||||
|
|
||||||
|
#### Tags format
|
||||||
|
|
||||||
|
```bash
|
||||||
|
main-base-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||||
|
```
|
||||||
|
|
||||||
|
Tags examples:
|
||||||
|
|
||||||
|
- `main-base-py3.11-cu124-2.6.0`
|
||||||
|
- `main-base-py3.11-cu124-2.5.1`
|
||||||
|
- `main-base-py3.11-cu124-2.4.1`
|
||||||
|
|
||||||
|
## Main
|
||||||
|
|
||||||
|
The main image is the image that is used to run Axolotl. It is based on the `axolotlai/axolotl-base` image and includes the Axolotl codebase, dependencies, and more.
|
||||||
|
|
||||||
|
#### Image
|
||||||
|
|
||||||
|
```
|
||||||
|
axolotlai/axolotl
|
||||||
|
```
|
||||||
|
|
||||||
|
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl)
|
||||||
|
|
||||||
|
#### Tags format {#sec-main-tags}
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# on push to main
|
||||||
|
main-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||||
|
|
||||||
|
# latest main (currently torch 2.5.1, python 3.11, cuda 12.4)
|
||||||
|
main-latest
|
||||||
|
|
||||||
|
# nightly build
|
||||||
|
{branch}-{date_in_YYYYMMDD}-py{python_version}-cu{cuda_version}-{pytorch_version}
|
||||||
|
|
||||||
|
# tagged release
|
||||||
|
{version}
|
||||||
|
```
|
||||||
|
|
||||||
|
:::{.callout-tip}
|
||||||
|
|
||||||
|
There may be some extra tags appended to the image, like `-vllm` which installs those packages.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
Tags examples:
|
||||||
|
|
||||||
|
- `main-py3.11-cu124-2.6.0`
|
||||||
|
- `main-py3.11-cu124-2.5.1`
|
||||||
|
- `main-py3.11-cu124-2.4.1`
|
||||||
|
- `main-latest`
|
||||||
|
- `main-20250303-py3.11-cu124-2.6.0`
|
||||||
|
- `main-20250303-py3.11-cu124-2.5.1`
|
||||||
|
- `main-20250303-py3.11-cu124-2.4.1`
|
||||||
|
- `0.7.1`
|
||||||
|
|
||||||
|
## Cloud
|
||||||
|
|
||||||
|
The cloud image is the image that is used to run Axolotl in the cloud. It is based on the `axolotlai/axolotl` image and sets ENV variables like HuggingFace cache directories for volume mounts, tmux, and more for different cloud providers.
|
||||||
|
|
||||||
|
:::{.callout-tip}
|
||||||
|
|
||||||
|
Jupyter lab is run by default. Set `JUPYTER_DISABLE=1` in the environment variables to disable it.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
#### Image
|
||||||
|
|
||||||
|
```
|
||||||
|
axolotlai/axolotl-cloud
|
||||||
|
```
|
||||||
|
|
||||||
|
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud)
|
||||||
|
|
||||||
|
#### Tags format
|
||||||
|
|
||||||
|
This uses the same tags as the [`main` image](#sec-main-tags).
|
||||||
|
|
||||||
|
#### Environment variables
|
||||||
|
|
||||||
|
- `JUPYTER_DISABLE`: Disable Jupyter lab.
|
||||||
|
- `JUPYTER_PASSWORD`: Set a password for the Jupyter lab.
|
||||||
|
- `PUBLIC_KEY` / `SSH_KEY`: Add a public key for the SSH service.
|
||||||
|
|
||||||
|
#### Volume mounts
|
||||||
|
|
||||||
|
:::{.callout-tip}
|
||||||
|
|
||||||
|
We recommend mounting volumes to `/workspace/data` for data persistence. `/workspace/axolotl` contains the source code and is ephemeral.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
- `/workspace/data/axolotl-artifacts`: Directory to store Axolotl artifacts.
|
||||||
|
- `/workspace/data/huggingface-cache`: Directory to store HuggingFace cache.
|
||||||
|
|
||||||
|
## Cloud-no-tmux
|
||||||
|
|
||||||
|
This is the same as the [`cloud` image](#sec-cloud) but without tmux.
|
||||||
|
|
||||||
|
#### Image
|
||||||
|
|
||||||
|
```
|
||||||
|
axolotlai/axolotl-cloud-term
|
||||||
|
```
|
||||||
|
|
||||||
|
Link: [Docker Hub](https://hub.docker.com/r/axolotlai/axolotl-cloud-term)
|
||||||
|
|
||||||
|
:::{.callout-note}
|
||||||
|
|
||||||
|
The naming may be a bit confusing as it has `-term` appended to the end.
|
||||||
|
|
||||||
|
:::
|
||||||
|
|
||||||
|
#### Tags format
|
||||||
|
|
||||||
|
This uses the same tags as the [`cloud` image](#sec-cloud-tags).
|
||||||
32
docs/faq.qmd
32
docs/faq.qmd
@@ -19,12 +19,38 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
|
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
|
||||||
|
|
||||||
> A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
|
**Q: ModuleNotFoundError: No module named 'mpi4py' using single GPU with deepspeed**
|
||||||
|
|
||||||
|
> A: You may be using deepspeed with single gpu. Please remove the `deepspeed:` section in the yaml file or `--deepspeed` CLI flag.
|
||||||
|
|
||||||
**Q: The codes is stuck on saving preprocessed datasets.**
|
**Q: The codes is stuck on saving preprocessed datasets.**
|
||||||
|
|
||||||
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
||||||
|
|
||||||
|
**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**
|
||||||
|
|
||||||
|
> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.
|
||||||
|
|
||||||
|
> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.
|
||||||
|
|
||||||
|
**Q: How to call Axolotl via custom python scripts?**
|
||||||
|
|
||||||
|
> A: Since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||||
|
|
||||||
|
**Q: How to know the value to use for `fsdp_transformer_layer_cls_to_wrap`?**
|
||||||
|
|
||||||
|
> A: This is the class name of the transformer layer to wrap with FSDP. For example, for `LlamaForCausalLM`, the value is `LlamaDecoderLayer`. To find this for a specific model, check the model's `PreTrainedModel` definition and look for `_no_split_modules` variable in the `modeling_<model_name>.py` file within `transformers` library.
|
||||||
|
|
||||||
|
**Q: ValueError: Asking to pad but the tokenizer does not have a padding token. Please select a token to use as pad_token**
|
||||||
|
|
||||||
|
> A: This is because the tokenizer does not have a padding token. Please add a padding token to the tokenizer via:
|
||||||
|
|
||||||
|
> ```yaml
|
||||||
|
> special_tokens:
|
||||||
|
> # str. If you're not sure, set to same as `eos_token`.
|
||||||
|
> pad_token: "..."
|
||||||
|
> ```
|
||||||
|
|
||||||
### Chat templates
|
### Chat templates
|
||||||
|
|
||||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||||
@@ -50,3 +76,7 @@ description: Frequently asked questions
|
|||||||
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
|
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
|
||||||
|
|
||||||
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
|
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
|
||||||
|
|
||||||
|
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
|
||||||
|
|
||||||
|
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.
|
||||||
|
|||||||
@@ -36,7 +36,9 @@ The YAML configuration file controls everything about your training. Here's what
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_model: NousResearch/Llama-3.2-1B
|
base_model: NousResearch/Llama-3.2-1B
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
load_in_8bit: true
|
||||||
|
adapter: lora
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
@@ -44,11 +46,15 @@ datasets:
|
|||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
`load_in_8bit: true` and `adapter: lora` enables LoRA adapter finetuning.
|
||||||
|
|
||||||
|
- To perform Full finetuning, remove these two lines.
|
||||||
|
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
|
||||||
|
:::
|
||||||
|
|
||||||
See our [Config options](config.qmd) for more details.
|
See our [Config options](config.qmd) for more details.
|
||||||
|
|
||||||
### Training {#sec-training}
|
### Training {#sec-training}
|
||||||
@@ -56,7 +62,7 @@ See our [Config options](config.qmd) for more details.
|
|||||||
When you run `axolotl train`, Axolotl:
|
When you run `axolotl train`, Axolotl:
|
||||||
|
|
||||||
1. Downloads the base model
|
1. Downloads the base model
|
||||||
2. (If specified) applies LoRA adapter layers
|
2. (If specified) applies QLoRA/LoRA adapter layers
|
||||||
3. Loads and processes the dataset
|
3. Loads and processes the dataset
|
||||||
4. Runs the training loop
|
4. Runs the training loop
|
||||||
5. Saves the trained model and / or LoRA weights
|
5. Saves the trained model and / or LoRA weights
|
||||||
@@ -69,6 +75,8 @@ Let's modify the example for your own data:
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_model: NousResearch/Nous-Hermes-llama-1b-v1
|
base_model: NousResearch/Nous-Hermes-llama-1b-v1
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
adapter: lora
|
adapter: lora
|
||||||
|
|
||||||
# Training settings
|
# Training settings
|
||||||
@@ -104,8 +112,6 @@ format):
|
|||||||
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
|
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
|
||||||
```
|
```
|
||||||
|
|
||||||
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
|
|
||||||
|
|
||||||
3. Run the training:
|
3. Run the training:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: "Inference"
|
title: "Inference and Merging"
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
toc: true
|
toc: true
|
||||||
@@ -9,10 +9,14 @@ execute:
|
|||||||
enabled: false
|
enabled: false
|
||||||
---
|
---
|
||||||
|
|
||||||
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
|
This guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps.
|
||||||
|
|
||||||
## Quick Start {#sec-quickstart}
|
## Quick Start {#sec-quickstart}
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Use the same config used for training on inference/merging.
|
||||||
|
:::
|
||||||
|
|
||||||
### Basic Inference {#sec-basic}
|
### Basic Inference {#sec-basic}
|
||||||
|
|
||||||
::: {.panel-tabset}
|
::: {.panel-tabset}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
|||||||
### PyPI Installation (Recommended) {#sec-pypi}
|
### PyPI Installation (Recommended) {#sec-pypi}
|
||||||
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -37,7 +38,7 @@ For the latest features between releases:
|
|||||||
```{.bash}
|
```{.bash}
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
pip3 install packaging ninja
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -65,6 +66,8 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it \
|
|||||||
```
|
```
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
Please refer to the [Docker documentation](docker.qmd) for more information on the different Docker images that are available.
|
||||||
|
|
||||||
## Cloud Environments {#sec-cloud}
|
## Cloud Environments {#sec-cloud}
|
||||||
|
|
||||||
### Cloud GPU Providers {#sec-cloud-gpu}
|
### Cloud GPU Providers {#sec-cloud-gpu}
|
||||||
@@ -76,6 +79,7 @@ For providers supporting Docker:
|
|||||||
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||||
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
|
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
|
||||||
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||||
|
- [Novita](https://novita.ai/gpus-console?templateId=311)
|
||||||
|
|
||||||
### Google Colab {#sec-colab}
|
### Google Colab {#sec-colab}
|
||||||
|
|
||||||
@@ -105,7 +109,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
|||||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
||||||
3. Install Axolotl:
|
3. Install Axolotl:
|
||||||
```{.bash}
|
```{.bash}
|
||||||
pip3 install packaging
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
4. (Optional) Login to Hugging Face:
|
4. (Optional) Login to Hugging Face:
|
||||||
|
|||||||
@@ -66,6 +66,10 @@ logic to be compatible with more of them.
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Check out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).
|
||||||
|
:::
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
These optimizations can be enabled in your Axolotl config YAML file. The
|
These optimizations can be enabled in your Axolotl config YAML file. The
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ Axolotl supports several methods for multi-GPU training:
|
|||||||
|
|
||||||
- DeepSpeed (recommended)
|
- DeepSpeed (recommended)
|
||||||
- FSDP (Fully Sharded Data Parallel)
|
- FSDP (Fully Sharded Data Parallel)
|
||||||
|
- Sequence parallelism
|
||||||
- FSDP + QLoRA
|
- FSDP + QLoRA
|
||||||
|
|
||||||
## DeepSpeed {#sec-deepspeed}
|
## DeepSpeed {#sec-deepspeed}
|
||||||
@@ -66,6 +67,28 @@ fsdp_config:
|
|||||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Sequence parallelism {#sec-sequence-parallelism}
|
||||||
|
|
||||||
|
We support sequence parallelism (SP) via the
|
||||||
|
[ring-flash-attention](https://github.com/zhuzilin/ring-flash-attention) project. This
|
||||||
|
allows one to split up sequences across GPUs, which is useful in the event that a
|
||||||
|
single sequence causes OOM errors during model training.
|
||||||
|
|
||||||
|
First, install `ring-flash-attn`, recommended via `pip install axolotl[ring-flash-attn]`,
|
||||||
|
or from source with `pip install .[ring-flash-attn]`.
|
||||||
|
|
||||||
|
Your Axolotl YAML config should contain the following lines:
|
||||||
|
|
||||||
|
```{.yaml}
|
||||||
|
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||||
|
flash_attention: true # Required with sequence parallelism
|
||||||
|
|
||||||
|
# Optional; strides across the key dimension. Larger values use more memory but will make training faster.
|
||||||
|
heads_k_stride: 1
|
||||||
|
```
|
||||||
|
|
||||||
|
See our [dedicated guide](sequence_parallelism.qmd) for more details.
|
||||||
|
|
||||||
### FSDP + QLoRA {#sec-fsdp-qlora}
|
### FSDP + QLoRA {#sec-fsdp-qlora}
|
||||||
|
|
||||||
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
|
For combining FSDP with QLoRA, see our [dedicated guide](fsdp_qlora.qmd).
|
||||||
|
|||||||
@@ -1,28 +1,171 @@
|
|||||||
# MultiModal / Vision Language Models (BETA)
|
---
|
||||||
|
title: MultiModal / Vision Language Models (BETA)
|
||||||
|
format:
|
||||||
|
html:
|
||||||
|
toc: true
|
||||||
|
toc-depth: 3
|
||||||
|
---
|
||||||
|
|
||||||
### Supported Models
|
## Supported Models
|
||||||
|
|
||||||
- Mllama, i.e. llama with vision models
|
- [Mllama](#sec-mllama)
|
||||||
|
- [Pixtral](#sec-pixtral)
|
||||||
|
- [Llava-1.5](#sec-llava-15)
|
||||||
|
- [Mistral-Small-3.1](#sec-mistral-small-31)
|
||||||
|
- [Gemma-3](#sec-gemma-3)
|
||||||
|
- [Qwen2-VL](#sec-qwen2-vl)
|
||||||
|
- [Qwen2.5-VL](#sec-qwen25-vl)
|
||||||
|
|
||||||
### Usage
|
## Usage
|
||||||
|
|
||||||
Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
|
Multimodal support is limited and doesn't have full feature parity.
|
||||||
you'll need to use the following in YAML in combination with the rest of the required hyperparams.
|
|
||||||
|
Here are the hyperparams you'll need to use to finetune a multimodal model.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
|
|
||||||
processor_type: AutoProcessor
|
processor_type: AutoProcessor
|
||||||
skip_prepare_dataset: true
|
|
||||||
|
|
||||||
chat_template: llama3_2_vision
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false # leave columns in place as they are needed to handle image embeddings during training
|
||||||
|
sample_packing: false # not yet supported with multimodal
|
||||||
|
|
||||||
|
chat_template: # see in next section
|
||||||
|
|
||||||
|
# example dataset
|
||||||
datasets:
|
datasets:
|
||||||
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
type: chat_template
|
type: chat_template
|
||||||
split: train[:1%]
|
split: train[:1%]
|
||||||
field_messages: messages
|
field_messages: messages
|
||||||
remove_unused_columns: false
|
|
||||||
sample_packing: false
|
|
||||||
|
|
||||||
# only finetune the Language model, leave the vision model and vision tower frozen
|
# (optional) if doing lora, only finetune the Language model,
|
||||||
|
# leave the vision model and vision tower frozen
|
||||||
|
# load_in_8bit: true
|
||||||
|
adapter: lora
|
||||||
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
# (optional) if you want to resize images to a set size
|
||||||
|
image_size: 512
|
||||||
|
image_resize_algorithm: bilinear
|
||||||
|
```
|
||||||
|
|
||||||
|
Please see [examples](https://github.com/axolotl-ai/axolotl/tree/main/examples) folder for full configs.
|
||||||
|
|
||||||
|
::: {.callout-warning}
|
||||||
|
Some of our chat_templates have been extended to support broader dataset types. This should not break any existing configs.
|
||||||
|
:::
|
||||||
|
|
||||||
|
### Mllama {#sec-mllama}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
|
|
||||||
|
chat_template: llama3_2_vision
|
||||||
|
```
|
||||||
|
|
||||||
|
### Pixtral {#sec-pixtral}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: mistralai/Pixtral-12B-2409
|
||||||
|
|
||||||
|
chat_template: pixtral
|
||||||
|
```
|
||||||
|
|
||||||
|
### Llava-1.5 {#sec-llava-15}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: llava-hf/llava-1.5-7b-hf
|
||||||
|
|
||||||
|
chat_template: llava
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mistral-Small-3.1 {#sec-mistral-small-31}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||||
|
|
||||||
|
chat_template: mistral_v7_tekken
|
||||||
|
```
|
||||||
|
|
||||||
|
### Gemma-3 {#sec-gemma-3}
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
The Gemma3-1B model is a text-only model, so please train as regular text model.
|
||||||
|
:::
|
||||||
|
|
||||||
|
For multi-modal 4B/12B/27B models, use the following config:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: google/gemma-3-4b-it
|
||||||
|
|
||||||
|
chat_template: gemma3
|
||||||
|
```
|
||||||
|
|
||||||
|
### Qwen2-VL {#sec-qwen2-vl}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
|
||||||
|
chat_template: qwen2_vl
|
||||||
|
```
|
||||||
|
|
||||||
|
### Qwen2.5-VL {#sec-qwen25-vl}
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: Qwen/Qwen2.5-VL-7B-Instruct
|
||||||
|
|
||||||
|
chat_template: qwen2_vl # same as qwen2-vl
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dataset Format
|
||||||
|
|
||||||
|
For multi-modal datasets, we adopt an extended `chat_template` format similar to OpenAI's Message format.
|
||||||
|
|
||||||
|
- A message is a list of `role` and `content`.
|
||||||
|
- `role` can be `system`, `user`, `assistant`, etc.
|
||||||
|
- `content` is a list of `type` and (`text` or `image` or `path` or `url` or `base64`).
|
||||||
|
|
||||||
|
::: {.callout-note}
|
||||||
|
For backwards compatibility:
|
||||||
|
|
||||||
|
- If the dataset has a `images` or `image` column of `list[Image]`, it will be appended to the first `content` list as `{"type": "image", "image": ...}`. However, if the content already has a `{"type": "image"}` but no `image` key, it will be set the `image` key.
|
||||||
|
- If `content` is a string, it will be converted to a list with `type` as `text`.
|
||||||
|
:::
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
For image loading, you can use the following keys within `content` alongside `"type": "image"`:
|
||||||
|
|
||||||
|
- `"path": "/path/to/image.jpg"`
|
||||||
|
- `"url": "https://example.com/image.jpg"`
|
||||||
|
- `"base64": "..."`
|
||||||
|
- `"image": PIL.Image`
|
||||||
|
:::
|
||||||
|
|
||||||
|
Here is an example of a multi-modal dataset:
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "You are a helpful assistant."}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image", "image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"},
|
||||||
|
{"type": "text", "text": "Describe this image in detail."}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "The image is a bee."}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -28,8 +28,23 @@ val_set_size: 0.1
|
|||||||
eval_steps: 100
|
eval_steps: 100
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Bradley-Terry chat templates expect single-turn conversations in the following format:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"system": "...", // optional
|
||||||
|
"input": "...",
|
||||||
|
"chosen": "...",
|
||||||
|
"rejected": "..."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
### Process Reward Models (PRM)
|
### Process Reward Models (PRM)
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Check out our [PRM blog](https://axolotlai.substack.com/p/process-reward-models).
|
||||||
|
:::
|
||||||
|
|
||||||
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
||||||
```yaml
|
```yaml
|
||||||
base_model: Qwen/Qwen2.5-3B
|
base_model: Qwen/Qwen2.5-3B
|
||||||
@@ -45,3 +60,5 @@ datasets:
|
|||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
eval_steps: 100
|
eval_steps: 100
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Please see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ title: "RLHF (Beta)"
|
|||||||
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
||||||
back-to-top-navigation: true
|
back-to-top-navigation: true
|
||||||
toc: true
|
toc: true
|
||||||
|
toc-expand: 2
|
||||||
toc-depth: 4
|
toc-depth: 4
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -297,7 +298,7 @@ The input format is a simple JSON input with customizable fields based on the ab
|
|||||||
|
|
||||||
### IPO
|
### IPO
|
||||||
|
|
||||||
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
rl: ipo
|
rl: ipo
|
||||||
@@ -343,8 +344,9 @@ ORPO supports the following types with the following dataset format:
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
rl: kto
|
rl: kto
|
||||||
rl_beta: 0.5
|
rl_beta: 0.1 # default
|
||||||
kto_desirable_weight: 0.2
|
kto_desirable_weight: 1.0 # default
|
||||||
|
kto_undesirable_weight: 1.0 # default
|
||||||
|
|
||||||
remove_unused_columns: false
|
remove_unused_columns: false
|
||||||
|
|
||||||
@@ -496,9 +498,52 @@ The input format is a simple JSON input with customizable fields based on the ab
|
|||||||
|
|
||||||
### GRPO
|
### GRPO
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
|
||||||
|
:::
|
||||||
|
|
||||||
|
If you have multiple GPUs available, we reccomend using `vLLM` with the `GRPOTrainer` to significantly speedup trajectory generation during training.
|
||||||
|
First, launch a `vLLM` server using `trl vllm-serve` - you may use a config file or CLI overrides to configure your vLLM server. In this example, we're
|
||||||
|
using 4 GPUs - 2 for training, and 2 for vLLM:
|
||||||
|
|
||||||
|
::: {.callout-important}
|
||||||
|
Make sure you've installed the correct version of vLLM by including it as an extra when installing axolotl, e.g. `pip install axolotl[vllm]`.
|
||||||
|
:::
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: Qwen/Qwen2.5-1.5B-Instruct
|
||||||
|
|
||||||
|
vllm:
|
||||||
|
host: 0.0.0.0
|
||||||
|
port: 8000
|
||||||
|
tensor_parallel_size: 2
|
||||||
|
gpu_memory_utilization: 0.85
|
||||||
|
dtype: auto
|
||||||
|
# max_model_len: # you may find it useful to set the vLLM model context length if you know this beforehand
|
||||||
|
|
||||||
|
rl: grpo
|
||||||
|
trl:
|
||||||
|
use_vllm: true
|
||||||
|
vllm_server_host: 0.0.0.0
|
||||||
|
vllm_server_port: 8000
|
||||||
|
vllm_server_timeout: 300
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=2,3 axolotl vllm_serve grpo.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Your `vLLM` instance will now attempt to spin up, and it's time to kick off training utilizing our remaining two GPUs. In another terminal, execute:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1 axolotl train grpo.yaml --num-processes 2
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Reward functions
|
||||||
|
|
||||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||||
|
|
||||||
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
For example, to load OpenAI's GSM8K and use a random reward for completions:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# rewards.py
|
# rewards.py
|
||||||
@@ -524,10 +569,9 @@ trl:
|
|||||||
beta: 0.001
|
beta: 0.001
|
||||||
max_completion_length: 256
|
max_completion_length: 256
|
||||||
use_vllm: True
|
use_vllm: True
|
||||||
vllm_device: auto
|
|
||||||
vllm_gpu_memory_utilization: 0.15
|
|
||||||
num_generations: 4
|
num_generations: 4
|
||||||
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
||||||
|
reward_weights: [1.0]
|
||||||
datasets:
|
datasets:
|
||||||
- path: openai/gsm8k
|
- path: openai/gsm8k
|
||||||
name: main
|
name: main
|
||||||
@@ -536,6 +580,21 @@ datasets:
|
|||||||
|
|
||||||
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
||||||
|
|
||||||
|
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
|
||||||
|
|
||||||
|
### SimPO
|
||||||
|
|
||||||
|
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
rl: simpo
|
||||||
|
rl_beta: 0.1 # default in CPOTrainer
|
||||||
|
cpo_alpha: 1.0 # default in CPOTrainer
|
||||||
|
simpo_gamma: 0.5 # default in CPOTrainer
|
||||||
|
```
|
||||||
|
|
||||||
|
This method uses the same dataset format as [DPO](#dpo).
|
||||||
|
|
||||||
### Using local dataset files
|
### Using local dataset files
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
101
docs/sequence_parallelism.qmd
Normal file
101
docs/sequence_parallelism.qmd
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
---
|
||||||
|
title: Sequence Parallelism
|
||||||
|
description: Train with long sequences split across multiple GPUs.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Sequence Parallelism
|
||||||
|
|
||||||
|
Sequence parallelism is a technique that splits sequences across multiple GPUs,
|
||||||
|
allowing you to train with very long sequences that wouldn't fit on a single GPU. Each
|
||||||
|
GPU processes a different portion of the sequence, and the results are aggregated
|
||||||
|
through a ring communication pattern.
|
||||||
|
|
||||||
|
## When to Use Sequence Parallelism
|
||||||
|
|
||||||
|
Use sequence parallelism when:
|
||||||
|
|
||||||
|
- You need to train with sequence lengths that don't fit into a single GPU's memory
|
||||||
|
- You have multiple GPUs available
|
||||||
|
- You're experiencing OOM (Out Of Memory) errors with long sequences
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
To enable sequence parallelism, add the following to your configuration file:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||||
|
flash_attention: true # SP requires flash attention
|
||||||
|
micro_batch_size: 1 # SP requires this is set to 1
|
||||||
|
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
||||||
|
heads_k_stride: 1
|
||||||
|
```
|
||||||
|
|
||||||
|
The `sequence_parallel_degree` should be a divisor of the total number of GPUs. For example:
|
||||||
|
|
||||||
|
- With 8 GPUs, valid values would be 2, 4, or 8
|
||||||
|
- With 4 GPUs, valid values would be 2 or 4
|
||||||
|
|
||||||
|
## Implementation Details
|
||||||
|
|
||||||
|
When sequence parallelism is enabled:
|
||||||
|
|
||||||
|
1. Each sequence is divided into equal chunks across the GPUs in a sequence parallel group
|
||||||
|
2. The data collator handles the chunking of input_ids, attention_mask, labels, and position_ids
|
||||||
|
3. Position IDs are adjusted to maintain proper relative positions, especially for packed sequences
|
||||||
|
4. The trainer uses special ring communication patterns for attention operations
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
To use sequence parallelism, you need:
|
||||||
|
|
||||||
|
- Multiple GPUs (at least 2)
|
||||||
|
- The `ring-flash-attn` package. Install with:
|
||||||
|
- `pip install axolotl[ring-flash-attn]` (preferred)
|
||||||
|
- `pip install ring-flash-attn>=0.1.4`
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- Flash attention must be enabled for this to work (`flash_attention: true` in config YAML)
|
||||||
|
- May have a small performance overhead due to communication between GPUs
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
base_model: meta-llama/Llama-3-8B-Instruct
|
||||||
|
sequence_len: 8192
|
||||||
|
|
||||||
|
...
|
||||||
|
|
||||||
|
sequence_parallel_degree: 4 # Split each sequence into 4 parts, one per GPU
|
||||||
|
flash_attention: true # SP requires flash attention
|
||||||
|
micro_batch_size: 1 # SP requires this is set to 1
|
||||||
|
# (optional) strides across the key dimension; larger values use more memory but should make training a bit faster
|
||||||
|
heads_k_stride: 1
|
||||||
|
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
This will train the Llama 3 8B model with 8192 context length, with each sequence split
|
||||||
|
into 4 subsequences of length 2048 across 4 GPUs.
|
||||||
|
|
||||||
|
## Sample Packing with Sequence Parallelism
|
||||||
|
|
||||||
|
Sequence parallelism is compatible with Axolotl's sample packing functionality. When using both features together:
|
||||||
|
|
||||||
|
1. Samples are first packed together
|
||||||
|
2. The packed sequences are then divided across GPUs in the sequence parallel group
|
||||||
|
3. Position IDs are automatically adjusted to maintain proper relative positions
|
||||||
|
|
||||||
|
## Effect on Batch Size
|
||||||
|
|
||||||
|
First, note that sequence parallelism supports only the case where `micro_batch_size: 1`.
|
||||||
|
|
||||||
|
When using sequence parallelism, your effective global batch size is **divided** by the `sequence_parallel_degree`. This happens because:
|
||||||
|
|
||||||
|
- Each group of `sequence_parallel_degree` GPUs works on the same batch (just different parts of each sequence)
|
||||||
|
- The number of batches processed per step decreases
|
||||||
|
|
||||||
|
For example:
|
||||||
|
- With 8 GPUs and no sequence parallelism: 8 different batches are processed per step
|
||||||
|
- With 8 GPUs and `sequence_parallel_degree=4`: Only 2 different batches processed per step (each split across 4 GPUs)
|
||||||
|
- If your per-GPU `micro_batch_size` is 1, the global batch size decreases from 8 to 2
|
||||||
71
examples/cohere/command-r-7b-qlora.yml
Normal file
71
examples/cohere/command-r-7b-qlora.yml
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
base_model: CohereForAI/c4ai-command-r7b-12-2024
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: cohere
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch:
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
79
examples/gemma3/gemma-3-1b-qlora.yml
Normal file
79
examples/gemma3/gemma-3-1b-qlora.yml
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
base_model: google/gemma-3-1b-it
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
# gemma3 doesn't seem to play nice with ddp
|
||||||
|
ddp_find_unused_parameters: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# huggingface repo
|
||||||
|
chat_template: gemma3
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
field_messages: conversations
|
||||||
|
message_property_mappings:
|
||||||
|
role: from
|
||||||
|
content: value
|
||||||
|
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch:
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
68
examples/gemma3/gemma-3-4b-lora.yml
Normal file
68
examples/gemma3/gemma-3-4b-lora.yml
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
base_model: google/gemma-3-4b-it
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
# gemma3 doesn't seem to play nice with ddp
|
||||||
|
ddp_find_unused_parameters: true
|
||||||
|
|
||||||
|
chat_template: gemma3
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
@@ -82,3 +82,6 @@ deepspeed:
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
fsdp:
|
fsdp:
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
|
|
||||||
|
special_tokens:
|
||||||
|
pad_token: "<|end_of_text|>"
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ val_set_size: 0.0
|
|||||||
output_dir: ./outputs/lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
dataset_exact_deduplication: true
|
dataset_exact_deduplication: true
|
||||||
test_value: true
|
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
80
examples/llama-3/lora-1b-sample-packing-sequentially.yml
Normal file
80
examples/llama-3/lora-1b-sample-packing-sequentially.yml
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
base_model: meta-llama/Llama-3.2-1B
|
||||||
|
# optionally might have model_type or tokenizer_type
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
test_value: true
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
sample_packing_sequentially: true
|
||||||
|
curriculum_sampling: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_modules_to_save:
|
||||||
|
- embed_tokens
|
||||||
|
- lm_head
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
@@ -55,7 +55,7 @@ tf32: true
|
|||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: true
|
use_reentrant: false
|
||||||
early_stopping_patience:
|
early_stopping_patience:
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
|
|||||||
63
examples/llava/lora-7b.yaml
Normal file
63
examples/llava/lora-7b.yaml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
base_model: llava-hf/llava-1.5-7b-hf
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: llava
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
66
examples/mistral/mistral-small-3.1-24B-lora.yml
Normal file
66
examples/mistral/mistral-small-3.1-24B-lora.yml
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
base_model: mistralai/Mistral-Small-3.1-24B-Instruct-2503
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: mistral_v7_tekken
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.01
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet.
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
65
examples/pixtral/lora-12b.yml
Normal file
65
examples/pixtral/lora-12b.yml
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
base_model: mistral-community/pixtral-12b
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: pixtral
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <pad>
|
||||||
63
examples/qwen2-vl/lora-7b.yaml
Normal file
63
examples/qwen2-vl/lora-7b.yaml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
base_model: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
processor_type: AutoProcessor
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
# these 3 lines are needed for now to handle vision chat templates w images
|
||||||
|
skip_prepare_dataset: true
|
||||||
|
remove_unused_columns: false
|
||||||
|
sample_packing: false
|
||||||
|
|
||||||
|
chat_template: qwen2_vl
|
||||||
|
datasets:
|
||||||
|
- path: HuggingFaceH4/llava-instruct-mix-vsft
|
||||||
|
type: chat_template
|
||||||
|
split: train[:1%]
|
||||||
|
field_messages: messages
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 8192
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16:
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
eager_attention:
|
||||||
|
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
|
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
@@ -8,6 +8,7 @@ dynamic = ["version", "dependencies", "optional-dependencies"]
|
|||||||
description = "LLM Trainer"
|
description = "LLM Trainer"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
# license = "Apache-2.0"
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
axolotl = "axolotl.cli.main:main"
|
axolotl = "axolotl.cli.main:main"
|
||||||
|
|||||||
@@ -2,3 +2,5 @@ pre-commit
|
|||||||
black
|
black
|
||||||
mypy
|
mypy
|
||||||
types-requests
|
types-requests
|
||||||
|
quartodoc
|
||||||
|
jupyter
|
||||||
|
|||||||
@@ -1,24 +1,23 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.45.2
|
bitsandbytes==0.45.4
|
||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
flash-attn==2.7.4.post1
|
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
autoawq==0.2.7.post3
|
autoawq==0.2.7.post3
|
||||||
liger-kernel==0.5.3
|
liger-kernel==0.5.5
|
||||||
# END section
|
# END section
|
||||||
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.14.0
|
peft==0.15.0
|
||||||
transformers==4.49.0
|
transformers==4.50.3
|
||||||
tokenizers>=0.21.0
|
tokenizers>=0.21.1
|
||||||
accelerate==1.3.0
|
accelerate==1.5.2
|
||||||
datasets==3.2.0
|
datasets==3.5.0
|
||||||
deepspeed==0.16.1
|
deepspeed==0.16.4
|
||||||
trl==0.15.1
|
trl==0.16.0
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
hf_transfer
|
hf_transfer
|
||||||
@@ -36,6 +35,7 @@ einops
|
|||||||
colorama
|
colorama
|
||||||
numba
|
numba
|
||||||
numpy>=1.24.4,<=2.0.1
|
numpy>=1.24.4,<=2.0.1
|
||||||
|
|
||||||
# qlora things
|
# qlora things
|
||||||
evaluate==0.4.1
|
evaluate==0.4.1
|
||||||
scipy
|
scipy
|
||||||
@@ -62,4 +62,5 @@ antlr4-python3-runtime==4.13.2
|
|||||||
torchao==0.7.0
|
torchao==0.7.0
|
||||||
schedulefree==1.3.0
|
schedulefree==1.3.0
|
||||||
|
|
||||||
axolotl-contribs-lgpl==0.0.3
|
axolotl-contribs-lgpl==0.0.6
|
||||||
|
axolotl-contribs-mit==0.0.3
|
||||||
|
|||||||
@@ -1,315 +0,0 @@
|
|||||||
accelerate==0.34.1
|
|
||||||
addict==2.4.0
|
|
||||||
aiofiles==23.2.1
|
|
||||||
aiohttp==3.9.0
|
|
||||||
aiosignal==1.3.1
|
|
||||||
aiostream==0.5.2
|
|
||||||
alembic==1.13.1
|
|
||||||
annotated-types==0.6.0
|
|
||||||
annoy==1.17.3
|
|
||||||
ansible==6.7.0
|
|
||||||
ansible-core==2.13.13
|
|
||||||
ansible-vault==2.1.0
|
|
||||||
anyio==3.7.1
|
|
||||||
appdirs==1.4.4
|
|
||||||
art==6.0
|
|
||||||
asgiref==3.7.2
|
|
||||||
async-timeout==4.0.2
|
|
||||||
attrdict==2.0.1
|
|
||||||
attrs==22.2.0
|
|
||||||
awscli==1.32.75
|
|
||||||
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
|
|
||||||
backoff==2.2.1
|
|
||||||
base58==2.1.1
|
|
||||||
beartype==0.17.2
|
|
||||||
bitnet==0.2.1
|
|
||||||
bitsandbytes==0.42.0
|
|
||||||
bittensor==6.7.0
|
|
||||||
black==23.7.0
|
|
||||||
blinker==1.7.0
|
|
||||||
boto3==1.34.75
|
|
||||||
botocore==1.34.75
|
|
||||||
cachetools==5.3.3
|
|
||||||
cachy==0.1.1
|
|
||||||
certifi==2023.7.22
|
|
||||||
cffi==1.16.0
|
|
||||||
cfgv==3.3.1
|
|
||||||
chai-guanaco==1.2.4
|
|
||||||
charset-normalizer==3.2.0
|
|
||||||
cleo==0.6.8
|
|
||||||
click==8.1.7
|
|
||||||
cloudpickle==2.0.0
|
|
||||||
cohere==4.11.2
|
|
||||||
colorama==0.4.4
|
|
||||||
coloredlogs==15.0.1
|
|
||||||
CoLT5-attention==0.10.20
|
|
||||||
contextlib2==21.6.0
|
|
||||||
contourpy==1.2.0
|
|
||||||
cryptography==41.0.3
|
|
||||||
cycler==0.12.1
|
|
||||||
cytoolz==0.12.3
|
|
||||||
databricks-cli==0.18.0
|
|
||||||
dataclasses-json==0.5.7
|
|
||||||
datasets==2.11.0
|
|
||||||
ddt==1.6.0
|
|
||||||
decorator==5.1.1
|
|
||||||
deepspeed==0.15.0
|
|
||||||
# Editable Git install with no remote (dialogpt==0.1)
|
|
||||||
-e /Users/wing/Projects/ml/dialogpt/src
|
|
||||||
dill==0.3.6
|
|
||||||
distlib==0.3.6
|
|
||||||
docker==7.0.0
|
|
||||||
docker-pycreds==0.4.0
|
|
||||||
docstring-parser==0.15
|
|
||||||
docutils==0.16
|
|
||||||
ecdsa==0.18.0
|
|
||||||
einops==0.7.0
|
|
||||||
einops-exts==0.0.4
|
|
||||||
einx==0.1.3
|
|
||||||
entrypoints==0.4
|
|
||||||
eth-hash==0.6.0
|
|
||||||
eth-keys==0.5.0
|
|
||||||
eth-typing==4.0.0
|
|
||||||
eth-utils==2.3.1
|
|
||||||
evaluate==0.4.0
|
|
||||||
exceptiongroup==1.1.1
|
|
||||||
fastapi==0.109.2
|
|
||||||
fastcore==1.5.29
|
|
||||||
ffmpy==0.4.0
|
|
||||||
filelock==3.12.2
|
|
||||||
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
|
|
||||||
fire==0.5.0
|
|
||||||
first==2.0.2
|
|
||||||
flake8==7.0.0
|
|
||||||
Flask==3.0.1
|
|
||||||
fonttools==4.47.2
|
|
||||||
frozendict==2.4.1
|
|
||||||
frozenlist==1.3.3
|
|
||||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
|
||||||
fsspec==2023.6.0
|
|
||||||
fuzzywuzzy==0.18.0
|
|
||||||
gitdb==4.0.10
|
|
||||||
GitPython==3.1.31
|
|
||||||
google-pasta==0.2.0
|
|
||||||
gradio==4.42.0
|
|
||||||
gradio_client==1.3.0
|
|
||||||
greenlet==2.0.2
|
|
||||||
grpclib==0.4.7
|
|
||||||
gunicorn==21.2.0
|
|
||||||
h11==0.14.0
|
|
||||||
h2==4.1.0
|
|
||||||
hpack==4.0.0
|
|
||||||
httpcore==0.17.3
|
|
||||||
httpx==0.24.1
|
|
||||||
huggingface-hub==0.23.4
|
|
||||||
humanfriendly==10.0
|
|
||||||
hyperframe==6.0.1
|
|
||||||
identify==2.5.24
|
|
||||||
idna==3.4
|
|
||||||
immutables==0.20
|
|
||||||
importlib-metadata==6.7.0
|
|
||||||
importlib-resources==6.1.1
|
|
||||||
inflection==0.5.1
|
|
||||||
iniconfig==2.0.0
|
|
||||||
itsdangerous==2.1.2
|
|
||||||
Jinja2==3.1.2
|
|
||||||
jmespath==1.0.1
|
|
||||||
joblib==1.3.2
|
|
||||||
jsonlines==3.1.0
|
|
||||||
jsonschema==2.6.0
|
|
||||||
kiwisolver==1.4.5
|
|
||||||
langchain==0.0.144
|
|
||||||
Levenshtein==0.24.0
|
|
||||||
libcst==1.1.0
|
|
||||||
liger-kernel==0.0.0
|
|
||||||
lion-pytorch==0.1.2
|
|
||||||
llama-cpp-python==0.1.36
|
|
||||||
llvmlite==0.40.1
|
|
||||||
local-attention==1.9.0
|
|
||||||
loguru==0.7.0
|
|
||||||
Mako==1.3.2
|
|
||||||
Markdown==3.5.2
|
|
||||||
markdown-it-py==3.0.0
|
|
||||||
markdown2==2.4.10
|
|
||||||
MarkupSafe==2.1.2
|
|
||||||
marshmallow==3.19.0
|
|
||||||
marshmallow-enum==1.5.1
|
|
||||||
matplotlib==3.8.2
|
|
||||||
mccabe==0.7.0
|
|
||||||
mdurl==0.1.2
|
|
||||||
MEGABYTE-pytorch==0.0.7
|
|
||||||
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
|
|
||||||
mlflow==2.10.0
|
|
||||||
modal==0.62.77
|
|
||||||
more-itertools==10.2.0
|
|
||||||
mpmath==1.2.1
|
|
||||||
msgpack==1.0.7
|
|
||||||
msgpack-numpy-opentensor==0.5.0
|
|
||||||
multidict==6.0.4
|
|
||||||
multiprocess==0.70.14
|
|
||||||
munch==2.5.0
|
|
||||||
mypy==1.3.0
|
|
||||||
mypy-extensions==1.0.0
|
|
||||||
nest-asyncio==1.6.0
|
|
||||||
netaddr==0.10.1
|
|
||||||
networkx==3.0rc1
|
|
||||||
nh3==0.2.14
|
|
||||||
nodeenv==1.8.0
|
|
||||||
nomic==2.0.2
|
|
||||||
numba==0.57.1
|
|
||||||
numexpr==2.8.4
|
|
||||||
numpy==1.24.4
|
|
||||||
oauthlib==3.2.2
|
|
||||||
openai==0.27.4
|
|
||||||
openapi==1.1.0
|
|
||||||
openapi-schema-pydantic==1.2.4
|
|
||||||
optimum==1.8.6
|
|
||||||
orjson==3.10.7
|
|
||||||
packaging==23.1
|
|
||||||
pandas==2.0.0
|
|
||||||
parameterized==0.9.0
|
|
||||||
password-strength==0.0.3.post2
|
|
||||||
pastel==0.1.1
|
|
||||||
pathos==0.3.0
|
|
||||||
pathspec==0.11.1
|
|
||||||
pathtools==0.1.2
|
|
||||||
peft==0.11.1
|
|
||||||
pendulum==3.0.0
|
|
||||||
Pillow==9.5.0
|
|
||||||
pip-tools==1.11.0
|
|
||||||
platformdirs==3.2.0
|
|
||||||
pluggy==1.4.0
|
|
||||||
poetry==0.7.1
|
|
||||||
pox==0.3.2
|
|
||||||
ppft==1.7.6.6
|
|
||||||
pre-commit==3.3.2
|
|
||||||
prettytable==3.10.0
|
|
||||||
prompt-toolkit==3.0.39
|
|
||||||
protobuf==3.20.2
|
|
||||||
protobuf3-to-dict==0.1.5
|
|
||||||
psutil==5.9.5
|
|
||||||
psycopg==3.1.18
|
|
||||||
PuLP==2.8.0
|
|
||||||
py==1.11.0
|
|
||||||
py-bip39-bindings==0.1.11
|
|
||||||
py-cpuinfo==9.0.0
|
|
||||||
py-ed25519-zebra-bindings==1.0.1
|
|
||||||
py-sr25519-bindings==0.2.0
|
|
||||||
pyarrow==11.0.0
|
|
||||||
pyasn1==0.6.0
|
|
||||||
pycodestyle==2.11.1
|
|
||||||
pycparser==2.21
|
|
||||||
pycryptodome==3.20.0
|
|
||||||
pydantic==2.5.3
|
|
||||||
pydantic_core==2.14.6
|
|
||||||
pydub==0.25.1
|
|
||||||
pyfiglet==0.8.post1
|
|
||||||
pyflakes==3.2.0
|
|
||||||
Pygments==2.15.1
|
|
||||||
PyJWT==2.8.0
|
|
||||||
pylev==1.4.0
|
|
||||||
PyNaCl==1.5.0
|
|
||||||
pynvml==11.5.0
|
|
||||||
pyparsing==2.4.7
|
|
||||||
pyrsistent==0.14.11
|
|
||||||
pytest==8.0.2
|
|
||||||
pytest-asyncio==0.23.4
|
|
||||||
python-dateutil==2.8.2
|
|
||||||
python-dotenv==1.0.1
|
|
||||||
python-Levenshtein==0.24.0
|
|
||||||
python-multipart==0.0.9
|
|
||||||
pytz==2023.3
|
|
||||||
PyYAML==6.0.1
|
|
||||||
querystring-parser==1.2.4
|
|
||||||
rapidfuzz==3.6.1
|
|
||||||
regex==2023.6.3
|
|
||||||
requests==2.31.0
|
|
||||||
requests-toolbelt==0.8.0
|
|
||||||
resolvelib==0.8.1
|
|
||||||
responses==0.18.0
|
|
||||||
retry==0.9.2
|
|
||||||
rich==13.7.0
|
|
||||||
rsa==4.7.2
|
|
||||||
ruff==0.6.3
|
|
||||||
s3transfer==0.10.1
|
|
||||||
safetensors==0.4.5
|
|
||||||
sagemaker==2.148.0
|
|
||||||
scalecodec==1.2.7
|
|
||||||
schedulefree==1.2.1
|
|
||||||
schema==0.7.5
|
|
||||||
scikit-learn==1.4.0
|
|
||||||
scipy==1.9.3
|
|
||||||
seaborn==0.13.2
|
|
||||||
semantic-version==2.10.0
|
|
||||||
sentencepiece==0.2.0
|
|
||||||
sentry-sdk==1.19.1
|
|
||||||
setproctitle==1.3.2
|
|
||||||
shellingham==1.5.4
|
|
||||||
shortuuid==1.0.11
|
|
||||||
shtab==1.6.5
|
|
||||||
sigtools==4.0.1
|
|
||||||
six==1.16.0
|
|
||||||
skypilot==0.4.1
|
|
||||||
smdebug-rulesconfig==1.0.1
|
|
||||||
smmap==5.0.0
|
|
||||||
sniffio==1.3.0
|
|
||||||
SQLAlchemy==1.4.47
|
|
||||||
sqlparse==0.4.4
|
|
||||||
starlette==0.36.3
|
|
||||||
substrate-interface==1.5.2
|
|
||||||
svgwrite==1.4.3
|
|
||||||
sympy==1.11.1
|
|
||||||
synchronicity==0.6.7
|
|
||||||
tabulate==0.9.0
|
|
||||||
tblib==1.7.0
|
|
||||||
tenacity==8.2.2
|
|
||||||
tensor-parallel==2.0.0
|
|
||||||
termcolor==2.2.0
|
|
||||||
text2art==0.2.0
|
|
||||||
threadpoolctl==3.2.0
|
|
||||||
tiktoken==0.6.0
|
|
||||||
time-machine==2.14.1
|
|
||||||
timm==0.9.16
|
|
||||||
tokenizers==0.19.1
|
|
||||||
tokenmonster==1.1.12
|
|
||||||
toml==0.9.6
|
|
||||||
tomli==2.0.1
|
|
||||||
tomlkit==0.12.0
|
|
||||||
toolz==0.12.1
|
|
||||||
torch==2.2.0
|
|
||||||
torchdata==0.6.1
|
|
||||||
torchdiffeq==0.2.3
|
|
||||||
TorchFix==0.4.0
|
|
||||||
torchtext==0.15.2
|
|
||||||
torchvision==0.17.0
|
|
||||||
tqdm==4.66.2
|
|
||||||
transformers==4.44.2
|
|
||||||
trl==0.9.6
|
|
||||||
typer==0.12.5
|
|
||||||
types-certifi==2021.10.8.3
|
|
||||||
types-requests==2.31.0.20240125
|
|
||||||
types-setuptools==69.0.0.20240125
|
|
||||||
types-toml==0.10.8.7
|
|
||||||
typing==3.7.4.3
|
|
||||||
typing-inspect==0.8.0
|
|
||||||
typing_extensions==4.9.0
|
|
||||||
tyro==0.5.18
|
|
||||||
tzdata==2023.3
|
|
||||||
unique-names-generator==1.0.2
|
|
||||||
urllib3==2.2.2
|
|
||||||
uvicorn==0.22.0
|
|
||||||
vector_quantize_pytorch==1.14.1
|
|
||||||
virtualenv==20.23.0
|
|
||||||
voyager==2.0.2
|
|
||||||
wandb==0.16.2
|
|
||||||
watchfiles==0.21.0
|
|
||||||
wavedrom==2.0.3.post3
|
|
||||||
wcwidth==0.2.6
|
|
||||||
websocket-client==1.7.0
|
|
||||||
websockets==12.0
|
|
||||||
Werkzeug==3.0.1
|
|
||||||
wonderwords==2.2.0
|
|
||||||
xxhash==3.2.0
|
|
||||||
yarl==1.8.2
|
|
||||||
zetascale==2.2.7
|
|
||||||
zipp==3.15.0
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
helper script to parse chat datasets into a usable yaml
|
helper script to parse chat datasets into a usable yaml
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import yaml
|
import yaml
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Script to output the correct installation command for cut-cross-entropy."""
|
"""Script to output the correct installation command for cut-cross-entropy."""
|
||||||
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -24,5 +25,5 @@ if cce_spec:
|
|||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
UNINSTALL_PREFIX
|
||||||
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
|
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
|
||||||
)
|
)
|
||||||
|
|||||||
97
setup.py
97
setup.py
@@ -10,19 +10,13 @@ from pathlib import Path
|
|||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
def parse_requirements():
|
def parse_requirements(extras_require_map):
|
||||||
_install_requires = []
|
_install_requires = []
|
||||||
_dependency_links = []
|
_dependency_links = []
|
||||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||||
lines = [r.strip() for r in requirements_file.readlines()]
|
lines = [r.strip() for r in requirements_file.readlines()]
|
||||||
for line in lines:
|
for line in lines:
|
||||||
is_extras = (
|
is_extras = "deepspeed" in line or "mamba-ssm" in line
|
||||||
"flash-attn" in line
|
|
||||||
or "flash-attention" in line
|
|
||||||
or "deepspeed" in line
|
|
||||||
or "mamba-ssm" in line
|
|
||||||
or "lion-pytorch" in line
|
|
||||||
)
|
|
||||||
if line.startswith("--extra-index-url"):
|
if line.startswith("--extra-index-url"):
|
||||||
# Handle custom index URLs
|
# Handle custom index URLs
|
||||||
_, url = line.split()
|
_, url = line.split()
|
||||||
@@ -39,7 +33,6 @@ def parse_requirements():
|
|||||||
"bitsandbytes",
|
"bitsandbytes",
|
||||||
"triton",
|
"triton",
|
||||||
"mamba-ssm",
|
"mamba-ssm",
|
||||||
"flash-attn",
|
|
||||||
"xformers",
|
"xformers",
|
||||||
"autoawq",
|
"autoawq",
|
||||||
"liger-kernel",
|
"liger-kernel",
|
||||||
@@ -74,6 +67,7 @@ def parse_requirements():
|
|||||||
if (major, minor) >= (2, 6):
|
if (major, minor) >= (2, 6):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
_install_requires.append("xformers==0.0.29.post2")
|
_install_requires.append("xformers==0.0.29.post2")
|
||||||
|
extras_require_map["vllm"] = ["vllm==0.8.1"]
|
||||||
elif (major, minor) >= (2, 5):
|
elif (major, minor) >= (2, 5):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
@@ -93,7 +87,7 @@ def parse_requirements():
|
|||||||
|
|
||||||
except PackageNotFoundError:
|
except PackageNotFoundError:
|
||||||
pass
|
pass
|
||||||
return _install_requires, _dependency_links
|
return _install_requires, _dependency_links, extras_require_map
|
||||||
|
|
||||||
|
|
||||||
def get_package_version():
|
def get_package_version():
|
||||||
@@ -110,7 +104,50 @@ def get_package_version():
|
|||||||
return version_
|
return version_
|
||||||
|
|
||||||
|
|
||||||
install_requires, dependency_links = parse_requirements()
|
extras_require = {
|
||||||
|
"flash-attn": ["flash-attn==2.7.4.post1"],
|
||||||
|
"ring-flash-attn": [
|
||||||
|
"flash-attn==2.7.4.post1",
|
||||||
|
"ring-flash-attn>=0.1.4",
|
||||||
|
"yunchang==0.6.0",
|
||||||
|
],
|
||||||
|
"deepspeed": [
|
||||||
|
"deepspeed==0.16.4",
|
||||||
|
"deepspeed-kernels",
|
||||||
|
],
|
||||||
|
"mamba-ssm": [
|
||||||
|
"mamba-ssm==1.2.0.post1",
|
||||||
|
"causal_conv1d",
|
||||||
|
],
|
||||||
|
"auto-gptq": [
|
||||||
|
"auto-gptq==0.5.1",
|
||||||
|
],
|
||||||
|
"mlflow": [
|
||||||
|
"mlflow",
|
||||||
|
],
|
||||||
|
"galore": [
|
||||||
|
"galore_torch",
|
||||||
|
],
|
||||||
|
"apollo": [
|
||||||
|
"apollo-torch",
|
||||||
|
],
|
||||||
|
"optimizers": [
|
||||||
|
"galore_torch",
|
||||||
|
"apollo-torch",
|
||||||
|
"lomo-optim==0.1.1",
|
||||||
|
"torch-optimi==0.2.1",
|
||||||
|
],
|
||||||
|
"ray": [
|
||||||
|
"ray[train]",
|
||||||
|
],
|
||||||
|
"vllm": [
|
||||||
|
"vllm==0.7.2",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
install_requires, dependency_links, extras_require_build = parse_requirements(
|
||||||
|
extras_require
|
||||||
|
)
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
version=get_package_version(),
|
version=get_package_version(),
|
||||||
@@ -123,41 +160,5 @@ setup(
|
|||||||
"axolotl=axolotl.cli.main:main",
|
"axolotl=axolotl.cli.main:main",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
extras_require={
|
extras_require=extras_require_build,
|
||||||
"flash-attn": [
|
|
||||||
"flash-attn==2.7.4.post1",
|
|
||||||
],
|
|
||||||
"deepspeed": [
|
|
||||||
"deepspeed==0.16.1",
|
|
||||||
"deepspeed-kernels",
|
|
||||||
],
|
|
||||||
"mamba-ssm": [
|
|
||||||
"mamba-ssm==1.2.0.post1",
|
|
||||||
"causal_conv1d",
|
|
||||||
],
|
|
||||||
"auto-gptq": [
|
|
||||||
"auto-gptq==0.5.1",
|
|
||||||
],
|
|
||||||
"mlflow": [
|
|
||||||
"mlflow",
|
|
||||||
],
|
|
||||||
"lion-pytorch": [
|
|
||||||
"lion-pytorch==0.1.2",
|
|
||||||
],
|
|
||||||
"galore": [
|
|
||||||
"galore_torch",
|
|
||||||
],
|
|
||||||
"optimizers": [
|
|
||||||
"galore_torch",
|
|
||||||
"lion-pytorch==0.1.2",
|
|
||||||
"lomo-optim==0.1.1",
|
|
||||||
"torch-optimi==0.2.1",
|
|
||||||
],
|
|
||||||
"ray": [
|
|
||||||
"ray[train]",
|
|
||||||
],
|
|
||||||
"vllm": [
|
|
||||||
"vllm==0.7.2",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -35,6 +35,55 @@ class TrainerCliArgs:
|
|||||||
num_processes: Optional[int] = field(default=None)
|
num_processes: Optional[int] = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VllmServeCliArgs:
|
||||||
|
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
|
||||||
|
|
||||||
|
tensor_parallel_size: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Number of tensor parallel workers to use."},
|
||||||
|
)
|
||||||
|
host: str = field(
|
||||||
|
default="0.0.0.0", # nosec B104
|
||||||
|
metadata={"help": "Host address to run the server on."},
|
||||||
|
)
|
||||||
|
port: int = field(
|
||||||
|
default=8000,
|
||||||
|
metadata={"help": "Port to run the server on."},
|
||||||
|
)
|
||||||
|
gpu_memory_utilization: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
||||||
|
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
|
||||||
|
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
|
||||||
|
"out-of-memory (OOM) errors during initialization."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dtype: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
||||||
|
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
max_model_len: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
|
||||||
|
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
|
||||||
|
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
enable_prefix_caching: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
|
||||||
|
"hardware support this feature."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EvaluateCliArgs:
|
class EvaluateCliArgs:
|
||||||
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
|
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
launch axolotl in supported cloud platforms
|
launch axolotl in supported cloud platforms
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
base class for cloud platforms from cli
|
base class for cloud platforms from cli
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Modal Cloud support from CLI
|
Modal Cloud support from CLI
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -113,7 +114,7 @@ class ModalCloud(Cloud):
|
|||||||
[
|
[
|
||||||
# Random id for cache busting of branch commits
|
# Random id for cache busting of branch commits
|
||||||
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
|
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
|
||||||
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
|
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -270,6 +271,7 @@ def _preprocess(config_yaml: str, volumes=None):
|
|||||||
|
|
||||||
|
|
||||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||||
|
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
||||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||||
f_out.write(config_yaml)
|
f_out.write(config_yaml)
|
||||||
run_folder = "/workspace/mounts"
|
run_folder = "/workspace/mounts"
|
||||||
@@ -288,6 +290,7 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def _lm_eval(config_yaml: str, volumes=None):
|
def _lm_eval(config_yaml: str, volumes=None):
|
||||||
|
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
||||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||||
f_out.write(config_yaml)
|
f_out.write(config_yaml)
|
||||||
run_folder = "/workspace/mounts"
|
run_folder = "/workspace/mounts"
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ def do_inference(
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Inference-specific CLI arguments.
|
cli_args: Inference-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
@@ -151,7 +151,7 @@ def do_inference_gradio(
|
|||||||
"""
|
"""
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
|
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg, inference=True)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Click CLI definitions for various axolotl commands."""
|
"""Click CLI definitions for various axolotl commands."""
|
||||||
|
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -13,7 +14,12 @@ import yaml
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
import axolotl
|
import axolotl
|
||||||
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
from axolotl.cli.args import (
|
||||||
|
EvaluateCliArgs,
|
||||||
|
PreprocessCliArgs,
|
||||||
|
TrainerCliArgs,
|
||||||
|
VllmServeCliArgs,
|
||||||
|
)
|
||||||
from axolotl.cli.sweeps import generate_sweep_configs
|
from axolotl.cli.sweeps import generate_sweep_configs
|
||||||
from axolotl.cli.utils import (
|
from axolotl.cli.utils import (
|
||||||
add_options_from_config,
|
add_options_from_config,
|
||||||
@@ -22,9 +28,10 @@ from axolotl.cli.utils import (
|
|||||||
fetch_from_github,
|
fetch_from_github,
|
||||||
filter_none_kwargs,
|
filter_none_kwargs,
|
||||||
)
|
)
|
||||||
|
from axolotl.cli.vllm_serve import do_vllm_serve
|
||||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@@ -315,6 +322,14 @@ def fetch(directory: str, dest: Optional[str]) -> None:
|
|||||||
fetch_from_github(f"{directory}/", dest)
|
fetch_from_github(f"{directory}/", dest)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@add_options_from_dataclass(VllmServeCliArgs)
|
||||||
|
@filter_none_kwargs
|
||||||
|
def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
|
||||||
|
do_vllm_serve(config, cli_args)
|
||||||
|
|
||||||
|
|
||||||
cli.add_command(lm_eval)
|
cli.add_command(lm_eval)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
"""
|
"""
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
|
model, tokenizer, processor = load_model_and_tokenizer(cfg=cfg)
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
LOG.info("Running merge of LoRA with base model...")
|
LOG.info("Running merge of LoRA with base model...")
|
||||||
@@ -44,6 +44,9 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
|
|||||||
)
|
)
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
|
|
||||||
|
if processor:
|
||||||
|
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""CLI to run training on a model."""
|
"""CLI to run training on a model."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -16,13 +17,14 @@ from axolotl.cli.config import load_cfg
|
|||||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.config import normalize_config, resolve_dtype
|
from axolotl.utils.config import normalize_config, resolve_dtype
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||||
"""
|
"""
|
||||||
Trains a `transformers` model by first loading the dataset(s) specified in the
|
Trains a `transformers` model by first loading the dataset(s) specified in the
|
||||||
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
|
||||||
@@ -32,25 +34,27 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
cli_args: Training-specific CLI arguments.
|
cli_args: Training-specific CLI arguments.
|
||||||
"""
|
"""
|
||||||
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
|
set_pytorch_cuda_alloc_conf()
|
||||||
|
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
check_user_token()
|
||||||
|
|
||||||
if cfg.rl:
|
if cfg.rl:
|
||||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
else:
|
else:
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
|
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||||
|
del model, tokenizer, trainer
|
||||||
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
|
||||||
del model
|
|
||||||
del tokenizer
|
|
||||||
|
|
||||||
plugin_manager.post_train_unload(cfg)
|
plugin_manager.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
"""
|
"""
|
||||||
Parses `axolotl` config, CLI args, and calls `do_train`.
|
Parses `axolotl` config, CLI args, and calls `do_train`.
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import dataclasses
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import typing
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import NoneType
|
from types import NoneType
|
||||||
@@ -14,17 +13,22 @@ from typing import Any, Callable, Type, Union, get_args, get_origin
|
|||||||
import click
|
import click
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
|
from transformers import (
|
||||||
|
PreTrainedModel,
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
PreTrainedTokenizerFast,
|
||||||
|
ProcessorMixin,
|
||||||
|
)
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||||
|
|
||||||
configure_logging()
|
configure_logging()
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def strip_optional_type(field_type: type | typing._SpecialForm | None):
|
def strip_optional_type(field_type: type | str | None):
|
||||||
"""
|
"""
|
||||||
Extracts the non-`None` type from an `Optional` / `Union` type.
|
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||||
|
|
||||||
@@ -296,9 +300,13 @@ def load_model_and_tokenizer(
|
|||||||
*,
|
*,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
inference: bool = False,
|
inference: bool = False,
|
||||||
) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
|
) -> tuple[
|
||||||
|
PreTrainedModel,
|
||||||
|
PreTrainedTokenizer | PreTrainedTokenizerFast | Any,
|
||||||
|
ProcessorMixin | None,
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Helper function for loading a model and tokenizer specified in the given `axolotl`
|
Helper function for loading a model, tokenizer, and processor specified in the given `axolotl`
|
||||||
config.
|
config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -306,7 +314,7 @@ def load_model_and_tokenizer(
|
|||||||
inference: Boolean denoting inference mode.
|
inference: Boolean denoting inference mode.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`transformers` model and tokenizer.
|
Tuple of (PreTrainedModel, PreTrainedTokenizer, ProcessorMixin).
|
||||||
"""
|
"""
|
||||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
@@ -314,4 +322,9 @@ def load_model_and_tokenizer(
|
|||||||
LOG.info("loading model...")
|
LOG.info("loading model...")
|
||||||
model, _ = load_model(cfg, tokenizer, inference=inference)
|
model, _ = load_model(cfg, tokenizer, inference=inference)
|
||||||
|
|
||||||
return model, tokenizer
|
processor = None
|
||||||
|
if cfg.is_multimodal:
|
||||||
|
LOG.info("loading processor...")
|
||||||
|
processor = load_processor(cfg, tokenizer)
|
||||||
|
|
||||||
|
return model, tokenizer, processor
|
||||||
|
|||||||
55
src/axolotl/cli/vllm_serve.py
Normal file
55
src/axolotl/cli/vllm_serve.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""
|
||||||
|
CLI to start the vllm server for online RL
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from trl.scripts.vllm_serve import ScriptArguments
|
||||||
|
from trl.scripts.vllm_serve import main as vllm_serve_main
|
||||||
|
|
||||||
|
from axolotl.cli.config import load_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def do_vllm_serve(
|
||||||
|
config: Union[Path, str],
|
||||||
|
cli_args: dict,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Starts the VLLM server for serving LLM models used for online RL
|
||||||
|
|
||||||
|
Args
|
||||||
|
:param cfg: Parsed doct of the YAML config
|
||||||
|
:param cli_args: dict of additional command-line arguments of type VllmServeCliArgs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
process_id: the process id of the started VLLM server
|
||||||
|
"""
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
model = cfg.base_model
|
||||||
|
|
||||||
|
tensor_parallel_size = (
|
||||||
|
cli_args.get("tensor_parallel_size") or cfg.vllm.tensor_parallel_size
|
||||||
|
)
|
||||||
|
host = cli_args.get("host") or cfg.vllm.host
|
||||||
|
port = cli_args.get("port") or cfg.vllm.port
|
||||||
|
gpu_memory_utilization = (
|
||||||
|
cli_args.get("gpu_memory_utilization") or cfg.vllm.gpu_memory_utilization
|
||||||
|
)
|
||||||
|
dtype = cli_args.get("dtype") or cfg.vllm.dtype
|
||||||
|
max_model_len = cli_args.get("max_model_len") or cfg.vllm.max_model_len
|
||||||
|
enable_prefix_caching = (
|
||||||
|
cli_args.get("enable_prefix_caching") or cfg.vllm.enable_prefix_caching
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_script_args = ScriptArguments(
|
||||||
|
model,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
dtype=dtype,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
enable_prefix_caching=enable_prefix_caching,
|
||||||
|
)
|
||||||
|
vllm_serve_main(vllm_script_args)
|
||||||
@@ -24,8 +24,8 @@ class TrainDatasetMeta:
|
|||||||
"""Dataclass with fields for training and validation datasets and metadata."""
|
"""Dataclass with fields for training and validation datasets and metadata."""
|
||||||
|
|
||||||
train_dataset: Dataset
|
train_dataset: Dataset
|
||||||
eval_dataset: Optional[Dataset] = None
|
eval_dataset: Dataset | None = None
|
||||||
total_num_steps: Optional[int] = None
|
total_num_steps: int | None = None
|
||||||
|
|
||||||
|
|
||||||
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
|
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
|
||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
ChatML transformation functions for MessageContents
|
ChatML transformation functions for MessageContents
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..messages import MessageContents, Messages
|
from ..messages import MessageContents, Messages
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Llama 3.x chat formatting functions for MessageContents
|
Llama 3.x chat formatting functions for MessageContents
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from ..messages import MessageContents, Messages
|
from ..messages import MessageContents, Messages
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
shared functions for format transforms
|
shared functions for format transforms
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from axolotl.core.chat.messages import MessageContents, Messages
|
from axolotl.core.chat.messages import MessageContents, Messages
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
internal message representations of chat messages
|
internal message representations of chat messages
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, List, Optional, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
chat dataset module
|
chat dataset module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
@@ -43,7 +44,7 @@ class TokenizedChatDataset(Dataset):
|
|||||||
process_or_cpu_count: int = (
|
process_or_cpu_count: int = (
|
||||||
process_count or os.cpu_count() # type: ignore[assignment]
|
process_count or os.cpu_count() # type: ignore[assignment]
|
||||||
)
|
)
|
||||||
num_proc = min(64, process_or_cpu_count)
|
num_proc = min(32, process_or_cpu_count)
|
||||||
features = data.features.keys()
|
features = data.features.keys()
|
||||||
tokenized_data = data.map(
|
tokenized_data = data.map(
|
||||||
map_fn,
|
map_fn,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
|
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Mapping, Union
|
from typing import Any, Mapping, Union
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,9 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
"""
|
"""Builder for the training args and trainer"""
|
||||||
Builder for the training args and trainer
|
|
||||||
"""
|
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import importlib
|
import importlib
|
||||||
@@ -35,9 +33,10 @@ from transformers import (
|
|||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
)
|
)
|
||||||
|
from transformers.training_args import OptimizerNames
|
||||||
from trl.trainer.utils import RewardDataCollatorWithPadding
|
from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
AxolotlMambaTrainer,
|
AxolotlMambaTrainer,
|
||||||
@@ -61,6 +60,7 @@ from axolotl.core.training_args import (
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback
|
from axolotl.monkeypatch.relora import ReLoRACallback
|
||||||
|
from axolotl.processing_strategies import get_processing_strategy
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
@@ -69,7 +69,6 @@ from axolotl.utils.callbacks import (
|
|||||||
LossWatchDogCallback,
|
LossWatchDogCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
SaveAxolotlConfigtoWandBCallback,
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SaveModelCallback,
|
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
causal_lm_bench_eval_callback_factory,
|
causal_lm_bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
@@ -85,19 +84,18 @@ from axolotl.utils.collators import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
from axolotl.utils.models import ensure_dtype
|
from axolotl.utils.models import ensure_dtype
|
||||||
|
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""Base class for trainer builder."""
|
||||||
Base class for trainer builder
|
|
||||||
"""
|
|
||||||
|
|
||||||
_train_dataset = None
|
_train_dataset = None
|
||||||
_eval_dataset = None
|
_eval_dataset = None
|
||||||
@@ -110,9 +108,9 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
|
||||||
# in case the model supports tagging, add the axolotl tag.
|
# If the model supports tagging, add the axolotl tag.
|
||||||
# This makes sure the tag is correctly pushed even if a user calls
|
# This makes sure the tag is correctly pushed even if a user calls
|
||||||
# model.push_to_hub instad of trainer.push_to_hub.
|
# model.push_to_hub instead of trainer.push_to_hub.
|
||||||
if hasattr(model, "add_model_tags"):
|
if hasattr(model, "add_model_tags"):
|
||||||
model.add_model_tags(["axolotl"])
|
model.add_model_tags(["axolotl"])
|
||||||
|
|
||||||
@@ -227,8 +225,8 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
|
|
||||||
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||||
"""
|
"""
|
||||||
Build the HuggingFace training args/trainer for causal models
|
Build the HuggingFace training args/trainer for causal models and reward modeling
|
||||||
and reward modelling using TRL.
|
using TRL.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
@@ -250,7 +248,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.gc_steps:
|
if self.cfg.gc_steps:
|
||||||
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
|
||||||
callbacks.append(SaveModelCallback())
|
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -332,9 +329,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs = {}
|
training_arguments_kwargs = {}
|
||||||
|
|
||||||
if self.cfg.include_tokens_per_second is not None:
|
if self.cfg.include_tokens_per_second is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["include_tokens_per_second"] = (
|
||||||
"include_tokens_per_second"
|
self.cfg.include_tokens_per_second
|
||||||
] = self.cfg.include_tokens_per_second
|
)
|
||||||
|
|
||||||
if self.cfg.bf16 == "full":
|
if self.cfg.bf16 == "full":
|
||||||
training_arguments_kwargs["bf16_full_eval"] = True
|
training_arguments_kwargs["bf16_full_eval"] = True
|
||||||
@@ -351,13 +348,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["seed"] = self.cfg.seed
|
training_arguments_kwargs["seed"] = self.cfg.seed
|
||||||
|
|
||||||
if self.cfg.gradient_checkpointing:
|
if self.cfg.gradient_checkpointing:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["gradient_checkpointing"] = (
|
||||||
"gradient_checkpointing"
|
self.cfg.gradient_checkpointing
|
||||||
] = self.cfg.gradient_checkpointing
|
)
|
||||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["gradient_checkpointing_kwargs"] = (
|
||||||
"gradient_checkpointing_kwargs"
|
self.cfg.gradient_checkpointing_kwargs
|
||||||
] = self.cfg.gradient_checkpointing_kwargs
|
)
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||||
if self.cfg.fsdp_config:
|
if self.cfg.fsdp_config:
|
||||||
@@ -373,9 +370,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
||||||
|
|
||||||
if self.cfg.lr_quadratic_warmup is not None:
|
if self.cfg.lr_quadratic_warmup is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["lr_quadratic_warmup"] = (
|
||||||
"lr_quadratic_warmup"
|
self.cfg.lr_quadratic_warmup
|
||||||
] = self.cfg.lr_quadratic_warmup
|
)
|
||||||
|
|
||||||
if self.cfg.adam_beta1:
|
if self.cfg.adam_beta1:
|
||||||
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
||||||
@@ -399,28 +396,28 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
if self.cfg.dataloader_pin_memory is not None:
|
if self.cfg.dataloader_pin_memory is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["dataloader_pin_memory"] = (
|
||||||
"dataloader_pin_memory"
|
self.cfg.dataloader_pin_memory
|
||||||
] = self.cfg.dataloader_pin_memory
|
)
|
||||||
if self.cfg.dataloader_num_workers is not None:
|
if self.cfg.dataloader_num_workers is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["dataloader_num_workers"] = (
|
||||||
"dataloader_num_workers"
|
self.cfg.dataloader_num_workers
|
||||||
] = self.cfg.dataloader_num_workers
|
)
|
||||||
if self.cfg.dataloader_prefetch_factor is not None:
|
if self.cfg.dataloader_prefetch_factor is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["dataloader_prefetch_factor"] = (
|
||||||
"dataloader_prefetch_factor"
|
self.cfg.dataloader_prefetch_factor
|
||||||
] = self.cfg.dataloader_prefetch_factor
|
)
|
||||||
if self.cfg.dataloader_drop_last is not None:
|
if self.cfg.dataloader_drop_last is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["dataloader_drop_last"] = (
|
||||||
"dataloader_drop_last"
|
self.cfg.dataloader_drop_last
|
||||||
] = self.cfg.dataloader_drop_last
|
)
|
||||||
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
||||||
training_arguments_kwargs["dataloader_drop_last"] = True
|
training_arguments_kwargs["dataloader_drop_last"] = True
|
||||||
|
|
||||||
if self.cfg.remove_unused_columns is not None:
|
if self.cfg.remove_unused_columns is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["remove_unused_columns"] = (
|
||||||
"remove_unused_columns"
|
self.cfg.remove_unused_columns
|
||||||
] = self.cfg.remove_unused_columns
|
)
|
||||||
|
|
||||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||||
# no eval set, so don't eval
|
# no eval set, so don't eval
|
||||||
@@ -452,9 +449,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.do_causal_lm_eval:
|
if self.cfg.do_causal_lm_eval:
|
||||||
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
|
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
|
||||||
if self.cfg.metric_for_best_model:
|
if self.cfg.metric_for_best_model:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["metric_for_best_model"] = (
|
||||||
"metric_for_best_model"
|
self.cfg.metric_for_best_model
|
||||||
] = self.cfg.metric_for_best_model
|
)
|
||||||
if self.cfg.greater_is_better:
|
if self.cfg.greater_is_better:
|
||||||
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
||||||
|
|
||||||
@@ -467,13 +464,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
|
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||||
if self.cfg.torch_compile_backend:
|
if self.cfg.torch_compile_backend:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["torch_compile_backend"] = (
|
||||||
"torch_compile_backend"
|
self.cfg.torch_compile_backend
|
||||||
] = self.cfg.torch_compile_backend
|
)
|
||||||
if self.cfg.torch_compile_mode:
|
if self.cfg.torch_compile_mode:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["torch_compile_mode"] = (
|
||||||
"torch_compile_mode"
|
self.cfg.torch_compile_mode
|
||||||
] = self.cfg.torch_compile_mode
|
)
|
||||||
|
|
||||||
# DDP Config
|
# DDP Config
|
||||||
if self.cfg.ddp_timeout:
|
if self.cfg.ddp_timeout:
|
||||||
@@ -482,32 +479,32 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.ddp_bucket_cap_mb:
|
if self.cfg.ddp_bucket_cap_mb:
|
||||||
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
|
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
|
||||||
if self.cfg.ddp_broadcast_buffers is not None:
|
if self.cfg.ddp_broadcast_buffers is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["ddp_broadcast_buffers"] = (
|
||||||
"ddp_broadcast_buffers"
|
self.cfg.ddp_broadcast_buffers
|
||||||
] = self.cfg.ddp_broadcast_buffers
|
)
|
||||||
|
|
||||||
# these are all the "standard" kwargs that are def used
|
# these are all the "standard" kwargs that are def used
|
||||||
training_arguments_kwargs["max_steps"] = (
|
training_arguments_kwargs["max_steps"] = (
|
||||||
total_num_steps if self.cfg.max_steps else -1
|
total_num_steps if self.cfg.max_steps else -1
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["per_device_train_batch_size"] = (
|
||||||
"per_device_train_batch_size"
|
self.cfg.micro_batch_size
|
||||||
] = self.cfg.micro_batch_size
|
)
|
||||||
if self.cfg.eval_batch_size:
|
if self.cfg.eval_batch_size:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["per_device_eval_batch_size"] = (
|
||||||
"per_device_eval_batch_size"
|
self.cfg.eval_batch_size
|
||||||
] = self.cfg.eval_batch_size
|
)
|
||||||
if self.cfg.auto_find_batch_size is not None:
|
if self.cfg.auto_find_batch_size is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["auto_find_batch_size"] = (
|
||||||
"auto_find_batch_size"
|
self.cfg.auto_find_batch_size
|
||||||
] = self.cfg.auto_find_batch_size
|
)
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["gradient_accumulation_steps"] = (
|
||||||
"gradient_accumulation_steps"
|
self.cfg.gradient_accumulation_steps
|
||||||
] = self.cfg.gradient_accumulation_steps
|
)
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["eval_accumulation_steps"] = (
|
||||||
"eval_accumulation_steps"
|
self.cfg.gradient_accumulation_steps
|
||||||
] = self.cfg.gradient_accumulation_steps
|
)
|
||||||
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||||
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
|
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
|
||||||
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
|
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
|
||||||
@@ -527,9 +524,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
and self.cfg.eval_steps
|
and self.cfg.eval_steps
|
||||||
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
||||||
) or False
|
) or False
|
||||||
|
|
||||||
|
# handle ddp
|
||||||
|
ddp_find_unused_parameters = None
|
||||||
|
if self.cfg.ddp:
|
||||||
|
ddp_find_unused_parameters = bool(self.cfg.ddp_find_unused_parameters)
|
||||||
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
||||||
False if self.cfg.ddp else None
|
ddp_find_unused_parameters
|
||||||
)
|
)
|
||||||
|
|
||||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||||
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
training_arguments_kwargs["curriculum_sampling"] = self.cfg.curriculum_sampling
|
||||||
report_to = []
|
report_to = []
|
||||||
@@ -551,34 +554,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
|
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||||
else:
|
else:
|
||||||
training_arguments_kwargs["run_name"] = None
|
training_arguments_kwargs["run_name"] = None
|
||||||
training_arguments_kwargs["optim"] = (
|
|
||||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
|
||||||
)
|
|
||||||
if self.cfg.optim_args:
|
|
||||||
if isinstance(self.cfg.optim_args, dict):
|
|
||||||
optim_args = ",".join(
|
|
||||||
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
optim_args = self.cfg.optim_args
|
|
||||||
training_arguments_kwargs["optim_args"] = optim_args
|
|
||||||
if self.cfg.optim_target_modules:
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"optim_target_modules"
|
|
||||||
] = self.cfg.optim_target_modules
|
|
||||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
|
||||||
training_arguments_kwargs[
|
|
||||||
"loraplus_lr_embedding"
|
|
||||||
] = self.cfg.loraplus_lr_embedding
|
|
||||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
|
||||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
|
||||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
|
||||||
|
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["alternate_lr_scheduler_type"] = (
|
||||||
"alternate_lr_scheduler_type"
|
self.cfg.lr_scheduler
|
||||||
] = self.cfg.lr_scheduler
|
)
|
||||||
else:
|
else:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = (
|
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||||
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
||||||
@@ -587,9 +568,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||||
)
|
)
|
||||||
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["cosine_constant_lr_ratio"] = (
|
||||||
"cosine_constant_lr_ratio"
|
self.cfg.cosine_constant_lr_ratio
|
||||||
] = self.cfg.cosine_constant_lr_ratio
|
)
|
||||||
training_arguments_kwargs["weight_decay"] = (
|
training_arguments_kwargs["weight_decay"] = (
|
||||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||||
)
|
)
|
||||||
@@ -602,40 +583,40 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.eval_sample_packing
|
self.cfg.eval_sample_packing
|
||||||
)
|
)
|
||||||
if self.cfg.sample_packing_bin_size is not None:
|
if self.cfg.sample_packing_bin_size is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["sample_packing_bin_size"] = (
|
||||||
"sample_packing_bin_size"
|
self.cfg.sample_packing_bin_size
|
||||||
] = self.cfg.sample_packing_bin_size
|
)
|
||||||
if self.cfg.sample_packing_group_size is not None:
|
if self.cfg.sample_packing_group_size is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["sample_packing_group_size"] = (
|
||||||
"sample_packing_group_size"
|
self.cfg.sample_packing_group_size
|
||||||
] = self.cfg.sample_packing_group_size
|
)
|
||||||
if self.cfg.sample_packing_eff_est:
|
if self.cfg.sample_packing_eff_est:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["sample_packing_efficiency"] = (
|
||||||
"sample_packing_efficiency"
|
self.cfg.sample_packing_eff_est
|
||||||
] = self.cfg.sample_packing_eff_est
|
)
|
||||||
|
|
||||||
if self.cfg.relora_steps:
|
if self.cfg.relora_steps:
|
||||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["relora_warmup_steps"] = (
|
||||||
"relora_warmup_steps"
|
self.cfg.relora_warmup_steps
|
||||||
] = self.cfg.relora_warmup_steps
|
)
|
||||||
if self.cfg.relora_anneal_steps:
|
if self.cfg.relora_anneal_steps:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["relora_anneal_steps"] = (
|
||||||
"relora_anneal_steps"
|
self.cfg.relora_anneal_steps
|
||||||
] = self.cfg.relora_anneal_steps
|
)
|
||||||
if self.cfg.relora_prune_ratio:
|
if self.cfg.relora_prune_ratio:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["relora_prune_ratio"] = (
|
||||||
"relora_prune_ratio"
|
self.cfg.relora_prune_ratio
|
||||||
] = self.cfg.relora_prune_ratio
|
)
|
||||||
|
|
||||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||||
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["lisa_step_interval"] = (
|
||||||
"lisa_step_interval"
|
self.cfg.lisa_step_interval
|
||||||
] = self.cfg.lisa_step_interval
|
)
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["lisa_layers_attribute"] = (
|
||||||
"lisa_layers_attribute"
|
self.cfg.lisa_layers_attribute
|
||||||
] = self.cfg.lisa_layers_attribute
|
)
|
||||||
|
|
||||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||||
training_arguments_kwargs
|
training_arguments_kwargs
|
||||||
@@ -649,63 +630,134 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.neftune_noise_alpha is not None:
|
if self.cfg.neftune_noise_alpha is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["neftune_noise_alpha"] = (
|
||||||
"neftune_noise_alpha"
|
self.cfg.neftune_noise_alpha
|
||||||
] = self.cfg.neftune_noise_alpha
|
)
|
||||||
|
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
training_arguments_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# Handle custom optimizer
|
||||||
if self.cfg.optimizer in [
|
custom_supported_optimizers = [opt.value for opt in CustomSupportedOptimizers]
|
||||||
"optimi_adamw",
|
if self.cfg.optimizer in custom_supported_optimizers:
|
||||||
"ao_adamw_4bit",
|
# Common optimizer kwargs
|
||||||
"ao_adamw_8bit",
|
optimizer_kwargs = {
|
||||||
"ao_adamw_fp8",
|
"lr": training_arguments_kwargs.get("learning_rate"),
|
||||||
"adopt_adamw",
|
"weight_decay": training_arguments_kwargs.get("weight_decay"),
|
||||||
]:
|
}
|
||||||
# Set default so transformers doesn't throw
|
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
|
||||||
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
|
|
||||||
|
|
||||||
if self.cfg.optimizer == "lion_pytorch":
|
# Adam-specific kwargs
|
||||||
from lion_pytorch import Lion
|
adam_kwargs = {}
|
||||||
|
if training_arguments_kwargs.get(
|
||||||
|
"adam_beta1"
|
||||||
|
) and training_arguments_kwargs.get("adam_beta2"):
|
||||||
|
adam_kwargs["betas"] = (
|
||||||
|
training_arguments_kwargs.get("adam_beta1"),
|
||||||
|
training_arguments_kwargs.get("adam_beta2"),
|
||||||
|
)
|
||||||
|
if training_arguments_kwargs.get("adam_epsilon"):
|
||||||
|
adam_kwargs["eps"] = training_arguments_kwargs.get("adam_epsilon")
|
||||||
|
|
||||||
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
|
if self.cfg.optimizer == "muon":
|
||||||
if "weight_decay" in training_arguments_kwargs:
|
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
||||||
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
|
MuonOptimizerFactory,
|
||||||
|
|
||||||
if (
|
|
||||||
"adam_beta1" in training_arguments_kwargs
|
|
||||||
and "adam_beta2" in training_arguments_kwargs
|
|
||||||
):
|
|
||||||
lion_kwargs["betas"] = (
|
|
||||||
training_arguments_kwargs["adam_beta1"],
|
|
||||||
training_arguments_kwargs["adam_beta2"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer_kwargs["optimizers"] = (
|
optimizer_cls = MuonOptimizerFactory
|
||||||
Lion(params=self.model.parameters(), **lion_kwargs),
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
None,
|
elif self.cfg.optimizer == "optimi_adamw":
|
||||||
|
from optimi import AdamW
|
||||||
|
|
||||||
|
optimizer_kwargs["foreach"] = False
|
||||||
|
optimizer_cls = AdamW
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_4bit":
|
||||||
|
# TODO remove 20250401
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||||
|
|
||||||
|
optimizer_cls = AdamW4bit
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
|
||||||
|
LOG.warning(
|
||||||
|
f"`ao_adamw_4bit` will be deprecated soon. Please use `{OptimizerNames.ADAMW_TORCH_4BIT}` instead."
|
||||||
|
)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_8bit":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||||
|
|
||||||
|
optimizer_cls = AdamW8bit
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "ao_adamw_fp8":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||||
|
|
||||||
|
optimizer_cls = AdamWFp8
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
elif self.cfg.optimizer == "adopt_adamw":
|
||||||
|
from axolotl.utils.optimizers.adopt import ADOPT
|
||||||
|
|
||||||
|
optimizer_cls = ADOPT
|
||||||
|
adam_kwargs["decouple"] = True
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
|
||||||
|
# Parse any additional optimizer args from config
|
||||||
|
if self.cfg.optim_args:
|
||||||
|
if isinstance(self.cfg.optim_args, dict):
|
||||||
|
optimizer_kwargs.update(self.cfg.optim_args)
|
||||||
|
else:
|
||||||
|
# Parse string format "key1=value1,key2=value2"
|
||||||
|
for mapping in self.cfg.optim_args.replace(" ", "").split(","):
|
||||||
|
key, value = mapping.split("=")
|
||||||
|
optimizer_kwargs[key] = value
|
||||||
|
|
||||||
|
trainer_kwargs["optimizer_cls_and_kwargs"] = (
|
||||||
|
optimizer_cls,
|
||||||
|
optimizer_kwargs,
|
||||||
)
|
)
|
||||||
# Set default so transformers doesn't throw
|
else:
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
# Use transformers' optimizer
|
||||||
|
training_arguments_kwargs["optim"] = self.cfg.optimizer
|
||||||
|
|
||||||
|
# Parse any additional optimizer args from config
|
||||||
|
if self.cfg.optim_args:
|
||||||
|
if isinstance(self.cfg.optim_args, dict):
|
||||||
|
optim_args = ",".join(
|
||||||
|
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optim_args = self.cfg.optim_args
|
||||||
|
training_arguments_kwargs["optim_args"] = optim_args
|
||||||
|
|
||||||
if self.cfg.optimizer == "adamw_anyprecision":
|
if self.cfg.optimizer == "adamw_anyprecision":
|
||||||
if Path(self.cfg.torchdistx_path).exists():
|
if Path(self.cfg.torchdistx_path).exists():
|
||||||
sys.path.append(self.cfg.torchdistx_path)
|
sys.path.append(self.cfg.torchdistx_path)
|
||||||
importlib.import_module("torchdistx")
|
importlib.import_module("torchdistx")
|
||||||
|
|
||||||
|
if self.cfg.optim_target_modules:
|
||||||
|
training_arguments_kwargs["optim_target_modules"] = (
|
||||||
|
self.cfg.optim_target_modules
|
||||||
|
)
|
||||||
|
|
||||||
|
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||||
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
|
|
||||||
|
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||||
|
training_arguments_kwargs["loraplus_lr_embedding"] = (
|
||||||
|
self.cfg.loraplus_lr_embedding
|
||||||
|
)
|
||||||
|
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||||
|
|
||||||
if self.cfg.accelerator_config:
|
if self.cfg.accelerator_config:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["accelerator_config"] = (
|
||||||
"accelerator_config"
|
self.cfg.accelerator_config
|
||||||
] = self.cfg.accelerator_config
|
)
|
||||||
|
|
||||||
if self.cfg.tensor_parallel:
|
|
||||||
training_arguments_kwargs["tp_size"] = torch.cuda.device_count()
|
|
||||||
|
|
||||||
|
if self.cfg.image_size:
|
||||||
|
training_arguments_kwargs["image_size"] = self.cfg.image_size
|
||||||
|
if self.cfg.image_resize_algorithm:
|
||||||
|
training_arguments_kwargs["image_resize_algorithm"] = (
|
||||||
|
self.cfg.image_resize_algorithm
|
||||||
|
)
|
||||||
if self.cfg.kd_ce_alpha is not None:
|
if self.cfg.kd_ce_alpha is not None:
|
||||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||||
if self.cfg.kd_alpha is not None:
|
if self.cfg.kd_alpha is not None:
|
||||||
@@ -713,13 +765,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.kd_temperature is not None:
|
if self.cfg.kd_temperature is not None:
|
||||||
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
||||||
if self.cfg.kd_zscore_base_temp is not None:
|
if self.cfg.kd_zscore_base_temp is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["kd_zscore_base_temp"] = (
|
||||||
"kd_zscore_base_temp"
|
self.cfg.kd_zscore_base_temp
|
||||||
] = self.cfg.kd_zscore_base_temp
|
)
|
||||||
if self.cfg.kd_top_k_before_softmax is not None:
|
if self.cfg.kd_top_k_before_softmax is not None:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["kd_top_k_before_softmax"] = (
|
||||||
"kd_top_k_before_softmax"
|
self.cfg.kd_top_k_before_softmax
|
||||||
] = self.cfg.kd_top_k_before_softmax
|
)
|
||||||
|
|
||||||
|
training_arguments_kwargs["sequence_parallel_degree"] = (
|
||||||
|
self.cfg.sequence_parallel_degree
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
@@ -804,9 +860,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||||
):
|
):
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
if self.cfg.pretraining_sample_concatenation is False:
|
if (
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
self.cfg.pretraining_sample_concatenation is False
|
||||||
if self.cfg.micro_batch_size > 1:
|
or self.cfg.micro_batch_size > 1
|
||||||
|
):
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -834,9 +891,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if "max_length" in kwargs:
|
if "max_length" in kwargs:
|
||||||
kwargs.pop("max_length")
|
kwargs.pop("max_length")
|
||||||
elif use_batch_sampler_collator:
|
elif use_batch_sampler_collator:
|
||||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or (
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
|
||||||
elif (
|
|
||||||
self.cfg.model_config_type in ["llama"]
|
self.cfg.model_config_type in ["llama"]
|
||||||
and self.cfg.flash_attention is not True
|
and self.cfg.flash_attention is not True
|
||||||
):
|
):
|
||||||
@@ -846,8 +901,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
if self.cfg.processor_type and self.processor:
|
if self.cfg.processor_type and self.processor:
|
||||||
collator = MultiModalChatDataCollator
|
collator = MultiModalChatDataCollator
|
||||||
kwargs["processor"] = self.processor
|
kwargs["processing_strategy"] = get_processing_strategy(
|
||||||
kwargs["chat_template"] = training_args.chat_template
|
self.processor,
|
||||||
|
training_args.chat_template,
|
||||||
|
self.cfg.chat_template,
|
||||||
|
image_size=training_args.image_size,
|
||||||
|
image_resize_algorithm=training_args.image_resize_algorithm,
|
||||||
|
)
|
||||||
elif self.cfg.batch_flattening:
|
elif self.cfg.batch_flattening:
|
||||||
collator = DataCollatorWithFlattening
|
collator = DataCollatorWithFlattening
|
||||||
collator_args.pop(0)
|
collator_args.pop(0)
|
||||||
@@ -867,6 +927,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
collator = DataCollatorForSeq2Seq
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|
||||||
kwargs["return_tensors"] = "pt"
|
kwargs["return_tensors"] = "pt"
|
||||||
|
if issubclass(collator, DataCollatorForSeq2Seq):
|
||||||
|
kwargs["sequence_parallel_degree"] = training_args.sequence_parallel_degree
|
||||||
|
|
||||||
return collator(
|
return collator(
|
||||||
*collator_args,
|
*collator_args,
|
||||||
@@ -875,13 +937,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
|
|
||||||
class HFRLTrainerBuilder(TrainerBuilderBase):
|
class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||||
"""
|
"""Trainer factory class for TRL-based RLHF trainers (e.g. DPO)"""
|
||||||
Trainer factory class for TRL-based RLHF trainers (e.g. DPO)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_callbacks(self):
|
def get_callbacks(self):
|
||||||
callbacks = super().get_callbacks()
|
callbacks = super().get_callbacks()
|
||||||
callbacks.append(SaveModelCallback())
|
|
||||||
|
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
@@ -931,32 +990,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||||
)
|
)
|
||||||
if self.cfg.remove_unused_columns is not None:
|
if self.cfg.remove_unused_columns is not None:
|
||||||
training_args_kwargs[
|
training_args_kwargs["remove_unused_columns"] = (
|
||||||
"remove_unused_columns"
|
self.cfg.remove_unused_columns
|
||||||
] = self.cfg.remove_unused_columns
|
)
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["remove_unused_columns"] = False
|
training_args_kwargs["remove_unused_columns"] = False
|
||||||
|
|
||||||
if self.cfg.dataloader_pin_memory is not None:
|
if self.cfg.dataloader_pin_memory is not None:
|
||||||
training_args_kwargs[
|
training_args_kwargs["dataloader_pin_memory"] = (
|
||||||
"dataloader_pin_memory"
|
self.cfg.dataloader_pin_memory
|
||||||
] = self.cfg.dataloader_pin_memory
|
)
|
||||||
if self.cfg.dataloader_num_workers is not None:
|
if self.cfg.dataloader_num_workers is not None:
|
||||||
training_args_kwargs[
|
training_args_kwargs["dataloader_num_workers"] = (
|
||||||
"dataloader_num_workers"
|
self.cfg.dataloader_num_workers
|
||||||
] = self.cfg.dataloader_num_workers
|
)
|
||||||
if self.cfg.dataloader_prefetch_factor is not None:
|
if self.cfg.dataloader_prefetch_factor is not None:
|
||||||
training_args_kwargs[
|
training_args_kwargs["dataloader_prefetch_factor"] = (
|
||||||
"dataloader_prefetch_factor"
|
self.cfg.dataloader_prefetch_factor
|
||||||
] = self.cfg.dataloader_prefetch_factor
|
)
|
||||||
if self.cfg.gradient_checkpointing:
|
if self.cfg.gradient_checkpointing:
|
||||||
training_args_kwargs[
|
training_args_kwargs["gradient_checkpointing"] = (
|
||||||
"gradient_checkpointing"
|
self.cfg.gradient_checkpointing
|
||||||
] = self.cfg.gradient_checkpointing
|
)
|
||||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||||
training_args_kwargs[
|
training_args_kwargs["gradient_checkpointing_kwargs"] = (
|
||||||
"gradient_checkpointing_kwargs"
|
self.cfg.gradient_checkpointing_kwargs
|
||||||
] = self.cfg.gradient_checkpointing_kwargs
|
)
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["gradient_checkpointing_kwargs"] = {
|
training_args_kwargs["gradient_checkpointing_kwargs"] = {
|
||||||
"use_reentrant": False
|
"use_reentrant": False
|
||||||
@@ -984,6 +1043,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rpo_alpha is not None:
|
if self.cfg.rpo_alpha is not None:
|
||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||||
|
|
||||||
|
training_args_kwargs["sequence_parallel_degree"] = (
|
||||||
|
self.cfg.sequence_parallel_degree
|
||||||
|
)
|
||||||
|
|
||||||
training_args_cls = None
|
training_args_cls = None
|
||||||
blocklist_args_kwargs = []
|
blocklist_args_kwargs = []
|
||||||
if self.cfg.rl == "simpo":
|
if self.cfg.rl == "simpo":
|
||||||
@@ -1030,9 +1093,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.dpo_use_weighting is not None:
|
if self.cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||||
if self.cfg.dpo_use_logits_to_keep is not None:
|
if self.cfg.dpo_use_logits_to_keep is not None:
|
||||||
training_args_kwargs[
|
training_args_kwargs["use_logits_to_keep"] = (
|
||||||
"use_logits_to_keep"
|
self.cfg.dpo_use_logits_to_keep
|
||||||
] = self.cfg.dpo_use_logits_to_keep
|
)
|
||||||
|
|
||||||
for blocklist_key in blocklist_args_kwargs:
|
for blocklist_key in blocklist_args_kwargs:
|
||||||
if blocklist_key in training_args_kwargs:
|
if blocklist_key in training_args_kwargs:
|
||||||
@@ -1067,9 +1130,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
dpo_trainer_kwargs["peft_config"] = self.peft_config
|
dpo_trainer_kwargs["peft_config"] = self.peft_config
|
||||||
if self.cfg.precompute_ref_log_probs is not None:
|
if self.cfg.precompute_ref_log_probs is not None:
|
||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs["precompute_ref_log_probs"] = (
|
||||||
"precompute_ref_log_probs"
|
self.cfg.precompute_ref_log_probs
|
||||||
] = self.cfg.precompute_ref_log_probs
|
)
|
||||||
if self.cfg.rl == "grpo":
|
if self.cfg.rl == "grpo":
|
||||||
trainer_cls = GRPOStrategy.get_trainer_class()
|
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
@@ -1102,6 +1165,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["dataset_tags"] = [
|
dpo_trainer_kwargs["dataset_tags"] = [
|
||||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
]
|
]
|
||||||
|
|
||||||
dpo_trainer = trainer_cls(
|
dpo_trainer = trainer_cls(
|
||||||
*trainer_cls_args,
|
*trainer_cls_args,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
@@ -1119,21 +1183,3 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer.add_callback(callback)
|
dpo_trainer.add_callback(callback)
|
||||||
|
|
||||||
return dpo_trainer
|
return dpo_trainer
|
||||||
|
|
||||||
|
|
||||||
class HFPPOTrainerBuilder(TrainerBuilderBase):
|
|
||||||
"""
|
|
||||||
HF Factory class for PPO Trainer
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_callbacks(self):
|
|
||||||
callbacks = super().get_callbacks()
|
|
||||||
return callbacks
|
|
||||||
|
|
||||||
def get_post_trainer_create_callbacks(self, trainer):
|
|
||||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
|
||||||
return callbacks
|
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
|
||||||
# build PPOConfig
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
"""Init for axolotl.core.trainers"""
|
||||||
|
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
from axolotl.core.trainers.dpo import AxolotlDPOTrainer
|
||||||
|
from axolotl.core.trainers.grpo import AxolotlGRPOTrainer
|
||||||
|
from axolotl.core.trainers.mamba import AxolotlMambaTrainer
|
||||||
|
from axolotl.core.trainers.relora import ReLoRATrainer
|
||||||
|
from axolotl.core.trainers.trl import (
|
||||||
|
AxolotlCPOTrainer,
|
||||||
|
AxolotlKTOTrainer,
|
||||||
|
AxolotlORPOTrainer,
|
||||||
|
AxolotlPPOTrainer,
|
||||||
|
AxolotlPRMTrainer,
|
||||||
|
AxolotlRewardTrainer,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,163 +1,44 @@
|
|||||||
"""
|
"""Module for customized trainers"""
|
||||||
module for customized trainers
|
|
||||||
"""
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Dict, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from torch.utils.data import (
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
BatchSampler,
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
DataLoader,
|
||||||
|
RandomSampler,
|
||||||
|
Sampler,
|
||||||
|
SequentialSampler,
|
||||||
|
)
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
|
||||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
from axolotl.core.trainers.handlers import SequenceParallelHandler
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.core.trainers.mixins import TrainerMixins
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.core.trainers.utils import (
|
||||||
get_cosine_schedule_with_min_lr,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
get_cosine_schedule_with_quadratic_warmup,
|
sanitize_kwargs_for_tagging,
|
||||||
get_cosine_schedule_with_warmup_decay_constant,
|
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
LOG = logging.getLogger(__name__)
|
||||||
import smdistributed.modelparallel.torch as smp
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
class AxolotlTrainer(TrainerMixins, Trainer):
|
||||||
if isinstance(tag_names, str):
|
"""Extend the base Trainer for axolotl helpers"""
|
||||||
tag_names = [tag_names]
|
|
||||||
|
|
||||||
if kwargs is not None:
|
|
||||||
if "tags" not in kwargs:
|
|
||||||
kwargs["tags"] = tag_names
|
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
|
||||||
kwargs["tags"].extend(tag_names)
|
|
||||||
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
|
||||||
tag_names.append(kwargs["tags"])
|
|
||||||
kwargs["tags"] = tag_names
|
|
||||||
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
|
||||||
if isinstance(dataset_tags, str):
|
|
||||||
dataset_tags = [dataset_tags]
|
|
||||||
|
|
||||||
if (dataset_tags is not None) and (kwargs is not None):
|
|
||||||
if "dataset_tags" not in kwargs:
|
|
||||||
kwargs["dataset_tags"] = dataset_tags
|
|
||||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
|
||||||
kwargs["dataset_tags"].extend(dataset_tags)
|
|
||||||
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
|
||||||
dataset_tags.append(kwargs["dataset_tags"])
|
|
||||||
kwargs["dataset_tags"] = dataset_tags
|
|
||||||
|
|
||||||
return kwargs
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerMixin(Trainer):
|
|
||||||
"""
|
|
||||||
Mixin class for scheduler setup in CausalTrainer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
|
||||||
passed as an argument.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_training_steps (int): The number of training steps to do.
|
|
||||||
optimizer (torch.optim.Optimizer): The training optimizer
|
|
||||||
"""
|
|
||||||
use_cosine_quadratic = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.lr_quadratic_warmup is True
|
|
||||||
)
|
|
||||||
|
|
||||||
use_cosine_min_lr = (
|
|
||||||
self.args.lr_scheduler_type == "cosine"
|
|
||||||
and self.args.cosine_min_lr_ratio is not None
|
|
||||||
)
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
|
||||||
# fmt: on
|
|
||||||
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
|
||||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
|
||||||
pct_start = num_warmup_steps / num_training_steps
|
|
||||||
extra_lr_kwargs = {}
|
|
||||||
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
|
||||||
extra_lr_kwargs["pct_start"] = pct_start
|
|
||||||
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
|
||||||
extra_lr_kwargs["anneal_strategy"] = "cos"
|
|
||||||
|
|
||||||
self.lr_scheduler = OneCycleLR(
|
|
||||||
optimizer,
|
|
||||||
max_lr=self.args.learning_rate,
|
|
||||||
total_steps=num_training_steps,
|
|
||||||
**extra_lr_kwargs,
|
|
||||||
**self.args.lr_scheduler_kwargs,
|
|
||||||
)
|
|
||||||
elif use_cosine_quadratic:
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
|
||||||
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
|
||||||
)
|
|
||||||
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
|
||||||
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
|
||||||
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer,
|
|
||||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
|
||||||
num_training_steps=num_training_steps,
|
|
||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
|
||||||
else:
|
|
||||||
if use_cosine_quadratic:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
if use_cosine_min_lr:
|
|
||||||
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(SchedulerMixin, Trainer):
|
|
||||||
"""
|
|
||||||
Extend the base Trainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
tag_names = ["axolotl"]
|
tag_names = ["axolotl"]
|
||||||
@@ -174,12 +55,16 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
self.eval_data_collator = eval_data_collator
|
self.eval_data_collator = eval_data_collator
|
||||||
self.dataset_tags = dataset_tags
|
self.dataset_tags = dataset_tags
|
||||||
self._signature_columns = None # workaround for pylint
|
self._signature_columns = None # workaround for pylint
|
||||||
|
|
||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
|
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
self.sequence_parallel_handler = SequenceParallelHandler(self.args)
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True, dataloader=None):
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
if self.args.torch_compile:
|
if self.args.torch_compile:
|
||||||
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
@@ -192,316 +77,251 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
)
|
)
|
||||||
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
|
def _create_multipack_sampler(
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
self, base_sampler: Sampler, dataset: Dataset
|
||||||
|
) -> MultipackBatchSampler:
|
||||||
|
"""
|
||||||
|
Helper method to create a `MultipackBatchSampler` for multipacking sequences
|
||||||
|
for training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_sampler: Sampler to wrap with `MultipackBatchSampler`.
|
||||||
|
dataset: Dataset to sample from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Multipack (sample packing) batch sampler.
|
||||||
|
"""
|
||||||
|
if self.args.multipack_real_batches:
|
||||||
|
batch_size = self.args.per_device_train_batch_size
|
||||||
|
batch_max_len = self.args.max_seq_length
|
||||||
|
else:
|
||||||
|
batch_size = 1
|
||||||
|
train_batch_size = (
|
||||||
|
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||||
|
)
|
||||||
|
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||||
|
|
||||||
|
return MultipackBatchSampler(
|
||||||
|
base_sampler,
|
||||||
|
lengths=get_dataset_lengths(dataset),
|
||||||
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
|
batch_max_len=batch_max_len,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sequential=self.args.sample_packing_sequentially,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_train_sampler(self) -> Sampler | None:
|
||||||
|
"""
|
||||||
|
Helper method to get the sampler for training. Handles cases for sequence
|
||||||
|
parallelism, sample packing, and curriculum sampling (sequential).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
|
depends on the passed training args.
|
||||||
|
"""
|
||||||
|
use_sample_packing = self.args.sample_packing and not self.args.pretraining
|
||||||
|
|
||||||
|
# Determine the base sampler first
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
base_sampler = self.sequence_parallel_handler._get_train_sampler(self.train_dataset)
|
||||||
|
elif self.args.curriculum_sampling:
|
||||||
|
base_sampler = SequentialSampler(self.train_dataset)
|
||||||
|
elif use_sample_packing:
|
||||||
|
base_sampler = RandomSampler(self.train_dataset)
|
||||||
|
else:
|
||||||
|
# Default to parent class implementation for standard random sampling
|
||||||
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
|
# Apply multipack wrapper if needed
|
||||||
|
if use_sample_packing:
|
||||||
|
return self._create_multipack_sampler(
|
||||||
|
base_sampler=base_sampler,
|
||||||
|
dataset=self.train_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
return base_sampler
|
||||||
|
|
||||||
|
def _get_eval_sampler(self, eval_dataset: Dataset | None = None) -> Sampler | None:
|
||||||
|
"""
|
||||||
|
Helper method to get the sampler for evaluation. Handles sequence parallelism
|
||||||
|
and sample packing cases.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If the dataset is non-empty, a sampler is returned, the type of which
|
||||||
|
depends on the passed training args.
|
||||||
|
"""
|
||||||
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
|
||||||
|
# Multipacking enabled if training is enabled and eval is not explicitly disabled
|
||||||
|
use_multipack = (
|
||||||
|
self.args.sample_packing and self.args.eval_sample_packing is not False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine the base sampler
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
base_sampler = self.sequence_parallel_handler._get_eval_sampler(eval_dataset)
|
||||||
|
elif use_multipack:
|
||||||
|
base_sampler = SequentialSampler(eval_dataset)
|
||||||
|
else:
|
||||||
|
return super()._get_eval_sampler(eval_dataset)
|
||||||
|
|
||||||
|
# Apply multipack wrapper if needed
|
||||||
|
if use_multipack:
|
||||||
|
return self._create_multipack_sampler(
|
||||||
|
base_sampler=base_sampler,
|
||||||
|
dataset=eval_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
return base_sampler
|
||||||
|
|
||||||
|
def _create_dataloader_params(self, is_eval=False, custom_batch_size=None):
|
||||||
|
"""Create common dataloader parameters for train or eval."""
|
||||||
|
batch_size = custom_batch_size or (
|
||||||
|
self.args.eval_batch_size if is_eval else self._train_batch_size
|
||||||
|
)
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"to_weight_decay": {}, # LayerNorm and bias
|
"batch_size": batch_size,
|
||||||
"embeddings": {}, # lm_head, embed_tokens,
|
"collate_fn": self.data_collator,
|
||||||
"no_weight_decay": {},
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
}
|
}
|
||||||
lr_groups_lookup = {}
|
|
||||||
lr_groups_learning_rates = {}
|
|
||||||
if self.args.lr_groups:
|
|
||||||
for lr_group in self.args.lr_groups:
|
|
||||||
group_name = lr_group["name"]
|
|
||||||
group_modules = lr_group["modules"]
|
|
||||||
for module in group_modules:
|
|
||||||
lr_groups_lookup[module] = group_name
|
|
||||||
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
|
||||||
params[f"to_weight_decay_{group_name}"] = {}
|
|
||||||
|
|
||||||
for name, param in opt_model.named_parameters():
|
# Add persistent workers only for training
|
||||||
if not param.requires_grad:
|
if not is_eval and hasattr(self.args, "dataloader_persistent_workers"):
|
||||||
continue
|
params["persistent_workers"] = self.args.dataloader_persistent_workers
|
||||||
if name.endswith("modules_to_save.default.weight") or any(
|
|
||||||
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
|
||||||
):
|
|
||||||
params["embeddings"][name] = param
|
|
||||||
elif name in decay_parameters:
|
|
||||||
lr_group_modules = [
|
|
||||||
group_modules
|
|
||||||
for group_modules in lr_groups_lookup
|
|
||||||
if group_modules in name
|
|
||||||
]
|
|
||||||
if lr_groups_lookup and any(lr_group_modules):
|
|
||||||
lr_group_module = lr_group_modules[0]
|
|
||||||
group_name = lr_groups_lookup[lr_group_module]
|
|
||||||
params[f"to_weight_decay_{group_name}"][name] = param
|
|
||||||
else:
|
|
||||||
params["to_weight_decay"][name] = param
|
|
||||||
else:
|
|
||||||
params["no_weight_decay"][name] = param
|
|
||||||
optimizer_grouped_parameters = []
|
|
||||||
if params["to_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["to_weight_decay"].values()),
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["embeddings"]:
|
|
||||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
|
||||||
if self.args.embedding_lr_scale:
|
|
||||||
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
|
||||||
elif self.args.embedding_lr:
|
|
||||||
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["embeddings"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": lr,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if params["no_weight_decay"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(params["no_weight_decay"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": optimizer_kwargs["lr"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for group_name, group_lr in lr_groups_learning_rates.items():
|
|
||||||
if params[f"to_weight_decay_{group_name}"]:
|
|
||||||
optimizer_grouped_parameters.append(
|
|
||||||
{
|
|
||||||
"params": list(
|
|
||||||
params[f"to_weight_decay_{group_name}"].values()
|
|
||||||
),
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
"lr": group_lr,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return optimizer_grouped_parameters
|
# Add prefetch factor if specified
|
||||||
|
if self.args.dataloader_prefetch_factor:
|
||||||
|
params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
def create_optimizer(self):
|
return params
|
||||||
if (
|
|
||||||
self.args.loraplus_lr_ratio is None
|
|
||||||
and self.args.embedding_lr_scale is None
|
|
||||||
and self.args.embedding_lr is None
|
|
||||||
and self.args.lr_groups is None
|
|
||||||
and self.args.alternate_optimizer
|
|
||||||
not in [
|
|
||||||
"optimi_adamw",
|
|
||||||
"ao_adamw_8bit",
|
|
||||||
"ao_adamw_4bit",
|
|
||||||
"ao_adamw_fp8",
|
|
||||||
"adopt_adamw",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
return super().create_optimizer()
|
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
def _prepare_dataloader(
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
self, dataset, sampler, is_eval=False, custom_batch_size=None
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
):
|
||||||
self.args,
|
"""Prepare a dataloader with the given dataset and sampler."""
|
||||||
opt_model,
|
# Get base parameters
|
||||||
)
|
dataloader_params = self._create_dataloader_params(is_eval, custom_batch_size)
|
||||||
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
|
||||||
opt_model, optimizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.args.loraplus_lr_ratio is not None:
|
# Add sampler configuration
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
if not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||||
loraplus_lr_embedding = getattr(
|
|
||||||
self.args, "loraplus_lr_embedding", 1e-6
|
|
||||||
)
|
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
opt_model,
|
|
||||||
optimizer_cls,
|
|
||||||
loraplus_lr_ratio=loraplus_lr_ratio,
|
|
||||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
|
||||||
**optimizer_kwargs,
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
self.args.embedding_lr_scale is not None
|
|
||||||
or self.args.embedding_lr is not None
|
|
||||||
or self.args.lr_groups is not None
|
|
||||||
):
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
|
||||||
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
|
||||||
from optimi import AdamW
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
|
||||||
AdamW(
|
|
||||||
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "ao_adamw_4bit":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamW4bit
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
|
||||||
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "ao_adamw_8bit":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamW8bit
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
|
||||||
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "ao_adamw_fp8":
|
|
||||||
from torchao.prototype.low_bit_optim import AdamWFp8
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
|
||||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
||||||
)
|
|
||||||
elif self.args.alternate_optimizer == "adopt_adamw":
|
|
||||||
from axolotl.utils.optimizers.adopt import ADOPT
|
|
||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
|
||||||
ADOPT(
|
|
||||||
optimizer_grouped_parameters,
|
|
||||||
decouple=True,
|
|
||||||
**optimizer_kwargs,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.optimizer
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
|
||||||
if self.args.multipack_real_batches:
|
|
||||||
batch_size = self.args.per_device_train_batch_size
|
|
||||||
batch_max_len = self.args.max_seq_length
|
|
||||||
else:
|
|
||||||
batch_size = 1
|
|
||||||
train_batch_size = (
|
|
||||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
|
||||||
)
|
|
||||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
|
||||||
|
|
||||||
if self.args.curriculum_sampling:
|
|
||||||
sampler = SequentialSampler(self.train_dataset)
|
|
||||||
else:
|
|
||||||
sampler = RandomSampler(self.train_dataset)
|
|
||||||
|
|
||||||
return MultipackBatchSampler(
|
|
||||||
sampler,
|
|
||||||
lengths=get_dataset_lengths(self.train_dataset),
|
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
|
||||||
batch_max_len=batch_max_len,
|
|
||||||
batch_size=batch_size,
|
|
||||||
group_size=self.args.sample_packing_group_size,
|
|
||||||
bin_size=self.args.sample_packing_bin_size,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
if self.args.curriculum_sampling:
|
|
||||||
return SequentialSampler(self.train_dataset)
|
|
||||||
return super()._get_train_sampler()
|
|
||||||
|
|
||||||
def _get_eval_sampler(
|
|
||||||
self, eval_dataset: Dataset
|
|
||||||
) -> Optional[torch.utils.data.Sampler]:
|
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
|
||||||
if self.args.multipack_real_batches:
|
|
||||||
batch_size = self.args.per_device_eval_batch_size
|
|
||||||
batch_max_len = self.args.max_seq_length
|
|
||||||
else:
|
|
||||||
batch_size = 1
|
|
||||||
batch_max_len = (
|
|
||||||
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
|
||||||
)
|
|
||||||
return MultipackBatchSampler(
|
|
||||||
SequentialSampler(eval_dataset),
|
|
||||||
lengths=get_dataset_lengths(self.eval_dataset),
|
|
||||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
|
||||||
batch_max_len=batch_max_len,
|
|
||||||
batch_size=batch_size,
|
|
||||||
group_size=self.args.sample_packing_group_size,
|
|
||||||
bin_size=self.args.sample_packing_bin_size,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
|
||||||
|
|
||||||
def get_train_dataloader(self) -> DataLoader:
|
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
|
||||||
train_dataset = self.train_dataset
|
|
||||||
if "length" in train_dataset.features.keys():
|
|
||||||
train_dataset = train_dataset.remove_columns(["length"])
|
|
||||||
data_collator = self.data_collator
|
|
||||||
dataloader_params = {
|
|
||||||
"batch_size": self._train_batch_size,
|
|
||||||
"collate_fn": data_collator,
|
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
|
||||||
}
|
|
||||||
if self.args.dataloader_prefetch_factor:
|
|
||||||
dataloader_params[
|
|
||||||
"prefetch_factor"
|
|
||||||
] = self.args.dataloader_prefetch_factor
|
|
||||||
|
|
||||||
sampler = self._get_train_sampler()
|
|
||||||
if isinstance(sampler, BatchSampler):
|
if isinstance(sampler, BatchSampler):
|
||||||
|
# batch_size and batch_sampler are mutually exclusive
|
||||||
dataloader_params["batch_sampler"] = sampler
|
dataloader_params["batch_sampler"] = sampler
|
||||||
del dataloader_params["batch_size"]
|
del dataloader_params["batch_size"]
|
||||||
else:
|
else:
|
||||||
dataloader_params["sampler"] = sampler
|
dataloader_params["sampler"] = sampler
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
dataloader_params["worker_init_fn"] = seed_worker
|
|
||||||
|
|
||||||
|
if not is_eval:
|
||||||
|
dataloader_params["worker_init_fn"] = seed_worker
|
||||||
|
|
||||||
|
# Create the dataloader
|
||||||
|
dataloader = DataLoader(dataset, **dataloader_params)
|
||||||
|
|
||||||
|
if self.args.sample_packing and (
|
||||||
|
(not is_eval and not self.args.pretraining)
|
||||||
|
or (is_eval and self.args.eval_sample_packing is not False)
|
||||||
|
):
|
||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
return self.accelerator.prepare_data_loader(
|
|
||||||
DataLoader(train_dataset, **dataloader_params)
|
|
||||||
)
|
|
||||||
return super().get_train_dataloader()
|
|
||||||
|
|
||||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
# Return unprepared dataloader if using sequence parallelism
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
# Otherwise prepare with accelerator
|
||||||
|
dataloader = self.accelerator.prepare_data_loader(dataloader)
|
||||||
|
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_dataloader(self) -> DataLoader:
|
||||||
|
"""Get dataloader for training"""
|
||||||
|
train_dataset = self.train_dataset
|
||||||
|
data_collator = self.data_collator # type: ignore
|
||||||
|
|
||||||
|
# Handle dataset preprocessing
|
||||||
|
if isinstance(train_dataset, datasets.Dataset):
|
||||||
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
|
train_dataset = train_dataset.remove_columns(["length"])
|
||||||
|
if not self.args.sample_packing or self.args.pretraining:
|
||||||
|
train_dataset = self._remove_unused_columns(
|
||||||
|
train_dataset, description="training"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||||
|
data_collator,
|
||||||
|
description="training",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get sampler and create dataloader
|
||||||
|
sampler = self._get_train_sampler()
|
||||||
|
return self._prepare_dataloader(train_dataset, sampler, is_eval=False)
|
||||||
|
|
||||||
|
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
|
||||||
|
"""Get dataloader for evaluation"""
|
||||||
|
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
|
||||||
|
# Handle special case: sample packing is enabled but eval_sample_packing is False
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.eval_data_collator
|
self.eval_data_collator
|
||||||
)
|
)
|
||||||
if eval_dataset:
|
if "length" in eval_dataset.column_names:
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
self.train_data_collator
|
self.train_data_collator
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
||||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
# Handle sample packing or sequence parallelism
|
||||||
eval_dataset = (
|
if (
|
||||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
self.args.sample_packing
|
||||||
|
and self.args.eval_sample_packing is not False
|
||||||
|
or self.args.sequence_parallel_degree > 1
|
||||||
|
):
|
||||||
|
# Get appropriate data collator
|
||||||
|
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.eval_data_collator
|
||||||
|
if hasattr(self, "eval_data_collator") and self.eval_data_collator
|
||||||
|
else self.data_collator
|
||||||
|
)
|
||||||
|
if "length" in eval_dataset.column_names:
|
||||||
|
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||||
|
|
||||||
|
# Handle dataset preprocessing for SP
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
if isinstance(eval_dataset, datasets.Dataset):
|
||||||
|
eval_dataset = self._remove_unused_columns(
|
||||||
|
eval_dataset, description="evaluation"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.data_collator = self._get_collator_with_removed_columns( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.data_collator, description="evaluation"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use eval_batch_size for sample packing, per_device_eval_batch_size otherwise
|
||||||
|
batch_size = (
|
||||||
|
self.args.eval_batch_size
|
||||||
|
if self.args.sample_packing
|
||||||
|
else self.args.per_device_eval_batch_size
|
||||||
|
)
|
||||||
|
sampler = self._get_eval_sampler(eval_dataset)
|
||||||
|
dataloader = self._prepare_dataloader(
|
||||||
|
eval_dataset, sampler, is_eval=True, custom_batch_size=batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
return dataloader
|
||||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
|
||||||
data_collator = self.data_collator
|
|
||||||
dataloader_params = {
|
|
||||||
"batch_size": self.args.eval_batch_size,
|
|
||||||
"collate_fn": data_collator,
|
|
||||||
"num_workers": self.args.dataloader_num_workers,
|
|
||||||
"pin_memory": self.args.dataloader_pin_memory,
|
|
||||||
}
|
|
||||||
if self.args.dataloader_prefetch_factor:
|
|
||||||
dataloader_params[
|
|
||||||
"prefetch_factor"
|
|
||||||
] = self.args.dataloader_prefetch_factor
|
|
||||||
|
|
||||||
if isinstance(eval_sampler, BatchSampler):
|
|
||||||
dataloader_params["batch_sampler"] = eval_sampler
|
|
||||||
del dataloader_params["batch_size"]
|
|
||||||
else:
|
|
||||||
dataloader_params["sampler"] = eval_sampler
|
|
||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
|
||||||
|
|
||||||
self.accelerator.even_batches = False
|
|
||||||
return self.accelerator.prepare_data_loader(
|
|
||||||
DataLoader(eval_dataset, **dataloader_params)
|
|
||||||
)
|
|
||||||
|
|
||||||
return super().get_eval_dataloader(eval_dataset)
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
def _get_bench_sampler(
|
def _get_bench_sampler(
|
||||||
self, bench_dataset: Dataset
|
self, bench_dataset: Dataset
|
||||||
) -> Optional[torch.utils.data.Sampler]:
|
) -> torch.utils.data.Sampler | None:
|
||||||
if self.args.world_size <= 1:
|
if self.args.world_size <= 1:
|
||||||
return SequentialSampler(bench_dataset)
|
return SequentialSampler(bench_dataset)
|
||||||
return None
|
return None
|
||||||
@@ -524,8 +344,59 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||||
|
|
||||||
return DataLoader(bench_dataset, **dataloader_params)
|
return DataLoader(bench_dataset, **dataloader_params)
|
||||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
|
||||||
|
|
||||||
|
def training_step(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: dict[str, torch.Tensor | Any],
|
||||||
|
num_items_in_batch: int | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Perform a training step on a batch of inputs. Overrides the
|
||||||
|
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||||
|
enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model to perform training step for.
|
||||||
|
inputs: Dictionary mapping of inputs.
|
||||||
|
num_items_in_batch: The number of items in the batch.
|
||||||
|
"""
|
||||||
|
# Set up sequence parallelism for this step if enabled
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
|
||||||
|
|
||||||
|
# Proceed with normal training step
|
||||||
|
return super().training_step(model, inputs, num_items_in_batch) # type: ignore
|
||||||
|
|
||||||
|
def prediction_step(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
inputs: dict[str, torch.Tensor | Any],
|
||||||
|
prediction_loss_only: bool,
|
||||||
|
ignore_keys: list[str] | None = None,
|
||||||
|
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
||||||
|
"""
|
||||||
|
Perform a prediction step on a batch of inputs. Overrides the
|
||||||
|
`transformers.trainer.Trainer` method to handle sequence parallelism if
|
||||||
|
enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model to perform prediction step for.
|
||||||
|
inputs: Dictionary mapping of inputs.
|
||||||
|
prediction_loss_only: Whether to return only the loss.
|
||||||
|
ignore_keys: Keys to ignore in the inputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (loss, logits, labels).
|
||||||
|
"""
|
||||||
|
# Set up sequence parallelism for this prediction step if enabled
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self.sequence_parallel_handler._update_ring_flash_attn_params(inputs)
|
||||||
|
|
||||||
|
# Proceed with normal prediction step
|
||||||
|
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) # type: ignore
|
||||||
|
|
||||||
|
@override
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||||
):
|
):
|
||||||
@@ -542,6 +413,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
return_outputs=return_outputs,
|
return_outputs=return_outputs,
|
||||||
num_items_in_batch=num_items_in_batch,
|
num_items_in_batch=num_items_in_batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
return super().compute_loss(
|
return super().compute_loss(
|
||||||
model,
|
model,
|
||||||
inputs,
|
inputs,
|
||||||
@@ -716,10 +588,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
kwargs = sanitize_kwargs_for_ds_tagging(
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
)
|
)
|
||||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
@@ -736,15 +608,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
||||||
"""
|
"""
|
||||||
Log `logs` on the various objects watching training, including stored metrics.
|
Log `logs` on the various objects watching training, including stored metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logs (`Dict[str, float]`):
|
logs: The values to log.
|
||||||
The values to log.
|
start_time: The start of training.
|
||||||
start_time (`Optional[float]`):
|
|
||||||
The start of training.
|
|
||||||
"""
|
"""
|
||||||
# logs either has 'loss' or 'eval_loss'
|
# logs either has 'loss' or 'eval_loss'
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
@@ -756,7 +626,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
return super().log(logs, start_time)
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
def store_metrics(
|
def store_metrics(
|
||||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||||
) -> None:
|
) -> None:
|
||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
self._stored_metrics[train_eval][key].append(value)
|
||||||
@@ -768,111 +638,3 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, **kwargs)
|
return super()._save_checkpoint(model, trial, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
|
||||||
"""
|
|
||||||
Mamba specific trainer to handle loss calculation
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "mamba"]
|
|
||||||
|
|
||||||
def compute_loss(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
inputs,
|
|
||||||
return_outputs=False, # pylint: disable=unused-argument
|
|
||||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
input_ids = inputs.pop("input_ids")
|
|
||||||
lm_logits = model(input_ids).logits
|
|
||||||
|
|
||||||
labels = input_ids.to(lm_logits.device)
|
|
||||||
shift_logits = lm_logits[:, :-1, :].contiguous()
|
|
||||||
labels = labels[:, 1:].contiguous()
|
|
||||||
|
|
||||||
loss_fct = torch.nn.CrossEntropyLoss()
|
|
||||||
lm_loss = loss_fct(
|
|
||||||
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
return lm_loss
|
|
||||||
|
|
||||||
|
|
||||||
class ReLoRATrainer(AxolotlTrainer):
|
|
||||||
"""
|
|
||||||
Trainer subclass that uses the OneCycleLR scheduler
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "relora"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.lr_scheduler = None
|
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self,
|
|
||||||
num_training_steps: int,
|
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
||||||
):
|
|
||||||
optimizer = self.optimizer if optimizer is None else optimizer
|
|
||||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
|
||||||
|
|
||||||
if self.args.relora_steps:
|
|
||||||
warmup_steps = (
|
|
||||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
|
||||||
)
|
|
||||||
anneal_steps = (
|
|
||||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
|
||||||
)
|
|
||||||
self.lr_scheduler = ReLoRAScheduler(
|
|
||||||
optimizer,
|
|
||||||
lr_scheduler,
|
|
||||||
self.args.relora_steps,
|
|
||||||
anneal_steps,
|
|
||||||
warmup_steps,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.lr_scheduler = lr_scheduler
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base CPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "cpo"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base RewardTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "reward"]
|
|
||||||
|
|
||||||
|
|
||||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
|
||||||
"""
|
|
||||||
Extend the base trl.PRMTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "prm"]
|
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
"""
|
"""DPO Specific Strategy for training"""
|
||||||
DPO Specific Strategy for training
|
|
||||||
"""
|
|
||||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||||
|
|
||||||
|
|
||||||
class DPOStrategy:
|
class DPOStrategy:
|
||||||
"""
|
"""Strategy for DPO training"""
|
||||||
Strategy for DPO training
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_trainer_class(cls):
|
def get_trainer_class(cls):
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""
|
"""Axolotl specific DPO args"""
|
||||||
Axolotl specific DPO args
|
|
||||||
"""
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from trl import DPOConfig
|
from trl import DPOConfig
|
||||||
@@ -10,6 +9,4 @@ from axolotl.core.training_args import AxolotlTrainingMixins
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
|
||||||
"""
|
"""DPO config for DPO training"""
|
||||||
DPO config for DPO training
|
|
||||||
"""
|
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
"""
|
"""DPO trainer for axolotl"""
|
||||||
DPO trainer for axolotl
|
|
||||||
"""
|
|
||||||
import gc
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Dict, Union
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from peft.optimizers import create_loraplus_optimizer
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
@@ -12,28 +10,29 @@ from transformers import Trainer
|
|||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.core.trainers.base import (
|
from axolotl.core.trainers.handlers import SequenceParallelHandler
|
||||||
SchedulerMixin,
|
from axolotl.core.trainers.mixins import TrainerMixins
|
||||||
_sanitize_kwargs_for_ds_tagging,
|
from axolotl.core.trainers.utils import (
|
||||||
_sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
|
sanitize_kwargs_for_tagging,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
import smdistributed.modelparallel.torch as smp
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
class AxolotlDPOTrainer(TrainerMixins, DPOTrainer):
|
||||||
"""
|
"""Extend the base DPOTrainer for axolotl helpers"""
|
||||||
Extend the base DPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "dpo"]
|
tag_names = ["axolotl", "dpo"]
|
||||||
|
|
||||||
def __init__(self, *args, dataset_tags=None, **kwargs):
|
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.dataset_tags = dataset_tags
|
self.dataset_tags = dataset_tags
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
self.model_accepts_loss_kwargs = False
|
self.model_accepts_loss_kwargs = False
|
||||||
|
self.sequence_parallel_handler = SequenceParallelHandler(args=self.args)
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -73,10 +72,10 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
kwargs = _sanitize_kwargs_for_ds_tagging(
|
kwargs = sanitize_kwargs_for_ds_tagging(
|
||||||
dataset_tags=self.dataset_tags, kwargs=kwargs
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
)
|
)
|
||||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
@@ -87,7 +86,7 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
max_prompt_length,
|
max_prompt_length,
|
||||||
max_completion_length,
|
max_completion_length,
|
||||||
add_special_tokens,
|
add_special_tokens,
|
||||||
) -> Dict:
|
) -> dict:
|
||||||
res = DPOTrainer.tokenize_row(
|
res = DPOTrainer.tokenize_row(
|
||||||
features,
|
features,
|
||||||
processing_class,
|
processing_class,
|
||||||
@@ -116,10 +115,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
def training_step(
|
def training_step(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
inputs: dict[str, torch.Tensor | Any | None],
|
||||||
num_items_in_batch=None,
|
num_items_in_batch=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
self.sequence_parallel_handler.prepare_for_training_step(self, inputs)
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
return super().training_step(model, inputs, num_items_in_batch)
|
||||||
return loss
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import logging
|
|||||||
from trl.trainer.grpo_trainer import RewardFunc
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
|
from axolotl.utils.schemas.trl import TRLConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -31,30 +32,60 @@ class GRPOStrategy:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def set_training_args_kwargs(cls, cfg):
|
def set_training_args_kwargs(cls, cfg):
|
||||||
grpo_args_kwargs = {}
|
grpo_args_kwargs = {}
|
||||||
if cfg.trl and cfg.trl.use_vllm:
|
|
||||||
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
if not hasattr(cfg, "trl") or not cfg.trl:
|
||||||
if cfg.trl and cfg.trl.vllm_device:
|
return grpo_args_kwargs
|
||||||
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
|
||||||
else:
|
trl: TRLConfig = cfg.trl # type: ignore
|
||||||
grpo_args_kwargs["vllm_device"] = "auto"
|
|
||||||
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
|
if trl.use_vllm:
|
||||||
grpo_args_kwargs[
|
grpo_args_kwargs["use_vllm"] = trl.use_vllm
|
||||||
"vllm_gpu_memory_utilization"
|
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
|
||||||
] = cfg.trl.vllm_gpu_memory_utilization
|
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
|
||||||
if cfg.trl and cfg.trl.vllm_max_model_len:
|
if trl.vllm_server_timeout:
|
||||||
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
|
||||||
if cfg.trl and cfg.trl.num_generations:
|
if trl.vllm_guided_decoding_regex:
|
||||||
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
grpo_args_kwargs["vllm_guided_decoding_regex"] = (
|
||||||
if cfg.trl and cfg.trl.sync_ref_model:
|
trl.vllm_guided_decoding_regex
|
||||||
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
)
|
||||||
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
|
|
||||||
grpo_args_kwargs[
|
if trl.num_generations:
|
||||||
"ref_model_mixup_alpha"
|
grpo_args_kwargs["num_generations"] = trl.num_generations
|
||||||
] = cfg.trl.ref_model_mixup_alpha
|
|
||||||
if cfg.trl and cfg.trl.ref_model_sync_steps:
|
if trl.sync_ref_model:
|
||||||
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
grpo_args_kwargs["sync_ref_model"] = trl.sync_ref_model
|
||||||
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
|
||||||
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
|
if trl.ref_model_mixup_alpha:
|
||||||
|
grpo_args_kwargs["ref_model_mixup_alpha"] = trl.ref_model_mixup_alpha
|
||||||
|
|
||||||
|
if trl.ref_model_sync_steps:
|
||||||
|
grpo_args_kwargs["ref_model_sync_steps"] = trl.ref_model_sync_steps
|
||||||
|
|
||||||
|
grpo_args_kwargs["max_completion_length"] = trl.max_completion_length
|
||||||
|
grpo_args_kwargs["log_completions"] = trl.log_completions
|
||||||
|
|
||||||
|
if trl.reward_weights:
|
||||||
|
grpo_args_kwargs["reward_weights"] = trl.reward_weights
|
||||||
|
|
||||||
|
if trl.scale_rewards is not None:
|
||||||
|
grpo_args_kwargs["scale_rewards"] = trl.scale_rewards
|
||||||
|
|
||||||
|
if trl.temperature is not None:
|
||||||
|
grpo_args_kwargs["temperature"] = trl.temperature
|
||||||
|
if trl.top_p is not None:
|
||||||
|
grpo_args_kwargs["top_p"] = trl.top_p
|
||||||
|
if trl.top_k is not None:
|
||||||
|
grpo_args_kwargs["top_k"] = trl.top_k
|
||||||
|
if trl.min_p is not None:
|
||||||
|
grpo_args_kwargs["min_p"] = trl.min_p
|
||||||
|
if trl.repetition_penalty is not None:
|
||||||
|
grpo_args_kwargs["repetition_penalty"] = trl.repetition_penalty
|
||||||
|
|
||||||
|
if trl.num_iterations is not None:
|
||||||
|
grpo_args_kwargs["num_iterations"] = trl.num_iterations
|
||||||
|
if trl.epsilon is not None:
|
||||||
|
grpo_args_kwargs["epsilon"] = trl.epsilon
|
||||||
|
|
||||||
return grpo_args_kwargs
|
return grpo_args_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -71,9 +102,9 @@ class GRPOStrategy:
|
|||||||
def set_trainer_kwargs(cls, cfg):
|
def set_trainer_kwargs(cls, cfg):
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
if cfg.trl and cfg.trl.reward_processing_classes:
|
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||||
trainer_kwargs[
|
trainer_kwargs["reward_processing_classes"] = (
|
||||||
"reward_processing_classes"
|
cfg.trl.reward_processing_classes
|
||||||
] = cfg.trl.reward_processing_classes
|
)
|
||||||
return trainer_kwargs
|
return trainer_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
Axolotl Specific Training Args
|
Axolotl Specific Training Args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from trl import GRPOConfig
|
from trl import GRPOConfig
|
||||||
|
|||||||
@@ -1,108 +1,65 @@
|
|||||||
"""
|
"""Axolotl GRPO trainer"""
|
||||||
Axolotl GRPO trainer
|
|
||||||
"""
|
|
||||||
from accelerate.utils import is_peft_model
|
|
||||||
from accelerate.utils.other import is_compiled_module
|
|
||||||
from transformers import PreTrainedModel
|
|
||||||
from trl import GRPOConfig, GRPOTrainer
|
|
||||||
from trl.models import unwrap_model_for_generation
|
|
||||||
|
|
||||||
from axolotl.core.trainers.base import SchedulerMixin
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
from accelerate.utils import is_deepspeed_available, is_peft_model
|
||||||
|
from trl import GRPOTrainer
|
||||||
|
from trl.extras.profiling import profiling_decorator
|
||||||
|
|
||||||
|
from axolotl.core.trainers.mixins import TrainerMixins
|
||||||
|
|
||||||
|
if is_deepspeed_available():
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
|
|
||||||
# mypy: ignore-errors
|
class AxolotlGRPOTrainer(TrainerMixins, GRPOTrainer):
|
||||||
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||||
"""
|
|
||||||
Extend the base GRPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
_tag_names = ["trl", "grpo", "axolotl"]
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
@profiling_decorator
|
||||||
super().__init__(*args, **kwargs)
|
def _move_model_to_vllm(self):
|
||||||
|
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
|
||||||
# pylint: disable=access-member-before-definition
|
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||||
# Enable gradient checkpointing if requested
|
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
|
||||||
if kwargs["args"].gradient_checkpointing:
|
gather_if_zero3 = (
|
||||||
# Ensure use_cache is disabled
|
deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
|
||||||
if hasattr(self.model, "config"):
|
|
||||||
self.model.config.use_cache = False
|
|
||||||
|
|
||||||
# Enable gradient checkpointing on the base model for PEFT
|
|
||||||
if is_peft_model(self.model) and hasattr(
|
|
||||||
self.model.base_model, "gradient_checkpointing_enable"
|
|
||||||
):
|
|
||||||
self.model.base_model.gradient_checkpointing_enable()
|
|
||||||
# Enable gradient checkpointing for non-PEFT models
|
|
||||||
elif hasattr(self.model, "gradient_checkpointing_enable"):
|
|
||||||
self.model.gradient_checkpointing_enable()
|
|
||||||
self.model = self._enable_gradient_checkpointing(self.model, kwargs["args"])
|
|
||||||
# pylint: enable=access-member-before-definition
|
|
||||||
|
|
||||||
def _enable_gradient_checkpointing(
|
|
||||||
self, model: PreTrainedModel, args: GRPOConfig
|
|
||||||
) -> PreTrainedModel:
|
|
||||||
"""Enables gradient checkpointing for the model."""
|
|
||||||
# pylint: disable=unused-argument,redefined-builtin
|
|
||||||
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
|
||||||
use_reentrant = (
|
|
||||||
"use_reentrant" not in gradient_checkpointing_kwargs
|
|
||||||
or gradient_checkpointing_kwargs["use_reentrant"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_reentrant:
|
if is_peft_model(self.model):
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
|
||||||
model.enable_input_require_grads()
|
# adapters in a sharded manner is not supported.
|
||||||
else:
|
with gather_if_zero3(list(self.model.parameters())):
|
||||||
|
self.model.merge_adapter()
|
||||||
|
|
||||||
def make_inputs_require_grad(module, input, output):
|
# Update vLLM weights while parameters are gathered
|
||||||
output.requires_grad_(True)
|
for name, param in self.model.named_parameters():
|
||||||
|
# When using PEFT, we need to recover the original parameter name and discard some parameters
|
||||||
|
name = (
|
||||||
|
name.removeprefix("base_model.model.")
|
||||||
|
.removeprefix("base_model.model.")
|
||||||
|
.replace(".base_layer", "")
|
||||||
|
)
|
||||||
|
if self.model.prefix in name:
|
||||||
|
continue
|
||||||
|
# When module to save, remove its prefix and discard the original module
|
||||||
|
if "original_module" in name:
|
||||||
|
continue
|
||||||
|
name = name.replace("modules_to_save.default.", "")
|
||||||
|
|
||||||
model.get_input_embeddings().register_forward_hook(
|
if self.accelerator.is_main_process:
|
||||||
make_inputs_require_grad
|
self.vllm_client.update_named_param(name, param.data)
|
||||||
)
|
|
||||||
|
|
||||||
return model
|
# Unmerge adapters while parameters are still gathered
|
||||||
# pylint: enable=unused-argument,redefined-builtin
|
self.model.unmerge_adapter()
|
||||||
|
# Parameters will automatically be repartitioned when exiting the context
|
||||||
|
else:
|
||||||
|
# For non-PEFT models, simply gather and update each parameter individually.
|
||||||
|
for name, param in self.model.named_parameters():
|
||||||
|
with gather_if_zero3([param]):
|
||||||
|
if self.accelerator.is_main_process:
|
||||||
|
self.vllm_client.update_named_param(name, param.data)
|
||||||
|
|
||||||
def _move_model_to_vllm(self):
|
# Reset cache on main process
|
||||||
with unwrap_model_for_generation(
|
if self.accelerator.is_main_process:
|
||||||
self.model,
|
self.vllm_client.reset_prefix_cache()
|
||||||
self.accelerator,
|
|
||||||
gather_deepspeed3_params=self.args.ds3_gather_for_generation,
|
|
||||||
) as unwrapped_model:
|
|
||||||
if is_compiled_module(unwrapped_model):
|
|
||||||
unwrapped_model = (
|
|
||||||
unwrapped_model._orig_mod # pylint: disable=protected-access
|
|
||||||
)
|
|
||||||
if is_peft_model(unwrapped_model):
|
|
||||||
unwrapped_model.merge_adapter()
|
|
||||||
state_dict = unwrapped_model.state_dict()
|
|
||||||
# Remove base_model and base_layer prefixes
|
|
||||||
state_dict = {
|
|
||||||
k.removeprefix("base_model.model.")
|
|
||||||
.removeprefix("base_model.model.")
|
|
||||||
.replace(".base_layer", ""): v
|
|
||||||
for k, v in state_dict.items()
|
|
||||||
}
|
|
||||||
# Remove values with adapter prefix (example: "_lora")
|
|
||||||
state_dict = {
|
|
||||||
k: v
|
|
||||||
for k, v in state_dict.items()
|
|
||||||
if unwrapped_model.prefix not in k
|
|
||||||
}
|
|
||||||
# When module to save, remove its prefix and discard the original module
|
|
||||||
state_dict = {
|
|
||||||
k.replace("modules_to_save.default.", ""): v
|
|
||||||
for k, v in state_dict.items()
|
|
||||||
if "original_module" not in k
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
state_dict = unwrapped_model.state_dict()
|
|
||||||
if self.accelerator.is_main_process:
|
|
||||||
llm_model = (
|
|
||||||
self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
|
||||||
)
|
|
||||||
llm_model.load_weights(state_dict.items())
|
|
||||||
if is_peft_model(unwrapped_model):
|
|
||||||
unwrapped_model.unmerge_adapter()
|
|
||||||
|
|||||||
3
src/axolotl/core/trainers/handlers/__init__.py
Normal file
3
src/axolotl/core/trainers/handlers/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""Init for trainer handlers"""
|
||||||
|
|
||||||
|
from axolotl.core.trainers.handlers.sequence_parallel import SequenceParallelHandler
|
||||||
123
src/axolotl/core/trainers/handlers/sequence_parallel.py
Normal file
123
src/axolotl/core/trainers/handlers/sequence_parallel.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""Handler class for sequence parallel trainer logic"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import DistributedSampler
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceParallelHandler:
|
||||||
|
"""
|
||||||
|
Handler class that encapsulates sequence parallelism functionality.
|
||||||
|
This replaces the SequenceParallelMixin with a composition-based approach.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args=None):
|
||||||
|
"""
|
||||||
|
Initialize the sequence parallel handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: The arguments object containing sequence parallelism settings.
|
||||||
|
"""
|
||||||
|
self.args = args
|
||||||
|
self.ring_attn_group = None
|
||||||
|
|
||||||
|
# Set up sequence parallelism if enabled
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
self._setup_sequence_parallel()
|
||||||
|
|
||||||
|
def _setup_sequence_parallel(self):
|
||||||
|
"""Set up sequence parallelism environment."""
|
||||||
|
from ring_flash_attn import update_ring_flash_attn_params
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
|
|
||||||
|
self.update_ring_flash_attn_params = update_ring_flash_attn_params
|
||||||
|
self.ring_attn_group = get_ring_attn_group()
|
||||||
|
|
||||||
|
def create_sequence_parallel_sampler(
|
||||||
|
self,
|
||||||
|
dataset,
|
||||||
|
shuffle=True,
|
||||||
|
is_eval=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper method to create sampler for sequence parallelism (SP).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Dataset to sample from.
|
||||||
|
shuffle: Whether to shuffle the dataset.
|
||||||
|
is_eval: Whether we are creating a sampler for evaluation or training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Distributed sampler.
|
||||||
|
"""
|
||||||
|
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||||
|
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||||
|
|
||||||
|
return DistributedSampler(
|
||||||
|
dataset,
|
||||||
|
num_replicas=num_sp_groups,
|
||||||
|
rank=sp_group_id,
|
||||||
|
seed=self.args.seed if shuffle else None,
|
||||||
|
shuffle=shuffle,
|
||||||
|
drop_last=not is_eval,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_train_sampler(self, dataset):
|
||||||
|
"""
|
||||||
|
Get a training sampler configured for sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The training dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured sequence parallel sampler.
|
||||||
|
"""
|
||||||
|
return self.create_sequence_parallel_sampler(
|
||||||
|
dataset,
|
||||||
|
shuffle=not self.args.curriculum_sampling,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_eval_sampler(self, eval_dataset):
|
||||||
|
"""
|
||||||
|
Get an evaluation sampler configured for sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eval_dataset: The evaluation dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured sequence parallel sampler.
|
||||||
|
"""
|
||||||
|
return self.create_sequence_parallel_sampler(
|
||||||
|
eval_dataset, shuffle=False, is_eval=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_ring_flash_attn_params(self, inputs):
|
||||||
|
"""
|
||||||
|
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||||
|
the substituted ring_flash_attn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Current batch of inputs.
|
||||||
|
"""
|
||||||
|
# At this point, inputs should already be partitioned by the sequence
|
||||||
|
# parallel data collator
|
||||||
|
batch_size = inputs["input_ids"].shape[0]
|
||||||
|
seq_len = inputs["input_ids"].shape[1]
|
||||||
|
packed_seq_lens = [seq_len] * batch_size
|
||||||
|
|
||||||
|
# Calculate the full sequence length across all GPUs in this SP group
|
||||||
|
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||||
|
|
||||||
|
cu_seqlens = torch.cumsum(
|
||||||
|
torch.tensor(
|
||||||
|
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlens = F.pad(
|
||||||
|
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
self.update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||||
32
src/axolotl/core/trainers/mamba.py
Normal file
32
src/axolotl/core/trainers/mamba.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Module for mamba trainer"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
|
"""Mamba specific trainer to handle loss calculation"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "mamba"]
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
inputs,
|
||||||
|
return_outputs=False, # pylint: disable=unused-argument
|
||||||
|
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
input_ids = inputs.pop("input_ids")
|
||||||
|
lm_logits = model(input_ids).logits
|
||||||
|
|
||||||
|
labels = input_ids.to(lm_logits.device)
|
||||||
|
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||||
|
labels = labels[:, 1:].contiguous()
|
||||||
|
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
|
lm_loss = loss_fct(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
return lm_loss
|
||||||
14
src/axolotl/core/trainers/mixins/__init__.py
Normal file
14
src/axolotl/core/trainers/mixins/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""Init for axolotl.core.trainers.mixins"""
|
||||||
|
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from axolotl.core.trainers.mixins.optimizer import OptimizerMixin
|
||||||
|
from axolotl.core.trainers.mixins.rng_state_loader import RngLoaderMixin
|
||||||
|
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerMixins(
|
||||||
|
OptimizerMixin, RngLoaderMixin, SchedulerMixin
|
||||||
|
):
|
||||||
|
"""Stub class combining all mixins for Axolotl trainers."""
|
||||||
201
src/axolotl/core/trainers/mixins/optimizer.py
Normal file
201
src/axolotl/core/trainers/mixins/optimizer.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
"""Module for Axolotl trainer optimizer mixin"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
|
from torch import nn
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BaseOptimizerFactory
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerMixin(Trainer):
|
||||||
|
"""Mixin class for shared handling of building custom optimizers"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def create_optimizer_grouped_parameters(
|
||||||
|
self, opt_model, optimizer_kwargs
|
||||||
|
) -> list[dict]:
|
||||||
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
|
params: dict = {
|
||||||
|
"to_weight_decay": {}, # LayerNorm and bias
|
||||||
|
"embeddings": {}, # lm_head, embed_tokens,
|
||||||
|
"no_weight_decay": {},
|
||||||
|
}
|
||||||
|
lr_groups_lookup = {}
|
||||||
|
lr_groups_learning_rates = {}
|
||||||
|
if self.args.lr_groups:
|
||||||
|
for lr_group in self.args.lr_groups:
|
||||||
|
group_name = lr_group["name"]
|
||||||
|
group_modules = lr_group["modules"]
|
||||||
|
for module in group_modules:
|
||||||
|
lr_groups_lookup[module] = group_name
|
||||||
|
lr_groups_learning_rates[group_name] = lr_group["lr"]
|
||||||
|
params[f"to_weight_decay_{group_name}"] = {}
|
||||||
|
|
||||||
|
for name, param in opt_model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
if name.endswith("modules_to_save.default.weight") or any(
|
||||||
|
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||||
|
):
|
||||||
|
params["embeddings"][name] = param
|
||||||
|
elif name in decay_parameters:
|
||||||
|
lr_group_modules = [
|
||||||
|
group_modules
|
||||||
|
for group_modules in lr_groups_lookup
|
||||||
|
if group_modules in name
|
||||||
|
]
|
||||||
|
if lr_groups_lookup and any(lr_group_modules):
|
||||||
|
lr_group_module = lr_group_modules[0]
|
||||||
|
group_name = lr_groups_lookup[lr_group_module]
|
||||||
|
params[f"to_weight_decay_{group_name}"][name] = param
|
||||||
|
else:
|
||||||
|
params["to_weight_decay"][name] = param
|
||||||
|
else:
|
||||||
|
params["no_weight_decay"][name] = param
|
||||||
|
optimizer_grouped_parameters = []
|
||||||
|
if params["to_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["to_weight_decay"].values()),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["embeddings"]:
|
||||||
|
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||||
|
if self.args.embedding_lr_scale:
|
||||||
|
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||||
|
elif self.args.embedding_lr:
|
||||||
|
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["embeddings"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["no_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["no_weight_decay"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for group_name, group_lr in lr_groups_learning_rates.items():
|
||||||
|
if params[f"to_weight_decay_{group_name}"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(
|
||||||
|
params[f"to_weight_decay_{group_name}"].values()
|
||||||
|
),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": group_lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
if (
|
||||||
|
self.args.loraplus_lr_ratio is None
|
||||||
|
and self.args.embedding_lr_scale is None
|
||||||
|
and self.args.embedding_lr is None
|
||||||
|
and self.args.lr_groups is None
|
||||||
|
and self.optimizer_cls_and_kwargs is None
|
||||||
|
):
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
|
|
||||||
|
if (
|
||||||
|
not self.optimizer
|
||||||
|
and self.optimizer_cls_and_kwargs is not None
|
||||||
|
and issubclass(self.optimizer_cls_and_kwargs[0], BaseOptimizerFactory)
|
||||||
|
):
|
||||||
|
optimizer_factory_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
self.optimizer = optimizer_factory_cls()(
|
||||||
|
opt_model, self.args, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.optimizer:
|
||||||
|
if self.optimizer_cls_and_kwargs is not None:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
|
||||||
|
else:
|
||||||
|
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args, opt_model
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
|
||||||
|
opt_model, optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
|
loraplus_lr_embedding = getattr(
|
||||||
|
self.args, "loraplus_lr_embedding", 1e-6
|
||||||
|
)
|
||||||
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
opt_model,
|
||||||
|
optimizer_cls,
|
||||||
|
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
|
**optimizer_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
|
# e.g. for GaLore optimizer.
|
||||||
|
if "params" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
|
||||||
|
|
||||||
|
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
|
||||||
|
# e.g. for LOMO optimizer.
|
||||||
|
if "model" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
|
||||||
|
|
||||||
|
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
|
||||||
|
# to avoid arguments conflicts.
|
||||||
|
if "optimizer_dict" in optimizer_kwargs:
|
||||||
|
optimizer_grouped_parameters = optimizer_kwargs.pop(
|
||||||
|
"optimizer_dict"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizer = optimizer_cls(
|
||||||
|
optimizer_grouped_parameters, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if optimizer_cls.__name__ == "Adam8bit":
|
||||||
|
import bitsandbytes
|
||||||
|
|
||||||
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||||
|
|
||||||
|
skipped = 0
|
||||||
|
for module in opt_model.modules():
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
skipped += sum(
|
||||||
|
{
|
||||||
|
p.data_ptr(): p.numel() for p in module.parameters()
|
||||||
|
}.values()
|
||||||
|
)
|
||||||
|
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||||
|
manager.register_module_override(
|
||||||
|
module, "weight", {"optim_bits": 32}
|
||||||
|
)
|
||||||
|
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||||
|
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.optimizer
|
||||||
65
src/axolotl/core/trainers/mixins/rng_state_loader.py
Normal file
65
src/axolotl/core/trainers/mixins/rng_state_loader.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""
|
||||||
|
Temporary fix/override for bug in resume from checkpoint
|
||||||
|
|
||||||
|
See https://github.com/huggingface/transformers/pull/37162
|
||||||
|
|
||||||
|
TODO: Remove when upstream added PR to release
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import Trainer, is_torch_npu_available
|
||||||
|
from transformers.trainer import safe_globals
|
||||||
|
from transformers.trainer_pt_utils import set_rng_state_for_device
|
||||||
|
from transformers.training_args import ParallelMode
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RngLoaderMixin(Trainer):
|
||||||
|
"""Mixin for method override to load RNG states from a checkpoint"""
|
||||||
|
|
||||||
|
def _load_rng_state(self, checkpoint):
|
||||||
|
# Load RNG states from `checkpoint`
|
||||||
|
if checkpoint is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.args.world_size > 1:
|
||||||
|
process_index = self.args.process_index
|
||||||
|
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
|
||||||
|
if not os.path.isfile(rng_file):
|
||||||
|
LOG.info(
|
||||||
|
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
|
||||||
|
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
rng_file = os.path.join(checkpoint, "rng_state.pth")
|
||||||
|
if not os.path.isfile(rng_file):
|
||||||
|
LOG.info(
|
||||||
|
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
|
||||||
|
"fashion, reproducibility is not guaranteed."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use safe_globals to ensure numpy RNG states can be deserialized safely under PyTorch 2.6+,
|
||||||
|
# which requires allowlisted classes when loading with weights_only=True.
|
||||||
|
with safe_globals():
|
||||||
|
checkpoint_rng_state = torch.load(rng_file) # nosec B614
|
||||||
|
random.setstate(checkpoint_rng_state["python"])
|
||||||
|
np.random.set_state(checkpoint_rng_state["numpy"])
|
||||||
|
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
|
||||||
|
|
||||||
|
is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
set_rng_state_for_device(
|
||||||
|
"CUDA", torch.cuda, checkpoint_rng_state, is_distributed
|
||||||
|
)
|
||||||
|
if is_torch_npu_available():
|
||||||
|
set_rng_state_for_device(
|
||||||
|
"NPU", torch.npu, checkpoint_rng_state, is_distributed
|
||||||
|
)
|
||||||
113
src/axolotl/core/trainers/mixins/scheduler.py
Normal file
113
src/axolotl/core/trainers/mixins/scheduler.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""Module for Axolotl trainer scheduler mixin"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
|
from axolotl.utils.schedulers import (
|
||||||
|
RexLR,
|
||||||
|
get_cosine_schedule_with_min_lr,
|
||||||
|
get_cosine_schedule_with_quadratic_warmup,
|
||||||
|
get_cosine_schedule_with_warmup_decay_constant,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerMixin(Trainer):
|
||||||
|
"""
|
||||||
|
Mixin class for scheduler setup in CausalTrainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||||
|
passed as an argument.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_training_steps (int): The number of training steps to do.
|
||||||
|
optimizer (torch.optim.Optimizer): The training optimizer
|
||||||
|
"""
|
||||||
|
use_cosine_quadratic = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.lr_quadratic_warmup is True
|
||||||
|
)
|
||||||
|
|
||||||
|
use_cosine_min_lr = (
|
||||||
|
self.args.lr_scheduler_type == "cosine"
|
||||||
|
and self.args.cosine_min_lr_ratio is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||||
|
# fmt: on
|
||||||
|
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
||||||
|
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||||
|
pct_start = num_warmup_steps / num_training_steps
|
||||||
|
extra_lr_kwargs = {}
|
||||||
|
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
||||||
|
extra_lr_kwargs["pct_start"] = pct_start
|
||||||
|
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
||||||
|
extra_lr_kwargs["anneal_strategy"] = "cos"
|
||||||
|
|
||||||
|
self.lr_scheduler = OneCycleLR(
|
||||||
|
optimizer,
|
||||||
|
max_lr=self.args.learning_rate,
|
||||||
|
total_steps=num_training_steps,
|
||||||
|
**extra_lr_kwargs,
|
||||||
|
**self.args.lr_scheduler_kwargs,
|
||||||
|
)
|
||||||
|
elif self.args.alternate_lr_scheduler_type == "rex":
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
|
||||||
|
self.lr_scheduler = RexLR(
|
||||||
|
optimizer=optimizer,
|
||||||
|
max_lr=self.args.learning_rate,
|
||||||
|
min_lr=0 if not use_cosine_min_lr else (self.args.learning_rate * self.args.cosine_min_lr_ratio),
|
||||||
|
total_steps=num_training_steps,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
)
|
||||||
|
elif use_cosine_quadratic:
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||||
|
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
|
||||||
|
)
|
||||||
|
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
|
||||||
|
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
|
||||||
|
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer,
|
||||||
|
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||||
|
else:
|
||||||
|
if use_cosine_quadratic:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
if use_cosine_min_lr:
|
||||||
|
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
132
src/axolotl/core/trainers/mixins/sequence_parallel.py
Normal file
132
src/axolotl/core/trainers/mixins/sequence_parallel.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
"""Module for Axolotl trainer sequence parallelism mixin"""
|
||||||
|
# TODO(Dan): remove
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from datasets import Dataset
|
||||||
|
from torch.utils.data import DistributedSampler, Sampler
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import get_ring_attn_group
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ring_flash_attn import update_ring_flash_attn_params
|
||||||
|
except ImportError:
|
||||||
|
# We pass silently here, but raise an ImportError in our Axolotl config validation
|
||||||
|
# if cfg.sequence_parallel_degree > 1 and `ring-flash-attn` is not installed.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceParallelMixin:
|
||||||
|
"""
|
||||||
|
Mixin class for sequence parallelism support in trainers.
|
||||||
|
|
||||||
|
This mixin provides functionality for handling sequence parallelism,
|
||||||
|
including creating appropriate samplers, managing data partitioning,
|
||||||
|
and updating ring flash attention parameters during training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|
||||||
|
def _setup_sequence_parallel(self):
|
||||||
|
"""Set up sequence parallelism environment."""
|
||||||
|
self.ring_attn_group = get_ring_attn_group()
|
||||||
|
|
||||||
|
def _create_sequence_parallel_sampler(
|
||||||
|
self,
|
||||||
|
dataset: Dataset,
|
||||||
|
shuffle: bool = True,
|
||||||
|
is_eval: bool = False,
|
||||||
|
) -> DistributedSampler:
|
||||||
|
"""
|
||||||
|
Helper method to create sampler for sequence parallelism (SP).
|
||||||
|
|
||||||
|
We create a distributed sampler with rank equal to the SP group ID, which
|
||||||
|
means that all ranks in the SP group receive the same sample / set of samples
|
||||||
|
per training step. We also set the number of replicas equal to the number of
|
||||||
|
SP groups, which is a bit of a hack / unintended use, but works!
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: Dataset to sample from.
|
||||||
|
shuffle: Whether to shuffle the dataset.
|
||||||
|
is_eval: Whether we are creating a sampler for evaluation or training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Distributed sampler.
|
||||||
|
"""
|
||||||
|
num_sp_groups = self.args.world_size // self.args.sequence_parallel_degree
|
||||||
|
sp_group_id = dist.get_rank() // self.args.sequence_parallel_degree
|
||||||
|
|
||||||
|
return DistributedSampler(
|
||||||
|
dataset,
|
||||||
|
num_replicas=num_sp_groups,
|
||||||
|
rank=sp_group_id,
|
||||||
|
seed=self.args.seed if shuffle else None,
|
||||||
|
shuffle=shuffle,
|
||||||
|
drop_last=not is_eval,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_train_sampler(self, dataset) -> Sampler | None:
|
||||||
|
"""
|
||||||
|
Get a training sampler configured for sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The training dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured sequence parallel sampler.
|
||||||
|
"""
|
||||||
|
return self._create_sequence_parallel_sampler(
|
||||||
|
dataset,
|
||||||
|
shuffle=not self.args.curriculum_sampling,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_eval_sampler(self, eval_dataset) -> Sampler | None:
|
||||||
|
"""
|
||||||
|
Get an evaluation sampler configured for sequence parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eval_dataset: The evaluation dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured sequence parallel sampler.
|
||||||
|
"""
|
||||||
|
return self._create_sequence_parallel_sampler(
|
||||||
|
eval_dataset, shuffle=False, is_eval=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_ring_flash_attn_params(self, inputs: dict[str, torch.Tensor | Any]):
|
||||||
|
"""
|
||||||
|
Calculate the cu_seqlens for the current forward pass and pass the value to
|
||||||
|
the substituted ring_flash_attn. This is accomplished by using the passed
|
||||||
|
`input_ids`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Current batch of inputs.
|
||||||
|
"""
|
||||||
|
# At this point, inputs should already be partitioned by the sequence
|
||||||
|
# parallel data collator
|
||||||
|
batch_size = inputs["input_ids"].shape[0]
|
||||||
|
seq_len = inputs["input_ids"].shape[1]
|
||||||
|
packed_seq_lens = [seq_len] * batch_size
|
||||||
|
|
||||||
|
# Calculate the full sequence length across all GPUs in this SP group
|
||||||
|
total_seq_len = seq_len * self.args.sequence_parallel_degree
|
||||||
|
|
||||||
|
cu_seqlens = torch.cumsum(
|
||||||
|
torch.tensor(
|
||||||
|
packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
cu_seqlens = F.pad(
|
||||||
|
F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len
|
||||||
|
)
|
||||||
|
|
||||||
|
update_ring_flash_attn_params(cu_seqlens, self.ring_attn_group)
|
||||||
43
src/axolotl/core/trainers/relora.py
Normal file
43
src/axolotl/core/trainers/relora.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""Module for ReLoRA trainer"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||||
|
|
||||||
|
|
||||||
|
class ReLoRATrainer(AxolotlTrainer):
|
||||||
|
"""Trainer subclass that uses the `OneCycleLR` scheduler"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "relora"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.lr_scheduler = None
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
self,
|
||||||
|
num_training_steps: int,
|
||||||
|
optimizer: torch.optim.Optimizer | None = None,
|
||||||
|
):
|
||||||
|
optimizer = self.optimizer if optimizer is None else optimizer
|
||||||
|
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
if self.args.relora_steps:
|
||||||
|
warmup_steps = (
|
||||||
|
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||||
|
)
|
||||||
|
anneal_steps = (
|
||||||
|
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||||
|
)
|
||||||
|
self.lr_scheduler = ReLoRAScheduler(
|
||||||
|
optimizer,
|
||||||
|
lr_scheduler,
|
||||||
|
self.args.relora_steps,
|
||||||
|
anneal_steps,
|
||||||
|
warmup_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
|
|
||||||
|
return self.lr_scheduler
|
||||||
@@ -1,15 +1,25 @@
|
|||||||
"""
|
"""Module for TRL PPO trainer"""
|
||||||
module for TRL PPO training
|
|
||||||
"""
|
from typing import Literal, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from trl import PPOTrainer
|
from trl import (
|
||||||
|
CPOTrainer,
|
||||||
|
KTOTrainer,
|
||||||
|
ORPOTrainer,
|
||||||
|
PPOTrainer,
|
||||||
|
PRMTrainer,
|
||||||
|
RewardTrainer,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.core.trainers.mixins import TrainerMixins
|
||||||
|
|
||||||
|
|
||||||
class TRLPPOTrainer(PPOTrainer):
|
class AxolotlPPOTrainer(TrainerMixins, PPOTrainer):
|
||||||
"""
|
"""Wrapper for TRL PPO trainer to handle customizations"""
|
||||||
wrapper for ppo trainer to handle customizations
|
|
||||||
"""
|
tag_names = ["axolotl", "ppo"]
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
@@ -30,9 +40,7 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
"batch_size": 16,
|
"batch_size": 16,
|
||||||
}
|
}
|
||||||
|
|
||||||
for epoch, batch in tqdm( # pylint: disable=unused-variable
|
for _, batch in tqdm(enumerate(self.dataloader)):
|
||||||
enumerate(self.dataloader)
|
|
||||||
):
|
|
||||||
query_tensors = batch["input_ids"]
|
query_tensors = batch["input_ids"]
|
||||||
|
|
||||||
# generate model response
|
# generate model response
|
||||||
@@ -64,3 +72,179 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
rewards,
|
rewards,
|
||||||
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
columns_to_log=["query", "response", "ref_response", "ref_rewards"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlORPOTrainer(TrainerMixins, ORPOTrainer):
|
||||||
|
"""Extend the base ORPOTrainer for axolotl helpers"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
def get_batch_loss_metrics(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
batch: dict[str, Union[list, torch.LongTensor]],
|
||||||
|
train_eval: Literal["train", "eval"] = "train",
|
||||||
|
):
|
||||||
|
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
||||||
|
|
||||||
|
# TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release
|
||||||
|
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
forward_output = self.concatenated_forward(model, batch)
|
||||||
|
(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
policy_chosen_logits,
|
||||||
|
policy_rejected_logits,
|
||||||
|
policy_nll_loss,
|
||||||
|
) = forward_output[:5]
|
||||||
|
if self.aux_loss_enabled:
|
||||||
|
aux_loss = forward_output[5]
|
||||||
|
|
||||||
|
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = (
|
||||||
|
self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
|
||||||
|
)
|
||||||
|
# full ORPO loss
|
||||||
|
loss = policy_nll_loss - losses.mean()
|
||||||
|
|
||||||
|
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||||
|
|
||||||
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
|
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(
|
||||||
|
chosen_rewards
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(
|
||||||
|
rejected_rewards
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(
|
||||||
|
reward_accuracies
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
|
||||||
|
chosen_rewards - rejected_rewards
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}logps/rejected"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logps/chosen"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
|
||||||
|
policy_rejected_logits.detach().mean()
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
|
||||||
|
policy_chosen_logits.detach().mean()
|
||||||
|
).mean()
|
||||||
|
metrics[f"{prefix}nll_loss"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}log_odds_ratio"] = (
|
||||||
|
self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}log_odds_chosen"] = (
|
||||||
|
self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
|
||||||
|
)
|
||||||
|
for k, v in metrics.items():
|
||||||
|
metrics[k] = v.item()
|
||||||
|
if self.aux_loss_enabled:
|
||||||
|
loss += self.aux_loss_coef * aux_loss
|
||||||
|
|
||||||
|
return loss, metrics
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlKTOTrainer(TrainerMixins, KTOTrainer):
|
||||||
|
"""Extend the base KTOTrainer for axolotl helpers"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlCPOTrainer(TrainerMixins, CPOTrainer):
|
||||||
|
"""Extend the base CPOTrainer for axolotl helpers"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
|
def get_batch_loss_metrics(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
batch: dict[str, Union[list, torch.LongTensor]],
|
||||||
|
train_eval: Literal["train", "eval"] = "train",
|
||||||
|
):
|
||||||
|
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
forward_output = self.concatenated_forward(model, batch)
|
||||||
|
(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
policy_chosen_logits,
|
||||||
|
policy_rejected_logits,
|
||||||
|
policy_nll_loss,
|
||||||
|
) = forward_output[:5]
|
||||||
|
if self.aux_loss_enabled:
|
||||||
|
aux_loss = forward_output[5]
|
||||||
|
|
||||||
|
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
||||||
|
policy_chosen_logps,
|
||||||
|
policy_rejected_logps,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
|
||||||
|
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||||
|
|
||||||
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
|
metrics[f"{prefix}rewards/chosen"] = (
|
||||||
|
self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}rewards/rejected"] = (
|
||||||
|
self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}rewards/accuracies"] = (
|
||||||
|
self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}rewards/margins"] = (
|
||||||
|
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards)
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logps/rejected"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_rejected_logps)
|
||||||
|
.detach()
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logps/chosen"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_chosen_logps)
|
||||||
|
.detach()
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logits/rejected"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean())
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}logits/chosen"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean())
|
||||||
|
.mean()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
metrics[f"{prefix}nll_loss"] = (
|
||||||
|
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.aux_loss_enabled:
|
||||||
|
loss += self.aux_loss_coef * aux_loss
|
||||||
|
|
||||||
|
return loss, metrics
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlRewardTrainer(TrainerMixins, RewardTrainer):
|
||||||
|
"""Extend the base RewardTrainer for axolotl helpers"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "reward"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlPRMTrainer(TrainerMixins, PRMTrainer):
|
||||||
|
"""Extend the base trl.PRMTrainer for axolotl helpers"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "prm"]
|
||||||
|
|||||||
33
src/axolotl/core/trainers/utils.py
Normal file
33
src/axolotl/core/trainers/utils.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""Utils for Axolotl trainers"""
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
||||||
|
if isinstance(tag_names, str):
|
||||||
|
tag_names = [tag_names]
|
||||||
|
|
||||||
|
if kwargs is not None:
|
||||||
|
if "tags" not in kwargs:
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||||
|
kwargs["tags"].extend(tag_names)
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||||
|
tag_names.append(kwargs["tags"])
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||||
|
if isinstance(dataset_tags, str):
|
||||||
|
dataset_tags = [dataset_tags]
|
||||||
|
|
||||||
|
if (dataset_tags is not None) and (kwargs is not None):
|
||||||
|
if "dataset_tags" not in kwargs:
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||||
|
kwargs["dataset_tags"].extend(dataset_tags)
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||||
|
dataset_tags.append(kwargs["dataset_tags"])
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
|
||||||
|
return kwargs
|
||||||
@@ -1,18 +1,18 @@
|
|||||||
"""
|
"""
|
||||||
extra axolotl specific training args
|
extra axolotl specific training args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from PIL.Image import Resampling
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingMixins:
|
class AxolotlTrainingMixins:
|
||||||
"""
|
"""Mixin class for the Axolotl training args."""
|
||||||
Mixin class for the Axolotl training args.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
model_type: Optional[str] = field(
|
model_type: Optional[str] = field(
|
||||||
@@ -32,6 +32,12 @@ class AxolotlTrainingMixins:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use sample packing for efficient training."},
|
metadata={"help": "Use sample packing for efficient training."},
|
||||||
)
|
)
|
||||||
|
sample_packing_sequentially: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||||
|
},
|
||||||
|
)
|
||||||
multipack_real_batches: bool = field(
|
multipack_real_batches: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use real batches for efficient training."},
|
metadata={"help": "Use real batches for efficient training."},
|
||||||
@@ -206,14 +212,33 @@ class AxolotlTrainingMixins:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sequence_parallel_degree: Optional[int] = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "The number of workers to use in sequence parallelism"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# multi-modal section
|
||||||
|
|
||||||
|
image_size: int | tuple[int, int] | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The size of the image to resize to"},
|
||||||
|
)
|
||||||
|
|
||||||
|
image_resize_algorithm: Resampling | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The algorithm to use for image resizing"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# end of multi-modal section
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
|
||||||
"""
|
"""
|
||||||
Training arguments for Causal trainer
|
Training arguments for Causal trainer
|
||||||
|
|
||||||
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
|
This code is duplicated due to HF TrainingArguments not setting output_dir with a
|
||||||
so it can't be used as a mixin.
|
default value so it can't be used as a mixin.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,11 +8,14 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.distributed import cleanup_distributed
|
||||||
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
from axolotl.utils.models import load_model, load_processor, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
@@ -25,18 +28,18 @@ LOG = get_logger("axolotl.evaluate")
|
|||||||
|
|
||||||
|
|
||||||
def evaluate_dataset(
|
def evaluate_dataset(
|
||||||
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False
|
||||||
) -> Optional[Dict[str, float]]:
|
) -> Optional[Dict[str, float]]:
|
||||||
"""Helper function to evaluate a single dataset safely.
|
"""Helper function to evaluate a single dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
trainer: The trainer instance
|
trainer: The trainer instance.
|
||||||
dataset: Dataset to evaluate
|
dataset: Dataset to evaluate.
|
||||||
dataset_type: Type of dataset ('train' or 'eval')
|
dataset_type: Type of dataset ('train' or 'eval').
|
||||||
flash_optimum: Whether to use flash optimum
|
flash_optimum: Whether to use flash optimum.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of metrics or None if dataset is None
|
Dictionary of metrics or None if dataset is None.
|
||||||
"""
|
"""
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return None
|
return None
|
||||||
@@ -63,17 +66,14 @@ def evaluate_dataset(
|
|||||||
|
|
||||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Evaluate a model on training and validation datasets
|
Evaluate a model on training and validation datasets.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
Dictionary mapping metric names to their values.
|
||||||
- The model (either PeftModel or PreTrainedModel)
|
|
||||||
- The tokenizer
|
|
||||||
- Dictionary of evaluation metrics
|
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
@@ -160,4 +160,6 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
|
|||||||
del model
|
del model
|
||||||
del tokenizer
|
del tokenizer
|
||||||
|
|
||||||
|
cleanup_distributed()
|
||||||
|
|
||||||
return all_metrics
|
return all_metrics
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ import importlib
|
|||||||
import logging
|
import logging
|
||||||
from typing import OrderedDict
|
from typing import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class BasePlugin:
|
class BasePlugin:
|
||||||
"""
|
"""
|
||||||
@@ -469,3 +471,14 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
plugin.post_train_unload(cfg)
|
plugin.post_train_unload(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseOptimizerFactory:
|
||||||
|
"""
|
||||||
|
Base class for factories to create custom optimizers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, opt_model, training_args, **optimizer_kwargs
|
||||||
|
) -> "torch.optim.Optimizer":
|
||||||
|
pass
|
||||||
|
|||||||
@@ -11,19 +11,17 @@
|
|||||||
# the License.
|
# the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
module to handle merging the plugins' input arguments with the base configurations.
|
Module to handle merging the plugins' input arguments with the base configurations.
|
||||||
|
|
||||||
this was moved here to prevent circular imports
|
This was moved here to prevent circular imports.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.schemas.config import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def merge_input_args():
|
def merge_input_args():
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user