Compare commits
54 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb8bfab9cc | ||
|
|
0c8b1d824a | ||
|
|
fd70eec577 | ||
|
|
d42f202046 | ||
|
|
0dabde1962 | ||
|
|
15f1462ccd | ||
|
|
521e62daf1 | ||
|
|
c16ec398d7 | ||
|
|
2f20cb7ebf | ||
|
|
71d4030b79 | ||
|
|
f3a5d119af | ||
|
|
ba219b51a5 | ||
|
|
5be8e13d35 | ||
|
|
2d7830fda6 | ||
|
|
5e98cdddac | ||
|
|
1d7aee0ad2 | ||
|
|
659ee5d723 | ||
|
|
342935cff3 | ||
|
|
c5eb9ea2c2 | ||
|
|
f2145a3ccb | ||
|
|
010d0e7ff3 | ||
|
|
01881c3113 | ||
|
|
0e8eb96e07 | ||
|
|
4e1891b12b | ||
|
|
28924fc791 | ||
|
|
8c480b2804 | ||
|
|
a4b1cc6df0 | ||
|
|
7b78a31593 | ||
|
|
810ebc2c0e | ||
|
|
ad435a3b09 | ||
|
|
9f1cf9b17c | ||
|
|
3931a42763 | ||
|
|
dc8f9059f7 | ||
|
|
234e94e9dd | ||
|
|
f68fb71005 | ||
|
|
9bc3ee6c75 | ||
|
|
d356740ffa | ||
|
|
e4af51eb66 | ||
|
|
e20b15bee3 | ||
|
|
d4796cb645 | ||
|
|
fd3b80716a | ||
|
|
3265b7095e | ||
|
|
3cb2d75de1 | ||
|
|
035e9f9dd7 | ||
|
|
02ce520b7e | ||
|
|
052a9a79b4 | ||
|
|
3591bcfaf9 | ||
|
|
dc1de7d81b | ||
|
|
d4dbfa02fe | ||
|
|
5c7e89105d | ||
|
|
74db2a1bae | ||
|
|
e62554c419 | ||
|
|
32c60765ef | ||
|
|
8c3a727f9d |
14
.github/workflows/base.yml
vendored
14
.github/workflows/base.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
python_version: "3.10"
|
||||
pytorch: 2.4.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
@@ -40,23 +40,25 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.0
|
||||
pytorch: 2.5.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v3
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: winglian/axolotl-base
|
||||
images: |
|
||||
winglian/axolotl-base
|
||||
axolotlai/axolotl-base
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
|
||||
2
.github/workflows/docs.yml
vendored
2
.github/workflows/docs.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: Set up Quarto
|
||||
uses: quarto-dev/quarto-actions/setup@v2
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v3
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: install dependencies
|
||||
|
||||
6
.github/workflows/lint.yml
vendored
6
.github/workflows/lint.yml
vendored
@@ -15,9 +15,9 @@ jobs:
|
||||
name: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
|
||||
37
.github/workflows/main.yml
vendored
37
.github/workflows/main.yml
vendored
@@ -4,6 +4,8 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
tags:
|
||||
- "v*"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
@@ -32,7 +34,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.0
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -42,7 +44,12 @@ jobs:
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: winglian/axolotl
|
||||
images: |
|
||||
winglian/axolotl
|
||||
axolotlai/axolotl
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{version}}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to Docker Hub
|
||||
@@ -56,7 +63,7 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
||||
@@ -94,7 +101,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.0
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -104,20 +111,25 @@ jobs:
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: winglian/axolotl-cloud
|
||||
images: |
|
||||
winglian/axolotl-cloud
|
||||
axolotlai/axolotl-cloud
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{version}}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
file: ./docker/Dockerfile-cloud
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
@@ -146,20 +158,25 @@ jobs:
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: winglian/axolotl-cloud-term
|
||||
images: |
|
||||
winglian/axolotl-cloud-term
|
||||
axolotlai/axolotl-cloud-term
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{version}}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
file: ./docker/Dockerfile-cloud-no-tmux
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
|
||||
7
.github/workflows/multi-gpu-e2e.yml
vendored
7
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -8,6 +8,11 @@ on:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
jobs:
|
||||
test-axolotl-multigpu:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
@@ -31,7 +36,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.0
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
|
||||
14
.github/workflows/nightlies.yml
vendored
14
.github/workflows/nightlies.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.0
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -41,7 +41,9 @@ jobs:
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: winglian/axolotl
|
||||
images: |
|
||||
winglian/axolotl
|
||||
axolotlai/axolotl
|
||||
tags: |
|
||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||
- name: Set up Docker Buildx
|
||||
@@ -93,7 +95,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.0
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -103,7 +105,9 @@ jobs:
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: winglian/axolotl-cloud
|
||||
images: |
|
||||
winglian/axolotl-cloud
|
||||
axolotlai/axolotl-cloud
|
||||
tags: |
|
||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||
- name: Login to Docker Hub
|
||||
@@ -112,7 +116,7 @@ jobs:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
|
||||
25
.github/workflows/pypi.yml
vendored
25
.github/workflows/pypi.yml
vendored
@@ -3,12 +3,31 @@ name: publish pypi
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
setup_release:
|
||||
name: Create Release
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Get the tag version
|
||||
id: extract_branch
|
||||
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
||||
shell: bash
|
||||
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
||||
release_name: ${{ steps.extract_branch.outputs.branch }}
|
||||
pypi-publish:
|
||||
name: Upload release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: [setup_release]
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/axolotl
|
||||
@@ -16,10 +35,10 @@ jobs:
|
||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
|
||||
22
.github/workflows/tests-nightly.yml
vendored
22
.github/workflows/tests-nightly.yml
vendored
@@ -9,12 +9,12 @@ jobs:
|
||||
name: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
@@ -25,15 +25,15 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
|
||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
@@ -48,6 +48,7 @@ jobs:
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
@@ -82,13 +83,6 @@ jobs:
|
||||
num_gpus: 1
|
||||
axolotl_extras: mamba-ssm
|
||||
nightly_build: "true"
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: mamba-ssm
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
@@ -99,7 +93,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.0
|
||||
pytorch: 2.5.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
|
||||
83
.github/workflows/tests.yml
vendored
83
.github/workflows/tests.yml
vendored
@@ -15,17 +15,22 @@ on:
|
||||
- '.github/workflows/*.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
@@ -36,15 +41,15 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
|
||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
@@ -72,7 +77,7 @@ jobs:
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
docker-e2e-tests:
|
||||
docker-e2e-tests-1st:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
@@ -83,30 +88,58 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.10"
|
||||
pytorch: 2.3.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: mamba-ssm
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: mamba-ssm
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==0.63.64 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.tests
|
||||
|
||||
docker-e2e-tests:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.10"
|
||||
pytorch: 2.3.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: mamba-ssm
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
295
1991.yml
295
1991.yml
@@ -1,295 +0,0 @@
|
||||
base_model: Qwen/Qwen2.5-14B-Instruct
|
||||
model_type: AutoModelForCausalLM #nohup accelerate launch -m axolotl.cli.train /home/ubuntu/qwen2.5_14B.yml > training_output.log 2>&1 &
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: tatsu-lab/alpaca
|
||||
type: alpaca
|
||||
|
||||
chat_template: chatml
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
unfrozen_parameters:
|
||||
- ^lm_head.weight$
|
||||
- ^model.embed_tokens.weight$
|
||||
# input_layernorm layers
|
||||
- model.layers.0.input_layernorm
|
||||
- model.layers.1.input_layernorm
|
||||
- model.layers.2.input_layernorm
|
||||
- model.layers.3.input_layernorm
|
||||
- model.layers.4.input_layernorm
|
||||
- model.layers.5.input_layernorm
|
||||
- model.layers.6.input_layernorm
|
||||
- model.layers.7.input_layernorm
|
||||
- model.layers.8.input_layernorm
|
||||
- model.layers.9.input_layernorm
|
||||
- model.layers.10.input_layernorm
|
||||
- model.layers.11.input_layernorm
|
||||
- model.layers.12.input_layernorm
|
||||
- model.layers.13.input_layernorm
|
||||
- model.layers.14.input_layernorm
|
||||
- model.layers.15.input_layernorm
|
||||
- model.layers.16.input_layernorm
|
||||
- model.layers.17.input_layernorm
|
||||
- model.layers.18.input_layernorm
|
||||
- model.layers.19.input_layernorm
|
||||
- model.layers.20.input_layernorm
|
||||
- model.layers.21.input_layernorm
|
||||
- model.layers.22.input_layernorm
|
||||
- model.layers.23.input_layernorm
|
||||
# lm_head layers
|
||||
# mlp.down_proj layers
|
||||
- model.layers.1.mlp.down_proj
|
||||
- model.layers.35.mlp.down_proj
|
||||
- model.layers.38.mlp.down_proj
|
||||
- model.layers.37.mlp.down_proj
|
||||
- model.layers.36.mlp.down_proj
|
||||
- model.layers.15.mlp.down_proj
|
||||
- model.layers.11.mlp.down_proj
|
||||
- model.layers.12.mlp.down_proj
|
||||
- model.layers.34.mlp.down_proj
|
||||
- model.layers.44.mlp.down_proj
|
||||
- model.layers.45.mlp.down_proj
|
||||
- model.layers.9.mlp.down_proj
|
||||
- model.layers.41.mlp.down_proj
|
||||
- model.layers.33.mlp.down_proj
|
||||
- model.layers.43.mlp.down_proj
|
||||
- model.layers.40.mlp.down_proj
|
||||
- model.layers.13.mlp.down_proj
|
||||
- model.layers.8.mlp.down_proj
|
||||
- model.layers.39.mlp.down_proj
|
||||
- model.layers.10.mlp.down_proj
|
||||
- model.layers.14.mlp.down_proj
|
||||
- model.layers.16.mlp.down_proj
|
||||
- model.layers.31.mlp.down_proj
|
||||
- model.layers.32.mlp.down_proj
|
||||
# mlp.gate_proj layers
|
||||
- model.layers.1.mlp.gate_proj
|
||||
- model.layers.44.mlp.gate_proj
|
||||
- model.layers.46.mlp.gate_proj
|
||||
- model.layers.45.mlp.gate_proj
|
||||
- model.layers.43.mlp.gate_proj
|
||||
- model.layers.47.mlp.gate_proj
|
||||
- model.layers.42.mlp.gate_proj
|
||||
- model.layers.32.mlp.gate_proj
|
||||
- model.layers.27.mlp.gate_proj
|
||||
- model.layers.33.mlp.gate_proj
|
||||
- model.layers.28.mlp.gate_proj
|
||||
- model.layers.39.mlp.gate_proj
|
||||
- model.layers.41.mlp.gate_proj
|
||||
- model.layers.40.mlp.gate_proj
|
||||
- model.layers.30.mlp.gate_proj
|
||||
- model.layers.29.mlp.gate_proj
|
||||
- model.layers.31.mlp.gate_proj
|
||||
- model.layers.26.mlp.gate_proj
|
||||
- model.layers.37.mlp.gate_proj
|
||||
- model.layers.10.mlp.gate_proj
|
||||
- model.layers.38.mlp.gate_proj
|
||||
- model.layers.12.mlp.gate_proj
|
||||
- model.layers.36.mlp.gate_proj
|
||||
- model.layers.13.mlp.gate_proj
|
||||
# mlp.up_proj layers
|
||||
- model.layers.1.mlp.up_proj
|
||||
- model.layers.13.mlp.up_proj
|
||||
- model.layers.11.mlp.up_proj
|
||||
- model.layers.14.mlp.up_proj
|
||||
- model.layers.15.mlp.up_proj
|
||||
- model.layers.12.mlp.up_proj
|
||||
- model.layers.8.mlp.up_proj
|
||||
- model.layers.16.mlp.up_proj
|
||||
- model.layers.9.mlp.up_proj
|
||||
- model.layers.19.mlp.up_proj
|
||||
- model.layers.10.mlp.up_proj
|
||||
- model.layers.7.mlp.up_proj
|
||||
- model.layers.17.mlp.up_proj
|
||||
- model.layers.20.mlp.up_proj
|
||||
- model.layers.21.mlp.up_proj
|
||||
- model.layers.18.mlp.up_proj
|
||||
- model.layers.38.mlp.up_proj
|
||||
- model.layers.37.mlp.up_proj
|
||||
- model.layers.39.mlp.up_proj
|
||||
- model.layers.42.mlp.up_proj
|
||||
- model.layers.41.mlp.up_proj
|
||||
- model.layers.27.mlp.up_proj
|
||||
- model.layers.28.mlp.up_proj
|
||||
- model.layers.34.mlp.up_proj
|
||||
# model.norm layers
|
||||
# post_attention_layernorm layers
|
||||
- model.layers.0.post_attention_layernorm
|
||||
- model.layers.1.post_attention_layernorm
|
||||
- model.layers.2.post_attention_layernorm
|
||||
- model.layers.3.post_attention_layernorm
|
||||
- model.layers.4.post_attention_layernorm
|
||||
- model.layers.5.post_attention_layernorm
|
||||
- model.layers.6.post_attention_layernorm
|
||||
- model.layers.7.post_attention_layernorm
|
||||
- model.layers.8.post_attention_layernorm
|
||||
- model.layers.9.post_attention_layernorm
|
||||
- model.layers.10.post_attention_layernorm
|
||||
- model.layers.11.post_attention_layernorm
|
||||
- model.layers.12.post_attention_layernorm
|
||||
- model.layers.13.post_attention_layernorm
|
||||
- model.layers.14.post_attention_layernorm
|
||||
- model.layers.15.post_attention_layernorm
|
||||
- model.layers.16.post_attention_layernorm
|
||||
- model.layers.17.post_attention_layernorm
|
||||
- model.layers.18.post_attention_layernorm
|
||||
- model.layers.19.post_attention_layernorm
|
||||
- model.layers.20.post_attention_layernorm
|
||||
- model.layers.21.post_attention_layernorm
|
||||
- model.layers.22.post_attention_layernorm
|
||||
- model.layers.23.post_attention_layernorm
|
||||
# self_attn.k_proj layers
|
||||
- model.layers.47.self_attn.k_proj
|
||||
- model.layers.39.self_attn.k_proj
|
||||
- model.layers.41.self_attn.k_proj
|
||||
- model.layers.37.self_attn.k_proj
|
||||
- model.layers.35.self_attn.k_proj
|
||||
- model.layers.44.self_attn.k_proj
|
||||
- model.layers.38.self_attn.k_proj
|
||||
- model.layers.14.self_attn.k_proj
|
||||
- model.layers.7.self_attn.k_proj
|
||||
- model.layers.12.self_attn.k_proj
|
||||
- model.layers.11.self_attn.k_proj
|
||||
- model.layers.32.self_attn.k_proj
|
||||
- model.layers.10.self_attn.k_proj
|
||||
- model.layers.8.self_attn.k_proj
|
||||
- model.layers.9.self_attn.k_proj
|
||||
- model.layers.6.self_attn.k_proj
|
||||
- model.layers.45.self_attn.k_proj
|
||||
- model.layers.42.self_attn.k_proj
|
||||
- model.layers.5.self_attn.k_proj
|
||||
- model.layers.40.self_attn.k_proj
|
||||
- model.layers.33.self_attn.k_proj
|
||||
- model.layers.0.self_attn.k_proj
|
||||
- model.layers.34.self_attn.k_proj
|
||||
- model.layers.13.self_attn.k_proj
|
||||
# self_attn.o_proj layers
|
||||
- model.layers.12.self_attn.o_proj
|
||||
- model.layers.5.self_attn.o_proj
|
||||
- model.layers.14.self_attn.o_proj
|
||||
- model.layers.16.self_attn.o_proj
|
||||
- model.layers.20.self_attn.o_proj
|
||||
- model.layers.13.self_attn.o_proj
|
||||
- model.layers.11.self_attn.o_proj
|
||||
- model.layers.4.self_attn.o_proj
|
||||
- model.layers.6.self_attn.o_proj
|
||||
- model.layers.19.self_attn.o_proj
|
||||
- model.layers.7.self_attn.o_proj
|
||||
- model.layers.18.self_attn.o_proj
|
||||
- model.layers.8.self_attn.o_proj
|
||||
- model.layers.38.self_attn.o_proj
|
||||
- model.layers.15.self_attn.o_proj
|
||||
- model.layers.17.self_attn.o_proj
|
||||
- model.layers.9.self_attn.o_proj
|
||||
- model.layers.10.self_attn.o_proj
|
||||
- model.layers.21.self_attn.o_proj
|
||||
- model.layers.28.self_attn.o_proj
|
||||
- model.layers.32.self_attn.o_proj
|
||||
- model.layers.35.self_attn.o_proj
|
||||
- model.layers.39.self_attn.o_proj
|
||||
- model.layers.3.self_attn.o_proj
|
||||
# self_attn.q_proj layers
|
||||
- model.layers.1.self_attn.q_proj
|
||||
- model.layers.2.self_attn.q_proj
|
||||
- model.layers.3.self_attn.q_proj
|
||||
- model.layers.44.self_attn.q_proj
|
||||
- model.layers.29.self_attn.q_proj
|
||||
- model.layers.45.self_attn.q_proj
|
||||
- model.layers.43.self_attn.q_proj
|
||||
- model.layers.32.self_attn.q_proj
|
||||
- model.layers.38.self_attn.q_proj
|
||||
- model.layers.19.self_attn.q_proj
|
||||
- model.layers.42.self_attn.q_proj
|
||||
- model.layers.34.self_attn.q_proj
|
||||
- model.layers.36.self_attn.q_proj
|
||||
- model.layers.40.self_attn.q_proj
|
||||
- model.layers.26.self_attn.q_proj
|
||||
- model.layers.20.self_attn.q_proj
|
||||
- model.layers.39.self_attn.q_proj
|
||||
- model.layers.28.self_attn.q_proj
|
||||
- model.layers.35.self_attn.q_proj
|
||||
- model.layers.41.self_attn.q_proj
|
||||
- model.layers.33.self_attn.q_proj
|
||||
- model.layers.25.self_attn.q_proj
|
||||
- model.layers.30.self_attn.q_proj
|
||||
- model.layers.27.self_attn.q_proj
|
||||
# self_attn.v_proj layers
|
||||
- model.layers.0.self_attn.v_proj
|
||||
- model.layers.7.self_attn.v_proj
|
||||
- model.layers.39.self_attn.v_proj
|
||||
- model.layers.31.self_attn.v_proj
|
||||
- model.layers.15.self_attn.v_proj
|
||||
- model.layers.10.self_attn.v_proj
|
||||
- model.layers.32.self_attn.v_proj
|
||||
- model.layers.41.self_attn.v_proj
|
||||
- model.layers.6.self_attn.v_proj
|
||||
- model.layers.33.self_attn.v_proj
|
||||
- model.layers.42.self_attn.v_proj
|
||||
- model.layers.29.self_attn.v_proj
|
||||
- model.layers.14.self_attn.v_proj
|
||||
- model.layers.9.self_attn.v_proj
|
||||
- model.layers.35.self_attn.v_proj
|
||||
- model.layers.38.self_attn.v_proj
|
||||
- model.layers.13.self_attn.v_proj
|
||||
- model.layers.30.self_attn.v_proj
|
||||
- model.layers.5.self_attn.v_proj
|
||||
- model.layers.34.self_attn.v_proj
|
||||
- model.layers.28.self_attn.v_proj
|
||||
- model.layers.37.self_attn.v_proj
|
||||
- model.layers.27.self_attn.v_proj
|
||||
- model.layers.11.self_attn.v_proj
|
||||
# model.embed_tokens layers
|
||||
|
||||
|
||||
gradient_accumulation_steps: 2
|
||||
micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_torch_fused
|
||||
lr_scheduler: linear
|
||||
learning_rate: 5e-6
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_swiglu: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
gradient_checkpointing: unsloth
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 2
|
||||
saves_per_epoch: 1
|
||||
save_total_limit: 4
|
||||
debug:
|
||||
deepspeed: deepspeed_configs/zero3_bf16.json
|
||||
weight_decay: 0.05
|
||||
special_tokens:
|
||||
eos_token: <|im_end|>
|
||||
18
README.md
18
README.md
@@ -159,7 +159,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
|
||||
#### Docker
|
||||
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest
|
||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||
```
|
||||
|
||||
Or run on the current files for development:
|
||||
@@ -178,7 +178,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
|
||||
A more powerful Docker command to run would be this:
|
||||
|
||||
```bash
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-latest
|
||||
```
|
||||
|
||||
It additionally:
|
||||
@@ -210,7 +210,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
||||
|
||||
#### Cloud GPU
|
||||
|
||||
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
|
||||
For cloud GPU providers that support docker images, use [`axolotlai/axolotl-cloud:main-latest`](https://hub.docker.com/r/axolotlai/axolotl-cloud/tags)
|
||||
|
||||
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
|
||||
@@ -319,7 +319,7 @@ Write a job description in YAML as below:
|
||||
# dstack.yaml
|
||||
type: task
|
||||
|
||||
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2
|
||||
image: axolotlai/axolotl-cloud:main-latest
|
||||
|
||||
env:
|
||||
- HUGGING_FACE_HUB_TOKEN
|
||||
@@ -383,11 +383,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- typescript
|
||||
type: ... # unimplemented custom format
|
||||
|
||||
# fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template)
|
||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
# chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template
|
||||
- path: ...
|
||||
type: sharegpt
|
||||
conversation: chatml # default: vicuna_v1.1
|
||||
type: chat_template
|
||||
chat_template: chatml # defaults to tokenizer's chat_template
|
||||
|
||||
# local
|
||||
- path: data.jsonl # or json
|
||||
@@ -562,7 +561,8 @@ plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_swiglu: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
```
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM winglian/axolotl-base:{{ BASE_TAG }}
|
||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||
@@ -28,6 +28,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
|
||||
@@ -10,7 +10,7 @@ import tempfile
|
||||
import jinja2
|
||||
import modal
|
||||
from jinja2 import select_autoescape
|
||||
from modal import Image, Stub
|
||||
from modal import App, Image
|
||||
|
||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
@@ -46,7 +46,7 @@ cicd_image = (
|
||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
)
|
||||
|
||||
stub = Stub("Axolotl CI/CD", secrets=[])
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 2))
|
||||
@@ -61,7 +61,7 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
|
||||
@stub.function(
|
||||
@app.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=60 * 60,
|
||||
@@ -72,6 +72,6 @@ def cicd_pytest():
|
||||
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
||||
|
||||
|
||||
@stub.local_entrypoint()
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
cicd_pytest.remote()
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
set -e
|
||||
|
||||
# only run one test at a time so as not to OOM the GPU
|
||||
pytest -n1 /workspace/axolotl/tests/e2e/multigpu/
|
||||
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
|
||||
|
||||
@@ -10,7 +10,7 @@ import tempfile
|
||||
import jinja2
|
||||
import modal
|
||||
from jinja2 import select_autoescape
|
||||
from modal import Image, Stub
|
||||
from modal import App, Image
|
||||
|
||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
@@ -47,7 +47,7 @@ cicd_image = (
|
||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
)
|
||||
|
||||
stub = Stub("Axolotl CI/CD", secrets=[])
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||
@@ -62,7 +62,7 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
|
||||
@stub.function(
|
||||
@app.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=60 * 60,
|
||||
@@ -73,6 +73,6 @@ def cicd_pytest():
|
||||
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
||||
|
||||
|
||||
@stub.local_entrypoint()
|
||||
@app.local_entrypoint()
|
||||
def main():
|
||||
cicd_pytest.remote()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Example config for debugging the sharegpt prompt format
|
||||
# Example config for debugging the chat_template prompt format
|
||||
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
@@ -7,8 +7,8 @@ load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
datasets:
|
||||
- path: philschmid/guanaco-sharegpt-style
|
||||
type: sharegpt
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
shards: 10
|
||||
val_set_size: 0
|
||||
output_dir: temp_debug/axolotl_outputs/model
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG BASE_TAG=main-base
|
||||
FROM winglian/axolotl-base:$BASE_TAG
|
||||
FROM axolotlai/axolotl-base:$BASE_TAG
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG BASE_TAG=main
|
||||
FROM winglian/axolotl:$BASE_TAG
|
||||
FROM axolotlai/axolotl:$BASE_TAG
|
||||
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG BASE_TAG=main
|
||||
FROM winglian/axolotl:$BASE_TAG
|
||||
FROM axolotlai/axolotl:$BASE_TAG
|
||||
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG BASE_TAG=main-base
|
||||
FROM winglian/axolotl-base:$BASE_TAG
|
||||
FROM axolotlai/axolotl-base:$BASE_TAG
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
|
||||
@@ -83,7 +83,7 @@ lora_on_cpu: true
|
||||
datasets:
|
||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
|
||||
data_files: # Optional[str] path to source data files
|
||||
@@ -91,15 +91,7 @@ datasets:
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
train_on_split: train # Optional[str] name of dataset split to load from
|
||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||
|
||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
field_human: # Optional[str]. Human key to use for conversation.
|
||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
||||
# Add additional keys from your dataset as input or output roles
|
||||
roles:
|
||||
input: # Optional[List[str]]. These will be masked based on train_on_input
|
||||
output: # Optional[List[str]].
|
||||
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
|
||||
|
||||
# Custom user instruction prompt
|
||||
- path: repo
|
||||
@@ -183,6 +175,8 @@ test_datasets:
|
||||
|
||||
# use RL training: 'dpo', 'ipo', 'kto'
|
||||
rl:
|
||||
# whether to perform weighting if doing DPO training. Boolean.
|
||||
dpo_use_weighting:
|
||||
|
||||
# The name of the chat template to use for training, following values are supported:
|
||||
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
|
||||
@@ -412,6 +406,7 @@ lr_div_factor: # Learning rate div factor
|
||||
# - adamw_torch_fused
|
||||
# - adamw_torch_xla
|
||||
# - adamw_apex_fused
|
||||
# - adopt_adamw (only for torch version >= 2.5.1)
|
||||
# - adafactor
|
||||
# - adamw_anyprecision
|
||||
# - sgd
|
||||
|
||||
@@ -6,33 +6,8 @@ order: 3
|
||||
|
||||
## sharegpt
|
||||
|
||||
UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below.
|
||||
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
|
||||
|
||||
conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
Note: `type: sharegpt` opens special configs:
|
||||
- `conversation`: enables conversions to many Conversation types. Refer to the 'name' [here](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) for options.
|
||||
- `roles`: allows you to specify the roles for input and output. This is useful for datasets with custom roles such as `tool` etc to support masking.
|
||||
- `field_human`: specify the key to use instead of `human` in the conversation.
|
||||
- `field_model`: specify the key to use instead of `gpt` in the conversation.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
path: ...
|
||||
type: sharegpt
|
||||
|
||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
field_human: # Optional[str]. Human key to use for conversation.
|
||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
||||
# Add additional keys from your dataset as input or output roles
|
||||
roles:
|
||||
input: # Optional[List[str]]. These will be masked based on train_on_input
|
||||
output: # Optional[List[str]].
|
||||
```
|
||||
|
||||
## pygmalion
|
||||
|
||||
@@ -40,38 +15,6 @@ datasets:
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
## sharegpt.load_role
|
||||
|
||||
conversations where `role` is used instead of `from`
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
## sharegpt.load_guanaco
|
||||
|
||||
conversations where `from` is `prompter` `assistant` instead of default sharegpt
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
## sharegpt.load_ultrachat
|
||||
|
||||
conversations where the turns field is 'messages', human is 'user' and gpt is 'assistant'.
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"messages": [{"user": "...", "assistant": "..."}]}
|
||||
```
|
||||
|
||||
## sharegpt_jokes
|
||||
|
||||
creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
||||
```
|
||||
|
||||
|
||||
## chat_template
|
||||
|
||||
|
||||
@@ -51,12 +51,12 @@ While debugging it's helpful to simplify your test scenario as much as possible.
|
||||
|
||||
### Background
|
||||
|
||||
The below example shows how to configure VSCode to debug data preprocessing of the `sharegpt` format. This is the format used when you have the following in your axolotl config:
|
||||
The below example shows how to configure VSCode to debug data preprocessing of the `chat_template` format. This is the format used when you have the following in your axolotl config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: <path to your sharegpt formatted dataset> # example on HF Hub: philschmid/guanaco-sharegpt-style
|
||||
type: sharegpt
|
||||
- path: <path to your chat_template formatted dataset> # example on HF Hub: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
```
|
||||
|
||||
>[!Important]
|
||||
@@ -83,7 +83,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
|
||||
|
||||
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
|
||||
|
||||
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_sharegpt.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
||||
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
||||
|
||||
```jsonc
|
||||
// .vscode/launch.json
|
||||
@@ -91,12 +91,12 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Debug axolotl prompt - sharegpt",
|
||||
"name": "Debug axolotl prompt - chat_template",
|
||||
"type": "python",
|
||||
"module": "accelerate.commands.launch",
|
||||
"request": "launch",
|
||||
"args": [
|
||||
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
|
||||
"-m", "axolotl.cli.train", "dev_chat_template.yml",
|
||||
// The flags below simplify debugging by overriding the axolotl config
|
||||
// with the debugging tips above. Modify as needed.
|
||||
"--dataset_processes=1", // limits data preprocessing to one process
|
||||
@@ -185,7 +185,7 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
|
||||
|
||||
## Debugging With Docker
|
||||
|
||||
Using [official Axolotl Docker images](https://hub.docker.com/r/winglian/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
|
||||
Using [official Axolotl Docker images](https://hub.docker.com/r/axolotlai/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
|
||||
|
||||
### Setup
|
||||
|
||||
@@ -202,11 +202,11 @@ cd axolotl
|
||||
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
||||
|
||||
```bash
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
|
||||
>[!Tip]
|
||||
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/winglian/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
||||
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
||||
|
||||
You will now be in the container. Next, perform an editable install of Axolotl:
|
||||
|
||||
@@ -240,6 +240,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
|
||||
</div>
|
||||
<br>
|
||||
|
||||
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/sharegpt.yml`, but this is the same thing.
|
||||
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
|
||||
|
||||
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).
|
||||
|
||||
@@ -44,7 +44,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
|
||||
"!pip install flash-attn==\"2.5.0\"\n",
|
||||
"!pip install flash-attn==\"2.7.0.post2\"\n",
|
||||
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
||||
]
|
||||
},
|
||||
|
||||
@@ -9,14 +9,17 @@ strict: false
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rms_norm: true
|
||||
liger_swiglu: true
|
||||
liger_glu_activation: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
chat_template: deepseek_v2
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -11,8 +11,11 @@ chat_template: gemma
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
chat_template: gemma
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
|
||||
@@ -4,11 +4,15 @@ tokenizer_type: AutoTokenizer
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
use_tensorboard: true
|
||||
chat_template: jamba
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
chat_template: jamba
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: jamba-large-fsdp-qlora-ft
|
||||
|
||||
@@ -4,7 +4,7 @@ plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_swiglu: true
|
||||
liger_glu_activation: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
strict: false
|
||||
@@ -14,6 +14,10 @@ datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
output_dir: ./outputs/out
|
||||
|
||||
93
examples/mistral/mistral-dpo-qlora.yml
Normal file
93
examples/mistral/mistral-dpo-qlora.yml
Normal file
@@ -0,0 +1,93 @@
|
||||
#Note that we are switching from the regular chat template to chatml.
|
||||
#If you experience problems with the special tokens, training for more epochs can help.
|
||||
#After training, merge the model before inference otherwise you might
|
||||
#face problems with the special tokens.
|
||||
|
||||
base_model: mistralai/Mistral-7B-Instruct-v0.2
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
chat_template: chatml
|
||||
rl: dpo
|
||||
datasets:
|
||||
- path: olivermolenschot/alpaca_messages_dpo_test
|
||||
type: chat_template.default
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/dpo-qlora
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.2
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
lora_modules_to_save:
|
||||
- embed_tokens
|
||||
- lm_head
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 16
|
||||
num_epochs: 6
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0001
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: false
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<|im_start|>"
|
||||
eos_token: "<|im_end|>"
|
||||
@@ -10,7 +10,6 @@ chat_template: phi_3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
chat_template: phi_3
|
||||
field_messages: messages
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
|
||||
67
examples/qwen2/dpo.yaml
Normal file
67
examples/qwen2/dpo.yaml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: Qwen/Qwen2.5-0.5B
|
||||
|
||||
strict: false
|
||||
|
||||
chat_template: qwen_25
|
||||
rl: dpo
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||
type: chat_template.default
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/dpo-out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
@@ -1,2 +1,3 @@
|
||||
pytest
|
||||
pytest-xdist
|
||||
pytest-retry
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft==0.13.2
|
||||
transformers==4.46.0
|
||||
transformers==4.46.2
|
||||
tokenizers>=0.20.1
|
||||
bitsandbytes==0.44.1
|
||||
accelerate==1.0.1
|
||||
datasets==3.0.1
|
||||
accelerate==1.1.0
|
||||
datasets==3.1.0
|
||||
deepspeed==0.15.3
|
||||
pydantic==2.6.3
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
requests
|
||||
flash-attn==2.6.3
|
||||
flash-attn==2.7.0.post2
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
@@ -28,13 +28,12 @@ scipy
|
||||
scikit-learn==1.4.2
|
||||
pynvml
|
||||
art
|
||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
python-dotenv==1.0.1
|
||||
autoawq>=0.2.5
|
||||
triton>=2.3.0
|
||||
liger-kernel==0.3.0
|
||||
liger-kernel==0.4.1
|
||||
|
||||
mamba-ssm==1.2.0.post1
|
||||
|
||||
@@ -43,7 +42,7 @@ s3fs>=2024.5.0
|
||||
gcsfs>=2024.5.0
|
||||
# adlfs
|
||||
|
||||
trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
|
||||
trl==0.12.0
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
@@ -54,3 +53,4 @@ immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.5.0
|
||||
schedulefree==1.3.0
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
# Export specific ENV variables to /etc/rp_environment
|
||||
echo "Exporting environment variables..."
|
||||
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
||||
printenv | grep -E '^HF_|^BNB_|^CUDA_|^NCCL_|^NV|^RUNPOD_|^PATH=|^_=' | sed 's/^\([^=]*\)=\(.*\)$/export \1="\2"/' | grep -v 'printenv' >> /etc/rp_environment
|
||||
echo 'source /etc/rp_environment' >> ~/.bashrc
|
||||
|
||||
add_keys_to_authorized() {
|
||||
|
||||
16
setup.py
16
setup.py
@@ -39,7 +39,10 @@ def parse_requirements():
|
||||
else:
|
||||
# detect the version of torch already installed
|
||||
# and set it so dependencies don't clobber the torch version
|
||||
torch_version = version("torch")
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
except PackageNotFoundError:
|
||||
torch_version = "2.5.1"
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
@@ -54,6 +57,10 @@ def parse_requirements():
|
||||
|
||||
if (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.28.post3")
|
||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||
elif (major, minor) >= (2, 4):
|
||||
if patch == 0:
|
||||
@@ -89,7 +96,7 @@ install_requires, dependency_links = parse_requirements()
|
||||
|
||||
setup(
|
||||
name="axolotl",
|
||||
version="0.4.1",
|
||||
version="0.5.0",
|
||||
description="LLM Trainer",
|
||||
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
||||
package_dir={"": "src"},
|
||||
@@ -98,10 +105,7 @@ setup(
|
||||
dependency_links=dependency_links,
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn==2.6.3",
|
||||
],
|
||||
"fused-dense-lib": [
|
||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib",
|
||||
"flash-attn==2.7.0.post2",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.14.4",
|
||||
|
||||
@@ -190,18 +190,15 @@ def do_inference(
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
chat_template_str = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
elif cfg.chat_template:
|
||||
chat_template_str = get_chat_template(cfg.chat_template)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
@@ -211,13 +208,31 @@ def do_inference(
|
||||
instruction = get_multi_line_input()
|
||||
if not instruction:
|
||||
return
|
||||
|
||||
if prompter_module:
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
if chat_template_str:
|
||||
batch = tokenizer.apply_chat_template(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
else:
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
print("=" * 40)
|
||||
model.eval()
|
||||
@@ -257,13 +272,6 @@ def do_inference_gradio(
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
default_tokens: Dict[str, str] = {}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
chat_template_str = None
|
||||
@@ -272,7 +280,7 @@ def do_inference_gradio(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
elif cfg.chat_template:
|
||||
chat_template_str = get_chat_template(cfg.chat_template)
|
||||
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
|
||||
@@ -23,10 +23,6 @@ from axolotl.cli import (
|
||||
)
|
||||
from axolotl.common.cli import PreprocessCliArgs
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.prompt_strategies.sharegpt import (
|
||||
register_chatml_template,
|
||||
register_llama3_template,
|
||||
)
|
||||
from axolotl.utils.trainer import disable_datasets_caching
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||
@@ -44,23 +40,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
if parsed_cfg.chat_template == "chatml":
|
||||
if parsed_cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||
)
|
||||
register_chatml_template(parsed_cfg.default_system_message)
|
||||
else:
|
||||
register_chatml_template()
|
||||
elif parsed_cfg.chat_template == "llama3":
|
||||
if parsed_cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||
)
|
||||
register_llama3_template(parsed_cfg.default_system_message)
|
||||
else:
|
||||
register_llama3_template()
|
||||
|
||||
if not parsed_cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
|
||||
@@ -19,10 +19,6 @@ from axolotl.cli import (
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.prompt_strategies.sharegpt import (
|
||||
register_chatml_template,
|
||||
register_llama3_template,
|
||||
)
|
||||
from axolotl.train import train
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.train")
|
||||
@@ -42,21 +38,6 @@ def do_train(cfg, cli_args) -> None:
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
if cfg.chat_template == "chatml" and cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"ChatML set. Adding default system message: {cfg.default_system_message}"
|
||||
)
|
||||
register_chatml_template(cfg.default_system_message)
|
||||
else:
|
||||
register_chatml_template()
|
||||
|
||||
if cfg.chat_template == "llama3" and cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
|
||||
)
|
||||
register_llama3_template(cfg.default_system_message)
|
||||
else:
|
||||
register_llama3_template()
|
||||
|
||||
if cfg.rl: # and cfg.rl != "orpo":
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -10,6 +10,7 @@ MOE_ARCH_BLOCK = {
|
||||
"JetMoeMoE",
|
||||
],
|
||||
"mixtral": "MixtralSparseMoeBlock",
|
||||
"phimoe": "PhiMoESparseMoeBlock",
|
||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||
"deepseek_v2": "DeepseekV2MoE",
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ from trl import (
|
||||
)
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
@@ -435,7 +436,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.alternate_optimizer
|
||||
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
|
||||
not in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
@@ -504,6 +511,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "adopt_adamw":
|
||||
from axolotl.utils.optimizers.adopt import ADOPT
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
ADOPT(
|
||||
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
@@ -895,13 +910,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
def _save_checkpoint(self, model, trial):
|
||||
def _save_checkpoint(self, model, trial, **kwargs):
|
||||
# make sure the checkpoint dir exists, since trainer is flakey
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
@@ -1023,24 +1038,37 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def tokenize_row(
|
||||
self,
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> Dict:
|
||||
res = super().tokenize_row(
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
)
|
||||
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
|
||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
@@ -1147,6 +1175,12 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
def get_callbacks(self) -> List[TrainerCallback]:
|
||||
callbacks = []
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||
)
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
@@ -1173,11 +1207,17 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
return callbacks
|
||||
|
||||
@abstractmethod
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
"""
|
||||
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
||||
"""
|
||||
callbacks = []
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
|
||||
)
|
||||
return callbacks
|
||||
|
||||
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
||||
# TODO
|
||||
@@ -1223,7 +1263,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer, "wandb"
|
||||
@@ -1260,6 +1300,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||
callbacks.append(lisa_callback_factory(trainer))
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
[
|
||||
cb
|
||||
for cb in plugin_manager.add_callbacks_post_trainer(
|
||||
self.cfg, trainer
|
||||
)
|
||||
if cb
|
||||
]
|
||||
)
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self):
|
||||
@@ -1377,17 +1429,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
training_arguments_kwargs["eval_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
elif self.cfg.evaluation_strategy:
|
||||
training_arguments_kwargs[
|
||||
"evaluation_strategy"
|
||||
] = self.cfg.evaluation_strategy
|
||||
elif self.cfg.eval_strategy:
|
||||
training_arguments_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||
else:
|
||||
# we have an eval set, but no steps defined, default to use epoch
|
||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||
training_arguments_kwargs["eval_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.save_steps:
|
||||
training_arguments_kwargs["save_strategy"] = "steps"
|
||||
@@ -1595,7 +1645,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||
if self.cfg.chat_template:
|
||||
training_arguments_kwargs["chat_template"] = get_chat_template(
|
||||
self.cfg.chat_template
|
||||
self.cfg.chat_template,
|
||||
tokenizer=self.tokenizer,
|
||||
)
|
||||
|
||||
if self.cfg.rl == "orpo":
|
||||
@@ -1611,11 +1662,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.reward_model:
|
||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
if self.cfg.optimizer in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]:
|
||||
# Set default so transformers doesn't throw
|
||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||
@@ -1790,7 +1843,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def build_training_arguments(self, total_num_steps):
|
||||
@@ -1818,10 +1871,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
if self.eval_dataset:
|
||||
training_args_kwargs["evaluation_strategy"] = "steps"
|
||||
training_args_kwargs["eval_strategy"] = "steps"
|
||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
else:
|
||||
training_args_kwargs["evaluation_strategy"] = "no"
|
||||
training_args_kwargs["eval_strategy"] = "no"
|
||||
|
||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||
training_args_kwargs["bf16"] = True
|
||||
@@ -1876,17 +1929,18 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
if self.cfg.rl_beta:
|
||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
||||
if self.cfg.orpo_alpha:
|
||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
if self.cfg.rpo_alpha is not None:
|
||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||
|
||||
training_args_cls = None
|
||||
if self.cfg.rl == "simpo":
|
||||
training_args_cls = AxolotlCPOConfig
|
||||
training_args_kwargs["loss_type"] = "simpo"
|
||||
@@ -1895,13 +1949,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
if self.cfg.rl == "orpo":
|
||||
elif self.cfg.rl == "orpo":
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
if self.cfg.rl == "kto":
|
||||
elif self.cfg.rl == "kto":
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
|
||||
training_args_kwargs["desirable_weight"] = (
|
||||
@@ -1916,6 +1970,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
else:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
if self.cfg.rl == "ipo":
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
output_dir=self.cfg.output_dir,
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
@@ -1936,7 +2001,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args = self.build_training_arguments(total_num_steps)
|
||||
dpo_trainer_kwargs = {}
|
||||
if self.cfg.rl == "ipo":
|
||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
||||
if self.cfg.dpo_label_smoothing:
|
||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
if self.eval_dataset:
|
||||
@@ -1950,12 +2014,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
|
||||
# these aren't used for the ORPO trainer
|
||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["max_target_length"] = None
|
||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
@@ -1999,11 +2057,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = []
|
||||
callbacks = super().get_callbacks()
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def build(self, total_num_steps):
|
||||
|
||||
@@ -18,9 +18,10 @@ Plugins can be used to integrate third-party models, modify the training process
|
||||
|
||||
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
|
||||
"""
|
||||
import collections
|
||||
import importlib
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import OrderedDict
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
@@ -47,7 +48,7 @@ class BasePlugin:
|
||||
Initializes the BasePlugin.
|
||||
"""
|
||||
|
||||
def register(self, cfg):
|
||||
def register(self, cfg): # pylint: disable=unused-argument
|
||||
"""
|
||||
Registers the plugin with the given configuration.
|
||||
|
||||
@@ -63,7 +64,7 @@ class BasePlugin:
|
||||
Returns a pydantic model for the plugin's input arguments.
|
||||
"""
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
def pre_model_load(self, cfg): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions before the model is loaded.
|
||||
|
||||
@@ -74,7 +75,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_model_load(self, cfg, model):
|
||||
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after the model is loaded.
|
||||
|
||||
@@ -86,7 +87,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def pre_lora_load(self, cfg, model):
|
||||
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions before LoRA weights are loaded.
|
||||
|
||||
@@ -98,7 +99,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_lora_load(self, cfg, model):
|
||||
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after LoRA weights are loaded.
|
||||
|
||||
@@ -110,7 +111,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def create_optimizer(self, cfg, trainer):
|
||||
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
"""
|
||||
Creates and returns an optimizer for training.
|
||||
|
||||
@@ -122,7 +123,9 @@ class BasePlugin:
|
||||
object: The created optimizer.
|
||||
"""
|
||||
|
||||
def create_lr_scheduler(self, cfg, trainer, optimizer):
|
||||
def create_lr_scheduler(
|
||||
self, cfg, trainer, optimizer
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
Creates and returns a learning rate scheduler.
|
||||
|
||||
@@ -135,9 +138,9 @@ class BasePlugin:
|
||||
object: The created learning rate scheduler.
|
||||
"""
|
||||
|
||||
def add_callbacks_pre_trainer(self, cfg, model):
|
||||
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Adds callbacks to the trainer before training.
|
||||
setup callbacks before creating the trainer.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
@@ -146,20 +149,25 @@ class BasePlugin:
|
||||
Returns:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||
"""
|
||||
return []
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
def add_callbacks_post_trainer(
|
||||
self, cfg, trainer
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
Adds callbacks to the trainer after training.
|
||||
Adds callbacks to the trainer after creating the trainer.
|
||||
This is useful for callbacks that require access to the model or trainer.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
|
||||
Returns:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||
List[callable]: A list of callback functions to be added
|
||||
"""
|
||||
return []
|
||||
|
||||
def post_train(self, cfg, model):
|
||||
def post_train(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after training is complete.
|
||||
|
||||
@@ -171,7 +179,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
def post_train_unload(self, cfg): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after training is complete and the model is unloaded.
|
||||
|
||||
@@ -227,7 +235,7 @@ class PluginManager:
|
||||
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
|
||||
"""
|
||||
|
||||
plugins: List[BasePlugin] = []
|
||||
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
|
||||
|
||||
_instance = None
|
||||
|
||||
@@ -237,7 +245,7 @@ class PluginManager:
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(PluginManager, cls).__new__(cls)
|
||||
cls._instance.plugins: List[BasePlugin] = []
|
||||
cls._instance.plugins = collections.OrderedDict()
|
||||
return cls._instance
|
||||
|
||||
@staticmethod
|
||||
@@ -265,7 +273,7 @@ class PluginManager:
|
||||
"""
|
||||
try:
|
||||
plugin = load_plugin(plugin_name)
|
||||
self.plugins.append(plugin)
|
||||
self.plugins[plugin_name] = plugin
|
||||
except ImportError:
|
||||
logging.error(f"Failed to load plugin: {plugin_name}")
|
||||
|
||||
@@ -277,7 +285,7 @@ class PluginManager:
|
||||
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
|
||||
"""
|
||||
input_args = []
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
input_args_from_plugin = plugin.get_input_args()
|
||||
if input_args_from_plugin is not None:
|
||||
input_args.append(input_args_from_plugin)
|
||||
@@ -293,7 +301,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.pre_model_load(cfg)
|
||||
|
||||
def post_model_load(self, cfg, model):
|
||||
@@ -307,7 +315,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_model_load(cfg, model)
|
||||
|
||||
def pre_lora_load(self, cfg, model):
|
||||
@@ -321,7 +329,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.pre_lora_load(cfg, model)
|
||||
|
||||
def post_lora_load(self, cfg, model):
|
||||
@@ -335,7 +343,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_lora_load(cfg, model)
|
||||
|
||||
def create_optimizer(self, cfg, trainer):
|
||||
@@ -349,7 +357,7 @@ class PluginManager:
|
||||
Returns:
|
||||
object: The created optimizer, or None if none was found.
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
optimizer = plugin.create_optimizer(cfg, trainer)
|
||||
if optimizer is not None:
|
||||
return optimizer
|
||||
@@ -367,7 +375,7 @@ class PluginManager:
|
||||
Returns:
|
||||
object: The created learning rate scheduler, or None if none was found.
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
|
||||
if scheduler is not None:
|
||||
return scheduler
|
||||
@@ -385,8 +393,10 @@ class PluginManager:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
||||
"""
|
||||
callbacks = []
|
||||
for plugin in self.plugins:
|
||||
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
|
||||
for plugin in self.plugins.values():
|
||||
plugin_callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
|
||||
if plugin_callbacks: # if the plugin returned a list of callbacks
|
||||
callbacks.extend(plugin_callbacks)
|
||||
return callbacks
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
@@ -401,8 +411,10 @@ class PluginManager:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
||||
"""
|
||||
callbacks = []
|
||||
for plugin in self.plugins:
|
||||
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
||||
for plugin in self.plugins.values():
|
||||
plugin_callbacks = plugin.add_callbacks_post_trainer(cfg, trainer)
|
||||
if plugin_callbacks:
|
||||
callbacks.extend(plugin_callbacks)
|
||||
return callbacks
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
@@ -416,5 +428,5 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins:
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_train_unload(cfg)
|
||||
|
||||
21
src/axolotl/integrations/grokfast/LICENSE
Normal file
21
src/axolotl/integrations/grokfast/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
13
src/axolotl/integrations/grokfast/README.md
Normal file
13
src/axolotl/integrations/grokfast/README.md
Normal file
@@ -0,0 +1,13 @@
|
||||
# Grokfast Optimizer
|
||||
|
||||
See https://github.com/ironjr/grokfast
|
||||
|
||||
### Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.grokfast.GrokfastPlugin
|
||||
|
||||
grokfast_alpha: 2.0
|
||||
grokfast_lamb: 0.98
|
||||
```
|
||||
50
src/axolotl/integrations/grokfast/__init__.py
Normal file
50
src/axolotl/integrations/grokfast/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Grokfast plugin for Axolotl
|
||||
"""
|
||||
import logging
|
||||
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
from ..base import BasePlugin
|
||||
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
|
||||
from .optimizer import gradfilter_ema
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.grokfast")
|
||||
|
||||
|
||||
class GrokfastCallbackHandler(TrainerCallback):
|
||||
"""
|
||||
Transformer trainer callbacks for Grokfast
|
||||
"""
|
||||
|
||||
def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs):
|
||||
super().__init__(*args_, **kwargs)
|
||||
self.grads = None
|
||||
self.alpha = alpha
|
||||
self.lamb = lamb
|
||||
|
||||
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
|
||||
self.grads = None
|
||||
|
||||
def on_pre_optimizer_step(
|
||||
self, args_, state, control, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
model = kwargs.pop("model")
|
||||
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
|
||||
return control
|
||||
|
||||
|
||||
class GrokfastPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for Grokfast optimizer integraton with Axolotl.
|
||||
"""
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.grokfast.GrokfastArgs"
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
LOG.info("Adding Grokfast callback to the trainer")
|
||||
callback = GrokfastCallbackHandler(
|
||||
alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb
|
||||
)
|
||||
return [callback]
|
||||
15
src/axolotl/integrations/grokfast/args.py
Normal file
15
src/axolotl/integrations/grokfast/args.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
config args for grokfast plugin
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GrokfastArgs(BaseModel):
|
||||
"""
|
||||
Input args for Grokfast optimizer.
|
||||
"""
|
||||
|
||||
grokfast_alpha: Optional[float] = 0.98
|
||||
grokfast_lamb: Optional[float] = 2.0
|
||||
63
src/axolotl/integrations/grokfast/optimizer.py
Normal file
63
src/axolotl/integrations/grokfast/optimizer.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
|
||||
# Reference: https://github.com/ironjr/grokfast
|
||||
|
||||
# pylint: skip-file
|
||||
from collections import deque
|
||||
from typing import Dict, Literal, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def gradfilter_ma(
|
||||
m: nn.Module,
|
||||
grads: Optional[Dict[str, deque]] = None,
|
||||
window_size: int = 100,
|
||||
lamb: float = 5.0,
|
||||
filter_type: Literal["mean", "sum"] = "mean",
|
||||
warmup: bool = True,
|
||||
trigger: bool = False, # For ablation study.
|
||||
) -> Dict[str, deque]:
|
||||
if grads is None:
|
||||
grads = {
|
||||
n: deque(maxlen=window_size)
|
||||
for n, p in m.named_parameters()
|
||||
if p.requires_grad and p.grad is not None
|
||||
}
|
||||
|
||||
for n, p in m.named_parameters():
|
||||
if p.requires_grad and p.grad is not None:
|
||||
grads[n].append(p.grad.data.detach()) # .cpu())
|
||||
|
||||
# Modify the gradients.
|
||||
if not warmup or len(grads[n]) == window_size and not trigger:
|
||||
if filter_type == "mean":
|
||||
avg = sum(grads[n]) / len(grads[n])
|
||||
elif filter_type == "sum":
|
||||
avg = sum(grads[n])
|
||||
else:
|
||||
raise ValueError(f"Unrecognized filter_type {filter_type}")
|
||||
p.grad.data = p.grad.data + avg * lamb
|
||||
|
||||
return grads
|
||||
|
||||
|
||||
def gradfilter_ema(
|
||||
m: nn.Module,
|
||||
grads: Optional[Dict[str, torch.Tensor]] = None,
|
||||
alpha: float = 0.98,
|
||||
lamb: float = 2.0,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if grads is None:
|
||||
grads = {
|
||||
n: p.grad.data.detach()
|
||||
for n, p in m.named_parameters()
|
||||
if p.requires_grad and p.grad is not None
|
||||
}
|
||||
|
||||
for n, p in m.named_parameters():
|
||||
if p.requires_grad and p.grad is not None:
|
||||
grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
|
||||
p.grad.data = p.grad.data + grads[n] * lamb
|
||||
|
||||
return grads
|
||||
@@ -18,20 +18,24 @@ Module for the Plugin for LIGER integraton with Axolotl.
|
||||
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
||||
It is designed to be performant, correct, and light-weight.
|
||||
"""
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
from ...utils.distributed import zero_only
|
||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.liger")
|
||||
|
||||
|
||||
class LigerPlugin(BasePlugin):
|
||||
"""
|
||||
@@ -42,59 +46,31 @@ class LigerPlugin(BasePlugin):
|
||||
return "axolotl.integrations.liger.LigerArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
if cfg.model_config_type == "llama":
|
||||
from liger_kernel.transformers.model.llama import (
|
||||
lce_forward as llama_lce_forward,
|
||||
)
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_llama.LlamaRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_llama.LlamaMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
elif cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "mistral":
|
||||
from liger_kernel.transformers.model.mistral import (
|
||||
lce_forward as mistral_lce_forward,
|
||||
)
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_mistral.MistralRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "gemma":
|
||||
from liger_kernel.transformers.model.gemma import (
|
||||
lce_forward as gemma_lce_forward,
|
||||
)
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_gemma.GemmaRMSNorm = partial(
|
||||
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||
kwargs = {}
|
||||
if "rope" in liger_fn_sig.parameters:
|
||||
kwargs["rope"] = cfg.liger_rope
|
||||
if "cross_entropy" in liger_fn_sig.parameters:
|
||||
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
||||
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
||||
kwargs[
|
||||
"fused_linear_cross_entropy"
|
||||
] = cfg.liger_fused_linear_cross_entropy
|
||||
if "rms_norm" in liger_fn_sig.parameters:
|
||||
kwargs["rms_norm"] = cfg.liger_rms_norm
|
||||
if "layer_norm" in liger_fn_sig.parameters:
|
||||
kwargs["layer_norm"] = cfg.liger_layer_norm
|
||||
if "geglu" in liger_fn_sig.parameters:
|
||||
kwargs["geglu"] = cfg.liger_glu_activation
|
||||
elif "swiglu" in liger_fn_sig.parameters:
|
||||
kwargs["swiglu"] = cfg.liger_glu_activation
|
||||
with zero_only():
|
||||
LOG.info(
|
||||
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
|
||||
)
|
||||
if cfg.liger_swiglu:
|
||||
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
||||
|
||||
apply_liger_fn(**kwargs)
|
||||
elif cfg.model_config_type == "jamba":
|
||||
from transformers.models.jamba import modeling_jamba
|
||||
|
||||
@@ -104,30 +80,14 @@ class LigerPlugin(BasePlugin):
|
||||
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "qwen2":
|
||||
from liger_kernel.transformers.model.qwen2 import (
|
||||
lce_forward as qwen2_lce_forward,
|
||||
)
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "deepseek_v2":
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoModelForCausalLM
|
||||
@@ -146,44 +106,11 @@ class LigerPlugin(BasePlugin):
|
||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
if cfg.liger_glu_activation:
|
||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||
if cfg.liger_cross_entropy:
|
||||
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
||||
# nn.CrossEntropyLoss in the forward method.
|
||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "gemma2":
|
||||
from transformers.models.gemma2 import modeling_gemma2
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_gemma2.Gemma2RMSNorm = partial(
|
||||
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
||||
)
|
||||
if cfg.liger_swiglu:
|
||||
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
logging.warning(
|
||||
"Fused linear cross entropy is not supported for Gemma 2."
|
||||
)
|
||||
|
||||
elif cfg.model_config_type == "phi3":
|
||||
from liger_kernel.transformers.model.phi3 import (
|
||||
lce_forward as phi3_lce_forward,
|
||||
)
|
||||
from transformers.models.phi3 import modeling_phi3
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
||||
|
||||
@@ -15,9 +15,12 @@
|
||||
"""
|
||||
Module for handling LIGER input arguments.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.liger.args")
|
||||
|
||||
|
||||
class LigerArgs(BaseModel):
|
||||
@@ -27,6 +30,24 @@ class LigerArgs(BaseModel):
|
||||
|
||||
liger_rope: Optional[bool] = None
|
||||
liger_rms_norm: Optional[bool] = None
|
||||
liger_layer_norm: Optional[bool] = None
|
||||
liger_swiglu: Optional[bool] = None
|
||||
liger_glu_activation: Optional[bool] = None
|
||||
liger_cross_entropy: Optional[bool] = None
|
||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_deprecated_swiglu(cls, data):
|
||||
if data.get("liger_swiglu") is not None:
|
||||
if data.get("liger_glu_activation") is not None:
|
||||
raise ValueError(
|
||||
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
|
||||
)
|
||||
|
||||
LOG.warning(
|
||||
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
|
||||
"Please use 'liger_glu_activation' instead."
|
||||
)
|
||||
data["liger_glu_activation"] = data.pop("liger_swiglu")
|
||||
return data
|
||||
|
||||
@@ -1,231 +0,0 @@
|
||||
"""
|
||||
monkeypatch to add a get_turns method
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Generator, Tuple
|
||||
|
||||
from fastchat.conversation import SeparatorStyle
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
|
||||
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
ret = ""
|
||||
for role, msg in self.get_turns():
|
||||
ret += role + msg
|
||||
return ret
|
||||
|
||||
|
||||
def get_turns( # pylint: disable=too-many-return-statements
|
||||
self,
|
||||
) -> Generator[Tuple[str, str], None, None]:
|
||||
"""Get the prompt for generation."""
|
||||
system_prompt = self.system_template.format(system_message=self.system_message)
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role + ": ", message + seps[i % 2]
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ": ", "" # must be end with a space
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
|
||||
yield "", "" if system_prompt == "" else system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", message + self.sep
|
||||
else:
|
||||
yield role + "\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role, message + self.sep
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role, message + seps[i % 2]
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.RWKV:
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role + ": ", message.replace("\r\n", "\n").replace(
|
||||
"\n\n", "\n"
|
||||
) + "\n\n"
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
|
||||
if self.system_message:
|
||||
if self.messages:
|
||||
# For llama, the system message is incorporated into the first human instruction
|
||||
first_role, first_msg = self.messages[0]
|
||||
if first_role == self.roles[0]:
|
||||
system_prompt += first_msg
|
||||
self.messages.pop(0)
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
if (i % 2 == 0 and not self.system_message) or (
|
||||
i % 2 != 0 and self.system_message
|
||||
):
|
||||
role = "<s> " + role
|
||||
yield role + " ", message
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
|
||||
contains_sys_msg = False
|
||||
if self.system_message:
|
||||
contains_sys_msg = True
|
||||
if self.messages:
|
||||
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction separated by a newline
|
||||
first_role, first_msg = self.messages[0]
|
||||
if first_role == self.roles[0]:
|
||||
system_prompt = self.system_template.format(
|
||||
system_message=" " + self.system_message
|
||||
)
|
||||
system_prompt += first_msg
|
||||
self.messages.pop(0)
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message and i == 0 and not contains_sys_msg:
|
||||
yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
|
||||
elif message:
|
||||
yield role + " ", message
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA3:
|
||||
if self.system_message:
|
||||
# For llama3, the system message is NOT incorporated into the first human instruction
|
||||
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
|
||||
else:
|
||||
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.GEMMA:
|
||||
if self.system_message:
|
||||
raise ValueError("Gemma chat template does not support system messages")
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
prefix = "<bos>" if i == 0 else ""
|
||||
message_str = message if message else ""
|
||||
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATGLM:
|
||||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||
round_add_n = 1 if self.name == "chatglm2" else 0
|
||||
if system_prompt:
|
||||
yield "", system_prompt + self.sep
|
||||
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if i % 2 == 0:
|
||||
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
|
||||
|
||||
if message:
|
||||
yield f"{role}:", f"{message}{self.sep}"
|
||||
else:
|
||||
yield f"{role}:", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATML:
|
||||
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", message + self.sep + "\n"
|
||||
else:
|
||||
yield role + "\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATGLM3:
|
||||
if self.system_message:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", " " + message
|
||||
else:
|
||||
yield role
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATINTERN:
|
||||
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
prefix = "<s>" if i % 2 == 0 else ""
|
||||
if message:
|
||||
yield prefix + role + ":", message + seps[i % 2] + "\n"
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.DOLLY:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
suffix = "\n\n" if i % 2 == 1 else ""
|
||||
yield role + ":\n", message + seps[i % 2] + suffix
|
||||
else:
|
||||
yield role + ":\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.PHOENIX:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", "<s>" + message + "</s>"
|
||||
else:
|
||||
yield role + ": " + "<s>", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ROBIN:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ":\n", message + self.sep
|
||||
else:
|
||||
yield role + ":\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.FALCON_CHAT:
|
||||
if self.system_message:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ":", ""
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
|
||||
def add_get_turns_to_conversation():
|
||||
import fastchat.conversation
|
||||
|
||||
fastchat.conversation.Conversation.get_turns = get_turns
|
||||
fastchat.conversation.Conversation.get_prompt = get_prompt
|
||||
@@ -1,4 +1,5 @@
|
||||
"""multipack patching for v2 of sample packing"""
|
||||
|
||||
import importlib
|
||||
|
||||
import transformers
|
||||
@@ -19,6 +20,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"falcon",
|
||||
"phi",
|
||||
"phi3",
|
||||
"phimoe",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"gemmoe",
|
||||
@@ -27,74 +29,28 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
]
|
||||
|
||||
|
||||
# def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
|
||||
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
||||
if model_type == "gemmoe":
|
||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||
elif model_type == "deepseek_v2":
|
||||
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
||||
# elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
|
||||
if has_remote_code:
|
||||
patch_remote(model_name)
|
||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||
if not has_remote_code:
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
return
|
||||
|
||||
# retain for legacy
|
||||
if model_type == "mixtral":
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
elif model_type == "llama":
|
||||
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "mistral":
|
||||
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "qwen2":
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "qwen2_moe":
|
||||
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "falcon":
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "phi":
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemma":
|
||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemma2":
|
||||
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "starcoder2":
|
||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
|
||||
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
|
||||
def patch_remote(model_name, config_name, modeling_name):
|
||||
|
||||
def patch_remote(model_name):
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_* to be available
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
|
||||
parts = model_config.__class__.__module__.split(".")
|
||||
parts[-1] = parts[-1].replace("configuration_", "modeling_", 1)
|
||||
module_name = ".".join(parts)
|
||||
modeling_arch = importlib.import_module(module_name)
|
||||
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
|
||||
if hasattr(modeling_arch, "_get_unpad_data"):
|
||||
modeling_arch._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
|
||||
83
src/axolotl/monkeypatch/trainer_fsdp_grad_accum.py
Normal file
83
src/axolotl/monkeypatch/trainer_fsdp_grad_accum.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
fix for FSDP gradient accumulation
|
||||
see https://github.com/huggingface/transformers/pull/34645
|
||||
"""
|
||||
import inspect
|
||||
|
||||
from accelerate.logging import get_logger
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||
|
||||
LOG = get_logger("axolotl.monkeypatch.trainer_fsdp_grad_accumulation")
|
||||
|
||||
ORIGINAL_CONTEXT_CODE = """
|
||||
context = (
|
||||
functools.partial(self.accelerator.no_sync, model=model)
|
||||
if i == len(batch_samples) - 1
|
||||
else contextlib.nullcontext
|
||||
)
|
||||
"""
|
||||
|
||||
PATCHED_CONTEXT_CODE = """
|
||||
context = (
|
||||
functools.partial(self.accelerator.no_sync, model=model)
|
||||
if i != len(batch_samples) - 1
|
||||
else contextlib.nullcontext
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
def get_training_loop_code() -> str:
|
||||
training_loop = inspect.getsource(
|
||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||
)
|
||||
return training_loop
|
||||
|
||||
|
||||
def check_training_loop_is_patchable() -> bool:
|
||||
train_loop = get_training_loop_code()
|
||||
train_loop, _ = detab_code(train_loop)
|
||||
return ORIGINAL_CONTEXT_CODE in train_loop
|
||||
|
||||
|
||||
def patch_training_loop_for_fsdp_grad_accum():
|
||||
"""
|
||||
monkeypatch for fixing the training loop for FSDP gradient accumulation
|
||||
"""
|
||||
|
||||
train_loop = get_training_loop_code()
|
||||
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
|
||||
train_loop
|
||||
)
|
||||
train_loop, _ = detab_code(train_loop)
|
||||
assert (
|
||||
ORIGINAL_CONTEXT_CODE in train_loop
|
||||
), "Original _inner_training_loop code not found"
|
||||
|
||||
train_loop = train_loop.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
|
||||
train_loop = train_loop.replace(
|
||||
"def _inner_training_loop(",
|
||||
"def _fixed_inner_training_loop(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.trainer
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.trainer):
|
||||
if item in train_loop:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.trainer import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(train_loop, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching _inner_training_loop", main_process_only=True)
|
||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
@@ -1,33 +0,0 @@
|
||||
"""Module containing the InstructShareGPTPromptTokenizingStrategy class"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||
)
|
||||
strategy = InstructShareGPTPromptTokenizingStrategy(
|
||||
# pylint: disable=duplicate-code
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
return strategy
|
||||
|
||||
|
||||
class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return [
|
||||
{"from": "human", "value": prompt["instruction"]},
|
||||
{"from": "gpt", "value": prompt["output"]},
|
||||
]
|
||||
@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Generator, List, Sequence
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -75,7 +75,7 @@ class Llama2ChatConversation:
|
||||
|
||||
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for ShareGPT prompts.
|
||||
Tokenizing strategy for Llama2 prompts.
|
||||
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
|
||||
"""
|
||||
|
||||
@@ -191,7 +191,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
|
||||
conv.messages = [] # pylint: disable=R0801
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
assert role == conv.roles[j % 2], ALTERNATING_ASSERTION_FAILED_ROLE
|
||||
if sentence["value"]:
|
||||
conv.append_message(role, sentence["value"])
|
||||
yield conv
|
||||
|
||||
@@ -1,223 +0,0 @@
|
||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
from axolotl.utils.tokenization import (
|
||||
chatml_to_conversation,
|
||||
merge_consecutive_messages,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def register_chatml_template(system_message=None):
|
||||
system_message = system_message or "You are a helpful assistant."
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="chatml",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
system_message=system_message,
|
||||
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>",
|
||||
)
|
||||
)
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="chatml_glaive",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
system_message=system_message,
|
||||
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def register_llama3_template(system_message=None):
|
||||
system_message = system_message or "You are a helpful assistant."
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="llama3",
|
||||
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
||||
system_message=system_message,
|
||||
roles=("user", "assistant"),
|
||||
sep_style=SeparatorStyle.LLAMA3,
|
||||
sep="",
|
||||
stop_str="<|eot_id|>",
|
||||
stop_token_ids=[128001, 128009],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_loader(
|
||||
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
|
||||
prompter_cls: Type["ShareGPTPrompterV2"],
|
||||
default_conversation: Optional[str] = None,
|
||||
):
|
||||
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
LOG.warning(
|
||||
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template",
|
||||
)
|
||||
conversation = (
|
||||
ds_cfg["conversation"]
|
||||
if ds_cfg and "conversation" in ds_cfg
|
||||
else default_conversation
|
||||
)
|
||||
field_human = (
|
||||
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
)
|
||||
field_model = (
|
||||
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
)
|
||||
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
||||
strategy = tokenization_strategy_cls(
|
||||
prompter_cls(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
role_key_human=field_human,
|
||||
roles=roles,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
|
||||
strategy.strict = ds_cfg["strict"]
|
||||
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
return strategy
|
||||
|
||||
return _load
|
||||
|
||||
|
||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
_strict = False
|
||||
_messages = "conversations"
|
||||
|
||||
@property
|
||||
def strict(self):
|
||||
return self._strict
|
||||
|
||||
@strict.setter
|
||||
def strict(self, strict):
|
||||
self._strict = strict
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
return self._messages
|
||||
|
||||
@messages.setter
|
||||
def messages(self, messages):
|
||||
self._messages = messages
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt[self.messages]
|
||||
if self.strict:
|
||||
return conversations
|
||||
role_key = "from"
|
||||
if "role" in conversations[0].keys():
|
||||
role_key = "role"
|
||||
value_key = "value"
|
||||
if "text" in conversations[0].keys():
|
||||
value_key = "text"
|
||||
elif "content" in conversations[0].keys():
|
||||
value_key = "content"
|
||||
# remap roles - allow for assistant turn"
|
||||
role_map = {
|
||||
"user": "human",
|
||||
"human": "human",
|
||||
"assistant": "gpt",
|
||||
"gpt": "gpt",
|
||||
"system": "system",
|
||||
}
|
||||
turns = [
|
||||
{
|
||||
"from": (
|
||||
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
|
||||
),
|
||||
"value": t[value_key],
|
||||
"weight": 1
|
||||
if "weight" not in t or t["weight"] is None
|
||||
else t["weight"],
|
||||
}
|
||||
for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
class SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||
SimpleShareGPTPromptTokenizingStrategy
|
||||
):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
|
||||
return turns
|
||||
|
||||
|
||||
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
sharegpt strategy that remaps oasst data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
role_map = {"prompter": "human", "assistant": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
sharegpt strategy that remaps ultrachat data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["messages"]
|
||||
role_map = {"user": "human", "assistant": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
sharegpt strategy that remaps glaive data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversation = chatml_to_conversation(prompt)
|
||||
conversation = merge_consecutive_messages(conversation)
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||
load_ultrachat = build_loader(
|
||||
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
|
||||
)
|
||||
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||
load_glaive = build_loader(
|
||||
GlaiveShareGPTPromptTokenizingStrategy,
|
||||
ShareGPTPrompterV2,
|
||||
default_conversation="chatml_glaive",
|
||||
)
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Module for Jokes prompts using sharegpt style """
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return SimpleJokesShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenization strategy for asking bot to tell a joke and then explain why its funny
|
||||
"""
|
||||
|
||||
# title, text, explanation
|
||||
def get_conversation_thread(self, prompt):
|
||||
title = "" if not prompt["title"] else prompt["title"] + " "
|
||||
return [
|
||||
{"from": "human", "value": "Tell me a joke."},
|
||||
{"from": "gpt", "value": title + prompt["text"]},
|
||||
{"from": "human", "value": "Why is that joke funny?"},
|
||||
{"from": "gpt", "value": prompt["explanation"]},
|
||||
]
|
||||
@@ -1,17 +1,12 @@
|
||||
"""Module containing PromptTokenizingStrategy and Prompter classes"""
|
||||
|
||||
import abc
|
||||
import copy
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
from transformers import BatchEncoding, PreTrainedTokenizer
|
||||
|
||||
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
||||
add_get_turns_to_conversation,
|
||||
)
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
from axolotl.prompters import Prompter
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -21,8 +16,6 @@ LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
||||
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
||||
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
||||
|
||||
add_get_turns_to_conversation()
|
||||
|
||||
|
||||
class InvalidDataException(Exception):
|
||||
"""
|
||||
@@ -331,154 +324,6 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
|
||||
)
|
||||
|
||||
|
||||
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for ShareGPT prompts.
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt["conversations"]
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# Initial values. We will append to these as we go through the conversation.
|
||||
result, current_len = tokenize_prompt_default()
|
||||
conversation: Conversation = (
|
||||
self.prompter._conversation.copy() # pylint: disable=protected-access
|
||||
)
|
||||
|
||||
input_roles = {conversation.roles[0]}
|
||||
output_roles = {conversation.roles[1]}
|
||||
|
||||
if len(conversation.roles) == 3:
|
||||
tool_role_label = conversation.roles[2]
|
||||
input_roles.add(tool_role_label)
|
||||
|
||||
# Add roles from the config
|
||||
if self.prompter.roles:
|
||||
if "input" in self.prompter.roles and self.prompter.roles["input"]:
|
||||
for role in self.prompter.roles["input"]:
|
||||
input_roles.add(role)
|
||||
|
||||
if "output" in self.prompter.roles and self.prompter.roles["output"]:
|
||||
for role in self.prompter.roles["output"]:
|
||||
output_roles.add(role)
|
||||
|
||||
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
||||
role_remap = []
|
||||
if (
|
||||
conversation.name == "vicuna_v1.1"
|
||||
and "roles" in prompt
|
||||
and len(prompt["roles"]) >= 2
|
||||
):
|
||||
role_remap = [
|
||||
{"from": conversation.roles[0], "to": prompt["roles"][0]},
|
||||
{"from": conversation.roles[1], "to": prompt["roles"][1]},
|
||||
]
|
||||
|
||||
try:
|
||||
for _, part in enumerate(
|
||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
||||
):
|
||||
if not isinstance(part, tuple):
|
||||
LOG.warning(f"expected tuple, got {part}")
|
||||
continue
|
||||
|
||||
if len(part) <= 2:
|
||||
role, content = part
|
||||
weight = 1
|
||||
else:
|
||||
role, content, weight = part
|
||||
|
||||
# Uses "in" because role contains extra characters
|
||||
input_turn = any(r.lower() in role.lower() for r in input_roles)
|
||||
output_turn = any(r.lower() in role.lower() for r in output_roles)
|
||||
empty_role = role.strip() == ""
|
||||
|
||||
if not any([input_turn, output_turn, empty_role]):
|
||||
LOG.warning(f"unhandled role: {role}")
|
||||
continue
|
||||
|
||||
if input_turn:
|
||||
role = (
|
||||
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||
if role_remap
|
||||
else role
|
||||
)
|
||||
turn = role + content
|
||||
# this is still the user query, we should
|
||||
if not content.strip():
|
||||
LOG.warning(f"user turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if self.train_on_inputs and weight == 1:
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
else:
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif output_turn:
|
||||
role = (
|
||||
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||
if role_remap
|
||||
else role
|
||||
)
|
||||
turn = role + content
|
||||
# this should be the assistant response, should end with an eos token
|
||||
if not content.strip():
|
||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
||||
add_eos_token = not (
|
||||
conversation.name == "chatml"
|
||||
and conversation.sep == self.tokenizer.eos_token
|
||||
)
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=add_eos_token,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
role_res = self._tokenize(
|
||||
role.rstrip(),
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
if not self.train_on_inputs:
|
||||
# mask out role tokens from the labels
|
||||
len_role = len(role_res["input_ids"])
|
||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
||||
len_role, len(labels)
|
||||
)
|
||||
if weight == 0:
|
||||
# everything from this is masked out from the labels
|
||||
# (role is masked out too because it makes no sense if contents is masked out)
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
|
||||
elif empty_role:
|
||||
turn = content
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
turn, add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
if self.train_on_inputs and weight == 1:
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
else:
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
result,
|
||||
current_len,
|
||||
res,
|
||||
labels,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
return result
|
||||
except (KeyError, AssertionError, IndexError) as err:
|
||||
raise InvalidDataException(str(err)) from err
|
||||
|
||||
|
||||
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
||||
"""
|
||||
Returns the default values for the tokenize prompt function
|
||||
|
||||
@@ -5,7 +5,6 @@ from enum import Enum
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
from colorama import Fore
|
||||
from fastchat.conversation import Conversation, get_conv_template
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
IGNORE_TOKEN_ID = -100
|
||||
@@ -262,166 +261,10 @@ class ReflectAlpacaPrompter(Prompter):
|
||||
)
|
||||
|
||||
|
||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
ALTERNATING_ASSERTION_FAILED_ROLE = (
|
||||
"Role did not alternate between turns (gpt and human). Please check your data."
|
||||
)
|
||||
|
||||
CONVERSATION_ROLE_FORMAT = {
|
||||
"chatml": "<|im_start|>{ROLE}",
|
||||
"zephyr": "<|{ROLE}|>",
|
||||
"vicuna_v1.1": "{ROLE}",
|
||||
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
|
||||
}
|
||||
|
||||
|
||||
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
|
||||
role_key_human = "human"
|
||||
role_key_model = "gpt"
|
||||
# Optional, only used for tool usage datasets.
|
||||
role_key_tool: Optional[str] = None
|
||||
# Optional, role input/output mapping
|
||||
roles: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_style=None, # pylint: disable=unused-argument
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
role_key_tool: Optional[str] = None,
|
||||
roles: Optional[dict] = None,
|
||||
):
|
||||
if conversation:
|
||||
if isinstance(conversation, Conversation):
|
||||
self._conversation = conversation
|
||||
else:
|
||||
self._conversation = get_conv_template(conversation)
|
||||
else:
|
||||
self._conversation = get_conv_template("vicuna_v1.1")
|
||||
if role_key_human:
|
||||
self.role_key_human = role_key_human
|
||||
if role_key_model:
|
||||
self.role_key_model = role_key_model
|
||||
if role_key_tool:
|
||||
self.role_key_tool = role_key_tool
|
||||
if roles:
|
||||
self.roles = roles
|
||||
|
||||
def _build_result(self, source):
|
||||
if len(source) < 2:
|
||||
# If there isn't a back and forth conversation, ignore it
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
raise IndexError(
|
||||
f"A conversation entry has less than 2 messages :\n{source}"
|
||||
)
|
||||
|
||||
conv = self._conversation.copy()
|
||||
|
||||
original_source = source.copy()
|
||||
# Add the conversation system prompt if provided, otherwise use the default one
|
||||
if source[0]["from"] == "system":
|
||||
conv.set_system_message(source[0]["value"])
|
||||
source.pop(0)
|
||||
|
||||
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
|
||||
if self.role_key_tool:
|
||||
roles[self.role_key_tool] = conv.roles[2]
|
||||
|
||||
try:
|
||||
# Apply prompt templates
|
||||
if source[0]["from"] not in roles:
|
||||
# Skip the first one if it is not from human
|
||||
source = source[1:]
|
||||
except IndexError as err:
|
||||
# sometimes there is a bing or system chat
|
||||
raise err
|
||||
|
||||
conv.messages = []
|
||||
for _, sentence in enumerate(source):
|
||||
from_role = sentence["from"]
|
||||
if from_role in roles:
|
||||
role = roles[from_role]
|
||||
else:
|
||||
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
|
||||
raise NotImplementedError(
|
||||
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
|
||||
"Please help us by creating an Issue to add support for this conversation type."
|
||||
)
|
||||
|
||||
if self._conversation.name in ["llama3"]:
|
||||
role = from_role
|
||||
else:
|
||||
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
|
||||
ROLE=from_role
|
||||
)
|
||||
|
||||
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
||||
if (
|
||||
role != "assistant"
|
||||
): # back to back assistant calls may be okay for tool calls
|
||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||
|
||||
conv.append_message(role, sentence["value"])
|
||||
turns = list(conv.get_turns())
|
||||
original_source_length = len(original_source)
|
||||
assert len(turns) in [
|
||||
original_source_length - 1,
|
||||
original_source_length,
|
||||
original_source_length + 1,
|
||||
]
|
||||
if len(turns) == original_source_length + 1:
|
||||
original_source = [{"weight": None}] + original_source
|
||||
elif len(turns) == original_source_length - 1:
|
||||
original_source = original_source[1:]
|
||||
return [
|
||||
(*turn, weight)
|
||||
for turn, weight in zip(
|
||||
turns,
|
||||
[
|
||||
1 if "weight" not in e or e["weight"] is None else e["weight"]
|
||||
for e in original_source
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
turns = self._build_result(source)
|
||||
|
||||
for part in turns:
|
||||
if part[0] and not part[1]:
|
||||
LOG.warning(f"role with empty message: {part[0]}")
|
||||
yield part
|
||||
|
||||
def __repr__(self) -> str:
|
||||
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
|
||||
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
|
||||
|
||||
|
||||
class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
"""
|
||||
A V2 prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
role_key_tool: Optional[str] = None,
|
||||
roles: Optional[dict] = None,
|
||||
):
|
||||
super().__init__(
|
||||
conversation=conversation,
|
||||
role_key_human=role_key_human,
|
||||
role_key_model=role_key_model,
|
||||
role_key_tool=role_key_tool,
|
||||
roles=roles,
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedPrompter(Prompter):
|
||||
"""
|
||||
|
||||
@@ -64,10 +64,7 @@ class EvalFirstStepCallback(
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if (
|
||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||
and state.global_step == 1
|
||||
):
|
||||
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
||||
control.should_evaluate = True
|
||||
return control
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,8 +1,6 @@
|
||||
"""Module for working with config dicts"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -10,7 +8,6 @@ from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.integrations.config import merge_input_args
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.config.models.input.v0_4_1 import SUPPORTED_METRICS
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
)
|
||||
@@ -215,11 +212,6 @@ def normalize_cfg_datasets(cfg):
|
||||
if cfg.chat_template:
|
||||
if cfg.datasets:
|
||||
for idx, ds_cfg in enumerate(cfg.datasets):
|
||||
if ds_cfg.type == "sharegpt" and not ds_cfg.conversation:
|
||||
LOG.info(
|
||||
f"updating dataset {ds_cfg.path} with `conversation: {cfg.chat_template}` to match your chat_template"
|
||||
)
|
||||
cfg.datasets[idx].conversation = cfg.chat_template
|
||||
if (
|
||||
ds_cfg.type in ["orpo.chat_template", "chat_template"]
|
||||
and not ds_cfg.chat_template
|
||||
@@ -252,391 +244,3 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||
return DictDefault(
|
||||
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
||||
)
|
||||
|
||||
|
||||
def legacy_validate_config(cfg):
|
||||
"""
|
||||
This is a "pre-validation" step that handles the yaml configuration before we have any
|
||||
information about the model architecture
|
||||
"""
|
||||
if is_torch_bf16_gpu_available():
|
||||
if not cfg.bf16 and not cfg.bfloat16:
|
||||
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
||||
else:
|
||||
if (
|
||||
not cfg.merge_lora
|
||||
and not cfg.is_preprocess
|
||||
and (cfg.bf16 is True or cfg.bfloat16 is True)
|
||||
):
|
||||
raise ValueError(
|
||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
||||
)
|
||||
if (
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
not (cfg.bf16 or cfg.bfloat16)
|
||||
and (cfg.fp16 or cfg.float16)
|
||||
and not cfg.adapter
|
||||
and not cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
LOG.warning(
|
||||
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
|
||||
)
|
||||
# ValueError: Attempting to unscale FP16 gradients.
|
||||
# OR
|
||||
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
|
||||
if cfg.max_packed_sequence_len:
|
||||
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
||||
|
||||
if cfg.sample_packing and cfg.rl:
|
||||
raise ValueError("`sample_packing: true` does not work with RLHF training")
|
||||
|
||||
if cfg.sample_packing and not cfg.pad_to_sequence_len:
|
||||
LOG.warning(
|
||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||
)
|
||||
|
||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||
raise ValueError(
|
||||
"please set only one of gradient_accumulation_steps or batch_size"
|
||||
)
|
||||
if cfg.batch_size:
|
||||
LOG.warning(
|
||||
"%s\n%s",
|
||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||
)
|
||||
if (
|
||||
cfg.eval_batch_size
|
||||
and cfg.micro_batch_size
|
||||
and cfg.eval_batch_size != cfg.micro_batch_size
|
||||
):
|
||||
LOG.warning(
|
||||
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||
)
|
||||
|
||||
if cfg.adapter == "qlora":
|
||||
if cfg.merge_lora:
|
||||
# can't merge qlora if loaded in 8bit or 4bit
|
||||
if cfg.load_in_8bit:
|
||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||
|
||||
if cfg.gptq:
|
||||
raise ValueError("Can't merge qlora if gptq")
|
||||
|
||||
if cfg.load_in_4bit:
|
||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||
|
||||
else:
|
||||
if cfg.load_in_8bit:
|
||||
raise ValueError("Can't load qlora in 8bit")
|
||||
|
||||
if cfg.gptq:
|
||||
raise ValueError("Can't load qlora if gptq")
|
||||
|
||||
if not cfg.load_in_4bit:
|
||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||
|
||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with QLoRA")
|
||||
|
||||
loftq = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
||||
if not cfg.load_in_8bit and cfg.adapter == "lora" and not loftq:
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
|
||||
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
||||
raise ValueError("Fused modules are not supported with LoRA")
|
||||
|
||||
if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters:
|
||||
raise ValueError(
|
||||
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
|
||||
)
|
||||
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
|
||||
if cfg.fsdp:
|
||||
raise ValueError("fsdp not supported with ReLoRA")
|
||||
|
||||
if cfg.deepspeed:
|
||||
raise ValueError("deepspeed not supported with ReLoRA")
|
||||
|
||||
if cfg.lr_scheduler == "one_cycle":
|
||||
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
||||
|
||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with ReLoRA")
|
||||
|
||||
if cfg.trust_remote_code:
|
||||
LOG.warning(
|
||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||
)
|
||||
|
||||
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
|
||||
raise ValueError(
|
||||
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
||||
)
|
||||
|
||||
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
||||
raise ValueError("FSDP is not supported for falcon models")
|
||||
|
||||
if (
|
||||
cfg.base_model and "mpt" in cfg.base_model.lower()
|
||||
) and cfg.gradient_checkpointing:
|
||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||
|
||||
if cfg.flash_optimum is True:
|
||||
if cfg.adapter:
|
||||
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
||||
if cfg.fp16 or cfg.bf16:
|
||||
raise ValueError("AMP is not supported with BetterTransformer")
|
||||
if cfg.float16 is not True and cfg.bfloat16 is not True:
|
||||
LOG.warning(
|
||||
"You should probably set bfloat16 or float16 to true to "
|
||||
"load the model in float16 for BetterTransformers"
|
||||
)
|
||||
if int(torch.__version__.split(".", maxsplit=1)[0]) < 2:
|
||||
LOG.warning("torch>=2.0.0 required")
|
||||
raise ValueError(
|
||||
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
||||
)
|
||||
|
||||
if cfg.pretraining_dataset and cfg.group_by_length:
|
||||
LOG.warning(
|
||||
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
||||
)
|
||||
if cfg.pretraining_dataset and not cfg.max_steps:
|
||||
raise ValueError(
|
||||
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
|
||||
)
|
||||
|
||||
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
||||
not cfg.optimizer or "adamw" not in cfg.optimizer
|
||||
):
|
||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||
|
||||
if cfg.push_to_hub_model_id:
|
||||
raise ValueError(
|
||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||
)
|
||||
|
||||
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
|
||||
LOG.warning(
|
||||
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
|
||||
)
|
||||
|
||||
if cfg.gptq and cfg.revision_of_model:
|
||||
raise ValueError(
|
||||
"revision_of_model is not supported for GPTQ models. "
|
||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||
+ "point to its path, and remove revision_of_model from the config."
|
||||
)
|
||||
|
||||
# if cfg.sample_packing and cfg.sdp_attention:
|
||||
# # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
# raise ValueError(
|
||||
# "sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||
# )
|
||||
|
||||
if cfg.sample_packing and cfg.xformers_attention:
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
|
||||
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
|
||||
LOG.warning(
|
||||
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
|
||||
"This may work on H100s."
|
||||
)
|
||||
|
||||
if cfg.early_stopping_patience:
|
||||
if not cfg.save_steps or not cfg.eval_steps:
|
||||
raise ValueError(
|
||||
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
||||
)
|
||||
if cfg.save_steps % cfg.eval_steps != 0:
|
||||
raise ValueError(
|
||||
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
||||
)
|
||||
|
||||
if cfg.datasets:
|
||||
for idx, ds_cfg in enumerate(cfg.datasets):
|
||||
if not ds_cfg.type:
|
||||
continue
|
||||
if ds_cfg.type == "sharegpt:chat":
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||
)
|
||||
)
|
||||
cfg.datasets[idx].type = "sharegpt"
|
||||
if "sharegpt_simple" in ds_cfg.type:
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||
)
|
||||
)
|
||||
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
||||
"sharegpt_simple", "sharegpt"
|
||||
)
|
||||
|
||||
if cfg.saves_per_epoch and cfg.save_steps:
|
||||
raise ValueError(
|
||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||
)
|
||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||
)
|
||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||
raise ValueError(
|
||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if (
|
||||
cfg.evals_per_epoch
|
||||
and cfg.evaluation_strategy
|
||||
and cfg.evaluation_strategy != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
if (
|
||||
cfg.evaluation_strategy
|
||||
and cfg.eval_steps
|
||||
and cfg.evaluation_strategy != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.val_set_size == 0
|
||||
and (cfg.eval_steps or cfg.evaluation_strategy)
|
||||
and not cfg.test_datasets
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.sample_packing
|
||||
and cfg.eval_table_size
|
||||
and cfg.eval_sample_packing is not False
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
|
||||
)
|
||||
|
||||
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
|
||||
raise ValueError(
|
||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
|
||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||
)
|
||||
|
||||
if cfg.rope_scaling:
|
||||
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
||||
|
||||
if cfg.wandb_run_id and not cfg.wandb_name:
|
||||
cfg.wandb_name = cfg.wandb_run_id
|
||||
|
||||
LOG.warning(
|
||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||
)
|
||||
|
||||
if cfg.noisy_embedding_alpha is not None:
|
||||
# Deprecated, use neftune_noise_alpha
|
||||
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||
if cfg.neftune_noise_alpha is None:
|
||||
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
|
||||
else:
|
||||
# User is providing both; bail and have them sort out their settings
|
||||
raise ValueError(
|
||||
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
||||
)
|
||||
|
||||
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
||||
raise ValueError("neftune_noise_alpha must be > 0.0")
|
||||
|
||||
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
|
||||
raise ValueError(
|
||||
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.unfrozen_parameters
|
||||
and cfg.gradient_checkpointing_kwargs
|
||||
and cfg.gradient_checkpointing_kwargs.use_reentrant is True
|
||||
):
|
||||
# https://github.com/huggingface/transformers/issues/21381
|
||||
raise ValueError(
|
||||
"`use_reentrant` must be false when used with partially frozen model."
|
||||
)
|
||||
|
||||
if cfg.deepspeed and Path(cfg.deepspeed).is_file():
|
||||
with open(cfg.deepspeed, encoding="utf-8") as file:
|
||||
contents = file.read()
|
||||
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
|
||||
if cfg.flash_attention:
|
||||
if (
|
||||
deepspeed_cfg.zero_optimization
|
||||
and deepspeed_cfg.zero_optimization.stage == 3
|
||||
):
|
||||
if not (
|
||||
(
|
||||
deepspeed_cfg.bf16
|
||||
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
|
||||
is True
|
||||
)
|
||||
or (
|
||||
deepspeed_cfg.fp16
|
||||
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
|
||||
is True
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
||||
)
|
||||
if "8bit" in cfg.optimizer and deepspeed_cfg.optimizer:
|
||||
LOG.warning(
|
||||
f"conflicting optimizer: {cfg.optimizer} used alongside deepspeed optimizer."
|
||||
)
|
||||
|
||||
if cfg.test_datasets and cfg.val_set_size:
|
||||
raise ValueError(
|
||||
"non-zero val_set_size should not be used with test_datasets configuration"
|
||||
)
|
||||
|
||||
if cfg.fsdp and "bnb" in cfg.optimizer:
|
||||
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
|
||||
|
||||
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
|
||||
raise ValueError(
|
||||
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
|
||||
)
|
||||
|
||||
if cfg.eval_causal_lm_metrics:
|
||||
if not isinstance(cfg.eval_causal_lm_metrics, list):
|
||||
raise ValueError("eval_causal_lm_metrics must be a list")
|
||||
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
||||
if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS:
|
||||
raise ValueError(
|
||||
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
|
||||
)
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
# no 8bit adaAmw w bf16
|
||||
|
||||
# GPT-NeoX
|
||||
# evals broken when extending context len
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
|
||||
# attention_mask = causal_mask + attention_mask
|
||||
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
|
||||
|
||||
@@ -57,6 +57,8 @@ class ChatTemplate(str, Enum):
|
||||
jinja = "jinja" # pylint: disable=invalid-name
|
||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||
exaone = "exaone" # pylint: disable=invalid-name
|
||||
metharme = "metharme" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DeprecatedParameters(BaseModel):
|
||||
@@ -66,6 +68,7 @@ class DeprecatedParameters(BaseModel):
|
||||
rope_scaling: Optional[Any] = None
|
||||
noisy_embedding_alpha: Optional[float] = None
|
||||
dpo_beta: Optional[float] = None
|
||||
evaluation_strategy: Optional[str] = None
|
||||
|
||||
@field_validator("max_packed_sequence_len")
|
||||
@classmethod
|
||||
@@ -97,6 +100,13 @@ class DeprecatedParameters(BaseModel):
|
||||
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
||||
return dpo_beta
|
||||
|
||||
@field_validator("evaluation_strategy")
|
||||
@classmethod
|
||||
def validate_evaluation_strategy(cls, evaluation_strategy):
|
||||
if evaluation_strategy is not None:
|
||||
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
||||
return evaluation_strategy
|
||||
|
||||
|
||||
class RemappedParameters(BaseModel):
|
||||
"""parameters that have been remapped to other names"""
|
||||
@@ -426,6 +436,7 @@ class HyperparametersConfig(BaseModel):
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
],
|
||||
]
|
||||
] = OptimizerNames.ADAMW_HF.value
|
||||
@@ -587,6 +598,9 @@ class AxolotlInputConfig(
|
||||
|
||||
rl: Optional[RLType] = None
|
||||
reward_model: Optional[bool] = None
|
||||
dpo_use_weighting: Optional[
|
||||
bool
|
||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||
|
||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
@@ -725,7 +739,7 @@ class AxolotlInputConfig(
|
||||
warmup_ratio: Optional[float] = None
|
||||
eval_steps: Optional[Union[int, float]] = None
|
||||
evals_per_epoch: Optional[Union[int]] = None
|
||||
evaluation_strategy: Optional[str] = None
|
||||
eval_strategy: Optional[str] = None
|
||||
save_steps: Optional[Union[int, float]] = None
|
||||
saves_per_epoch: Optional[int] = None
|
||||
save_strategy: Optional[str] = None
|
||||
@@ -777,28 +791,25 @@ class AxolotlInputConfig(
|
||||
is_mistral_derived_model: Optional[bool] = Field(default=None)
|
||||
is_qwen_derived_model: Optional[bool] = Field(default=None)
|
||||
|
||||
plugins: Optional[List[str]] = Field(default=None)
|
||||
|
||||
@field_validator("datasets", mode="before")
|
||||
@classmethod
|
||||
def fix_sharegpt_datasets(cls, datasets):
|
||||
for idx, ds_cfg in enumerate(datasets):
|
||||
if not ds_cfg["type"]:
|
||||
def deprecate_sharegpt_datasets(cls, datasets):
|
||||
for _, ds_cfg in enumerate(datasets):
|
||||
if not ds_cfg.get("type"):
|
||||
continue
|
||||
if ds_cfg["type"] == "sharegpt:chat":
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||
)
|
||||
)
|
||||
datasets[idx]["type"] = "sharegpt"
|
||||
if "sharegpt_simple" in ds_cfg["type"]:
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||
)
|
||||
)
|
||||
datasets[idx]["type"] = datasets[idx]["type"].replace(
|
||||
"sharegpt_simple", "sharegpt"
|
||||
|
||||
ds_type = ds_cfg["type"]
|
||||
# skip if it's a dict (for custom user instruction prompt)
|
||||
if isinstance(ds_type, dict):
|
||||
continue
|
||||
|
||||
if isinstance(ds_type, str) and ds_type.startswith("sharegpt"):
|
||||
raise ValueError(
|
||||
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
|
||||
)
|
||||
|
||||
return datasets
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1030,21 +1041,21 @@ class AxolotlInputConfig(
|
||||
@classmethod
|
||||
def check_evals(cls, data):
|
||||
if (
|
||||
data.get("evaluation_strategy")
|
||||
data.get("eval_strategy")
|
||||
and data.get("eval_steps")
|
||||
and data.get("evaluation_strategy") != "steps"
|
||||
and data.get("eval_strategy") != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||
"eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps."
|
||||
)
|
||||
|
||||
if (
|
||||
data.get("val_set_size") == 0
|
||||
and (data.get("eval_steps") or data.get("evaluation_strategy"))
|
||||
and (data.get("eval_steps") or data.get("eval_strategy"))
|
||||
and not data.get("test_datasets")
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||
"eval_steps and eval_strategy are not supported with val_set_size == 0"
|
||||
)
|
||||
if data.get("evals_per_epoch") and data.get("eval_steps"):
|
||||
raise ValueError(
|
||||
@@ -1052,11 +1063,11 @@ class AxolotlInputConfig(
|
||||
)
|
||||
if (
|
||||
data.get("evals_per_epoch")
|
||||
and data.get("evaluation_strategy")
|
||||
and data.get("evaluation_strategy") != "steps"
|
||||
and data.get("eval_strategy")
|
||||
and data.get("eval_strategy") != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
"eval_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
|
||||
if data.get("do_bench_eval") and not (
|
||||
@@ -1288,6 +1299,25 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def warn_qlora_zero3_w_use_reentrant(cls, data):
|
||||
if (
|
||||
data.get("adapter") == "qlora"
|
||||
and data.get("gradient_checkpointing_kwargs", {})
|
||||
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
|
||||
is False
|
||||
and "zero3" in data.get("deepspeed", "")
|
||||
):
|
||||
# may result in:
|
||||
# torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint:
|
||||
# Recomputed values for the following tensors have different metadata
|
||||
# than during the forward pass.
|
||||
LOG.warning(
|
||||
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_val_w_test_datasets(cls, data):
|
||||
@@ -1297,6 +1327,19 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_eval_strategy(cls, data):
|
||||
if (
|
||||
data.get("evaluation_strategy") is not None
|
||||
and data.get("eval_strategy") is None
|
||||
):
|
||||
LOG.info(
|
||||
"explicitly setting `eval_strategy` from the `evaluation_strategy`"
|
||||
)
|
||||
data["eval_strategy"] = data.get("evaluation_strategy")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import requests
|
||||
from datasets import (
|
||||
Dataset,
|
||||
DatasetDict,
|
||||
@@ -53,6 +55,28 @@ from axolotl.utils.trainer import (
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except (
|
||||
requests.exceptions.ReadTimeout,
|
||||
requests.exceptions.ConnectionError,
|
||||
) as exc:
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(delay)
|
||||
else:
|
||||
raise exc
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||
def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
prompters = []
|
||||
if not cfg.pretraining_dataset:
|
||||
@@ -236,6 +260,7 @@ def load_tokenized_prepared_datasets(
|
||||
for config_dataset in for_d_in_datasets(cfg_datasets):
|
||||
ds: Optional[Union[Dataset, DatasetDict]] = None
|
||||
ds_from_hub = False
|
||||
ds_trust_remote_code = config_dataset.trust_remote_code
|
||||
try:
|
||||
# this is just a basic check to see if the path is a
|
||||
# valid HF dataset that's loadable
|
||||
@@ -245,6 +270,7 @@ def load_tokenized_prepared_datasets(
|
||||
streaming=True,
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=ds_trust_remote_code,
|
||||
)
|
||||
ds_from_hub = True
|
||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||
@@ -324,7 +350,15 @@ def load_tokenized_prepared_datasets(
|
||||
split=None,
|
||||
)
|
||||
else:
|
||||
ds = load_from_disk(config_dataset.path)
|
||||
try:
|
||||
ds = load_from_disk(config_dataset.path)
|
||||
except FileNotFoundError:
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
|
||||
@@ -342,7 +376,7 @@ def load_tokenized_prepared_datasets(
|
||||
elif ds_from_hub:
|
||||
load_ds_kwargs = {}
|
||||
if config_dataset.split:
|
||||
load_ds_kwargs = {"split": config_dataset.split}
|
||||
load_ds_kwargs["split"] = config_dataset.split
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
@@ -350,6 +384,7 @@ def load_tokenized_prepared_datasets(
|
||||
data_files=config_dataset.data_files,
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif ds_from_cloud and remote_file_system:
|
||||
@@ -367,6 +402,7 @@ def load_tokenized_prepared_datasets(
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
)
|
||||
elif config_dataset.path.startswith("https://"):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
@@ -377,6 +413,7 @@ def load_tokenized_prepared_datasets(
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
|
||||
25
src/axolotl/utils/environment.py
Normal file
25
src/axolotl/utils/environment.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
utils to get GPU info for the current environment
|
||||
"""
|
||||
from accelerate.utils.environment import (
|
||||
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
|
||||
)
|
||||
from accelerate.utils.environment import get_gpu_info
|
||||
|
||||
|
||||
def check_cuda_p2p_ib_support():
|
||||
if not accelerate_check_cuda_p2p_ib_support():
|
||||
return False
|
||||
unsupported_devices = {"RTX 6000 Ada"}
|
||||
try:
|
||||
device_names, device_count = get_gpu_info()
|
||||
if 1 < device_count < 8:
|
||||
if any(
|
||||
unsupported_device in device_name
|
||||
for device_name in device_names
|
||||
for unsupported_device in unsupported_devices
|
||||
):
|
||||
return False
|
||||
except Exception: # pylint: disable=broad-except # nosec
|
||||
pass
|
||||
return True
|
||||
@@ -14,6 +14,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
torch_version = version.parse(torch.__version__)
|
||||
|
||||
if torch_version < version.parse("2.4.0"):
|
||||
torch_cuda_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch_cuda_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
||||
else:
|
||||
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||
|
||||
|
||||
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
@@ -25,7 +35,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_fwd
|
||||
@torch_cuda_amp_custom_fwd
|
||||
def forward(ctx, forward_function, hidden_states, *args):
|
||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||
with torch.no_grad():
|
||||
@@ -36,7 +46,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.cuda.amp.custom_bwd
|
||||
@torch_cuda_amp_custom_bwd
|
||||
def backward(ctx, dY):
|
||||
(hidden_states,) = ctx.saved_tensors
|
||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||
|
||||
@@ -238,6 +238,7 @@ def load_tokenizer(cfg):
|
||||
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
||||
)
|
||||
)
|
||||
and k != "pad_token"
|
||||
):
|
||||
lora_modules_to_save = ", ".join(
|
||||
[f"`{x}`" for x in lora_modules_to_save]
|
||||
@@ -396,9 +397,11 @@ class ModelLoader:
|
||||
):
|
||||
has_remote_code = (
|
||||
"auto_map" in self.model_config
|
||||
and self.model_type in self.model_config["auto_map"]
|
||||
and "AutoModelForCausalLM" in self.model_config["auto_map"]
|
||||
)
|
||||
|
||||
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
|
||||
has_remote_code = self.cfg.trust_remote_code
|
||||
patch_for_multipack(
|
||||
self.cfg.model_config_type,
|
||||
model_name=self.cfg.base_model,
|
||||
@@ -645,9 +648,7 @@ class ModelLoader:
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**self.model_config.quantization_config
|
||||
)
|
||||
elif self.cfg.adapter == "qlora" and (
|
||||
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
|
||||
):
|
||||
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
|
||||
bnb_config = {
|
||||
"load_in_4bit": True,
|
||||
"llm_int8_threshold": 6.0,
|
||||
@@ -670,9 +671,7 @@ class ModelLoader:
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
elif self.cfg.adapter == "lora" and (
|
||||
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
|
||||
):
|
||||
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
|
||||
bnb_config = {
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
@@ -685,10 +684,8 @@ class ModelLoader:
|
||||
|
||||
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
||||
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
|
||||
if "load_in_8bit" in self.model_kwargs:
|
||||
del self.model_kwargs["load_in_8bit"]
|
||||
if "load_in_4bit" in self.model_kwargs:
|
||||
del self.model_kwargs["load_in_4bit"]
|
||||
self.model_kwargs.pop("load_in_8bit", None)
|
||||
self.model_kwargs.pop("load_in_4bit", None)
|
||||
|
||||
def set_attention_config(self) -> None:
|
||||
"""
|
||||
@@ -973,17 +970,10 @@ class ModelLoader:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
is_load_in_8bit = (
|
||||
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
|
||||
)
|
||||
is_load_in_4bit = (
|
||||
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
|
||||
)
|
||||
|
||||
if (
|
||||
not skip_prepare_model_for_kbit_training
|
||||
and self.cfg.adapter in ["lora", "qlora"]
|
||||
and (is_load_in_8bit or is_load_in_4bit)
|
||||
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
|
||||
):
|
||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||
self.model = prepare_model_for_kbit_training(
|
||||
@@ -1121,16 +1111,10 @@ class ModelLoader:
|
||||
# ---------------------------------------------------------
|
||||
# put model to accelerator
|
||||
# ---------------------------------------------------------
|
||||
is_load_in_8bit = (
|
||||
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
|
||||
)
|
||||
is_load_in_4bit = (
|
||||
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
|
||||
)
|
||||
if (
|
||||
self.cfg.ddp
|
||||
and not is_load_in_8bit
|
||||
and not (self.cfg.rl and is_load_in_4bit)
|
||||
and not self.cfg.load_in_8bit
|
||||
and not (self.cfg.rl and self.cfg.load_in_4bit)
|
||||
and not skip_move_to_device
|
||||
):
|
||||
# TODO revaldate this conditional
|
||||
|
||||
508
src/axolotl/utils/optimizers/adopt.py
Normal file
508
src/axolotl/utils/optimizers/adopt.py
Normal file
@@ -0,0 +1,508 @@
|
||||
"""
|
||||
Copied from https://github.com/iShohei220/adopt
|
||||
|
||||
ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024)
|
||||
Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka
|
||||
"""
|
||||
# mypy: ignore-errors
|
||||
# pylint: skip-file
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim.optimizer import (
|
||||
Optimizer,
|
||||
ParamsT,
|
||||
_default_to_fused_or_foreach,
|
||||
_device_dtype_check_for_fused,
|
||||
_disable_dynamo_if_unsupported,
|
||||
_get_capturable_supported_devices,
|
||||
_get_scalar_dtype,
|
||||
_get_value,
|
||||
_use_grad_for_differentiable,
|
||||
_view_as_real,
|
||||
)
|
||||
|
||||
__all__ = ["ADOPT", "adopt"]
|
||||
|
||||
|
||||
class ADOPT(Optimizer):
|
||||
def __init__(
|
||||
self,
|
||||
params: ParamsT,
|
||||
lr: Union[float, Tensor] = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.9999),
|
||||
eps: float = 1e-6,
|
||||
weight_decay: float = 0.0,
|
||||
decoupled: bool = False,
|
||||
*,
|
||||
foreach: Optional[bool] = None,
|
||||
maximize: bool = False,
|
||||
capturable: bool = False,
|
||||
differentiable: bool = False,
|
||||
fused: Optional[bool] = None,
|
||||
):
|
||||
if isinstance(lr, Tensor):
|
||||
if foreach and not capturable:
|
||||
raise ValueError(
|
||||
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
if lr.numel() != 1:
|
||||
raise ValueError("Tensor lr must be 1-element")
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
decoupled=decoupled,
|
||||
maximize=maximize,
|
||||
foreach=foreach,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
fused=fused,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
if fused:
|
||||
# TODO: support fused
|
||||
raise RuntimeError("`fused` is not currently supported")
|
||||
|
||||
if differentiable:
|
||||
raise RuntimeError("`fused` does not support `differentiable`")
|
||||
self._step_supports_amp_scaling = True
|
||||
# TODO(crcrpar): [low prec params & their higher prec copy]
|
||||
# Support AMP with FP16/BF16 model params which would need
|
||||
# higher prec copy of params to do update math in higher prec to
|
||||
# alleviate the loss of information.
|
||||
if foreach:
|
||||
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault("maximize", False)
|
||||
group.setdefault("foreach", None)
|
||||
group.setdefault("capturable", False)
|
||||
group.setdefault("differentiable", False)
|
||||
fused = group.setdefault("fused", None)
|
||||
for p in group["params"]:
|
||||
p_state = self.state.get(p, [])
|
||||
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
|
||||
step_val = float(p_state["step"])
|
||||
p_state["step"] = (
|
||||
torch.tensor(
|
||||
step_val,
|
||||
dtype=_get_scalar_dtype(is_fused=fused),
|
||||
device=p.device,
|
||||
)
|
||||
if group["capturable"] or group["fused"]
|
||||
else torch.tensor(step_val, dtype=_get_scalar_dtype())
|
||||
)
|
||||
|
||||
def _init_group(
|
||||
self,
|
||||
group,
|
||||
params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
):
|
||||
has_complex = False
|
||||
for p in group["params"]:
|
||||
if p.grad is not None:
|
||||
has_complex |= torch.is_complex(p)
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError("ADOPT does not support sparse gradients")
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
# Lazy state initialization
|
||||
if len(state) == 0:
|
||||
if group["fused"]:
|
||||
_device_dtype_check_for_fused(p)
|
||||
# note(crcrpar): [special device hosting for step]
|
||||
# Deliberately host `step` on CPU if both capturable and fused are off.
|
||||
# This is because kernel launches are costly on CUDA and XLA.
|
||||
state["step"] = (
|
||||
torch.zeros(
|
||||
(),
|
||||
dtype=_get_scalar_dtype(is_fused=group["fused"]),
|
||||
device=p.device,
|
||||
)
|
||||
if group["capturable"] or group["fused"]
|
||||
else torch.tensor(0.0, dtype=_get_scalar_dtype())
|
||||
)
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
exp_avgs.append(state["exp_avg"])
|
||||
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||
|
||||
if group["differentiable"] and state["step"].requires_grad:
|
||||
raise RuntimeError(
|
||||
"`requires_grad` is not supported for `step` in differentiable mode"
|
||||
)
|
||||
|
||||
# Foreach without capturable does not support a tensor lr
|
||||
if (
|
||||
group["foreach"]
|
||||
and torch.is_tensor(group["lr"])
|
||||
and not group["capturable"]
|
||||
):
|
||||
raise RuntimeError(
|
||||
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
|
||||
state_steps.append(state["step"])
|
||||
return has_complex
|
||||
|
||||
@_use_grad_for_differentiable
|
||||
def step(self, closure=None):
|
||||
"""Perform a single optimization step.
|
||||
|
||||
Args:
|
||||
closure (Callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
self._cuda_graph_capture_health_check()
|
||||
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad: List[Tensor] = []
|
||||
grads: List[Tensor] = []
|
||||
exp_avgs: List[Tensor] = []
|
||||
exp_avg_sqs: List[Tensor] = []
|
||||
state_steps: List[Tensor] = []
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
has_complex = self._init_group(
|
||||
group,
|
||||
params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
)
|
||||
|
||||
adopt(
|
||||
params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
has_complex=has_complex,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group["lr"],
|
||||
weight_decay=group["weight_decay"],
|
||||
decoupled=group["decoupled"],
|
||||
eps=group["eps"],
|
||||
maximize=group["maximize"],
|
||||
foreach=group["foreach"],
|
||||
capturable=group["capturable"],
|
||||
differentiable=group["differentiable"],
|
||||
fused=group["fused"],
|
||||
grad_scale=getattr(self, "grad_scale", None),
|
||||
found_inf=getattr(self, "found_inf", None),
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def _single_tensor_adopt(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
has_complex: bool,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
decoupled: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
capturable: bool,
|
||||
differentiable: bool,
|
||||
):
|
||||
assert grad_scale is None and found_inf is None
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
# this assert is due to JIT being dumb and not realizing that the ops below
|
||||
# have overloads to handle both float and Tensor lrs, so we just assert it's
|
||||
# a float since most people using JIT are using floats
|
||||
assert isinstance(lr, float)
|
||||
|
||||
for i, param in enumerate(params):
|
||||
grad = grads[i] if not maximize else -grads[i]
|
||||
exp_avg = exp_avgs[i]
|
||||
exp_avg_sq = exp_avg_sqs[i]
|
||||
step_t = state_steps[i]
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices()
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
||||
# update step
|
||||
step_t += 1
|
||||
|
||||
if weight_decay != 0:
|
||||
if decoupled:
|
||||
param.add_(param, alpha=-lr * weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
if torch.is_complex(param):
|
||||
grad = torch.view_as_real(grad)
|
||||
if exp_avg is not None:
|
||||
exp_avg = torch.view_as_real(exp_avg)
|
||||
if exp_avg_sq is not None:
|
||||
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||
param = torch.view_as_real(param)
|
||||
|
||||
step = step_t if capturable or differentiable else _get_value(step_t)
|
||||
if step == 1:
|
||||
exp_avg_sq.addcmul_(grad, grad.conj())
|
||||
continue
|
||||
|
||||
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
||||
if step == 2:
|
||||
exp_avg.addcdiv_(grad, denom)
|
||||
else:
|
||||
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
|
||||
|
||||
param.add_(exp_avg, alpha=-lr)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||
|
||||
|
||||
def _multi_tensor_adopt(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
grad_scale: Optional[Tensor],
|
||||
found_inf: Optional[Tensor],
|
||||
*,
|
||||
has_complex: bool,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
decoupled: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
capturable: bool,
|
||||
differentiable: bool,
|
||||
):
|
||||
if len(params) == 0:
|
||||
return
|
||||
|
||||
if isinstance(lr, Tensor) and not capturable:
|
||||
raise RuntimeError(
|
||||
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||
)
|
||||
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
capturable_supported_devices = _get_capturable_supported_devices(
|
||||
supports_xla=False
|
||||
)
|
||||
assert all(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
||||
assert grad_scale is None and found_inf is None
|
||||
|
||||
assert not differentiable, "_foreach ops don't support autograd"
|
||||
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||
[params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item]
|
||||
)
|
||||
for (
|
||||
device_params_,
|
||||
device_grads_,
|
||||
device_exp_avgs_,
|
||||
device_exp_avg_sqs_,
|
||||
device_state_steps_,
|
||||
), _ in grouped_tensors.values():
|
||||
device_params = cast(List[Tensor], device_params_)
|
||||
device_grads = cast(List[Tensor], device_grads_)
|
||||
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
|
||||
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
|
||||
device_state_steps = cast(List[Tensor], device_state_steps_)
|
||||
|
||||
# Handle complex parameters
|
||||
if has_complex:
|
||||
_view_as_real(
|
||||
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
|
||||
)
|
||||
|
||||
if maximize:
|
||||
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
||||
|
||||
# Update steps
|
||||
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||
torch._foreach_add_(
|
||||
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||
)
|
||||
else:
|
||||
torch._foreach_add_(device_state_steps, 1)
|
||||
|
||||
if weight_decay != 0:
|
||||
if decoupled:
|
||||
torch._foreach_add_(
|
||||
device_params, device_params, alpha=-lr * weight_decay
|
||||
)
|
||||
else:
|
||||
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||
if maximize:
|
||||
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||
else:
|
||||
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||
device_grads, device_params, alpha=weight_decay
|
||||
)
|
||||
|
||||
if device_state_steps[0] == 1:
|
||||
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
||||
continue
|
||||
|
||||
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
|
||||
|
||||
if device_state_steps[0] == 2:
|
||||
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
|
||||
else:
|
||||
torch._foreach_mul_(device_exp_avgs, beta1)
|
||||
torch._foreach_addcdiv_(
|
||||
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
|
||||
)
|
||||
|
||||
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||
torch._foreach_addcmul_(
|
||||
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
||||
)
|
||||
|
||||
|
||||
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
|
||||
def adopt(
|
||||
params: List[Tensor],
|
||||
grads: List[Tensor],
|
||||
exp_avgs: List[Tensor],
|
||||
exp_avg_sqs: List[Tensor],
|
||||
state_steps: List[Tensor],
|
||||
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||
foreach: Optional[bool] = None,
|
||||
capturable: bool = False,
|
||||
differentiable: bool = False,
|
||||
fused: Optional[bool] = None,
|
||||
grad_scale: Optional[Tensor] = None,
|
||||
found_inf: Optional[Tensor] = None,
|
||||
has_complex: bool = False,
|
||||
*,
|
||||
beta1: float,
|
||||
beta2: float,
|
||||
lr: Union[float, Tensor],
|
||||
weight_decay: float,
|
||||
decoupled: bool,
|
||||
eps: float,
|
||||
maximize: bool,
|
||||
):
|
||||
r"""Functional API that performs ADOPT algorithm computation."""
|
||||
# Respect when the user inputs False/True for foreach or fused. We only want to change
|
||||
# the default when neither have been user-specified. Note that we default to foreach
|
||||
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
|
||||
# bake-in time before making it the default, even if it is typically faster.
|
||||
if fused is None and foreach is None:
|
||||
_, foreach = _default_to_fused_or_foreach(
|
||||
params, differentiable, use_fused=False
|
||||
)
|
||||
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
|
||||
if foreach and isinstance(lr, Tensor) and not capturable:
|
||||
foreach = False
|
||||
if fused is None:
|
||||
fused = False
|
||||
if foreach is None:
|
||||
foreach = False
|
||||
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(
|
||||
isinstance(t, torch.Tensor) for t in state_steps
|
||||
):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
if fused and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with fused optimizers")
|
||||
|
||||
# if fused and not torch.jit.is_scripting():
|
||||
# func = _fused_adopt
|
||||
# elif foreach and not torch.jit.is_scripting():
|
||||
if foreach and not torch.jit.is_scripting():
|
||||
func = _multi_tensor_adopt
|
||||
else:
|
||||
func = _single_tensor_adopt
|
||||
|
||||
func(
|
||||
params,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
state_steps,
|
||||
has_complex=has_complex,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=lr,
|
||||
weight_decay=weight_decay,
|
||||
decoupled=decoupled,
|
||||
eps=eps,
|
||||
maximize=maximize,
|
||||
capturable=capturable,
|
||||
differentiable=differentiable,
|
||||
grad_scale=grad_scale,
|
||||
found_inf=found_inf,
|
||||
)
|
||||
@@ -1,8 +1,6 @@
|
||||
"""Module for tokenization utilities"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
@@ -93,65 +91,3 @@ def check_rl_example_labels(example, tokenizer, text_only=False):
|
||||
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
|
||||
|
||||
return delimiter.join(colored_tokens)
|
||||
|
||||
|
||||
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
||||
GLAIVE_TO_SHAREGPT_ROLE = {
|
||||
"SYSTEM": "system",
|
||||
"USER": "human",
|
||||
"ASSISTANT": "gpt",
|
||||
"FUNCTION RESPONSE": "tool",
|
||||
}
|
||||
|
||||
GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ")
|
||||
|
||||
|
||||
def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Converts a ChatML formatted row to a list of messages in ShareGPT format.
|
||||
Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb.
|
||||
"""
|
||||
|
||||
system_prompt = row.get("system")
|
||||
if system_prompt:
|
||||
system_prompt = system_prompt.removeprefix("SYSTEM: ")
|
||||
|
||||
chat_str = row["chat"]
|
||||
chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s]
|
||||
|
||||
chat_msg_dicts = [
|
||||
{"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value}
|
||||
for role, value in zip(chat_msgs[::2], chat_msgs[1::2])
|
||||
]
|
||||
|
||||
if system_prompt:
|
||||
chat_msg_dicts = [
|
||||
{"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt}
|
||||
] + chat_msg_dicts
|
||||
|
||||
return chat_msg_dicts
|
||||
|
||||
|
||||
def merge_consecutive_messages(messages):
|
||||
"""
|
||||
Merge consecutive messages from the same sender into a single message.
|
||||
This can be useful with datasets that contain multiple consecutive tool calls.
|
||||
"""
|
||||
|
||||
merged_messages = []
|
||||
current_from = None
|
||||
current_message = ""
|
||||
|
||||
for msg in messages:
|
||||
if current_from == msg["from"]:
|
||||
current_message += msg["value"]
|
||||
else:
|
||||
if current_from is not None:
|
||||
merged_messages.append({"from": current_from, "value": current_message})
|
||||
current_from = msg["from"]
|
||||
current_message = msg["value"]
|
||||
|
||||
if current_from is not None:
|
||||
merged_messages.append({"from": current_from, "value": current_message})
|
||||
|
||||
return merged_messages
|
||||
|
||||
@@ -16,7 +16,11 @@ from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||
from axolotl.monkeypatch.trainer_fsdp_grad_accum import (
|
||||
patch_training_loop_for_fsdp_grad_accum,
|
||||
)
|
||||
from axolotl.utils.distributed import reduce_and_broadcast
|
||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
LOG = get_logger("axolotl")
|
||||
@@ -184,11 +188,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
min_sequence_len=cfg.min_sample_len or 2,
|
||||
)
|
||||
|
||||
if cfg.is_preprocess:
|
||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||
|
||||
if cfg.model_config_type == "mamba":
|
||||
LOG.info("dropping attention_mask column")
|
||||
@@ -461,6 +464,9 @@ def setup_fsdp_envs(cfg):
|
||||
|
||||
|
||||
def prepare_optim_env(cfg):
|
||||
if not check_cuda_p2p_ib_support():
|
||||
if os.getenv("NCCL_P2P_DISABLE") is None:
|
||||
os.environ["NCCL_P2P_DISABLE"] = "1"
|
||||
if cfg.fsdp:
|
||||
setup_fsdp_envs(cfg)
|
||||
elif cfg.deepspeed:
|
||||
@@ -490,6 +496,11 @@ def prepare_opinionated_env(cfg):
|
||||
def setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||
):
|
||||
if cfg.fsdp:
|
||||
try:
|
||||
patch_training_loop_for_fsdp_grad_accum()
|
||||
except AssertionError:
|
||||
pass
|
||||
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||
trainer_builder.model_ref = model[1]
|
||||
|
||||
16
tests/e2e/conftest.py
Normal file
16
tests/e2e/conftest.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
shared pytest fixtures
|
||||
"""
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
# Create a temporary directory
|
||||
_temp_dir = tempfile.mkdtemp()
|
||||
yield _temp_dir
|
||||
# Clean up the directory after the test
|
||||
shutil.rmtree(_temp_dir)
|
||||
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
Simple end-to-end test for Liger integration
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -3,28 +3,25 @@ E2E tests for multigpu eval
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
class TestMultiGPUEval(unittest.TestCase):
|
||||
class TestMultiGPUEval:
|
||||
"""
|
||||
Test case for MultiGPU Eval Sample Packing
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_eval_sample_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
@@ -83,13 +80,14 @@ class TestMultiGPUEval(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_eval(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
@@ -148,6 +146,8 @@ class TestMultiGPUEval(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
|
||||
@@ -4,17 +4,17 @@ E2E tests for multigpu lora tinyllama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
from ..utils import is_hopper
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -28,18 +28,16 @@ def download_model():
|
||||
snapshot_download("TinyLlama/TinyLlama_v1.1")
|
||||
|
||||
|
||||
class TestMultiGPULlama(unittest.TestCase):
|
||||
class TestMultiGPULlama:
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||
"sequence_len": 2048,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
@@ -48,9 +46,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -59,7 +55,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 100,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -81,19 +77,23 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_ddp_packed(self, temp_dir):
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 4],
|
||||
)
|
||||
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": False,
|
||||
@@ -105,9 +105,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -116,9 +114,9 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 50,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
@@ -138,26 +136,170 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_fsdp(self, temp_dir):
|
||||
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
|
||||
def test_dpo_lora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": False,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"rl": "dpo",
|
||||
"chat_template": "llama3",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"type": "chat_template.default",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
def test_dpo_qlora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": False,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"type": "chat_template.default",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
"warmup_steps": 0,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 4],
|
||||
)
|
||||
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.01,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "tatsu-lab/alpaca",
|
||||
@@ -165,9 +307,9 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 100,
|
||||
"max_steps": 10,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
@@ -184,7 +326,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"fsdp_use_orig_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": False,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
}
|
||||
@@ -201,28 +343,29 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_fsdp_packed(self, temp_dir):
|
||||
@pytest.mark.parametrize(
|
||||
"fsdp_state_dict_type",
|
||||
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
|
||||
)
|
||||
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -231,7 +374,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 100,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -250,7 +393,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"fsdp_use_orig_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": False,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||
"fsdp_state_dict_type": fsdp_state_dict_type,
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
}
|
||||
@@ -267,14 +410,14 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.mark.skip("disabled due to upstream issue")
|
||||
@with_temp_dir
|
||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
@@ -282,6 +425,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"base_model": "axolotl-ai-co/TinyLlama_v1.1-bnb-nf4-bf16",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"adapter": "qlora",
|
||||
"mean_resizing_embeddings": True,
|
||||
"load_in_4bit": True,
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
@@ -297,7 +441,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|end_of_text|>",
|
||||
"pad_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -307,7 +451,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 100,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -343,28 +487,29 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ds_zero3_packed(self, temp_dir):
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 4],
|
||||
)
|
||||
def test_ds_zero3_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -373,9 +518,9 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 100,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
@@ -396,19 +541,19 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ds_zero3_qlora_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
@@ -421,9 +566,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -432,7 +575,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 100,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"output_dir": temp_dir,
|
||||
@@ -455,6 +598,8 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
|
||||
@@ -4,31 +4,30 @@ E2E tests for multigpu qwen2
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestMultiGPUQwen2(unittest.TestCase):
|
||||
class TestMultiGPUQwen2:
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_qlora_fsdp_dpo(self, temp_dir):
|
||||
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
||||
def test_qlora_fsdp_dpo(self, base_model, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2-1.5B",
|
||||
"base_model": base_model,
|
||||
"load_in_4bit": True,
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
@@ -47,9 +46,9 @@ class TestMultiGPUQwen2(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 100,
|
||||
"max_steps": 5,
|
||||
"warmup_steps": 20,
|
||||
"micro_batch_size": 4,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
@@ -91,6 +90,8 @@ class TestMultiGPUQwen2(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
|
||||
@@ -13,7 +13,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import require_torch_2_1_1, with_temp_dir
|
||||
from ..utils import require_torch_2_3_1, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -24,7 +24,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
Test case for Llama models using 4d attention with multipack
|
||||
"""
|
||||
|
||||
@require_torch_2_1_1
|
||||
@require_torch_2_3_1
|
||||
@with_temp_dir
|
||||
def test_sdp_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
15
tests/e2e/patched/test_trainer_fsdp.py
Normal file
15
tests/e2e/patched/test_trainer_fsdp.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
||||
import unittest
|
||||
|
||||
from axolotl.monkeypatch.trainer_fsdp_grad_accum import check_training_loop_is_patchable
|
||||
|
||||
|
||||
class TestTrainerFSDPIntegration(unittest.TestCase):
|
||||
"""Unsloth monkeypatch integration tests."""
|
||||
|
||||
def test_train_loop_patchable(self):
|
||||
# ensures the current version of transformers has loss code that matches our patching code
|
||||
self.assertTrue(
|
||||
check_training_loop_is_patchable(),
|
||||
"HF transformers _inner_training_loop has changed and isn't patchable",
|
||||
)
|
||||
@@ -115,6 +115,51 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_use_weighting(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 64,
|
||||
"lora_alpha": 32,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_target_linear": True,
|
||||
"special_tokens": {},
|
||||
"rl": "dpo",
|
||||
"dpo_use_weighting": True,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||
"type": "chatml.ultra",
|
||||
"split": "train",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "paged_adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"warmup_steps": 5,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
|
||||
@pytest.mark.skip("kto_pair no longer supported in trl")
|
||||
@with_temp_dir
|
||||
def test_kto_pair_lora(self, temp_dir):
|
||||
|
||||
66
tests/e2e/test_llama.py
Normal file
66
tests/e2e/test_llama.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
E2E tests for llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_trust_remote_code(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"trust_remote_code": True,
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
@@ -13,7 +13,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
from .utils import require_torch_2_5_1, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -65,3 +65,80 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@with_temp_dir
|
||||
@require_torch_2_5_1
|
||||
def test_adopt_adamw(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adopt_adamw",
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "schedule_free_adamw",
|
||||
"lr_scheduler": "constant",
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
85
tests/e2e/test_qwen.py
Normal file
85
tests/e2e/test_qwen.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
E2E tests for qwen
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.qwen")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestE2eQwen:
|
||||
"""
|
||||
Test cases for qwen models
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
||||
def test_dpo(self, base_model, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": base_model,
|
||||
"rl": "dpo",
|
||||
"chat_template": "qwen_25",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.0,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"split": "train",
|
||||
"type": "chat_template.default",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"warmup_steps": 20,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"bf16": "auto",
|
||||
"tf32": True,
|
||||
"gradient_checkpointing": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
@@ -6,9 +6,13 @@ import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import wraps
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
# from importlib.metadata import version
|
||||
from packaging import version
|
||||
|
||||
|
||||
def with_temp_dir(test_func):
|
||||
@wraps(test_func)
|
||||
@@ -35,13 +39,30 @@ def most_recent_subdir(path):
|
||||
return subdir
|
||||
|
||||
|
||||
def require_torch_2_1_1(test_case):
|
||||
def require_torch_2_3_1(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.1.1
|
||||
Decorator marking a test that requires torch >= 2.3.1
|
||||
"""
|
||||
|
||||
def is_min_2_1_1():
|
||||
torch_version = version("torch")
|
||||
return torch_version >= "2.1.1"
|
||||
def is_min_2_3_1():
|
||||
torch_version = version.parse(torch.__version__)
|
||||
return torch_version >= version.parse("2.3.1")
|
||||
|
||||
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)
|
||||
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
|
||||
|
||||
|
||||
def require_torch_2_5_1(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.3.1
|
||||
"""
|
||||
|
||||
def is_min_2_5_1():
|
||||
torch_version = version.parse(torch.__version__)
|
||||
return torch_version >= version.parse("2.5.1")
|
||||
|
||||
return unittest.skipUnless(is_min_2_5_1(), "test torch 2.5.1")(test_case)
|
||||
|
||||
|
||||
def is_hopper():
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
return compute_capability == (9, 0)
|
||||
|
||||
0
tests/integrations/__init__.py
Normal file
0
tests/integrations/__init__.py
Normal file
80
tests/integrations/liger.py
Normal file
80
tests/integrations/liger.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
config validation tests for swiglu args
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="minimal_base_cfg")
|
||||
def fixture_cfg():
|
||||
return DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
|
||||
"learning_rate": 0.000001,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class BaseValidation:
|
||||
"""
|
||||
Base validation module to setup the log capture
|
||||
"""
|
||||
|
||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class TestValidation(BaseValidation):
|
||||
"""
|
||||
Test the validation module for liger
|
||||
"""
|
||||
|
||||
def test_deprecated_swiglu(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"liger_swiglu": False,
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
updated_cfg = validate_config(test_cfg)
|
||||
assert (
|
||||
"The 'liger_swiglu' argument is deprecated"
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
assert updated_cfg.liger_swiglu is None
|
||||
assert updated_cfg.liger_glu_activations is False
|
||||
|
||||
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"liger_swiglu": False,
|
||||
"liger_glu_activations": True,
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
|
||||
):
|
||||
validate_config(test_cfg)
|
||||
@@ -1,500 +0,0 @@
|
||||
"""
|
||||
Test module for sharegpt integration w chatml
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
from tokenizers import AddedToken
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.sharegpt import (
|
||||
GlaiveShareGPTPromptTokenizingStrategy,
|
||||
SimpleShareGPTPromptTokenizingStrategy,
|
||||
register_chatml_template,
|
||||
register_llama3_template,
|
||||
)
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
|
||||
register_chatml_template()
|
||||
register_llama3_template()
|
||||
|
||||
|
||||
@pytest.fixture(name="sharegpt_dataset")
|
||||
def fixture_sharegpt_dataset():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"from": "system",
|
||||
"value": "repeat",
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "hello",
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "hello",
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "goodbye",
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "goodbye",
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="sharegpt_dataset_with_weights")
|
||||
def fixture_sharegpt_dataset_with_weights():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"from": "system",
|
||||
"value": "repeat",
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "hello",
|
||||
"weight": 1,
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "hello",
|
||||
"weight": 0,
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "rehello",
|
||||
"weight": 0,
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "rehello",
|
||||
"weight": 1,
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "goodbye",
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "goodbye",
|
||||
"weight": 0,
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="glaive_dataset")
|
||||
def fixture_sharegpt_glaive_dataset():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"system": "SYSTEM: This is a system prompt",
|
||||
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="multi_role_dataset")
|
||||
def fixture_multi_role_dataset():
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"from": "system",
|
||||
"value": "use get_weather(city) to get the weather for a city",
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "hello, what's the weather in New York?",
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "let me get that for you",
|
||||
},
|
||||
{
|
||||
"from": "tool",
|
||||
"value": "get_weather(New York)",
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "the weather in New York is 70 degrees and sunny",
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"casperhansen/mistral-7b-instruct-v0.1-awq"
|
||||
)
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"eos_token": AddedToken(
|
||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||
)
|
||||
}
|
||||
)
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
|
||||
]
|
||||
)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="llama3_tokenizer")
|
||||
def fixture_llama3_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
||||
tokenizer.eos_token = "<|eot_id|>"
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class TestSharegptLlama3:
|
||||
"""Test class for ShareGPT style datasets with llama-3 prompts"""
|
||||
|
||||
def test_tokenization(self, sharegpt_dataset, llama3_tokenizer):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="llama3",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
llama3_tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset, process_count=1
|
||||
)
|
||||
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
|
||||
# fmt: off
|
||||
# pylint: disable=duplicate-code
|
||||
assert input_ids == [
|
||||
128000, # bos
|
||||
128006, 9125, 128007, # system header
|
||||
271, 31724, 128009, # sys prompt, eot
|
||||
128006, 882, 128007, # user header
|
||||
271, 15339, 128009, # user prompt eot
|
||||
128006, 78191, 128007, # assistant header
|
||||
271, 15339, 128009, # assistant response eot
|
||||
128006, 882, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_tokenization_with_weights(
|
||||
self, sharegpt_dataset_with_weights, llama3_tokenizer
|
||||
):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="llama3",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
llama3_tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
||||
)
|
||||
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
|
||||
# fmt: off
|
||||
# pylint: disable=duplicate-code
|
||||
assert input_ids == [
|
||||
128000, # bos
|
||||
128006, 9125, 128007, # system header
|
||||
271, 31724, 128009, # sys prompt, eot
|
||||
128006, 882, 128007, # user header
|
||||
271, 15339, 128009, # user prompt eot
|
||||
128006, 78191, 128007, # assistant header
|
||||
271, 15339, 128009, # assistant response eot
|
||||
128006, 882, 128007,
|
||||
271, 11310, 4896, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 11310, 4896, 128009,
|
||||
128006, 882, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
128006, 78191, 128007,
|
||||
271, 19045, 29474, 128009,
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
class TestSharegptChatML:
|
||||
"""
|
||||
Test class for sharegpt prompter
|
||||
"""
|
||||
|
||||
def test_no_double_im_end(self, sharegpt_dataset, tokenizer):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset, process_count=1
|
||||
)
|
||||
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
# 28705, 13, is " \n"
|
||||
1, # bos
|
||||
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
||||
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
|
||||
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_no_double_im_end_with_weights(
|
||||
self, sharegpt_dataset_with_weights, tokenizer
|
||||
):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
||||
)
|
||||
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
# 28705, 13, is " \n"
|
||||
1, # bos
|
||||
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
||||
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
|
||||
32001, 2188, 13, 267, 21558, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt
|
||||
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset, process_count=1
|
||||
)
|
||||
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
-100, # bos
|
||||
-100, -100, -100, -100, -100, -100, -100, # system
|
||||
-100, -100, -100, -100, -100, -100, -100, # human
|
||||
-100, -100, 13, 21558, 32000, 28705, 13, # gpt
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, # human
|
||||
-100, -100, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_no_train_on_input_with_weights(
|
||||
self, sharegpt_dataset_with_weights, tokenizer
|
||||
):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
||||
)
|
||||
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
-100, # bos
|
||||
-100, -100, -100, -100, -100, -100, -100, # system
|
||||
-100, -100, -100, -100, -100, -100, -100, # human
|
||||
-100, -100, -100, -100, -100, -100, -100, # gpt with weight zero
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, # human
|
||||
-100, -100, 13, 267, 21558, 32000, 28705, 13, # gpt
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, # human
|
||||
-100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight zero
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
True, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset, process_count=1
|
||||
)
|
||||
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
1, # bos
|
||||
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
||||
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
|
||||
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_w_train_on_input_with_weights(
|
||||
self, sharegpt_dataset_with_weights, tokenizer
|
||||
):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
True, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
||||
)
|
||||
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
1, # bos
|
||||
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
||||
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
||||
-100, -100, -100, -100, -100, -100, -100, # gpt with weight 0
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, # human with weight 0
|
||||
32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt
|
||||
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
||||
-100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight 0
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_chatml_glaive(self, glaive_dataset, tokenizer):
|
||||
strategy = GlaiveShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation="chatml",
|
||||
role_key_model=None,
|
||||
role_key_human=None,
|
||||
),
|
||||
tokenizer,
|
||||
True, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, glaive_dataset, process_count=1
|
||||
)
|
||||
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
1, # bos
|
||||
32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, # system
|
||||
32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
def test_multi_role_dataset(self, multi_role_dataset, tokenizer):
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}),
|
||||
tokenizer,
|
||||
False, # train_on_inputs
|
||||
2048, # sequence_len
|
||||
)
|
||||
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
strategy, multi_role_dataset, process_count=1
|
||||
)
|
||||
|
||||
input_ids = dataset_wrapper[0]["input_ids"]
|
||||
# fmt: off
|
||||
assert input_ids == [
|
||||
1, # bos
|
||||
32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system
|
||||
32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human
|
||||
32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
||||
32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool
|
||||
32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
labels = dataset_wrapper[0]["labels"]
|
||||
# fmt: off
|
||||
assert labels == [
|
||||
-100, # bos
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human
|
||||
-100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool
|
||||
-100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
||||
]
|
||||
# fmt: on
|
||||
@@ -306,6 +306,10 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
"""Verify that processing data from the hub works with a specific revision"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
|
||||
# make sure prepared_path is empty
|
||||
shutil.rmtree(prepared_path, ignore_errors=True)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
@@ -367,7 +371,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
def test_load_local_hub_with_revision(self):
|
||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
|
||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||
snapshot_download(
|
||||
repo_id="mhenrichsen/alpaca_2k_test",
|
||||
@@ -387,7 +391,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
"ds_type": "parquet",
|
||||
"type": "alpaca",
|
||||
"data_files": [
|
||||
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
|
||||
f"{tmp_ds_path}/alpaca_2000.parquet",
|
||||
],
|
||||
"revision": "d05c1cb",
|
||||
},
|
||||
@@ -405,6 +409,42 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
def test_loading_local_dataset_folder(self):
|
||||
"""Verify that a dataset downloaded to a local folder can be loaded"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||
snapshot_download(
|
||||
repo_id="mhenrichsen/alpaca_2k_test",
|
||||
repo_type="dataset",
|
||||
local_dir=tmp_ds_path,
|
||||
)
|
||||
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"datasets": [
|
||||
{
|
||||
"path": str(tmp_ds_path),
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -39,12 +39,12 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
"datasets": [
|
||||
{
|
||||
"path": "lorem/ipsum",
|
||||
"type": "sharegpt",
|
||||
"conversation": "vicuna_v1.1",
|
||||
"type": "chat_template",
|
||||
"chat_template": "gemma",
|
||||
},
|
||||
{
|
||||
"path": "sit/amet",
|
||||
"type": "sharegpt",
|
||||
"type": "chat_template",
|
||||
},
|
||||
],
|
||||
}
|
||||
@@ -52,8 +52,8 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
||||
|
||||
normalize_cfg_datasets(cfg)
|
||||
|
||||
assert cfg.datasets[0].conversation == "vicuna_v1.1"
|
||||
assert cfg.datasets[1].conversation == "chatml"
|
||||
assert cfg.datasets[0].chat_template == "gemma"
|
||||
assert cfg.datasets[1].chat_template == "chatml"
|
||||
|
||||
@patch("axolotl.utils.config.is_torch_bf16_gpu_available")
|
||||
def test_bf16_auto_setter_available(self, mock_bf16_avail):
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import functools
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -21,6 +22,7 @@ class TestPretrainingPacking(unittest.TestCase):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.pad_token = "</s>"
|
||||
|
||||
@pytest.mark.flaky(retries=3, delay=5)
|
||||
def test_packing_stream_dataset(self):
|
||||
# pylint: disable=duplicate-code
|
||||
dataset = load_dataset(
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import json
|
||||
import logging
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -21,12 +20,8 @@ from axolotl.prompt_strategies.llama2_chat import (
|
||||
LLama2ChatTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompt_strategies.orpo.chat_template import load
|
||||
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
||||
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
@@ -65,17 +60,6 @@ test_data = {
|
||||
}
|
||||
|
||||
|
||||
def prompt_strat(conversation, tokenizer):
|
||||
"Helper function to create a prompt strategy for testing."
|
||||
prompter = ShareGPTPrompterV2(conversation=conversation)
|
||||
return ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
|
||||
|
||||
class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
"""
|
||||
Test class for prompt tokenization strategies.
|
||||
@@ -98,196 +82,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
def test_sharegpt_integration(self):
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
conversation = json.loads(data)
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.tokenized.json",
|
||||
encoding="utf-8",
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
tokenized_conversation = json.loads(data)
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
example = strat.tokenize_prompt(conversation)
|
||||
for fields in ["input_ids", "attention_mask", "labels"]:
|
||||
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
||||
self.assertEqual(example[fields], tokenized_conversation[fields])
|
||||
|
||||
def test_sharegpt_warnings_integration(self):
|
||||
with open(
|
||||
Path(__file__).parent / "fixtures/conversation.missingturns.json",
|
||||
encoding="utf-8",
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
conversation = json.loads(data)
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
strat.tokenize_prompt(conversation)
|
||||
assert "assistant turn has empty text" in self._caplog.records[1].message
|
||||
|
||||
def test_sharegpt_warnings_turns(self):
|
||||
conversation = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
]
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
strat.tokenize_prompt(conversation)
|
||||
assert (
|
||||
"Role did not alternate between turns (gpt and human)"
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
def test_sharegpt_llama(self):
|
||||
"Make sure the sharegpt/llama is tokenized and formatted correctly."
|
||||
strat = prompt_strat("llama-2", self.tokenizer)
|
||||
|
||||
def tokenize(conv):
|
||||
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
|
||||
|
||||
def decode(ids):
|
||||
return strat.tokenizer.decode(ids)
|
||||
|
||||
# fmt: off
|
||||
# System message, multi-turn conversations
|
||||
mt_ids = tokenize(test_data['multi_turn_sys'])
|
||||
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||
|
||||
# System message, single-turn conversations
|
||||
st_ids = tokenize(test_data['single_turn_sys'])
|
||||
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
|
||||
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
||||
|
||||
# No system message, single-turn
|
||||
ns_ids = tokenize(test_data['single_turn_no_sys'])
|
||||
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
||||
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
||||
|
||||
# No system message, multi-turn
|
||||
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
|
||||
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||
# fmt: on
|
||||
|
||||
def test_sharegpt_mistral(self):
|
||||
"Make sure the sharegpt/mistral is tokenized and formatted correctly."
|
||||
strat = prompt_strat("mistral", self.tokenizer)
|
||||
|
||||
def tokenize(conv):
|
||||
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
|
||||
|
||||
def decode(ids):
|
||||
return strat.tokenizer.decode(ids)
|
||||
|
||||
# fmt: off
|
||||
# System message, multi-turn conversations
|
||||
mt_ids = tokenize(test_data['multi_turn_sys'])
|
||||
assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
|
||||
assert mt_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||
|
||||
# System message, single-turn conversations
|
||||
st_ids = tokenize(test_data['single_turn_sys'])
|
||||
assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
|
||||
assert st_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
||||
|
||||
# No system message, single-turn
|
||||
ns_ids = tokenize(test_data['single_turn_no_sys'])
|
||||
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
||||
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
||||
|
||||
# No system message, multi-turn
|
||||
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
|
||||
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
|
||||
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||
# fmt: on
|
||||
|
||||
def test_sharegpt_changes_roles(self):
|
||||
conversation = {
|
||||
"roles": ["USER", "CHARACTER"],
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
],
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
res = strat.tokenize_prompt(conversation)
|
||||
assert "CHARACTER" in self.tokenizer.decode(res["input_ids"])
|
||||
|
||||
def test_sharegpt_assistant_label_ignore(self):
|
||||
conversation = {
|
||||
"roles": ["user", "assistant"],
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
],
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
res = strat.tokenize_prompt(conversation)
|
||||
idx = res["input_ids"].index(20255) # assistant token
|
||||
assert res["labels"][idx] == -100
|
||||
|
||||
def test_glaive_tool_label_ignore(self):
|
||||
conversation = {
|
||||
"system": "SYSTEM: This is a system prompt",
|
||||
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = GlaiveShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
res = strat.tokenize_prompt(conversation)
|
||||
idx = res["input_ids"].index(13566) # assistant token
|
||||
assert res["labels"][idx] == -100
|
||||
|
||||
def test_no_sys_prompt(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
|
||||
@@ -646,39 +646,6 @@ class TestValidation(BaseValidation):
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_sharegpt_deprecation(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
new_cfg = validate_config(cfg)
|
||||
assert any(
|
||||
"`type: sharegpt:chat` will soon be deprecated." in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
assert new_cfg.datasets[0].type == "sharegpt"
|
||||
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"datasets": [
|
||||
{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}
|
||||
]
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
new_cfg = validate_config(cfg)
|
||||
assert any(
|
||||
"`type: sharegpt_simple` will soon be deprecated." in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
assert new_cfg.datasets[0].type == "sharegpt:load_role"
|
||||
|
||||
def test_no_conflict_save_strategy(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
@@ -759,7 +726,7 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "epoch",
|
||||
"eval_strategy": "epoch",
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
@@ -767,14 +734,14 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||
ValueError, match=r".*eval_strategy and eval_steps mismatch.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "no",
|
||||
"eval_strategy": "no",
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
@@ -782,14 +749,14 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||
ValueError, match=r".*eval_strategy and eval_steps mismatch.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "steps",
|
||||
"eval_strategy": "steps",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
@@ -800,7 +767,7 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "steps",
|
||||
"eval_strategy": "steps",
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
@@ -823,7 +790,7 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "no",
|
||||
"eval_strategy": "no",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
@@ -834,7 +801,7 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "epoch",
|
||||
"eval_strategy": "epoch",
|
||||
"val_set_size": 0,
|
||||
}
|
||||
)
|
||||
@@ -843,7 +810,7 @@ class TestValidation(BaseValidation):
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||
match=r".*eval_steps and eval_strategy are not supported with val_set_size == 0.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
@@ -859,7 +826,7 @@ class TestValidation(BaseValidation):
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||
match=r".*eval_steps and eval_strategy are not supported with val_set_size == 0.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
@@ -889,7 +856,7 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "epoch",
|
||||
"eval_strategy": "epoch",
|
||||
"val_set_size": 0.01,
|
||||
}
|
||||
)
|
||||
@@ -1128,6 +1095,24 @@ class TestValidation(BaseValidation):
|
||||
assert new_cfg["dpo_beta"] is None
|
||||
assert len(self._caplog.records) == 1
|
||||
|
||||
def test_eval_strategy_remap(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "steps",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.eval_strategy == "steps"
|
||||
assert (
|
||||
"evaluation_strategy is deprecated, use eval_strategy instead"
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
|
||||
class TestValidationCheckModelConfig(BaseValidation):
|
||||
"""
|
||||
|
||||
@@ -48,9 +48,8 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "LDJnr/Puffin",
|
||||
"type": "sharegpt",
|
||||
"conversation": "chatml",
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
"shards": 10,
|
||||
}
|
||||
]
|
||||
@@ -62,7 +61,6 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
def _check_config():
|
||||
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
|
||||
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
|
||||
assert checked_cfg.datasets[0].conversation == cfg.datasets[0].conversation
|
||||
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
|
||||
|
||||
_check_config()
|
||||
@@ -236,3 +234,59 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
||||
)
|
||||
|
||||
_check_config()
|
||||
|
||||
def test_dataset_sharegpt_deprecation(self, minimal_cfg):
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "LDJnr/Puffin",
|
||||
"type": "sharegpt",
|
||||
"conversation": "chatml",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Check sharegpt deprecation is raised
|
||||
with pytest.raises(ValueError, match=r".*type: sharegpt.*` is deprecated.*"):
|
||||
validate_config(cfg)
|
||||
|
||||
# Check that deprecation is not thrown for non-str type
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": {
|
||||
"field_instruction": "instruction",
|
||||
"field_output": "output",
|
||||
"field_system": "system",
|
||||
"format": "<|user|> {instruction} {input} <|model|>",
|
||||
"no_input_format": "<|user|> {instruction} <|model|>",
|
||||
"system_prompt": "",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
# Check that deprecation is not thrown for non-sharegpt type
|
||||
cfg = DictDefault(
|
||||
minimal_cfg
|
||||
| {
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user