Compare commits
64 Commits
fa-261
...
remove-gpt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b15816cf4 | ||
|
|
fefa95e350 | ||
|
|
b33dc07a77 | ||
|
|
dcbff16983 | ||
|
|
2f8037fee6 | ||
|
|
de4ea2d1f2 | ||
|
|
7ed92e61c2 | ||
|
|
9caa3eb699 | ||
|
|
5b0b774e38 | ||
|
|
c3fc529bfc | ||
|
|
957c956f89 | ||
|
|
f07802f9fa | ||
|
|
9f917245f6 | ||
|
|
649c19aba3 | ||
|
|
5aac4bc284 | ||
|
|
e29931259b | ||
|
|
b1d2921222 | ||
|
|
803fed3e90 | ||
|
|
68a3c7678a | ||
|
|
f18925fb4b | ||
|
|
1853d6021d | ||
|
|
0801f239cc | ||
|
|
54392ac8a6 | ||
|
|
3e2b269d06 | ||
|
|
5ee4b7325f | ||
|
|
70978467a0 | ||
|
|
850f999a76 | ||
|
|
c56e0a79a5 | ||
|
|
35d5e59d78 | ||
|
|
fbbeb4fee0 | ||
|
|
ecdda006de | ||
|
|
b7665c26c8 | ||
|
|
cb023c70db | ||
|
|
7402eb9dcb | ||
|
|
203816f7b4 | ||
|
|
78b42a3fe1 | ||
|
|
3ebf22464b | ||
|
|
dbf8fb549e | ||
|
|
9a63884597 | ||
|
|
c5587b45ac | ||
|
|
d4f6a6b103 | ||
|
|
d8d1788ffc | ||
|
|
3bc8e64557 | ||
|
|
55cc214c76 | ||
|
|
94ba93259f | ||
|
|
22680913f3 | ||
|
|
6a9cfec222 | ||
|
|
fe250ada78 | ||
|
|
e6b299dd79 | ||
|
|
608a2f3180 | ||
|
|
87455e7f32 | ||
|
|
985819d89b | ||
|
|
fa91b698e9 | ||
|
|
e4063d60a7 | ||
|
|
7830fe04b5 | ||
|
|
c86c32a627 | ||
|
|
8731b95d04 | ||
|
|
8619b2d855 | ||
|
|
976f85195a | ||
|
|
152ab76623 | ||
|
|
5f58555bd0 | ||
|
|
cfc533a7f7 | ||
|
|
e1725aef2b | ||
|
|
78e12f8ca5 |
37
.github/workflows/base.yml
vendored
37
.github/workflows/base.yml
vendored
@@ -12,36 +12,24 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: "118"
|
- cuda: "121"
|
||||||
cuda_version: 11.8.0
|
cuda_version: 12.1.1
|
||||||
|
cudnn_version: 8
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.3.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "121"
|
- cuda: "121"
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
cudnn_version: 8
|
||||||
pytorch: 2.1.2
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
- cuda: "121"
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.1.2
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
- cuda: "121"
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.2.2
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
- cuda: "121"
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.3.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
|
||||||
- cuda: "121"
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "124"
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.4.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
@@ -67,6 +55,7 @@ jobs:
|
|||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
build-args: |
|
build-args: |
|
||||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||||
|
CUDNN_VERSION=${{ matrix.cudnn_version }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
PYTHON_VERSION=${{ matrix.python_version }}
|
PYTHON_VERSION=${{ matrix.python_version }}
|
||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
|
|||||||
54
.github/workflows/main.yml
vendored
54
.github/workflows/main.yml
vendored
@@ -13,28 +13,22 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 118
|
- cuda: 121
|
||||||
cuda_version: 11.8.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras: mamba-ssm
|
||||||
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.2
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.2.2
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras: mamba-ssm
|
||||||
is_latest: true
|
is_latest: true
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.4.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -65,6 +59,7 @@ jobs:
|
|||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: |
|
tags: |
|
||||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|
||||||
@@ -75,27 +70,22 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 118
|
- cuda: 121
|
||||||
cuda_version: 11.8.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.2
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.2.2
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.4.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -134,7 +124,7 @@ jobs:
|
|||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
|||||||
52
.github/workflows/multi-gpu-e2e.yml
vendored
Normal file
52
.github/workflows/multi-gpu-e2e.yml
vendored
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
name: docker-multigpu-tests-biweekly
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
schedule:
|
||||||
|
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test-axolotl-multigpu:
|
||||||
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.1
|
||||||
|
axolotl_extras:
|
||||||
|
num_gpus: 2
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.1
|
||||||
|
axolotl_extras:
|
||||||
|
num_gpus: 2
|
||||||
|
nightly_build: "true"
|
||||||
|
runs-on: [self-hosted, modal]
|
||||||
|
timeout-minutes: 120
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Install Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
- name: Install Modal
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install modal==0.63.64 jinja2
|
||||||
|
- name: Update env vars
|
||||||
|
run: |
|
||||||
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
|
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||||
|
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||||
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
|
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||||
|
- name: Run tests job on Modal
|
||||||
|
run: |
|
||||||
|
modal run cicd.multigpu
|
||||||
47
.github/workflows/nightlies.yml
vendored
47
.github/workflows/nightlies.yml
vendored
@@ -12,28 +12,22 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 118
|
- cuda: 121
|
||||||
cuda_version: 11.8.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
|
||||||
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.2
|
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.2.2
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.4.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -75,27 +69,22 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 118
|
- cuda: 121
|
||||||
cuda_version: 11.8.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.2
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.2.2
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.4.0
|
||||||
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
116
.github/workflows/tests-nightly.yml
vendored
Normal file
116
.github/workflows/tests-nightly.yml
vendored
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
name: Tests Nightly against upstream main
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
schedule:
|
||||||
|
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
pre-commit:
|
||||||
|
name: pre-commit
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
cache: 'pip' # caching pip dependencies
|
||||||
|
- uses: pre-commit/action@v3.0.0
|
||||||
|
env:
|
||||||
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
|
pytest:
|
||||||
|
name: PyTest
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python_version: ["3.10", "3.11"]
|
||||||
|
timeout-minutes: 20
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Check out repository code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python_version }}
|
||||||
|
cache: 'pip' # caching pip dependencies
|
||||||
|
|
||||||
|
- name: Update requirements.txt
|
||||||
|
run: |
|
||||||
|
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
|
||||||
|
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
||||||
|
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
||||||
|
sed -i 's#^bitsandbytes.*#bitsandbytes @ git+https://github.com/bitsandbytes-foundation/bitsandbytes.git@main#' requirements.txt
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip3 install --upgrade pip
|
||||||
|
pip3 install --upgrade packaging
|
||||||
|
pip3 install -U -e .
|
||||||
|
pip3 install -r requirements-tests.txt
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: |
|
||||||
|
pytest --ignore=tests/e2e/ tests/
|
||||||
|
|
||||||
|
- name: cleanup pip cache
|
||||||
|
run: |
|
||||||
|
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||||
|
|
||||||
|
docker-e2e-tests:
|
||||||
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
|
runs-on: [self-hosted, modal]
|
||||||
|
timeout-minutes: 60
|
||||||
|
needs: [pre-commit, pytest]
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.1
|
||||||
|
python_version: "3.10"
|
||||||
|
pytorch: 2.3.1
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras: mamba-ssm
|
||||||
|
nightly_build: "true"
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.3.1
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras: mamba-ssm
|
||||||
|
nightly_build: "true"
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.4.0
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
|
nightly_build: "true"
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Install Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
- name: Install Modal
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install modal==0.63.64 jinja2
|
||||||
|
- name: Update env vars
|
||||||
|
run: |
|
||||||
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
|
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||||
|
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||||
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
|
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
|
||||||
|
- name: Run tests job on Modal
|
||||||
|
run: |
|
||||||
|
modal run cicd.tests
|
||||||
36
.github/workflows/tests.yml
vendored
36
.github/workflows/tests.yml
vendored
@@ -26,6 +26,8 @@ jobs:
|
|||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
- uses: pre-commit/action@v3.0.0
|
- uses: pre-commit/action@v3.0.0
|
||||||
|
env:
|
||||||
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
pytest:
|
pytest:
|
||||||
name: PyTest
|
name: PyTest
|
||||||
@@ -57,6 +59,10 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pytest --ignore=tests/e2e/ tests/
|
pytest --ignore=tests/e2e/ tests/
|
||||||
|
|
||||||
|
- name: cleanup pip cache
|
||||||
|
run: |
|
||||||
|
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||||
|
|
||||||
docker-e2e-tests:
|
docker-e2e-tests:
|
||||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||||
# this job needs to be run on self-hosted GPU runners...
|
# this job needs to be run on self-hosted GPU runners...
|
||||||
@@ -68,27 +74,24 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 118
|
- cuda: 121
|
||||||
cuda_version: 11.8.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.2
|
pytorch: 2.3.1
|
||||||
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
|
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
|
axolotl_extras: mamba-ssm
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.1
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.2
|
|
||||||
num_gpus: 1
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.2.2
|
|
||||||
num_gpus: 1
|
|
||||||
- cuda: 121
|
|
||||||
cuda_version: 12.1.0
|
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.3.1
|
pytorch: 2.3.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
|
axolotl_extras: mamba-ssm
|
||||||
|
- cuda: 124
|
||||||
|
cuda_version: 12.4.1
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.4.0
|
||||||
|
num_gpus: 1
|
||||||
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -99,12 +102,13 @@ jobs:
|
|||||||
- name: Install Modal
|
- name: Install Modal
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install modal jinja2
|
pip install modal==0.63.64 jinja2
|
||||||
- name: Update env vars
|
- name: Update env vars
|
||||||
run: |
|
run: |
|
||||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||||
|
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||||
- name: Run tests job on Modal
|
- name: Run tests job on Modal
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ repos:
|
|||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
|
- id: no-commit-to-branch
|
||||||
|
args: ['--branch', 'main']
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: 23.3.0
|
rev: 23.3.0
|
||||||
hooks:
|
hooks:
|
||||||
|
|||||||
82
README.md
82
README.md
@@ -1,5 +1,9 @@
|
|||||||
# Axolotl
|
# Axolotl
|
||||||
|
|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|
|
||||||
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
|
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
|
||||||
|
|
||||||
Features:
|
Features:
|
||||||
@@ -22,38 +26,49 @@ Features:
|
|||||||
<td>
|
<td>
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
- [Introduction](#axolotl)
|
- [Axolotl](#axolotl)
|
||||||
- [Supported Features](#axolotl-supports)
|
- [Table of Contents](#table-of-contents)
|
||||||
- [Quickstart](#quickstart-)
|
- [Axolotl supports](#axolotl-supports)
|
||||||
- [Environment](#environment)
|
- [Quickstart ⚡](#quickstart-)
|
||||||
- [Docker](#docker)
|
- [Usage](#usage)
|
||||||
- [Conda/Pip venv](#condapip-venv)
|
- [Advanced Setup](#advanced-setup)
|
||||||
- [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod
|
- [Environment](#environment)
|
||||||
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
|
- [Docker](#docker)
|
||||||
- [Windows](#windows)
|
- [Conda/Pip venv](#condapip-venv)
|
||||||
- [Mac](#mac)
|
- [Cloud GPU](#cloud-gpu)
|
||||||
- [Google Colab](#google-colab)
|
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
|
||||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
- [LambdaLabs](#lambdalabs)
|
||||||
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
|
- [GCP](#gcp)
|
||||||
- [Dataset](#dataset)
|
- [Windows](#windows)
|
||||||
- [Config](#config)
|
- [Mac](#mac)
|
||||||
- [Train](#train)
|
- [Google Colab](#google-colab)
|
||||||
- [Inference](#inference-playground)
|
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||||
- [Merge LORA to Base](#merge-lora-to-base)
|
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
|
||||||
- [Special Tokens](#special-tokens)
|
- [Dataset](#dataset)
|
||||||
- [All Config Options](#all-config-options)
|
- [Config](#config)
|
||||||
- Advanced Topics
|
- [All Config Options](#all-config-options)
|
||||||
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
- [Train](#train)
|
||||||
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
- [Preprocess dataset](#preprocess-dataset)
|
||||||
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
- [Multi-GPU](#multi-gpu)
|
||||||
- [Common Errors](#common-errors-)
|
- [DeepSpeed](#deepspeed)
|
||||||
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
- [FSDP](#fsdp)
|
||||||
- [Debugging Axolotl](#debugging-axolotl)
|
- [FSDP + QLoRA](#fsdp--qlora)
|
||||||
- [Need Help?](#need-help-)
|
- [Weights \& Biases Logging](#weights--biases-logging)
|
||||||
- [Badge](#badge-)
|
- [Special Tokens](#special-tokens)
|
||||||
- [Community Showcase](#community-showcase)
|
- [Inference Playground](#inference-playground)
|
||||||
- [Contributing](#contributing-)
|
- [Merge LORA to base](#merge-lora-to-base)
|
||||||
- [Sponsors](#sponsors-)
|
- [Common Errors 🧰](#common-errors-)
|
||||||
|
- [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training)
|
||||||
|
- [Debugging Axolotl](#debugging-axolotl)
|
||||||
|
- [Need help? 🙋](#need-help-)
|
||||||
|
- [Badge ❤🏷️](#badge-️)
|
||||||
|
- [Community Showcase](#community-showcase)
|
||||||
|
- [Contributing 🤝](#contributing-)
|
||||||
|
- [Sponsors 🤝❤](#sponsors-)
|
||||||
|
- [💎 Diamond Sponsors - Contact directly](#-diamond-sponsors---contact-directly)
|
||||||
|
- [🥇 Gold Sponsors - $5000/mo](#-gold-sponsors---5000mo)
|
||||||
|
- [🥈 Silver Sponsors - $1000/mo](#-silver-sponsors---1000mo)
|
||||||
|
- [🥉 Bronze Sponsors - $500/mo](#-bronze-sponsors---500mo)
|
||||||
|
|
||||||
</td>
|
</td>
|
||||||
<td>
|
<td>
|
||||||
@@ -95,6 +110,7 @@ Features:
|
|||||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
|
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
|
||||||
|
| Jamba | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
|
||||||
|
|
||||||
✅: supported
|
✅: supported
|
||||||
❌: not supported
|
❌: not supported
|
||||||
@@ -333,7 +349,7 @@ For further and fine-grained use cases, please refer to the official [dstack doc
|
|||||||
|
|
||||||
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
|
||||||
|
|
||||||
See [these docs](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
|
See [the documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
|
||||||
|
|
||||||
### Config
|
### Config
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ website:
|
|||||||
- docs/nccl.qmd
|
- docs/nccl.qmd
|
||||||
- docs/mac.qmd
|
- docs/mac.qmd
|
||||||
- docs/multi-node.qmd
|
- docs/multi-node.qmd
|
||||||
|
- docs/unsloth.qmd
|
||||||
- section: "Dataset Formats"
|
- section: "Dataset Formats"
|
||||||
contents: docs/dataset-formats/*
|
contents: docs/dataset-formats/*
|
||||||
- section: "Reference"
|
- section: "Reference"
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ ENV BNB_CUDA_VERSION="{{ CUDA }}"
|
|||||||
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
|
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
|
||||||
ENV GITHUB_REF="{{ GITHUB_REF }}"
|
ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||||
|
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||||
@@ -23,10 +24,17 @@ RUN git fetch origin +$GITHUB_REF && \
|
|||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN pip install causal_conv1d
|
RUN pip install causal_conv1d
|
||||||
|
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||||
|
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
||||||
|
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||||
|
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||||
|
sed -i 's#^bitsandbytes.*#bitsandbytes @ git+https://github.com/bitsandbytes-foundation/bitsandbytes.git@main#' requirements.txt; \
|
||||||
|
fi
|
||||||
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -2,5 +2,5 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||||
pytest /workspace/axolotl/tests/e2e/patched/
|
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/
|
||||||
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
|
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ /workspace/axolotl/tests/e2e/
|
||||||
|
|||||||
77
cicd/multigpu.py
Normal file
77
cicd/multigpu.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""
|
||||||
|
modal application to run axolotl gpu tests in Modal
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
import modal
|
||||||
|
from jinja2 import select_autoescape
|
||||||
|
from modal import Image, Stub
|
||||||
|
|
||||||
|
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||||
|
|
||||||
|
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
|
||||||
|
template_env = jinja2.Environment(
|
||||||
|
loader=template_loader, autoescape=select_autoescape()
|
||||||
|
)
|
||||||
|
df_template = template_env.get_template("Dockerfile.jinja")
|
||||||
|
|
||||||
|
df_args = {
|
||||||
|
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||||
|
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||||
|
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.3.1"),
|
||||||
|
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.3.1"),
|
||||||
|
"CUDA": os.environ.get("CUDA", "121"),
|
||||||
|
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||||
|
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
dockerfile_contents = df_template.render(**df_args)
|
||||||
|
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
|
||||||
|
f.write(dockerfile_contents)
|
||||||
|
|
||||||
|
cicd_image = (
|
||||||
|
Image.from_dockerfile(
|
||||||
|
pathlib.Path(temp_dir) / "Dockerfile",
|
||||||
|
force_build=True,
|
||||||
|
gpu="A10G",
|
||||||
|
)
|
||||||
|
.env(df_args)
|
||||||
|
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||||
|
)
|
||||||
|
|
||||||
|
stub = Stub("Axolotl CI/CD", secrets=[])
|
||||||
|
|
||||||
|
|
||||||
|
N_GPUS = int(os.environ.get("N_GPUS", 2))
|
||||||
|
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
|
||||||
|
|
||||||
|
|
||||||
|
def run_cmd(cmd: str, run_folder: str):
|
||||||
|
import subprocess # nosec
|
||||||
|
|
||||||
|
# Propagate errors from subprocess.
|
||||||
|
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
|
||||||
|
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||||
|
|
||||||
|
|
||||||
|
@stub.function(
|
||||||
|
image=cicd_image,
|
||||||
|
gpu=GPU_CONFIG,
|
||||||
|
timeout=45 * 60,
|
||||||
|
cpu=8.0,
|
||||||
|
memory=131072 * N_GPUS,
|
||||||
|
)
|
||||||
|
def cicd_pytest():
|
||||||
|
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
@stub.local_entrypoint()
|
||||||
|
def main():
|
||||||
|
cicd_pytest.remote()
|
||||||
5
cicd/multigpu.sh
Executable file
5
cicd/multigpu.sh
Executable file
@@ -0,0 +1,5 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# only run one test at a time so as not to OOM the GPU
|
||||||
|
pytest -n1 /workspace/axolotl/tests/e2e/multigpu/
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
modal application to run axolotl gpu tests in Modal
|
modal application to run axolotl gpu tests in Modal
|
||||||
"""
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -21,11 +23,12 @@ df_template = template_env.get_template("Dockerfile.jinja")
|
|||||||
df_args = {
|
df_args = {
|
||||||
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
|
||||||
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
|
||||||
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.0.1"),
|
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.3.1"),
|
||||||
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.10-cu118-2.0.1"),
|
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.3.1"),
|
||||||
"CUDA": os.environ.get("CUDA", "118"),
|
"CUDA": os.environ.get("CUDA", "121"),
|
||||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||||
|
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
dockerfile_contents = df_template.render(**df_args)
|
dockerfile_contents = df_template.render(**df_args)
|
||||||
|
|||||||
@@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN pip install causal_conv1d
|
RUN pip install causal_conv1d
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# So we can test the Docker image
|
# So we can test the Docker image
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ ARG CUDNN_VERSION="8"
|
|||||||
ARG UBUNTU_VERSION="22.04"
|
ARG UBUNTU_VERSION="22.04"
|
||||||
ARG MAX_JOBS=4
|
ARG MAX_JOBS=4
|
||||||
|
|
||||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder
|
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG
|
|||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
|
||||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ FROM winglian/axolotl:$BASE_TAG
|
|||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
|
||||||
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
ENV HF_HUB_ENABLE_HF_TRANSFER="1"
|
||||||
|
|
||||||
|
|||||||
@@ -54,6 +54,14 @@ conversations where `from` is `prompter` `assistant` instead of default sharegpt
|
|||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
{"conversations": [{"from": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## sharegpt.load_ultrachat
|
||||||
|
|
||||||
|
conversations where the turns field is 'messages', human is 'user' and gpt is 'assistant'.
|
||||||
|
|
||||||
|
```{.json filename="data.jsonl"}
|
||||||
|
{"messages": [{"user": "...", "assistant": "..."}]}
|
||||||
|
```
|
||||||
|
|
||||||
## sharegpt_jokes
|
## sharegpt_jokes
|
||||||
|
|
||||||
creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
||||||
|
|||||||
19
docs/torchao.qmd
Normal file
19
docs/torchao.qmd
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
---
|
||||||
|
title: "PyTorch ao"
|
||||||
|
description: "Custom data types and layouts for training and inference"
|
||||||
|
---
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
Stable Release from the PyTorch index
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Nightly release
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
|
||||||
|
```
|
||||||
49
docs/unsloth.qmd
Normal file
49
docs/unsloth.qmd
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
---
|
||||||
|
title: "Unsloth"
|
||||||
|
description: "Hyper-optimized QLoRA finetuning for single GPUs"
|
||||||
|
---
|
||||||
|
|
||||||
|
### Overview
|
||||||
|
|
||||||
|
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
|
||||||
|
standard industry baselines.
|
||||||
|
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
|
||||||
|
to date libraries.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
|
||||||
|
pip install --no-deps --force-reinstall xformers==0.0.26.post1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using unsloth w Axolotl
|
||||||
|
|
||||||
|
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
|
||||||
|
|
||||||
|
Our unsloth integration is currently limited to the following model architectures:
|
||||||
|
- llama
|
||||||
|
|
||||||
|
These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
|
||||||
|
```yaml
|
||||||
|
unsloth_lora_mlp: true
|
||||||
|
unsloth_lora_qkv: true
|
||||||
|
unsloth_lora_o: true
|
||||||
|
```
|
||||||
|
|
||||||
|
These options are composable and can be used with multi-gpu finetuning
|
||||||
|
```yaml
|
||||||
|
unsloth_cross_entropy_loss: true
|
||||||
|
unsloth_rms_norm: true
|
||||||
|
unsloth_rope: true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Limitations
|
||||||
|
|
||||||
|
- Single GPU only; e.g. no multi-gpu support
|
||||||
|
- No deepspeed or FSDP support (requires multi-gpu)
|
||||||
|
- LoRA + QLoRA support only. No full fine tunes or fp8 support.
|
||||||
|
- Limited model architecture support. Llama, Phi, Gemma, Mistral only
|
||||||
|
- No MoE support.
|
||||||
@@ -43,7 +43,6 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install torch==\"2.1.2\"\n",
|
|
||||||
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
|
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
|
||||||
"!pip install flash-attn==\"2.5.0\"\n",
|
"!pip install flash-attn==\"2.5.0\"\n",
|
||||||
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
||||||
|
|||||||
@@ -6,5 +6,5 @@
|
|||||||
- ✅ qlora w/ deepspeed Zero-3 needs at least 2x GPUs and 67GiB VRAM (wtf?)
|
- ✅ qlora w/ deepspeed Zero-3 needs at least 2x GPUs and 67GiB VRAM (wtf?)
|
||||||
- ✅ qlora single-gpu, ~51GiB VRAM
|
- ✅ qlora single-gpu, ~51GiB VRAM
|
||||||
- ✅ multipack
|
- ✅ multipack
|
||||||
- ❓ FSDP
|
- ✅ FSDP
|
||||||
- ❓ 8-bit LoRA
|
- ❓ 8-bit LoRA
|
||||||
|
|||||||
61
examples/jamba/qlora_fsdp_large.yaml
Normal file
61
examples/jamba/qlora_fsdp_large.yaml
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
base_model: ai21labs/AI21-Jamba-1.5-Large
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
use_tensorboard: true
|
||||||
|
datasets:
|
||||||
|
- path: cgato/SlimOrcaDedupCleaned
|
||||||
|
type: chat_template
|
||||||
|
chat_template: jamba
|
||||||
|
drop_system_message: true
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: jamba-large-fsdp-qlora-ft
|
||||||
|
save_safetensors: true
|
||||||
|
adapter: qlora
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules: [down_proj,gate_proj,in_proj,k_proj,o_proj,out_proj,q_proj,up_proj,v_proj,x_proj]
|
||||||
|
lora_target_linear: false
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 2
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.00001
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: true
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 1
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: false
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: meta-llama/Meta-Llama-3-8B
|
base_model: NousResearch/Meta-Llama-3-8B
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
|||||||
81
examples/llama-3/instruct-dpo-lora-8b.yml
Normal file
81
examples/llama-3/instruct-dpo-lora-8b.yml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: llama3
|
||||||
|
rl: dpo
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||||
|
type: chat_template.default
|
||||||
|
chat_template: llama3
|
||||||
|
field_messages: conversation
|
||||||
|
field_chosen: chosen
|
||||||
|
field_rejected: rejected
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
roles:
|
||||||
|
system:
|
||||||
|
- system
|
||||||
|
user:
|
||||||
|
- user
|
||||||
|
assistant:
|
||||||
|
- assistant
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: meta-llama/Meta-Llama-3-8B-Instruct
|
base_model: NousResearch/Meta-Llama-3-8B-Instruct
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
@@ -74,3 +74,5 @@ deepspeed:
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
fsdp:
|
fsdp:
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|end_of_text|>
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: meta-llama/Meta-Llama-3-8B
|
base_model: NousResearch/Meta-Llama-3-8B
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
|||||||
63
examples/llama-3/qlora-fsdp-405b.yaml
Normal file
63
examples/llama-3/qlora-fsdp-405b.yaml
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
base_model: hugging-quants/Meta-Llama-3.1-405B-BNB-NF4-BF16
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/out/qlora-llama3_1-405b
|
||||||
|
save_safetensors: true
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_modules:
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 2
|
||||||
|
optimizer: adamw_torch
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.00001
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: true
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
special_tokens:
|
||||||
|
pad_token: <|finetune_right_pad_id|>
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: meta-llama/Meta-Llama-3-8B
|
base_model: NousResearch/Meta-Llama-3-8B
|
||||||
model_type: AutoModelForCausalLM
|
model_type: AutoModelForCausalLM
|
||||||
tokenizer_type: AutoTokenizer
|
tokenizer_type: AutoTokenizer
|
||||||
|
|
||||||
|
|||||||
@@ -72,4 +72,5 @@ fsdp_config:
|
|||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
special_tokens:
|
special_tokens:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
|
base_model: TinyLlama/TinyLlama_v1.1
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
|
base_model: TinyLlama/TinyLlama_v1.1
|
||||||
model_type: LlamaForCausalLM
|
tokenizer_type: AutoTokenizer
|
||||||
tokenizer_type: LlamaTokenizer
|
|
||||||
|
|
||||||
load_in_8bit: true
|
load_in_8bit: true
|
||||||
load_in_4bit: false
|
load_in_4bit: false
|
||||||
|
|||||||
@@ -9,9 +9,9 @@ strict: false
|
|||||||
|
|
||||||
max_steps: 200
|
max_steps: 200
|
||||||
pretraining_dataset:
|
pretraining_dataset:
|
||||||
path: c4
|
- path: allenai/c4
|
||||||
name: en
|
name: en
|
||||||
type: pretrain
|
type: pretrain
|
||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.0
|
val_set_size: 0.0
|
||||||
output_dir: ./outputs/model-out
|
output_dir: ./outputs/model-out
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
|
base_model: TinyLlama/TinyLlama_v1.1
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,18 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.11.1
|
peft==0.12.0
|
||||||
transformers==4.42.3
|
transformers==4.44.0
|
||||||
tokenizers==0.19.1
|
tokenizers>=0.19.1
|
||||||
bitsandbytes==0.43.1
|
bitsandbytes==0.43.3
|
||||||
accelerate==0.32.0
|
accelerate==0.33.0
|
||||||
deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b
|
datasets==2.20.0
|
||||||
|
deepspeed==0.14.4
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
requests
|
requests
|
||||||
datasets==2.19.1
|
flash-attn==2.6.3
|
||||||
flash-attn==2.6.1
|
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
@@ -21,23 +21,24 @@ optimum==1.16.2
|
|||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
numba
|
numba
|
||||||
numpy>=1.24.4
|
numpy>=1.24.4,<=2.0.1
|
||||||
# qlora things
|
# qlora things
|
||||||
evaluate==0.4.1
|
evaluate==0.4.1
|
||||||
scipy
|
scipy
|
||||||
scikit-learn==1.2.2
|
scikit-learn==1.4.2
|
||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
|
autoawq>=0.2.5
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
# remote filesystems
|
# remote filesystems
|
||||||
s3fs
|
s3fs>=2024.5.0
|
||||||
gcsfs
|
gcsfs>=2024.5.0
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl==0.9.6
|
trl==0.9.6
|
||||||
|
|||||||
12
setup.py
12
setup.py
@@ -80,13 +80,13 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.6.1",
|
"flash-attn==2.6.3",
|
||||||
],
|
],
|
||||||
"fused-dense-lib": [
|
"fused-dense-lib": [
|
||||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib",
|
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
|
"deepspeed==0.14.4",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
],
|
],
|
||||||
"mamba-ssm": [
|
"mamba-ssm": [
|
||||||
@@ -104,5 +104,11 @@ setup(
|
|||||||
"galore": [
|
"galore": [
|
||||||
"galore_torch",
|
"galore_torch",
|
||||||
],
|
],
|
||||||
|
"optimizers": [
|
||||||
|
"galore_torch",
|
||||||
|
"lion-pytorch==0.1.2",
|
||||||
|
"lomo-optim==0.1.1",
|
||||||
|
"torch-optimi==0.2.1",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ from axolotl.utils.distributed import is_main_process
|
|||||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.trainer import prepare_optim_env
|
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
@@ -375,13 +375,15 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
|||||||
cfg,
|
cfg,
|
||||||
capabilities={
|
capabilities={
|
||||||
"bf16": is_torch_bf16_gpu_available(),
|
"bf16": is_torch_bf16_gpu_available(),
|
||||||
"n_gpu": os.environ.get("WORLD_SIZE", 1),
|
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
"compute_capability": gpu_version,
|
"compute_capability": gpu_version,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
prepare_optim_env(cfg)
|
prepare_optim_env(cfg)
|
||||||
|
|
||||||
|
prepare_opinionated_env(cfg)
|
||||||
|
|
||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
|
|
||||||
normalize_cfg_datasets(cfg)
|
normalize_cfg_datasets(cfg)
|
||||||
|
|||||||
204
src/axolotl/cli/merge_sharded_fsdp_weights.py
Normal file
204
src/axolotl/cli/merge_sharded_fsdp_weights.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
"""
|
||||||
|
This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Union
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
|
import torch.distributed.checkpoint as dist_cp
|
||||||
|
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
|
||||||
|
import transformers
|
||||||
|
from accelerate.utils import (
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_INDEX_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
is_torch_version,
|
||||||
|
)
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from huggingface_hub import split_torch_state_dict_into_shards
|
||||||
|
from safetensors.torch import save_file as safe_save_file
|
||||||
|
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
|
||||||
|
|
||||||
|
from axolotl.cli import load_cfg, print_axolotl_text_art
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights")
|
||||||
|
|
||||||
|
|
||||||
|
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
|
||||||
|
"""
|
||||||
|
A custom planner to cast tensors to bfloat16 on the fly during loading.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
|
||||||
|
tensor.copy_(tensor.to(torch.bfloat16))
|
||||||
|
|
||||||
|
|
||||||
|
def _distributed_checkpoint_to_merged_weights(
|
||||||
|
checkpoint_dir: Union[str, Path],
|
||||||
|
save_path: str,
|
||||||
|
safe_serialization: bool = False,
|
||||||
|
max_shard_size: str = "5GB",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`
|
||||||
|
|
||||||
|
Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
state_dict: Dict = {}
|
||||||
|
save_path_ = Path(save_path)
|
||||||
|
save_path_.mkdir(exist_ok=True)
|
||||||
|
dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access
|
||||||
|
state_dict,
|
||||||
|
storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
|
||||||
|
planner=BFloat16CastPlanner(), # pylint: disable=protected-access
|
||||||
|
no_dist=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# To handle if state is a dict like {model: {...}}
|
||||||
|
if len(state_dict.keys()) == 1:
|
||||||
|
state_dict = state_dict[list(state_dict)[0]]
|
||||||
|
|
||||||
|
# Ensure all tensors are in bfloat16
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16:
|
||||||
|
state_dict[key] = value.to(torch.bfloat16)
|
||||||
|
|
||||||
|
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||||
|
|
||||||
|
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||||
|
".safetensors", "{suffix}.safetensors"
|
||||||
|
)
|
||||||
|
state_dict_split = split_torch_state_dict_into_shards(
|
||||||
|
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
|
||||||
|
)
|
||||||
|
# Save index if sharded
|
||||||
|
index = None
|
||||||
|
if state_dict_split.is_sharded:
|
||||||
|
index = {
|
||||||
|
"metadata": state_dict_split.metadata,
|
||||||
|
"weight_map": state_dict_split.tensor_to_filename,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
||||||
|
|
||||||
|
for shard_file, tensors in filename_to_tensors:
|
||||||
|
shard = {tensor: state_dict[tensor] for tensor in tensors}
|
||||||
|
|
||||||
|
if safe_serialization:
|
||||||
|
safe_save_file(
|
||||||
|
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch.save(shard, os.path.join(save_path_, shard_file))
|
||||||
|
|
||||||
|
if index is not None:
|
||||||
|
save_index_file = (
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||||
|
)
|
||||||
|
save_index_file = os.path.join(save_path_, save_index_file)
|
||||||
|
# Save the index as well
|
||||||
|
with open(save_index_file, "w", encoding="utf-8") as fout:
|
||||||
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||||
|
fout.write(content)
|
||||||
|
|
||||||
|
return save_path_
|
||||||
|
|
||||||
|
|
||||||
|
def merge_fsdp_weights(
|
||||||
|
checkpoint_dir: str,
|
||||||
|
output_path: str,
|
||||||
|
safe_serialization: bool = False,
|
||||||
|
remove_checkpoint_dir: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
|
||||||
|
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
|
||||||
|
`safe_serialization` else `pytorch_model.bin`.
|
||||||
|
|
||||||
|
Note: this is a CPU-bound process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_dir (`str`):
|
||||||
|
The directory containing the FSDP checkpoints (can be either the model or optimizer).
|
||||||
|
output_path (`str`):
|
||||||
|
The path to save the merged checkpoint.
|
||||||
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to save the merged weights with safetensors (recommended).
|
||||||
|
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to remove the checkpoint directory after merging.
|
||||||
|
"""
|
||||||
|
checkpoint_dir_ = Path(checkpoint_dir)
|
||||||
|
from accelerate.state import PartialState
|
||||||
|
|
||||||
|
if not is_torch_version(">=", "2.3.0"):
|
||||||
|
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
|
||||||
|
|
||||||
|
# Verify that the checkpoint directory exists
|
||||||
|
if not checkpoint_dir_.exists():
|
||||||
|
model_path_exists = (checkpoint_dir_ / "pytorch_model_fsdp_0").exists()
|
||||||
|
optimizer_path_exists = (checkpoint_dir_ / "optimizer_0").exists()
|
||||||
|
err = f"Tried to load from {checkpoint_dir_} but couldn't find a valid metadata file."
|
||||||
|
if model_path_exists and optimizer_path_exists:
|
||||||
|
err += (
|
||||||
|
" However, potential model and optimizer checkpoint directories exist."
|
||||||
|
)
|
||||||
|
err += f"Please pass in either {checkpoint_dir_}/pytorch_model_fsdp_0 or {checkpoint_dir_}/optimizer_0"
|
||||||
|
err += "instead."
|
||||||
|
elif model_path_exists:
|
||||||
|
err += " However, a potential model checkpoint directory exists."
|
||||||
|
err += (
|
||||||
|
f"Please try passing in {checkpoint_dir_}/pytorch_model_fsdp_0 instead."
|
||||||
|
)
|
||||||
|
elif optimizer_path_exists:
|
||||||
|
err += " However, a potential optimizer checkpoint directory exists."
|
||||||
|
err += f"Please try passing in {checkpoint_dir_}/optimizer_0 instead."
|
||||||
|
raise ValueError(err)
|
||||||
|
|
||||||
|
# To setup `save` to work
|
||||||
|
state = PartialState()
|
||||||
|
if state.is_main_process:
|
||||||
|
LOG.info(f"Merging FSDP weights from {checkpoint_dir_}")
|
||||||
|
save_path = _distributed_checkpoint_to_merged_weights(
|
||||||
|
checkpoint_dir_, output_path, safe_serialization
|
||||||
|
)
|
||||||
|
LOG.info(f"Successfully merged FSDP weights and saved to {save_path}")
|
||||||
|
if remove_checkpoint_dir:
|
||||||
|
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
|
||||||
|
shutil.rmtree(checkpoint_dir_)
|
||||||
|
state.wait_for_everyone()
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
print_axolotl_text_art()
|
||||||
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
|
return_remaining_strings=True
|
||||||
|
)
|
||||||
|
parsed_cli_args.merge_lora = True
|
||||||
|
|
||||||
|
parsed_cfg = load_cfg(
|
||||||
|
config,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
|
||||||
|
merge_fsdp_weights(
|
||||||
|
checkpoint_dir=str(fsdp_dir),
|
||||||
|
output_path=str(Path(parsed_cfg.output_dir) / "merged"),
|
||||||
|
safe_serialization=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
load_dotenv()
|
||||||
|
fire.Fire(do_cli)
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
CLI to run training on a model
|
CLI to run training on a model
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -76,8 +77,19 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
|
|
||||||
if parsed_cli_args.download:
|
if parsed_cli_args.download:
|
||||||
model_name = parsed_cfg.base_model
|
model_name = parsed_cfg.base_model
|
||||||
with init_empty_weights():
|
with warnings.catch_warnings():
|
||||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
# there are a bunch of useless UserWarnings about
|
||||||
|
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
with init_empty_weights(include_buffers=True):
|
||||||
|
# fmt: off
|
||||||
|
try:
|
||||||
|
AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name, trust_remote_code=True
|
||||||
|
)
|
||||||
|
except Exception as exc: # pylint: disable=broad-exception-caught,unused-variable # nosec B110 # noqa F841
|
||||||
|
pass
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
LOG.info(
|
LOG.info(
|
||||||
Fore.GREEN
|
Fore.GREEN
|
||||||
|
|||||||
15
src/axolotl/common/architectures.py
Normal file
15
src/axolotl/common/architectures.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Common architecture specific constants
|
||||||
|
"""
|
||||||
|
|
||||||
|
MOE_ARCH_BLOCK = {
|
||||||
|
"dbrx": "DbrxFFN",
|
||||||
|
"jamba": "JambaSparseMoeBlock",
|
||||||
|
"jetmoe": [
|
||||||
|
"JetMoeMoA",
|
||||||
|
"JetMoeMoE",
|
||||||
|
],
|
||||||
|
"mixtral": "MixtralSparseMoeBlock",
|
||||||
|
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||||
|
"deepseek_v2": "DeepseekV2MoE",
|
||||||
|
}
|
||||||
150
src/axolotl/core/tokenizer_utils.py
Normal file
150
src/axolotl/core/tokenizer_utils.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""
|
||||||
|
helper functions for fixing the embeddings/tokenizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode
|
||||||
|
def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
|
||||||
|
"""
|
||||||
|
Many of the newer models have reserved tokens that are not trained.
|
||||||
|
"""
|
||||||
|
embedding_matrix = model.get_input_embeddings().weight
|
||||||
|
lm_head_matrix = model.get_output_embeddings().weight
|
||||||
|
|
||||||
|
# Get untrained tokens
|
||||||
|
indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps
|
||||||
|
where_untrained = torch.where(indicator_untrained)[0]
|
||||||
|
n_untrained = where_untrained.shape[0]
|
||||||
|
n_trained = embedding_matrix.shape[0] - n_untrained
|
||||||
|
|
||||||
|
# Get set and actual tokens
|
||||||
|
where_untrained = where_untrained.tolist()
|
||||||
|
if len(where_untrained) == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Remove untrained indices where it's longer
|
||||||
|
|
||||||
|
where_untrained_set = frozenset(where_untrained)
|
||||||
|
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
|
||||||
|
# Remove None items in actual_bad_tokens
|
||||||
|
actual_bad_tokens = [x for x in actual_bad_tokens if x is not None]
|
||||||
|
|
||||||
|
# Check if tokenizer and training datasets have bad tokens
|
||||||
|
if_bad_first = False
|
||||||
|
if_bad_second = False
|
||||||
|
# Check tokenizer's chat template for any untrained tokens
|
||||||
|
chat_template = getattr(tokenizer, "chat_template", None)
|
||||||
|
if chat_template is not None:
|
||||||
|
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
|
||||||
|
|
||||||
|
# Check the first 250, last 250 input_ids
|
||||||
|
size_dataset = len(train_dataset)
|
||||||
|
size = min(size_dataset, 250)
|
||||||
|
for j in range(size):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
if_bad = any(item in where_untrained_set for item in input_ids)
|
||||||
|
if if_bad:
|
||||||
|
if_bad_second = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check last 250
|
||||||
|
if not if_bad_second:
|
||||||
|
left = max(size_dataset - 250, 0)
|
||||||
|
for j in range(left, size_dataset):
|
||||||
|
input_ids = train_dataset[j]
|
||||||
|
if "input_ids" in input_ids:
|
||||||
|
input_ids = input_ids["input_ids"]
|
||||||
|
if_bad = any(item in where_untrained_set for item in input_ids)
|
||||||
|
if if_bad:
|
||||||
|
if_bad_second = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check if bad tokens exists!
|
||||||
|
if not if_bad_first and not if_bad_second:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Count all the possible bad tokens
|
||||||
|
final_counts = np.zeros(
|
||||||
|
max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64
|
||||||
|
)
|
||||||
|
|
||||||
|
def mapping(examples):
|
||||||
|
input_ids = examples["input_ids"]
|
||||||
|
counter = np.fromiter(itertools.chain.from_iterable(input_ids), dtype=np.int32)
|
||||||
|
np.add.at(final_counts, counter, 1)
|
||||||
|
|
||||||
|
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
|
||||||
|
|
||||||
|
# Get sum of all items
|
||||||
|
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
|
||||||
|
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
|
||||||
|
|
||||||
|
# Remove bad tokens
|
||||||
|
sum_embedding -= torch.sum(
|
||||||
|
embedding_matrix[where_untrained], dtype=torch.float32, axis=0
|
||||||
|
)
|
||||||
|
sum_lm_head -= torch.sum(
|
||||||
|
lm_head_matrix[where_untrained], dtype=torch.float32, axis=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Find correct average by dividing by sum of trained tokens
|
||||||
|
mean_embedding = sum_embedding / n_trained
|
||||||
|
mean_lm_head = sum_lm_head / n_trained
|
||||||
|
|
||||||
|
# Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen
|
||||||
|
scaling = final_counts[where_untrained] / max(final_counts.max(), 1)
|
||||||
|
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
|
||||||
|
mean_embedding = (
|
||||||
|
mean_embedding.repeat(
|
||||||
|
(
|
||||||
|
n_untrained,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
* scaling
|
||||||
|
)
|
||||||
|
mean_lm_head = (
|
||||||
|
mean_lm_head.repeat(
|
||||||
|
(
|
||||||
|
n_untrained,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
* scaling
|
||||||
|
)
|
||||||
|
where_null = scaling.ravel() == 0
|
||||||
|
mean_embedding[where_null] = 0
|
||||||
|
mean_lm_head[where_null] = 0
|
||||||
|
|
||||||
|
# Set them to the mean
|
||||||
|
embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype)
|
||||||
|
lm_head_matrix[where_untrained] = mean_lm_head.to(lm_head_matrix.dtype)
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
for _ in range(3):
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return True
|
||||||
@@ -8,6 +8,7 @@ import importlib
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@@ -28,9 +29,18 @@ from transformers import (
|
|||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
|
from trl import (
|
||||||
|
CPOConfig,
|
||||||
|
CPOTrainer,
|
||||||
|
DPOConfig,
|
||||||
|
DPOTrainer,
|
||||||
|
KTOConfig,
|
||||||
|
KTOTrainer,
|
||||||
|
ORPOConfig,
|
||||||
|
ORPOTrainer,
|
||||||
|
)
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
@@ -226,6 +236,18 @@ class AxolotlTrainingMixins:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||||
)
|
)
|
||||||
|
alternate_optimizer: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "workaround to pass an alternate optimizer to the HF trainer"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
alternate_lr_scheduler_type: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -259,58 +281,24 @@ class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
@dataclass
|
||||||
|
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
|
||||||
"""
|
"""
|
||||||
Extend the base Trainer for axolotl helpers
|
CPO config for CPO training
|
||||||
|
"""
|
||||||
|
|
||||||
|
simpo_gamma: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "simpo gamma parameter"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerMixin(Trainer):
|
||||||
|
"""
|
||||||
|
Mixin class for scheduler setup in CausalTrainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
args = None # type: AxolotlTrainingArguments
|
args = None # type: AxolotlTrainingArguments
|
||||||
tag_names = ["axolotl"]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*_args,
|
|
||||||
num_epochs=1,
|
|
||||||
bench_data_collator=None,
|
|
||||||
eval_data_collator=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.num_epochs = num_epochs
|
|
||||||
self.bench_data_collator = bench_data_collator
|
|
||||||
self.eval_data_collator = eval_data_collator
|
|
||||||
super().__init__(*_args, **kwargs)
|
|
||||||
self.train_data_collator = self.data_collator
|
|
||||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
|
||||||
if self.args.orpo_alpha:
|
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
|
||||||
|
|
||||||
def create_optimizer(self):
|
|
||||||
if self.args.loraplus_lr_ratio is None:
|
|
||||||
return super().create_optimizer()
|
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
|
||||||
self.args,
|
|
||||||
opt_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
|
||||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
opt_model,
|
|
||||||
optimizer_cls,
|
|
||||||
optimizer_kwargs,
|
|
||||||
loraplus_lr_ratio,
|
|
||||||
loraplus_lr_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
|
||||||
self.optimizer
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.optimizer
|
|
||||||
|
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||||
@@ -336,7 +324,23 @@ class AxolotlTrainer(Trainer):
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if use_cosine_quadratic:
|
if self.args.alternate_lr_scheduler_type == "one_cycle":
|
||||||
|
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||||
|
pct_start = num_warmup_steps / num_training_steps
|
||||||
|
extra_lr_kwargs = {}
|
||||||
|
if "pct_start" not in self.args.lr_scheduler_kwargs:
|
||||||
|
extra_lr_kwargs["pct_start"] = pct_start
|
||||||
|
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
|
||||||
|
extra_lr_kwargs["anneal_strategy"] = "cos"
|
||||||
|
|
||||||
|
self.lr_scheduler = OneCycleLR(
|
||||||
|
optimizer,
|
||||||
|
max_lr=self.args.learning_rate,
|
||||||
|
total_steps=num_training_steps,
|
||||||
|
**extra_lr_kwargs,
|
||||||
|
**self.args.lr_scheduler_kwargs,
|
||||||
|
)
|
||||||
|
elif use_cosine_quadratic:
|
||||||
if use_cosine_min_lr:
|
if use_cosine_min_lr:
|
||||||
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
|
||||||
|
|
||||||
@@ -374,6 +378,125 @@ class AxolotlTrainer(Trainer):
|
|||||||
|
|
||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||||
|
"""
|
||||||
|
Extend the base Trainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
args = None # type: AxolotlTrainingArguments
|
||||||
|
tag_names = ["axolotl"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*_args,
|
||||||
|
num_epochs=1,
|
||||||
|
bench_data_collator=None,
|
||||||
|
eval_data_collator=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.num_epochs = num_epochs
|
||||||
|
self.bench_data_collator = bench_data_collator
|
||||||
|
self.eval_data_collator = eval_data_collator
|
||||||
|
super().__init__(*_args, **kwargs)
|
||||||
|
self.train_data_collator = self.data_collator
|
||||||
|
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||||
|
if self.args.orpo_alpha:
|
||||||
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
def _wrap_model(self, model, training=True, dataloader=None):
|
||||||
|
if self.args.torch_compile:
|
||||||
|
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
|
||||||
|
256
|
||||||
|
)
|
||||||
|
model = torch.compile(
|
||||||
|
model,
|
||||||
|
backend=self.args.torch_compile_backend,
|
||||||
|
mode=self.args.torch_compile_mode,
|
||||||
|
)
|
||||||
|
return super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||||
|
|
||||||
|
def create_optimizer(self):
|
||||||
|
if (
|
||||||
|
self.args.loraplus_lr_ratio is None
|
||||||
|
and self.args.alternate_optimizer
|
||||||
|
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
|
||||||
|
):
|
||||||
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in opt_model.named_parameters()
|
||||||
|
if (n in decay_parameters and p.requires_grad)
|
||||||
|
],
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in opt_model.named_parameters()
|
||||||
|
if (n not in decay_parameters and p.requires_grad)
|
||||||
|
],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
||||||
|
self.args,
|
||||||
|
opt_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.loraplus_lr_ratio is not None:
|
||||||
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
|
loraplus_lr_embedding = getattr(
|
||||||
|
self.args, "loraplus_lr_embedding", None
|
||||||
|
)
|
||||||
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
opt_model,
|
||||||
|
optimizer_cls,
|
||||||
|
optimizer_kwargs,
|
||||||
|
loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding,
|
||||||
|
)
|
||||||
|
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||||
|
from optimi import AdamW
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
AdamW(
|
||||||
|
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif self.args.alternate_optimizer == "ao_adamw_4bit":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW4bit
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
)
|
||||||
|
elif self.args.alternate_optimizer == "ao_adamw_8bit":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamW8bit
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
)
|
||||||
|
elif self.args.alternate_optimizer == "ao_adamw_fp8":
|
||||||
|
from torchao.prototype.low_bit_optim import AdamWFp8
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sagemaker_mp_enabled():
|
||||||
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
self.optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.optimizer
|
||||||
|
|
||||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
if self.args.sample_packing and not self.args.pretraining:
|
if self.args.sample_packing and not self.args.pretraining:
|
||||||
if self.args.multipack_real_batches:
|
if self.args.multipack_real_batches:
|
||||||
@@ -738,6 +861,14 @@ class AxolotlTrainer(Trainer):
|
|||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
self._stored_metrics[train_eval][key].append(value)
|
self._stored_metrics[train_eval][key].append(value)
|
||||||
|
|
||||||
|
def _save_checkpoint(self, model, trial, metrics=None):
|
||||||
|
# make sure the checkpoint dir exists, since trainer is flakey
|
||||||
|
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||||
|
run_dir = self._get_output_dir(trial=trial)
|
||||||
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
return super()._save_checkpoint(model, trial, metrics=metrics)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -767,37 +898,6 @@ class AxolotlMambaTrainer(AxolotlTrainer):
|
|||||||
return lm_loss
|
return lm_loss
|
||||||
|
|
||||||
|
|
||||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
|
||||||
"""
|
|
||||||
Trainer subclass that uses the OneCycleLR scheduler
|
|
||||||
"""
|
|
||||||
|
|
||||||
tag_names = ["axolotl", "onecycle"]
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.lr_scheduler = None
|
|
||||||
|
|
||||||
def create_scheduler(
|
|
||||||
self,
|
|
||||||
num_training_steps: int,
|
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
||||||
):
|
|
||||||
optimizer = self.optimizer if optimizer is None else optimizer
|
|
||||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
|
||||||
pct_start = num_warmup_steps / num_training_steps
|
|
||||||
|
|
||||||
self.lr_scheduler = OneCycleLR(
|
|
||||||
optimizer,
|
|
||||||
max_lr=self.args.learning_rate,
|
|
||||||
total_steps=num_training_steps,
|
|
||||||
pct_start=pct_start,
|
|
||||||
div_factor=6,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.lr_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
class ReLoRATrainer(AxolotlTrainer):
|
class ReLoRATrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
Trainer subclass that uses the OneCycleLR scheduler
|
Trainer subclass that uses the OneCycleLR scheduler
|
||||||
@@ -837,7 +937,7 @@ class ReLoRATrainer(AxolotlTrainer):
|
|||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(DPOTrainer):
|
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base DPOTrainer for axolotl helpers
|
Extend the base DPOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -898,7 +998,7 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(ORPOTrainer):
|
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -906,7 +1006,7 @@ class AxolotlORPOTrainer(ORPOTrainer):
|
|||||||
tag_names = ["axolotl", "orpo"]
|
tag_names = ["axolotl", "orpo"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(KTOTrainer):
|
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -914,6 +1014,14 @@ class AxolotlKTOTrainer(KTOTrainer):
|
|||||||
tag_names = ["axolotl", "kto"]
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||||
|
"""
|
||||||
|
Extend the base CPOTrainer for axolotl helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "cpo"]
|
||||||
|
|
||||||
|
|
||||||
class TrainerBuilderBase(abc.ABC):
|
class TrainerBuilderBase(abc.ABC):
|
||||||
"""
|
"""
|
||||||
Base class for trainer builder
|
Base class for trainer builder
|
||||||
@@ -1073,10 +1181,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def _get_trainer_cls(self):
|
def _get_trainer_cls(self):
|
||||||
if self.cfg.lr_scheduler == "one_cycle" and (
|
|
||||||
self.cfg.fsdp or self.cfg.adapter == "qlora"
|
|
||||||
):
|
|
||||||
return OneCycleLRSchedulerTrainer
|
|
||||||
if self.cfg.relora_steps:
|
if self.cfg.relora_steps:
|
||||||
return ReLoRATrainer
|
return ReLoRATrainer
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
@@ -1126,7 +1230,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||||
if self.cfg.fsdp_config:
|
if self.cfg.fsdp_config:
|
||||||
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
|
training_arguments_kwargs["fsdp_config"] = {
|
||||||
|
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
|
||||||
|
}
|
||||||
|
|
||||||
if self.cfg.adapter == "qlora":
|
if self.cfg.adapter == "qlora":
|
||||||
training_arguments_kwargs["qlora"] = True
|
training_arguments_kwargs["qlora"] = True
|
||||||
@@ -1235,6 +1341,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"torch_compile_backend"
|
"torch_compile_backend"
|
||||||
] = self.cfg.torch_compile_backend
|
] = self.cfg.torch_compile_backend
|
||||||
|
if self.cfg.torch_compile_mode:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"torch_compile_mode"
|
||||||
|
] = self.cfg.torch_compile_mode
|
||||||
|
|
||||||
# DDP Config
|
# DDP Config
|
||||||
if self.cfg.ddp_timeout:
|
if self.cfg.ddp_timeout:
|
||||||
@@ -1320,12 +1430,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"loraplus_lr_embedding"
|
"loraplus_lr_embedding"
|
||||||
] = self.cfg.loraplus_lr_embedding
|
] = self.cfg.loraplus_lr_embedding
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = (
|
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||||
self.cfg.lr_scheduler
|
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
if self.cfg.lr_scheduler
|
training_arguments_kwargs[
|
||||||
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
"alternate_lr_scheduler_type"
|
||||||
else "cosine"
|
] = self.cfg.lr_scheduler
|
||||||
)
|
else:
|
||||||
|
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||||
|
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
||||||
|
)
|
||||||
training_arguments_kwargs["lr_scheduler_kwargs"] = (
|
training_arguments_kwargs["lr_scheduler_kwargs"] = (
|
||||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||||
)
|
)
|
||||||
@@ -1396,6 +1509,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
|
if self.cfg.optimizer in [
|
||||||
|
"optimi_adamw",
|
||||||
|
"ao_adamw_4bit",
|
||||||
|
"ao_adamw_8bit",
|
||||||
|
"ao_adamw_fp8",
|
||||||
|
]:
|
||||||
|
# Set default so transformers doesn't throw
|
||||||
|
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||||
|
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
|
||||||
|
|
||||||
if self.cfg.optimizer == "lion_pytorch":
|
if self.cfg.optimizer == "lion_pytorch":
|
||||||
from lion_pytorch import Lion
|
from lion_pytorch import Lion
|
||||||
|
|
||||||
@@ -1424,6 +1547,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
sys.path.append(self.cfg.torchdistx_path)
|
sys.path.append(self.cfg.torchdistx_path)
|
||||||
importlib.import_module("torchdistx")
|
importlib.import_module("torchdistx")
|
||||||
|
|
||||||
|
if self.cfg.accelerator_config:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"accelerator_config"
|
||||||
|
] = self.cfg.accelerator_config
|
||||||
|
|
||||||
training_args = (
|
training_args = (
|
||||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
**training_arguments_kwargs,
|
**training_arguments_kwargs,
|
||||||
@@ -1617,16 +1745,27 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
# default to saving each epoch if not defined
|
# default to saving each epoch if not defined
|
||||||
training_args_kwargs["save_strategy"] = "epoch"
|
training_args_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
|
if self.cfg.rl_beta:
|
||||||
|
training_args_kwargs["beta"] = self.cfg.rl_beta
|
||||||
if self.cfg.orpo_alpha:
|
if self.cfg.orpo_alpha:
|
||||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
|
|
||||||
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
training_args_cls = AxolotlDPOConfig
|
training_args_cls = AxolotlDPOConfig
|
||||||
if self.cfg.rpo_alpha is not None:
|
if self.cfg.rpo_alpha is not None:
|
||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||||
|
|
||||||
|
if self.cfg.rl == "simpo":
|
||||||
|
training_args_cls = AxolotlCPOConfig
|
||||||
|
training_args_kwargs["loss_type"] = "simpo"
|
||||||
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
|
||||||
|
if self.cfg.cpo_alpha is not None:
|
||||||
|
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||||
|
|
||||||
if self.cfg.rl == "orpo":
|
if self.cfg.rl == "orpo":
|
||||||
training_args_cls = AxolotlORPOConfig
|
training_args_cls = AxolotlORPOConfig
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
if self.cfg.max_prompt_len:
|
if self.cfg.max_prompt_len:
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
@@ -1634,7 +1773,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rl == "kto":
|
if self.cfg.rl == "kto":
|
||||||
training_args_cls = AxolotlKTOConfig
|
training_args_cls = AxolotlKTOConfig
|
||||||
|
|
||||||
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
|
||||||
training_args_kwargs["desirable_weight"] = (
|
training_args_kwargs["desirable_weight"] = (
|
||||||
self.cfg.kto_desirable_weight or 1.0
|
self.cfg.kto_desirable_weight or 1.0
|
||||||
)
|
)
|
||||||
@@ -1680,7 +1818,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo"]:
|
if self.cfg.rl in ["dpo", "ipo"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|
||||||
# these aren't used for the ORPO trainer
|
# these aren't used for the ORPO trainer
|
||||||
@@ -1688,14 +1825,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["max_target_length"] = None
|
dpo_trainer_kwargs["max_target_length"] = None
|
||||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||||
dpo_trainer_kwargs["generate_during_eval"] = True
|
dpo_trainer_kwargs["generate_during_eval"] = True
|
||||||
if self.cfg.rl == "dpo":
|
|
||||||
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
elif self.cfg.rl in ["kto"]:
|
elif self.cfg.rl in ["kto"]:
|
||||||
trainer_cls = AxolotlKTOTrainer
|
trainer_cls = AxolotlKTOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
|
elif self.cfg.rl in ["simpo"]:
|
||||||
|
trainer_cls = AxolotlCPOTrainer
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
dpo_trainer = trainer_cls(
|
dpo_trainer = trainer_cls(
|
||||||
@@ -1708,6 +1846,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
)
|
)
|
||||||
if self.cfg.fsdp:
|
if self.cfg.fsdp:
|
||||||
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
|
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
|
||||||
|
if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model:
|
||||||
|
ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype)
|
||||||
|
|
||||||
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
||||||
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
||||||
|
|||||||
@@ -78,6 +78,33 @@ def replace_llama_qkv_with_fused(model):
|
|||||||
set_module_name(model, name, qkv)
|
set_module_name(model, name, qkv)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_llama_cross_entropy():
|
||||||
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||||
|
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||||
|
CrossEntropyLoss, inplace_backward=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_llama_rms_norm():
|
||||||
|
try:
|
||||||
|
from flash_attn.ops.rms_norm import RMSNorm
|
||||||
|
|
||||||
|
class LlamaRMSNorm(RMSNorm):
|
||||||
|
"""Patched LLamaRMSNorm"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
super().__init__(hidden_size, eps=eps)
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||||
|
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning(
|
||||||
|
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_attn_with_flash_attn(
|
def replace_llama_attn_with_flash_attn(
|
||||||
packed: Optional[bool] = False,
|
packed: Optional[bool] = False,
|
||||||
cross_entropy: Optional[bool] = False,
|
cross_entropy: Optional[bool] = False,
|
||||||
@@ -104,30 +131,11 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
|
|
||||||
# skip only if explicitly disabled
|
# skip only if explicitly disabled
|
||||||
if cross_entropy:
|
if cross_entropy:
|
||||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
patch_llama_cross_entropy()
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
|
||||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
|
||||||
CrossEntropyLoss, inplace_backward=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# skip only if explicitly disabled
|
# skip only if explicitly disabled
|
||||||
if rms_norm:
|
if rms_norm:
|
||||||
try:
|
patch_llama_rms_norm()
|
||||||
from flash_attn.ops.rms_norm import RMSNorm
|
|
||||||
|
|
||||||
class LlamaRMSNorm(RMSNorm):
|
|
||||||
"""Patched LLamaRMSNorm"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
super().__init__(hidden_size, eps=eps)
|
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
|
||||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
|
||||||
except ImportError:
|
|
||||||
LOG.warning(
|
|
||||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FusedAttention(LlamaAttention):
|
class FusedAttention(LlamaAttention):
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from functools import partial
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -45,6 +46,15 @@ def replace_mistral_attn_with_flash_attn(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_mistral_cross_entropy():
|
||||||
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||||
|
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
|
||||||
|
CrossEntropyLoss, inplace_backward=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def _make_sliding_window_causal_mask(
|
def _make_sliding_window_causal_mask(
|
||||||
bsz: int,
|
bsz: int,
|
||||||
|
|||||||
@@ -10,11 +10,14 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
|||||||
from axolotl.monkeypatch.utils import get_unpad_data
|
from axolotl.monkeypatch.utils import get_unpad_data
|
||||||
|
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||||
|
"llama",
|
||||||
|
"mistral",
|
||||||
"mixtral",
|
"mixtral",
|
||||||
"qwen2",
|
"qwen2",
|
||||||
"qwen2_moe",
|
"qwen2_moe",
|
||||||
"falcon",
|
"falcon",
|
||||||
"phi",
|
"phi",
|
||||||
|
"phi3",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
"gemmoe",
|
"gemmoe",
|
||||||
@@ -23,13 +26,36 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def patch_for_multipack(model_type, model_name=None):
|
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
|
||||||
|
if model_type == "gemmoe":
|
||||||
|
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||||
|
elif model_type == "deepseek_v2":
|
||||||
|
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
||||||
|
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
|
||||||
|
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||||
|
patch_mixtral_moe_forward_zero3()
|
||||||
|
return
|
||||||
|
|
||||||
|
# retain for legacy
|
||||||
if model_type == "mixtral":
|
if model_type == "mixtral":
|
||||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
patch_mixtral_moe_forward_zero3()
|
patch_mixtral_moe_forward_zero3()
|
||||||
|
elif model_type == "llama":
|
||||||
|
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
|
||||||
|
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "mistral":
|
||||||
|
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
|
||||||
|
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
elif model_type == "qwen2":
|
elif model_type == "qwen2":
|
||||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
@@ -58,12 +84,6 @@ def patch_for_multipack(model_type, model_name=None):
|
|||||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
elif model_type == "gemmoe":
|
|
||||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
|
||||||
elif model_type == "jamba":
|
|
||||||
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
|
|
||||||
elif model_type == "deepseek_v2":
|
|
||||||
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
|
||||||
|
|
||||||
|
|
||||||
def patch_remote(model_name, config_name, modeling_name):
|
def patch_remote(model_name, config_name, modeling_name):
|
||||||
|
|||||||
@@ -1,18 +1,20 @@
|
|||||||
"""module for patching with unsloth optimizations"""
|
"""module for patching with unsloth optimizations"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
import types
|
import types
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate.logging import get_logger
|
||||||
from peft import PeftModelForCausalLM
|
from peft import PeftModelForCausalLM
|
||||||
|
from torch import nn
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaFlashAttention2,
|
LlamaFlashAttention2,
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.unsloth")
|
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||||
|
|
||||||
ORIGINAL_CEL_CODE = """ if labels is not None:
|
ORIGINAL_CEL_CODE = """ if labels is not None:
|
||||||
# Shift so that tokens < n predict n
|
# Shift so that tokens < n predict n
|
||||||
@@ -97,48 +99,51 @@ def check_self_attn_is_patchable() -> bool:
|
|||||||
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
|
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
|
||||||
|
|
||||||
|
|
||||||
def integrate_cross_entropy_loss_patch():
|
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
||||||
forward = get_forward_code()
|
if model_type == "llama":
|
||||||
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
forward = get_forward_code()
|
||||||
forward, _ = detab_code(forward)
|
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
||||||
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
forward, _ = detab_code(forward)
|
||||||
|
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
||||||
|
|
||||||
forward = forward.replace(
|
forward = forward.replace(
|
||||||
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
||||||
)
|
)
|
||||||
forward = forward.replace(
|
forward = forward.replace(
|
||||||
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
||||||
forward = forward.replace(
|
forward = forward.replace(
|
||||||
"def forward(",
|
"def forward(",
|
||||||
"def fast_cross_entropy_loss_forward(",
|
"def fast_cross_entropy_loss_forward(",
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# load imports necessary
|
# load imports necessary
|
||||||
import transformers.models.llama.modeling_llama
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
items_to_import = []
|
items_to_import = []
|
||||||
for item in dir(transformers.models.llama.modeling_llama):
|
for item in dir(transformers.models.llama.modeling_llama):
|
||||||
if item in forward:
|
if item in forward:
|
||||||
items_to_import.append(item)
|
items_to_import.append(item)
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
||||||
globals(),
|
globals(),
|
||||||
)
|
)
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
"from transformers.models.llama.modeling_llama import ("
|
"from transformers.models.llama.modeling_llama import ("
|
||||||
+ ", ".join(x for x in items_to_import)
|
+ ", ".join(x for x in items_to_import)
|
||||||
+ ")",
|
+ ")",
|
||||||
globals(),
|
globals(),
|
||||||
)
|
)
|
||||||
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
print("patching unsloth fast_cross_entropy_loss")
|
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
|
||||||
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported model type")
|
||||||
|
|
||||||
|
|
||||||
def detab_code(code: str) -> Tuple[str, str]:
|
def detab_code(code: str) -> Tuple[str, str]:
|
||||||
@@ -179,12 +184,30 @@ def patch_self_attn_lora():
|
|||||||
globals(),
|
globals(),
|
||||||
)
|
)
|
||||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
print("patching unsloth attn lora")
|
LOG.info("patching unsloth attn lora", main_process_only=True)
|
||||||
LlamaFlashAttention2.forward = (
|
LlamaFlashAttention2.forward = (
|
||||||
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def integrate_rope_embeddings():
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
from unsloth.kernels.rope_embedding import fast_rope_embedding
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb( # pylint: disable=unused-argument
|
||||||
|
q, # pylint: disable=invalid-name
|
||||||
|
k, # pylint: disable=invalid-name
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
position_ids=None,
|
||||||
|
unsqueeze_dim=1,
|
||||||
|
):
|
||||||
|
return fast_rope_embedding(q, k, cos, sin)
|
||||||
|
|
||||||
|
LOG.info("patching unsloth RoPE embeddings", main_process_only=True)
|
||||||
|
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
||||||
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
|
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
|
||||||
from unsloth.kernels import apply_lora_mlp_swiglu
|
from unsloth.kernels import apply_lora_mlp_swiglu
|
||||||
@@ -217,7 +240,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
|||||||
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
||||||
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
||||||
else:
|
else:
|
||||||
logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
|
LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
|
||||||
|
|
||||||
|
|
||||||
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||||
@@ -243,9 +266,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
|||||||
layer.self_attn.apply_qkv = apply_lora_qkv
|
layer.self_attn.apply_qkv = apply_lora_qkv
|
||||||
else:
|
else:
|
||||||
layer.self_attn.apply_qkv = original_apply_qkv
|
layer.self_attn.apply_qkv = original_apply_qkv
|
||||||
logging.warning(
|
LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx)
|
||||||
"unable to apply unsloth lora qkv patch to layer %d", idx
|
|
||||||
)
|
|
||||||
if cfg.unsloth_lora_o:
|
if cfg.unsloth_lora_o:
|
||||||
layer_modules = [
|
layer_modules = [
|
||||||
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||||
@@ -264,6 +285,33 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
|||||||
layer.self_attn.apply_o = apply_lora_o
|
layer.self_attn.apply_o = apply_lora_o
|
||||||
else:
|
else:
|
||||||
layer.self_attn.apply_o = original_apply_o
|
layer.self_attn.apply_o = original_apply_o
|
||||||
logging.warning(
|
LOG.warning(
|
||||||
"unable to apply unsloth lora o_proj patch to layer %d", idx
|
"unable to apply unsloth lora o_proj patch to layer %d", idx
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_unsloth_layernorm():
|
||||||
|
try:
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm
|
||||||
|
|
||||||
|
class LlamaRMSNorm(nn.Module):
|
||||||
|
"""LlamaRMSNorm"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
"""
|
||||||
|
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
return Fast_RMS_Layernorm.apply(
|
||||||
|
hidden_states, self.weight, self.variance_epsilon, False
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("patching with unsloth.kernels.rms_layernorm")
|
||||||
|
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning("missing unsloth library")
|
||||||
|
|||||||
@@ -6,14 +6,16 @@ import logging
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import Prompter
|
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
|
|
||||||
|
# Configure the logger
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
LOG.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplatePrompter(Prompter):
|
class ChatTemplatePrompter(Prompter):
|
||||||
"""prompter for HF chat templates"""
|
"""Prompter for HF chat templates"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -22,6 +24,8 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
max_length=2048,
|
max_length=2048,
|
||||||
message_field_role: str = "from",
|
message_field_role: str = "from",
|
||||||
message_field_content: str = "value",
|
message_field_content: str = "value",
|
||||||
|
message_field_training: str = "train",
|
||||||
|
message_field_training_detail: str = "train_detail",
|
||||||
roles: Optional[Dict[str, List[str]]] = None,
|
roles: Optional[Dict[str, List[str]]] = None,
|
||||||
drop_system_message: bool = False,
|
drop_system_message: bool = False,
|
||||||
):
|
):
|
||||||
@@ -37,6 +41,8 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
}
|
}
|
||||||
self.message_field_role = message_field_role
|
self.message_field_role = message_field_role
|
||||||
self.message_field_content = message_field_content
|
self.message_field_content = message_field_content
|
||||||
|
self.message_field_training = message_field_training
|
||||||
|
self.message_field_training_detail = message_field_training_detail
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.chat_template = chat_template
|
self.chat_template = chat_template
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
@@ -47,6 +53,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
{
|
{
|
||||||
"role": self.roles[t[self.message_field_role]],
|
"role": self.roles[t[self.message_field_role]],
|
||||||
"content": t[self.message_field_content],
|
"content": t[self.message_field_content],
|
||||||
|
"training": t.get(self.message_field_training, None),
|
||||||
}
|
}
|
||||||
for t in conversation
|
for t in conversation
|
||||||
]
|
]
|
||||||
@@ -62,6 +69,108 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
chat_template=self.chat_template,
|
chat_template=self.chat_template,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_offsets_for_train_detail(
|
||||||
|
self, text: str, train_details: List[Dict], mask_untrainable: bool = True
|
||||||
|
) -> List[int]:
|
||||||
|
tokenized_output = self.tokenizer(
|
||||||
|
text, return_offsets_mapping=True, add_special_tokens=False
|
||||||
|
)
|
||||||
|
tokens = tokenized_output.tokens()
|
||||||
|
token_offsets = tokenized_output["offset_mapping"]
|
||||||
|
|
||||||
|
LOG.debug(f"Tokenizing text: {text}")
|
||||||
|
LOG.debug(f"Tokens: {tokens}")
|
||||||
|
# Adjust the end offsets. For some reason by default they are set to the same value as the start offsets.
|
||||||
|
for i in range(len(token_offsets) - 1):
|
||||||
|
token_offsets[i] = (token_offsets[i][0], token_offsets[i + 1][0] - 1)
|
||||||
|
# Ensure the last token's end offset is set correctly
|
||||||
|
token_offsets[-1] = (token_offsets[-1][0], len(text) - 1)
|
||||||
|
LOG.debug(f"Token offsets: {token_offsets}")
|
||||||
|
|
||||||
|
# Initialize all offsets as IGNORE_TOKEN_ID (not trained)
|
||||||
|
result = [IGNORE_TOKEN_ID] * len(token_offsets)
|
||||||
|
|
||||||
|
# Adjust train_details to align with token boundaries
|
||||||
|
adjusted_train_details = self.adjust_train_details(train_details, token_offsets)
|
||||||
|
|
||||||
|
for idx, (start, end) in enumerate(token_offsets):
|
||||||
|
for detail in adjusted_train_details:
|
||||||
|
# Check if the token is completely within the detail's range
|
||||||
|
if start >= detail["begin_offset"] and end <= detail["end_offset"]:
|
||||||
|
if detail["train"] or not mask_untrainable:
|
||||||
|
result[idx] = start
|
||||||
|
LOG.debug(f"Token {idx} ({tokens[idx]}) marked for training")
|
||||||
|
else:
|
||||||
|
LOG.debug(
|
||||||
|
f"Token {idx} ({tokens[idx]}) marked as non-trainable"
|
||||||
|
)
|
||||||
|
elif start < detail["end_offset"] and end > detail["begin_offset"]:
|
||||||
|
# Token partially overlaps with detail, always mark as non-trainable
|
||||||
|
LOG.debug(
|
||||||
|
f"Token {idx} ({tokens[idx]}) partially overlaps detail, marked as non-trainable"
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug(f"Final result: {result}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
def adjust_train_details(
|
||||||
|
self, train_details: List[Dict], token_offsets: List[tuple]
|
||||||
|
) -> List[Dict]:
|
||||||
|
adjusted_details = []
|
||||||
|
for detail in train_details:
|
||||||
|
begin_offset = detail["begin_offset"]
|
||||||
|
end_offset = detail["end_offset"]
|
||||||
|
|
||||||
|
# Find the first token that starts after or at the begin_offset
|
||||||
|
begin_token = next(
|
||||||
|
(
|
||||||
|
i
|
||||||
|
for i, (t_start, t_end) in enumerate(token_offsets)
|
||||||
|
if t_start >= begin_offset
|
||||||
|
),
|
||||||
|
len(token_offsets),
|
||||||
|
)
|
||||||
|
if begin_token > 0 and token_offsets[begin_token - 1][1] > begin_offset:
|
||||||
|
begin_token -= 1
|
||||||
|
|
||||||
|
# Find the last token that ends before or at the end_offset
|
||||||
|
end_token = next(
|
||||||
|
(
|
||||||
|
i
|
||||||
|
for i in range(len(token_offsets) - 1, -1, -1)
|
||||||
|
if token_offsets[i][1] <= end_offset
|
||||||
|
),
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
end_token < len(token_offsets) - 1
|
||||||
|
and token_offsets[end_token + 1][0] < end_offset
|
||||||
|
):
|
||||||
|
end_token += 1
|
||||||
|
|
||||||
|
if begin_token <= end_token:
|
||||||
|
adjusted_begin = token_offsets[begin_token][0]
|
||||||
|
adjusted_end = token_offsets[end_token][1]
|
||||||
|
|
||||||
|
if adjusted_begin != begin_offset or adjusted_end != end_offset:
|
||||||
|
LOG.warning(
|
||||||
|
f"Adjusting detail offsets: ({begin_offset}, {end_offset}) -> ({adjusted_begin}, {adjusted_end})"
|
||||||
|
)
|
||||||
|
|
||||||
|
adjusted_details.append(
|
||||||
|
{
|
||||||
|
"begin_offset": adjusted_begin,
|
||||||
|
"end_offset": adjusted_end,
|
||||||
|
"train": detail["train"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.warning(
|
||||||
|
f"Could not adjust detail offsets: ({begin_offset}, {end_offset}). Skipping this detail."
|
||||||
|
)
|
||||||
|
|
||||||
|
return adjusted_details
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||||
"""
|
"""
|
||||||
@@ -70,6 +179,19 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
_messages = "conversations"
|
_messages = "conversations"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompter,
|
||||||
|
tokenizer,
|
||||||
|
train_on_inputs,
|
||||||
|
sequence_len,
|
||||||
|
roles_to_train=None,
|
||||||
|
train_on_eos="last",
|
||||||
|
):
|
||||||
|
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
|
||||||
|
self.roles_to_train = roles_to_train if roles_to_train is not None else []
|
||||||
|
self.train_on_eos = train_on_eos
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def messages(self):
|
def messages(self):
|
||||||
return self._messages
|
return self._messages
|
||||||
@@ -79,62 +201,170 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
self._messages = messages
|
self._messages = messages
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
turns = self.get_conversation_thread(prompt)
|
turns = prompt[self.messages]
|
||||||
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
|
|
||||||
input_ids = self.prompter.build_prompt(turns)
|
input_ids = self.prompter.build_prompt(turns)
|
||||||
|
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||||
|
|
||||||
if not self.train_on_inputs:
|
last_eos_idx = -1
|
||||||
user_prompt_len = len(prompt_ids)
|
for index, turn in enumerate(turns):
|
||||||
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
|
role = turn.get(self.prompter.message_field_role)
|
||||||
else:
|
content = turn.get(self.prompter.message_field_content)
|
||||||
labels = input_ids
|
train_turn = turn.get(self.prompter.message_field_training)
|
||||||
|
train_detail = turn.get(self.prompter.message_field_training_detail)
|
||||||
|
|
||||||
tokenized_prompt = {
|
LOG.debug(
|
||||||
|
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
|
||||||
|
)
|
||||||
|
|
||||||
|
should_train = (
|
||||||
|
train_turn
|
||||||
|
if train_turn is not None
|
||||||
|
else bool(train_detail is not None)
|
||||||
|
if train_detail is not None
|
||||||
|
else self.train_on_inputs or role in self.roles_to_train
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug(f"Should train: {should_train}")
|
||||||
|
|
||||||
|
turn_start_idx, turn_end_idx = self.find_turn(
|
||||||
|
conversation_ids=input_ids, turn=index, turn_content=turn
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
|
||||||
|
|
||||||
|
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
||||||
|
if train_detail:
|
||||||
|
token_offsets = self.prompter.get_offsets_for_train_detail(
|
||||||
|
content, train_detail
|
||||||
|
)
|
||||||
|
LOG.debug(f"Token offsets: {token_offsets}")
|
||||||
|
for i, offset in enumerate(token_offsets):
|
||||||
|
if offset != IGNORE_TOKEN_ID and turn_start_idx + i < len(
|
||||||
|
input_ids
|
||||||
|
):
|
||||||
|
labels[turn_start_idx + i] = input_ids[turn_start_idx + i]
|
||||||
|
LOG.debug(
|
||||||
|
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
labels[turn_start_idx:turn_end_idx] = input_ids[
|
||||||
|
turn_start_idx:turn_end_idx
|
||||||
|
]
|
||||||
|
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}")
|
||||||
|
|
||||||
|
LOG.debug(f"Labels after processing turn {index}: {labels}")
|
||||||
|
|
||||||
|
# Handle EOS token
|
||||||
|
eos_idx = self.find_eos_token(input_ids, turn_end_idx)
|
||||||
|
if eos_idx == turn_end_idx:
|
||||||
|
last_eos_idx = eos_idx
|
||||||
|
if self.train_on_eos == "all" or (
|
||||||
|
self.train_on_eos == "turn" and should_train
|
||||||
|
):
|
||||||
|
labels[eos_idx] = input_ids[eos_idx]
|
||||||
|
LOG.debug(f"EOS token set for training at index {eos_idx}")
|
||||||
|
else:
|
||||||
|
LOG.debug(
|
||||||
|
f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle 'last' option for train_on_eos
|
||||||
|
if self.train_on_eos == "last" and last_eos_idx != -1:
|
||||||
|
labels[last_eos_idx] = input_ids[last_eos_idx]
|
||||||
|
LOG.debug(f"Last EOS token set for training at index {last_eos_idx}")
|
||||||
|
|
||||||
|
LOG.debug(f"Final labels: {labels}")
|
||||||
|
|
||||||
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"labels": labels,
|
"labels": labels,
|
||||||
"attention_mask": [1] * len(input_ids),
|
"attention_mask": [1] * len(input_ids),
|
||||||
}
|
}
|
||||||
|
|
||||||
return tokenized_prompt
|
def find_eos_token(self, input_ids, start_idx):
|
||||||
|
eos_token_id = self.tokenizer.eos_token_id
|
||||||
|
for i in range(start_idx, len(input_ids)):
|
||||||
|
if input_ids[i] == eos_token_id:
|
||||||
|
return i
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def find_turn(self, conversation_ids, turn, turn_content):
|
||||||
|
"""
|
||||||
|
Locate the starting and ending indices of the specified turn in a conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_ids (list[int]): Token IDs representing the conversation.
|
||||||
|
turn (int): The turn number to locate (based on EOS tokens).
|
||||||
|
turn_content (str): String containing the content of the turn.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
|
||||||
|
Returns (-1, -1) if the turn content is not found.
|
||||||
|
"""
|
||||||
|
content = turn_content.get(self.prompter.message_field_content, "")
|
||||||
|
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
|
||||||
|
|
||||||
|
eos_token_id = self.tokenizer.eos_token_id
|
||||||
|
eos_count = 0
|
||||||
|
start_search_idx = 0
|
||||||
|
|
||||||
|
# Locate the starting index after the specified number of EOS tokens
|
||||||
|
for i, token_id in enumerate(conversation_ids):
|
||||||
|
if token_id == eos_token_id:
|
||||||
|
eos_count += 1
|
||||||
|
if eos_count == turn:
|
||||||
|
start_search_idx = (
|
||||||
|
i + 1
|
||||||
|
) # Start searching after the specified turn's EOS token
|
||||||
|
break
|
||||||
|
|
||||||
|
# Find the start index of the content within the conversation
|
||||||
|
start_idx = -1
|
||||||
|
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1):
|
||||||
|
if conversation_ids[i : i + len(content_ids)] == content_ids:
|
||||||
|
start_idx = i
|
||||||
|
break
|
||||||
|
|
||||||
|
if start_idx != -1:
|
||||||
|
end_idx = start_idx + len(content_ids)
|
||||||
|
else:
|
||||||
|
end_idx = -1
|
||||||
|
|
||||||
|
return start_idx, end_idx
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self, prompt):
|
||||||
return prompt[self.messages]
|
return prompt[self.messages]
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
chat_template = (
|
ds_cfg = ds_cfg or {}
|
||||||
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
|
|
||||||
)
|
prompter_params = {
|
||||||
message_field_role = (
|
"tokenizer": tokenizer,
|
||||||
ds_cfg["message_field_role"]
|
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
|
||||||
if ds_cfg and "message_field_role" in ds_cfg
|
"message_field_role": ds_cfg.get("message_field_role", "from"),
|
||||||
else "from"
|
"message_field_content": ds_cfg.get("message_field_content", "value"),
|
||||||
)
|
"message_field_training": ds_cfg.get("message_field_training", "training"),
|
||||||
message_field_content = (
|
"message_field_training_detail": ds_cfg.get(
|
||||||
ds_cfg["message_field_content"]
|
"message_field_training_detail", "train_detail"
|
||||||
if ds_cfg and "message_field_content" in ds_cfg
|
),
|
||||||
else "value"
|
"roles": ds_cfg.get("roles"),
|
||||||
)
|
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||||
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
|
"max_length": cfg.sequence_len,
|
||||||
drop_system_message = (
|
}
|
||||||
ds_cfg["drop_system_message"]
|
|
||||||
if ds_cfg and "drop_system_message" in ds_cfg
|
strategy_params = {
|
||||||
else False
|
"train_on_inputs": cfg.train_on_inputs,
|
||||||
)
|
"sequence_len": cfg.sequence_len,
|
||||||
|
"roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]),
|
||||||
|
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
|
||||||
|
}
|
||||||
|
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||||
tokenizer,
|
|
||||||
chat_templates(chat_template),
|
|
||||||
message_field_role=message_field_role,
|
|
||||||
message_field_content=message_field_content,
|
|
||||||
roles=roles,
|
|
||||||
drop_system_message=drop_system_message,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
)
|
||||||
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
|
||||||
|
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||||
strategy.messages = ds_cfg["field_messages"]
|
strategy.messages = ds_cfg["field_messages"]
|
||||||
|
|
||||||
return strategy
|
return strategy
|
||||||
|
|||||||
78
src/axolotl/prompt_strategies/dpo/chat_template.py
Normal file
78
src/axolotl/prompt_strategies/dpo/chat_template.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""
|
||||||
|
DPO prompt strategies for using tokenizer chat templates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
|
|
||||||
|
|
||||||
|
def default(
|
||||||
|
cfg, dataset_idx=0, **kwargs
|
||||||
|
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||||
|
ds_cfg = cfg["datasets"][dataset_idx]
|
||||||
|
chat_template_str = chat_templates(cfg.chat_template)
|
||||||
|
|
||||||
|
field_messages = ds_cfg.get("field_messages", "messages")
|
||||||
|
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||||
|
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||||
|
field_message_role = ds_cfg.get("message_field_role", "role")
|
||||||
|
field_message_content = ds_cfg.get("message_field_content", "content")
|
||||||
|
role_map_inv = ds_cfg.get(
|
||||||
|
"roles",
|
||||||
|
{
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
"system": ["system"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
role_map = {}
|
||||||
|
for target, sources in role_map_inv.items():
|
||||||
|
for source in sources:
|
||||||
|
role_map[source] = target
|
||||||
|
|
||||||
|
def transform_fn(sample, tokenizer=None):
|
||||||
|
messages = sample[field_messages]
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": role_map[m[field_message_role]],
|
||||||
|
"content": m[field_message_content],
|
||||||
|
}
|
||||||
|
for m in messages
|
||||||
|
]
|
||||||
|
chosen = {
|
||||||
|
"role": role_map[sample[field_chosen][field_message_role]],
|
||||||
|
"content": sample[field_chosen][field_message_content],
|
||||||
|
}
|
||||||
|
rejected = {
|
||||||
|
"role": role_map[sample[field_rejected][field_message_role]],
|
||||||
|
"content": sample[field_rejected][field_message_content],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
result["prompt"] = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result["chosen"] = tokenizer.apply_chat_template(
|
||||||
|
[chosen],
|
||||||
|
add_generation_prompt=False,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
chosen_strip_index = result["chosen"].find(chosen["content"])
|
||||||
|
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
|
||||||
|
|
||||||
|
result["rejected"] = tokenizer.apply_chat_template(
|
||||||
|
[rejected],
|
||||||
|
add_generation_prompt=False,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
rejected_strip_index = result["rejected"].find(rejected["content"])
|
||||||
|
result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return transform_fn
|
||||||
@@ -65,8 +65,10 @@ class AlpacaPrompter(Prompter):
|
|||||||
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||||
elif self.prompt_style == PromptStyle.PHI.value:
|
elif self.prompt_style == PromptStyle.PHI.value:
|
||||||
self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>"
|
self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>"
|
||||||
self.turn_no_input_format = "<|user|>\n{instruction}<|end|><|assistant|>"
|
self.turn_no_input_format = (
|
||||||
self.system_format = "<|system|>{system}\n"
|
"<|user|>\n{instruction}<|end|>\n<|assistant|>\n"
|
||||||
|
)
|
||||||
|
self.system_format = "<|system|>\n{system}<|end|>\n"
|
||||||
|
|
||||||
def _build_result(self, instruction, input_text, output):
|
def _build_result(self, instruction, input_text, output):
|
||||||
# returns the full prompt from instruction and optional input
|
# returns the full prompt from instruction and optional input
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import torch
|
|||||||
import transformers.modelcard
|
import transformers.modelcard
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
|
from accelerate.utils import save_fsdp_model
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from pkg_resources import get_distribution # type: ignore
|
from pkg_resources import get_distribution # type: ignore
|
||||||
@@ -19,6 +20,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer
|
|||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.core.tokenizer_utils import fix_untrained_tokens
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.freeze import freeze_layers_except
|
from axolotl.utils.freeze import freeze_layers_except
|
||||||
@@ -52,6 +54,15 @@ class TrainDatasetMeta:
|
|||||||
def train(
|
def train(
|
||||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||||
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
|
||||||
|
# enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
|
torch_version = torch.__version__.split(".")
|
||||||
|
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
|
||||||
|
if torch_major == 2 and torch_minor >= 2:
|
||||||
|
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
|
||||||
|
os.environ[
|
||||||
|
"PYTORCH_CUDA_ALLOC_CONF"
|
||||||
|
] = "expandable_segments:True,roundup_power2_divisions:16"
|
||||||
|
|
||||||
# load the tokenizer first
|
# load the tokenizer first
|
||||||
LOG.debug(
|
LOG.debug(
|
||||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||||
@@ -114,6 +125,13 @@ def train(
|
|||||||
total_num_steps,
|
total_num_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.fix_untrained_tokens:
|
||||||
|
fix_untrained_tokens(model, tokenizer, train_dataset)
|
||||||
|
if cfg.local_rank == 0:
|
||||||
|
model.save_pretrained(
|
||||||
|
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
|
|
||||||
# go ahead and presave, so we have the adapter config available to inspect
|
# go ahead and presave, so we have the adapter config available to inspect
|
||||||
if peft_config:
|
if peft_config:
|
||||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||||
@@ -177,9 +195,12 @@ def train(
|
|||||||
if hasattr(module, "_post_training"):
|
if hasattr(module, "_post_training"):
|
||||||
module._post_training(model, name) # pylint: disable=protected-access
|
module._post_training(model, name) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
state_dict_type = "FULL_STATE_DICT"
|
||||||
if trainer.is_fsdp_enabled:
|
if trainer.is_fsdp_enabled:
|
||||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
if cfg.fsdp_final_state_dict_type:
|
||||||
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
|
state_dict_type = cfg.fsdp_final_state_dict_type
|
||||||
|
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
|
||||||
|
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
|
||||||
|
|
||||||
if cfg.relora_steps:
|
if cfg.relora_steps:
|
||||||
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
|
||||||
@@ -191,30 +212,38 @@ def train(
|
|||||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
trainer.save_model(cfg.output_dir)
|
if (
|
||||||
|
state_dict_type == "SHARDED_STATE_DICT"
|
||||||
|
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
|
||||||
|
):
|
||||||
|
save_fsdp_model(
|
||||||
|
trainer.accelerator.state.fsdp_plugin,
|
||||||
|
trainer.accelerator,
|
||||||
|
trainer.model,
|
||||||
|
cfg.output_dir,
|
||||||
|
)
|
||||||
|
elif state_dict_type == "FULL_STATE_DICT":
|
||||||
|
trainer.save_model(cfg.output_dir)
|
||||||
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
||||||
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
||||||
trainer.accelerator.wait_for_everyone()
|
trainer.accelerator.wait_for_everyone()
|
||||||
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
|
trainer.save_model(cfg.output_dir)
|
||||||
|
|
||||||
# the trainer saved a model.safetensors file in the output directory,
|
# the trainer saved a model.safetensors file in the output directory,
|
||||||
# but it is a proxy model and should be deleted
|
# but it is most likely a proxy model and if so, should be deleted
|
||||||
if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")):
|
maybe_proxy = os.path.exists(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||||
|
maybe_sharded = os.path.exists(
|
||||||
|
os.path.join(cfg.output_dir, "model.safetensors.index.json")
|
||||||
|
)
|
||||||
|
|
||||||
|
if maybe_proxy and maybe_sharded:
|
||||||
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
|
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
|
||||||
LOG.info("This is a proxy model and should be deleted")
|
LOG.info("This is a proxy model and should be deleted")
|
||||||
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
try:
|
||||||
|
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
|
||||||
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
|
||||||
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
|
||||||
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
|
|
||||||
# The model name saved is `pytorch_model.bin`
|
|
||||||
unwrapped_model.save_pretrained(
|
|
||||||
cfg.output_dir,
|
|
||||||
is_main_process=trainer.accelerator.is_main_process,
|
|
||||||
save_function=trainer.accelerator.save,
|
|
||||||
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
|
||||||
)
|
|
||||||
elif cfg.local_rank == 0:
|
elif cfg.local_rank == 0:
|
||||||
if cfg.flash_optimum and BetterTransformer:
|
if cfg.flash_optimum and BetterTransformer:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -7,6 +7,7 @@ Module for pydantic models for configuration
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from importlib.metadata import version
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
|
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
|
||||||
@@ -77,6 +78,7 @@ class PretrainingDataset(BaseModel):
|
|||||||
split: Optional[str] = "train"
|
split: Optional[str] = "train"
|
||||||
text_column: Optional[str] = "text"
|
text_column: Optional[str] = "text"
|
||||||
type: Optional[str] = "pretrain"
|
type: Optional[str] = "pretrain"
|
||||||
|
trust_remote_code: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedPrompterType(BaseModel):
|
class UserDefinedPrompterType(BaseModel):
|
||||||
@@ -114,10 +116,16 @@ class SFTDataset(BaseModel):
|
|||||||
field_messages: Optional[str] = None
|
field_messages: Optional[str] = None
|
||||||
message_field_role: Optional[str] = None
|
message_field_role: Optional[str] = None
|
||||||
message_field_content: Optional[str] = None
|
message_field_content: Optional[str] = None
|
||||||
|
message_field_training: Optional[str] = None
|
||||||
|
message_field_training_detail: Optional[str] = None
|
||||||
|
roles_to_train: Optional[List[str]] = None
|
||||||
|
train_on_eos: Optional[str] = None
|
||||||
|
|
||||||
roles: Optional[Dict[str, List[str]]] = None
|
roles: Optional[Dict[str, List[str]]] = None
|
||||||
drop_system_message: Optional[bool] = None
|
drop_system_message: Optional[bool] = None
|
||||||
|
|
||||||
|
trust_remote_code: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class UserDefinedDPOType(BaseModel):
|
class UserDefinedDPOType(BaseModel):
|
||||||
"""User defined typing for DPO"""
|
"""User defined typing for DPO"""
|
||||||
@@ -158,6 +166,7 @@ class KTODataset(BaseModel):
|
|||||||
split: Optional[str] = None
|
split: Optional[str] = None
|
||||||
type: Optional[Union[UserDefinedKTOType, str]] = None
|
type: Optional[Union[UserDefinedKTOType, str]] = None
|
||||||
data_files: Optional[List[str]] = None
|
data_files: Optional[List[str]] = None
|
||||||
|
trust_remote_code: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class RLType(str, Enum):
|
class RLType(str, Enum):
|
||||||
@@ -167,6 +176,7 @@ class RLType(str, Enum):
|
|||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
kto = "kto" # pylint: disable=invalid-name
|
kto = "kto" # pylint: disable=invalid-name
|
||||||
|
simpo = "simpo" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplate(str, Enum):
|
class ChatTemplate(str, Enum):
|
||||||
@@ -179,6 +189,8 @@ class ChatTemplate(str, Enum):
|
|||||||
cohere = "cohere" # pylint: disable=invalid-name
|
cohere = "cohere" # pylint: disable=invalid-name
|
||||||
llama3 = "llama3" # pylint: disable=invalid-name
|
llama3 = "llama3" # pylint: disable=invalid-name
|
||||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||||
|
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||||
|
jamba = "jamba" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class LoftQConfig(BaseModel):
|
class LoftQConfig(BaseModel):
|
||||||
@@ -225,6 +237,12 @@ class LoraConfig(BaseModel):
|
|||||||
peft_use_rslora: Optional[bool] = None
|
peft_use_rslora: Optional[bool] = None
|
||||||
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
|
||||||
|
|
||||||
|
qlora_sharded_model_loading: Optional[bool] = Field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||||
|
},
|
||||||
|
)
|
||||||
lora_on_cpu: Optional[bool] = None
|
lora_on_cpu: Optional[bool] = None
|
||||||
gptq: Optional[bool] = None
|
gptq: Optional[bool] = None
|
||||||
bnb_config_kwargs: Optional[Dict[str, Any]] = None
|
bnb_config_kwargs: Optional[Dict[str, Any]] = None
|
||||||
@@ -304,6 +322,8 @@ class ModelInputConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
trust_remote_code: Optional[bool] = None
|
trust_remote_code: Optional[bool] = None
|
||||||
|
|
||||||
|
model_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
@field_validator("trust_remote_code")
|
@field_validator("trust_remote_code")
|
||||||
@classmethod
|
@classmethod
|
||||||
def hint_trust_remote_code(cls, trust_remote_code):
|
def hint_trust_remote_code(cls, trust_remote_code):
|
||||||
@@ -341,7 +361,16 @@ class HyperparametersConfig(BaseModel):
|
|||||||
learning_rate: Union[str, float]
|
learning_rate: Union[str, float]
|
||||||
weight_decay: Optional[float] = 0.0
|
weight_decay: Optional[float] = 0.0
|
||||||
optimizer: Optional[
|
optimizer: Optional[
|
||||||
Union[OptimizerNames, Literal["lion_pytorch"]]
|
Union[
|
||||||
|
OptimizerNames,
|
||||||
|
Literal[
|
||||||
|
"lion_pytorch",
|
||||||
|
"optimi_adamw",
|
||||||
|
"ao_adamw_4bit",
|
||||||
|
"ao_adamw_8bit",
|
||||||
|
"ao_adamw_fp8",
|
||||||
|
],
|
||||||
|
]
|
||||||
] = OptimizerNames.ADAMW_HF.value
|
] = OptimizerNames.ADAMW_HF.value
|
||||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||||
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
||||||
@@ -353,7 +382,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
torchdistx_path: Optional[str] = None
|
torchdistx_path: Optional[str] = None
|
||||||
lr_scheduler: Optional[SchedulerType] = "cosine"
|
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
|
||||||
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
||||||
lr_quadratic_warmup: Optional[bool] = None
|
lr_quadratic_warmup: Optional[bool] = None
|
||||||
cosine_min_lr_ratio: Optional[float] = None
|
cosine_min_lr_ratio: Optional[float] = None
|
||||||
@@ -504,6 +533,8 @@ class AxolotlInputConfig(
|
|||||||
dataloader_prefetch_factor: Optional[int] = None
|
dataloader_prefetch_factor: Optional[int] = None
|
||||||
dataloader_drop_last: Optional[bool] = None
|
dataloader_drop_last: Optional[bool] = None
|
||||||
|
|
||||||
|
accelerator_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
remove_unused_columns: Optional[bool] = None
|
remove_unused_columns: Optional[bool] = None
|
||||||
|
|
||||||
push_dataset_to_hub: Optional[str] = None
|
push_dataset_to_hub: Optional[str] = None
|
||||||
@@ -586,14 +617,21 @@ class AxolotlInputConfig(
|
|||||||
flash_attn_fuse_mlp: Optional[bool] = None
|
flash_attn_fuse_mlp: Optional[bool] = None
|
||||||
flash_optimum: Optional[bool] = None
|
flash_optimum: Optional[bool] = None
|
||||||
|
|
||||||
|
eager_attention: Optional[bool] = None
|
||||||
|
|
||||||
unsloth_cross_entropy_loss: Optional[bool] = None
|
unsloth_cross_entropy_loss: Optional[bool] = None
|
||||||
unsloth_lora_mlp: Optional[bool] = None
|
unsloth_lora_mlp: Optional[bool] = None
|
||||||
unsloth_lora_qkv: Optional[bool] = None
|
unsloth_lora_qkv: Optional[bool] = None
|
||||||
unsloth_lora_o: Optional[bool] = None
|
unsloth_lora_o: Optional[bool] = None
|
||||||
|
unsloth_rms_norm: Optional[bool] = None
|
||||||
|
unsloth_rope: Optional[bool] = None
|
||||||
|
|
||||||
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
fsdp: Optional[List[str]] = None
|
fsdp: Optional[List[str]] = None
|
||||||
fsdp_config: Optional[Dict[str, Any]] = None
|
fsdp_config: Optional[Dict[str, Any]] = None
|
||||||
|
fsdp_final_state_dict_type: Optional[
|
||||||
|
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
|
||||||
|
] = None
|
||||||
|
|
||||||
val_set_size: Optional[float] = Field(default=0.0)
|
val_set_size: Optional[float] = Field(default=0.0)
|
||||||
|
|
||||||
@@ -602,6 +640,9 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
torch_compile: Optional[bool] = None
|
torch_compile: Optional[bool] = None
|
||||||
torch_compile_backend: Optional[str] = None
|
torch_compile_backend: Optional[str] = None
|
||||||
|
torch_compile_mode: Optional[
|
||||||
|
Literal["default", "reduce-overhead", "max-autotune"]
|
||||||
|
] = None
|
||||||
|
|
||||||
max_steps: Optional[int] = None
|
max_steps: Optional[int] = None
|
||||||
warmup_steps: Optional[int] = None
|
warmup_steps: Optional[int] = None
|
||||||
@@ -623,6 +664,8 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
orpo_alpha: Optional[float] = None
|
orpo_alpha: Optional[float] = None
|
||||||
rpo_alpha: Optional[float] = None
|
rpo_alpha: Optional[float] = None
|
||||||
|
simpo_gamma: Optional[float] = None
|
||||||
|
cpo_alpha: Optional[float] = None
|
||||||
|
|
||||||
kto_desirable_weight: Optional[float] = None
|
kto_desirable_weight: Optional[float] = None
|
||||||
kto_undesirable_weight: Optional[float] = None
|
kto_undesirable_weight: Optional[float] = None
|
||||||
@@ -637,6 +680,8 @@ class AxolotlInputConfig(
|
|||||||
chat_template: Optional[ChatTemplate] = None
|
chat_template: Optional[ChatTemplate] = None
|
||||||
default_system_message: Optional[str] = None
|
default_system_message: Optional[str] = None
|
||||||
|
|
||||||
|
fix_untrained_tokens: Optional[bool] = None
|
||||||
|
|
||||||
# INTERNALS - document for now, generally not set externally
|
# INTERNALS - document for now, generally not set externally
|
||||||
is_preprocess: Optional[bool] = None
|
is_preprocess: Optional[bool] = None
|
||||||
|
|
||||||
@@ -702,6 +747,24 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_pretraining_split_batches_accelerate(cls, data):
|
||||||
|
# alternatively set ACCELERATE_SPLIT_BATCHES=False
|
||||||
|
if data.get("pretraining_dataset"):
|
||||||
|
accelerator_config = data.get("accelerator_config", {})
|
||||||
|
if not accelerator_config:
|
||||||
|
data["accelerator_config"] = {
|
||||||
|
"split_batches": False,
|
||||||
|
"dispatch_batches": False,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
if accelerator_config.get("split_batches") is None:
|
||||||
|
data["accelerator_config"]["split_batches"] = False
|
||||||
|
if accelerator_config.get("dispatch_batches") is None:
|
||||||
|
data["accelerator_config"]["dispatch_batches"] = False
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_gptq_w_revision(cls, data):
|
def check_gptq_w_revision(cls, data):
|
||||||
@@ -820,7 +883,7 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_adamw_optimizer_params(self):
|
def check_adamw_optimizer_params(self):
|
||||||
if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
|
if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
|
||||||
not self.optimizer or "adamw" not in self.optimizer.value
|
not self.optimizer or "adamw" not in str(self.optimizer).lower()
|
||||||
):
|
):
|
||||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||||
return self
|
return self
|
||||||
@@ -891,6 +954,8 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_eval_packing(cls, data):
|
def check_eval_packing(cls, data):
|
||||||
|
# TODO also should check test_datasets and val_set_size as we can skip
|
||||||
|
# if there are no eval datasets/splits
|
||||||
if (
|
if (
|
||||||
data.get("sample_packing")
|
data.get("sample_packing")
|
||||||
and data.get("eval_table_size")
|
and data.get("eval_table_size")
|
||||||
@@ -1087,6 +1152,20 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_fsdp_sharded_state_dict_w_safetensors(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("fsdp")
|
||||||
|
and data.get("save_safetensors")
|
||||||
|
and data.get("fsdp_config")
|
||||||
|
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_causal_lm_evals(cls, data):
|
def check_causal_lm_evals(cls, data):
|
||||||
@@ -1112,6 +1191,55 @@ class AxolotlInputConfig(
|
|||||||
raise ValueError("either datasets or pretraining_dataset is required")
|
raise ValueError("either datasets or pretraining_dataset is required")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_xentropy_patch_conflicts(cls, data):
|
||||||
|
if data.get("flash_attn_cross_entropy") and data.get(
|
||||||
|
"unsloth_cross_entropy_loss"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_qlora_unsloth(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("unsloth_lora_mlp")
|
||||||
|
or data.get("unsloth_lora_qkv")
|
||||||
|
or data.get("unsloth_lora_o")
|
||||||
|
):
|
||||||
|
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
|
||||||
|
raise ValueError(
|
||||||
|
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_unsloth_xformers_version(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("unsloth_lora_mlp")
|
||||||
|
or data.get("unsloth_lora_qkv")
|
||||||
|
or data.get("unsloth_lora_o")
|
||||||
|
):
|
||||||
|
xformers_version = version("xformers")
|
||||||
|
if xformers_version == "0.0.27":
|
||||||
|
raise ValueError(
|
||||||
|
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_torch_compile_deepspeed(cls, data):
|
||||||
|
if data.get("deepspeed") and data.get("torch_compile"):
|
||||||
|
raise ValueError(
|
||||||
|
"torch_compile should be set within your deepspeed config file"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
@@ -1157,9 +1285,37 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_hopper_8bit_lora(cls, data):
|
||||||
|
is_sm_90: bool = (
|
||||||
|
data["capabilities"]
|
||||||
|
and data["capabilities"].get("compute_capability") == "sm_90"
|
||||||
|
)
|
||||||
|
if data.get("adapter") and data.get("load_in_8bit") and is_sm_90:
|
||||||
|
# see https://github.com/bitsandbytes-foundation/bitsandbytes/issues/538#issuecomment-2262945464
|
||||||
|
raise ValueError("8-bit LoRA is not supported on Hopper GPUs")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_fsdp_deepspeed(cls, data):
|
def check_fsdp_deepspeed(cls, data):
|
||||||
if data.get("deepspeed") and data.get("fsdp"):
|
if data.get("deepspeed") and data.get("fsdp"):
|
||||||
raise ValueError("deepspeed and fsdp cannot be used together.")
|
raise ValueError("deepspeed and fsdp cannot be used together.")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_multigpu_unsloth(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("unsloth_lora_mlp")
|
||||||
|
or data.get("unsloth_lora_qkv")
|
||||||
|
or data.get("unsloth_lora_o")
|
||||||
|
):
|
||||||
|
capabilities = data.get("capabilities")
|
||||||
|
if capabilities and capabilities.get("n_gpu", 0) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|||||||
@@ -18,10 +18,10 @@ LOG = logging.getLogger("axolotl")
|
|||||||
|
|
||||||
|
|
||||||
def encode_pretraining(
|
def encode_pretraining(
|
||||||
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
|
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
|
||||||
) -> Dict[str, List]:
|
) -> Dict[str, List]:
|
||||||
res = tokenizer(
|
res = tokenizer(
|
||||||
examples,
|
examples["text"],
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_tokens - 2,
|
max_length=max_tokens - 2,
|
||||||
add_special_tokens=True,
|
add_special_tokens=True,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""data handling specific to DPO"""
|
"""data handling specific to DPO"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ from axolotl.prompters import (
|
|||||||
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||||
from axolotl.utils.data.utils import md5
|
from axolotl.utils.data.utils import md5
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_local_main_process, zero_first
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
calculate_total_num_steps,
|
calculate_total_num_steps,
|
||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
@@ -54,7 +54,7 @@ LOG = logging.getLogger("axolotl")
|
|||||||
def prepare_dataset(cfg, tokenizer):
|
def prepare_dataset(cfg, tokenizer):
|
||||||
prompters = []
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_local_main_process()):
|
||||||
if cfg.test_datasets:
|
if cfg.test_datasets:
|
||||||
train_dataset, _, prompters = load_prepare_datasets(
|
train_dataset, _, prompters = load_prepare_datasets(
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
|
||||||
@@ -160,8 +160,12 @@ def load_tokenized_prepared_datasets(
|
|||||||
use_auth_token = cfg.hf_use_auth_token
|
use_auth_token = cfg.hf_use_auth_token
|
||||||
try:
|
try:
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
|
LOG.info(
|
||||||
|
f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
|
||||||
|
)
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
cfg.push_dataset_to_hub,
|
||||||
|
ds_hash,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
dataset = dataset[split]
|
dataset = dataset[split]
|
||||||
@@ -170,6 +174,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
if dataset:
|
if dataset:
|
||||||
|
# This is for the case where we already loaded a pretokenized dataset from the hub
|
||||||
...
|
...
|
||||||
elif (
|
elif (
|
||||||
cfg.dataset_prepared_path
|
cfg.dataset_prepared_path
|
||||||
@@ -180,7 +185,14 @@ def load_tokenized_prepared_datasets(
|
|||||||
dataset = load_from_disk(str(prepared_ds_path))
|
dataset = load_from_disk(str(prepared_ds_path))
|
||||||
LOG.info("Prepared dataset loaded from disk...")
|
LOG.info("Prepared dataset loaded from disk...")
|
||||||
else:
|
else:
|
||||||
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
if cfg.push_dataset_to_hub:
|
||||||
|
LOG.info("Unable to find prepared dataset in Huggingface hub")
|
||||||
|
if cfg.is_preprocess:
|
||||||
|
LOG.info(
|
||||||
|
f"Skipping prepared dataset in {prepared_ds_path} for pre-processing..."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
|
||||||
LOG.info("Loading raw datasets...")
|
LOG.info("Loading raw datasets...")
|
||||||
if not cfg.is_preprocess:
|
if not cfg.is_preprocess:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
@@ -198,6 +210,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
def for_d_in_datasets(dataset_configs):
|
def for_d_in_datasets(dataset_configs):
|
||||||
for dataset in dataset_configs:
|
for dataset in dataset_configs:
|
||||||
if dataset.name and isinstance(dataset.name, list):
|
if dataset.name and isinstance(dataset.name, list):
|
||||||
|
# load_dataset doesn't properly handle multiple named configurations
|
||||||
|
# at the same time for a given dataset
|
||||||
for name in dataset.name:
|
for name in dataset.name:
|
||||||
yield DictDefault({**dataset, "name": name})
|
yield DictDefault({**dataset, "name": name})
|
||||||
else:
|
else:
|
||||||
@@ -208,6 +222,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
ds: Optional[Union[Dataset, DatasetDict]] = None
|
ds: Optional[Union[Dataset, DatasetDict]] = None
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
try:
|
try:
|
||||||
|
# this is just a basic check to see if the path is a
|
||||||
|
# valid HF dataset that's loadable
|
||||||
load_dataset(
|
load_dataset(
|
||||||
config_dataset.path,
|
config_dataset.path,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
@@ -428,10 +444,12 @@ def load_tokenized_prepared_datasets(
|
|||||||
dataset.save_to_disk(str(prepared_ds_path))
|
dataset.save_to_disk(str(prepared_ds_path))
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
|
f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
|
||||||
)
|
)
|
||||||
dataset.push_to_hub(
|
dataset.push_to_hub(
|
||||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
cfg.push_dataset_to_hub,
|
||||||
|
ds_hash,
|
||||||
|
private=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataset, prompters
|
return dataset, prompters
|
||||||
|
|||||||
@@ -44,6 +44,10 @@ def is_main_process():
|
|||||||
return dist.get_rank() == 0
|
return dist.get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def is_local_main_process():
|
||||||
|
return PartialState().is_main_process
|
||||||
|
|
||||||
|
|
||||||
def get_world_size():
|
def get_world_size():
|
||||||
return int(os.getenv("WORLD_SIZE", "1"))
|
return int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
|
||||||
@@ -149,11 +153,11 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
|
|||||||
if is_main_process():
|
if is_main_process():
|
||||||
value_scalar = fn()
|
value_scalar = fn()
|
||||||
value_tensor = torch.tensor(
|
value_tensor = torch.tensor(
|
||||||
value_scalar, device=torch.cuda.current_device()
|
value_scalar, device=torch.cuda.current_device(), dtype=torch.float32
|
||||||
).float()
|
)
|
||||||
else:
|
else:
|
||||||
value_tensor = torch.tensor(
|
value_tensor = torch.tensor(
|
||||||
0.0, device=torch.cuda.current_device()
|
0.0, device=torch.cuda.current_device(), dtype=torch.float32
|
||||||
) # Placeholder tensor
|
) # Placeholder tensor
|
||||||
|
|
||||||
# Broadcast the tensor to all processes.
|
# Broadcast the tensor to all processes.
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from fastcore.parallel import parallel
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
|
from transformers.quantizers import AutoHfQuantizer
|
||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
|
||||||
|
|
||||||
|
|
||||||
@@ -173,6 +174,7 @@ def load_sharded_model_quant(
|
|||||||
low_memory=True,
|
low_memory=True,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
loading_workers=2,
|
loading_workers=2,
|
||||||
|
quantization_config=None,
|
||||||
):
|
):
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = AutoModelForCausalLM.from_config(
|
model = AutoModelForCausalLM.from_config(
|
||||||
@@ -186,15 +188,26 @@ def load_sharded_model_quant(
|
|||||||
compute_dtype=compute_dtype,
|
compute_dtype=compute_dtype,
|
||||||
quant_type="nf4",
|
quant_type="nf4",
|
||||||
quant_storage=quant_storage,
|
quant_storage=quant_storage,
|
||||||
|
compress_statistics=True, # bnb_4bit_use_double_quant
|
||||||
|
skip_modules=[
|
||||||
|
"lm_head",
|
||||||
|
"embed_out",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# this is the more common case with HF transformers
|
# this is the more common case with HF transformers
|
||||||
|
# TODO can we detect the model arch and dynamically set skip_modules
|
||||||
model.model = _replace_linear(
|
model.model = _replace_linear(
|
||||||
model.model,
|
model.model,
|
||||||
Linear4bit,
|
Linear4bit,
|
||||||
compute_dtype=compute_dtype,
|
compute_dtype=compute_dtype,
|
||||||
quant_type="nf4",
|
quant_type="nf4",
|
||||||
quant_storage=quant_storage,
|
quant_storage=quant_storage,
|
||||||
|
compress_statistics=True, # bnb_4bit_use_double_quant
|
||||||
|
skip_modules=[
|
||||||
|
"lm_head",
|
||||||
|
"embed_out",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
model.is_loaded_in_4bit = True
|
model.is_loaded_in_4bit = True
|
||||||
|
|
||||||
@@ -251,6 +264,11 @@ def load_sharded_model_quant(
|
|||||||
quant_method=quant_method,
|
quant_method=quant_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# these attributes are needed to inform transformers/peft of the quantization
|
||||||
|
model.is_quantized = True
|
||||||
|
model.quantization_method = "bitsandbytes"
|
||||||
|
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
|
||||||
|
|
||||||
if cfg.local_rank == 0 and verbose:
|
if cfg.local_rank == 0 and verbose:
|
||||||
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
|
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
|
||||||
# cleanup any extra memory usage from parallel loading
|
# cleanup any extra memory usage from parallel loading
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Module for models and model loading"""
|
"""Module for models and model loading"""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -29,6 +29,7 @@ from transformers import ( # noqa: F401
|
|||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
AwqConfig,
|
||||||
BitsAndBytesConfig,
|
BitsAndBytesConfig,
|
||||||
GPTQConfig,
|
GPTQConfig,
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
@@ -36,6 +37,7 @@ from transformers import ( # noqa: F401
|
|||||||
)
|
)
|
||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
|
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.monkeypatch.multipack import (
|
from axolotl.monkeypatch.multipack import (
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||||
@@ -94,12 +96,6 @@ def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDef
|
|||||||
"Please make sure to point to a GPTQ model."
|
"Please make sure to point to a GPTQ model."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not cfg.gptq and quant_config_exists:
|
|
||||||
raise ValueError(
|
|
||||||
"model_config.quantization_config is set but `gptq` flag is not. "
|
|
||||||
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
|
||||||
)
|
|
||||||
|
|
||||||
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
|
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
|
||||||
if (
|
if (
|
||||||
cfg.adapter
|
cfg.adapter
|
||||||
@@ -346,7 +342,36 @@ def load_model(
|
|||||||
and cfg.flash_attention
|
and cfg.flash_attention
|
||||||
and cfg.sample_packing
|
and cfg.sample_packing
|
||||||
):
|
):
|
||||||
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
|
patch_for_multipack(
|
||||||
|
cfg.model_config_type,
|
||||||
|
model_name=cfg.base_model,
|
||||||
|
is_remote_code=cfg.trust_remote_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.is_llama_derived_model:
|
||||||
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
|
patch_llama_cross_entropy,
|
||||||
|
patch_llama_rms_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.flash_attn_cross_entropy:
|
||||||
|
patch_llama_cross_entropy()
|
||||||
|
if cfg.flash_attn_rms_norm:
|
||||||
|
patch_llama_rms_norm()
|
||||||
|
elif cfg.unsloth_rms_norm:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
|
||||||
|
|
||||||
|
patch_unsloth_layernorm()
|
||||||
|
if cfg.unsloth_cross_entropy_loss:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import (
|
||||||
|
integrate_cross_entropy_loss_patch,
|
||||||
|
)
|
||||||
|
|
||||||
|
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||||
|
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||||
|
|
||||||
|
patch_self_attn_lora()
|
||||||
elif cfg.is_llama_derived_model:
|
elif cfg.is_llama_derived_model:
|
||||||
# Modify all llama derived models in one block
|
# Modify all llama derived models in one block
|
||||||
|
|
||||||
@@ -399,7 +424,7 @@ def load_model(
|
|||||||
if cfg.unsloth_cross_entropy_loss:
|
if cfg.unsloth_cross_entropy_loss:
|
||||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||||
|
|
||||||
integrate_cross_entropy_loss_patch()
|
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||||
|
|
||||||
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||||
@@ -407,23 +432,12 @@ def load_model(
|
|||||||
patch_self_attn_lora()
|
patch_self_attn_lora()
|
||||||
|
|
||||||
# Modify mistral derived models
|
# Modify mistral derived models
|
||||||
if (
|
if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss:
|
||||||
cfg.model_config_type == "mistral"
|
|
||||||
and cfg.flash_attention
|
|
||||||
and cfg.sample_packing
|
|
||||||
):
|
|
||||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||||
replace_mistral_attn_with_flash_attn,
|
patch_mistral_cross_entropy,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("patching mistral with flash attention")
|
patch_mistral_cross_entropy()
|
||||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
|
||||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
|
||||||
|
|
||||||
LOG.info("patching _expand_mask")
|
|
||||||
hijack_expand_mask()
|
|
||||||
|
|
||||||
model_kwargs: Dict[str, Any] = {}
|
model_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
@@ -496,7 +510,25 @@ def load_model(
|
|||||||
model_kwargs["quantization_config"] = GPTQConfig(
|
model_kwargs["quantization_config"] = GPTQConfig(
|
||||||
**model_config.quantization_config
|
**model_config.quantization_config
|
||||||
)
|
)
|
||||||
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
if (
|
||||||
|
cfg.adapter in ["qlora", "lora"]
|
||||||
|
and hasattr(model_config, "quantization_config")
|
||||||
|
and model_config.quantization_config["quant_method"]
|
||||||
|
in ["gptq", "awq", "bitsandbytes"]
|
||||||
|
):
|
||||||
|
if model_config.quantization_config["quant_method"] == "gptq":
|
||||||
|
model_kwargs["quantization_config"] = GPTQConfig(
|
||||||
|
**model_config.quantization_config
|
||||||
|
)
|
||||||
|
elif model_config.quantization_config["quant_method"] == "awq":
|
||||||
|
model_kwargs["quantization_config"] = AwqConfig(
|
||||||
|
**model_config.quantization_config
|
||||||
|
)
|
||||||
|
elif model_config.quantization_config["quant_method"] == "bitsandbytes":
|
||||||
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
**model_config.quantization_config
|
||||||
|
)
|
||||||
|
elif cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"llm_int8_threshold": 6.0,
|
"llm_int8_threshold": 6.0,
|
||||||
@@ -506,7 +538,9 @@ def load_model(
|
|||||||
"bnb_4bit_quant_type": "nf4",
|
"bnb_4bit_quant_type": "nf4",
|
||||||
"bnb_4bit_quant_storage": torch.bfloat16,
|
"bnb_4bit_quant_storage": torch.bfloat16,
|
||||||
}
|
}
|
||||||
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
|
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
|
||||||
|
cfg.deepspeed or cfg.fsdp
|
||||||
|
):
|
||||||
# for some reason, this causes the loss to be off by an order of magnitude
|
# for some reason, this causes the loss to be off by an order of magnitude
|
||||||
# but deepspeed needs this still in bfloat16
|
# but deepspeed needs this still in bfloat16
|
||||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||||
@@ -551,16 +585,10 @@ def load_model(
|
|||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
"flash_attention_2"
|
||||||
"flash_attention_2"
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_kwargs["attn_implementation"] = "eager"
|
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
|
||||||
"eager"
|
|
||||||
)
|
|
||||||
elif cfg.sdp_attention:
|
elif cfg.sdp_attention:
|
||||||
model_kwargs["attn_implementation"] = "sdpa"
|
model_kwargs["attn_implementation"] = "sdpa"
|
||||||
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
|
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
|
||||||
@@ -590,14 +618,21 @@ def load_model(
|
|||||||
elif (
|
elif (
|
||||||
qlora_fsdp
|
qlora_fsdp
|
||||||
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
|
||||||
and cfg.model_config_type == "dbrx"
|
and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading)
|
||||||
):
|
):
|
||||||
quant_storage = cfg.torch_dtype
|
quant_storage = cfg.torch_dtype
|
||||||
|
quantization_config = hasattr(
|
||||||
|
model_config, "quantization_config"
|
||||||
|
) and getattr(model_config, "quantization_config")
|
||||||
|
quantization_config = (
|
||||||
|
quantization_config or model_kwargs["quantization_config"]
|
||||||
|
)
|
||||||
model = load_sharded_model_quant(
|
model = load_sharded_model_quant(
|
||||||
base_model,
|
base_model,
|
||||||
model_config,
|
model_config,
|
||||||
cfg,
|
cfg,
|
||||||
quant_storage=quant_storage,
|
quant_storage=quant_storage,
|
||||||
|
quantization_config=quantization_config,
|
||||||
)
|
)
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
elif (
|
elif (
|
||||||
@@ -605,7 +640,7 @@ def load_model(
|
|||||||
and not cfg.trust_remote_code
|
and not cfg.trust_remote_code
|
||||||
and not cfg.gptq
|
and not cfg.gptq
|
||||||
):
|
):
|
||||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
del model_kwargs["device_map"]
|
del model_kwargs["device_map"]
|
||||||
@@ -687,7 +722,7 @@ def load_model(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
||||||
# disabling either of these two still leads to VRAM spike before setting back down
|
# disabling either of these two still leads to VRAM spike before setting back down
|
||||||
skip_move_to_device = True
|
skip_move_to_device = True
|
||||||
if "device_map" in model_kwargs:
|
if "device_map" in model_kwargs:
|
||||||
@@ -771,12 +806,16 @@ def load_model(
|
|||||||
set_z3_leaf_modules,
|
set_z3_leaf_modules,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.model_config_type == "mixtral":
|
if cfg.model_config_type in MOE_ARCH_BLOCK:
|
||||||
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
|
moe_blocks = MOE_ARCH_BLOCK[cfg.model_config_type]
|
||||||
set_z3_leaf_modules(model, [moe_block])
|
moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks
|
||||||
elif cfg.model_config_type == "dbrx":
|
set_z3_leaf_modules(
|
||||||
moe_block = get_module_class_from_name(model, "DbrxFFN")
|
model,
|
||||||
set_z3_leaf_modules(model, [moe_block])
|
[
|
||||||
|
get_module_class_from_name(model, module_name)
|
||||||
|
for module_name in moe_blocks
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
||||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||||
@@ -790,6 +829,9 @@ def load_model(
|
|||||||
# make sure everything is in the same dtype
|
# make sure everything is in the same dtype
|
||||||
skip_prepare_model_for_kbit_training = True
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
skip_prepare_model_for_kbit_training = True
|
||||||
|
|
||||||
if cfg.adapter in ["lora", "qlora"]:
|
if cfg.adapter in ["lora", "qlora"]:
|
||||||
if cfg.gradient_checkpointing:
|
if cfg.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable(
|
model.gradient_checkpointing_enable(
|
||||||
@@ -824,6 +866,9 @@ def load_model(
|
|||||||
else:
|
else:
|
||||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
skip_move_to_device = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.ddp
|
cfg.ddp
|
||||||
and not load_in_8bit
|
and not load_in_8bit
|
||||||
@@ -863,6 +908,15 @@ def load_model(
|
|||||||
|
|
||||||
integrate_lora_patch(model, cfg)
|
integrate_lora_patch(model, cfg)
|
||||||
|
|
||||||
|
if cfg.unsloth_rope:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
|
||||||
|
|
||||||
|
integrate_rope_embeddings()
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return model, lora_config
|
return model, lora_config
|
||||||
|
|
||||||
@@ -960,7 +1014,7 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
|
|
||||||
if cfg.lora_target_linear:
|
if cfg.lora_target_linear:
|
||||||
linear_names = find_all_linear_names(model)
|
linear_names = find_all_linear_names(model)
|
||||||
LOG.info(f"found linear modules: {repr(linear_names)}")
|
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
|
||||||
lora_target_modules = list(set(lora_target_modules + linear_names))
|
lora_target_modules = list(set(lora_target_modules + linear_names))
|
||||||
|
|
||||||
lora_config_kwargs = {}
|
lora_config_kwargs = {}
|
||||||
@@ -1036,9 +1090,20 @@ def load_lora(model, cfg, inference=False, config_only=False):
|
|||||||
|
|
||||||
def ensure_dtype(model, dtype=torch.bfloat16):
|
def ensure_dtype(model, dtype=torch.bfloat16):
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
|
weight_mismatch = False
|
||||||
|
bias_mismatch = False
|
||||||
try:
|
try:
|
||||||
if module.weight.dtype != dtype:
|
weight_mismatch = module.weight.dtype != dtype
|
||||||
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
|
|
||||||
module.to(dtype)
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
bias_mismatch = module.bias.dtype != dtype
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if weight_mismatch:
|
||||||
|
print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}")
|
||||||
|
if bias_mismatch:
|
||||||
|
print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
|
||||||
|
if weight_mismatch or bias_mismatch:
|
||||||
|
module.to(dtype)
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
|
|||||||
"""Helper function to process and color tokens."""
|
"""Helper function to process and color tokens."""
|
||||||
colored_tokens = [
|
colored_tokens = [
|
||||||
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
|
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
|
||||||
for token in tokenizer.encode(tokens)
|
for token in tokenizer.encode(tokens, add_special_tokens=False)
|
||||||
]
|
]
|
||||||
return colored_tokens
|
return colored_tokens
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module containing the Trainer class and related functions"""
|
"""Module containing the Trainer class and related functions"""
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
@@ -15,7 +16,7 @@ from torch.utils.data import DataLoader, RandomSampler
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
LOG = get_logger("axolotl")
|
LOG = get_logger("axolotl")
|
||||||
@@ -182,90 +183,88 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
sequence_len=cfg.sequence_len,
|
sequence_len=cfg.sequence_len,
|
||||||
min_sequence_len=cfg.min_sample_len or 2,
|
min_sequence_len=cfg.min_sample_len or 2,
|
||||||
)
|
)
|
||||||
with zero_first(is_main_process()):
|
|
||||||
if cfg.is_preprocess:
|
|
||||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
|
||||||
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
|
||||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
|
||||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
|
||||||
|
|
||||||
if (
|
if cfg.is_preprocess:
|
||||||
cfg.is_mistral_derived_model and cfg.flash_attention
|
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||||
) or cfg.model_config_type == "mamba":
|
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
||||||
LOG.info("dropping attention_mask column")
|
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
||||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||||
if eval_dataset:
|
|
||||||
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
|
||||||
|
|
||||||
if cfg.model_config_type == "falcon":
|
if cfg.model_config_type == "mamba":
|
||||||
LOG.info("dropping token_type_ids column if it exists")
|
LOG.info("dropping attention_mask column")
|
||||||
if "token_type_ids" in train_dataset.column_names:
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
train_dataset = train_dataset.remove_columns("token_type_ids")
|
if eval_dataset:
|
||||||
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
|
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||||
eval_dataset = eval_dataset.remove_columns("token_type_ids")
|
|
||||||
|
|
||||||
train_dataset = train_dataset.filter(
|
if cfg.model_config_type == "falcon":
|
||||||
|
LOG.info("dropping token_type_ids column if it exists")
|
||||||
|
if "token_type_ids" in train_dataset.column_names:
|
||||||
|
train_dataset = train_dataset.remove_columns("token_type_ids")
|
||||||
|
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
|
||||||
|
eval_dataset = eval_dataset.remove_columns("token_type_ids")
|
||||||
|
|
||||||
|
train_dataset = train_dataset.filter(
|
||||||
|
drop_long,
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Dropping Long Sequences",
|
||||||
|
)
|
||||||
|
if eval_dataset:
|
||||||
|
eval_dataset = eval_dataset.filter(
|
||||||
drop_long,
|
drop_long,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Dropping Long Sequences",
|
desc="Dropping Long Sequences",
|
||||||
)
|
)
|
||||||
if eval_dataset:
|
|
||||||
eval_dataset = eval_dataset.filter(
|
|
||||||
drop_long,
|
|
||||||
num_proc=cfg.dataset_processes,
|
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
|
||||||
desc="Dropping Long Sequences",
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_length,
|
add_length,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Group By Length",
|
desc="Group By Length",
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.use_pose:
|
if cfg.use_pose:
|
||||||
pose_kwargs = {}
|
pose_kwargs = {}
|
||||||
if cfg.pose_num_chunks is not None:
|
if cfg.pose_num_chunks is not None:
|
||||||
pose_kwargs["chunks"] = cfg.pose_num_chunks
|
pose_kwargs["chunks"] = cfg.pose_num_chunks
|
||||||
pose_fn = partial(
|
pose_fn = partial(
|
||||||
add_pose_position_ids,
|
add_pose_position_ids,
|
||||||
max_context_len=cfg.pose_max_context_len,
|
max_context_len=cfg.pose_max_context_len,
|
||||||
split_on_token_ids=cfg.pose_split_on_token_ids,
|
split_on_token_ids=cfg.pose_split_on_token_ids,
|
||||||
**pose_kwargs,
|
**pose_kwargs,
|
||||||
)
|
)
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
pose_fn,
|
pose_fn,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
train_dataset = train_dataset.sort("sequence_len")
|
train_dataset = train_dataset.sort("sequence_len")
|
||||||
if cfg.eval_sample_packing is not False:
|
if cfg.eval_sample_packing is not False:
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
pose_fn,
|
pose_fn,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
elif cfg.sample_packing:
|
elif cfg.sample_packing:
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (Sample Packing)",
|
desc="Add position_id column (Sample Packing)",
|
||||||
)
|
)
|
||||||
if cfg.eval_sample_packing is not False:
|
if cfg.eval_sample_packing is not False:
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
num_proc=cfg.dataset_processes,
|
num_proc=cfg.dataset_processes,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
desc="Add position_id column (Sample Packing)",
|
desc="Add position_id column (Sample Packing)",
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
@@ -391,6 +390,26 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
return total_num_steps
|
return total_num_steps
|
||||||
|
|
||||||
|
|
||||||
|
def setup_torch_compile_env(cfg):
|
||||||
|
if cfg.torch_compile:
|
||||||
|
if not cfg.torch_compile_backend:
|
||||||
|
os.environ["ACCELERATE_DYNAMO_BACKEND"] = "INDUCTOR"
|
||||||
|
else:
|
||||||
|
os.environ["ACCELERATE_DYNAMO_BACKEND"] = cfg.torch_compile_backend.upper()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_deepspeed_env(cfg, stage=None):
|
||||||
|
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
|
||||||
|
|
||||||
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
|
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||||
|
if stage:
|
||||||
|
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
|
||||||
|
if stage == 3:
|
||||||
|
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
|
||||||
|
HfTrainerDeepSpeedConfig(cfg.deepspeed)
|
||||||
|
|
||||||
|
|
||||||
def setup_fsdp_envs(cfg):
|
def setup_fsdp_envs(cfg):
|
||||||
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||||
if cfg.fsdp_config.fsdp_activation_checkpointing:
|
if cfg.fsdp_config.fsdp_activation_checkpointing:
|
||||||
@@ -417,8 +436,16 @@ def prepare_optim_env(cfg):
|
|||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
stage = None
|
||||||
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
# check if the cfg.deepspeed is a file
|
||||||
|
if os.path.isfile(cfg.deepspeed):
|
||||||
|
# parse with json
|
||||||
|
with open(cfg.deepspeed, "r", encoding="utf-8") as fin:
|
||||||
|
deepspeed_config = json.load(fin)
|
||||||
|
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
|
||||||
|
setup_deepspeed_env(cfg, stage=stage)
|
||||||
|
|
||||||
|
setup_torch_compile_env(cfg)
|
||||||
|
|
||||||
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
|
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
|
||||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
|
||||||
@@ -426,8 +453,14 @@ def prepare_optim_env(cfg):
|
|||||||
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_opinionated_env(cfg):
|
||||||
|
if cfg.qlora_sharded_model_loading:
|
||||||
|
# model loading is forked after the tokenizer
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.rl in ["dpo", "ipo", "orpo", "kto"]:
|
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
|
|||||||
0
tests/e2e/multigpu/__init__.py
Normal file
0
tests/e2e/multigpu/__init__.py
Normal file
341
tests/e2e/multigpu/test_llama.py
Normal file
341
tests/e2e/multigpu/test_llama.py
Normal file
@@ -0,0 +1,341 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for multigpu lora tinyllama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from ..utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiGPULlama(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using LoRA
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_lora_ddp(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.05,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 100,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_lora_ddp_packed(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"sample_packing": True,
|
||||||
|
"eval_sample_packing": False,
|
||||||
|
"pad_to_sequence_len": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.05,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 50,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_fsdp(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.05,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 100,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"fsdp": [
|
||||||
|
"full_shard",
|
||||||
|
"auto_wrap",
|
||||||
|
],
|
||||||
|
"fsdp_config": {
|
||||||
|
"fsdp_limit_all_gathers": True,
|
||||||
|
"fsdp_offload_params": False,
|
||||||
|
"fsdp_sync_module_states": True,
|
||||||
|
"fsdp_use_orig_params": False,
|
||||||
|
"fsdp_cpu_ram_efficient_loading": False,
|
||||||
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
|
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||||
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_fsdp_packed(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sample_packing": True,
|
||||||
|
"eval_sample_packing": False,
|
||||||
|
"pad_to_sequence_len": True,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.05,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 100,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"fsdp": [
|
||||||
|
"full_shard",
|
||||||
|
"auto_wrap",
|
||||||
|
],
|
||||||
|
"fsdp_config": {
|
||||||
|
"fsdp_limit_all_gathers": True,
|
||||||
|
"fsdp_offload_params": False,
|
||||||
|
"fsdp_sync_module_states": True,
|
||||||
|
"fsdp_use_orig_params": False,
|
||||||
|
"fsdp_cpu_ram_efficient_loading": False,
|
||||||
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
|
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||||
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.skip("disabled due to upstream issue")
|
||||||
|
@with_temp_dir
|
||||||
|
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "axolotl-ai-co/TinyLlama_v1.1-bnb-nf4-bf16",
|
||||||
|
"tokenizer_type": "AutoTokenizer",
|
||||||
|
"adapter": "qlora",
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"lora_modules_to_save": [
|
||||||
|
"embed_tokens",
|
||||||
|
"lm_head",
|
||||||
|
],
|
||||||
|
"sample_packing": True,
|
||||||
|
"eval_sample_packing": False,
|
||||||
|
"pad_to_sequence_len": True,
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.05,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|end_of_text|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"type": "alpaca",
|
||||||
|
"split": "train[:25%]",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 100,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 4,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"fsdp": [
|
||||||
|
"full_shard",
|
||||||
|
"auto_wrap",
|
||||||
|
],
|
||||||
|
"fsdp_config": {
|
||||||
|
"fsdp_limit_all_gathers": True,
|
||||||
|
"fsdp_offload_params": False,
|
||||||
|
"fsdp_sync_module_states": True,
|
||||||
|
"fsdp_use_orig_params": False,
|
||||||
|
"fsdp_cpu_ram_efficient_loading": True,
|
||||||
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
|
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||||
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
]
|
||||||
|
)
|
||||||
98
tests/e2e/multigpu/test_qwen2.py
Normal file
98
tests/e2e/multigpu/test_qwen2.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for multigpu qwen2
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from ..utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiGPUQwen2(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using LoRA
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_qlora_fsdp_dpo(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "Qwen/Qwen2-1.5B",
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"rl": "dpo",
|
||||||
|
"chat_template": "chatml",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"adapter": "qlora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.05,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "Intel/orca_dpo_pairs",
|
||||||
|
"split": "train",
|
||||||
|
"type": "chatml.intel",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 100,
|
||||||
|
"warmup_steps": 20,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
"tf32": True,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
|
"gradient_checkpointing_kwargs": {
|
||||||
|
"use_reentrant": False,
|
||||||
|
},
|
||||||
|
"fsdp": [
|
||||||
|
"full_shard",
|
||||||
|
"auto_wrap",
|
||||||
|
],
|
||||||
|
"fsdp_config": {
|
||||||
|
"fsdp_limit_all_gathers": True,
|
||||||
|
"fsdp_offload_params": False,
|
||||||
|
"fsdp_sync_module_states": True,
|
||||||
|
"fsdp_use_orig_params": False,
|
||||||
|
"fsdp_cpu_ram_efficient_loading": False,
|
||||||
|
"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
|
||||||
|
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||||
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
|
"fsdp_sharding_strategy": "FULL_SHARD",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
]
|
||||||
|
)
|
||||||
@@ -4,6 +4,8 @@ E2E smoke tests to check that the monkeypatches are in place for certain configu
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -87,9 +89,9 @@ class TestModelPatches(unittest.TestCase):
|
|||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
"axolotl.monkeypatch.mistral_attn_hijack_flash"
|
"torch.jit"
|
||||||
in model.model.layers[0].self_attn.forward.__module__
|
in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access
|
||||||
)
|
)
|
||||||
|
|||||||
20
tests/e2e/test_imports.py
Normal file
20
tests/e2e/test_imports.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
test module to import various submodules that have historically broken due to dependency issues
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class TestImports(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test class to import various submodules that have historically broken due to dependency issues
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_import_causal_trainer(self):
|
||||||
|
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
|
||||||
|
HFCausalTrainerBuilder,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_import_rl_trainer(self):
|
||||||
|
from axolotl.core.trainer_builder import ( # pylint: disable=unused-import # noqa: F401
|
||||||
|
HFRLTrainerBuilder,
|
||||||
|
)
|
||||||
67
tests/e2e/test_llama_pretrain.py
Normal file
67
tests/e2e/test_llama_pretrain.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for llama pretrain
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPretrainLlama(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models w pretraining
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_pretrain_w_sample_packing(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"sample_packing": True,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"pretraining_dataset": [
|
||||||
|
{
|
||||||
|
"path": "allenai/c4",
|
||||||
|
"name": "en",
|
||||||
|
"type": "pretrain",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_steps": 5,
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"save_safetensors": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
@@ -34,8 +34,8 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 32,
|
"lora_r": 8,
|
||||||
"lora_alpha": 64,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.1,
|
"val_set_size": 0.1,
|
||||||
@@ -50,7 +50,7 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 2,
|
"num_epochs": 1,
|
||||||
"micro_batch_size": 8,
|
"micro_batch_size": 8,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
|||||||
67
tests/e2e/test_optimizers.py
Normal file
67
tests/e2e/test_optimizers.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for custom optimizers using Llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomOptimizers(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using LoRA
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_optimi_adamw(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "optimi_adamw",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
tests for chat_template prompt strategy
|
tests for chat_template prompt strategy
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -13,33 +14,24 @@ from axolotl.prompt_strategies.chat_template import (
|
|||||||
ChatTemplateStrategy,
|
ChatTemplateStrategy,
|
||||||
load,
|
load,
|
||||||
)
|
)
|
||||||
|
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="assistant_dataset")
|
@pytest.fixture(name="assistant_dataset")
|
||||||
def fixture_assistant_dataset():
|
def fixture_assistant_dataset():
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
return Dataset.from_list(
|
return Dataset.from_list(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{"role": "user", "content": "hello"},
|
||||||
"role": "user",
|
{"role": "assistant", "content": "hello"},
|
||||||
"content": "hello",
|
{"role": "user", "content": "goodbye"},
|
||||||
},
|
{"role": "assistant", "content": "goodbye"},
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "hello",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "goodbye",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "goodbye",
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -53,22 +45,28 @@ def fixture_sharegpt_dataset():
|
|||||||
[
|
[
|
||||||
{
|
{
|
||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{"from": "human", "value": "hello"},
|
||||||
"from": "human",
|
{"from": "gpt", "value": "hello"},
|
||||||
"value": "hello",
|
{"from": "human", "value": "goodbye"},
|
||||||
},
|
{"from": "gpt", "value": "goodbye"},
|
||||||
{
|
]
|
||||||
"from": "gpt",
|
}
|
||||||
"value": "hello",
|
]
|
||||||
},
|
)
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "goodbye",
|
@pytest.fixture(name="basic_dataset")
|
||||||
},
|
def fixture_basic_dataset():
|
||||||
{
|
# pylint: disable=duplicate-code
|
||||||
"from": "gpt",
|
return Dataset.from_list(
|
||||||
"value": "goodbye",
|
[
|
||||||
},
|
{
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "You are an AI assistant."},
|
||||||
|
{"from": "human", "value": "Hello"},
|
||||||
|
{"from": "assistant", "value": "Hi there!"},
|
||||||
|
{"from": "human", "value": "How are you?"},
|
||||||
|
{"from": "assistant", "value": "I'm doing well, thank you!"},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -77,19 +75,611 @@ def fixture_sharegpt_dataset():
|
|||||||
|
|
||||||
@pytest.fixture(name="llama3_tokenizer")
|
@pytest.fixture(name="llama3_tokenizer")
|
||||||
def fixture_llama3_tokenizer():
|
def fixture_llama3_tokenizer():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||||
tokenizer.eos_token = "<|eot_id|>"
|
|
||||||
|
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatTemplateConfigurations:
|
||||||
|
"""
|
||||||
|
Test class for various configurations of ChatTemplateStrategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_sublist(full_list, sub_list):
|
||||||
|
token_count = len(sub_list)
|
||||||
|
for index in range(len(full_list) - token_count + 1):
|
||||||
|
if full_list[index : index + token_count] == sub_list:
|
||||||
|
return index
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_inputs=True")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=True,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Verify that assistant responses are labeled
|
||||||
|
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||||
|
for response in assistant_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
LOG.debug(
|
||||||
|
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
assert start_idx != -1, f"Could not find '{response}' in input_ids"
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||||
|
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||||
|
|
||||||
|
# Check the behavior of human inputs
|
||||||
|
human_inputs = ["Hello", "How are you?"]
|
||||||
|
for input_text in human_inputs:
|
||||||
|
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, input_ids)
|
||||||
|
labeled = all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(input_ids)]
|
||||||
|
)
|
||||||
|
LOG.debug(
|
||||||
|
f"Human input '{input_text}' is {'labeled' if labeled else 'not labeled'}, expected IDs: {input_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug("Full labels: %s", labels)
|
||||||
|
LOG.debug("Full input_ids: %s", input_ids)
|
||||||
|
|
||||||
|
def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_inputs=False")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Verify that only assistant responses are labeled
|
||||||
|
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||||
|
for response in assistant_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
LOG.debug(
|
||||||
|
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
assert start_idx != -1, f"Could not find '{response}' in input_ids"
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||||
|
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||||
|
|
||||||
|
# Verify that human inputs are not labeled
|
||||||
|
human_inputs = ["Hello", "How are you?"]
|
||||||
|
for input_text in human_inputs:
|
||||||
|
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, input_ids)
|
||||||
|
LOG.debug(
|
||||||
|
f"Human input '{input_text}' expected IDs: {input_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
|
||||||
|
assert all(
|
||||||
|
label == IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(input_ids)]
|
||||||
|
), f"Expected labels for human input '{input_text}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:start_idx+len(input_ids)]}"
|
||||||
|
|
||||||
|
def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing roles_to_train with assistant only")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Verify that only assistant responses are labeled
|
||||||
|
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||||
|
for response in assistant_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
LOG.debug(
|
||||||
|
f"Assistant response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||||
|
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||||
|
|
||||||
|
def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing roles_to_train with all roles")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=True,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["human", "assistant"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Verify that all responses are labeled (except for special tokens)
|
||||||
|
all_responses = [
|
||||||
|
"Hello",
|
||||||
|
"Hi there!",
|
||||||
|
"How are you?",
|
||||||
|
"I'm doing well, thank you!",
|
||||||
|
]
|
||||||
|
for response in all_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
LOG.debug(
|
||||||
|
f"Response '{response}' expected IDs: {response_ids}, found at: {start_idx}"
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||||
|
), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:start_idx+len(response_ids)]}"
|
||||||
|
|
||||||
|
def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with empty roles_to_train")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=[],
|
||||||
|
train_on_eos="none", # Add this line
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
|
||||||
|
# Verify that no labels are set when roles_to_train is empty
|
||||||
|
LOG.debug("Full labels: %s", labels)
|
||||||
|
assert all(
|
||||||
|
label == IGNORE_TOKEN_ID for label in labels
|
||||||
|
), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
|
||||||
|
|
||||||
|
def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_eos='all'")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eos="all",
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
eos_token_id = llama3_tokenizer.eos_token_id
|
||||||
|
eos_indices = [
|
||||||
|
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||||
|
for eos_idx in eos_indices:
|
||||||
|
assert (
|
||||||
|
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOS token at index {eos_idx} to be labeled"
|
||||||
|
|
||||||
|
def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_eos='turn'")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eos="turn",
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
eos_token_id = llama3_tokenizer.eos_token_id
|
||||||
|
assistant_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||||
|
|
||||||
|
for response in assistant_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
assert start_idx != -1, f"Could not find '{response}' in input_ids"
|
||||||
|
|
||||||
|
eos_idx = start_idx + len(response_ids)
|
||||||
|
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
|
||||||
|
eos_idx += 1
|
||||||
|
|
||||||
|
assert eos_idx < len(
|
||||||
|
input_ids
|
||||||
|
), f"Could not find EOS token after '{response}'"
|
||||||
|
assert (
|
||||||
|
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOS token after assistant response '{response}' to be labeled"
|
||||||
|
|
||||||
|
# Check that EOS tokens after human inputs are not labeled
|
||||||
|
human_inputs = ["Hello", "How are you?"]
|
||||||
|
for input_text in human_inputs:
|
||||||
|
input_ids = llama3_tokenizer.encode(input_text, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, input_ids)
|
||||||
|
assert start_idx != -1, f"Could not find '{input_text}' in input_ids"
|
||||||
|
|
||||||
|
eos_idx = start_idx + len(input_ids)
|
||||||
|
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
|
||||||
|
eos_idx += 1
|
||||||
|
|
||||||
|
assert (
|
||||||
|
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOS token after human input '{input_text}' to not be labeled"
|
||||||
|
|
||||||
|
def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_eos='last'")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eos="last",
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
eos_token_id = llama3_tokenizer.eos_token_id
|
||||||
|
eos_indices = [
|
||||||
|
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||||
|
last_eos_idx = eos_indices[-1]
|
||||||
|
|
||||||
|
# Check that only the last EOS token is labeled
|
||||||
|
for idx in eos_indices[:-1]:
|
||||||
|
assert (
|
||||||
|
labels[idx] == IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOS token at index {idx} to not be labeled"
|
||||||
|
assert (
|
||||||
|
labels[last_eos_idx] != IGNORE_TOKEN_ID
|
||||||
|
), f"Expected last EOS token at index {last_eos_idx} to be labeled"
|
||||||
|
|
||||||
|
def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with train_on_eos='none'")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
train_on_eos="none",
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
eos_token_id = llama3_tokenizer.eos_token_id
|
||||||
|
eos_indices = [
|
||||||
|
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||||
|
for eos_idx in eos_indices:
|
||||||
|
assert (
|
||||||
|
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||||
|
), f"Expected EOS token at index {eos_idx} to not be labeled"
|
||||||
|
|
||||||
|
def test_drop_system_message(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing with drop_system_message=True")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_templates("llama3"), drop_system_message=True
|
||||||
|
),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Check if system message is not present in input_ids
|
||||||
|
system_message = "You are an AI assistant."
|
||||||
|
system_ids = llama3_tokenizer.encode(system_message, add_special_tokens=False)
|
||||||
|
assert (
|
||||||
|
self.find_sublist(input_ids, system_ids) == -1
|
||||||
|
), "Expected system message to be dropped"
|
||||||
|
|
||||||
|
def test_custom_roles(self, llama3_tokenizer):
|
||||||
|
LOG.info("Testing with custom roles mapping")
|
||||||
|
custom_roles = {
|
||||||
|
"user": ["human", "user"],
|
||||||
|
"assistant": ["ai", "assistant"],
|
||||||
|
"system": ["context"],
|
||||||
|
}
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer, chat_templates("llama3"), roles=custom_roles
|
||||||
|
),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["ai"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a new dataset with modified role names
|
||||||
|
modified_conversations = [
|
||||||
|
{"from": "context", "value": "You are an AI assistant."},
|
||||||
|
{"from": "human", "value": "Hello"},
|
||||||
|
{"from": "ai", "value": "Hi there!"},
|
||||||
|
{"from": "human", "value": "How are you?"},
|
||||||
|
{"from": "ai", "value": "I'm doing well, thank you!"},
|
||||||
|
]
|
||||||
|
|
||||||
|
modified_dataset = Dataset.from_dict(
|
||||||
|
{"conversations": [modified_conversations]}
|
||||||
|
)
|
||||||
|
|
||||||
|
res = strategy.tokenize_prompt(modified_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Check if AI responses are labeled correctly
|
||||||
|
ai_responses = ["Hi there!", "I'm doing well, thank you!"]
|
||||||
|
for response in ai_responses:
|
||||||
|
response_ids = llama3_tokenizer.encode(response, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, response_ids)
|
||||||
|
assert start_idx != -1, f"Could not find response '{response}' in input_ids"
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(response_ids)]
|
||||||
|
), f"Expected labels for AI response '{response}' to be set"
|
||||||
|
|
||||||
|
# Check if human messages are not labeled
|
||||||
|
human_messages = ["Hello", "How are you?"]
|
||||||
|
for message in human_messages:
|
||||||
|
message_ids = llama3_tokenizer.encode(message, add_special_tokens=False)
|
||||||
|
start_idx = self.find_sublist(input_ids, message_ids)
|
||||||
|
assert start_idx != -1, f"Could not find message '{message}' in input_ids"
|
||||||
|
assert all(
|
||||||
|
label == IGNORE_TOKEN_ID
|
||||||
|
for label in labels[start_idx : start_idx + len(message_ids)]
|
||||||
|
), f"Expected labels for human message '{message}' to be IGNORE_TOKEN_ID"
|
||||||
|
|
||||||
|
def test_message_field_training(self, llama3_tokenizer):
|
||||||
|
LOG.info("Testing with message_field_training")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer,
|
||||||
|
chat_templates("llama3"),
|
||||||
|
message_field_training="train",
|
||||||
|
message_field_training_detail="train_detail",
|
||||||
|
),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a new dataset with the train and train_detail fields
|
||||||
|
modified_conversation = [
|
||||||
|
{"from": "system", "value": "You are an AI assistant.", "train": False},
|
||||||
|
{"from": "human", "value": "Hello", "train": False},
|
||||||
|
{"from": "assistant", "value": "Hello", "train": True},
|
||||||
|
{"from": "human", "value": "How are you?", "train": True},
|
||||||
|
{
|
||||||
|
"from": "assistant",
|
||||||
|
"value": "I'm doing very well, thank you!",
|
||||||
|
"train_detail": [
|
||||||
|
{"begin_offset": 0, "end_offset": 8, "train": False},
|
||||||
|
{"begin_offset": 9, "end_offset": 18, "train": True},
|
||||||
|
{"begin_offset": 19, "end_offset": 30, "train": False},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"from": "human",
|
||||||
|
"value": "I'm doing very well, thank you!",
|
||||||
|
"train": False,
|
||||||
|
},
|
||||||
|
{"from": "assistant", "value": "Hi there!", "train": True},
|
||||||
|
]
|
||||||
|
|
||||||
|
modified_dataset = Dataset.from_dict({"conversations": [modified_conversation]})
|
||||||
|
|
||||||
|
res = strategy.tokenize_prompt(modified_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
|
||||||
|
# Function to find all occurrences of a sublist
|
||||||
|
def find_all_sublists(full_list, sub_list):
|
||||||
|
indices = []
|
||||||
|
for index in range(len(full_list) - len(sub_list) + 1):
|
||||||
|
if full_list[index : index + len(sub_list)] == sub_list:
|
||||||
|
indices.append(index)
|
||||||
|
return indices
|
||||||
|
|
||||||
|
# Keep track of which occurrences we've processed
|
||||||
|
processed_occurrences = {}
|
||||||
|
# Check if messages are labeled correctly based on train or train_detail
|
||||||
|
for i, turn in enumerate(modified_conversation):
|
||||||
|
turn_tokens = llama3_tokenizer.encode(
|
||||||
|
turn["value"], add_special_tokens=False
|
||||||
|
)
|
||||||
|
occurrences = find_all_sublists(input_ids, turn_tokens)
|
||||||
|
turn_key = turn["value"]
|
||||||
|
if turn_key not in processed_occurrences:
|
||||||
|
processed_occurrences[turn_key] = 0
|
||||||
|
current_occurrence = processed_occurrences[turn_key]
|
||||||
|
|
||||||
|
if current_occurrence >= len(occurrences):
|
||||||
|
assert (
|
||||||
|
False
|
||||||
|
), f"Not enough occurrences found for message: {turn['value']}"
|
||||||
|
|
||||||
|
start_idx = occurrences[current_occurrence]
|
||||||
|
processed_occurrences[turn_key] += 1
|
||||||
|
end_idx = start_idx + len(turn_tokens)
|
||||||
|
|
||||||
|
LOG.debug(
|
||||||
|
f"Processing turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "train_detail" in turn:
|
||||||
|
# Get token offsets
|
||||||
|
tokenized_output = llama3_tokenizer(
|
||||||
|
turn["value"], return_offsets_mapping=True, add_special_tokens=False
|
||||||
|
)
|
||||||
|
token_offsets = tokenized_output["offset_mapping"]
|
||||||
|
|
||||||
|
# Adjust token offsets as done in the implementation
|
||||||
|
for i in range(len(token_offsets) - 1):
|
||||||
|
token_offsets[i] = (
|
||||||
|
token_offsets[i][0],
|
||||||
|
token_offsets[i + 1][0] - 1,
|
||||||
|
)
|
||||||
|
token_offsets[-1] = (token_offsets[-1][0], len(turn["value"]) - 1)
|
||||||
|
|
||||||
|
# Adjust train_details
|
||||||
|
adjusted_train_details = strategy.prompter.adjust_train_details(
|
||||||
|
turn["train_detail"], token_offsets
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug(f"Original train_details: {turn['train_detail']}")
|
||||||
|
LOG.debug(f"Adjusted train_details: {adjusted_train_details}")
|
||||||
|
|
||||||
|
# Handle train_detail
|
||||||
|
token_offsets = strategy.prompter.get_offsets_for_train_detail(
|
||||||
|
text=turn["value"],
|
||||||
|
train_details=adjusted_train_details,
|
||||||
|
mask_untrainable=False,
|
||||||
|
)
|
||||||
|
token_offsets_masked = strategy.prompter.get_offsets_for_train_detail(
|
||||||
|
text=turn["value"],
|
||||||
|
train_details=adjusted_train_details,
|
||||||
|
mask_untrainable=True,
|
||||||
|
)
|
||||||
|
LOG.debug(f"Token offsets: {token_offsets_masked}")
|
||||||
|
|
||||||
|
expected_labels = [IGNORE_TOKEN_ID] * len(turn_tokens)
|
||||||
|
for i, offset in enumerate(token_offsets_masked):
|
||||||
|
if offset != IGNORE_TOKEN_ID:
|
||||||
|
expected_labels[i] = turn_tokens[i]
|
||||||
|
actual_labels = labels[
|
||||||
|
start_idx : start_idx + len(token_offsets_masked)
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
actual_labels == expected_labels
|
||||||
|
), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
|
||||||
|
|
||||||
|
for detail in adjusted_train_details:
|
||||||
|
# Find the token indices that correspond to the character offsets
|
||||||
|
detail_start = start_idx + next(
|
||||||
|
i
|
||||||
|
for i, offset in enumerate(token_offsets)
|
||||||
|
if offset >= detail["begin_offset"]
|
||||||
|
)
|
||||||
|
detail_end = start_idx + next(
|
||||||
|
(
|
||||||
|
i
|
||||||
|
for i, offset in enumerate(token_offsets)
|
||||||
|
if offset > detail["end_offset"]
|
||||||
|
),
|
||||||
|
len(token_offsets),
|
||||||
|
)
|
||||||
|
|
||||||
|
detail_text = turn["value"][
|
||||||
|
detail["begin_offset"] : detail["end_offset"] + 1
|
||||||
|
]
|
||||||
|
detail_labels = labels[detail_start:detail_end]
|
||||||
|
detail_input_ids = input_ids[detail_start:detail_end]
|
||||||
|
|
||||||
|
LOG.debug(
|
||||||
|
f"Detail: '{detail_text}', Start: {detail_start}, End: {detail_end}"
|
||||||
|
)
|
||||||
|
LOG.debug(f"Detail input_ids: {detail_input_ids}")
|
||||||
|
LOG.debug(f"Detail labels: {detail_labels}")
|
||||||
|
LOG.debug(
|
||||||
|
f"Decoded detail: {llama3_tokenizer.decode(detail_input_ids)}"
|
||||||
|
)
|
||||||
|
LOG.debug(
|
||||||
|
f"Token offsets for this detail: {token_offsets[detail_start-start_idx:detail_end-start_idx]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if detail["train"]:
|
||||||
|
assert all(
|
||||||
|
label != IGNORE_TOKEN_ID for label in detail_labels
|
||||||
|
), (
|
||||||
|
f"Expected labels for trainable detail '{detail_text}' to be set, but some were IGNORE_TOKEN_ID. "
|
||||||
|
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
|
||||||
|
f"InputIDs: {detail_input_ids}, "
|
||||||
|
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert all(
|
||||||
|
label == IGNORE_TOKEN_ID for label in detail_labels
|
||||||
|
), (
|
||||||
|
f"Expected all labels for non-trainable detail '{detail_text}' to be IGNORE_TOKEN_ID, but some were not. "
|
||||||
|
f"Labels({detail_start}:{detail_end}): {detail_labels}, "
|
||||||
|
f"InputIDs: {detail_input_ids}, "
|
||||||
|
f"Decoded: '{llama3_tokenizer.decode(detail_input_ids)}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
should_train = turn.get("train", False)
|
||||||
|
turn_labels = labels[start_idx:end_idx]
|
||||||
|
|
||||||
|
LOG.debug(f"Should train: {should_train}")
|
||||||
|
LOG.debug(f"Turn indices: start={start_idx}, end={end_idx}")
|
||||||
|
LOG.debug(f"Turn labels: {turn_labels}")
|
||||||
|
LOG.debug(f"Turn input IDs: {input_ids[start_idx:end_idx]}")
|
||||||
|
LOG.debug(
|
||||||
|
f"Decoded turn: {llama3_tokenizer.decode(input_ids[start_idx:end_idx])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_train:
|
||||||
|
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
|
||||||
|
f"Expected all labels for '{turn['value']}' to be set\n"
|
||||||
|
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
|
||||||
|
f"InputIDs: {input_ids[start_idx:end_idx]}, "
|
||||||
|
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert all(label == IGNORE_TOKEN_ID for label in turn_labels), (
|
||||||
|
f"Expected all labels for '{turn['value']}' to be IGNORE_TOKEN_ID\n"
|
||||||
|
f"Labels({start_idx}:{end_idx}): {turn_labels}, "
|
||||||
|
f"InputIDs: {input_ids[start_idx:end_idx]}, "
|
||||||
|
f"Decoded: '{llama3_tokenizer.decode(input_ids[start_idx:end_idx])}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug(
|
||||||
|
f"Processed turn: {turn['from']}, content: '{turn['value']}', "
|
||||||
|
f"start_idx: {start_idx}, end_idx: {end_idx}, "
|
||||||
|
f"labels: {labels[start_idx:end_idx]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug(f"Final labels: {labels}")
|
||||||
|
LOG.debug(f"Final input_ids: {input_ids}")
|
||||||
|
|
||||||
|
|
||||||
class TestAssistantChatTemplateLlama3:
|
class TestAssistantChatTemplateLlama3:
|
||||||
"""
|
"""
|
||||||
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
|
def test_llama3_load(self, llama3_tokenizer, assistant_dataset):
|
||||||
# pylint: disable=duplicate-code
|
LOG.info("Loading llama-3 tokenizer with assistant dataset")
|
||||||
strategy = load(
|
strategy = load(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
DictDefault(
|
DictDefault(
|
||||||
@@ -115,21 +705,26 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
# fmt: off
|
# fmt: off
|
||||||
assert input_ids == [
|
expected_input_ids = [
|
||||||
128000, # bos
|
128000, # bos
|
||||||
128006, 882, 128007, # user header
|
128006, 882, 128007, # user header
|
||||||
271, 15339, 128009, # user prompt eot
|
271, 15339, 128009, # user prompt eot
|
||||||
128006, 78191, 128007, # assistant header
|
128006, 78191, 128007, # assistant header
|
||||||
271, 15339, 128009, # assistant response eot
|
271, 15339, 128009, # assistant response eot
|
||||||
128006, 882, 128007,
|
128006, 882, 128007,
|
||||||
271, 19045, 29474, 128009,
|
271, 19045, 29474, 128009,
|
||||||
128006, 78191, 128007,
|
128006, 78191, 128007,
|
||||||
271, 19045, 29474, 128009,
|
271, 19045, 29474, 128009,
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||||
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||||
|
assert (
|
||||||
|
input_ids == expected_input_ids
|
||||||
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||||
|
|
||||||
def test_llama3(self, llama3_tokenizer, assistant_dataset):
|
def test_llama3(self, llama3_tokenizer, assistant_dataset):
|
||||||
# pylint: disable=duplicate-code
|
LOG.info("Testing llama-3 with assistant dataset")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(
|
ChatTemplatePrompter(
|
||||||
llama3_tokenizer,
|
llama3_tokenizer,
|
||||||
@@ -142,15 +737,16 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
"system": ["system"],
|
"system": ["system"],
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
False,
|
train_on_inputs=False,
|
||||||
512,
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
)
|
)
|
||||||
strategy.messages = "messages"
|
strategy.messages = "messages"
|
||||||
res = strategy.tokenize_prompt(assistant_dataset[0])
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
# fmt: off
|
# fmt: off
|
||||||
assert input_ids == [
|
expected_input_ids = [
|
||||||
128000, # bos
|
128000, # bos
|
||||||
128006, 882, 128007, # user header
|
128006, 882, 128007, # user header
|
||||||
271, 15339, 128009, # user prompt eot
|
271, 15339, 128009, # user prompt eot
|
||||||
@@ -162,6 +758,64 @@ class TestAssistantChatTemplateLlama3:
|
|||||||
271, 19045, 29474, 128009,
|
271, 19045, 29474, 128009,
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||||
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||||
|
assert (
|
||||||
|
input_ids == expected_input_ids
|
||||||
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||||
|
|
||||||
|
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
|
||||||
|
LOG.info("Testing llama-3 with assistant dataset including training data")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(
|
||||||
|
llama3_tokenizer,
|
||||||
|
chat_templates("llama3"),
|
||||||
|
message_field_role="role",
|
||||||
|
message_field_content="content",
|
||||||
|
message_field_training="training",
|
||||||
|
roles={
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
"system": ["system"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
train_on_eos="none",
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["assistant"],
|
||||||
|
)
|
||||||
|
strategy.messages = "messages"
|
||||||
|
prompt_tokens = strategy.prompter.build_prompt(
|
||||||
|
assistant_dataset[0]["messages"], False
|
||||||
|
)
|
||||||
|
prompt = llama3_tokenizer.decode(prompt_tokens, skip_special_tokens=False)
|
||||||
|
LOG.debug(f"Generated prompt: {prompt}")
|
||||||
|
res = strategy.tokenize_prompt(assistant_dataset[0])
|
||||||
|
labels = res["labels"]
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
# fmt: off
|
||||||
|
expected_labels = [
|
||||||
|
IGNORE_TOKEN_ID, # bos
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||||
|
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
LOG.debug(f"Expected labels: {expected_labels}")
|
||||||
|
LOG.debug(f"Actual labels: {labels}")
|
||||||
|
assert labels == expected_labels, (
|
||||||
|
f"Labels mismatch:\n"
|
||||||
|
f"Expected: {expected_labels}\n"
|
||||||
|
f"Actual: {labels}\n"
|
||||||
|
f"Input IDs: {input_ids}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSharegptChatTemplateLlama3:
|
class TestSharegptChatTemplateLlama3:
|
||||||
@@ -169,30 +823,160 @@ class TestSharegptChatTemplateLlama3:
|
|||||||
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
|
Test class for ShareGPT style datasets with llama-3 prompts using the chat_template strategy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_llama3(self, llama3_tokenizer, sharegpt_dataset):
|
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
|
||||||
# pylint: disable=duplicate-code
|
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
|
||||||
strategy = ChatTemplateStrategy(
|
strategy = ChatTemplateStrategy(
|
||||||
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
llama3_tokenizer,
|
tokenizer=llama3_tokenizer,
|
||||||
False,
|
train_on_inputs=False,
|
||||||
512,
|
train_on_eos="none",
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["gpt"],
|
||||||
)
|
)
|
||||||
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||||
input_ids = res["input_ids"]
|
input_ids = res["input_ids"]
|
||||||
|
labels = res["labels"]
|
||||||
# fmt: off
|
# fmt: off
|
||||||
assert input_ids == [
|
expected_input_ids = [
|
||||||
128000, # bos
|
128000, # bos
|
||||||
128006, 882, 128007, # user header
|
128006, 882, 128007, # user header
|
||||||
271, 15339, 128009, # user prompt eot
|
271, 15339, 128009, # user prompt eot
|
||||||
128006, 78191, 128007, # assistant header
|
128006, 78191, 128007, # assistant header
|
||||||
271, 15339, 128009, # assistant response eot
|
271, 15339, 128009, # assistant response eot
|
||||||
128006, 882, 128007,
|
128006, 882, 128007,
|
||||||
271, 19045, 29474, 128009,
|
271, 19045, 29474, 128009,
|
||||||
128006, 78191, 128007,
|
128006, 78191, 128007,
|
||||||
271, 19045, 29474, 128009,
|
271, 19045, 29474, 128009,
|
||||||
]
|
]
|
||||||
|
expected_labels = [
|
||||||
|
IGNORE_TOKEN_ID, # bos
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user prompt eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||||
|
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # assistant response eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
|
||||||
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||||
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||||
|
LOG.debug(f"Expected labels: {expected_labels}")
|
||||||
|
LOG.debug(f"Actual labels: {labels}")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
input_ids == expected_input_ids
|
||||||
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||||
|
assert (
|
||||||
|
labels == expected_labels
|
||||||
|
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||||
|
|
||||||
|
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
|
||||||
|
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
train_on_eos="none",
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["human"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(sharegpt_dataset[0])
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
labels = res["labels"]
|
||||||
|
# fmt: off
|
||||||
|
expected_input_ids = [
|
||||||
|
128000, # bos
|
||||||
|
128006, 882, 128007, # user header
|
||||||
|
271, 15339, 128009, # user prompt eot
|
||||||
|
128006, 78191, 128007, # assistant header
|
||||||
|
271, 15339, 128009, # assistant response eot
|
||||||
|
128006, 882, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
128006, 78191, 128007,
|
||||||
|
271, 19045, 29474, 128009,
|
||||||
|
]
|
||||||
|
expected_labels = [
|
||||||
|
IGNORE_TOKEN_ID, # bos
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
||||||
|
IGNORE_TOKEN_ID, 15339, IGNORE_TOKEN_ID, # user prompt eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, 19045, 29474, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||||
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||||
|
LOG.debug(f"Expected labels: {expected_labels}")
|
||||||
|
LOG.debug(f"Actual labels: {labels}")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
input_ids == expected_input_ids
|
||||||
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||||
|
assert (
|
||||||
|
labels == expected_labels
|
||||||
|
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||||
|
|
||||||
|
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
|
||||||
|
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
|
||||||
|
strategy = ChatTemplateStrategy(
|
||||||
|
ChatTemplatePrompter(llama3_tokenizer, chat_templates("llama3")),
|
||||||
|
tokenizer=llama3_tokenizer,
|
||||||
|
train_on_inputs=False,
|
||||||
|
train_on_eos="none",
|
||||||
|
sequence_len=512,
|
||||||
|
roles_to_train=["system", "human"],
|
||||||
|
)
|
||||||
|
res = strategy.tokenize_prompt(basic_dataset[0])
|
||||||
|
input_ids = res["input_ids"]
|
||||||
|
labels = res["labels"]
|
||||||
|
# fmt: off
|
||||||
|
expected_input_ids = [
|
||||||
|
128000, # bos
|
||||||
|
128006, 9125, 128007,
|
||||||
|
271, 2675, 527, 459, 15592, 18328, 13, 128009,
|
||||||
|
128006, 882, 128007, # user header
|
||||||
|
271, 9906, 128009, # user prompt eot
|
||||||
|
128006, 78191, 128007, # assistant header
|
||||||
|
271, 13347, 1070, 0, 128009, # assistant response eot
|
||||||
|
128006, 882, 128007,
|
||||||
|
271, 4438, 527, 499, 30, 128009,
|
||||||
|
128006, 78191, 128007,
|
||||||
|
271, 40, 2846, 3815, 1664, 11, 9901, 499, 0, 128009,
|
||||||
|
]
|
||||||
|
expected_labels = [
|
||||||
|
IGNORE_TOKEN_ID, # bos
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # system header
|
||||||
|
IGNORE_TOKEN_ID, 2675, 527, 459, 15592, 18328, 13, IGNORE_TOKEN_ID, # system prompt eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # user header
|
||||||
|
IGNORE_TOKEN_ID, 9906, IGNORE_TOKEN_ID, # user prompt eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant header
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, # assistant response eot
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, 4438, 527, 499, 30, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID, IGNORE_TOKEN_ID,
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||||
|
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||||
|
LOG.debug(f"Expected labels: {expected_labels}")
|
||||||
|
LOG.debug(f"Actual labels: {labels}")
|
||||||
|
|
||||||
|
assert (
|
||||||
|
input_ids == expected_input_ids
|
||||||
|
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||||
|
assert (
|
||||||
|
labels == expected_labels
|
||||||
|
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
156
tests/prompt_strategies/test_dpo_chat_templates.py
Normal file
156
tests/prompt_strategies/test_dpo_chat_templates.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""
|
||||||
|
tests for chat_template prompt strategy
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from datasets import Dataset
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from axolotl.prompt_strategies.dpo.chat_template import default
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="assistant_dataset")
|
||||||
|
def fixture_assistant_dataset():
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
return Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "goodbye",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"chosen": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "goodbye",
|
||||||
|
},
|
||||||
|
"rejected": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "party on",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="custom_assistant_dataset")
|
||||||
|
def fixture_custom_assistant_dataset():
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
return Dataset.from_list(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"conversation": [
|
||||||
|
{
|
||||||
|
"speaker": "human",
|
||||||
|
"text": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speaker": "agent",
|
||||||
|
"text": "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"speaker": "human",
|
||||||
|
"text": "goodbye",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"better": {
|
||||||
|
"speaker": "agent",
|
||||||
|
"text": "goodbye",
|
||||||
|
},
|
||||||
|
"worse": {
|
||||||
|
"speaker": "agent",
|
||||||
|
"text": "party on",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="llama3_tokenizer")
|
||||||
|
def fixture_llama3_tokenizer():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||||
|
tokenizer.eos_token = "<|eot_id|>"
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssistantDPOChatTemplateLlama3:
|
||||||
|
"""
|
||||||
|
Test class for assistant style datasets with llama-3 prompts using the chat_template strategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
transform_fn = default(
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = transform_fn(assistant_dataset[0], tokenizer=llama3_tokenizer)
|
||||||
|
assert result["prompt"] == (
|
||||||
|
"<|begin_of_text|>"
|
||||||
|
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
assert result["chosen"] == "goodbye<|eot_id|>"
|
||||||
|
assert result["rejected"] == "party on<|eot_id|>"
|
||||||
|
|
||||||
|
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
transform_fn = default(
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"chat_template": "llama3",
|
||||||
|
"field_messages": "conversation",
|
||||||
|
"field_chosen": "better",
|
||||||
|
"field_rejected": "worse",
|
||||||
|
"message_field_role": "speaker",
|
||||||
|
"message_field_content": "text",
|
||||||
|
"roles": {
|
||||||
|
"user": ["human"],
|
||||||
|
"assistant": ["agent"],
|
||||||
|
"system": ["sys"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = transform_fn(custom_assistant_dataset[0], tokenizer=llama3_tokenizer)
|
||||||
|
assert result["prompt"] == (
|
||||||
|
"<|begin_of_text|>"
|
||||||
|
+ "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>user<|end_header_id|>\n\ngoodbye<|eot_id|>"
|
||||||
|
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
)
|
||||||
|
assert result["chosen"] == "goodbye<|eot_id|>"
|
||||||
|
assert result["rejected"] == "party on<|eot_id|>"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -192,6 +192,7 @@ class TestSharegptLlama3:
|
|||||||
input_ids = dataset_wrapper[0]["input_ids"]
|
input_ids = dataset_wrapper[0]["input_ids"]
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
assert input_ids == [
|
assert input_ids == [
|
||||||
128000, # bos
|
128000, # bos
|
||||||
128006, 9125, 128007, # system header
|
128006, 9125, 128007, # system header
|
||||||
@@ -228,6 +229,7 @@ class TestSharegptLlama3:
|
|||||||
input_ids = dataset_wrapper[0]["input_ids"]
|
input_ids = dataset_wrapper[0]["input_ids"]
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
assert input_ids == [
|
assert input_ids == [
|
||||||
128000, # bos
|
128000, # bos
|
||||||
128006, 9125, 128007, # system header
|
128006, 9125, 128007, # system header
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class TestEncodePretraining(unittest.TestCase):
|
|||||||
"hello, hello",
|
"hello, hello",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])
|
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
|
||||||
|
|
||||||
self.assertEqual(len(result["input_ids"]), 3)
|
self.assertEqual(len(result["input_ids"]), 3)
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class TestPretrainingPacking(unittest.TestCase):
|
|||||||
def test_packing_stream_dataset(self):
|
def test_packing_stream_dataset(self):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
"c4",
|
"allenai/c4",
|
||||||
"en",
|
"en",
|
||||||
streaming=True,
|
streaming=True,
|
||||||
)["train"]
|
)["train"]
|
||||||
@@ -33,7 +33,7 @@ class TestPretrainingPacking(unittest.TestCase):
|
|||||||
{
|
{
|
||||||
"pretraining_dataset": [
|
"pretraining_dataset": [
|
||||||
{
|
{
|
||||||
"path": "c4",
|
"path": "allenai/c4",
|
||||||
"name": "en",
|
"name": "en",
|
||||||
"type": "pretrain",
|
"type": "pretrain",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,6 +42,19 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|||||||
assert "USER:" not in res
|
assert "USER:" not in res
|
||||||
assert "ASSISTANT:" not in res
|
assert "ASSISTANT:" not in res
|
||||||
|
|
||||||
|
def test_prompt_style_w_phi(self):
|
||||||
|
prompter = AlpacaPrompter(prompt_style=PromptStyle.PHI.value)
|
||||||
|
res = next(prompter.build_prompt("tell me a joke about the following"))
|
||||||
|
assert (
|
||||||
|
"""<|system|>
|
||||||
|
Below is an instruction that describes a task. Write a response that appropriately completes the request.<|end|>
|
||||||
|
<|user|>
|
||||||
|
tell me a joke about the following<|end|>
|
||||||
|
<|assistant|>
|
||||||
|
"""
|
||||||
|
== res
|
||||||
|
)
|
||||||
|
|
||||||
def test_prompt_style_w_chat(self):
|
def test_prompt_style_w_chat(self):
|
||||||
prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value)
|
prompter = AlpacaPrompter(prompt_style=PromptStyle.CHAT.value)
|
||||||
res = next(
|
res = next(
|
||||||
|
|||||||
Reference in New Issue
Block a user