Compare commits
41 Commits
sageattent
...
docker-bas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3afc91fba9 | ||
|
|
0689419d25 | ||
|
|
e64c32c0bd | ||
|
|
ec819dde3b | ||
|
|
fdf4bb5087 | ||
|
|
f67d16268c | ||
|
|
684b543aa1 | ||
|
|
5bef19064b | ||
|
|
743ba62bd5 | ||
|
|
f9a7748bd8 | ||
|
|
5e9fa33f3d | ||
|
|
08fa133177 | ||
|
|
6b3058b2dc | ||
|
|
5726141c4e | ||
|
|
2f3ebbc44f | ||
|
|
fc973f4322 | ||
|
|
e399ba533e | ||
|
|
4baf8e5e96 | ||
|
|
d7d2fd366e | ||
|
|
e2882dd749 | ||
|
|
a1790f2652 | ||
|
|
418ad2b586 | ||
|
|
d87df2c776 | ||
|
|
1ef70312ba | ||
|
|
81ef3e45f7 | ||
|
|
bd8436bc6e | ||
|
|
fc6188cd76 | ||
|
|
b9bb02406a | ||
|
|
ff4794cd8e | ||
|
|
822c904092 | ||
|
|
d5f58b6509 | ||
|
|
9f6d0b5587 | ||
|
|
53963c792c | ||
|
|
a4f4a56d77 | ||
|
|
ce5bcff750 | ||
|
|
b620ed94d0 | ||
|
|
5f1d98e8fc | ||
|
|
1cf7075d18 | ||
|
|
f4cabc2351 | ||
|
|
6e0fb4a6b2 | ||
|
|
724b660d56 |
67
.github/workflows/base.yml
vendored
67
.github/workflows/base.yml
vendored
@@ -1,6 +1,16 @@
|
|||||||
name: ci-cd-base
|
name: ci-cd-base
|
||||||
|
|
||||||
on:
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- "main"
|
||||||
|
paths:
|
||||||
|
- 'Dockerfile-base'
|
||||||
|
- '.github/workflows/base.yml'
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'Dockerfile-base'
|
||||||
|
- '.github/workflows/base.yml'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
@@ -12,36 +22,38 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: "121"
|
# - cuda: "121"
|
||||||
cuda_version: 12.1.1
|
# cuda_version: 12.1.1
|
||||||
cudnn_version: 8
|
# cudnn_version: 8
|
||||||
python_version: "3.10"
|
# python_version: "3.10"
|
||||||
pytorch: 2.3.1
|
# pytorch: 2.3.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "121"
|
# from_base_img: ""
|
||||||
cuda_version: 12.1.1
|
# from_base_tag: ""
|
||||||
cudnn_version: 8
|
# - cuda: "121"
|
||||||
python_version: "3.11"
|
# cuda_version: 12.1.1
|
||||||
pytorch: 2.3.1
|
# cudnn_version: 8
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
# python_version: "3.11"
|
||||||
- cuda: "124"
|
# pytorch: 2.3.1
|
||||||
cuda_version: 12.4.1
|
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
cudnn_version: ""
|
# from_base_img: ""
|
||||||
python_version: "3.10"
|
# from_base_tag: ""
|
||||||
pytorch: 2.4.1
|
# - cuda: "124"
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
# cuda_version: 12.4.1
|
||||||
- cuda: "124"
|
# cudnn_version: ""
|
||||||
cuda_version: 12.4.1
|
# python_version: "3.11"
|
||||||
cudnn_version: ""
|
# pytorch: 2.4.1
|
||||||
python_version: "3.11"
|
# torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
pytorch: 2.4.1
|
# from_base_img: ""
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
# from_base_tag: ""
|
||||||
- cuda: "124"
|
- cuda: "124"
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
from_base_img: nvcr.io/nvidia/pytorch
|
||||||
|
from_base_tag: 24.10-py3
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -51,7 +63,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
images: |
|
images: |
|
||||||
winglian/axolotl-base
|
winglian/axolotl-base
|
||||||
axolotlai/axolotl-base
|
# axolotlai/axolotl-base
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v2
|
||||||
with:
|
with:
|
||||||
@@ -64,7 +76,8 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./docker/Dockerfile-base
|
file: ./docker/Dockerfile-base
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: true
|
||||||
|
# push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
build-args: |
|
build-args: |
|
||||||
@@ -74,3 +87,5 @@ jobs:
|
|||||||
PYTHON_VERSION=${{ matrix.python_version }}
|
PYTHON_VERSION=${{ matrix.python_version }}
|
||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
|
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}
|
||||||
|
BASE_IMAGE=${{ matrix.from_base_img || '' }}
|
||||||
|
BASE_TAG=${{ matrix.from_base_tag || '' }}
|
||||||
|
|||||||
15
.github/workflows/tests-nightly.yml
vendored
15
.github/workflows/tests-nightly.yml
vendored
@@ -23,9 +23,15 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
max-parallel: 2
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11"]
|
python_version: ["3.10", "3.11"]
|
||||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
||||||
|
exclude:
|
||||||
|
- python_version: "3.10"
|
||||||
|
pytorch_version: "2.4.1"
|
||||||
|
- python_version: "3.10"
|
||||||
|
pytorch_version: "2.5.1"
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -55,11 +61,18 @@ jobs:
|
|||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging
|
pip3 install --upgrade packaging
|
||||||
pip3 install -U -e .
|
pip3 install -U -e .
|
||||||
|
python scripts/unsloth_install.py | sh
|
||||||
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
|
- name: Ensure axolotl CLI was installed
|
||||||
|
run: |
|
||||||
|
axolotl --help
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest --ignore=tests/e2e/ tests/
|
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||||
|
pytest tests/patched/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
126
.github/workflows/tests.yml
vendored
126
.github/workflows/tests.yml
vendored
@@ -8,11 +8,17 @@ on:
|
|||||||
- '**.py'
|
- '**.py'
|
||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
- '.github/workflows/*.yml'
|
- '.github/workflows/*.yml'
|
||||||
|
- 'requirements-tests.txt'
|
||||||
|
- 'cicd/cicd.sh'
|
||||||
|
- 'cicd/Dockerfile.jinja'
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- '**.py'
|
- '**.py'
|
||||||
- 'requirements.txt'
|
- 'requirements.txt'
|
||||||
- '.github/workflows/*.yml'
|
- '.github/workflows/*.yml'
|
||||||
|
- 'requirements-tests.txt'
|
||||||
|
- 'cicd/cicd.sh'
|
||||||
|
- 'cicd/Dockerfile.jinja'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
# Cancel jobs on the same ref if a new one is triggered
|
# Cancel jobs on the same ref if a new one is triggered
|
||||||
@@ -39,9 +45,15 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
max-parallel: 2
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11"]
|
python_version: ["3.10", "3.11"]
|
||||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
||||||
|
exclude:
|
||||||
|
- python_version: "3.10"
|
||||||
|
pytorch_version: "2.4.1"
|
||||||
|
- python_version: "3.10"
|
||||||
|
pytorch_version: "2.5.1"
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@@ -67,11 +79,18 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip3 show torch
|
pip3 show torch
|
||||||
pip3 install -U -e .
|
pip3 install -U -e .
|
||||||
|
python scripts/unsloth_install.py | sh
|
||||||
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
|
- name: Ensure axolotl CLI was installed
|
||||||
|
run: |
|
||||||
|
axolotl --help
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -n8 --ignore=tests/e2e/ tests/
|
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||||
|
pytest tests/patched/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
@@ -82,6 +101,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
max-parallel: 1
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
pytorch_version: ["2.4.1", "2.5.1"]
|
pytorch_version: ["2.4.1", "2.5.1"]
|
||||||
@@ -111,73 +131,81 @@ jobs:
|
|||||||
pip3 show torch
|
pip3 show torch
|
||||||
python3 setup.py sdist
|
python3 setup.py sdist
|
||||||
pip3 install dist/axolotl*.tar.gz
|
pip3 install dist/axolotl*.tar.gz
|
||||||
|
python scripts/unsloth_install.py | sh
|
||||||
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
|
- name: Ensure axolotl CLI was installed
|
||||||
|
run: |
|
||||||
|
axolotl --help
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -n8 --ignore=tests/e2e/ tests/
|
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
||||||
|
pytest tests/patched/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||||
|
|
||||||
docker-e2e-tests-1st:
|
# docker-e2e-tests-1st:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
# if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# # this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
# runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 90
|
# timeout-minutes: 90
|
||||||
needs: [pre-commit, pytest, pytest-sdist]
|
# needs: [pre-commit, pytest, pytest-sdist]
|
||||||
|
#
|
||||||
strategy:
|
# strategy:
|
||||||
fail-fast: false
|
# fail-fast: false
|
||||||
matrix:
|
# matrix:
|
||||||
include:
|
# include:
|
||||||
- cuda: 124
|
# - cuda: 124
|
||||||
cuda_version: 12.4.1
|
# cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
# python_version: "3.11"
|
||||||
pytorch: 2.4.1
|
# pytorch: 2.4.1
|
||||||
num_gpus: 1
|
# num_gpus: 1
|
||||||
axolotl_extras:
|
# axolotl_extras:
|
||||||
steps:
|
# steps:
|
||||||
- name: Checkout
|
# - name: Checkout
|
||||||
uses: actions/checkout@v4
|
# uses: actions/checkout@v4
|
||||||
- name: Install Python
|
# - name: Install Python
|
||||||
uses: actions/setup-python@v5
|
# uses: actions/setup-python@v5
|
||||||
with:
|
# with:
|
||||||
python-version: "3.10"
|
# python-version: "3.10"
|
||||||
- name: Install Modal
|
# - name: Install Modal
|
||||||
run: |
|
# run: |
|
||||||
python -m pip install --upgrade pip
|
# python -m pip install --upgrade pip
|
||||||
pip install modal==0.63.64 jinja2
|
# pip install modal==0.63.64 jinja2
|
||||||
- name: Update env vars
|
# - name: Update env vars
|
||||||
run: |
|
# run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
# echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
# echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
# echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
# echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
# echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
# echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
# - name: Run tests job on Modal
|
||||||
run: |
|
# run: |
|
||||||
modal run cicd.tests
|
# modal run cicd.tests
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
runs-on: [self-hosted, modal]
|
runs-on: [self-hosted, modal]
|
||||||
timeout-minutes: 90
|
timeout-minutes: 90
|
||||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
# needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||||
|
needs: [pre-commit, pytest]
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 121
|
# - cuda: 121
|
||||||
cuda_version: 12.1.1
|
# cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
# python_version: "3.10"
|
||||||
pytorch: 2.3.1
|
# pytorch: 2.3.1
|
||||||
num_gpus: 1
|
# num_gpus: 1
|
||||||
axolotl_extras: mamba-ssm
|
# axolotl_extras: mamba-ssm
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -197,7 +225,7 @@ jobs:
|
|||||||
pip install modal==0.63.64 jinja2
|
pip install modal==0.63.64 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=pr-2139-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
|
|||||||
253
README.md
253
README.md
@@ -41,9 +41,12 @@ Features:
|
|||||||
## Table of Contents
|
## Table of Contents
|
||||||
- [Axolotl](#axolotl)
|
- [Axolotl](#axolotl)
|
||||||
- [Table of Contents](#table-of-contents)
|
- [Table of Contents](#table-of-contents)
|
||||||
- [Axolotl supports](#axolotl-supports)
|
|
||||||
- [Quickstart ⚡](#quickstart-)
|
- [Quickstart ⚡](#quickstart-)
|
||||||
- [Usage](#usage)
|
- [Usage](#usage)
|
||||||
|
- [Badge ❤🏷️](#badge-️)
|
||||||
|
- [Contributing 🤝](#contributing-)
|
||||||
|
- [Sponsors 🤝❤](#sponsors-)
|
||||||
|
- [Axolotl supports](#axolotl-supports)
|
||||||
- [Advanced Setup](#advanced-setup)
|
- [Advanced Setup](#advanced-setup)
|
||||||
- [Environment](#environment)
|
- [Environment](#environment)
|
||||||
- [Docker](#docker)
|
- [Docker](#docker)
|
||||||
@@ -75,14 +78,6 @@ Features:
|
|||||||
- [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training)
|
- [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training)
|
||||||
- [Debugging Axolotl](#debugging-axolotl)
|
- [Debugging Axolotl](#debugging-axolotl)
|
||||||
- [Need help? 🙋](#need-help-)
|
- [Need help? 🙋](#need-help-)
|
||||||
- [Badge ❤🏷️](#badge-️)
|
|
||||||
- [Community Showcase](#community-showcase)
|
|
||||||
- [Contributing 🤝](#contributing-)
|
|
||||||
- [Sponsors 🤝❤](#sponsors-)
|
|
||||||
- [💎 Diamond Sponsors - Contact directly](#-diamond-sponsors---contact-directly)
|
|
||||||
- [🥇 Gold Sponsors - $5000/mo](#-gold-sponsors---5000mo)
|
|
||||||
- [🥈 Silver Sponsors - $1000/mo](#-silver-sponsors---1000mo)
|
|
||||||
- [🥉 Bronze Sponsors - $500/mo](#-bronze-sponsors---500mo)
|
|
||||||
|
|
||||||
</td>
|
</td>
|
||||||
<td>
|
<td>
|
||||||
@@ -105,6 +100,127 @@ Features:
|
|||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
|
## Quickstart ⚡
|
||||||
|
|
||||||
|
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
|
||||||
|
|
||||||
|
**Requirements**: *Nvidia* GPU (Ampere architecture or newer for `bf16` and Flash Attention) or *AMD* GPU, Python >=3.10 and PyTorch >=2.3.1.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/axolotl-ai-cloud/axolotl
|
||||||
|
cd axolotl
|
||||||
|
|
||||||
|
pip3 install packaging ninja
|
||||||
|
pip3 install -e '.[flash-attn,deepspeed]'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
```bash
|
||||||
|
# preprocess datasets - optional but recommended
|
||||||
|
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
|
||||||
|
|
||||||
|
# finetune lora
|
||||||
|
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||||
|
|
||||||
|
# inference
|
||||||
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
|
--lora_model_dir="./outputs/lora-out"
|
||||||
|
|
||||||
|
# gradio
|
||||||
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
|
--lora_model_dir="./outputs/lora-out" --gradio
|
||||||
|
|
||||||
|
# remote yaml files - the yaml config can be hosted on a public URL
|
||||||
|
# Note: the yaml config must directly link to the **raw** yaml
|
||||||
|
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Axolotl CLI
|
||||||
|
|
||||||
|
If you've installed this package using `pip` from source, we now support a new, more
|
||||||
|
streamlined CLI using [click](https://click.palletsprojects.com/en/stable/). Rewriting
|
||||||
|
the above commands:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# preprocess datasets - optional but recommended
|
||||||
|
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/openllama-3b/lora.yml
|
||||||
|
|
||||||
|
# finetune lora
|
||||||
|
axolotl train examples/openllama-3b/lora.yml
|
||||||
|
|
||||||
|
# inference
|
||||||
|
axolotl inference examples/openllama-3b/lora.yml \
|
||||||
|
--lora-model-dir="./outputs/lora-out"
|
||||||
|
|
||||||
|
# gradio
|
||||||
|
axolotl inference examples/openllama-3b/lora.yml \
|
||||||
|
--lora-model-dir="./outputs/lora-out" --gradio
|
||||||
|
|
||||||
|
# remote yaml files - the yaml config can be hosted on a public URL
|
||||||
|
# Note: the yaml config must directly link to the **raw** yaml
|
||||||
|
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
We've also added a new command for fetching `examples` and `deepspeed_configs` to your
|
||||||
|
local machine. This will come in handy when installing `axolotl` from PyPI.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Fetch example YAML files (stores in "examples/" folder)
|
||||||
|
axolotl fetch examples
|
||||||
|
|
||||||
|
# Fetch deepspeed config files (stores in "deepspeed_configs/" folder)
|
||||||
|
axolotl fetch deepspeed_configs
|
||||||
|
|
||||||
|
# Optionally, specify a destination folder
|
||||||
|
axolotl fetch examples --dest path/to/folder
|
||||||
|
```
|
||||||
|
|
||||||
|
## Badge ❤🏷️
|
||||||
|
|
||||||
|
Building something cool with Axolotl? Consider adding a badge to your model card.
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
```
|
||||||
|
|
||||||
|
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
|
||||||
|
|
||||||
|
## Sponsors 🤝❤
|
||||||
|
|
||||||
|
If you love axolotl, consider sponsoring the project by reaching out directly to [wing@axolotl.ai](mailto:wing@axolotl.ai).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
- [Modal](https://modal.com/) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Contributing 🤝
|
||||||
|
|
||||||
|
Please read the [contributing guide](./.github/CONTRIBUTING.md)
|
||||||
|
|
||||||
|
Bugs? Please check the [open issues](https://github.com/axolotl-ai-cloud/axolotl/issues/bug) else create a new Issue.
|
||||||
|
|
||||||
|
PRs are **greatly welcome**!
|
||||||
|
|
||||||
|
Please run the quickstart instructions followed by the below to setup env:
|
||||||
|
```bash
|
||||||
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
pre-commit install
|
||||||
|
|
||||||
|
# test
|
||||||
|
pytest tests/
|
||||||
|
|
||||||
|
# optional: run against all files
|
||||||
|
pre-commit run --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
|
||||||
|
|
||||||
|
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors">
|
||||||
|
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
|
||||||
|
</a>
|
||||||
|
|
||||||
## Axolotl supports
|
## Axolotl supports
|
||||||
|
|
||||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||||
@@ -130,41 +246,6 @@ Features:
|
|||||||
❌: not supported
|
❌: not supported
|
||||||
❓: untested
|
❓: untested
|
||||||
|
|
||||||
## Quickstart ⚡
|
|
||||||
|
|
||||||
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
|
|
||||||
|
|
||||||
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl
|
|
||||||
cd axolotl
|
|
||||||
|
|
||||||
pip3 install packaging ninja
|
|
||||||
pip3 install -e '.[flash-attn,deepspeed]'
|
|
||||||
```
|
|
||||||
|
|
||||||
### Usage
|
|
||||||
```bash
|
|
||||||
# preprocess datasets - optional but recommended
|
|
||||||
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
|
|
||||||
|
|
||||||
# finetune lora
|
|
||||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
|
||||||
|
|
||||||
# inference
|
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|
||||||
--lora_model_dir="./outputs/lora-out"
|
|
||||||
|
|
||||||
# gradio
|
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|
||||||
--lora_model_dir="./outputs/lora-out" --gradio
|
|
||||||
|
|
||||||
# remote yaml files - the yaml config can be hosted on a public URL
|
|
||||||
# Note: the yaml config must directly link to the **raw** yaml
|
|
||||||
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
|
|
||||||
```
|
|
||||||
|
|
||||||
## Advanced Setup
|
## Advanced Setup
|
||||||
|
|
||||||
### Environment
|
### Environment
|
||||||
@@ -682,86 +763,6 @@ See [this debugging guide](docs/debugging.qmd) for tips on debugging Axolotl, al
|
|||||||
|
|
||||||
## Need help? 🙋
|
## Need help? 🙋
|
||||||
|
|
||||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we our community members can help you.
|
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where our community members can help you.
|
||||||
|
|
||||||
Need dedicated support? Please contact us at [✉️wing@openaccessaicollective.org](mailto:wing@openaccessaicollective.org) for dedicated support options.
|
Need dedicated support? Please contact us at [✉️wing@axolotl.ai](ailto:wing@axolotl.ai) for dedicated support options.
|
||||||
|
|
||||||
## Badge ❤🏷️
|
|
||||||
|
|
||||||
Building something cool with Axolotl? Consider adding a badge to your model card.
|
|
||||||
|
|
||||||
```markdown
|
|
||||||
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
```
|
|
||||||
|
|
||||||
[<img src="https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/axolotl-ai-cloud/axolotl)
|
|
||||||
|
|
||||||
## Community Showcase
|
|
||||||
|
|
||||||
Check out some of the projects and models that have been built using Axolotl! Have a model you'd like to add to our Community Showcase? Open a PR with your model.
|
|
||||||
|
|
||||||
Open Access AI Collective
|
|
||||||
- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b-fixed)
|
|
||||||
- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
|
|
||||||
- [Hippogriff 30b](https://huggingface.co/openaccess-ai-collective/hippogriff-30b-chat)
|
|
||||||
|
|
||||||
PocketDoc Labs
|
|
||||||
- [Dan's PersonalityEngine 13b LoRA](https://huggingface.co/PocketDoc/Dans-PersonalityEngine-13b-LoRA)
|
|
||||||
|
|
||||||
## Contributing 🤝
|
|
||||||
|
|
||||||
Please read the [contributing guide](./.github/CONTRIBUTING.md)
|
|
||||||
|
|
||||||
Bugs? Please check the [open issues](https://github.com/axolotl-ai-cloud/axolotl/issues/bug) else create a new Issue.
|
|
||||||
|
|
||||||
PRs are **greatly welcome**!
|
|
||||||
|
|
||||||
Please run the quickstart instructions followed by the below to setup env:
|
|
||||||
```bash
|
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
|
||||||
pre-commit install
|
|
||||||
|
|
||||||
# test
|
|
||||||
pytest tests/
|
|
||||||
|
|
||||||
# optional: run against all files
|
|
||||||
pre-commit run --all-files
|
|
||||||
```
|
|
||||||
|
|
||||||
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
|
|
||||||
|
|
||||||
<a href="https://github.com/axolotl-ai-cloud/axolotl/graphs/contributors">
|
|
||||||
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
|
|
||||||
</a>
|
|
||||||
|
|
||||||
## Sponsors 🤝❤
|
|
||||||
|
|
||||||
OpenAccess AI Collective is run by volunteer contributors such as [winglian](https://github.com/winglian),
|
|
||||||
[NanoCode012](https://github.com/NanoCode012), [tmm1](https://github.com/tmm1),
|
|
||||||
[mhenrichsen](https://github.com/mhenrichsen), [casper-hansen](https://github.com/casper-hansen),
|
|
||||||
[hamelsmu](https://github.com/hamelsmu) and many more who help us accelerate forward by fixing bugs, answering
|
|
||||||
community questions and implementing new features. Axolotl needs donations from sponsors for the compute needed to
|
|
||||||
run our unit & integration tests, troubleshooting community issues, and providing bounties. If you love axolotl,
|
|
||||||
consider sponsoring the project via [GitHub Sponsors](https://github.com/sponsors/OpenAccess-AI-Collective),
|
|
||||||
[Ko-fi](https://ko-fi.com/axolotl_ai) or reach out directly to
|
|
||||||
[wing@openaccessaicollective.org](mailto:wing@openaccessaicollective.org).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
#### 💎 Diamond Sponsors - [Contact directly](mailto:wing@openaccessaicollective.org)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
#### 🥇 Gold Sponsors - $5000/mo
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
#### 🥈 Silver Sponsors - $1000/mo
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
#### 🥉 Bronze Sponsors - $500/mo
|
|
||||||
|
|
||||||
- [JarvisLabs.ai](https://jarvislabs.ai)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
FROM winglian/axolotl-base:{{ BASE_TAG }}
|
||||||
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||||
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
|
||||||
ENV CUDA="{{ CUDA }}"
|
ENV CUDA="{{ CUDA }}"
|
||||||
ENV BNB_CUDA_VERSION="{{ CUDA }}"
|
|
||||||
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
|
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
|
||||||
ENV GITHUB_REF="{{ GITHUB_REF }}"
|
ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||||
@@ -37,6 +36,9 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
|||||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
RUN python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
|
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
pytest -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/
|
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
|
||||||
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/
|
||||||
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
||||||
|
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
|||||||
cicd_image = (
|
cicd_image = (
|
||||||
Image.from_dockerfile(
|
Image.from_dockerfile(
|
||||||
pathlib.Path(temp_dir) / "Dockerfile",
|
pathlib.Path(temp_dir) / "Dockerfile",
|
||||||
|
context_mount=None,
|
||||||
force_build=True,
|
force_build=True,
|
||||||
gpu="A10G",
|
gpu="A10G",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
|
ARG BASE_IMAGE=axolotlai/axolotl-base
|
||||||
ARG BASE_TAG=main-base
|
ARG BASE_TAG=main-base
|
||||||
FROM axolotlai/axolotl-base:$BASE_TAG
|
FROM $BASE_IMAGE:$BASE_TAG
|
||||||
|
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
ARG AXOLOTL_EXTRAS=""
|
ARG AXOLOTL_EXTRAS=""
|
||||||
ARG AXOLOTL_ARGS=""
|
ARG AXOLOTL_ARGS=""
|
||||||
ARG CUDA="118"
|
ARG CUDA="118"
|
||||||
ENV BNB_CUDA_VERSION=$CUDA
|
|
||||||
ARG PYTORCH_VERSION="2.1.2"
|
ARG PYTORCH_VERSION="2.1.2"
|
||||||
|
|
||||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||||
@@ -26,6 +26,9 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
|||||||
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
RUN python scripts/unsloth_install.py | sh
|
||||||
|
RUN python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
RUN pip install pytest
|
RUN pip install pytest
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ ARG CUDNN_VERSION="8"
|
|||||||
ARG UBUNTU_VERSION="22.04"
|
ARG UBUNTU_VERSION="22.04"
|
||||||
ARG MAX_JOBS=4
|
ARG MAX_JOBS=4
|
||||||
|
|
||||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
ARG BASE_IMAGE=nvidia/cuda
|
||||||
|
ARG DEFAULT_TAG=${CUDA_VERSION}-cudnn${CUDNN_VERSION}-devel-ubuntu${UBUNTU_VERSION}
|
||||||
|
ARG BASE_TAG=""
|
||||||
|
FROM ${BASE_IMAGE:-nvidia/cuda}:${BASE_TAG:-${DEFAULT_TAG}} AS base-builder
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||||
|
|
||||||
@@ -29,7 +32,9 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
|||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||||
|
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||||
|
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||||
|
|
||||||
RUN git lfs install --skip-repo && \
|
RUN git lfs install --skip-repo && \
|
||||||
pip3 install awscli && \
|
pip3 install awscli && \
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ ARG BASE_TAG=main
|
|||||||
FROM axolotlai/axolotl:$BASE_TAG
|
FROM axolotlai/axolotl:$BASE_TAG
|
||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ ARG BASE_TAG=main
|
|||||||
FROM axolotlai/axolotl:$BASE_TAG
|
FROM axolotlai/axolotl:$BASE_TAG
|
||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HF_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
|||||||
ARG AXOLOTL_EXTRAS=""
|
ARG AXOLOTL_EXTRAS=""
|
||||||
ARG AXOLOTL_ARGS=""
|
ARG AXOLOTL_ARGS=""
|
||||||
ARG CUDA="118"
|
ARG CUDA="118"
|
||||||
ENV BNB_CUDA_VERSION=$CUDA
|
|
||||||
ARG PYTORCH_VERSION="2.1.2"
|
ARG PYTORCH_VERSION="2.1.2"
|
||||||
ARG GITHUB_REF="main"
|
ARG GITHUB_REF="main"
|
||||||
|
|
||||||
|
|||||||
@@ -162,6 +162,9 @@ datasets:
|
|||||||
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
|
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
|
||||||
shuffle_merged_datasets: true
|
shuffle_merged_datasets: true
|
||||||
|
|
||||||
|
Deduplicates datasets and test_datasets with identical entries.
|
||||||
|
dataset_exact_deduplication: true
|
||||||
|
|
||||||
# A list of one or more datasets to eval the model with.
|
# A list of one or more datasets to eval the model with.
|
||||||
# You can use either test_datasets, or val_set_size, but not both.
|
# You can use either test_datasets, or val_set_size, but not both.
|
||||||
test_datasets:
|
test_datasets:
|
||||||
@@ -406,7 +409,7 @@ lr_div_factor: # Learning rate div factor
|
|||||||
# - adamw_torch_fused
|
# - adamw_torch_fused
|
||||||
# - adamw_torch_xla
|
# - adamw_torch_xla
|
||||||
# - adamw_apex_fused
|
# - adamw_apex_fused
|
||||||
# - adopt_adamw (only for torch version >= 2.5.1)
|
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
|
||||||
# - adafactor
|
# - adafactor
|
||||||
# - adamw_anyprecision
|
# - adamw_anyprecision
|
||||||
# - sgd
|
# - sgd
|
||||||
|
|||||||
95
examples/llama-3/lora-1b-deduplicate-dpo.yml
Normal file
95
examples/llama-3/lora-1b-deduplicate-dpo.yml
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
base_model: meta-llama/Llama-3.2-1B
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: llama3
|
||||||
|
rl: dpo
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||||
|
type: chat_template.default
|
||||||
|
field_messages: conversation
|
||||||
|
field_chosen: chosen
|
||||||
|
field_rejected: rejected
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
roles:
|
||||||
|
system:
|
||||||
|
- system
|
||||||
|
user:
|
||||||
|
- user
|
||||||
|
assistant:
|
||||||
|
- assistant
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||||
|
type: chat_template.default
|
||||||
|
field_messages: conversation
|
||||||
|
field_chosen: chosen
|
||||||
|
field_rejected: rejected
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
roles:
|
||||||
|
system:
|
||||||
|
- system
|
||||||
|
user:
|
||||||
|
- user
|
||||||
|
assistant:
|
||||||
|
- assistant
|
||||||
|
|
||||||
|
dataset_exact_deduplication: true
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
76
examples/llama-3/lora-1b-deduplicate-sft.yml
Normal file
76
examples/llama-3/lora-1b-deduplicate-sft.yml
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
base_model: meta-llama/Llama-3.2-1B
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
dataset_exact_deduplication: true
|
||||||
|
test_value: true
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_modules_to_save:
|
||||||
|
- embed_tokens
|
||||||
|
- lm_head
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
19
pyproject.toml
Normal file
19
pyproject.toml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "axolotl"
|
||||||
|
dynamic = ["version", "dependencies", "optional-dependencies"]
|
||||||
|
description = "LLM Trainer"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
axolotl = "axolotl.cli.main:main"
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://axolotl-ai-cloud.github.io/axolotl/"
|
||||||
|
Repository = "https://github.com/axolotl-ai-cloud/axolotl.git"
|
||||||
|
|
||||||
|
[tool.setuptools_scm]
|
||||||
@@ -2,4 +2,3 @@ pre-commit
|
|||||||
black
|
black
|
||||||
mypy
|
mypy
|
||||||
types-requests
|
types-requests
|
||||||
tbparse
|
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
pytest
|
pytest
|
||||||
pytest-xdist
|
pytest-xdist
|
||||||
pytest-retry
|
pytest-retry
|
||||||
|
pytest-sugar
|
||||||
|
tbparse
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.13.2
|
peft==0.14.0
|
||||||
transformers==4.46.3
|
transformers==4.47.0
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
bitsandbytes==0.44.1
|
bitsandbytes==0.45.0
|
||||||
accelerate==1.1.0
|
accelerate==1.2.0
|
||||||
datasets==3.1.0
|
datasets==3.1.0
|
||||||
deepspeed==0.15.4
|
deepspeed==0.15.4
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
@@ -26,7 +26,7 @@ numpy>=1.24.4,<=2.0.1
|
|||||||
evaluate==0.4.1
|
evaluate==0.4.1
|
||||||
scipy
|
scipy
|
||||||
scikit-learn==1.4.2
|
scikit-learn==1.4.2
|
||||||
pynvml
|
nvidia-ml-py==12.560.30
|
||||||
art
|
art
|
||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
@@ -42,7 +42,7 @@ s3fs>=2024.5.0
|
|||||||
gcsfs>=2024.5.0
|
gcsfs>=2024.5.0
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl==0.12.0
|
trl==0.12.1
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|
||||||
|
|||||||
28
scripts/cutcrossentropy_install.py
Normal file
28
scripts/cutcrossentropy_install.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""Script to output the correct installation command for cut-cross-entropy."""
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError("Install torch via `pip install torch`") from exc
|
||||||
|
from packaging.version import Version as V
|
||||||
|
|
||||||
|
v = V(torch.__version__)
|
||||||
|
|
||||||
|
# no cut-cross-entropy support for torch < 2.4.0
|
||||||
|
if v < V("2.4.0"):
|
||||||
|
print("")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
||||||
|
|
||||||
|
UNINSTALL_PREFIX = ""
|
||||||
|
if cce_spec:
|
||||||
|
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
|
||||||
|
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
|
||||||
|
|
||||||
|
print(
|
||||||
|
UNINSTALL_PREFIX
|
||||||
|
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
|
||||||
|
)
|
||||||
@@ -8,7 +8,10 @@ from packaging.version import Version as V
|
|||||||
|
|
||||||
v = V(torch.__version__)
|
v = V(torch.__version__)
|
||||||
cuda = str(torch.version.cuda)
|
cuda = str(torch.version.cuda)
|
||||||
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
try:
|
||||||
|
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
||||||
|
except RuntimeError:
|
||||||
|
is_ampere = False
|
||||||
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
|
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
|
||||||
raise RuntimeError(f"CUDA = {cuda} not supported!")
|
raise RuntimeError(f"CUDA = {cuda} not supported!")
|
||||||
if v <= V("2.1.0"):
|
if v <= V("2.1.0"):
|
||||||
@@ -29,5 +32,5 @@ else:
|
|||||||
raise RuntimeError(f"Torch = {v} too new!")
|
raise RuntimeError(f"Torch = {v} too new!")
|
||||||
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
||||||
print(
|
print(
|
||||||
f'pip install unsloth-zoo && pip install --no-deps "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"'
|
f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"'
|
||||||
)
|
)
|
||||||
|
|||||||
11
setup.py
11
setup.py
@@ -1,5 +1,4 @@
|
|||||||
"""setup.py for axolotl"""
|
"""setup.py for axolotl"""
|
||||||
|
|
||||||
import platform
|
import platform
|
||||||
import re
|
import re
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
@@ -93,16 +92,16 @@ def parse_requirements():
|
|||||||
|
|
||||||
install_requires, dependency_links = parse_requirements()
|
install_requires, dependency_links = parse_requirements()
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="axolotl",
|
|
||||||
version="0.5.2",
|
|
||||||
description="LLM Trainer",
|
|
||||||
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
packages=find_packages("src"),
|
packages=find_packages("src"),
|
||||||
install_requires=install_requires,
|
install_requires=install_requires,
|
||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
|
entry_points={
|
||||||
|
"console_scripts": [
|
||||||
|
"axolotl=axolotl.cli.main:main",
|
||||||
|
],
|
||||||
|
},
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.7.0.post2",
|
"flash-attn==2.7.0.post2",
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
"""Axolotl - Train and fine-tune large language models"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from importlib.metadata import version
|
||||||
|
|
||||||
|
__version__ = version("axolotl")
|
||||||
|
except ImportError:
|
||||||
|
__version__ = "unknown"
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
from transformers.utils.import_utils import _is_package_available
|
from transformers.utils.import_utils import _is_package_available
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||||
from axolotl.integrations.base import PluginManager
|
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.train import TrainDatasetMeta
|
from axolotl.train import TrainDatasetMeta
|
||||||
from axolotl.utils.chat_templates import (
|
from axolotl.utils.chat_templates import (
|
||||||
@@ -38,6 +37,7 @@ from axolotl.utils.comet_ import setup_comet_env_vars
|
|||||||
from axolotl.utils.config import (
|
from axolotl.utils.config import (
|
||||||
normalize_cfg_datasets,
|
normalize_cfg_datasets,
|
||||||
normalize_config,
|
normalize_config,
|
||||||
|
prepare_plugins,
|
||||||
validate_config,
|
validate_config,
|
||||||
)
|
)
|
||||||
from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
|
from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
|
||||||
@@ -100,8 +100,8 @@ def print_dep_versions():
|
|||||||
print("*" * 40)
|
print("*" * 40)
|
||||||
print("**** Axolotl Dependency Versions *****")
|
print("**** Axolotl Dependency Versions *****")
|
||||||
for pkg in packages:
|
for pkg in packages:
|
||||||
version = _is_package_available(pkg, return_version=True)
|
pkg_version = _is_package_available(pkg, return_version=True)
|
||||||
print(f"{pkg: >{max_len}}: {version[1]: <15}")
|
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
|
||||||
print("*" * 40)
|
print("*" * 40)
|
||||||
|
|
||||||
|
|
||||||
@@ -139,7 +139,7 @@ def check_remote_config(config: Union[str, Path]):
|
|||||||
with open(output_path, "wb") as file:
|
with open(output_path, "wb") as file:
|
||||||
file.write(content)
|
file.write(content)
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
|
f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n"
|
||||||
)
|
)
|
||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
@@ -380,7 +380,7 @@ def choose_config(path: Path):
|
|||||||
|
|
||||||
if len(yaml_files) == 1:
|
if len(yaml_files) == 1:
|
||||||
print(f"Using default YAML file '{yaml_files[0]}'")
|
print(f"Using default YAML file '{yaml_files[0]}'")
|
||||||
return yaml_files[0]
|
return str(yaml_files[0])
|
||||||
|
|
||||||
print("Choose a YAML file:")
|
print("Choose a YAML file:")
|
||||||
for idx, file in enumerate(yaml_files):
|
for idx, file in enumerate(yaml_files):
|
||||||
@@ -391,7 +391,7 @@ def choose_config(path: Path):
|
|||||||
try:
|
try:
|
||||||
choice = int(input("Enter the number of your choice: "))
|
choice = int(input("Enter the number of your choice: "))
|
||||||
if 1 <= choice <= len(yaml_files):
|
if 1 <= choice <= len(yaml_files):
|
||||||
chosen_file = yaml_files[choice - 1]
|
chosen_file = str(yaml_files[choice - 1])
|
||||||
else:
|
else:
|
||||||
print("Invalid choice. Please choose a number from the list.")
|
print("Invalid choice. Please choose a number from the list.")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -426,17 +426,14 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
cfg.axolotl_config_path = config
|
cfg.axolotl_config_path = config
|
||||||
|
|
||||||
if cfg.get("plugins"):
|
|
||||||
plugin_manager = PluginManager.get_instance()
|
|
||||||
for plugin_name in cfg["plugins"]:
|
|
||||||
plugin_manager.register(plugin_name)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
device_props = torch.cuda.get_device_properties("cuda")
|
device_props = torch.cuda.get_device_properties("cuda")
|
||||||
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
||||||
except: # pylint: disable=bare-except # noqa: E722
|
except: # pylint: disable=bare-except # noqa: E722
|
||||||
gpu_version = None
|
gpu_version = None
|
||||||
|
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
|
||||||
cfg = validate_config(
|
cfg = validate_config(
|
||||||
cfg,
|
cfg,
|
||||||
capabilities={
|
capabilities={
|
||||||
@@ -444,6 +441,9 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
"compute_capability": gpu_version,
|
"compute_capability": gpu_version,
|
||||||
},
|
},
|
||||||
|
env_capabilities={
|
||||||
|
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
prepare_optim_env(cfg)
|
prepare_optim_env(cfg)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
CLI to run inference on a trained model
|
CLI to run inference on a trained model
|
||||||
"""
|
"""
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
@@ -16,10 +17,10 @@ from axolotl.cli import (
|
|||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, inference=True, **kwargs)
|
||||||
parsed_cfg.sample_packing = False
|
parsed_cfg.sample_packing = False
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
|
|||||||
231
src/axolotl/cli/main.py
Normal file
231
src/axolotl/cli/main.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
"""CLI definition for various axolotl commands."""
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
import subprocess # nosec B404
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
from axolotl.cli.utils import (
|
||||||
|
add_options_from_config,
|
||||||
|
add_options_from_dataclass,
|
||||||
|
build_command,
|
||||||
|
fetch_from_github,
|
||||||
|
)
|
||||||
|
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def cli():
|
||||||
|
"""Axolotl CLI - Train and fine-tune large language models"""
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@add_options_from_dataclass(PreprocessCliArgs)
|
||||||
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
def preprocess(config: str, **kwargs):
|
||||||
|
"""Preprocess datasets before training."""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
|
from axolotl.cli.preprocess import do_cli
|
||||||
|
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option(
|
||||||
|
"--accelerate/--no-accelerate",
|
||||||
|
default=True,
|
||||||
|
help="Use accelerate launch for multi-GPU training",
|
||||||
|
)
|
||||||
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
def train(config: str, accelerate: bool, **kwargs):
|
||||||
|
"""Train or fine-tune a model."""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
|
if accelerate:
|
||||||
|
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
|
||||||
|
if config:
|
||||||
|
base_cmd.append(config)
|
||||||
|
cmd = build_command(base_cmd, kwargs)
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
else:
|
||||||
|
from axolotl.cli.train import do_cli
|
||||||
|
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option(
|
||||||
|
"--accelerate/--no-accelerate",
|
||||||
|
default=True,
|
||||||
|
help="Use accelerate launch for multi-GPU inference",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--lora-model-dir",
|
||||||
|
type=click.Path(exists=True, path_type=str),
|
||||||
|
help="Directory containing LoRA model",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--base-model",
|
||||||
|
type=click.Path(exists=True, path_type=str),
|
||||||
|
help="Path to base model for non-LoRA models",
|
||||||
|
)
|
||||||
|
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
|
||||||
|
@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode")
|
||||||
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
def inference(
|
||||||
|
config: str,
|
||||||
|
accelerate: bool,
|
||||||
|
lora_model_dir: Optional[str] = None,
|
||||||
|
base_model: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Run inference with a trained model."""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
del kwargs["inference"] # interferes with inference.do_cli
|
||||||
|
|
||||||
|
if lora_model_dir:
|
||||||
|
kwargs["lora_model_dir"] = lora_model_dir
|
||||||
|
if base_model:
|
||||||
|
kwargs["output_dir"] = base_model
|
||||||
|
|
||||||
|
if accelerate:
|
||||||
|
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
|
||||||
|
if config:
|
||||||
|
base_cmd.append(config)
|
||||||
|
cmd = build_command(base_cmd, kwargs)
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
else:
|
||||||
|
from axolotl.cli.inference import do_cli
|
||||||
|
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option(
|
||||||
|
"--accelerate/--no-accelerate",
|
||||||
|
default=False,
|
||||||
|
help="Use accelerate launch for multi-GPU operations",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--model-dir",
|
||||||
|
type=click.Path(exists=True, path_type=str),
|
||||||
|
help="Directory containing model weights to shard",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--save-dir",
|
||||||
|
type=click.Path(path_type=str),
|
||||||
|
help="Directory to save sharded weights",
|
||||||
|
)
|
||||||
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
def shard(config: str, accelerate: bool, **kwargs):
|
||||||
|
"""Shard model weights."""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
|
if accelerate:
|
||||||
|
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"]
|
||||||
|
if config:
|
||||||
|
base_cmd.append(config)
|
||||||
|
cmd = build_command(base_cmd, kwargs)
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
else:
|
||||||
|
from axolotl.cli.shard import do_cli
|
||||||
|
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option(
|
||||||
|
"--accelerate/--no-accelerate",
|
||||||
|
default=True,
|
||||||
|
help="Use accelerate launch for weight merging",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--model-dir",
|
||||||
|
type=click.Path(exists=True, path_type=str),
|
||||||
|
help="Directory containing sharded weights",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--save-path", type=click.Path(path_type=str), help="Path to save merged weights"
|
||||||
|
)
|
||||||
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
|
||||||
|
"""Merge sharded FSDP model weights."""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
|
if accelerate:
|
||||||
|
base_cmd = [
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.merge_sharded_fsdp_weights",
|
||||||
|
]
|
||||||
|
if config:
|
||||||
|
base_cmd.append(config)
|
||||||
|
cmd = build_command(base_cmd, kwargs)
|
||||||
|
subprocess.run(cmd, check=True) # nosec B603
|
||||||
|
else:
|
||||||
|
from axolotl.cli.merge_sharded_fsdp_weights import do_cli
|
||||||
|
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option(
|
||||||
|
"--lora-model-dir",
|
||||||
|
type=click.Path(exists=True, path_type=str),
|
||||||
|
help="Directory containing the LoRA model to merge",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--output-dir",
|
||||||
|
type=click.Path(path_type=str),
|
||||||
|
help="Directory to save the merged model",
|
||||||
|
)
|
||||||
|
def merge_lora(
|
||||||
|
config: str,
|
||||||
|
lora_model_dir: Optional[str] = None,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Merge a trained LoRA into a base model"""
|
||||||
|
kwargs = {}
|
||||||
|
if lora_model_dir:
|
||||||
|
kwargs["lora_model_dir"] = lora_model_dir
|
||||||
|
if output_dir:
|
||||||
|
kwargs["output_dir"] = output_dir
|
||||||
|
|
||||||
|
from axolotl.cli.merge_lora import do_cli
|
||||||
|
|
||||||
|
do_cli(config=config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
|
||||||
|
@click.option("--dest", help="Destination directory")
|
||||||
|
def fetch(directory: str, dest: Optional[str]):
|
||||||
|
"""
|
||||||
|
Fetch example configs or other resources.
|
||||||
|
|
||||||
|
Available directories:
|
||||||
|
- examples: Example configuration files
|
||||||
|
- deepspeed_configs: DeepSpeed configuration files
|
||||||
|
"""
|
||||||
|
fetch_from_github(f"{directory}/", dest)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
cli()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
CLI to run merge a trained LoRA into a base model
|
CLI to run merge a trained LoRA into a base model
|
||||||
"""
|
"""
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
@@ -11,7 +12,7 @@ from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
|
|||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
|
|||||||
@@ -177,7 +177,7 @@ def merge_fsdp_weights(
|
|||||||
state.wait_for_everyone()
|
state.wait_for_everyone()
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
|
|||||||
218
src/axolotl/cli/utils.py
Normal file
218
src/axolotl/cli/utils.py
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
"""Utility methods for axoltl CLI."""
|
||||||
|
import concurrent.futures
|
||||||
|
import dataclasses
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from types import NoneType
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin
|
||||||
|
|
||||||
|
import click
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.cli.utils")
|
||||||
|
|
||||||
|
|
||||||
|
def add_options_from_dataclass(config_class: Type[Any]):
|
||||||
|
"""Create Click options from the fields of a dataclass."""
|
||||||
|
|
||||||
|
def decorator(function):
|
||||||
|
# Process dataclass fields in reverse order for correct option ordering
|
||||||
|
for field in reversed(dataclasses.fields(config_class)):
|
||||||
|
field_type = field.type
|
||||||
|
|
||||||
|
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||||
|
field_type = next(
|
||||||
|
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||||
|
)
|
||||||
|
|
||||||
|
if field_type == bool:
|
||||||
|
field_name = field.name.replace("_", "-")
|
||||||
|
option_name = f"--{field_name}/--no-{field_name}"
|
||||||
|
function = click.option(
|
||||||
|
option_name,
|
||||||
|
default=field.default,
|
||||||
|
help=field.metadata.get("description"),
|
||||||
|
)(function)
|
||||||
|
else:
|
||||||
|
option_name = f"--{field.name.replace('_', '-')}"
|
||||||
|
function = click.option(
|
||||||
|
option_name,
|
||||||
|
type=field_type,
|
||||||
|
default=field.default,
|
||||||
|
help=field.metadata.get("description"),
|
||||||
|
)(function)
|
||||||
|
return function
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def add_options_from_config(config_class: Type[BaseModel]):
|
||||||
|
"""Create Click options from the fields of a Pydantic model."""
|
||||||
|
|
||||||
|
def decorator(function):
|
||||||
|
# Process model fields in reverse order for correct option ordering
|
||||||
|
for name, field in reversed(config_class.model_fields.items()):
|
||||||
|
if field.annotation == bool:
|
||||||
|
field_name = name.replace("_", "-")
|
||||||
|
option_name = f"--{field_name}/--no-{field_name}"
|
||||||
|
function = click.option(
|
||||||
|
option_name, default=None, help=field.description
|
||||||
|
)(function)
|
||||||
|
else:
|
||||||
|
option_name = f"--{name.replace('_', '-')}"
|
||||||
|
function = click.option(
|
||||||
|
option_name, default=None, help=field.description
|
||||||
|
)(function)
|
||||||
|
return function
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
|
||||||
|
"""Build command list from base command and options."""
|
||||||
|
cmd = base_cmd.copy()
|
||||||
|
|
||||||
|
for key, value in options.items():
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
key = key.replace("_", "-")
|
||||||
|
|
||||||
|
if isinstance(value, bool):
|
||||||
|
if value:
|
||||||
|
cmd.append(f"--{key}")
|
||||||
|
else:
|
||||||
|
cmd.extend([f"--{key}", str(value)])
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
|
||||||
|
def download_file(
|
||||||
|
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Download a single file and return its processing status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_info: Tuple of (file_path, remote_sha)
|
||||||
|
raw_base_url: Base URL for raw GitHub content
|
||||||
|
dest_path: Local destination directory
|
||||||
|
dir_prefix: Directory prefix to filter files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'
|
||||||
|
"""
|
||||||
|
file_path, remote_sha = file_info
|
||||||
|
raw_url = f"{raw_base_url}/{file_path}"
|
||||||
|
dest_file = dest_path / file_path.split(dir_prefix)[-1]
|
||||||
|
|
||||||
|
# Check if file exists and needs updating
|
||||||
|
if dest_file.exists():
|
||||||
|
with open(dest_file, "rb") as file:
|
||||||
|
content = file.read()
|
||||||
|
# Calculate git blob SHA
|
||||||
|
blob = b"blob " + str(len(content)).encode() + b"\0" + content
|
||||||
|
local_sha = hashlib.sha1(blob, usedforsecurity=False).hexdigest()
|
||||||
|
|
||||||
|
if local_sha == remote_sha:
|
||||||
|
print(f"Skipping {file_path} (unchanged)")
|
||||||
|
return file_path, "unchanged"
|
||||||
|
|
||||||
|
print(f"Updating {file_path}")
|
||||||
|
status = "new"
|
||||||
|
else:
|
||||||
|
print(f"Downloading {file_path}")
|
||||||
|
status = "new"
|
||||||
|
|
||||||
|
# Create directories if needed
|
||||||
|
dest_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Download and save file
|
||||||
|
try:
|
||||||
|
response = requests.get(raw_url, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(dest_file, "wb") as file:
|
||||||
|
file.write(response.content)
|
||||||
|
|
||||||
|
return file_path, status
|
||||||
|
except (requests.RequestException, IOError) as request_error:
|
||||||
|
print(f"Error downloading {file_path}: {str(request_error)}")
|
||||||
|
return file_path, "error"
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_from_github(
|
||||||
|
dir_prefix: str, dest_dir: Optional[str] = None, max_workers: int = 5
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Sync files from a specific directory in the GitHub repository.
|
||||||
|
Only downloads files that don't exist locally or have changed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/')
|
||||||
|
dest_dir: Local destination directory
|
||||||
|
max_workers: Maximum number of concurrent downloads
|
||||||
|
"""
|
||||||
|
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
|
||||||
|
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
|
||||||
|
|
||||||
|
# Get repository tree with timeout
|
||||||
|
response = requests.get(api_url, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
tree = json.loads(response.text)
|
||||||
|
|
||||||
|
# Filter for files and get their SHA
|
||||||
|
files = {
|
||||||
|
item["path"]: item["sha"]
|
||||||
|
for item in tree["tree"]
|
||||||
|
if item["type"] == "blob" and item["path"].startswith(dir_prefix)
|
||||||
|
}
|
||||||
|
|
||||||
|
if not files:
|
||||||
|
raise click.ClickException(f"No files found in {dir_prefix}")
|
||||||
|
|
||||||
|
# Default destination directory is the last part of dir_prefix
|
||||||
|
default_dest = Path(dir_prefix.rstrip("/"))
|
||||||
|
dest_path = Path(dest_dir) if dest_dir else default_dest
|
||||||
|
|
||||||
|
# Keep track of processed files for summary
|
||||||
|
files_processed: Dict[str, List[str]] = {
|
||||||
|
"new": [],
|
||||||
|
"updated": [],
|
||||||
|
"unchanged": [],
|
||||||
|
"error": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process files in parallel using ThreadPoolExecutor
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
future_to_file = {
|
||||||
|
executor.submit(
|
||||||
|
download_file,
|
||||||
|
(file_path, remote_sha),
|
||||||
|
raw_base_url,
|
||||||
|
dest_path,
|
||||||
|
dir_prefix,
|
||||||
|
): file_path
|
||||||
|
for file_path, remote_sha in files.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process completed tasks as they finish
|
||||||
|
for future in concurrent.futures.as_completed(future_to_file):
|
||||||
|
file_path = future_to_file[future]
|
||||||
|
try:
|
||||||
|
file_path, status = future.result()
|
||||||
|
files_processed[status].append(file_path)
|
||||||
|
except (requests.RequestException, IOError) as request_error:
|
||||||
|
print(f"Error processing {file_path}: {str(request_error)}")
|
||||||
|
files_processed["error"].append(file_path)
|
||||||
|
|
||||||
|
# Log summary
|
||||||
|
LOG.info("\nSync Summary:")
|
||||||
|
LOG.info(f"New files: {len(files_processed['new'])}")
|
||||||
|
LOG.info(f"Updated files: {len(files_processed['updated'])}")
|
||||||
|
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
|
||||||
|
if files_processed["error"]:
|
||||||
|
LOG.info(f"Failed files: {len(files_processed['error'])}")
|
||||||
@@ -3,36 +3,88 @@ helper functions for fixing the embeddings/tokenizer
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||||
|
# GNU LESSER GENERAL PUBLIC LICENSE
|
||||||
|
# Version 3, 29 June 2007
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||||
# you may not use this file except in compliance with the License.
|
# Everyone is permitted to copy and distribute verbatim copies
|
||||||
# You may obtain a copy of the License at
|
# of this license document, but changing it is not allowed.
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
|
import logging
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.core.tokenizer_utils")
|
||||||
|
|
||||||
@torch.inference_mode
|
|
||||||
def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
@torch.inference_mode()
|
||||||
|
def fix_untrained_tokens( # pylint: disable=too-many-return-statements
|
||||||
|
model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Many of the newer models have reserved tokens that are not trained.
|
Llama-3 for eg has untrained vectors in the base model.
|
||||||
|
These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
|
||||||
|
We reset them to the mean of the rest of the tokens
|
||||||
"""
|
"""
|
||||||
|
# Code licensed under LGPL
|
||||||
embedding_matrix = model.get_input_embeddings().weight
|
embedding_matrix = model.get_input_embeddings().weight
|
||||||
lm_head_matrix = model.get_output_embeddings().weight
|
lm_head_matrix = model.get_output_embeddings().weight
|
||||||
|
chat_template = getattr(tokenizer, "chat_template", None)
|
||||||
|
tokenizer = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer
|
||||||
|
|
||||||
|
# Ignore some model checks for now
|
||||||
|
if not ignored_tokenizer_names:
|
||||||
|
ignored_tokenizer_names = []
|
||||||
|
if (
|
||||||
|
model.config._name_or_path # pylint: disable=protected-access
|
||||||
|
in ignored_tokenizer_names
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Sometimes the sizes can be different like in vision models
|
||||||
|
# Ie <image> is in input, but not in output
|
||||||
|
min_size = min(embedding_matrix.shape[1], lm_head_matrix.shape[1])
|
||||||
|
embedding_matrix = embedding_matrix[:, :min_size]
|
||||||
|
lm_head_matrix = lm_head_matrix[:, :min_size]
|
||||||
|
|
||||||
# Get untrained tokens
|
# Get untrained tokens
|
||||||
indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps
|
indicator_untrained1 = torch.amax(embedding_matrix, axis=1) <= eps
|
||||||
|
# Check lm_head as well
|
||||||
|
|
||||||
|
# Does NOT work for Llama 3.1!!
|
||||||
|
indicator_untrained2 = torch.amax(lm_head_matrix, axis=1) <= eps
|
||||||
|
|
||||||
|
# We instead check for repeated vectors
|
||||||
|
lm_head_where = torch.where(indicator_untrained1)[0]
|
||||||
|
lm_head_bad = lm_head_matrix[lm_head_where]
|
||||||
|
lm_head_bad = lm_head_bad.cpu().float().numpy().round(3)
|
||||||
|
counter = Counter()
|
||||||
|
for row in lm_head_bad:
|
||||||
|
counter[hash(row.data.tobytes())] += 1
|
||||||
|
counter = Counter({k: c for k, c in counter.items() if c >= 2})
|
||||||
|
|
||||||
|
lm_head_where = lm_head_where.cpu().numpy()
|
||||||
|
final_bad_lm_head = []
|
||||||
|
for j, row in enumerate(lm_head_bad):
|
||||||
|
if hash(row.data.tobytes()) in counter:
|
||||||
|
final_bad_lm_head.append(lm_head_where[j])
|
||||||
|
indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2)
|
||||||
|
indicator_untrained2[final_bad_lm_head] = True
|
||||||
|
|
||||||
|
# Combine both checks
|
||||||
|
indicator_untrained = indicator_untrained1 & indicator_untrained2
|
||||||
|
|
||||||
|
# Remove pad token possibility
|
||||||
|
if hasattr(tokenizer, "pad_token_id"):
|
||||||
|
pad_token_id = tokenizer.pad_token_id
|
||||||
|
if pad_token_id is not None and pad_token_id < indicator_untrained.shape[0]:
|
||||||
|
indicator_untrained[pad_token_id] = False
|
||||||
|
|
||||||
where_untrained = torch.where(indicator_untrained)[0]
|
where_untrained = torch.where(indicator_untrained)[0]
|
||||||
n_untrained = where_untrained.shape[0]
|
n_untrained = where_untrained.shape[0]
|
||||||
n_trained = embedding_matrix.shape[0] - n_untrained
|
n_trained = embedding_matrix.shape[0] - n_untrained
|
||||||
@@ -40,10 +92,9 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
|||||||
# Get set and actual tokens
|
# Get set and actual tokens
|
||||||
where_untrained = where_untrained.tolist()
|
where_untrained = where_untrained.tolist()
|
||||||
if len(where_untrained) == 0:
|
if len(where_untrained) == 0:
|
||||||
return False
|
return
|
||||||
|
|
||||||
# Remove untrained indices where it's longer
|
# Remove untrained indices where it's longer
|
||||||
|
|
||||||
where_untrained_set = frozenset(where_untrained)
|
where_untrained_set = frozenset(where_untrained)
|
||||||
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
|
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
|
||||||
# Remove None items in actual_bad_tokens
|
# Remove None items in actual_bad_tokens
|
||||||
@@ -53,10 +104,14 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
|||||||
if_bad_first = False
|
if_bad_first = False
|
||||||
if_bad_second = False
|
if_bad_second = False
|
||||||
# Check tokenizer's chat template for any untrained tokens
|
# Check tokenizer's chat template for any untrained tokens
|
||||||
chat_template = getattr(tokenizer, "chat_template", None)
|
|
||||||
if chat_template is not None:
|
if chat_template is not None:
|
||||||
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
|
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
|
||||||
|
|
||||||
|
if isinstance(train_dataset, datasets.IterableDataset):
|
||||||
|
# Skip the check, since the code below assumes
|
||||||
|
# an indexable dataset
|
||||||
|
return
|
||||||
|
|
||||||
# Check the first 250, last 250 input_ids
|
# Check the first 250, last 250 input_ids
|
||||||
size_dataset = len(train_dataset)
|
size_dataset = len(train_dataset)
|
||||||
size = min(size_dataset, 250)
|
size = min(size_dataset, 250)
|
||||||
@@ -83,7 +138,69 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
|||||||
|
|
||||||
# Check if bad tokens exists!
|
# Check if bad tokens exists!
|
||||||
if not if_bad_first and not if_bad_second:
|
if not if_bad_first and not if_bad_second:
|
||||||
return False
|
return
|
||||||
|
|
||||||
|
# Check if lm_head / embed_token are trainable!
|
||||||
|
bad_not_trainable = False
|
||||||
|
if not embedding_matrix.requires_grad:
|
||||||
|
bad_not_trainable = True
|
||||||
|
if not lm_head_matrix.requires_grad:
|
||||||
|
bad_not_trainable = True
|
||||||
|
|
||||||
|
if bad_not_trainable: # pylint: disable=too-many-nested-blocks
|
||||||
|
final_bad_items = []
|
||||||
|
|
||||||
|
# Re-check the first 250, last 250 input_ids
|
||||||
|
size_dataset = len(train_dataset)
|
||||||
|
size = min(size_dataset, 250)
|
||||||
|
for j in range(size):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
for item in input_ids:
|
||||||
|
if item in where_untrained_set:
|
||||||
|
final_bad_items.append(item)
|
||||||
|
|
||||||
|
# Re-check last 250
|
||||||
|
left = max(size_dataset - 250, 0)
|
||||||
|
for j in range(left, size_dataset):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
for item in input_ids:
|
||||||
|
if item in where_untrained_set:
|
||||||
|
final_bad_items.append(item)
|
||||||
|
|
||||||
|
# If no bad tokens, possibly chat template itself has issues?
|
||||||
|
if len(final_bad_items) == 0:
|
||||||
|
# Recheck 2000 and last 2000 items
|
||||||
|
size_dataset = len(train_dataset)
|
||||||
|
size = min(size_dataset, 2000)
|
||||||
|
for j in range(size):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
for item in input_ids:
|
||||||
|
if item in where_untrained_set:
|
||||||
|
final_bad_items.append(item)
|
||||||
|
|
||||||
|
# Re-check last 2000
|
||||||
|
left = max(size_dataset - 2000, 0)
|
||||||
|
for j in range(left, size_dataset):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
for item in input_ids:
|
||||||
|
if item in where_untrained_set:
|
||||||
|
final_bad_items.append(item)
|
||||||
|
|
||||||
|
# Most likely false signal!
|
||||||
|
if len(final_bad_items) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Untrained tokens of [{list(set(final_bad_items))}] found, but embed_tokens & lm_head not trainable, causing NaNs. "
|
||||||
|
)
|
||||||
|
|
||||||
# Count all the possible bad tokens
|
# Count all the possible bad tokens
|
||||||
final_counts = np.zeros(
|
final_counts = np.zeros(
|
||||||
@@ -97,6 +214,23 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
|||||||
|
|
||||||
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
|
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
|
||||||
|
|
||||||
|
# Get counts for untrained tokens
|
||||||
|
counts_untrained = final_counts[where_untrained]
|
||||||
|
# Identify untrained tokens seen in train_dataset
|
||||||
|
indices_seen_in_train = np.where(counts_untrained > 0)[0]
|
||||||
|
tokens_to_update = [where_untrained[i] for i in indices_seen_in_train]
|
||||||
|
|
||||||
|
if len(tokens_to_update) == 0:
|
||||||
|
LOG.info(
|
||||||
|
"No untrained tokens found in train_dataset. No embeddings were modified."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Log the token IDs that are being rescaled
|
||||||
|
LOG.info(
|
||||||
|
f"Rescaling embeddings for tokens seen in train_dataset: {tokens_to_update}"
|
||||||
|
)
|
||||||
|
|
||||||
# Get sum of all items
|
# Get sum of all items
|
||||||
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
|
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
|
||||||
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
|
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
|
||||||
@@ -113,38 +247,26 @@ def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
|||||||
mean_embedding = sum_embedding / n_trained
|
mean_embedding = sum_embedding / n_trained
|
||||||
mean_lm_head = sum_lm_head / n_trained
|
mean_lm_head = sum_lm_head / n_trained
|
||||||
|
|
||||||
# Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen
|
# Compute scaling for tokens to update
|
||||||
scaling = final_counts[where_untrained] / max(final_counts.max(), 1)
|
scaling = counts_untrained[indices_seen_in_train] / max(final_counts.max(), 1)
|
||||||
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
|
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
|
||||||
mean_embedding = (
|
|
||||||
mean_embedding.repeat(
|
|
||||||
(
|
|
||||||
n_untrained,
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
* scaling
|
|
||||||
)
|
|
||||||
mean_lm_head = (
|
|
||||||
mean_lm_head.repeat(
|
|
||||||
(
|
|
||||||
n_untrained,
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
* scaling
|
|
||||||
)
|
|
||||||
where_null = scaling.ravel() == 0
|
|
||||||
mean_embedding[where_null] = 0
|
|
||||||
mean_lm_head[where_null] = 0
|
|
||||||
|
|
||||||
# Set them to the mean
|
# Prepare mean embeddings for tokens to update
|
||||||
embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype)
|
mean_embedding_repeated = (
|
||||||
lm_head_matrix[where_untrained] = mean_lm_head.to(lm_head_matrix.dtype)
|
mean_embedding.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
|
||||||
|
)
|
||||||
|
mean_lm_head_repeated = (
|
||||||
|
mean_lm_head.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update embeddings only for tokens seen in train_dataset
|
||||||
|
embedding_matrix[tokens_to_update] = mean_embedding_repeated.to(
|
||||||
|
embedding_matrix.dtype
|
||||||
|
)
|
||||||
|
lm_head_matrix[tokens_to_update] = mean_lm_head_repeated.to(lm_head_matrix.dtype)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
return
|
||||||
return True
|
|
||||||
|
|||||||
@@ -107,6 +107,22 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
|||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
|
||||||
|
if isinstance(dataset_tags, str):
|
||||||
|
dataset_tags = [dataset_tags]
|
||||||
|
|
||||||
|
if (dataset_tags is not None) and (kwargs is not None):
|
||||||
|
if "dataset_tags" not in kwargs:
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
|
||||||
|
kwargs["dataset_tags"].extend(dataset_tags)
|
||||||
|
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
|
||||||
|
dataset_tags.append(kwargs["dataset_tags"])
|
||||||
|
kwargs["dataset_tags"] = dataset_tags
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingMixins:
|
class AxolotlTrainingMixins:
|
||||||
"""
|
"""
|
||||||
@@ -220,6 +236,14 @@ class AxolotlTrainingMixins:
|
|||||||
default=1e-6,
|
default=1e-6,
|
||||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||||
)
|
)
|
||||||
|
embedding_lr_scale: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||||
|
)
|
||||||
|
embedding_lr: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||||
|
)
|
||||||
qlora: bool = field(
|
qlora: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "whether this is a qlora training"},
|
metadata={"help": "whether this is a qlora training"},
|
||||||
@@ -386,7 +410,7 @@ class SchedulerMixin(Trainer):
|
|||||||
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
min_lr_ratio=self.args.cosine_min_lr_ratio,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer=optimizer)
|
||||||
else:
|
else:
|
||||||
if use_cosine_quadratic:
|
if use_cosine_quadratic:
|
||||||
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
|
||||||
@@ -410,10 +434,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
*_args,
|
*_args,
|
||||||
bench_data_collator=None,
|
bench_data_collator=None,
|
||||||
eval_data_collator=None,
|
eval_data_collator=None,
|
||||||
|
dataset_tags=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.bench_data_collator = bench_data_collator
|
self.bench_data_collator = bench_data_collator
|
||||||
self.eval_data_collator = eval_data_collator
|
self.eval_data_collator = eval_data_collator
|
||||||
|
self.dataset_tags = dataset_tags
|
||||||
super().__init__(*_args, **kwargs)
|
super().__init__(*_args, **kwargs)
|
||||||
self.train_data_collator = self.data_collator
|
self.train_data_collator = self.data_collator
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
@@ -435,6 +461,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
|
and self.args.embedding_lr_scale is None
|
||||||
|
and self.args.embedding_lr is None
|
||||||
and self.args.alternate_optimizer
|
and self.args.alternate_optimizer
|
||||||
not in [
|
not in [
|
||||||
"optimi_adamw",
|
"optimi_adamw",
|
||||||
@@ -449,30 +477,59 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
decay_parameters = self.get_decay_parameter_names(opt_model)
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
optimizer_grouped_parameters = [
|
params = {
|
||||||
{
|
"to_weight_decay": {}, # LayerNorm and bias
|
||||||
"params": [
|
"embeddings": {}, # lm_head, embed_tokens,
|
||||||
p
|
"no_weight_decay": {},
|
||||||
for n, p in opt_model.named_parameters()
|
}
|
||||||
if (n in decay_parameters and p.requires_grad)
|
|
||||||
],
|
|
||||||
"weight_decay": self.args.weight_decay,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p
|
|
||||||
for n, p in opt_model.named_parameters()
|
|
||||||
if (n not in decay_parameters and p.requires_grad)
|
|
||||||
],
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||||
self.args,
|
self.args,
|
||||||
opt_model,
|
opt_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for name, param in opt_model.named_parameters():
|
||||||
|
if not param.requires_grad:
|
||||||
|
continue
|
||||||
|
if name.endswith("modules_to_save.default.weight") or any(
|
||||||
|
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
|
||||||
|
):
|
||||||
|
params["embeddings"][name] = param
|
||||||
|
elif name in decay_parameters:
|
||||||
|
params["to_weight_decay"][name] = param
|
||||||
|
else:
|
||||||
|
params["no_weight_decay"][name] = param
|
||||||
|
optimizer_grouped_parameters = []
|
||||||
|
if params["to_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["to_weight_decay"].values()),
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["embeddings"]:
|
||||||
|
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
||||||
|
if self.args.embedding_lr_scale:
|
||||||
|
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
|
||||||
|
elif self.args.embedding_lr:
|
||||||
|
lr = self.args.embedding_lr # pylint: disable=invalid-name
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["embeddings"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": lr,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if params["no_weight_decay"]:
|
||||||
|
optimizer_grouped_parameters.append(
|
||||||
|
{
|
||||||
|
"params": list(params["no_weight_decay"].values()),
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
"lr": optimizer_kwargs["lr"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if self.args.loraplus_lr_ratio is not None:
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
loraplus_lr_embedding = getattr(
|
loraplus_lr_embedding = getattr(
|
||||||
@@ -485,6 +542,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
loraplus_lr_embedding=loraplus_lr_embedding,
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
**optimizer_kwargs,
|
**optimizer_kwargs,
|
||||||
)
|
)
|
||||||
|
elif (
|
||||||
|
self.args.embedding_lr_scale is not None
|
||||||
|
or self.args.embedding_lr is not None
|
||||||
|
):
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
)
|
||||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||||
from optimi import AdamW
|
from optimi import AdamW
|
||||||
|
|
||||||
@@ -516,7 +580,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
|
|
||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
ADOPT(
|
ADOPT(
|
||||||
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
|
optimizer_grouped_parameters,
|
||||||
|
decouple=True,
|
||||||
|
**optimizer_kwargs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -871,6 +937,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
|
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||||
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
|
)
|
||||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
@@ -888,13 +957,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def log(self, logs: Dict[str, float]) -> None:
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Log `logs` on the various objects watching training, including stored metrics.
|
Log `logs` on the various objects watching training, including stored metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logs (`Dict[str, float]`):
|
logs (`Dict[str, float]`):
|
||||||
The values to log.
|
The values to log.
|
||||||
|
start_time (`Optional[float]`):
|
||||||
|
The start of training.
|
||||||
"""
|
"""
|
||||||
# logs either has 'loss' or 'eval_loss'
|
# logs either has 'loss' or 'eval_loss'
|
||||||
train_eval = "train" if "loss" in logs else "eval"
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
@@ -902,7 +973,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
for key, metrics in self._stored_metrics[train_eval].items():
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
logs[key] = torch.tensor(metrics).mean().item()
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
del self._stored_metrics[train_eval]
|
del self._stored_metrics[train_eval]
|
||||||
return super().log(logs)
|
return super().log(logs, start_time)
|
||||||
|
|
||||||
def store_metrics(
|
def store_metrics(
|
||||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||||
@@ -994,8 +1065,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "dpo"]
|
tag_names = ["axolotl", "dpo"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, dataset_tags=None, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
self.dataset_tags = dataset_tags
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
@@ -1034,6 +1106,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
"""
|
"""
|
||||||
|
kwargs = _sanitize_kwargs_for_ds_tagging(
|
||||||
|
dataset_tags=self.dataset_tags, kwargs=kwargs
|
||||||
|
)
|
||||||
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
|
||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
@@ -1082,6 +1157,18 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1090,6 +1177,18 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1098,6 +1197,45 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "kto"]
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# train metrics should have no prefix, eval should have 'eval_'
|
||||||
|
prefix = "eval_" if train_eval == "eval" else ""
|
||||||
|
# accumulate average metrics from sums and lengths
|
||||||
|
for split in ["chosen", "rejected"]:
|
||||||
|
if f"count/{split}" in self._stored_metrics[train_eval]:
|
||||||
|
count_sum = (
|
||||||
|
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
|
||||||
|
.sum()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
for metric in ["rewards", "logps", "logits"]:
|
||||||
|
logs[f"{prefix}{metric}/{split}"] = (
|
||||||
|
torch.Tensor(
|
||||||
|
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
||||||
|
)
|
||||||
|
.sum()
|
||||||
|
.item()
|
||||||
|
/ count_sum
|
||||||
|
)
|
||||||
|
# delete obsolete metric
|
||||||
|
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
||||||
|
del self._stored_metrics[train_eval][f"count/{split}"]
|
||||||
|
# calculate reward margin
|
||||||
|
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
||||||
|
logs[f"{prefix}rewards/margins"] = (
|
||||||
|
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
||||||
|
)
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1106,6 +1244,18 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "cpo"]
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
# logs either has 'loss' or 'eval_loss'
|
||||||
|
train_eval = "train" if "loss" in logs else "eval"
|
||||||
|
# Add averaged stored metrics to logs
|
||||||
|
for key, metrics in self._stored_metrics[train_eval].items():
|
||||||
|
logs[key] = torch.tensor(metrics).mean().item()
|
||||||
|
del self._stored_metrics[train_eval]
|
||||||
|
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -1114,6 +1264,12 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
|||||||
|
|
||||||
tag_names = ["axolotl", "reward"]
|
tag_names = ["axolotl", "reward"]
|
||||||
|
|
||||||
|
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
|
# TODO remove once trl supports the updated to the Trainer.log method
|
||||||
|
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
|
||||||
|
logs, start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
@@ -1571,6 +1727,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"loraplus_lr_embedding"
|
"loraplus_lr_embedding"
|
||||||
] = self.cfg.loraplus_lr_embedding
|
] = self.cfg.loraplus_lr_embedding
|
||||||
|
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||||
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
|
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
@@ -1755,6 +1914,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||||
|
|
||||||
|
if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None:
|
||||||
|
trainer_kwargs["dataset_tags"] = [
|
||||||
|
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
|
]
|
||||||
trainer = trainer_cls(
|
trainer = trainer_cls(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
@@ -2028,6 +2191,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||||
|
|
||||||
|
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
||||||
|
dpo_trainer_kwargs["dataset_tags"] = [
|
||||||
|
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
|
]
|
||||||
dpo_trainer = trainer_cls(
|
dpo_trainer = trainer_cls(
|
||||||
*trainer_cls_args,
|
*trainer_cls_args,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
query_tensors,
|
query_tensors,
|
||||||
return_prompt=False,
|
return_prompt=False,
|
||||||
generate_ref_response=True,
|
generate_ref_response=True,
|
||||||
**generation_kwargs
|
**generation_kwargs,
|
||||||
)
|
)
|
||||||
batch["response"] = self.tokenizer.batch_decode(response_tensors)
|
batch["response"] = self.tokenizer.batch_decode(response_tensors)
|
||||||
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
|
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
|
||||||
|
|||||||
325
src/axolotl/integrations/cut_cross_entropy/ACKNOWLEDGEMENTS.md
Normal file
325
src/axolotl/integrations/cut_cross_entropy/ACKNOWLEDGEMENTS.md
Normal file
@@ -0,0 +1,325 @@
|
|||||||
|
Acknowledgements
|
||||||
|
|
||||||
|
Portions of this Cut Cross Entropy Software may utilize the following copyrighted
|
||||||
|
material, the use of which is hereby acknowledged.
|
||||||
|
|
||||||
|
|
||||||
|
------
|
||||||
|
|
||||||
|
|
||||||
|
PyTorch
|
||||||
|
|
||||||
|
From PyTorch:
|
||||||
|
|
||||||
|
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
||||||
|
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
||||||
|
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
||||||
|
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
||||||
|
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
||||||
|
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
||||||
|
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
||||||
|
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
||||||
|
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
||||||
|
|
||||||
|
From Caffe2:
|
||||||
|
|
||||||
|
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
||||||
|
|
||||||
|
All contributions by Facebook:
|
||||||
|
Copyright (c) 2016 Facebook Inc.
|
||||||
|
|
||||||
|
All contributions by Google:
|
||||||
|
Copyright (c) 2015 Google Inc.
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
All contributions by Yangqing Jia:
|
||||||
|
Copyright (c) 2015 Yangqing Jia
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
All contributions by Kakao Brain:
|
||||||
|
Copyright 2019-2020 Kakao Brain
|
||||||
|
|
||||||
|
All contributions by Cruise LLC:
|
||||||
|
Copyright (c) 2022 Cruise LLC.
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
All contributions by Arm:
|
||||||
|
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
|
||||||
|
|
||||||
|
All contributions from Caffe:
|
||||||
|
Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
All other contributions:
|
||||||
|
Copyright(c) 2015, 2016 the respective contributors
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
||||||
|
copyright over their contributions to Caffe2. The project versioning records
|
||||||
|
all such contribution and copyright details. If a contributor wants to further
|
||||||
|
mark their specific copyright on a particular contribution, they should
|
||||||
|
indicate their copyright solely in the commit message of the change when it is
|
||||||
|
committed.
|
||||||
|
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
1. Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
2. Redistributions in binary form must reproduce the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer in the
|
||||||
|
documentation and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
||||||
|
and IDIAP Research Institute nor the names of its contributors may be
|
||||||
|
used to endorse or promote products derived from this software without
|
||||||
|
specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||||
|
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
|
||||||
|
Triton
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Copyright 2018-2020 Philippe Tillet
|
||||||
|
* Copyright 2020-2022 OpenAI
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
* a copy of this software and associated documentation files
|
||||||
|
* (the "Software"), to deal in the Software without restriction,
|
||||||
|
* including without limitation the rights to use, copy, modify, merge,
|
||||||
|
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||||
|
* and to permit persons to whom the Software is furnished to do so,
|
||||||
|
* subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be
|
||||||
|
* included in all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||||
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||||
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||||
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||||
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
Transformers
|
||||||
|
|
||||||
|
Copyright 2018- The Hugging Face team. All rights reserved.
|
||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
47
src/axolotl/integrations/cut_cross_entropy/LICENSE
Normal file
47
src/axolotl/integrations/cut_cross_entropy/LICENSE
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
IMPORTANT: This Apple software is supplied to you by Apple
|
||||||
|
Inc. ("Apple") in consideration of your agreement to the following
|
||||||
|
terms, and your use, installation, modification or redistribution of
|
||||||
|
this Apple software constitutes acceptance of these terms. If you do
|
||||||
|
not agree with these terms, please do not use, install, modify or
|
||||||
|
redistribute this Apple software.
|
||||||
|
|
||||||
|
In consideration of your agreement to abide by the following terms, and
|
||||||
|
subject to these terms, Apple grants you a personal, non-exclusive
|
||||||
|
license, under Apple's copyrights in this original Apple software (the
|
||||||
|
"Apple Software"), to use, reproduce, modify and redistribute the Apple
|
||||||
|
Software, with or without modifications, in source and/or binary forms;
|
||||||
|
provided that if you redistribute the Apple Software in its entirety and
|
||||||
|
without modifications, you must retain this notice and the following
|
||||||
|
text and disclaimers in all such redistributions of the Apple Software.
|
||||||
|
Neither the name, trademarks, service marks or logos of Apple Inc. may
|
||||||
|
be used to endorse or promote products derived from the Apple Software
|
||||||
|
without specific prior written permission from Apple. Except as
|
||||||
|
expressly stated in this notice, no other rights or licenses, express or
|
||||||
|
implied, are granted by Apple herein, including but not limited to any
|
||||||
|
patent rights that may be infringed by your derivative works or by other
|
||||||
|
works in which the Apple Software may be incorporated.
|
||||||
|
|
||||||
|
The Apple Software is provided by Apple on an "AS IS" basis. APPLE
|
||||||
|
MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
|
||||||
|
THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
|
||||||
|
FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
|
||||||
|
OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
|
||||||
|
|
||||||
|
IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
|
||||||
|
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
|
||||||
|
MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
|
||||||
|
AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
|
||||||
|
STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
|
||||||
|
POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
|
||||||
|
|
||||||
|
-------------------------------------------------------------------------------
|
||||||
|
SOFTWARE DISTRIBUTED WITH CUT CROSS ENTROPY:
|
||||||
|
|
||||||
|
The Cut Cross Entropy software includes a number of subcomponents with separate
|
||||||
|
copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.md.
|
||||||
|
-------------------------------------------------------------------------------
|
||||||
10
src/axolotl/integrations/cut_cross_entropy/README.md
Normal file
10
src/axolotl/integrations/cut_cross_entropy/README.md
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# Cut Cross Entropy
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||||
|
|
||||||
|
cut_cross_entropy: true
|
||||||
|
```
|
||||||
83
src/axolotl/integrations/cut_cross_entropy/__init__.py
Normal file
83
src/axolotl/integrations/cut_cross_entropy/__init__.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Module for the Plugin for Cut Cross Entropy integration with Axolotl.
|
||||||
|
|
||||||
|
Cut Cross Entropy is an optimized implementation of cross entropy loss
|
||||||
|
from Apple's ML team.
|
||||||
|
"""
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.utils import get_pytorch_version
|
||||||
|
|
||||||
|
from ...utils.distributed import zero_only
|
||||||
|
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
||||||
|
|
||||||
|
_CCE_INSTALL_MESSAGE = (
|
||||||
|
"Please install cut_cross_entropy with transformers support using "
|
||||||
|
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CutCrossEntropyPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for Cut Cross Entropy integration with Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_input_args(self):
|
||||||
|
return "axolotl.integrations.cut_cross_entropy.CutCrossEntropyArgs"
|
||||||
|
|
||||||
|
def _check_requirements(self):
|
||||||
|
"""Check if all requirements are met."""
|
||||||
|
# Check PyTorch version
|
||||||
|
|
||||||
|
major, minor, _ = get_pytorch_version()
|
||||||
|
if (major, minor) < (2, 4):
|
||||||
|
raise ImportError(
|
||||||
|
"Cut Cross Entropy requires PyTorch >= 2.4.0. "
|
||||||
|
f"Current version: {torch.__version__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if cut_cross_entropy is installed
|
||||||
|
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
||||||
|
if cce_spec is None:
|
||||||
|
raise ImportError(_CCE_INSTALL_MESSAGE)
|
||||||
|
|
||||||
|
cce_spec_transformers = importlib.util.find_spec(
|
||||||
|
"cut_cross_entropy.transformers"
|
||||||
|
)
|
||||||
|
if cce_spec_transformers is None:
|
||||||
|
raise ImportError(_CCE_INSTALL_MESSAGE)
|
||||||
|
|
||||||
|
def pre_model_load(self, cfg):
|
||||||
|
"""Apply cut cross entropy before model loading if enabled."""
|
||||||
|
if cfg.cut_cross_entropy:
|
||||||
|
self._check_requirements()
|
||||||
|
|
||||||
|
from cut_cross_entropy.transformers import cce_patch
|
||||||
|
|
||||||
|
with zero_only():
|
||||||
|
LOG.info(
|
||||||
|
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# The patch checks model_type internally
|
||||||
|
cce_patch(cfg.model_config_type)
|
||||||
42
src/axolotl/integrations/cut_cross_entropy/args.py
Normal file
42
src/axolotl/integrations/cut_cross_entropy/args.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# Copyright 2024 Axolotl AI. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Module for handling Cut Cross Entropy input arguments.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args")
|
||||||
|
|
||||||
|
|
||||||
|
class CutCrossEntropyArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for Cut Cross Entropy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cut_cross_entropy: Optional[bool] = None
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_dtype_is_half(cls, data):
|
||||||
|
if data.get("cut_cross_entropy") and not (data.get("bf16") or data.get("fp16")):
|
||||||
|
raise ValueError(
|
||||||
|
"Cut Cross Entropy requires fp16/bf16 training for backward pass. "
|
||||||
|
"Please set `bf16` or `fp16` to `True`."
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
@@ -1,361 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright (c) 2024 by SageAttention team.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.autograd import Function
|
|
||||||
|
|
||||||
from .triton.attn_qk_int8_per_block_causal_varlen import (
|
|
||||||
backward as sageattn_varlen_backward,
|
|
||||||
)
|
|
||||||
from .triton.attn_qk_int8_per_block_causal_varlen import forward as attn_true_varlen
|
|
||||||
from .triton.quant_per_block_varlen import (
|
|
||||||
per_block_int8 as per_block_int8_varlen_triton,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_arch_versions():
|
|
||||||
cuda_archs = []
|
|
||||||
for i in range(torch.cuda.device_count()):
|
|
||||||
major, minor = torch.cuda.get_device_capability(i)
|
|
||||||
cuda_archs.append(f"sm{major}{minor}")
|
|
||||||
return cuda_archs
|
|
||||||
|
|
||||||
|
|
||||||
def sageattn_varlen(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
cu_seqlens_q: torch.Tensor,
|
|
||||||
cu_seqlens_k: torch.Tensor,
|
|
||||||
max_seqlen_q: int,
|
|
||||||
max_seqlen_k: int,
|
|
||||||
sm_scale: Optional[float] = None,
|
|
||||||
smooth_k: bool = True,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
q : torch.Tensor
|
|
||||||
The query tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
|
|
||||||
|
|
||||||
k : torch.Tensor
|
|
||||||
The key tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
|
|
||||||
|
|
||||||
v : torch.Tensor
|
|
||||||
The value tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
|
|
||||||
|
|
||||||
cu_seqlens_q : torch.Tensor
|
|
||||||
The cumulative sequence lengths for the query sequences in the batch, used to index into `q`.
|
|
||||||
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
|
|
||||||
|
|
||||||
cu_seqlens_k : torch.Tensor
|
|
||||||
The cumulative sequence lengths for the key and value sequences in the batch, used to index into `k` and `v`.
|
|
||||||
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
|
|
||||||
|
|
||||||
max_seqlen_q : int
|
|
||||||
The maximum sequence length for the query tensor in the batch.
|
|
||||||
|
|
||||||
max_seqlen_k : int
|
|
||||||
The maximum sequence length for the key and value tensors in the batch.
|
|
||||||
|
|
||||||
is_causal : bool
|
|
||||||
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len for each sequence.
|
|
||||||
Default: False.
|
|
||||||
|
|
||||||
sm_scale : Optional[float]
|
|
||||||
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
|
|
||||||
|
|
||||||
smooth_k : bool
|
|
||||||
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
|
|
||||||
Default: True.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
The output tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
|
|
||||||
|
|
||||||
Note
|
|
||||||
----
|
|
||||||
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
|
|
||||||
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``.
|
|
||||||
- The tensors `cu_seqlens_q` and `cu_seqlens_k` must have the dtype ``torch.int32`` or ``torch.int64``.
|
|
||||||
- All tensors must be on the same cuda device.
|
|
||||||
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
|
|
||||||
"""
|
|
||||||
|
|
||||||
dtype = q.dtype
|
|
||||||
assert q.is_cuda, "Input tensors must be on cuda."
|
|
||||||
assert dtype in [
|
|
||||||
torch.float16,
|
|
||||||
torch.bfloat16,
|
|
||||||
], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
|
|
||||||
assert q.device == k.device == v.device, "All tensors must be on the same device."
|
|
||||||
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
|
|
||||||
|
|
||||||
head_dim = q.size(-1)
|
|
||||||
assert head_dim in [64, 128], "varlen only support head_dim [64, 128]."
|
|
||||||
|
|
||||||
assert (
|
|
||||||
q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
|
|
||||||
), "Last dim of qkv must be contiguous."
|
|
||||||
assert (
|
|
||||||
cu_seqlens_q.is_contiguous() and cu_seqlens_k.is_contiguous()
|
|
||||||
), "cu_seqlens_q and cu_seqlens_k must be contiguous."
|
|
||||||
|
|
||||||
if dtype == torch.bfloat16 or dtype == torch.float32:
|
|
||||||
v = v.to(torch.float16)
|
|
||||||
|
|
||||||
if smooth_k:
|
|
||||||
km = k.mean(
|
|
||||||
dim=0, keepdim=True
|
|
||||||
) # ! km is calculated on the all the batches. Calculate over each individual sequence requires dedicated kernel.
|
|
||||||
k -= km
|
|
||||||
|
|
||||||
(
|
|
||||||
q_int8,
|
|
||||||
q_scale,
|
|
||||||
k_int8,
|
|
||||||
k_scale,
|
|
||||||
cu_seqlens_q_scale,
|
|
||||||
cu_seqlens_k_scale,
|
|
||||||
) = per_block_int8_varlen_triton(
|
|
||||||
q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=sm_scale
|
|
||||||
)
|
|
||||||
|
|
||||||
o = attn_true_varlen(
|
|
||||||
q_int8,
|
|
||||||
k_int8,
|
|
||||||
v,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
q_scale,
|
|
||||||
k_scale,
|
|
||||||
cu_seqlens_q_scale,
|
|
||||||
cu_seqlens_k_scale,
|
|
||||||
output_dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
return o
|
|
||||||
|
|
||||||
|
|
||||||
class SageAttentionFunction(Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(
|
|
||||||
ctx,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
attn_mask=None,
|
|
||||||
dropout_p=0.0,
|
|
||||||
is_causal=False,
|
|
||||||
scale=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
query: Tensor of shape [batch_size, num_heads, seq_len_q, head_dim]
|
|
||||||
key: Tensor of shape [batch_size, num_heads, seq_len_k, head_dim]
|
|
||||||
value: Tensor of shape [batch_size, num_heads, seq_len_k, head_dim]
|
|
||||||
attn_mask: Optional[Tensor], mask tensor
|
|
||||||
dropout_p: float, dropout probability
|
|
||||||
is_causal: bool, whether to apply causal masking
|
|
||||||
scale: Optional[float], scaling factor for attention scores
|
|
||||||
"""
|
|
||||||
# Ensure inputs are contiguous
|
|
||||||
query = query.contiguous()
|
|
||||||
key = key.contiguous()
|
|
||||||
value = value.contiguous()
|
|
||||||
|
|
||||||
# Handle default scale
|
|
||||||
if scale is None:
|
|
||||||
scale = 1.0 / (query.size(-1) ** 0.5)
|
|
||||||
|
|
||||||
# Save parameters needed for backward
|
|
||||||
ctx.scale = scale
|
|
||||||
ctx.is_causal = is_causal
|
|
||||||
ctx.dropout_p = dropout_p
|
|
||||||
ctx.attn_mask = attn_mask
|
|
||||||
|
|
||||||
# Prepare cumulative sequence lengths and max sequence lengths
|
|
||||||
# Assuming batch sizes are consistent across query, key, and value
|
|
||||||
batch_size, num_heads, seq_len_q, head_dim = query.shape
|
|
||||||
seq_len_k = key.shape[2]
|
|
||||||
|
|
||||||
# Flatten batch and head dimensions
|
|
||||||
q = query.view(
|
|
||||||
-1, seq_len_q, head_dim
|
|
||||||
) # [batch_size * num_heads, seq_len_q, head_dim]
|
|
||||||
k = key.view(-1, seq_len_k, head_dim)
|
|
||||||
v = value.view(-1, seq_len_k, head_dim)
|
|
||||||
|
|
||||||
# Create cumulative sequence lengths
|
|
||||||
cu_seqlens_q = torch.arange(
|
|
||||||
0,
|
|
||||||
(batch_size * num_heads + 1) * seq_len_q,
|
|
||||||
seq_len_q,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=query.device,
|
|
||||||
)
|
|
||||||
cu_seqlens_k = torch.arange(
|
|
||||||
0,
|
|
||||||
(batch_size * num_heads + 1) * seq_len_k,
|
|
||||||
seq_len_k,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=key.device,
|
|
||||||
)
|
|
||||||
max_seqlen_q = seq_len_q
|
|
||||||
max_seqlen_k = seq_len_k
|
|
||||||
|
|
||||||
# Call your custom per-block int8 quantization function
|
|
||||||
(
|
|
||||||
q_int8,
|
|
||||||
q_scale,
|
|
||||||
k_int8,
|
|
||||||
k_scale,
|
|
||||||
cu_seqlens_q_scale,
|
|
||||||
cu_seqlens_k_scale,
|
|
||||||
) = per_block_int8_varlen_triton(
|
|
||||||
q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=scale
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call your custom attention function
|
|
||||||
if is_causal:
|
|
||||||
output = attn_true_varlen(
|
|
||||||
q_int8,
|
|
||||||
k_int8,
|
|
||||||
v,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
q_scale,
|
|
||||||
k_scale,
|
|
||||||
cu_seqlens_q_scale,
|
|
||||||
cu_seqlens_k_scale,
|
|
||||||
output_dtype=query.dtype,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Non-causal attention is not implemented yet.")
|
|
||||||
|
|
||||||
# Reshape output to match the expected shape
|
|
||||||
output = output.view(batch_size, num_heads, seq_len_q, head_dim)
|
|
||||||
|
|
||||||
# Save tensors for backward
|
|
||||||
ctx.save_for_backward(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
q_int8,
|
|
||||||
k_int8,
|
|
||||||
q_scale,
|
|
||||||
k_scale,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
cu_seqlens_q_scale,
|
|
||||||
cu_seqlens_k_scale,
|
|
||||||
output,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
q_int8,
|
|
||||||
k_int8,
|
|
||||||
q_scale,
|
|
||||||
k_scale,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
cu_seqlens_q_scale,
|
|
||||||
cu_seqlens_k_scale,
|
|
||||||
output,
|
|
||||||
) = ctx.saved_tensors
|
|
||||||
|
|
||||||
scale = ctx.scale
|
|
||||||
is_causal = ctx.is_causal
|
|
||||||
dropout_p = ctx.dropout_p
|
|
||||||
attn_mask = ctx.attn_mask
|
|
||||||
|
|
||||||
# Flatten batch and head dimensions
|
|
||||||
batch_size, num_heads, seq_len_q, head_dim = query.shape
|
|
||||||
seq_len_k = key.shape[2]
|
|
||||||
grad_output = grad_output.contiguous()
|
|
||||||
do = grad_output.view(-1, seq_len_q, head_dim)
|
|
||||||
|
|
||||||
# Compute gradients w.r.t. q, k, v
|
|
||||||
dq, dk, dv = sageattn_varlen_backward(
|
|
||||||
do,
|
|
||||||
query.view(-1, seq_len_q, head_dim),
|
|
||||||
key.view(-1, seq_len_k, head_dim),
|
|
||||||
value.view(-1, seq_len_k, head_dim),
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
seq_len_q,
|
|
||||||
seq_len_k,
|
|
||||||
q_int8,
|
|
||||||
k_int8,
|
|
||||||
q_scale,
|
|
||||||
k_scale,
|
|
||||||
cu_seqlens_q_scale,
|
|
||||||
cu_seqlens_k_scale,
|
|
||||||
scale,
|
|
||||||
is_causal,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reshape gradients to match the input shapes
|
|
||||||
dq = dq.view(batch_size, num_heads, seq_len_q, head_dim)
|
|
||||||
dk = dk.view(batch_size, num_heads, seq_len_k, head_dim)
|
|
||||||
dv = dv.view(batch_size, num_heads, seq_len_k, head_dim)
|
|
||||||
|
|
||||||
# Handle optional arguments
|
|
||||||
d_attn_mask = None # Assuming attn_mask does not require gradients
|
|
||||||
d_dropout_p = (
|
|
||||||
None # Dropout probability is a hyperparameter, typically not optimized
|
|
||||||
)
|
|
||||||
d_is_causal = None # Not differentiable
|
|
||||||
d_scale = None # If scale is a tensor and requires grad, compute its gradient
|
|
||||||
|
|
||||||
return dq, dk, dv, d_attn_mask, d_dropout_p, d_is_causal, d_scale
|
|
||||||
|
|
||||||
|
|
||||||
def scaled_dot_product_attention(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
attn_mask=None,
|
|
||||||
dropout_p=0.0,
|
|
||||||
is_causal=False,
|
|
||||||
scale=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Custom scaled dot product attention using SageAttentionFunction.
|
|
||||||
"""
|
|
||||||
return SageAttentionFunction.apply(
|
|
||||||
query, key, value, attn_mask, dropout_p, is_causal, scale
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def monkeypatch_sdp_w_sage_attention():
|
|
||||||
"""
|
|
||||||
Replace torch.nn.functional.scaled_dot_product_attention with custom scaled dot product attention using SageAttentionFunction.
|
|
||||||
"""
|
|
||||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
|
||||||
@@ -1,622 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright (c) 2024 by SageAttention team.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _attn_fwd_inner(
|
|
||||||
acc,
|
|
||||||
l_i,
|
|
||||||
m_i,
|
|
||||||
q,
|
|
||||||
q_scale,
|
|
||||||
kv_len,
|
|
||||||
K_ptrs,
|
|
||||||
K_scale_ptr,
|
|
||||||
V_ptrs,
|
|
||||||
stride_kn,
|
|
||||||
stride_vn,
|
|
||||||
start_m,
|
|
||||||
H: tl.constexpr,
|
|
||||||
BLOCK_M: tl.constexpr,
|
|
||||||
HEAD_DIM: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
STAGE: tl.constexpr,
|
|
||||||
offs_m: tl.constexpr,
|
|
||||||
offs_n: tl.constexpr,
|
|
||||||
):
|
|
||||||
if STAGE == 1:
|
|
||||||
lo, hi = 0, start_m * BLOCK_M
|
|
||||||
elif STAGE == 2:
|
|
||||||
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
|
||||||
lo = tl.multiple_of(lo, BLOCK_M)
|
|
||||||
K_scale_ptr += (lo // BLOCK_N) * H
|
|
||||||
K_ptrs += stride_kn * lo
|
|
||||||
V_ptrs += stride_vn * lo
|
|
||||||
for start_n in range(lo, hi, BLOCK_N):
|
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
||||||
k_mask = offs_n[None, :] < (kv_len - start_n)
|
|
||||||
k = tl.load(K_ptrs, mask=k_mask)
|
|
||||||
k_scale = tl.load(K_scale_ptr)
|
|
||||||
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
|
|
||||||
|
|
||||||
if STAGE == 2:
|
|
||||||
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
|
||||||
qk = qk + tl.where(mask, 0, -1.0e6)
|
|
||||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
|
||||||
qk -= m_ij[:, None]
|
|
||||||
else:
|
|
||||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
|
||||||
qk = qk - m_ij[:, None]
|
|
||||||
|
|
||||||
p = tl.math.exp2(qk)
|
|
||||||
l_ij = tl.sum(p, 1)
|
|
||||||
|
|
||||||
alpha = tl.math.exp2(m_i - m_ij)
|
|
||||||
l_i = l_i * alpha + l_ij
|
|
||||||
|
|
||||||
acc = acc * alpha[:, None]
|
|
||||||
|
|
||||||
v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n))
|
|
||||||
p = p.to(tl.float16)
|
|
||||||
|
|
||||||
acc += tl.dot(p, v, out_dtype=tl.float16)
|
|
||||||
m_i = m_ij
|
|
||||||
K_ptrs += BLOCK_N * stride_kn
|
|
||||||
K_scale_ptr += H
|
|
||||||
V_ptrs += BLOCK_N * stride_vn
|
|
||||||
return acc, l_i, m_i
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _attn_fwd(
|
|
||||||
Q,
|
|
||||||
K,
|
|
||||||
V,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
Q_scale,
|
|
||||||
K_scale,
|
|
||||||
cu_seqlens_q_scale,
|
|
||||||
cu_seqlens_k_scale,
|
|
||||||
Out,
|
|
||||||
stride_qh,
|
|
||||||
stride_qn,
|
|
||||||
stride_kh,
|
|
||||||
stride_kn,
|
|
||||||
stride_vh,
|
|
||||||
stride_vn,
|
|
||||||
stride_oh,
|
|
||||||
stride_on,
|
|
||||||
H: tl.constexpr,
|
|
||||||
num_kv_groups: tl.constexpr,
|
|
||||||
HEAD_DIM: tl.constexpr,
|
|
||||||
BLOCK_M: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
STAGE: tl.constexpr,
|
|
||||||
):
|
|
||||||
start_m = tl.program_id(0)
|
|
||||||
|
|
||||||
off_z = tl.program_id(2).to(tl.int64)
|
|
||||||
off_h = tl.program_id(1).to(tl.int64)
|
|
||||||
|
|
||||||
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
|
||||||
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
|
||||||
|
|
||||||
qo_len = cu_seqlens_q_end - cu_seqlens_q_start
|
|
||||||
|
|
||||||
if (start_m * BLOCK_M) >= qo_len:
|
|
||||||
return
|
|
||||||
|
|
||||||
cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z)
|
|
||||||
cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z)
|
|
||||||
|
|
||||||
q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H
|
|
||||||
k_scale_offset = (
|
|
||||||
cu_seq_lens_k_scale_start * (H // num_kv_groups) + off_h // num_kv_groups
|
|
||||||
)
|
|
||||||
|
|
||||||
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
|
||||||
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
|
||||||
|
|
||||||
kv_len = cu_seqlens_k_end - cu_seqlens_k_start
|
|
||||||
|
|
||||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
|
||||||
offs_k = tl.arange(0, HEAD_DIM)
|
|
||||||
Q_ptrs = (
|
|
||||||
Q
|
|
||||||
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
|
||||||
+ offs_m[:, None] * stride_qn
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
Q_scale_ptr = Q_scale + q_scale_offset
|
|
||||||
K_ptrs = (
|
|
||||||
K
|
|
||||||
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
|
|
||||||
+ offs_n[None, :] * stride_kn
|
|
||||||
+ offs_k[:, None]
|
|
||||||
)
|
|
||||||
K_scale_ptr = K_scale + k_scale_offset
|
|
||||||
V_ptrs = (
|
|
||||||
V
|
|
||||||
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
|
|
||||||
+ offs_n[:, None] * stride_vn
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
O_block_ptr = (
|
|
||||||
Out
|
|
||||||
+ (cu_seqlens_q_start * stride_on + off_h * stride_oh)
|
|
||||||
+ offs_m[:, None] * stride_on
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
|
|
||||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
|
||||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
|
||||||
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
|
||||||
|
|
||||||
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
|
|
||||||
q_scale = tl.load(Q_scale_ptr)
|
|
||||||
acc, l_i, m_i = _attn_fwd_inner(
|
|
||||||
acc,
|
|
||||||
l_i,
|
|
||||||
m_i,
|
|
||||||
q,
|
|
||||||
q_scale,
|
|
||||||
kv_len,
|
|
||||||
K_ptrs,
|
|
||||||
K_scale_ptr,
|
|
||||||
V_ptrs,
|
|
||||||
stride_kn,
|
|
||||||
stride_vn,
|
|
||||||
start_m,
|
|
||||||
H // num_kv_groups,
|
|
||||||
BLOCK_M,
|
|
||||||
HEAD_DIM,
|
|
||||||
BLOCK_N,
|
|
||||||
4 - STAGE,
|
|
||||||
offs_m,
|
|
||||||
offs_n,
|
|
||||||
)
|
|
||||||
|
|
||||||
acc, l_i, _ = _attn_fwd_inner(
|
|
||||||
acc,
|
|
||||||
l_i,
|
|
||||||
m_i,
|
|
||||||
q,
|
|
||||||
q_scale,
|
|
||||||
kv_len,
|
|
||||||
K_ptrs,
|
|
||||||
K_scale_ptr,
|
|
||||||
V_ptrs,
|
|
||||||
stride_kn,
|
|
||||||
stride_vn,
|
|
||||||
start_m,
|
|
||||||
H // num_kv_groups,
|
|
||||||
BLOCK_M,
|
|
||||||
HEAD_DIM,
|
|
||||||
BLOCK_N,
|
|
||||||
2,
|
|
||||||
offs_m,
|
|
||||||
offs_n,
|
|
||||||
)
|
|
||||||
acc = acc / l_i[:, None]
|
|
||||||
tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=(offs_m[:, None] < qo_len))
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _attn_bwd_inner(
|
|
||||||
dq_acc,
|
|
||||||
dk_acc,
|
|
||||||
dv_acc,
|
|
||||||
l_i,
|
|
||||||
m_i,
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
do,
|
|
||||||
q_scale,
|
|
||||||
k_scale,
|
|
||||||
kv_len,
|
|
||||||
stride_kn,
|
|
||||||
stride_vn,
|
|
||||||
start_m,
|
|
||||||
H,
|
|
||||||
BLOCK_M: tl.constexpr,
|
|
||||||
HEAD_DIM: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
STAGE: tl.constexpr,
|
|
||||||
offs_m: tl.constexpr,
|
|
||||||
offs_n: tl.constexpr,
|
|
||||||
):
|
|
||||||
if STAGE == 1:
|
|
||||||
lo, hi = 0, start_m * BLOCK_M
|
|
||||||
elif STAGE == 2:
|
|
||||||
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
|
||||||
lo = tl.multiple_of(lo, BLOCK_M)
|
|
||||||
k += stride_kn * lo
|
|
||||||
v += stride_vn * lo
|
|
||||||
|
|
||||||
for start_n in range(lo, hi, BLOCK_N):
|
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
||||||
k_mask = offs_n[None, :] < (kv_len - start_n)
|
|
||||||
k_curr = tl.load(k, mask=k_mask)
|
|
||||||
v_curr = tl.load(v, mask=k_mask)
|
|
||||||
k_scale_curr = tl.load(k_scale)
|
|
||||||
s = tl.dot(q, k_curr, trans_b=True).to(tl.float32) * q_scale * k_scale_curr
|
|
||||||
|
|
||||||
if STAGE == 2:
|
|
||||||
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
|
||||||
s = s + tl.where(mask, 0.0, -float("inf"))
|
|
||||||
m_ij = tl.maximum(m_i, tl.max(s, 1))
|
|
||||||
s = s - m_ij[:, None]
|
|
||||||
else:
|
|
||||||
m_ij = tl.maximum(m_i, tl.max(s, 1))
|
|
||||||
s = s - m_ij[:, None]
|
|
||||||
|
|
||||||
p = tl.math.exp2(s)
|
|
||||||
l_ij = tl.sum(p, 1)
|
|
||||||
alpha = tl.math.exp2(m_i - m_ij)
|
|
||||||
l_i = l_i * alpha + l_ij
|
|
||||||
m_i = m_ij
|
|
||||||
|
|
||||||
p = p / l_i[:, None] # Normalize probabilities
|
|
||||||
|
|
||||||
# Compute gradients
|
|
||||||
# Compute softmax gradient
|
|
||||||
do_scaled = do / l_i[:, None]
|
|
||||||
dv_contrib = tl.dot(p.to(tl.float16).T, do_scaled.to(tl.float16))
|
|
||||||
dv_acc += dv_contrib
|
|
||||||
|
|
||||||
dp = tl.dot(do_scaled.to(tl.float16), v_curr.to(tl.float16).T)
|
|
||||||
|
|
||||||
# Compute ds (gradient w.r.t. logits s)
|
|
||||||
p_dp = p * dp
|
|
||||||
sum_p_dp = tl.sum(p_dp, axis=1)
|
|
||||||
ds = (p_dp - p * sum_p_dp[:, None]) * tl.math.log(2.0) # Adjust for exp2
|
|
||||||
|
|
||||||
# Compute gradients w.r.t q and k
|
|
||||||
dq_contrib = tl.dot(ds.to(tl.float16), k_curr.to(tl.float16))
|
|
||||||
dk_contrib = tl.dot(ds.to(tl.float16).T, q.to(tl.float16))
|
|
||||||
|
|
||||||
dq_acc += dq_contrib * (q_scale * k_scale_curr)
|
|
||||||
dk_acc += dk_contrib * (q_scale * k_scale_curr)
|
|
||||||
|
|
||||||
k += BLOCK_N * stride_kn
|
|
||||||
k_scale += H
|
|
||||||
v += BLOCK_N * stride_vn
|
|
||||||
|
|
||||||
return dq_acc, dk_acc, dv_acc, l_i, m_i
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _attn_bwd(
|
|
||||||
DO,
|
|
||||||
Q,
|
|
||||||
K,
|
|
||||||
V,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
Q_scale,
|
|
||||||
K_scale,
|
|
||||||
cu_seqlens_q_scale,
|
|
||||||
cu_seqlens_k_scale,
|
|
||||||
L,
|
|
||||||
M,
|
|
||||||
DQ,
|
|
||||||
DK,
|
|
||||||
DV,
|
|
||||||
stride_qh,
|
|
||||||
stride_qn,
|
|
||||||
stride_kh,
|
|
||||||
stride_kn,
|
|
||||||
stride_vh,
|
|
||||||
stride_vn,
|
|
||||||
H: tl.constexpr,
|
|
||||||
num_kv_groups: tl.constexpr,
|
|
||||||
HEAD_DIM: tl.constexpr,
|
|
||||||
BLOCK_M: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
STAGE: tl.constexpr,
|
|
||||||
):
|
|
||||||
start_m = tl.program_id(0)
|
|
||||||
off_z = tl.program_id(2).to(tl.int64)
|
|
||||||
off_h = tl.program_id(1).to(tl.int64)
|
|
||||||
|
|
||||||
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
|
||||||
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
|
||||||
qo_len = cu_seqlens_q_end - cu_seqlens_q_start
|
|
||||||
|
|
||||||
if (start_m * BLOCK_M) >= qo_len:
|
|
||||||
return
|
|
||||||
|
|
||||||
cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z)
|
|
||||||
cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z)
|
|
||||||
|
|
||||||
q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H
|
|
||||||
k_scale_offset = (
|
|
||||||
cu_seq_lens_k_scale_start * (H // num_kv_groups) + off_h // num_kv_groups
|
|
||||||
)
|
|
||||||
|
|
||||||
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
|
||||||
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
|
||||||
kv_len = cu_seqlens_k_end - cu_seqlens_k_start
|
|
||||||
|
|
||||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
|
||||||
offs_k = tl.arange(0, HEAD_DIM)
|
|
||||||
Q_ptrs = (
|
|
||||||
Q
|
|
||||||
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
|
||||||
+ offs_m[:, None] * stride_qn
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
DO_ptrs = (
|
|
||||||
DO
|
|
||||||
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
|
||||||
+ offs_m[:, None] * stride_qn
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
Q_scale_ptr = Q_scale + q_scale_offset
|
|
||||||
K_ptrs = (
|
|
||||||
K
|
|
||||||
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
|
|
||||||
+ offs_n[None, :] * stride_kn
|
|
||||||
+ offs_k[:, None]
|
|
||||||
)
|
|
||||||
K_scale_ptr = K_scale + k_scale_offset
|
|
||||||
V_ptrs = (
|
|
||||||
V
|
|
||||||
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
|
|
||||||
+ offs_n[:, None] * stride_vn
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
DQ_ptrs = (
|
|
||||||
DQ
|
|
||||||
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
|
||||||
+ offs_m[:, None] * stride_qn
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
DK_ptrs = (
|
|
||||||
DK
|
|
||||||
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
|
|
||||||
+ offs_n[None, :] * stride_kn
|
|
||||||
+ offs_k[:, None]
|
|
||||||
)
|
|
||||||
DV_ptrs = (
|
|
||||||
DV
|
|
||||||
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
|
|
||||||
+ offs_n[:, None] * stride_vn
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
L_ptrs = L + (cu_seqlens_q_start + offs_m)
|
|
||||||
M_ptrs = M + (cu_seqlens_q_start + offs_m)
|
|
||||||
|
|
||||||
m_i = tl.load(M_ptrs, mask=offs_m < qo_len, other=float("-inf"))
|
|
||||||
l_i = tl.load(L_ptrs, mask=offs_m < qo_len, other=1.0)
|
|
||||||
|
|
||||||
dq_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
|
||||||
dk_acc = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32)
|
|
||||||
dv_acc = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32)
|
|
||||||
|
|
||||||
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
|
|
||||||
do = tl.load(DO_ptrs, mask=offs_m[:, None] < qo_len)
|
|
||||||
q_scale = tl.load(Q_scale_ptr)
|
|
||||||
|
|
||||||
dq_acc, dk_acc, dv_acc, l_i, m_i = _attn_bwd_inner(
|
|
||||||
dq_acc,
|
|
||||||
dk_acc,
|
|
||||||
dv_acc,
|
|
||||||
l_i,
|
|
||||||
m_i,
|
|
||||||
q,
|
|
||||||
K_ptrs,
|
|
||||||
V_ptrs,
|
|
||||||
do,
|
|
||||||
q_scale,
|
|
||||||
K_scale_ptr,
|
|
||||||
kv_len,
|
|
||||||
stride_kn,
|
|
||||||
stride_vn,
|
|
||||||
start_m,
|
|
||||||
H // num_kv_groups,
|
|
||||||
BLOCK_M,
|
|
||||||
HEAD_DIM,
|
|
||||||
BLOCK_N,
|
|
||||||
4 - STAGE,
|
|
||||||
offs_m,
|
|
||||||
offs_n,
|
|
||||||
)
|
|
||||||
|
|
||||||
dq_acc, dk_acc, dv_acc, l_i, m_i = _attn_bwd_inner(
|
|
||||||
dq_acc,
|
|
||||||
dk_acc,
|
|
||||||
dv_acc,
|
|
||||||
l_i,
|
|
||||||
m_i,
|
|
||||||
q,
|
|
||||||
K_ptrs,
|
|
||||||
V_ptrs,
|
|
||||||
do,
|
|
||||||
q_scale,
|
|
||||||
K_scale_ptr,
|
|
||||||
kv_len,
|
|
||||||
stride_kn,
|
|
||||||
stride_vn,
|
|
||||||
start_m,
|
|
||||||
H // num_kv_groups,
|
|
||||||
BLOCK_M,
|
|
||||||
HEAD_DIM,
|
|
||||||
BLOCK_N,
|
|
||||||
2,
|
|
||||||
offs_m,
|
|
||||||
offs_n,
|
|
||||||
)
|
|
||||||
|
|
||||||
tl.store(DQ_ptrs, dq_acc.to(DQ.dtype.element_ty), mask=offs_m[:, None] < qo_len)
|
|
||||||
tl.store(DK_ptrs, dk_acc.to(DK.dtype.element_ty), mask=offs_n[None, :] < kv_len)
|
|
||||||
tl.store(DV_ptrs, dv_acc.to(DV.dtype.element_ty), mask=offs_n[:, None] < kv_len)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
q_scale,
|
|
||||||
k_scale,
|
|
||||||
cu_seqlens_q_scale,
|
|
||||||
cu_seqlens_k_scale,
|
|
||||||
output_dtype=torch.float16,
|
|
||||||
):
|
|
||||||
BLOCK_M = 128
|
|
||||||
BLOCK_N = 64
|
|
||||||
stage = 3
|
|
||||||
|
|
||||||
o = torch.empty(q.shape, dtype=output_dtype, device=q.device)
|
|
||||||
|
|
||||||
b = cu_seqlens_q.shape[0] - 1
|
|
||||||
_, h_qo, head_dim = q.shape
|
|
||||||
_, h_kv, _ = k.shape
|
|
||||||
|
|
||||||
HEAD_DIM_K = head_dim
|
|
||||||
num_kv_groups = h_qo // h_kv
|
|
||||||
|
|
||||||
grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b)
|
|
||||||
_attn_fwd[grid](
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
q_scale,
|
|
||||||
k_scale,
|
|
||||||
cu_seqlens_q_scale,
|
|
||||||
cu_seqlens_k_scale,
|
|
||||||
o,
|
|
||||||
q.stride(1),
|
|
||||||
q.stride(0),
|
|
||||||
k.stride(1),
|
|
||||||
k.stride(0),
|
|
||||||
v.stride(1),
|
|
||||||
v.stride(0),
|
|
||||||
o.stride(1),
|
|
||||||
o.stride(0),
|
|
||||||
h_qo,
|
|
||||||
num_kv_groups,
|
|
||||||
BLOCK_M=BLOCK_M,
|
|
||||||
BLOCK_N=BLOCK_N,
|
|
||||||
HEAD_DIM=HEAD_DIM_K,
|
|
||||||
STAGE=stage,
|
|
||||||
num_warps=4 if head_dim == 64 else 8,
|
|
||||||
num_stages=4,
|
|
||||||
)
|
|
||||||
return o
|
|
||||||
|
|
||||||
|
|
||||||
def backward(
|
|
||||||
do,
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
q_scale,
|
|
||||||
k_scale,
|
|
||||||
cu_seqlens_q_scale,
|
|
||||||
cu_seqlens_k_scale,
|
|
||||||
l,
|
|
||||||
m,
|
|
||||||
output_dtype=torch.float16,
|
|
||||||
):
|
|
||||||
BLOCK_M = 128
|
|
||||||
BLOCK_N = 64
|
|
||||||
stage = 3
|
|
||||||
|
|
||||||
device = q.device
|
|
||||||
dtype = q.dtype
|
|
||||||
b = cu_seqlens_q.shape[0] - 1
|
|
||||||
_, h_qo, head_dim = q.shape
|
|
||||||
_, h_kv, _ = k.shape
|
|
||||||
num_kv_groups = h_qo // h_kv
|
|
||||||
|
|
||||||
dq = torch.zeros_like(q, dtype=output_dtype)
|
|
||||||
dk = torch.zeros_like(k, dtype=output_dtype)
|
|
||||||
dv = torch.zeros_like(v, dtype=output_dtype)
|
|
||||||
|
|
||||||
grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b)
|
|
||||||
_attn_bwd[grid](
|
|
||||||
do,
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
q_scale,
|
|
||||||
k_scale,
|
|
||||||
cu_seqlens_q_scale,
|
|
||||||
cu_seqlens_k_scale,
|
|
||||||
l,
|
|
||||||
m,
|
|
||||||
dq,
|
|
||||||
dk,
|
|
||||||
dv,
|
|
||||||
q.stride(1),
|
|
||||||
q.stride(0),
|
|
||||||
k.stride(1),
|
|
||||||
k.stride(0),
|
|
||||||
v.stride(1),
|
|
||||||
v.stride(0),
|
|
||||||
h_qo,
|
|
||||||
num_kv_groups,
|
|
||||||
HEAD_DIM=head_dim,
|
|
||||||
BLOCK_M=BLOCK_M,
|
|
||||||
BLOCK_N=BLOCK_N,
|
|
||||||
STAGE=stage,
|
|
||||||
num_warps=4 if head_dim == 64 else 8,
|
|
||||||
num_stages=4,
|
|
||||||
)
|
|
||||||
return dq, dk, dv
|
|
||||||
|
|
||||||
|
|
||||||
# class TritonAttentionFunction(torch.autograd.Function):
|
|
||||||
# @staticmethod
|
|
||||||
# def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale):
|
|
||||||
# l = torch.zeros(q.shape[0], device=q.device, dtype=torch.float32)
|
|
||||||
# m = torch.zeros(q.shape[0], device=q.device, dtype=torch.float32)
|
|
||||||
# output = forward(q, k, v, cu_seqlens_q, cu_seqlens_k, q.shape[0], q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m)
|
|
||||||
# ctx.save_for_backward(q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m)
|
|
||||||
# return output
|
|
||||||
#
|
|
||||||
# @staticmethod
|
|
||||||
# def backward(ctx, do):
|
|
||||||
# q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m = ctx.saved_tensors
|
|
||||||
# dq, dk, dv = backward(
|
|
||||||
# do, q, k, v,
|
|
||||||
# cu_seqlens_q, cu_seqlens_k,
|
|
||||||
# q.shape[0], q_scale, k_scale,
|
|
||||||
# cu_seqlens_q_scale, cu_seqlens_k_scale,
|
|
||||||
# l, m,
|
|
||||||
# )
|
|
||||||
# return dq, dk, dv, None, None, None, None, None, None
|
|
||||||
@@ -1,158 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright (c) 2024 by SageAttention team.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def quant_per_block_int8_kernel(
|
|
||||||
Input,
|
|
||||||
Output,
|
|
||||||
Scale,
|
|
||||||
cu_seqlens_input,
|
|
||||||
cu_seqlens_scale,
|
|
||||||
stride_ih,
|
|
||||||
stride_in,
|
|
||||||
stride_oh,
|
|
||||||
stride_on,
|
|
||||||
sm_scale,
|
|
||||||
H: tl.constexpr,
|
|
||||||
C: tl.constexpr,
|
|
||||||
BLK: tl.constexpr,
|
|
||||||
):
|
|
||||||
off_blk = tl.program_id(0)
|
|
||||||
off_h = tl.program_id(1)
|
|
||||||
off_b = tl.program_id(2)
|
|
||||||
|
|
||||||
cu_seqlens_input_start = tl.load(cu_seqlens_input + off_b)
|
|
||||||
cu_seqlens_input_end = tl.load(cu_seqlens_input + off_b + 1)
|
|
||||||
|
|
||||||
L = cu_seqlens_input_end - cu_seqlens_input_start
|
|
||||||
|
|
||||||
if (off_blk * BLK) >= L:
|
|
||||||
return
|
|
||||||
|
|
||||||
cu_seqlens_scale_start = tl.load(cu_seqlens_scale + off_b)
|
|
||||||
|
|
||||||
offs_n = off_blk * BLK + tl.arange(0, BLK)
|
|
||||||
offs_k = tl.arange(0, C)
|
|
||||||
|
|
||||||
input_ptrs = (
|
|
||||||
Input
|
|
||||||
+ cu_seqlens_input_start * stride_in
|
|
||||||
+ off_h * stride_ih
|
|
||||||
+ offs_n[:, None] * stride_in
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
output_ptrs = (
|
|
||||||
Output
|
|
||||||
+ cu_seqlens_input_start * stride_on
|
|
||||||
+ off_h * stride_oh
|
|
||||||
+ offs_n[:, None] * stride_on
|
|
||||||
+ offs_k[None, :]
|
|
||||||
)
|
|
||||||
scale_ptrs = Scale + cu_seqlens_scale_start * H + off_h + off_blk * H
|
|
||||||
|
|
||||||
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
|
|
||||||
x = x.to(tl.float32)
|
|
||||||
x *= sm_scale
|
|
||||||
scale = tl.max(tl.abs(x)) / 127.0
|
|
||||||
x_int8 = x / scale
|
|
||||||
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
|
|
||||||
x_int8 = x_int8.to(tl.int8)
|
|
||||||
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
|
|
||||||
tl.store(scale_ptrs, scale)
|
|
||||||
|
|
||||||
|
|
||||||
def per_block_int8(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_k,
|
|
||||||
max_seqlen_q,
|
|
||||||
max_seqlen_k,
|
|
||||||
BLKQ=128,
|
|
||||||
BLKK=64,
|
|
||||||
sm_scale=None,
|
|
||||||
):
|
|
||||||
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
|
|
||||||
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
|
|
||||||
|
|
||||||
h_qo = q.shape[1]
|
|
||||||
h_kv = k.shape[1]
|
|
||||||
head_dim = q.shape[-1]
|
|
||||||
|
|
||||||
b = cu_seqlens_q.shape[0] - 1
|
|
||||||
q_batch_len = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
|
|
||||||
k_batch_len = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
|
|
||||||
|
|
||||||
q_scale_len = (q_batch_len + BLKQ - 1) // BLKQ
|
|
||||||
k_scale_len = (k_batch_len + BLKK - 1) // BLKK
|
|
||||||
|
|
||||||
cu_seqlens_q_scale = torch.nn.functional.pad(
|
|
||||||
torch.cumsum(q_scale_len, dim=0), (1, 0), value=0
|
|
||||||
)
|
|
||||||
cu_seqlens_k_scale = torch.nn.functional.pad(
|
|
||||||
torch.cumsum(k_scale_len, dim=0), (1, 0), value=0
|
|
||||||
)
|
|
||||||
|
|
||||||
q_scale = torch.empty(
|
|
||||||
(cu_seqlens_q_scale[-1], h_qo), device=q.device, dtype=torch.float32
|
|
||||||
)
|
|
||||||
k_scale = torch.empty(
|
|
||||||
(cu_seqlens_k_scale[-1], h_kv), device=k.device, dtype=torch.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
if sm_scale is None:
|
|
||||||
sm_scale = head_dim**-0.5
|
|
||||||
|
|
||||||
grid = ((max_seqlen_q + BLKQ - 1) // BLKQ, h_qo, b)
|
|
||||||
quant_per_block_int8_kernel[grid](
|
|
||||||
q,
|
|
||||||
q_int8,
|
|
||||||
q_scale,
|
|
||||||
cu_seqlens_q,
|
|
||||||
cu_seqlens_q_scale,
|
|
||||||
q.stride(1),
|
|
||||||
q.stride(0),
|
|
||||||
q_int8.stride(1),
|
|
||||||
q_int8.stride(0),
|
|
||||||
sm_scale=(sm_scale * 1.44269504),
|
|
||||||
H=h_qo,
|
|
||||||
C=head_dim,
|
|
||||||
BLK=BLKQ,
|
|
||||||
)
|
|
||||||
|
|
||||||
grid = ((max_seqlen_k + BLKK - 1) // BLKK, h_kv, b)
|
|
||||||
quant_per_block_int8_kernel[grid](
|
|
||||||
k,
|
|
||||||
k_int8,
|
|
||||||
k_scale,
|
|
||||||
cu_seqlens_k,
|
|
||||||
cu_seqlens_k_scale,
|
|
||||||
k.stride(1),
|
|
||||||
k.stride(0),
|
|
||||||
k_int8.stride(1),
|
|
||||||
k_int8.stride(0),
|
|
||||||
sm_scale=1.0,
|
|
||||||
H=h_kv,
|
|
||||||
C=head_dim,
|
|
||||||
BLK=BLKK,
|
|
||||||
)
|
|
||||||
|
|
||||||
return q_int8, q_scale, k_int8, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale
|
|
||||||
@@ -4,7 +4,6 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from functools import partial
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -94,13 +93,32 @@ def replace_llama_qkv_with_fused(model):
|
|||||||
set_module_name(model, name, qkv)
|
set_module_name(model, name, qkv)
|
||||||
|
|
||||||
|
|
||||||
def patch_llama_cross_entropy():
|
def patch_fa_llama_cross_entropy():
|
||||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
LOG.info(
|
||||||
|
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
|
||||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
|
||||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
|
||||||
CrossEntropyLoss, inplace_backward=True
|
|
||||||
)
|
)
|
||||||
|
from flash_attn.ops.triton.cross_entropy import (
|
||||||
|
cross_entropy_loss as flash_attn_cross_entropy_loss,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fa2_fixed_cross_entropy(
|
||||||
|
source,
|
||||||
|
target,
|
||||||
|
num_items_in_batch: int = None,
|
||||||
|
ignore_index: int = -100,
|
||||||
|
**kwargs,
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
reduction = "sum" if num_items_in_batch is not None else "mean"
|
||||||
|
loss, _ = flash_attn_cross_entropy_loss(
|
||||||
|
source, target, ignore_index=ignore_index
|
||||||
|
)
|
||||||
|
if reduction == "sum":
|
||||||
|
loss = loss.sum() / num_items_in_batch
|
||||||
|
else:
|
||||||
|
loss = loss.sum() / (target != ignore_index).sum()
|
||||||
|
return loss
|
||||||
|
|
||||||
|
transformers.loss.loss_utils.fixed_cross_entropy = fa2_fixed_cross_entropy
|
||||||
|
|
||||||
|
|
||||||
def patch_llama_rms_norm():
|
def patch_llama_rms_norm():
|
||||||
@@ -147,7 +165,7 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
|
|
||||||
# skip only if explicitly disabled
|
# skip only if explicitly disabled
|
||||||
if cross_entropy:
|
if cross_entropy:
|
||||||
patch_llama_cross_entropy()
|
patch_fa_llama_cross_entropy()
|
||||||
|
|
||||||
# skip only if explicitly disabled
|
# skip only if explicitly disabled
|
||||||
if rms_norm:
|
if rms_norm:
|
||||||
|
|||||||
@@ -46,9 +46,10 @@ def reset_optimizer(
|
|||||||
*,
|
*,
|
||||||
reset_params: List[str], # where str is the key to a torch.nn.Parameter
|
reset_params: List[str], # where str is the key to a torch.nn.Parameter
|
||||||
optimizer_state_keys: List[str],
|
optimizer_state_keys: List[str],
|
||||||
prune_ratio: float = 0.9,
|
optimizer_magnitude_pruning: float = 0.9,
|
||||||
):
|
):
|
||||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
# pylint:disable=unused-argument
|
||||||
|
pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
|
||||||
n_zeros = 0
|
n_zeros = 0
|
||||||
n_total = 0
|
n_total = 0
|
||||||
|
|
||||||
@@ -56,16 +57,22 @@ def reset_optimizer(
|
|||||||
if isinstance(optimizer, ZeroRedundancyOptimizer):
|
if isinstance(optimizer, ZeroRedundancyOptimizer):
|
||||||
optimizer_state = optimizer.optim.state
|
optimizer_state = optimizer.optim.state
|
||||||
|
|
||||||
for param in reset_params:
|
for group in optimizer.param_groups:
|
||||||
param_state = optimizer_state[param]
|
for param in group["params"]:
|
||||||
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
|
state = optimizer_state[param]
|
||||||
continue
|
for key, value in state.items():
|
||||||
for key in optimizer_state_keys:
|
if key not in optimizer_state_keys:
|
||||||
pruning_fn(
|
continue
|
||||||
param_state[key]
|
if torch.is_tensor(value):
|
||||||
) # pruning fn has to be inplace to keep the same keys in the dict
|
try:
|
||||||
n_total += param_state[key].numel()
|
pruning_fn(value)
|
||||||
n_zeros += torch.sum(param_state[key] == 0).item()
|
n_total += value.numel()
|
||||||
|
n_zeros += torch.sum(value == 0).item()
|
||||||
|
except RuntimeError as exc:
|
||||||
|
if "quantile() input tensor is too large" in str(exc):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise exc
|
||||||
|
|
||||||
_zeroed = n_zeros / (1e-7 + n_total) * 100
|
_zeroed = n_zeros / (1e-7 + n_total) * 100
|
||||||
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
|
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
|
||||||
@@ -129,6 +136,9 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
|
|
||||||
if "adam" in args.optim.lower():
|
if "adam" in args.optim.lower():
|
||||||
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
|
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
|
||||||
|
if "8bit" in args.optim.lower():
|
||||||
|
optimizer_state_keys.append("state1")
|
||||||
|
optimizer_state_keys.append("state2")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
|
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
|
||||||
|
|
||||||
@@ -160,7 +170,7 @@ class ReLoRACallback(TrainerCallback):
|
|||||||
optimizer,
|
optimizer,
|
||||||
reset_params=lora_params,
|
reset_params=lora_params,
|
||||||
optimizer_state_keys=optimizer_state_keys,
|
optimizer_state_keys=optimizer_state_keys,
|
||||||
prune_ratio=args.relora_prune_ratio,
|
optimizer_magnitude_pruning=args.relora_prune_ratio,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.quantized:
|
if self.quantized:
|
||||||
|
|||||||
207
src/axolotl/monkeypatch/trainer_grad_accum.py
Normal file
207
src/axolotl/monkeypatch/trainer_grad_accum.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""
|
||||||
|
fix for FSDP gradient accumulation
|
||||||
|
see https://github.com/huggingface/transformers/pull/35128
|
||||||
|
"""
|
||||||
|
import inspect
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from transformers import LlamaForCausalLM
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
|
||||||
|
|
||||||
|
ORIGINAL_CONTEXT_CODE = """
|
||||||
|
with self.compute_loss_context_manager():
|
||||||
|
if self.model_accepts_loss_kwargs:
|
||||||
|
loss = self.compute_loss(model, inputs)
|
||||||
|
else:
|
||||||
|
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_CONTEXT_CODE = """
|
||||||
|
with self.compute_loss_context_manager():
|
||||||
|
if self.model_accepts_loss_kwargs:
|
||||||
|
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||||
|
else:
|
||||||
|
loss = self.compute_loss(model, inputs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
ORIGINAL_LLAMA_FCLM_CODE = """
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_LLAMA_FCLM_CODE = """
|
||||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
|
||||||
|
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_step_code() -> str:
|
||||||
|
training_step = inspect.getsource(
|
||||||
|
Trainer.training_step # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
return training_step
|
||||||
|
|
||||||
|
|
||||||
|
def check_training_step_is_patchable() -> bool:
|
||||||
|
training_step = get_training_step_code()
|
||||||
|
training_step, _ = detab_code(training_step)
|
||||||
|
return ORIGINAL_CONTEXT_CODE in training_step
|
||||||
|
|
||||||
|
|
||||||
|
def patch_training_step_for_ga():
|
||||||
|
"""
|
||||||
|
monkeypatch for fixing the training loop for gradient accumulation
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
training_step = get_training_step_code()
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
Trainer._original_training_step = training_step # pylint: disable=protected-access
|
||||||
|
training_step, _ = detab_code(training_step)
|
||||||
|
if ORIGINAL_CONTEXT_CODE not in training_step:
|
||||||
|
return
|
||||||
|
# assert (
|
||||||
|
# ORIGINAL_CONTEXT_CODE in training_step
|
||||||
|
# ), "Original training_step code not found"
|
||||||
|
|
||||||
|
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
|
||||||
|
training_step = training_step.replace(
|
||||||
|
"def training_step(",
|
||||||
|
"def _fixed_training_step(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.trainer
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.trainer):
|
||||||
|
if item in training_step:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.trainer import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching training_step")
|
||||||
|
Trainer.training_step = ( # pylint: disable=protected-access
|
||||||
|
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_forward_code() -> str:
|
||||||
|
forward = inspect.getsource(
|
||||||
|
LlamaForCausalLM.forward # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_is_patchable() -> bool:
|
||||||
|
forward = get_model_forward_code()
|
||||||
|
forward, _ = detab_code(forward)
|
||||||
|
return ORIGINAL_LLAMA_FCLM_CODE in forward
|
||||||
|
|
||||||
|
|
||||||
|
def patch_forward_for_ga():
|
||||||
|
"""
|
||||||
|
monkeypatch for fixing the training loop for gradient accumulation
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
forward = get_model_forward_code()
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
||||||
|
forward, _ = detab_code(forward)
|
||||||
|
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
|
||||||
|
return
|
||||||
|
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
|
||||||
|
|
||||||
|
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
|
||||||
|
forward = forward.replace(
|
||||||
|
"def forward(",
|
||||||
|
"def _fixed_forward(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.models.llama.modeling_llama):
|
||||||
|
if item in forward:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.models.llama.modeling_llama import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching forward")
|
||||||
|
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
|
||||||
|
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
@@ -9,10 +9,7 @@ import torch
|
|||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from peft import PeftModelForCausalLM
|
from peft import PeftModelForCausalLM
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||||
LlamaFlashAttention2,
|
|
||||||
LlamaForCausalLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||||
|
|
||||||
@@ -55,11 +52,6 @@ def original_apply_o(self, hidden_states):
|
|||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
def get_forward_code() -> str:
|
|
||||||
forward = inspect.getsource(LlamaForCausalLM.forward)
|
|
||||||
return forward
|
|
||||||
|
|
||||||
|
|
||||||
def get_self_attn_code() -> str:
|
def get_self_attn_code() -> str:
|
||||||
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
||||||
return forward
|
return forward
|
||||||
@@ -102,12 +94,22 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
|||||||
|
|
||||||
|
|
||||||
def detab_code(code: str) -> Tuple[str, str]:
|
def detab_code(code: str) -> Tuple[str, str]:
|
||||||
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
try:
|
||||||
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
||||||
|
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
||||||
|
except AttributeError:
|
||||||
|
return code, ""
|
||||||
return code, spaces
|
return code, spaces
|
||||||
|
|
||||||
|
|
||||||
|
self_attn_lora_patched = False # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
def patch_self_attn_lora():
|
def patch_self_attn_lora():
|
||||||
|
global self_attn_lora_patched # pylint: disable=global-statement
|
||||||
|
if self_attn_lora_patched:
|
||||||
|
# prevent patching multiple times
|
||||||
|
return
|
||||||
self_attn_forward = get_self_attn_code()
|
self_attn_forward = get_self_attn_code()
|
||||||
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
|
LlamaFlashAttention2._original_forward = ( # pylint: disable=protected-access
|
||||||
self_attn_forward
|
self_attn_forward
|
||||||
@@ -139,6 +141,7 @@ def patch_self_attn_lora():
|
|||||||
globals(),
|
globals(),
|
||||||
)
|
)
|
||||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
self_attn_lora_patched = True
|
||||||
LOG.info("patching unsloth attn lora", main_process_only=True)
|
LOG.info("patching unsloth attn lora", main_process_only=True)
|
||||||
LlamaFlashAttention2.forward = (
|
LlamaFlashAttention2.forward = (
|
||||||
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
|||||||
@@ -259,11 +259,31 @@ def train(
|
|||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
|
||||||
if not cfg.hub_model_id:
|
if not cfg.hub_model_id:
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
from huggingface_hub.utils import RepositoryNotFoundError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
trainer.create_model_card(
|
# Check to make sure the base model is from HuggingFace not a local directory
|
||||||
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
|
hf_api = HfApi()
|
||||||
)
|
hf_api.model_info(cfg.base_model)
|
||||||
except (AttributeError, UnicodeDecodeError):
|
|
||||||
|
model_card_kwarg = {
|
||||||
|
"model_name": cfg.output_dir.lstrip("./")
|
||||||
|
.encode("utf-8")
|
||||||
|
.decode("utf-8")
|
||||||
|
}
|
||||||
|
if cfg.datasets is not None:
|
||||||
|
if cfg.rl is not None or cfg.reward_model:
|
||||||
|
model_card_kwarg["dataset_name"] = [
|
||||||
|
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
model_card_kwarg["dataset_tags"] = [
|
||||||
|
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
|
]
|
||||||
|
|
||||||
|
trainer.create_model_card(**model_card_kwarg)
|
||||||
|
except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError):
|
||||||
pass
|
pass
|
||||||
elif cfg.hub_model_id:
|
elif cfg.hub_model_id:
|
||||||
# defensively push to the hub to ensure the model card is updated
|
# defensively push to the hub to ensure the model card is updated
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
"""
|
"""
|
||||||
Basic utils for Axolotl
|
Basic utils for Axolotl
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def is_mlflow_available():
|
def is_mlflow_available():
|
||||||
@@ -10,3 +14,23 @@ def is_mlflow_available():
|
|||||||
|
|
||||||
def is_comet_available():
|
def is_comet_available():
|
||||||
return importlib.util.find_spec("comet_ml") is not None
|
return importlib.util.find_spec("comet_ml") is not None
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
def get_pytorch_version() -> tuple[int, int, int]:
|
||||||
|
"""
|
||||||
|
Get Pytorch version as a tuple of (major, minor, patch).
|
||||||
|
"""
|
||||||
|
torch_version = torch.__version__
|
||||||
|
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||||
|
|
||||||
|
if not version_match:
|
||||||
|
raise ValueError("Invalid version format")
|
||||||
|
|
||||||
|
major, minor, patch = version_match.groups()
|
||||||
|
major, minor = int(major), int(minor)
|
||||||
|
patch = int(patch) if patch is not None else 0 # Default patch to 0 if not present
|
||||||
|
return major, minor, patch
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: enable=duplicate-code
|
||||||
|
|||||||
@@ -1,13 +1,24 @@
|
|||||||
"""Benchmarking and measurement utilities"""
|
"""Benchmarking and measurement utilities"""
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import pynvml
|
|
||||||
import torch
|
import torch
|
||||||
from pynvml.nvml import NVMLError
|
|
||||||
from transformers.utils.import_utils import is_torch_npu_available
|
from transformers.utils.import_utils import is_torch_npu_available
|
||||||
|
|
||||||
from axolotl.utils.distributed import get_device_type
|
from axolotl.utils.distributed import get_device_type
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pynvml import (
|
||||||
|
NVMLError,
|
||||||
|
nvmlDeviceGetHandleByIndex,
|
||||||
|
nvmlDeviceGetMemoryInfo,
|
||||||
|
nvmlInit,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
NVMLError = None
|
||||||
|
nvmlDeviceGetHandleByIndex = None
|
||||||
|
nvmlDeviceGetMemoryInfo = None
|
||||||
|
nvmlInit = None
|
||||||
|
|
||||||
|
|
||||||
def check_cuda_device(default_value):
|
def check_cuda_device(default_value):
|
||||||
"""
|
"""
|
||||||
@@ -68,10 +79,12 @@ def gpu_memory_usage_smi(device=0):
|
|||||||
device = device.index
|
device = device.index
|
||||||
if isinstance(device, str) and device.startswith("cuda:"):
|
if isinstance(device, str) and device.startswith("cuda:"):
|
||||||
device = int(device[5:])
|
device = int(device[5:])
|
||||||
|
if not nvmlInit:
|
||||||
|
return 0.0
|
||||||
try:
|
try:
|
||||||
pynvml.nvmlInit()
|
nvmlInit()
|
||||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
handle = nvmlDeviceGetHandleByIndex(device)
|
||||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
info = nvmlDeviceGetMemoryInfo(handle)
|
||||||
return info.used / 1024.0**3
|
return info.used / 1024.0**3
|
||||||
except NVMLError:
|
except NVMLError:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from transformers import (
|
|||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||||
|
from trl.models import unwrap_model_for_generation
|
||||||
|
|
||||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
@@ -46,6 +47,7 @@ from axolotl.utils.distributed import (
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||||
|
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
LOG = logging.getLogger("axolotl.callbacks")
|
LOG = logging.getLogger("axolotl.callbacks")
|
||||||
|
|
||||||
@@ -64,7 +66,10 @@ class EvalFirstStepCallback(
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
if (
|
||||||
|
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||||
|
and state.global_step == 1
|
||||||
|
):
|
||||||
control.should_evaluate = True
|
control.should_evaluate = True
|
||||||
return control
|
return control
|
||||||
|
|
||||||
@@ -375,7 +380,10 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
|||||||
for metric in self.cfg.eval_causal_lm_metrics:
|
for metric in self.cfg.eval_causal_lm_metrics:
|
||||||
if metric == "perplexity":
|
if metric == "perplexity":
|
||||||
max_seq_len = self.cfg.eval_max_new_tokens
|
max_seq_len = self.cfg.eval_max_new_tokens
|
||||||
metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len)
|
metrics[metric] = Perplexity(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
metrics[metric] = evaluate.load(metric)
|
metrics[metric] = evaluate.load(metric)
|
||||||
@@ -392,8 +400,11 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
|||||||
eval_dataloader,
|
eval_dataloader,
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
trainer.model.eval()
|
trainer.model_wrapped.eval()
|
||||||
device = torch.device(self.cfg.device)
|
|
||||||
|
device = torch.device(
|
||||||
|
self.cfg.device
|
||||||
|
) # Use this instead of trainer.model_wrapped.device as it may return cpu if fsdp offloaded
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
generation_config = GenerationConfig(
|
generation_config = GenerationConfig(
|
||||||
@@ -430,6 +441,10 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
|||||||
for k in metric._feature_names() # pylint: disable=protected-access
|
for k in metric._feature_names() # pylint: disable=protected-access
|
||||||
if k in kwargs
|
if k in kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isinstance(metric, Perplexity):
|
||||||
|
metric_kwargs["model"] = trainer.model_wrapped
|
||||||
|
|
||||||
metric_score = metric.compute(**metric_kwargs)
|
metric_score = metric.compute(**metric_kwargs)
|
||||||
return (
|
return (
|
||||||
metric_score["score"]
|
metric_score["score"]
|
||||||
@@ -465,89 +480,97 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
|||||||
def predict_with_generate():
|
def predict_with_generate():
|
||||||
eval_src, eval_pred, eval_ref = [], [], []
|
eval_src, eval_pred, eval_ref = [], [], []
|
||||||
|
|
||||||
for batch in tqdm(eval_dataloader):
|
with unwrap_model_for_generation(
|
||||||
batch_labels = batch["labels"].to(device)
|
trainer.model_wrapped, trainer.accelerator
|
||||||
batch_input_ids = batch["input_ids"].to(device)
|
) as unwrapped_model:
|
||||||
|
for batch in tqdm(eval_dataloader, disable=not is_main_process()):
|
||||||
|
batch_labels = batch["labels"].to(device)
|
||||||
|
batch_input_ids = batch["input_ids"].to(device)
|
||||||
|
|
||||||
if "position_ids" in batch:
|
if "position_ids" in batch:
|
||||||
batch_pos_ids = batch["position_ids"].tolist()
|
batch_pos_ids = batch["position_ids"].tolist()
|
||||||
else:
|
|
||||||
batch_pos_ids = [None] * len(batch["input_ids"])
|
|
||||||
|
|
||||||
prompt_token_ids_list = []
|
|
||||||
completion_token_ids_list = []
|
|
||||||
|
|
||||||
for input_ids_all, labels_all, pos_ids in zip(
|
|
||||||
batch_input_ids,
|
|
||||||
batch_labels,
|
|
||||||
batch_pos_ids,
|
|
||||||
):
|
|
||||||
if pos_ids is None:
|
|
||||||
pos_ranges = [(0, len(input_ids_all) - 1)]
|
|
||||||
else:
|
else:
|
||||||
pos_ranges = find_ranges(pos_ids)
|
batch_pos_ids = [None] * len(batch["input_ids"])
|
||||||
|
|
||||||
for pos_range in pos_ranges:
|
prompt_token_ids_list = []
|
||||||
start, end = pos_range
|
completion_token_ids_list = []
|
||||||
if start == end:
|
|
||||||
continue
|
|
||||||
|
|
||||||
input_ids = input_ids_all[start : end + 1]
|
for input_ids_all, labels_all, pos_ids in zip(
|
||||||
labels = labels_all[start : end + 1]
|
batch_input_ids,
|
||||||
|
batch_labels,
|
||||||
|
batch_pos_ids,
|
||||||
|
):
|
||||||
|
if pos_ids is None:
|
||||||
|
pos_ranges = [(0, len(input_ids_all) - 1)]
|
||||||
|
else:
|
||||||
|
pos_ranges = find_ranges(pos_ids)
|
||||||
|
|
||||||
tokens_without_loss = labels == IGNORE_INDEX
|
for pos_range in pos_ranges:
|
||||||
tokens_with_loss = labels != IGNORE_INDEX
|
start, end = pos_range
|
||||||
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
|
if start == end:
|
||||||
prompt_token_includes = (
|
continue
|
||||||
tokens_without_loss & tokens_exclude_padding
|
|
||||||
|
input_ids = input_ids_all[start : end + 1]
|
||||||
|
labels = labels_all[start : end + 1]
|
||||||
|
|
||||||
|
tokens_without_loss = labels == IGNORE_INDEX
|
||||||
|
tokens_with_loss = labels != IGNORE_INDEX
|
||||||
|
tokens_exclude_padding = (
|
||||||
|
input_ids != tokenizer.pad_token_id
|
||||||
|
)
|
||||||
|
prompt_token_includes = (
|
||||||
|
tokens_without_loss & tokens_exclude_padding
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_token_ids = input_ids[prompt_token_includes]
|
||||||
|
prompt_token_ids_list.append(prompt_token_ids)
|
||||||
|
|
||||||
|
completion_token_ids = input_ids[tokens_with_loss]
|
||||||
|
completion_token_ids_list.append(completion_token_ids)
|
||||||
|
|
||||||
|
prompt_texts = tokenizer.batch_decode(
|
||||||
|
prompt_token_ids_list, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
completion_texts = tokenizer.batch_decode(
|
||||||
|
completion_token_ids_list, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
prompt_encoding = tokenizer(
|
||||||
|
prompt_texts, padding=True, return_tensors="pt"
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
predictions = unwrapped_model.generate(
|
||||||
|
**prompt_encoding, generation_config=generation_config
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_token_ids = input_ids[prompt_token_includes]
|
del prompt_encoding
|
||||||
prompt_token_ids_list.append(prompt_token_ids)
|
|
||||||
|
|
||||||
completion_token_ids = input_ids[tokens_with_loss]
|
prediction_all_tokens = predictions["sequences"].cpu().tolist()
|
||||||
completion_token_ids_list.append(completion_token_ids)
|
prediction_without_prompt_tokens_list = []
|
||||||
|
for prompt_token_ids, prediction_tokens in zip(
|
||||||
|
prompt_token_ids_list, prediction_all_tokens
|
||||||
|
):
|
||||||
|
prediction_without_prompt_tokens = prediction_tokens[
|
||||||
|
len(prompt_token_ids) :
|
||||||
|
]
|
||||||
|
prediction_without_prompt_tokens_list.append(
|
||||||
|
prediction_without_prompt_tokens
|
||||||
|
)
|
||||||
|
|
||||||
prompt_texts = tokenizer.batch_decode(
|
predicted_texts = tokenizer.batch_decode(
|
||||||
prompt_token_ids_list, skip_special_tokens=True
|
prediction_without_prompt_tokens_list,
|
||||||
)
|
skip_special_tokens=True,
|
||||||
completion_texts = tokenizer.batch_decode(
|
|
||||||
completion_token_ids_list, skip_special_tokens=True
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
prompt_encoding = tokenizer(
|
|
||||||
prompt_texts, padding=True, return_tensors="pt"
|
|
||||||
).to(self.cfg.device)
|
|
||||||
predictions = trainer.model.generate(
|
|
||||||
**prompt_encoding, generation_config=generation_config
|
|
||||||
)
|
)
|
||||||
|
|
||||||
prediction_all_tokens = predictions["sequences"].cpu().tolist()
|
eval_src.extend(prompt_texts)
|
||||||
prediction_without_prompt_tokens_list = []
|
eval_pred.extend(predicted_texts)
|
||||||
for prompt_token_ids, prediction_tokens in zip(
|
eval_ref.extend(completion_texts)
|
||||||
prompt_token_ids_list, prediction_all_tokens
|
|
||||||
):
|
|
||||||
prediction_without_prompt_tokens = prediction_tokens[
|
|
||||||
len(prompt_token_ids) :
|
|
||||||
]
|
|
||||||
prediction_without_prompt_tokens_list.append(
|
|
||||||
prediction_without_prompt_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
predicted_texts = tokenizer.batch_decode(
|
|
||||||
prediction_without_prompt_tokens_list, skip_special_tokens=True
|
|
||||||
)
|
|
||||||
|
|
||||||
eval_src.extend(prompt_texts)
|
|
||||||
eval_pred.extend(predicted_texts)
|
|
||||||
eval_ref.extend(completion_texts)
|
|
||||||
|
|
||||||
return eval_src, eval_pred, eval_ref
|
return eval_src, eval_pred, eval_ref
|
||||||
|
|
||||||
if is_main_process():
|
eval_preds = predict_with_generate()
|
||||||
eval_preds = predict_with_generate()
|
trainer.log(evaluate_preds(*eval_preds))
|
||||||
trainer.log(evaluate_preds(*eval_preds))
|
|
||||||
|
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ from transformers.modeling_outputs import CausalLMOutput
|
|||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
|
|
||||||
class Perplexity:
|
class Perplexity:
|
||||||
"""
|
"""
|
||||||
@@ -17,16 +19,13 @@ class Perplexity:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: PreTrainedModel,
|
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
stride: int = 512,
|
stride: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.device = model.device
|
|
||||||
self.name = "perplexity"
|
self.name = "perplexity"
|
||||||
|
|
||||||
def _feature_names(self) -> List[str]:
|
def _feature_names(self) -> List[str]:
|
||||||
@@ -34,6 +33,7 @@ class Perplexity:
|
|||||||
|
|
||||||
def compute(
|
def compute(
|
||||||
self,
|
self,
|
||||||
|
model: PreTrainedModel,
|
||||||
references: Optional[List[str]] = None,
|
references: Optional[List[str]] = None,
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
@@ -41,17 +41,21 @@ class Perplexity:
|
|||||||
"""
|
"""
|
||||||
assert references is not None, "Missing parameter: references"
|
assert references is not None, "Missing parameter: references"
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
references_tokenized = self.tokenizer(
|
references_tokenized = self.tokenizer(
|
||||||
references, return_tensors="pt", padding=True, truncation=True
|
references, return_tensors="pt", padding=True, truncation=True
|
||||||
)
|
)
|
||||||
input_ids: Tensor = references_tokenized["input_ids"] # type: ignore
|
input_ids: Tensor = references_tokenized["input_ids"] # type: ignore
|
||||||
input_ids = input_ids.to(self.device)
|
input_ids = input_ids.to(model.device)
|
||||||
|
|
||||||
sequence_length = input_ids.size(1)
|
sequence_length = input_ids.size(1)
|
||||||
|
|
||||||
losses = []
|
losses = []
|
||||||
prev_end_loc = 0
|
prev_end_loc = 0
|
||||||
for begin_loc in tqdm(range(0, sequence_length, self.stride)):
|
for begin_loc in tqdm(
|
||||||
|
range(0, sequence_length, self.stride), disable=not is_main_process()
|
||||||
|
):
|
||||||
end_loc = min(begin_loc + self.max_seq_len, sequence_length)
|
end_loc = min(begin_loc + self.max_seq_len, sequence_length)
|
||||||
trg_len = end_loc - prev_end_loc
|
trg_len = end_loc - prev_end_loc
|
||||||
input_ids_slice = input_ids[:, begin_loc:end_loc]
|
input_ids_slice = input_ids[:, begin_loc:end_loc]
|
||||||
@@ -59,7 +63,7 @@ class Perplexity:
|
|||||||
labels_slice[:, :-trg_len] = -100
|
labels_slice[:, :-trg_len] = -100
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs: CausalLMOutput = self.model(
|
outputs: CausalLMOutput = model(
|
||||||
input_ids=input_ids_slice, labels=labels_slice
|
input_ids=input_ids_slice, labels=labels_slice
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
Collators for multi-modal chat messages and packing
|
Collators for multi-modal chat messages and packing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
||||||
@@ -30,8 +32,8 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
raise ValueError("Packing is currently not supported.")
|
raise ValueError("Packing is currently not supported.")
|
||||||
|
|
||||||
def torch_call(
|
def torch_call(
|
||||||
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
|
self, examples: list[Union[list[int], Any, dict[str, Any]]]
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
# Handle dict or lists with proper padding and conversion to tensor.
|
# Handle dict or lists with proper padding and conversion to tensor.
|
||||||
|
|
||||||
return self.__class__.process_rows(
|
return self.__class__.process_rows(
|
||||||
@@ -46,6 +48,120 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
# *** This is COPIED from the trl example sft_vlm.py code ***
|
# *** This is COPIED from the trl example sft_vlm.py code ***
|
||||||
# use this as a starting point
|
# use this as a starting point
|
||||||
|
|
||||||
|
def _preprocess(examples: list[dict]) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Preprocess conversation examples to ensure consistent format.
|
||||||
|
|
||||||
|
Converts different conversation formats to OpenAI format with 'messages'.
|
||||||
|
Supports two formats:
|
||||||
|
1. OpenAI format with 'messages'
|
||||||
|
2. Legacy format with 'conversations'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
examples: list of conversation dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict in OpenAI format with 'messages' key
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the conversation format is not supported
|
||||||
|
"""
|
||||||
|
role_mapping = {
|
||||||
|
"human": "user",
|
||||||
|
"gpt": "assistant",
|
||||||
|
}
|
||||||
|
|
||||||
|
def normalize_role(role: str) -> str:
|
||||||
|
"""Normalize role names to OpenAI format. Default to original role if not found."""
|
||||||
|
return role_mapping.get(role, role)
|
||||||
|
|
||||||
|
def convert_legacy_format(example: dict) -> dict:
|
||||||
|
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": normalize_role(convo["from"]),
|
||||||
|
"content": convo["value"],
|
||||||
|
}
|
||||||
|
for convo in example["conversations"]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create new dict without 'conversations' key
|
||||||
|
result = deepcopy(example)
|
||||||
|
result.pop("conversations")
|
||||||
|
return {"messages": messages, **result}
|
||||||
|
|
||||||
|
processed_examples = []
|
||||||
|
for example in examples:
|
||||||
|
# OpenAI format
|
||||||
|
if "messages" in example:
|
||||||
|
processed_examples.append(example)
|
||||||
|
|
||||||
|
# Legacy format
|
||||||
|
elif "conversations" in example:
|
||||||
|
processed_examples.append(convert_legacy_format(example))
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Only `messages` and `conversations` message keys are currently supported."
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_examples
|
||||||
|
|
||||||
|
def _process_images(examples, max_images):
|
||||||
|
"""
|
||||||
|
Process images from examples, ensuring consistency in image presence and applying max_images limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
examples: List of dictionaries that may contain 'images' key
|
||||||
|
max_images: Maximum number of images to keep per example (0 means no limit)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either None (if no images) or List[Image objects] (if all examples have images)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If there's a mix of None and non-None images
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_image(example):
|
||||||
|
if "images" not in example:
|
||||||
|
return None
|
||||||
|
images = example["images"]
|
||||||
|
if isinstance(images, str):
|
||||||
|
return Image.open(images)
|
||||||
|
return images
|
||||||
|
|
||||||
|
images = [get_image(example) for example in examples]
|
||||||
|
|
||||||
|
# Count None and non-None images
|
||||||
|
none_count = sum(1 for img in images if img is None)
|
||||||
|
|
||||||
|
# All images are None
|
||||||
|
if none_count == len(images):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Mix of None and non-None images
|
||||||
|
if none_count > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"All images should be either None or not None. "
|
||||||
|
"Please provide images for all examples or None."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply max_images limit if specified
|
||||||
|
if max_images > 0:
|
||||||
|
images = [
|
||||||
|
(
|
||||||
|
img_batch[:max_images]
|
||||||
|
if isinstance(img_batch, (list, tuple))
|
||||||
|
else img_batch
|
||||||
|
)
|
||||||
|
for img_batch in images
|
||||||
|
]
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
# Preprocess the examples
|
||||||
|
examples = _preprocess(examples)
|
||||||
|
|
||||||
# Get the texts and images, and apply the chat template
|
# Get the texts and images, and apply the chat template
|
||||||
texts = [
|
texts = [
|
||||||
processor.apply_chat_template(
|
processor.apply_chat_template(
|
||||||
@@ -53,15 +169,8 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
|||||||
)
|
)
|
||||||
for example in examples
|
for example in examples
|
||||||
]
|
]
|
||||||
images = [
|
|
||||||
Image.open(example["images"])
|
|
||||||
if isinstance(example["images"], str)
|
|
||||||
else example["images"]
|
|
||||||
for example in examples
|
|
||||||
]
|
|
||||||
|
|
||||||
if max_images > 0:
|
images = _process_images(examples, max_images=max_images)
|
||||||
images = [img_batch[:max_images] for img_batch in images]
|
|
||||||
|
|
||||||
# Tokenize the texts and process the images
|
# Tokenize the texts and process the images
|
||||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import torch
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
from transformers.utils.import_utils import is_torch_npu_available
|
from transformers.utils.import_utils import is_torch_npu_available
|
||||||
|
|
||||||
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.integrations.config import merge_input_args
|
from axolotl.integrations.config import merge_input_args
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
@@ -152,7 +153,7 @@ def normalize_config(cfg):
|
|||||||
cfg.is_llama_derived_model = (
|
cfg.is_llama_derived_model = (
|
||||||
(
|
(
|
||||||
hasattr(model_config, "model_type")
|
hasattr(model_config, "model_type")
|
||||||
and model_config.model_type == ["llama", "mllama_text_model"]
|
and model_config.model_type in ["llama", "mllama_text_model"]
|
||||||
)
|
)
|
||||||
or cfg.is_llama_derived_model
|
or cfg.is_llama_derived_model
|
||||||
or "llama" in cfg.base_model.lower()
|
or "llama" in cfg.base_model.lower()
|
||||||
@@ -229,7 +230,11 @@ def normalize_cfg_datasets(cfg):
|
|||||||
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
|
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
|
||||||
|
|
||||||
|
|
||||||
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
def validate_config(
|
||||||
|
cfg: DictDefault,
|
||||||
|
capabilities: Optional[dict] = None,
|
||||||
|
env_capabilities: Optional[dict] = None,
|
||||||
|
):
|
||||||
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
|
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
|
||||||
AxolotlInputConfig = AxolotlInputConfigBase
|
AxolotlInputConfig = AxolotlInputConfigBase
|
||||||
|
|
||||||
@@ -239,14 +244,35 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
|||||||
AxolotlInputConfig, # pylint: disable=invalid-name
|
AxolotlInputConfig, # pylint: disable=invalid-name
|
||||||
) = merge_input_args()
|
) = merge_input_args()
|
||||||
|
|
||||||
if capabilities:
|
if capabilities or env_capabilities:
|
||||||
|
if (capabilities and not env_capabilities) or (
|
||||||
|
env_capabilities and not capabilities
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Both capabilities and env_capabilities must be provided or not provided."
|
||||||
|
)
|
||||||
|
|
||||||
return DictDefault(
|
return DictDefault(
|
||||||
dict(
|
dict(
|
||||||
AxolotlConfigWCapabilities(
|
AxolotlConfigWCapabilities(
|
||||||
**cfg.to_dict(), capabilities=capabilities
|
**cfg.to_dict(),
|
||||||
|
capabilities=capabilities,
|
||||||
|
env_capabilities=env_capabilities,
|
||||||
).model_dump(exclude_none=True)
|
).model_dump(exclude_none=True)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return DictDefault(
|
return DictDefault(
|
||||||
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_plugins(cfg):
|
||||||
|
"""
|
||||||
|
Prepare the plugins for the configuration
|
||||||
|
"""
|
||||||
|
|
||||||
|
if cfg.get("plugins"):
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
for plugin_name in cfg["plugins"]:
|
||||||
|
plugin_manager.register(plugin_name)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import os
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
Field,
|
Field,
|
||||||
@@ -21,7 +22,7 @@ from transformers import SchedulerType
|
|||||||
from transformers.training_args import OptimizerNames
|
from transformers.training_args import OptimizerNames
|
||||||
from transformers.utils.import_utils import is_torch_npu_available
|
from transformers.utils.import_utils import is_torch_npu_available
|
||||||
|
|
||||||
from axolotl.utils.config.models.internals import GPUCapabilities
|
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||||
|
|
||||||
@@ -322,11 +323,13 @@ class LoraConfig(BaseModel):
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_adapter(cls, data):
|
def validate_adapter(cls, data):
|
||||||
if not data.get("adapter") and (
|
if (
|
||||||
data.get("load_in_8bit") or data.get("load_in_4bit")
|
not data.get("adapter")
|
||||||
|
and not data.get("inference")
|
||||||
|
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
|
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
|
||||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
@@ -430,6 +433,8 @@ class HyperparametersConfig(BaseModel):
|
|||||||
group_by_length: Optional[bool] = None
|
group_by_length: Optional[bool] = None
|
||||||
|
|
||||||
learning_rate: Union[str, float]
|
learning_rate: Union[str, float]
|
||||||
|
embedding_lr: Optional[float] = None
|
||||||
|
embedding_lr_scale: Optional[float] = None
|
||||||
weight_decay: Optional[float] = 0.0
|
weight_decay: Optional[float] = 0.0
|
||||||
optimizer: Optional[
|
optimizer: Optional[
|
||||||
Union[
|
Union[
|
||||||
@@ -622,6 +627,7 @@ class AxolotlInputConfig(
|
|||||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||||
)
|
)
|
||||||
dataset_processes: Optional[int] = Field(default=os.cpu_count())
|
dataset_processes: Optional[int] = Field(default=os.cpu_count())
|
||||||
|
dataset_exact_deduplication: Optional[bool] = None
|
||||||
dataset_keep_in_memory: Optional[bool] = None
|
dataset_keep_in_memory: Optional[bool] = None
|
||||||
dataloader_pin_memory: Optional[bool] = None
|
dataloader_pin_memory: Optional[bool] = None
|
||||||
dataloader_num_workers: Optional[int] = None
|
dataloader_num_workers: Optional[int] = None
|
||||||
@@ -1474,6 +1480,7 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
|
|
||||||
capabilities: GPUCapabilities
|
capabilities: GPUCapabilities
|
||||||
|
env_capabilities: EnvCapabilities
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_bf16(self):
|
def check_bf16(self):
|
||||||
@@ -1514,19 +1521,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_hopper_8bit_lora(cls, data):
|
|
||||||
is_sm_90: bool = (
|
|
||||||
data["capabilities"]
|
|
||||||
and data["capabilities"].get("compute_capability") == "sm_90"
|
|
||||||
)
|
|
||||||
if data.get("adapter") and data.get("load_in_8bit") and is_sm_90:
|
|
||||||
# see https://github.com/bitsandbytes-foundation/bitsandbytes/issues/538#issuecomment-2262945464
|
|
||||||
raise ValueError("8-bit LoRA is not supported on Hopper GPUs")
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_fsdp_deepspeed(cls, data):
|
def check_fsdp_deepspeed(cls, data):
|
||||||
@@ -1548,3 +1542,21 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
|
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_adopt_torch_version(cls, data):
|
||||||
|
if (data.get("optimizer") is not None) and ("adopt" in data.get("optimizer")):
|
||||||
|
env_capabilities = data.get("env_capabilities", {})
|
||||||
|
torch_version = env_capabilities.get("torch_version")
|
||||||
|
|
||||||
|
if torch_version is None:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
|
||||||
|
|
||||||
|
if version.parse(torch_version) < version.parse("2.5.1"):
|
||||||
|
raise ValueError(
|
||||||
|
"ADOPT optimizer is incompatible with torch version < 2.5.1"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|||||||
@@ -12,3 +12,9 @@ class GPUCapabilities(BaseModel):
|
|||||||
n_gpu: int = Field(default=1)
|
n_gpu: int = Field(default=1)
|
||||||
n_node: int = Field(default=1)
|
n_node: int = Field(default=1)
|
||||||
compute_capability: Optional[str] = Field(default=None)
|
compute_capability: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class EnvCapabilities(BaseModel):
|
||||||
|
"""model to manage the environment capabilities statically"""
|
||||||
|
|
||||||
|
torch_version: Optional[str] = Field(default=None)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
|||||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||||
from axolotl.prompt_strategies.kto import load as load_kto
|
from axolotl.prompt_strategies.kto import load as load_kto
|
||||||
from axolotl.prompt_strategies.orpo import load as load_orpo
|
from axolotl.prompt_strategies.orpo import load as load_orpo
|
||||||
from axolotl.utils.data.utils import md5
|
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
@@ -208,4 +208,9 @@ def load_prepare_dpo_datasets(cfg):
|
|||||||
if eval_dataset and not eval_is_preprocessed:
|
if eval_dataset and not eval_is_preprocessed:
|
||||||
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
|
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
|
||||||
|
|
||||||
|
if cfg.dataset_exact_deduplication:
|
||||||
|
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||||
|
train_dataset=train_dataset, eval_dataset=eval_dataset
|
||||||
|
)
|
||||||
|
|
||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
|
|||||||
@@ -2,11 +2,9 @@
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import requests
|
|
||||||
from datasets import (
|
from datasets import (
|
||||||
Dataset,
|
Dataset,
|
||||||
DatasetDict,
|
DatasetDict,
|
||||||
@@ -44,7 +42,11 @@ from axolotl.prompters import (
|
|||||||
UnsupportedPrompter,
|
UnsupportedPrompter,
|
||||||
)
|
)
|
||||||
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||||
from axolotl.utils.data.utils import md5
|
from axolotl.utils.data.utils import (
|
||||||
|
deduplicate_and_log_datasets,
|
||||||
|
md5,
|
||||||
|
retry_on_request_exceptions,
|
||||||
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_local_main_process, zero_first
|
from axolotl.utils.distributed import is_local_main_process, zero_first
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
@@ -55,27 +57,6 @@ from axolotl.utils.trainer import (
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
|
||||||
def decorator(func):
|
|
||||||
@functools.wraps(func)
|
|
||||||
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
except (
|
|
||||||
requests.exceptions.ReadTimeout,
|
|
||||||
requests.exceptions.ConnectionError,
|
|
||||||
) as exc:
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
time.sleep(delay)
|
|
||||||
else:
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
def prepare_dataset(cfg, tokenizer, processor=None):
|
def prepare_dataset(cfg, tokenizer, processor=None):
|
||||||
prompters = []
|
prompters = []
|
||||||
@@ -136,8 +117,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
|||||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||||
train_dataset = train_dataset.with_format("torch")
|
train_dataset = train_dataset.with_format("torch")
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
|
if cfg.dataset_exact_deduplication:
|
||||||
|
LOG.info("Deduplication not available for pretrained datasets")
|
||||||
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
||||||
|
|
||||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||||
if total_eval_steps == 0:
|
if total_eval_steps == 0:
|
||||||
@@ -584,7 +566,8 @@ def load_prepare_datasets(
|
|||||||
)
|
)
|
||||||
train_fingerprint = md5(to_hash_train)
|
train_fingerprint = md5(to_hash_train)
|
||||||
test_fingerprint = md5(to_hash_test)
|
test_fingerprint = md5(to_hash_test)
|
||||||
|
if cfg.dataset_exact_deduplication:
|
||||||
|
_, _, dataset = deduplicate_and_log_datasets(dataset=dataset)
|
||||||
dataset = dataset.train_test_split(
|
dataset = dataset.train_test_split(
|
||||||
test_size=val_set_size,
|
test_size=val_set_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
@@ -596,12 +579,17 @@ def load_prepare_datasets(
|
|||||||
train_dataset = dataset["train"]
|
train_dataset = dataset["train"]
|
||||||
eval_dataset = dataset["test"]
|
eval_dataset = dataset["test"]
|
||||||
elif split == "test":
|
elif split == "test":
|
||||||
|
if cfg.dataset_exact_deduplication:
|
||||||
|
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
|
||||||
|
else:
|
||||||
|
eval_dataset = dataset
|
||||||
train_dataset = None
|
train_dataset = None
|
||||||
eval_dataset = dataset
|
|
||||||
else:
|
else:
|
||||||
train_dataset = dataset
|
if cfg.dataset_exact_deduplication:
|
||||||
|
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
|
||||||
|
else:
|
||||||
|
train_dataset = dataset
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
|
|
||||||
return train_dataset, eval_dataset, prompters
|
return train_dataset, eval_dataset, prompters
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,55 @@
|
|||||||
"""data handling helpers"""
|
"""data handling helpers"""
|
||||||
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
|
import requests
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
class RetryStrategy(Enum):
|
||||||
|
"""
|
||||||
|
Enum for retry strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
CONSTANT = 1
|
||||||
|
LINEAR = 2
|
||||||
|
EXPONENTIAL = 3
|
||||||
|
|
||||||
|
|
||||||
|
def retry_on_request_exceptions(
|
||||||
|
max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR
|
||||||
|
):
|
||||||
|
def decorator(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except (
|
||||||
|
requests.exceptions.ReadTimeout,
|
||||||
|
requests.exceptions.ConnectionError,
|
||||||
|
huggingface_hub.errors.HfHubHTTPError,
|
||||||
|
) as exc:
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
if retry_strategy == RetryStrategy.EXPONENTIAL:
|
||||||
|
step_delay = delay * 2**attempt
|
||||||
|
elif retry_strategy == RetryStrategy.LINEAR:
|
||||||
|
step_delay = delay * (attempt + 1)
|
||||||
|
else:
|
||||||
|
step_delay = delay # Use constant delay.
|
||||||
|
time.sleep(step_delay)
|
||||||
|
else:
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||||
@@ -8,3 +57,96 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
|||||||
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
|
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
|
||||||
except TypeError:
|
except TypeError:
|
||||||
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
||||||
|
|
||||||
|
|
||||||
|
def sha256(to_hash: str, encoding: str = "utf-8") -> str:
|
||||||
|
return hashlib.sha256(to_hash.encode(encoding)).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def deduplicate_dataset(
|
||||||
|
dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None
|
||||||
|
) -> Dataset:
|
||||||
|
unique_indices = []
|
||||||
|
|
||||||
|
for idx, row in enumerate(dataset):
|
||||||
|
row_hash = sha256(str(row)) # Using SHA256 for collision resistance.
|
||||||
|
if row_hash not in seen_hashes:
|
||||||
|
seen_hashes[row_hash] = [idx]
|
||||||
|
unique_indices.append(idx)
|
||||||
|
else:
|
||||||
|
# Check for collision by looking up the original dataset indices
|
||||||
|
original_indices = seen_hashes[row_hash]
|
||||||
|
is_duplicate = False
|
||||||
|
for original_idx in original_indices:
|
||||||
|
if (
|
||||||
|
not idx == original_idx
|
||||||
|
and original_idx < len(dataset)
|
||||||
|
and str(dataset[original_idx]) == str(row)
|
||||||
|
):
|
||||||
|
is_duplicate = True
|
||||||
|
break
|
||||||
|
# Check in the other dataset if provided
|
||||||
|
if other_dataset is not None:
|
||||||
|
if original_idx < len(other_dataset) and str(
|
||||||
|
other_dataset[original_idx]
|
||||||
|
) == str(row):
|
||||||
|
is_duplicate = True
|
||||||
|
break
|
||||||
|
if not is_duplicate:
|
||||||
|
seen_hashes[row_hash].append(idx)
|
||||||
|
unique_indices.append(idx)
|
||||||
|
continue
|
||||||
|
return dataset.select(unique_indices)
|
||||||
|
|
||||||
|
|
||||||
|
def deduplicate_and_log_datasets(
|
||||||
|
*,
|
||||||
|
train_dataset: Dataset = None,
|
||||||
|
eval_dataset: Dataset = None,
|
||||||
|
dataset: Dataset = None,
|
||||||
|
) -> tuple[Dataset, Dataset, Dataset]:
|
||||||
|
"""
|
||||||
|
Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: Deduplicated train, eval, and additional datasets.
|
||||||
|
"""
|
||||||
|
seen_hashes: dict[str, list[int]] = {}
|
||||||
|
|
||||||
|
# Handle cases where datasets are None
|
||||||
|
if train_dataset is not None:
|
||||||
|
LOG.info(
|
||||||
|
f"Starting deduplication for train dataset. Original size: {len(train_dataset)}"
|
||||||
|
)
|
||||||
|
train_dataset = deduplicate_dataset(
|
||||||
|
dataset=train_dataset, seen_hashes=seen_hashes
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
f"Deduplication complete for train dataset. New size: {len(train_dataset)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.info("Train dataset is None. Skipping deduplication.")
|
||||||
|
|
||||||
|
if eval_dataset is not None:
|
||||||
|
LOG.info(
|
||||||
|
f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}"
|
||||||
|
)
|
||||||
|
eval_dataset = deduplicate_dataset(
|
||||||
|
dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.info("Eval dataset is None. Skipping deduplication.")
|
||||||
|
|
||||||
|
if dataset is not None and (eval_dataset is None and train_dataset is None):
|
||||||
|
LOG.info(
|
||||||
|
f"Starting deduplication for combined dataset. Original size: {len(dataset)}"
|
||||||
|
)
|
||||||
|
dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes)
|
||||||
|
LOG.info(
|
||||||
|
f"Deduplication complete for combined dataset. New size: {len(dataset)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dataset, eval_dataset, dataset
|
||||||
|
|||||||
@@ -2,10 +2,12 @@
|
|||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
import gc
|
import gc
|
||||||
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import types
|
import types
|
||||||
|
from functools import cached_property
|
||||||
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
||||||
|
|
||||||
import addict
|
import addict
|
||||||
@@ -46,7 +48,6 @@ from transformers.integrations.deepspeed import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||||
from axolotl.integrations.sageattention.lib.core import monkeypatch_sdp_w_sage_attention
|
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.monkeypatch.multipack import (
|
from axolotl.monkeypatch.multipack import (
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||||
@@ -385,6 +386,15 @@ class ModelLoader:
|
|||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
self.patch_attention()
|
self.patch_attention()
|
||||||
|
|
||||||
|
if self.cfg.model_config_type == "llama":
|
||||||
|
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||||
|
patch_forward_for_ga,
|
||||||
|
patch_training_step_for_ga,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_forward_for_ga()
|
||||||
|
patch_training_step_for_ga()
|
||||||
|
|
||||||
if self.cfg.sample_packing and self.cfg.s2_attention:
|
if self.cfg.sample_packing and self.cfg.s2_attention:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Received `sample_packing=true` and `s2_attention=true`; however, \
|
"Received `sample_packing=true` and `s2_attention=true`; however, \
|
||||||
@@ -410,7 +420,7 @@ class ModelLoader:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.is_llama_derived_model:
|
if self.cfg.is_llama_derived_model:
|
||||||
self.patch_loss()
|
self.patch_loss_llama()
|
||||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||||
|
|
||||||
@@ -452,27 +462,34 @@ class ModelLoader:
|
|||||||
|
|
||||||
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
|
replace_stablelm_attn_with_flash_attn(self.cfg.base_model)
|
||||||
|
|
||||||
def patch_loss(self) -> None:
|
@cached_property
|
||||||
|
def has_flash_attn(self) -> bool:
|
||||||
|
"""Check if flash attention is installed"""
|
||||||
|
return importlib.util.find_spec("flash_attn") is not None
|
||||||
|
|
||||||
|
def patch_loss_llama(self) -> None:
|
||||||
"""
|
"""
|
||||||
Patch loss functions
|
Patch loss functions
|
||||||
"""
|
"""
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
if self.has_flash_attn:
|
||||||
patch_llama_cross_entropy,
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
patch_llama_rms_norm,
|
patch_fa_llama_cross_entropy,
|
||||||
)
|
patch_llama_rms_norm,
|
||||||
|
)
|
||||||
|
|
||||||
if self.cfg.flash_attn_cross_entropy:
|
if self.cfg.flash_attn_cross_entropy and self.has_flash_attn:
|
||||||
patch_llama_cross_entropy()
|
patch_fa_llama_cross_entropy()
|
||||||
if self.cfg.flash_attn_rms_norm:
|
elif self.cfg.unsloth_cross_entropy_loss:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||||
|
|
||||||
|
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||||
|
|
||||||
|
if self.cfg.flash_attn_rms_norm and self.has_flash_attn:
|
||||||
patch_llama_rms_norm()
|
patch_llama_rms_norm()
|
||||||
elif self.cfg.unsloth_rms_norm:
|
elif self.cfg.unsloth_rms_norm:
|
||||||
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
|
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
|
||||||
|
|
||||||
patch_unsloth_layernorm()
|
patch_unsloth_layernorm()
|
||||||
if self.cfg.unsloth_cross_entropy_loss:
|
|
||||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
|
||||||
|
|
||||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
|
||||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||||
|
|
||||||
@@ -482,6 +499,7 @@ class ModelLoader:
|
|||||||
"""
|
"""
|
||||||
Modify all llama derived models in one block
|
Modify all llama derived models in one block
|
||||||
"""
|
"""
|
||||||
|
self.patch_loss_llama()
|
||||||
|
|
||||||
if self.cfg.flash_attention:
|
if self.cfg.flash_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
@@ -529,16 +547,6 @@ class ModelLoader:
|
|||||||
"Shifted-sparse attention not currently implemented without flash attention."
|
"Shifted-sparse attention not currently implemented without flash attention."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.unsloth_cross_entropy_loss:
|
|
||||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
|
||||||
|
|
||||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
|
||||||
|
|
||||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
|
||||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
|
||||||
|
|
||||||
patch_self_attn_lora()
|
|
||||||
|
|
||||||
def set_auto_model_loader(self) -> None:
|
def set_auto_model_loader(self) -> None:
|
||||||
"""set self.AutoModelLoader
|
"""set self.AutoModelLoader
|
||||||
- default value: AutoModelForCausalLM (set at __init__)
|
- default value: AutoModelForCausalLM (set at __init__)
|
||||||
@@ -708,7 +716,6 @@ class ModelLoader:
|
|||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"sdpa"
|
"sdpa"
|
||||||
)
|
)
|
||||||
monkeypatch_sdp_w_sage_attention()
|
|
||||||
elif self.cfg.eager_attention:
|
elif self.cfg.eager_attention:
|
||||||
self.model_kwargs["attn_implementation"] = "eager"
|
self.model_kwargs["attn_implementation"] = "eager"
|
||||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
@@ -1086,14 +1093,17 @@ class ModelLoader:
|
|||||||
|
|
||||||
self.prepare_model(qlora_fsdp)
|
self.prepare_model(qlora_fsdp)
|
||||||
|
|
||||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
should_convert = (
|
||||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||||
if (needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp:
|
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||||
LOG.info(
|
((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp)
|
||||||
"converting modules to %s for flash attention", self.cfg.torch_dtype
|
or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if should_convert:
|
||||||
|
LOG.info("Converting modules to %s", self.cfg.torch_dtype)
|
||||||
self.convert_embedding_modules_dtype(
|
self.convert_embedding_modules_dtype(
|
||||||
embedding_modules,
|
embedding_modules=embedding_modules,
|
||||||
dist_dtype=self.cfg.torch_dtype,
|
dist_dtype=self.cfg.torch_dtype,
|
||||||
before_kbit_train_or_finetune=False,
|
before_kbit_train_or_finetune=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,21 +6,29 @@ Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeo
|
|||||||
"""
|
"""
|
||||||
# mypy: ignore-errors
|
# mypy: ignore-errors
|
||||||
# pylint: skip-file
|
# pylint: skip-file
|
||||||
|
# flake8: noqa
|
||||||
# mypy: allow-untyped-decorators
|
# mypy: allow-untyped-decorators
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from typing import List, Optional, Tuple, Union, cast
|
from typing import Callable, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim.optimizer import (
|
from torch.optim.optimizer import ( # DeviceDict,; _capturable_doc,; _differentiable_doc,; _foreach_doc,; _fused_doc,; _maximize_doc,; _stack_if_compiling,
|
||||||
|
DeviceDict,
|
||||||
Optimizer,
|
Optimizer,
|
||||||
ParamsT,
|
ParamsT,
|
||||||
|
_capturable_doc,
|
||||||
_default_to_fused_or_foreach,
|
_default_to_fused_or_foreach,
|
||||||
_device_dtype_check_for_fused,
|
_device_dtype_check_for_fused,
|
||||||
|
_differentiable_doc,
|
||||||
_disable_dynamo_if_unsupported,
|
_disable_dynamo_if_unsupported,
|
||||||
|
_foreach_doc,
|
||||||
|
_fused_doc,
|
||||||
_get_capturable_supported_devices,
|
_get_capturable_supported_devices,
|
||||||
_get_scalar_dtype,
|
_get_scalar_dtype,
|
||||||
_get_value,
|
_get_value,
|
||||||
|
_maximize_doc,
|
||||||
|
_stack_if_compiling,
|
||||||
_use_grad_for_differentiable,
|
_use_grad_for_differentiable,
|
||||||
_view_as_real,
|
_view_as_real,
|
||||||
)
|
)
|
||||||
@@ -35,8 +43,9 @@ class ADOPT(Optimizer):
|
|||||||
lr: Union[float, Tensor] = 1e-3,
|
lr: Union[float, Tensor] = 1e-3,
|
||||||
betas: Tuple[float, float] = (0.9, 0.9999),
|
betas: Tuple[float, float] = (0.9, 0.9999),
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
|
clip_lambda: Optional[Callable[[int], float]] = lambda step: step**0.25,
|
||||||
weight_decay: float = 0.0,
|
weight_decay: float = 0.0,
|
||||||
decoupled: bool = False,
|
decouple: bool = False,
|
||||||
*,
|
*,
|
||||||
foreach: Optional[bool] = None,
|
foreach: Optional[bool] = None,
|
||||||
maximize: bool = False,
|
maximize: bool = False,
|
||||||
@@ -62,12 +71,14 @@ class ADOPT(Optimizer):
|
|||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
|
self.clip_lambda = clip_lambda
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
betas=betas,
|
betas=betas,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
decoupled=decoupled,
|
decouple=decouple,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
foreach=foreach,
|
foreach=foreach,
|
||||||
capturable=capturable,
|
capturable=capturable,
|
||||||
@@ -219,8 +230,9 @@ class ADOPT(Optimizer):
|
|||||||
beta1=beta1,
|
beta1=beta1,
|
||||||
beta2=beta2,
|
beta2=beta2,
|
||||||
lr=group["lr"],
|
lr=group["lr"],
|
||||||
|
clip_lambda=self.clip_lambda,
|
||||||
weight_decay=group["weight_decay"],
|
weight_decay=group["weight_decay"],
|
||||||
decoupled=group["decoupled"],
|
decouple=group["decouple"],
|
||||||
eps=group["eps"],
|
eps=group["eps"],
|
||||||
maximize=group["maximize"],
|
maximize=group["maximize"],
|
||||||
foreach=group["foreach"],
|
foreach=group["foreach"],
|
||||||
@@ -247,8 +259,9 @@ def _single_tensor_adopt(
|
|||||||
beta1: float,
|
beta1: float,
|
||||||
beta2: float,
|
beta2: float,
|
||||||
lr: Union[float, Tensor],
|
lr: Union[float, Tensor],
|
||||||
|
clip_lambda: Optional[Callable[[int], float]],
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
decoupled: bool,
|
decouple: bool,
|
||||||
eps: float,
|
eps: float,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
capturable: bool,
|
capturable: bool,
|
||||||
@@ -276,14 +289,10 @@ def _single_tensor_adopt(
|
|||||||
and param.device.type in capturable_supported_devices
|
and param.device.type in capturable_supported_devices
|
||||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||||
|
|
||||||
# update step
|
step = step_t if capturable or differentiable else _get_value(step_t)
|
||||||
step_t += 1
|
|
||||||
|
|
||||||
if weight_decay != 0:
|
if weight_decay != 0 and not decouple:
|
||||||
if decoupled:
|
grad = grad.add(param, alpha=weight_decay)
|
||||||
param.add_(param, alpha=-lr * weight_decay)
|
|
||||||
else:
|
|
||||||
grad = grad.add(param, alpha=weight_decay)
|
|
||||||
|
|
||||||
if torch.is_complex(param):
|
if torch.is_complex(param):
|
||||||
grad = torch.view_as_real(grad)
|
grad = torch.view_as_real(grad)
|
||||||
@@ -293,20 +302,29 @@ def _single_tensor_adopt(
|
|||||||
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||||
param = torch.view_as_real(param)
|
param = torch.view_as_real(param)
|
||||||
|
|
||||||
step = step_t if capturable or differentiable else _get_value(step_t)
|
if step == 0:
|
||||||
if step == 1:
|
|
||||||
exp_avg_sq.addcmul_(grad, grad.conj())
|
exp_avg_sq.addcmul_(grad, grad.conj())
|
||||||
|
# update step
|
||||||
|
step_t += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if weight_decay != 0 and decouple:
|
||||||
|
param.add_(param, alpha=-lr * weight_decay)
|
||||||
|
|
||||||
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
||||||
if step == 2:
|
normed_grad = grad.div(denom)
|
||||||
exp_avg.addcdiv_(grad, denom)
|
if clip_lambda is not None:
|
||||||
else:
|
clip = clip_lambda(step)
|
||||||
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
|
normed_grad.clamp_(-clip, clip)
|
||||||
|
|
||||||
|
exp_avg.lerp_(normed_grad, 1 - beta1)
|
||||||
|
|
||||||
param.add_(exp_avg, alpha=-lr)
|
param.add_(exp_avg, alpha=-lr)
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||||
|
|
||||||
|
# update step
|
||||||
|
step_t += 1
|
||||||
|
|
||||||
|
|
||||||
def _multi_tensor_adopt(
|
def _multi_tensor_adopt(
|
||||||
params: List[Tensor],
|
params: List[Tensor],
|
||||||
@@ -321,8 +339,9 @@ def _multi_tensor_adopt(
|
|||||||
beta1: float,
|
beta1: float,
|
||||||
beta2: float,
|
beta2: float,
|
||||||
lr: Union[float, Tensor],
|
lr: Union[float, Tensor],
|
||||||
|
clip_lambda: Optional[Callable[[int], float]],
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
decoupled: bool,
|
decouple: bool,
|
||||||
eps: float,
|
eps: float,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
capturable: bool,
|
capturable: bool,
|
||||||
@@ -376,6 +395,51 @@ def _multi_tensor_adopt(
|
|||||||
if maximize:
|
if maximize:
|
||||||
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
||||||
|
|
||||||
|
if weight_decay != 0 and not decouple:
|
||||||
|
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||||
|
if maximize:
|
||||||
|
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||||
|
else:
|
||||||
|
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||||
|
device_grads, device_params, alpha=weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
if device_state_steps[0] == 0:
|
||||||
|
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
||||||
|
|
||||||
|
# Update steps
|
||||||
|
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||||
|
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||||
|
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||||
|
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||||
|
torch._foreach_add_(
|
||||||
|
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch._foreach_add_(device_state_steps, 1)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
if weight_decay != 0 and decouple:
|
||||||
|
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
|
||||||
|
|
||||||
|
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||||
|
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
|
||||||
|
|
||||||
|
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
|
||||||
|
if clip_lambda is not None:
|
||||||
|
clip = clip_lambda(device_state_steps[0])
|
||||||
|
torch._foreach_maximum_(normed_grad, -clip)
|
||||||
|
torch._foreach_minimum_(normed_grad, clip)
|
||||||
|
|
||||||
|
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
|
||||||
|
|
||||||
|
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||||
|
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||||
|
torch._foreach_addcmul_(
|
||||||
|
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
||||||
|
)
|
||||||
|
|
||||||
# Update steps
|
# Update steps
|
||||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||||
@@ -387,41 +451,6 @@ def _multi_tensor_adopt(
|
|||||||
else:
|
else:
|
||||||
torch._foreach_add_(device_state_steps, 1)
|
torch._foreach_add_(device_state_steps, 1)
|
||||||
|
|
||||||
if weight_decay != 0:
|
|
||||||
if decoupled:
|
|
||||||
torch._foreach_add_(
|
|
||||||
device_params, device_params, alpha=-lr * weight_decay
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
|
||||||
if maximize:
|
|
||||||
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
|
||||||
else:
|
|
||||||
device_grads = torch._foreach_add( # type: ignore[assignment]
|
|
||||||
device_grads, device_params, alpha=weight_decay
|
|
||||||
)
|
|
||||||
|
|
||||||
if device_state_steps[0] == 1:
|
|
||||||
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
|
||||||
continue
|
|
||||||
|
|
||||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
|
||||||
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
|
|
||||||
|
|
||||||
if device_state_steps[0] == 2:
|
|
||||||
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
|
|
||||||
else:
|
|
||||||
torch._foreach_mul_(device_exp_avgs, beta1)
|
|
||||||
torch._foreach_addcdiv_(
|
|
||||||
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
|
|
||||||
)
|
|
||||||
|
|
||||||
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
|
||||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
|
||||||
torch._foreach_addcmul_(
|
|
||||||
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
|
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
|
||||||
def adopt(
|
def adopt(
|
||||||
@@ -443,8 +472,9 @@ def adopt(
|
|||||||
beta1: float,
|
beta1: float,
|
||||||
beta2: float,
|
beta2: float,
|
||||||
lr: Union[float, Tensor],
|
lr: Union[float, Tensor],
|
||||||
|
clip_lambda: Optional[Callable[[int], float]],
|
||||||
weight_decay: float,
|
weight_decay: float,
|
||||||
decoupled: bool,
|
decouple: bool,
|
||||||
eps: float,
|
eps: float,
|
||||||
maximize: bool,
|
maximize: bool,
|
||||||
):
|
):
|
||||||
@@ -497,8 +527,9 @@ def adopt(
|
|||||||
beta1=beta1,
|
beta1=beta1,
|
||||||
beta2=beta2,
|
beta2=beta2,
|
||||||
lr=lr,
|
lr=lr,
|
||||||
|
clip_lambda=clip_lambda,
|
||||||
weight_decay=weight_decay,
|
weight_decay=weight_decay,
|
||||||
decoupled=decoupled,
|
decouple=decouple,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
capturable=capturable,
|
capturable=capturable,
|
||||||
|
|||||||
36
tests/cli/conftest.py
Normal file
36
tests/cli/conftest.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""Shared pytest fixtures for cli module."""
|
||||||
|
import pytest
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
VALID_TEST_CONFIG = """
|
||||||
|
base_model: HuggingFaceTB/SmolLM2-135M
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
sequence_len: 2048
|
||||||
|
max_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
learning_rate: 1e-3
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|endoftext|>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cli_runner():
|
||||||
|
return CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_test_config():
|
||||||
|
return VALID_TEST_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config_path(tmp_path):
|
||||||
|
"""Creates a temporary config file"""
|
||||||
|
path = tmp_path / "config.yml"
|
||||||
|
path.write_text(VALID_TEST_CONFIG)
|
||||||
|
|
||||||
|
return path
|
||||||
38
tests/cli/test_cli_fetch.py
Normal file
38
tests/cli/test_cli_fetch.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""pytest tests for axolotl CLI fetch command."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.cli.main import fetch
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_cli_examples(cli_runner):
|
||||||
|
"""Test fetch command with examples directory"""
|
||||||
|
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
|
||||||
|
result = cli_runner.invoke(fetch, ["examples"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
mock_fetch.assert_called_once_with("examples/", None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_cli_deepspeed(cli_runner):
|
||||||
|
"""Test fetch command with deepspeed_configs directory"""
|
||||||
|
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
|
||||||
|
result = cli_runner.invoke(fetch, ["deepspeed_configs"])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
mock_fetch.assert_called_once_with("deepspeed_configs/", None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_cli_with_dest(cli_runner, tmp_path):
|
||||||
|
"""Test fetch command with custom destination"""
|
||||||
|
with patch("axolotl.cli.main.fetch_from_github") as mock_fetch:
|
||||||
|
custom_dir = tmp_path / "tmp_examples"
|
||||||
|
result = cli_runner.invoke(fetch, ["examples", "--dest", str(custom_dir)])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
mock_fetch.assert_called_once_with("examples/", str(custom_dir))
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_cli_invalid_directory(cli_runner):
|
||||||
|
"""Test fetch command with invalid directory choice"""
|
||||||
|
result = cli_runner.invoke(fetch, ["invalid"])
|
||||||
|
assert result.exit_code != 0
|
||||||
30
tests/cli/test_cli_inference.py
Normal file
30
tests/cli/test_cli_inference.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""pytest tests for axolotl CLI inference command."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_inference_basic(cli_runner, config_path):
|
||||||
|
"""Test basic inference"""
|
||||||
|
with patch("axolotl.cli.inference.do_inference") as mock:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
["inference", str(config_path), "--no-accelerate"],
|
||||||
|
catch_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_inference_gradio(cli_runner, config_path):
|
||||||
|
"""Test basic inference (gradio path)"""
|
||||||
|
with patch("axolotl.cli.inference.do_inference_gradio") as mock:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
["inference", str(config_path), "--no-accelerate", "--gradio"],
|
||||||
|
catch_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert result.exit_code == 0
|
||||||
47
tests/cli/test_cli_interface.py
Normal file
47
tests/cli/test_cli_interface.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""General pytest tests for axolotl.cli.main interface."""
|
||||||
|
from axolotl.cli.main import build_command, cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_command():
|
||||||
|
"""Test converting dict of options to CLI arguments"""
|
||||||
|
base_cmd = ["accelerate", "launch"]
|
||||||
|
options = {
|
||||||
|
"learning_rate": 1e-4,
|
||||||
|
"batch_size": 8,
|
||||||
|
"debug": True,
|
||||||
|
"use_fp16": False,
|
||||||
|
"null_value": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = build_command(base_cmd, options)
|
||||||
|
assert result == [
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--learning-rate",
|
||||||
|
"0.0001",
|
||||||
|
"--batch-size",
|
||||||
|
"8",
|
||||||
|
"--debug",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_command_options(cli_runner):
|
||||||
|
"""Test handling of invalid command options"""
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"train",
|
||||||
|
"config.yml",
|
||||||
|
"--invalid-option",
|
||||||
|
"value",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "No such option" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_required_config_argument(cli_runner):
|
||||||
|
"""Test commands fail properly when config argument is missing"""
|
||||||
|
result = cli_runner.invoke(cli, ["train"])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "Missing argument 'CONFIG'" in result.output
|
||||||
56
tests/cli/test_cli_merge_lora.py
Normal file
56
tests/cli/test_cli_merge_lora.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""pytest tests for axolotl CLI merge_lora command."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_lora_basic(cli_runner, config_path):
|
||||||
|
"""Test basic merge_lora command"""
|
||||||
|
with patch("axolotl.cli.merge_lora.do_cli") as mock_do_cli:
|
||||||
|
result = cli_runner.invoke(cli, ["merge-lora", str(config_path)])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
mock_do_cli.assert_called_once()
|
||||||
|
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_lora_with_dirs(cli_runner, config_path, tmp_path):
|
||||||
|
"""Test merge_lora with custom lora and output directories"""
|
||||||
|
lora_dir = tmp_path / "lora"
|
||||||
|
output_dir = tmp_path / "output"
|
||||||
|
lora_dir.mkdir()
|
||||||
|
|
||||||
|
with patch("axolotl.cli.merge_lora.do_cli") as mock_do_cli:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"merge-lora",
|
||||||
|
str(config_path),
|
||||||
|
"--lora-model-dir",
|
||||||
|
str(lora_dir),
|
||||||
|
"--output-dir",
|
||||||
|
str(output_dir),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
mock_do_cli.assert_called_once()
|
||||||
|
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock_do_cli.call_args.kwargs["lora_model_dir"] == str(lora_dir)
|
||||||
|
assert mock_do_cli.call_args.kwargs["output_dir"] == str(output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_lora_nonexistent_config(cli_runner, tmp_path):
|
||||||
|
"""Test merge_lora with nonexistent config"""
|
||||||
|
config_path = tmp_path / "nonexistent.yml"
|
||||||
|
result = cli_runner.invoke(cli, ["merge-lora", str(config_path)])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_lora_nonexistent_lora_dir(cli_runner, config_path, tmp_path):
|
||||||
|
"""Test merge_lora with nonexistent lora directory"""
|
||||||
|
lora_dir = tmp_path / "nonexistent"
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli, ["merge-lora", str(config_path), "--lora-model-dir", str(lora_dir)]
|
||||||
|
)
|
||||||
|
assert result.exit_code != 0
|
||||||
60
tests/cli/test_cli_merge_sharded_fsdp_weights.py
Normal file
60
tests/cli/test_cli_merge_sharded_fsdp_weights.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
|
||||||
|
"""Test merge_sharded_fsdp_weights command without accelerate"""
|
||||||
|
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli, ["merge-sharded-fsdp-weights", str(config_path), "--no-accelerate"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_sharded_fsdp_weights_with_model_dir(cli_runner, config_path, tmp_path):
|
||||||
|
"""Test merge_sharded_fsdp_weights command with model_dir option"""
|
||||||
|
model_dir = tmp_path / "model"
|
||||||
|
model_dir.mkdir()
|
||||||
|
|
||||||
|
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"merge-sharded-fsdp-weights",
|
||||||
|
str(config_path),
|
||||||
|
"--no-accelerate",
|
||||||
|
"--model-dir",
|
||||||
|
str(model_dir),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_sharded_fsdp_weights_with_save_path(cli_runner, config_path):
|
||||||
|
"""Test merge_sharded_fsdp_weights command with save_path option"""
|
||||||
|
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"merge-sharded-fsdp-weights",
|
||||||
|
str(config_path),
|
||||||
|
"--no-accelerate",
|
||||||
|
"--save-path",
|
||||||
|
"/path/to/save",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock.call_args.kwargs["save_path"] == "/path/to/save"
|
||||||
|
assert result.exit_code == 0
|
||||||
71
tests/cli/test_cli_preprocess.py
Normal file
71
tests/cli/test_cli_preprocess.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""pytest tests for axolotl CLI preprocess command."""
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def cleanup_last_run_prepared():
|
||||||
|
yield
|
||||||
|
|
||||||
|
if Path("last_run_prepared").exists():
|
||||||
|
shutil.rmtree("last_run_prepared")
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess_config_not_found(cli_runner):
|
||||||
|
"""Test preprocess fails when config not found"""
|
||||||
|
result = cli_runner.invoke(cli, ["preprocess", "nonexistent.yml"])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess_basic(cli_runner, config_path):
|
||||||
|
"""Test basic preprocessing with minimal config"""
|
||||||
|
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
|
||||||
|
result = cli_runner.invoke(cli, ["preprocess", str(config_path)])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
mock_do_cli.assert_called_once()
|
||||||
|
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock_do_cli.call_args.kwargs["download"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess_without_download(cli_runner, config_path):
|
||||||
|
"""Test preprocessing without model download"""
|
||||||
|
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli, ["preprocess", str(config_path), "--no-download"]
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
mock_do_cli.assert_called_once()
|
||||||
|
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock_do_cli.call_args.kwargs["download"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess_custom_path(cli_runner, tmp_path, valid_test_config):
|
||||||
|
"""Test preprocessing with custom dataset path"""
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
custom_path = tmp_path / "custom_prepared"
|
||||||
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
|
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"preprocess",
|
||||||
|
str(config_path),
|
||||||
|
"--dataset-prepared-path",
|
||||||
|
str(custom_path.absolute()),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
mock_do_cli.assert_called_once()
|
||||||
|
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock_do_cli.call_args.kwargs["dataset_prepared_path"] == str(
|
||||||
|
custom_path.absolute()
|
||||||
|
)
|
||||||
76
tests/cli/test_cli_shard.py
Normal file
76
tests/cli/test_cli_shard.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""pytest tests for axolotl CLI shard command."""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_shard_with_accelerate(cli_runner, config_path):
|
||||||
|
"""Test shard command with accelerate"""
|
||||||
|
with patch("subprocess.run") as mock:
|
||||||
|
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert mock.call_args.args[0] == [
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.shard",
|
||||||
|
str(config_path),
|
||||||
|
"--debug-num-examples",
|
||||||
|
"0",
|
||||||
|
]
|
||||||
|
assert mock.call_args.kwargs == {"check": True}
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_shard_no_accelerate(cli_runner, config_path):
|
||||||
|
"""Test shard command without accelerate"""
|
||||||
|
with patch("axolotl.cli.shard.do_cli") as mock:
|
||||||
|
result = cli_runner.invoke(cli, ["shard", str(config_path), "--no-accelerate"])
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_shard_with_model_dir(cli_runner, config_path, tmp_path):
|
||||||
|
"""Test shard command with model_dir option"""
|
||||||
|
model_dir = tmp_path / "model"
|
||||||
|
model_dir.mkdir()
|
||||||
|
|
||||||
|
with patch("axolotl.cli.shard.do_cli") as mock:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"shard",
|
||||||
|
str(config_path),
|
||||||
|
"--no-accelerate",
|
||||||
|
"--model-dir",
|
||||||
|
str(model_dir),
|
||||||
|
],
|
||||||
|
catch_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_shard_with_save_dir(cli_runner, config_path):
|
||||||
|
with patch("axolotl.cli.shard.do_cli") as mock:
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"shard",
|
||||||
|
str(config_path),
|
||||||
|
"--no-accelerate",
|
||||||
|
"--save-dir",
|
||||||
|
"/path/to/save",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||||
|
assert mock.call_args.kwargs["save_dir"] == "/path/to/save"
|
||||||
|
assert result.exit_code == 0
|
||||||
98
tests/cli/test_cli_train.py
Normal file
98
tests/cli/test_cli_train.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""pytest tests for axolotl CLI train command."""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_cli_validation(cli_runner):
|
||||||
|
"""Test CLI validation"""
|
||||||
|
# Test missing config file
|
||||||
|
result = cli_runner.invoke(cli, ["train", "--no-accelerate"])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
|
||||||
|
# Test non-existent config file
|
||||||
|
result = cli_runner.invoke(cli, ["train", "nonexistent.yml", "--no-accelerate"])
|
||||||
|
assert result.exit_code != 0
|
||||||
|
assert "Error: Invalid value for 'CONFIG'" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_basic_execution(cli_runner, tmp_path, valid_test_config):
|
||||||
|
"""Test basic successful execution"""
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
|
with patch("subprocess.run") as mock:
|
||||||
|
result = cli_runner.invoke(cli, ["train", str(config_path)])
|
||||||
|
|
||||||
|
assert mock.called
|
||||||
|
assert mock.call_args.args[0] == [
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.train",
|
||||||
|
str(config_path),
|
||||||
|
"--debug-num-examples",
|
||||||
|
"0",
|
||||||
|
]
|
||||||
|
assert mock.call_args.kwargs == {"check": True}
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_basic_execution_no_accelerate(cli_runner, tmp_path, valid_test_config):
|
||||||
|
"""Test basic successful execution"""
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
config_path.write_text(valid_test_config)
|
||||||
|
|
||||||
|
with patch("axolotl.cli.train.train") as mock_train:
|
||||||
|
mock_train.return_value = (MagicMock(), MagicMock())
|
||||||
|
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"train",
|
||||||
|
str(config_path),
|
||||||
|
"--learning-rate",
|
||||||
|
"1e-4",
|
||||||
|
"--micro-batch-size",
|
||||||
|
"2",
|
||||||
|
"--no-accelerate",
|
||||||
|
],
|
||||||
|
catch_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
mock_train.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_cli_overrides(cli_runner, tmp_path, valid_test_config):
|
||||||
|
"""Test CLI arguments properly override config values"""
|
||||||
|
config_path = tmp_path / "config.yml"
|
||||||
|
output_dir = tmp_path / "model-out"
|
||||||
|
|
||||||
|
test_config = valid_test_config.replace(
|
||||||
|
"output_dir: model-out", f"output_dir: {output_dir}"
|
||||||
|
)
|
||||||
|
config_path.write_text(test_config)
|
||||||
|
|
||||||
|
with patch("axolotl.cli.train.train") as mock_train:
|
||||||
|
mock_train.return_value = (MagicMock(), MagicMock())
|
||||||
|
|
||||||
|
result = cli_runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"train",
|
||||||
|
str(config_path),
|
||||||
|
"--learning-rate",
|
||||||
|
"1e-4",
|
||||||
|
"--micro-batch-size",
|
||||||
|
"2",
|
||||||
|
"--no-accelerate",
|
||||||
|
],
|
||||||
|
catch_exceptions=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
mock_train.assert_called_once()
|
||||||
|
cfg = mock_train.call_args[1]["cfg"]
|
||||||
|
assert cfg["learning_rate"] == 1e-4
|
||||||
|
assert cfg["micro_batch_size"] == 2
|
||||||
72
tests/cli/test_utils.py
Normal file
72
tests/cli/test_utils.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
"""pytest tests for axolotl CLI utils."""
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
import json
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import click
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from axolotl.cli.utils import fetch_from_github
|
||||||
|
|
||||||
|
# Sample GitHub API response
|
||||||
|
MOCK_TREE_RESPONSE = {
|
||||||
|
"tree": [
|
||||||
|
{"path": "examples/config1.yml", "type": "blob", "sha": "abc123"},
|
||||||
|
{"path": "examples/config2.yml", "type": "blob", "sha": "def456"},
|
||||||
|
{"path": "other/file.txt", "type": "blob", "sha": "xyz789"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_responses():
|
||||||
|
"""Mock responses for API and file downloads"""
|
||||||
|
|
||||||
|
def mock_get(url, timeout=None): # pylint: disable=unused-argument
|
||||||
|
response = Mock()
|
||||||
|
if "api.github.com" in url:
|
||||||
|
response.text = json.dumps(MOCK_TREE_RESPONSE)
|
||||||
|
else:
|
||||||
|
response.content = b"file content"
|
||||||
|
return response
|
||||||
|
|
||||||
|
return mock_get
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_from_github_new_files(tmp_path, mock_responses):
|
||||||
|
"""Test fetching new files"""
|
||||||
|
with patch("requests.get", mock_responses):
|
||||||
|
fetch_from_github("examples/", tmp_path)
|
||||||
|
|
||||||
|
# Verify files were created
|
||||||
|
assert (tmp_path / "config1.yml").exists()
|
||||||
|
assert (tmp_path / "config2.yml").exists()
|
||||||
|
assert not (tmp_path / "file.txt").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_from_github_unchanged_files(tmp_path, mock_responses):
|
||||||
|
"""Test handling of unchanged files"""
|
||||||
|
# Create existing file with matching SHA
|
||||||
|
existing_file = tmp_path / "config1.yml"
|
||||||
|
existing_file.write_bytes(b"file content")
|
||||||
|
|
||||||
|
with patch("requests.get", mock_responses):
|
||||||
|
fetch_from_github("examples/", tmp_path)
|
||||||
|
|
||||||
|
# File should not be downloaded again
|
||||||
|
assert existing_file.read_bytes() == b"file content"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_from_github_invalid_prefix(mock_responses):
|
||||||
|
"""Test error handling for invalid directory prefix"""
|
||||||
|
with patch("requests.get", mock_responses):
|
||||||
|
with pytest.raises(click.ClickException):
|
||||||
|
fetch_from_github("nonexistent/", None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_from_github_network_error():
|
||||||
|
"""Test handling of network errors"""
|
||||||
|
with patch("requests.get", side_effect=requests.RequestException):
|
||||||
|
with pytest.raises(requests.RequestException):
|
||||||
|
fetch_from_github("examples/", None)
|
||||||
144
tests/conftest.py
Normal file
144
tests/conftest.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""
|
||||||
|
shared pytest fixtures
|
||||||
|
"""
|
||||||
|
import functools
|
||||||
|
import importlib
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
def decorator(func):
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except (
|
||||||
|
requests.exceptions.ReadTimeout,
|
||||||
|
requests.exceptions.ConnectionError,
|
||||||
|
) as exc:
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
time.sleep(delay)
|
||||||
|
else:
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
|
def snapshot_download_w_retry(*args, **kwargs):
|
||||||
|
return snapshot_download(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_smollm2_135m_model():
|
||||||
|
# download the model
|
||||||
|
snapshot_download_w_retry("HuggingFaceTB/SmolLM2-135M")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_llama_68m_random_model():
|
||||||
|
# download the model
|
||||||
|
snapshot_download_w_retry("JackFram/llama-68m")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_qwen_2_5_half_billion_model():
|
||||||
|
# download the model
|
||||||
|
snapshot_download_w_retry("Qwen/Qwen2.5-0.5B")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_tatsu_lab_alpaca_dataset():
|
||||||
|
# download the dataset
|
||||||
|
snapshot_download_w_retry("tatsu-lab/alpaca", repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_mhenrichsen_alpaca_2k_dataset():
|
||||||
|
# download the dataset
|
||||||
|
snapshot_download_w_retry("mhenrichsen/alpaca_2k_test", repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_mhenrichsen_alpaca_2k_w_revision_dataset():
|
||||||
|
# download the dataset
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"mhenrichsen/alpaca_2k_test", repo_type="dataset", revision="d05c1cb"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_mlabonne_finetome_100k_dataset():
|
||||||
|
# download the dataset
|
||||||
|
snapshot_download_w_retry("mlabonne/FineTome-100k", repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_argilla_distilabel_capybara_dpo_7k_binarized_dataset():
|
||||||
|
# download the dataset
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"argilla/distilabel-capybara-dpo-7k-binarized", repo_type="dataset"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
||||||
|
# download the dataset
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
||||||
|
# download the dataset
|
||||||
|
snapshot_download_w_retry(
|
||||||
|
"arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", repo_type="dataset"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
# Create a temporary directory
|
||||||
|
_temp_dir = tempfile.mkdtemp()
|
||||||
|
yield _temp_dir
|
||||||
|
# Clean up the directory after the test
|
||||||
|
shutil.rmtree(_temp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def cleanup_monkeypatches():
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||||
|
|
||||||
|
original_fa2_forward = LlamaFlashAttention2.forward
|
||||||
|
# monkey patches can happen inside the tests
|
||||||
|
yield
|
||||||
|
# Reset LlamaFlashAttention2 forward
|
||||||
|
LlamaFlashAttention2.forward = original_fa2_forward
|
||||||
|
|
||||||
|
# Reset other known monkeypatches
|
||||||
|
modules_to_reset: list[tuple[str, list[str]]] = [
|
||||||
|
("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]),
|
||||||
|
("transformers.trainer",),
|
||||||
|
("transformers.loss.loss_utils",),
|
||||||
|
]
|
||||||
|
for module_name_tuple in modules_to_reset:
|
||||||
|
module_name = module_name_tuple[0]
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
sys.modules[module_name] = module
|
||||||
|
importlib.reload(sys.modules[module_name])
|
||||||
|
if len(module_name_tuple) > 1:
|
||||||
|
module_globals = module_name_tuple[1]
|
||||||
|
for module_global in module_globals:
|
||||||
|
globals().pop(module_global, None)
|
||||||
32
tests/constants.py
Normal file
32
tests/constants.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
# constants.py
|
||||||
|
"""
|
||||||
|
This module contains constants and configuration dictionaries used for
|
||||||
|
datasets and other utilities in the Axolotl project, specifically for testing.
|
||||||
|
"""
|
||||||
|
# Configuration for Alpaca Messages Dataset
|
||||||
|
ALPACA_MESSAGES_CONFIG_OG = {
|
||||||
|
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||||
|
"type": "chat_template.default",
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"field_messages": "conversation",
|
||||||
|
"field_chosen": "chosen",
|
||||||
|
"field_rejected": "rejected",
|
||||||
|
"message_field_role": "role",
|
||||||
|
"message_field_content": "content",
|
||||||
|
"roles": {
|
||||||
|
"system": ["system"],
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Revision configuration extending the original
|
||||||
|
ALPACA_MESSAGES_CONFIG_REVISION = ALPACA_MESSAGES_CONFIG_OG.copy()
|
||||||
|
ALPACA_MESSAGES_CONFIG_REVISION["revision"] = "ea82cff"
|
||||||
|
|
||||||
|
|
||||||
|
SPECIAL_TOKENS = {
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
}
|
||||||
@@ -14,9 +14,7 @@ from axolotl.utils.models import load_model, load_tokenizer
|
|||||||
def fixture_cfg():
|
def fixture_cfg():
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"model_type": "AutoModelForCausalLM",
|
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"learning_rate": 0.00005,
|
"learning_rate": 0.00005,
|
||||||
@@ -33,6 +31,9 @@ def fixture_cfg():
|
|||||||
"dataloader_num_workers": 1,
|
"dataloader_num_workers": 1,
|
||||||
"dataloader_pin_memory": True,
|
"dataloader_pin_memory": True,
|
||||||
"model_config_type": "llama",
|
"model_config_type": "llama",
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,35 +0,0 @@
|
|||||||
"""
|
|
||||||
shared pytest fixtures
|
|
||||||
"""
|
|
||||||
import shutil
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_smollm2_135m_model():
|
|
||||||
# download the model
|
|
||||||
snapshot_download("HuggingFaceTB/SmolLM2-135M")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_tatsu_lab_alpaca_dataset():
|
|
||||||
# download the model
|
|
||||||
snapshot_download("tatsu-lab/alpaca", repo_type="dataset")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def download_mhenrichsen_alpaca_2k_dataset():
|
|
||||||
# download the model
|
|
||||||
snapshot_download("mhenrichsen/alpaca_2k_test", repo_type="dataset")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def temp_dir():
|
|
||||||
# Create a temporary directory
|
|
||||||
_temp_dir = tempfile.mkdtemp()
|
|
||||||
yield _temp_dir
|
|
||||||
# Clean up the directory after the test
|
|
||||||
shutil.rmtree(_temp_dir)
|
|
||||||
@@ -7,7 +7,7 @@ from pathlib import Path
|
|||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import with_temp_dir
|
from ..utils import with_temp_dir
|
||||||
@@ -54,8 +54,10 @@ class LigerIntegrationTestCase(unittest.TestCase):
|
|||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"save_safetensors": True,
|
"save_safetensors": True,
|
||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
|
"max_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
prepare_plugins(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
@@ -99,8 +101,10 @@ class LigerIntegrationTestCase(unittest.TestCase):
|
|||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"save_safetensors": True,
|
"save_safetensors": True,
|
||||||
"bf16": "auto",
|
"bf16": "auto",
|
||||||
|
"max_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
prepare_plugins(cfg)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
94
tests/e2e/integrations/test_cut_cross_entropy.py
Normal file
94
tests/e2e/integrations/test_cut_cross_entropy.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
"""
|
||||||
|
Simple end-to-end test for Cut Cross Entropy integration
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils import get_pytorch_version
|
||||||
|
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def min_cfg(temp_dir):
|
||||||
|
return {
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"plugins": [
|
||||||
|
"axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin",
|
||||||
|
],
|
||||||
|
"cut_cross_entropy": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"max_steps": 10,
|
||||||
|
"bf16": "auto",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestCutCrossEntropyIntegration:
|
||||||
|
"""
|
||||||
|
e2e tests for cut_cross_entropy integration with Axolotl
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
def test_llama_w_cce(self, min_cfg, temp_dir):
|
||||||
|
cfg = DictDefault(min_cfg)
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
major, minor, _ = get_pytorch_version()
|
||||||
|
if (major, minor) < (2, 4):
|
||||||
|
with pytest.raises(ImportError):
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
else:
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"attention_type",
|
||||||
|
["flash_attention", "sdp_attention", "xformers_attention"],
|
||||||
|
)
|
||||||
|
def test_llama_w_cce_and_attention(self, min_cfg, temp_dir, attention_type):
|
||||||
|
cfg = DictDefault(
|
||||||
|
min_cfg
|
||||||
|
| {
|
||||||
|
attention_type: True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
prepare_plugins(cfg)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
major, minor, _ = get_pytorch_version()
|
||||||
|
if (major, minor) < (2, 4):
|
||||||
|
with pytest.raises(ImportError):
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
else:
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
@@ -11,6 +11,8 @@ from transformers.testing_utils import get_torch_dist_unique_port
|
|||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from ..utils import check_tensorboard
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
@@ -26,7 +28,7 @@ class TestMultiGPUEval:
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"load_in_8bit": False,
|
"load_in_8bit": False,
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"strict": False,
|
"strict": False,
|
||||||
@@ -40,8 +42,8 @@ class TestMultiGPUEval:
|
|||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.004,
|
||||||
"special_tokens": {"pad_token": "<|end_of_text|>"},
|
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "teknium/GPT4-LLM-Cleaned",
|
"path": "teknium/GPT4-LLM-Cleaned",
|
||||||
@@ -66,6 +68,7 @@ class TestMultiGPUEval:
|
|||||||
"saves_per_epoch": 1,
|
"saves_per_epoch": 1,
|
||||||
"logging_steps": 1,
|
"logging_steps": 1,
|
||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -88,11 +91,13 @@ class TestMultiGPUEval:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
check_tensorboard(temp_dir + "/runs", "eval/loss", 2.5, "Eval Loss is too high")
|
||||||
|
|
||||||
def test_eval(self, temp_dir):
|
def test_eval(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"load_in_8bit": False,
|
"load_in_8bit": False,
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"strict": False,
|
"strict": False,
|
||||||
@@ -106,8 +111,8 @@ class TestMultiGPUEval:
|
|||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.0004,
|
||||||
"special_tokens": {"pad_token": "<|end_of_text|>"},
|
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "teknium/GPT4-LLM-Cleaned",
|
"path": "teknium/GPT4-LLM-Cleaned",
|
||||||
@@ -132,6 +137,7 @@ class TestMultiGPUEval:
|
|||||||
"saves_per_epoch": 1,
|
"saves_per_epoch": 1,
|
||||||
"logging_steps": 1,
|
"logging_steps": 1,
|
||||||
"weight_decay": 0.0,
|
"weight_decay": 0.0,
|
||||||
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -153,3 +159,5 @@ class TestMultiGPUEval:
|
|||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
check_tensorboard(temp_dir + "/runs", "eval/loss", 2.9, "Eval Loss is too high")
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ from transformers.testing_utils import get_torch_dist_unique_port
|
|||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import is_hopper
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
@@ -144,7 +142,6 @@ class TestMultiGPULlama:
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
|
|
||||||
def test_dpo_lora_ddp(self, temp_dir):
|
def test_dpo_lora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
|||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.02,
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
@@ -86,7 +86,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
|||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.02,
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
|||||||
47
tests/e2e/patched/test_cli_integrations.py
Normal file
47
tests/e2e/patched/test_cli_integrations.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""
|
||||||
|
test cases to make sure the plugin args are loaded from the config file
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from axolotl.cli import load_cfg
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
class TestPluginArgs:
|
||||||
|
"""
|
||||||
|
test class for plugin args loaded from the config file
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_liger_plugin_args(self, temp_dir):
|
||||||
|
test_cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"learning_rate": 0.000001,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"plugins": ["axolotl.integrations.liger.LigerPlugin"],
|
||||||
|
"liger_layer_norm": True,
|
||||||
|
"liger_rope": True,
|
||||||
|
"liger_rms_norm": False,
|
||||||
|
"liger_glu_activation": True,
|
||||||
|
"liger_fused_linear_cross_entropy": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(test_cfg.to_dict()))
|
||||||
|
cfg = load_cfg(str(Path(temp_dir) / "config.yaml"))
|
||||||
|
assert cfg.liger_layer_norm is True
|
||||||
|
assert cfg.liger_rope is True
|
||||||
|
assert cfg.liger_rms_norm is False
|
||||||
|
assert cfg.liger_glu_activation is True
|
||||||
|
assert cfg.liger_fused_linear_cross_entropy is True
|
||||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
from importlib import reload
|
from importlib import reload
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -17,7 +16,7 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import with_temp_dir
|
from ..utils import check_tensorboard
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -31,49 +30,55 @@ def reload_transformers():
|
|||||||
reload(transformers.models.llama.modeling_llama)
|
reload(transformers.models.llama.modeling_llama)
|
||||||
|
|
||||||
|
|
||||||
class TestFAXentropyLlama(unittest.TestCase):
|
class TestFAXentropyLlama:
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using LoRA w multipack
|
Test case for Llama models using LoRA w multipack
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize(
|
||||||
def test_lora_packing_fa_cross_entropy(self, temp_dir):
|
"gradient_accumulation_steps",
|
||||||
|
[1, 4],
|
||||||
|
)
|
||||||
|
def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_steps):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"flash_attn_cross_entropy": True,
|
"flash_attn_cross_entropy": True,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 32,
|
"lora_r": 8,
|
||||||
"lora_alpha": 64,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.2,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
|
"chat_template": "chatml",
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
"path": "mlabonne/FineTome-100k",
|
||||||
"type": "alpaca",
|
"field_messages": "conversations",
|
||||||
|
"message_field_content": "value",
|
||||||
|
"message_field_role": "from",
|
||||||
|
"type": "chat_template",
|
||||||
|
"split": "train[:2%]",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 10,
|
"max_steps": 5,
|
||||||
"save_steps": 10,
|
"save_steps": 5,
|
||||||
"micro_batch_size": 8,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if is_torch_bf16_gpu_available():
|
if is_torch_bf16_gpu_available():
|
||||||
@@ -87,3 +92,7 @@ class TestFAXentropyLlama(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class TestFalconPatched(unittest.TestCase):
|
|||||||
"lora_dropout": 0.1,
|
"lora_dropout": 0.1,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"lora_modules_to_save": ["word_embeddings", "lm_head"],
|
"lora_modules_to_save": ["word_embeddings", "lm_head"],
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"bos_token": "<|endoftext|>",
|
"bos_token": "<|endoftext|>",
|
||||||
"pad_token": "<|endoftext|>",
|
"pad_token": "<|endoftext|>",
|
||||||
@@ -80,7 +80,7 @@ class TestFalconPatched(unittest.TestCase):
|
|||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"bos_token": "<|endoftext|>",
|
"bos_token": "<|endoftext|>",
|
||||||
"pad_token": "<|endoftext|>",
|
"pad_token": "<|endoftext|>",
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class TestFusedLlama(unittest.TestCase):
|
|||||||
"flash_attn_fuse_mlp": True,
|
"flash_attn_fuse_mlp": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.02,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"unk_token": "<unk>",
|
||||||
"bos_token": "<s>",
|
"bos_token": "<s>",
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"lora_alpha": 64,
|
"lora_alpha": 64,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.02,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"unk_token": "<unk>",
|
||||||
"bos_token": "<s>",
|
"bos_token": "<s>",
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class TestMistral(unittest.TestCase):
|
|||||||
"lora_alpha": 64,
|
"lora_alpha": 64,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"unk_token": "<unk>",
|
||||||
"bos_token": "<s>",
|
"bos_token": "<s>",
|
||||||
@@ -80,7 +80,7 @@ class TestMistral(unittest.TestCase):
|
|||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"unk_token": "<unk>",
|
||||||
"bos_token": "<s>",
|
"bos_token": "<s>",
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"lora_alpha": 32,
|
"lora_alpha": 32,
|
||||||
"lora_dropout": 0.1,
|
"lora_dropout": 0.1,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {},
|
"special_tokens": {},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -78,7 +78,7 @@ class TestMixtral(unittest.TestCase):
|
|||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {},
|
"special_tokens": {},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class TestPhiMultipack(unittest.TestCase):
|
|||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"load_in_8bit": False,
|
"load_in_8bit": False,
|
||||||
"adapter": None,
|
"adapter": None,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"pad_token": "<|endoftext|>",
|
"pad_token": "<|endoftext|>",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
@@ -17,35 +16,35 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import most_recent_subdir, with_temp_dir
|
from ..utils import most_recent_subdir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
class TestResumeLlama(unittest.TestCase):
|
class TestResumeLlama:
|
||||||
"""
|
"""
|
||||||
Test case for resuming training of llama models
|
Test case for resuming training of llama models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
def test_resume_lora_packed(self, temp_dir):
|
||||||
def test_resume_qlora_packed(self, temp_dir):
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
"load_in_4bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "lora",
|
||||||
"lora_r": 32,
|
"lora_r": 8,
|
||||||
"lora_alpha": 64,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.001,
|
||||||
"special_tokens": {},
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "vicgalle/alpaca-gpt4",
|
"path": "vicgalle/alpaca-gpt4",
|
||||||
@@ -57,11 +56,11 @@ class TestResumeLlama(unittest.TestCase):
|
|||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
"save_steps": 10,
|
"save_steps": 3,
|
||||||
"save_total_limit": 5,
|
"save_total_limit": 5,
|
||||||
"max_steps": 40,
|
"max_steps": 15,
|
||||||
"use_tensorboard": True,
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -77,7 +76,7 @@ class TestResumeLlama(unittest.TestCase):
|
|||||||
|
|
||||||
resume_cfg = cfg | DictDefault(
|
resume_cfg = cfg | DictDefault(
|
||||||
{
|
{
|
||||||
"resume_from_checkpoint": f"{temp_dir}/checkpoint-30/",
|
"resume_from_checkpoint": f"{temp_dir}/checkpoint-9/",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
normalize_config(resume_cfg)
|
normalize_config(resume_cfg)
|
||||||
@@ -93,4 +92,4 @@ class TestResumeLlama(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
pattern = r"first_step\s+(\d+)"
|
pattern = r"first_step\s+(\d+)"
|
||||||
first_steps = int(re.findall(pattern, res.stdout)[0])
|
first_steps = int(re.findall(pattern, res.stdout)[0])
|
||||||
assert first_steps == 31
|
assert first_steps == 10
|
||||||
|
|||||||
186
tests/e2e/patched/test_unsloth_qlora.py
Normal file
186
tests/e2e/patched/test_unsloth_qlora.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
"""
|
||||||
|
e2e tests for unsloth qlora
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from ..utils import check_tensorboard
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
class TestUnslothQLoRA:
|
||||||
|
"""
|
||||||
|
Test class for Unsloth QLoRA Llama models
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sample_packing",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
def test_unsloth_llama_qlora_fa2(self, temp_dir, sample_packing):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"sample_packing": sample_packing,
|
||||||
|
"flash_attention": True,
|
||||||
|
"unsloth_lora_mlp": True,
|
||||||
|
"unsloth_lora_qkv": True,
|
||||||
|
"unsloth_lora_o": True,
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"adapter": "qlora",
|
||||||
|
"lora_r": 16,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.05,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_steps": 10,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_unsloth_llama_qlora_unpacked(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"unsloth_lora_mlp": True,
|
||||||
|
"unsloth_lora_qkv": True,
|
||||||
|
"unsloth_lora_o": True,
|
||||||
|
"sample_packing": False,
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"adapter": "qlora",
|
||||||
|
"lora_r": 16,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.05,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_steps": 10,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sdp_attention",
|
||||||
|
[True, False],
|
||||||
|
)
|
||||||
|
def test_unsloth_llama_qlora_unpacked_no_fa2_fp16(self, temp_dir, sdp_attention):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"unsloth_lora_mlp": True,
|
||||||
|
"unsloth_lora_qkv": True,
|
||||||
|
"unsloth_lora_o": True,
|
||||||
|
"sample_packing": False,
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"adapter": "qlora",
|
||||||
|
"lora_r": 16,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.05,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_steps": 10,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"sdp_attention": sdp_attention,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"use_tensorboard": True,
|
||||||
|
"fp16": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||||
|
)
|
||||||
113
tests/e2e/test_embeddings_lr.py
Normal file
113
tests/e2e/test_embeddings_lr.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for llama pretrain
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import check_tensorboard, with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingsLrScale(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for embedding_lr*
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_train_w_embedding_lr_scale(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"sample_packing": True,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"max_steps": 5,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"embedding_lr_scale": 0.5,
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
"use_tensorboard": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
|
||||||
|
)
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_train_w_embedding_lr(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"sample_packing": True,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"max_steps": 5,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"embedding_lr": 0.000005,
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
"use_tensorboard": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
|
||||||
|
)
|
||||||
116
tests/e2e/test_llama_vision.py
Normal file
116
tests/e2e/test_llama_vision.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for lora llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLlamaVision(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama Vision models
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_lora_llama_vision_text_only_dataset(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
|
||||||
|
"processor_type": "AutoProcessor",
|
||||||
|
"skip_prepare_dataset": True,
|
||||||
|
"remove_unused_columns": False,
|
||||||
|
"sample_packing": False,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
|
||||||
|
"val_set_size": 0,
|
||||||
|
"chat_template": "llama3_2_vision",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "LDJnr/Puffin",
|
||||||
|
"type": "chat_template",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"message_field_role": "from",
|
||||||
|
"message_field_content": "value",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_lora_llama_vision_multimodal_dataset(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
|
||||||
|
"processor_type": "AutoProcessor",
|
||||||
|
"skip_prepare_dataset": True,
|
||||||
|
"remove_unused_columns": False,
|
||||||
|
"sample_packing": False,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_modules": r"language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj",
|
||||||
|
"val_set_size": 0,
|
||||||
|
"chat_template": "llama3_2_vision",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "axolotl-ai-co/llava-instruct-mix-vsft-small",
|
||||||
|
"type": "chat_template",
|
||||||
|
"split": "train",
|
||||||
|
"field_messages": "messages",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 5,
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||||
@@ -57,6 +57,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "optimi_adamw",
|
"optimizer": "optimi_adamw",
|
||||||
|
"max_steps": 5,
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -94,6 +95,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
"micro_batch_size": 8,
|
"micro_batch_size": 8,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
@@ -115,7 +117,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
{
|
{
|
||||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.01,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"pad_token": "<|endoftext|>",
|
"pad_token": "<|endoftext|>",
|
||||||
},
|
},
|
||||||
@@ -126,13 +128,14 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "schedule_free_adamw",
|
"optimizer": "schedule_free_adamw",
|
||||||
"lr_scheduler": "constant",
|
"lr_scheduler": "constant",
|
||||||
"save_safetensors": True,
|
"save_safetensors": True,
|
||||||
|
"max_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from tbparse import SummaryReader
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.cli import load_datasets
|
from axolotl.cli import load_datasets
|
||||||
@@ -15,7 +14,7 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import most_recent_subdir, with_temp_dir
|
from .utils import check_tensorboard, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -66,9 +65,6 @@ class TestPackedLlama(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
|
||||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
check_tensorboard(
|
||||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||||
reader = SummaryReader(event_file)
|
)
|
||||||
df = reader.scalars # pylint: disable=invalid-name
|
|
||||||
df = df[(df.tag == "train/train_loss")] # pylint: disable=invalid-name
|
|
||||||
assert df.value.values[-1] < 2.0, "Loss is too high"
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import with_temp_dir
|
from .utils import check_tensorboard, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -29,35 +29,48 @@ class TestReLoraLlama(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "JackFram/llama-68m",
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
"sequence_len": 2048,
|
||||||
"sequence_len": 1024,
|
"sample_packing": True,
|
||||||
|
"pad_to_sequence_len": True,
|
||||||
|
"flash_attention": True,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 32,
|
"lora_r": 8,
|
||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_modules": ["q_proj", "v_proj"],
|
"lora_target_modules": ["q_proj", "v_proj"],
|
||||||
"relora_steps": 25,
|
"relora_steps": 100,
|
||||||
"relora_warmup_steps": 5,
|
"relora_warmup_steps": 20,
|
||||||
"relora_anneal_steps": 5,
|
"relora_anneal_steps": 10,
|
||||||
|
"relora_prune_ratio": 0.9,
|
||||||
"relora_cpu_offload": True,
|
"relora_cpu_offload": True,
|
||||||
"val_set_size": 0.0,
|
"val_set_size": 0.0,
|
||||||
"special_tokens": {},
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"chat_template": "chatml",
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
"path": "mlabonne/FineTome-100k",
|
||||||
"type": "alpaca",
|
"type": "chat_template",
|
||||||
|
"split": "train[:10%]",
|
||||||
|
"field_messages": "conversations",
|
||||||
|
"message_field_role": "from",
|
||||||
|
"message_field_content": "value",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"warmup_steps": 15,
|
"warmup_steps": 20,
|
||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
"micro_batch_size": 4,
|
"max_steps": 205, # at least 2x relora_steps
|
||||||
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"use_tensorboard": True,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
@@ -65,4 +78,11 @@ class TestReLoraLlama(unittest.TestCase):
|
|||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
assert (
|
||||||
|
Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors"
|
||||||
|
).exists()
|
||||||
|
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()
|
||||||
|
|
||||||
|
check_tensorboard(
|
||||||
|
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"
|
||||||
|
)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import torch
|
|||||||
|
|
||||||
# from importlib.metadata import version
|
# from importlib.metadata import version
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from tbparse import SummaryReader
|
||||||
|
|
||||||
|
|
||||||
def with_temp_dir(test_func):
|
def with_temp_dir(test_func):
|
||||||
@@ -53,7 +54,7 @@ def require_torch_2_3_1(test_case):
|
|||||||
|
|
||||||
def require_torch_2_5_1(test_case):
|
def require_torch_2_5_1(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires torch >= 2.3.1
|
Decorator marking a test that requires torch >= 2.5.1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def is_min_2_5_1():
|
def is_min_2_5_1():
|
||||||
@@ -66,3 +67,17 @@ def require_torch_2_5_1(test_case):
|
|||||||
def is_hopper():
|
def is_hopper():
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
return compute_capability == (9, 0)
|
return compute_capability == (9, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def check_tensorboard(
|
||||||
|
temp_run_dir: str, tag: str, lt_val: float, assertion_err: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
helper function to parse and check tensorboard logs
|
||||||
|
"""
|
||||||
|
tb_log_path = most_recent_subdir(temp_run_dir)
|
||||||
|
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||||
|
reader = SummaryReader(event_file)
|
||||||
|
df = reader.scalars # pylint: disable=invalid-name
|
||||||
|
df = df[(df.tag == tag)] # pylint: disable=invalid-name
|
||||||
|
assert df.value.values[-1] < lt_val, assertion_err
|
||||||
|
|||||||
25
tests/patched/test_llama_trainer_ga.py
Normal file
25
tests/patched/test_llama_trainer_ga.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
""""Test module for checking whether the Hugging Face Transformers is working as expected."""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.trainer_grad_accum import (
|
||||||
|
check_forward_is_patchable,
|
||||||
|
check_training_step_is_patchable,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainerGAIntegration(unittest.TestCase):
|
||||||
|
"""llama monkeypatch integration tests."""
|
||||||
|
|
||||||
|
def test_train_step_patchable(self):
|
||||||
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
|
self.assertTrue(
|
||||||
|
check_training_step_is_patchable(),
|
||||||
|
"HF transformers Trainer.training_step has changed and isn't patchable",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_model_forward_patchable(self):
|
||||||
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
|
self.assertTrue(
|
||||||
|
check_forward_is_patchable(),
|
||||||
|
"HF transformers LlamaForCausalLM.forward has changed and isn't patchable",
|
||||||
|
)
|
||||||
@@ -672,6 +672,9 @@ class TestValidation(BaseValidation):
|
|||||||
{
|
{
|
||||||
"bf16": True,
|
"bf16": True,
|
||||||
"capabilities": {"bf16": False},
|
"capabilities": {"bf16": False},
|
||||||
|
"env_capabilities": {
|
||||||
|
"torch_version": "2.5.1",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -1160,6 +1163,38 @@ class TestValidation(BaseValidation):
|
|||||||
in self._caplog.records[0].message
|
in self._caplog.records[0].message
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_torch_version_adopt_req(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"optimizer": "adopt_adamw",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*ADOPT optimizer is incompatible with torch version*",
|
||||||
|
):
|
||||||
|
env_capabilities = {"torch_version": "2.3.0"}
|
||||||
|
capabilities = {"bf16": False}
|
||||||
|
_ = validate_config(
|
||||||
|
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||||
|
)
|
||||||
|
|
||||||
|
env_capabilities = {"torch_version": "2.5.1"}
|
||||||
|
capabilities = {"bf16": False}
|
||||||
|
_ = validate_config(
|
||||||
|
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||||
|
)
|
||||||
|
|
||||||
|
env_capabilities = {"torch_version": "2.5.2"}
|
||||||
|
capabilities = {"bf16": False}
|
||||||
|
_ = validate_config(
|
||||||
|
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestValidationCheckModelConfig(BaseValidation):
|
class TestValidationCheckModelConfig(BaseValidation):
|
||||||
"""
|
"""
|
||||||
@@ -4,6 +4,7 @@ shared fixtures for prompt strategies tests
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
@@ -60,6 +61,17 @@ def fixture_basic_dataset():
|
|||||||
|
|
||||||
@pytest.fixture(name="llama3_tokenizer")
|
@pytest.fixture(name="llama3_tokenizer")
|
||||||
def fixture_llama3_tokenizer():
|
def fixture_llama3_tokenizer():
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
|
||||||
|
filename="special_tokens_map.json",
|
||||||
|
)
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id="NousResearch/Meta-Llama-3-8B-Instruct",
|
||||||
|
filename="tokenizer_config.json",
|
||||||
|
)
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id="NousResearch/Meta-Llama-3-8B-Instruct", filename="tokenizer.json"
|
||||||
|
)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
test module for the axolotl.utis.data module
|
test module for the axolotl.utils.data module
|
||||||
"""
|
"""
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user