Compare commits
43 Commits
shampoo-lo
...
phi-moe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb8bfab9cc | ||
|
|
0c8b1d824a | ||
|
|
fd70eec577 | ||
|
|
d42f202046 | ||
|
|
0dabde1962 | ||
|
|
15f1462ccd | ||
|
|
521e62daf1 | ||
|
|
c16ec398d7 | ||
|
|
2f20cb7ebf | ||
|
|
71d4030b79 | ||
|
|
f3a5d119af | ||
|
|
ba219b51a5 | ||
|
|
5be8e13d35 | ||
|
|
2d7830fda6 | ||
|
|
5e98cdddac | ||
|
|
1d7aee0ad2 | ||
|
|
659ee5d723 | ||
|
|
342935cff3 | ||
|
|
c5eb9ea2c2 | ||
|
|
f2145a3ccb | ||
|
|
010d0e7ff3 | ||
|
|
01881c3113 | ||
|
|
0e8eb96e07 | ||
|
|
4e1891b12b | ||
|
|
28924fc791 | ||
|
|
8c480b2804 | ||
|
|
a4b1cc6df0 | ||
|
|
7b78a31593 | ||
|
|
810ebc2c0e | ||
|
|
ad435a3b09 | ||
|
|
9f1cf9b17c | ||
|
|
3931a42763 | ||
|
|
dc8f9059f7 | ||
|
|
234e94e9dd | ||
|
|
f68fb71005 | ||
|
|
9bc3ee6c75 | ||
|
|
d356740ffa | ||
|
|
e4af51eb66 | ||
|
|
e20b15bee3 | ||
|
|
d4796cb645 | ||
|
|
fd3b80716a | ||
|
|
3265b7095e | ||
|
|
3cb2d75de1 |
12
.github/workflows/base.yml
vendored
12
.github/workflows/base.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
- cuda: "124"
|
- cuda: "124"
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.10"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "124"
|
- cuda: "124"
|
||||||
@@ -44,19 +44,21 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: Docker metadata
|
- name: Docker metadata
|
||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v3
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-base
|
images: |
|
||||||
|
winglian/axolotl-base
|
||||||
|
axolotlai/axolotl-base
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v2
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
2
.github/workflows/docs.yml
vendored
2
.github/workflows/docs.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
|||||||
- name: Set up Quarto
|
- name: Set up Quarto
|
||||||
uses: quarto-dev/quarto-actions/setup@v2
|
uses: quarto-dev/quarto-actions/setup@v2
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v3
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
- name: install dependencies
|
- name: install dependencies
|
||||||
|
|||||||
6
.github/workflows/lint.yml
vendored
6
.github/workflows/lint.yml
vendored
@@ -15,9 +15,9 @@ jobs:
|
|||||||
name: pre-commit
|
name: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
- uses: pre-commit/action@v3.0.0
|
- uses: pre-commit/action@v3.0.1
|
||||||
|
|||||||
37
.github/workflows/main.yml
vendored
37
.github/workflows/main.yml
vendored
@@ -4,6 +4,8 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
|
tags:
|
||||||
|
- "v*"
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
@@ -32,7 +34,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.0
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
@@ -42,7 +44,12 @@ jobs:
|
|||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl
|
images: |
|
||||||
|
winglian/axolotl
|
||||||
|
axolotlai/axolotl
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=semver,pattern={{version}}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
@@ -56,7 +63,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
||||||
@@ -94,7 +101,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.0
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
@@ -104,20 +111,25 @@ jobs:
|
|||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-cloud
|
images: |
|
||||||
|
winglian/axolotl-cloud
|
||||||
|
axolotlai/axolotl-cloud
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=semver,pattern={{version}}
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
file: ./docker/Dockerfile-cloud
|
file: ./docker/Dockerfile-cloud
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
@@ -146,20 +158,25 @@ jobs:
|
|||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-cloud-term
|
images: |
|
||||||
|
winglian/axolotl-cloud-term
|
||||||
|
axolotlai/axolotl-cloud-term
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=semver,pattern={{version}}
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
file: ./docker/Dockerfile-cloud-no-tmux
|
file: ./docker/Dockerfile-cloud-no-tmux
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
|
|||||||
7
.github/workflows/multi-gpu-e2e.yml
vendored
7
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -8,6 +8,11 @@ on:
|
|||||||
schedule:
|
schedule:
|
||||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||||
|
|
||||||
|
# Cancel jobs on the same ref if a new one is triggered
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-axolotl-multigpu:
|
test-axolotl-multigpu:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
@@ -31,7 +36,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.0
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
|
|||||||
14
.github/workflows/nightlies.yml
vendored
14
.github/workflows/nightlies.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.0
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
@@ -41,7 +41,9 @@ jobs:
|
|||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl
|
images: |
|
||||||
|
winglian/axolotl
|
||||||
|
axolotlai/axolotl
|
||||||
tags: |
|
tags: |
|
||||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
@@ -93,7 +95,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.0
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
@@ -103,7 +105,9 @@ jobs:
|
|||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-cloud
|
images: |
|
||||||
|
winglian/axolotl-cloud
|
||||||
|
axolotlai/axolotl-cloud
|
||||||
tags: |
|
tags: |
|
||||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
@@ -112,7 +116,7 @@ jobs:
|
|||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
|
|||||||
25
.github/workflows/pypi.yml
vendored
25
.github/workflows/pypi.yml
vendored
@@ -3,12 +3,31 @@ name: publish pypi
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- 'v*'
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
setup_release:
|
||||||
|
name: Create Release
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Get the tag version
|
||||||
|
id: extract_branch
|
||||||
|
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
||||||
|
shell: bash
|
||||||
|
|
||||||
|
- name: Create Release
|
||||||
|
id: create_release
|
||||||
|
uses: actions/create-release@v1
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
with:
|
||||||
|
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
||||||
|
release_name: ${{ steps.extract_branch.outputs.branch }}
|
||||||
pypi-publish:
|
pypi-publish:
|
||||||
name: Upload release to PyPI
|
name: Upload release to PyPI
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
needs: [setup_release]
|
||||||
environment:
|
environment:
|
||||||
name: pypi
|
name: pypi
|
||||||
url: https://pypi.org/p/axolotl
|
url: https://pypi.org/p/axolotl
|
||||||
@@ -16,10 +35,10 @@ jobs:
|
|||||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
|
||||||
|
|||||||
15
.github/workflows/tests-nightly.yml
vendored
15
.github/workflows/tests-nightly.yml
vendored
@@ -9,12 +9,12 @@ jobs:
|
|||||||
name: pre-commit
|
name: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
- uses: pre-commit/action@v3.0.0
|
- uses: pre-commit/action@v3.0.1
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
@@ -25,15 +25,15 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11"]
|
python_version: ["3.10", "3.11"]
|
||||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
|
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
@@ -48,6 +48,7 @@ jobs:
|
|||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
||||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
||||||
|
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -92,7 +93,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.0
|
pytorch: 2.5.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
|
|||||||
19
.github/workflows/tests.yml
vendored
19
.github/workflows/tests.yml
vendored
@@ -15,17 +15,22 @@ on:
|
|||||||
- '.github/workflows/*.yml'
|
- '.github/workflows/*.yml'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
# Cancel jobs on the same ref if a new one is triggered
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pre-commit:
|
pre-commit:
|
||||||
name: pre-commit
|
name: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
- uses: pre-commit/action@v3.0.0
|
- uses: pre-commit/action@v3.0.1
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
@@ -36,15 +41,15 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version: ["3.10", "3.11"]
|
python_version: ["3.10", "3.11"]
|
||||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
|
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
||||||
timeout-minutes: 20
|
timeout-minutes: 20
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
@@ -132,7 +137,7 @@ jobs:
|
|||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.0
|
pytorch: 2.5.1
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
15
README.md
15
README.md
@@ -159,7 +159,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
|
|||||||
#### Docker
|
#### Docker
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest
|
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||||
```
|
```
|
||||||
|
|
||||||
Or run on the current files for development:
|
Or run on the current files for development:
|
||||||
@@ -178,7 +178,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
|
|||||||
A more powerful Docker command to run would be this:
|
A more powerful Docker command to run would be this:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest
|
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-latest
|
||||||
```
|
```
|
||||||
|
|
||||||
It additionally:
|
It additionally:
|
||||||
@@ -210,7 +210,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
|||||||
|
|
||||||
#### Cloud GPU
|
#### Cloud GPU
|
||||||
|
|
||||||
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
|
For cloud GPU providers that support docker images, use [`axolotlai/axolotl-cloud:main-latest`](https://hub.docker.com/r/axolotlai/axolotl-cloud/tags)
|
||||||
|
|
||||||
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||||
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
|
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
|
||||||
@@ -319,7 +319,7 @@ Write a job description in YAML as below:
|
|||||||
# dstack.yaml
|
# dstack.yaml
|
||||||
type: task
|
type: task
|
||||||
|
|
||||||
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2
|
image: axolotlai/axolotl-cloud:main-latest
|
||||||
|
|
||||||
env:
|
env:
|
||||||
- HUGGING_FACE_HUB_TOKEN
|
- HUGGING_FACE_HUB_TOKEN
|
||||||
@@ -383,11 +383,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
- typescript
|
- typescript
|
||||||
type: ... # unimplemented custom format
|
type: ... # unimplemented custom format
|
||||||
|
|
||||||
# fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template)
|
# chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template
|
||||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
|
||||||
- path: ...
|
- path: ...
|
||||||
type: sharegpt
|
type: chat_template
|
||||||
conversation: chatml # default: vicuna_v1.1
|
chat_template: chatml # defaults to tokenizer's chat_template
|
||||||
|
|
||||||
# local
|
# local
|
||||||
- path: data.jsonl # or json
|
- path: data.jsonl # or json
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
FROM winglian/axolotl-base:{{ BASE_TAG }}
|
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
||||||
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||||
@@ -28,6 +28,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||||
|
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import tempfile
|
|||||||
import jinja2
|
import jinja2
|
||||||
import modal
|
import modal
|
||||||
from jinja2 import select_autoescape
|
from jinja2 import select_autoescape
|
||||||
from modal import Image, Stub
|
from modal import App, Image
|
||||||
|
|
||||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ cicd_image = (
|
|||||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||||
)
|
)
|
||||||
|
|
||||||
stub = Stub("Axolotl CI/CD", secrets=[])
|
app = App("Axolotl CI/CD", secrets=[])
|
||||||
|
|
||||||
|
|
||||||
N_GPUS = int(os.environ.get("N_GPUS", 2))
|
N_GPUS = int(os.environ.get("N_GPUS", 2))
|
||||||
@@ -61,7 +61,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||||
|
|
||||||
|
|
||||||
@stub.function(
|
@app.function(
|
||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=60 * 60,
|
timeout=60 * 60,
|
||||||
@@ -72,6 +72,6 @@ def cicd_pytest():
|
|||||||
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
||||||
|
|
||||||
|
|
||||||
@stub.local_entrypoint()
|
@app.local_entrypoint()
|
||||||
def main():
|
def main():
|
||||||
cicd_pytest.remote()
|
cicd_pytest.remote()
|
||||||
|
|||||||
@@ -2,4 +2,4 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
# only run one test at a time so as not to OOM the GPU
|
# only run one test at a time so as not to OOM the GPU
|
||||||
pytest -n1 /workspace/axolotl/tests/e2e/multigpu/
|
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import tempfile
|
|||||||
import jinja2
|
import jinja2
|
||||||
import modal
|
import modal
|
||||||
from jinja2 import select_autoescape
|
from jinja2 import select_autoescape
|
||||||
from modal import Image, Stub
|
from modal import App, Image
|
||||||
|
|
||||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ cicd_image = (
|
|||||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||||
)
|
)
|
||||||
|
|
||||||
stub = Stub("Axolotl CI/CD", secrets=[])
|
app = App("Axolotl CI/CD", secrets=[])
|
||||||
|
|
||||||
|
|
||||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||||
@@ -62,7 +62,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||||
|
|
||||||
|
|
||||||
@stub.function(
|
@app.function(
|
||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=60 * 60,
|
timeout=60 * 60,
|
||||||
@@ -73,6 +73,6 @@ def cicd_pytest():
|
|||||||
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
||||||
|
|
||||||
|
|
||||||
@stub.local_entrypoint()
|
@app.local_entrypoint()
|
||||||
def main():
|
def main():
|
||||||
cicd_pytest.remote()
|
cicd_pytest.remote()
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Example config for debugging the sharegpt prompt format
|
# Example config for debugging the chat_template prompt format
|
||||||
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||||
model_type: LlamaForCausalLM
|
model_type: LlamaForCausalLM
|
||||||
tokenizer_type: LlamaTokenizer
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
ARG BASE_TAG=main-base
|
ARG BASE_TAG=main-base
|
||||||
FROM winglian/axolotl-base:$BASE_TAG
|
FROM axolotlai/axolotl-base:$BASE_TAG
|
||||||
|
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
ARG AXOLOTL_EXTRAS=""
|
ARG AXOLOTL_EXTRAS=""
|
||||||
|
|||||||
@@ -35,7 +35,3 @@ RUN git lfs install --skip-repo && \
|
|||||||
pip3 install awscli && \
|
pip3 install awscli && \
|
||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|
||||||
RUN if [ "$PYTHON_VERSION" != "2.5.1" ] ; then \
|
|
||||||
pip3 install flash-attn==2.6.3; \
|
|
||||||
fi
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
ARG BASE_TAG=main
|
ARG BASE_TAG=main
|
||||||
FROM winglian/axolotl:$BASE_TAG
|
FROM axolotlai/axolotl:$BASE_TAG
|
||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
ARG BASE_TAG=main
|
ARG BASE_TAG=main
|
||||||
FROM winglian/axolotl:$BASE_TAG
|
FROM axolotlai/axolotl:$BASE_TAG
|
||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
ARG BASE_TAG=main-base
|
ARG BASE_TAG=main-base
|
||||||
FROM winglian/axolotl-base:$BASE_TAG
|
FROM axolotlai/axolotl-base:$BASE_TAG
|
||||||
|
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
ARG AXOLOTL_EXTRAS=""
|
ARG AXOLOTL_EXTRAS=""
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ lora_on_cpu: true
|
|||||||
datasets:
|
datasets:
|
||||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||||
- path: vicgalle/alpaca-gpt4
|
- path: vicgalle/alpaca-gpt4
|
||||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
|
||||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||||
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
|
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
|
||||||
data_files: # Optional[str] path to source data files
|
data_files: # Optional[str] path to source data files
|
||||||
@@ -91,15 +91,7 @@ datasets:
|
|||||||
name: # Optional[str] name of dataset configuration to load
|
name: # Optional[str] name of dataset configuration to load
|
||||||
train_on_split: train # Optional[str] name of dataset split to load from
|
train_on_split: train # Optional[str] name of dataset split to load from
|
||||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||||
|
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
|
||||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
|
||||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
|
||||||
field_human: # Optional[str]. Human key to use for conversation.
|
|
||||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
|
||||||
# Add additional keys from your dataset as input or output roles
|
|
||||||
roles:
|
|
||||||
input: # Optional[List[str]]. These will be masked based on train_on_input
|
|
||||||
output: # Optional[List[str]].
|
|
||||||
|
|
||||||
# Custom user instruction prompt
|
# Custom user instruction prompt
|
||||||
- path: repo
|
- path: repo
|
||||||
@@ -183,6 +175,8 @@ test_datasets:
|
|||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto'
|
# use RL training: 'dpo', 'ipo', 'kto'
|
||||||
rl:
|
rl:
|
||||||
|
# whether to perform weighting if doing DPO training. Boolean.
|
||||||
|
dpo_use_weighting:
|
||||||
|
|
||||||
# The name of the chat template to use for training, following values are supported:
|
# The name of the chat template to use for training, following values are supported:
|
||||||
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
|
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
|
||||||
@@ -412,6 +406,7 @@ lr_div_factor: # Learning rate div factor
|
|||||||
# - adamw_torch_fused
|
# - adamw_torch_fused
|
||||||
# - adamw_torch_xla
|
# - adamw_torch_xla
|
||||||
# - adamw_apex_fused
|
# - adamw_apex_fused
|
||||||
|
# - adopt_adamw (only for torch version >= 2.5.1)
|
||||||
# - adafactor
|
# - adafactor
|
||||||
# - adamw_anyprecision
|
# - adamw_anyprecision
|
||||||
# - sgd
|
# - sgd
|
||||||
|
|||||||
@@ -6,33 +6,8 @@ order: 3
|
|||||||
|
|
||||||
## sharegpt
|
## sharegpt
|
||||||
|
|
||||||
UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below.
|
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
|
||||||
|
|
||||||
conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
|
|
||||||
|
|
||||||
```{.json filename="data.jsonl"}
|
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
|
||||||
```
|
|
||||||
|
|
||||||
Note: `type: sharegpt` opens special configs:
|
|
||||||
- `conversation`: enables conversions to many Conversation types. Refer to the 'name' [here](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) for options.
|
|
||||||
- `roles`: allows you to specify the roles for input and output. This is useful for datasets with custom roles such as `tool` etc to support masking.
|
|
||||||
- `field_human`: specify the key to use instead of `human` in the conversation.
|
|
||||||
- `field_model`: specify the key to use instead of `gpt` in the conversation.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
datasets:
|
|
||||||
path: ...
|
|
||||||
type: sharegpt
|
|
||||||
|
|
||||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
|
||||||
field_human: # Optional[str]. Human key to use for conversation.
|
|
||||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
|
||||||
# Add additional keys from your dataset as input or output roles
|
|
||||||
roles:
|
|
||||||
input: # Optional[List[str]]. These will be masked based on train_on_input
|
|
||||||
output: # Optional[List[str]].
|
|
||||||
```
|
|
||||||
|
|
||||||
## pygmalion
|
## pygmalion
|
||||||
|
|
||||||
@@ -40,38 +15,6 @@ datasets:
|
|||||||
{"conversations": [{"role": "...", "value": "..."}]}
|
{"conversations": [{"role": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
|
|
||||||
## sharegpt.load_role
|
|
||||||
|
|
||||||
conversations where `role` is used instead of `from`
|
|
||||||
|
|
||||||
```{.json filename="data.jsonl"}
|
|
||||||
{"conversations": [{"role": "...", "value": "..."}]}
|
|
||||||
```
|
|
||||||
|
|
||||||
## sharegpt.load_guanaco
|
|
||||||
|
|
||||||
conversations where `from` is `prompter` `assistant` instead of default sharegpt
|
|
||||||
|
|
||||||
```{.json filename="data.jsonl"}
|
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
|
||||||
```
|
|
||||||
|
|
||||||
## sharegpt.load_ultrachat
|
|
||||||
|
|
||||||
conversations where the turns field is 'messages', human is 'user' and gpt is 'assistant'.
|
|
||||||
|
|
||||||
```{.json filename="data.jsonl"}
|
|
||||||
{"messages": [{"user": "...", "assistant": "..."}]}
|
|
||||||
```
|
|
||||||
|
|
||||||
## sharegpt_jokes
|
|
||||||
|
|
||||||
creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
|
||||||
|
|
||||||
```{.json filename="data.jsonl"}
|
|
||||||
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## chat_template
|
## chat_template
|
||||||
|
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
|
|||||||
|
|
||||||
## Debugging With Docker
|
## Debugging With Docker
|
||||||
|
|
||||||
Using [official Axolotl Docker images](https://hub.docker.com/r/winglian/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
|
Using [official Axolotl Docker images](https://hub.docker.com/r/axolotlai/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
@@ -202,11 +202,11 @@ cd axolotl
|
|||||||
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
|
||||||
```
|
```
|
||||||
|
|
||||||
>[!Tip]
|
>[!Tip]
|
||||||
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/winglian/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
||||||
|
|
||||||
You will now be in the container. Next, perform an editable install of Axolotl:
|
You will now be in the container. Next, perform an editable install of Axolotl:
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
|
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
|
||||||
"!pip install flash-attn==\"2.5.0\"\n",
|
"!pip install flash-attn==\"2.7.0.post2\"\n",
|
||||||
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
93
examples/mistral/mistral-dpo-qlora.yml
Normal file
93
examples/mistral/mistral-dpo-qlora.yml
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
#Note that we are switching from the regular chat template to chatml.
|
||||||
|
#If you experience problems with the special tokens, training for more epochs can help.
|
||||||
|
#After training, merge the model before inference otherwise you might
|
||||||
|
#face problems with the special tokens.
|
||||||
|
|
||||||
|
base_model: mistralai/Mistral-7B-Instruct-v0.2
|
||||||
|
model_type: MistralForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: chatml
|
||||||
|
rl: dpo
|
||||||
|
datasets:
|
||||||
|
- path: olivermolenschot/alpaca_messages_dpo_test
|
||||||
|
type: chat_template.default
|
||||||
|
field_messages: conversation
|
||||||
|
field_chosen: chosen
|
||||||
|
field_rejected: rejected
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/dpo-qlora
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 8
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.2
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
lora_modules_to_save:
|
||||||
|
- embed_tokens
|
||||||
|
- lm_head
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 16
|
||||||
|
num_epochs: 6
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0001
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: false
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<|im_start|>"
|
||||||
|
eos_token: "<|im_end|>"
|
||||||
67
examples/qwen2/dpo.yaml
Normal file
67
examples/qwen2/dpo.yaml
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
base_model: Qwen/Qwen2.5-0.5B
|
||||||
|
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: qwen_25
|
||||||
|
rl: dpo
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||||
|
type: chat_template.default
|
||||||
|
field_messages: conversation
|
||||||
|
field_chosen: chosen
|
||||||
|
field_rejected: rejected
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
roles:
|
||||||
|
system:
|
||||||
|
- system
|
||||||
|
user:
|
||||||
|
- user
|
||||||
|
assistant:
|
||||||
|
- assistant
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/dpo-out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
@@ -1,2 +1,3 @@
|
|||||||
pytest
|
pytest
|
||||||
pytest-xdist
|
pytest-xdist
|
||||||
|
pytest-retry
|
||||||
|
|||||||
@@ -1,18 +1,18 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.13.2
|
peft==0.13.2
|
||||||
transformers==4.46.1
|
transformers==4.46.2
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
bitsandbytes==0.44.1
|
bitsandbytes==0.44.1
|
||||||
accelerate==1.1.0
|
accelerate==1.1.0
|
||||||
datasets==3.0.1
|
datasets==3.1.0
|
||||||
deepspeed==0.15.3
|
deepspeed==0.15.3
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
requests
|
requests
|
||||||
flash-attn==2.6.3
|
flash-attn==2.7.0.post2
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
@@ -28,13 +28,12 @@ scipy
|
|||||||
scikit-learn==1.4.2
|
scikit-learn==1.4.2
|
||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
|
||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
autoawq>=0.2.5
|
autoawq>=0.2.5
|
||||||
triton>=2.3.0
|
triton>=2.3.0
|
||||||
liger-kernel==0.4.0
|
liger-kernel==0.4.1
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
@@ -43,7 +42,7 @@ s3fs>=2024.5.0
|
|||||||
gcsfs>=2024.5.0
|
gcsfs>=2024.5.0
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
|
trl==0.12.0
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|
||||||
@@ -54,3 +53,4 @@ immutabledict==4.2.0
|
|||||||
antlr4-python3-runtime==4.13.2
|
antlr4-python3-runtime==4.13.2
|
||||||
|
|
||||||
torchao==0.5.0
|
torchao==0.5.0
|
||||||
|
schedulefree==1.3.0
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
# Export specific ENV variables to /etc/rp_environment
|
# Export specific ENV variables to /etc/rp_environment
|
||||||
echo "Exporting environment variables..."
|
echo "Exporting environment variables..."
|
||||||
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
printenv | grep -E '^HF_|^BNB_|^CUDA_|^NCCL_|^NV|^RUNPOD_|^PATH=|^_=' | sed 's/^\([^=]*\)=\(.*\)$/export \1="\2"/' | grep -v 'printenv' >> /etc/rp_environment
|
||||||
echo 'source /etc/rp_environment' >> ~/.bashrc
|
echo 'source /etc/rp_environment' >> ~/.bashrc
|
||||||
|
|
||||||
add_keys_to_authorized() {
|
add_keys_to_authorized() {
|
||||||
|
|||||||
16
setup.py
16
setup.py
@@ -39,7 +39,10 @@ def parse_requirements():
|
|||||||
else:
|
else:
|
||||||
# detect the version of torch already installed
|
# detect the version of torch already installed
|
||||||
# and set it so dependencies don't clobber the torch version
|
# and set it so dependencies don't clobber the torch version
|
||||||
torch_version = version("torch")
|
try:
|
||||||
|
torch_version = version("torch")
|
||||||
|
except PackageNotFoundError:
|
||||||
|
torch_version = "2.5.1"
|
||||||
_install_requires.append(f"torch=={torch_version}")
|
_install_requires.append(f"torch=={torch_version}")
|
||||||
|
|
||||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||||
@@ -54,6 +57,10 @@ def parse_requirements():
|
|||||||
|
|
||||||
if (major, minor) >= (2, 5):
|
if (major, minor) >= (2, 5):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
if patch == 0:
|
||||||
|
_install_requires.append("xformers==0.0.28.post2")
|
||||||
|
else:
|
||||||
|
_install_requires.append("xformers==0.0.28.post3")
|
||||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||||
elif (major, minor) >= (2, 4):
|
elif (major, minor) >= (2, 4):
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
@@ -89,7 +96,7 @@ install_requires, dependency_links = parse_requirements()
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="axolotl",
|
name="axolotl",
|
||||||
version="0.4.1",
|
version="0.5.0",
|
||||||
description="LLM Trainer",
|
description="LLM Trainer",
|
||||||
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
@@ -98,10 +105,7 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.6.3",
|
"flash-attn==2.7.0.post2",
|
||||||
],
|
|
||||||
"fused-dense-lib": [
|
|
||||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib",
|
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.14.4",
|
"deepspeed==0.14.4",
|
||||||
|
|||||||
@@ -190,18 +190,15 @@ def do_inference(
|
|||||||
):
|
):
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
|
||||||
|
|
||||||
for token, symbol in default_tokens.items():
|
|
||||||
# If the token isn't already specified in the config, add it
|
|
||||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
|
||||||
tokenizer.add_special_tokens({token: symbol})
|
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
|
chat_template_str = None
|
||||||
if prompter:
|
if prompter:
|
||||||
prompter_module = getattr(
|
prompter_module = getattr(
|
||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
elif cfg.chat_template:
|
||||||
|
chat_template_str = get_chat_template(cfg.chat_template)
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
@@ -211,13 +208,31 @@ def do_inference(
|
|||||||
instruction = get_multi_line_input()
|
instruction = get_multi_line_input()
|
||||||
if not instruction:
|
if not instruction:
|
||||||
return
|
return
|
||||||
|
|
||||||
if prompter_module:
|
if prompter_module:
|
||||||
prompt: str = next(
|
prompt: str = next(
|
||||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = instruction.strip()
|
prompt = instruction.strip()
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
|
||||||
|
if chat_template_str:
|
||||||
|
batch = tokenizer.apply_chat_template(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||||
|
|
||||||
print("=" * 40)
|
print("=" * 40)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -257,13 +272,6 @@ def do_inference_gradio(
|
|||||||
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
|
||||||
default_tokens: Dict[str, str] = {}
|
|
||||||
|
|
||||||
for token, symbol in default_tokens.items():
|
|
||||||
# If the token isn't already specified in the config, add it
|
|
||||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
|
||||||
tokenizer.add_special_tokens({token: symbol})
|
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
chat_template_str = None
|
chat_template_str = None
|
||||||
|
|||||||
@@ -23,10 +23,6 @@ from axolotl.cli import (
|
|||||||
)
|
)
|
||||||
from axolotl.common.cli import PreprocessCliArgs
|
from axolotl.common.cli import PreprocessCliArgs
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.prompt_strategies.sharegpt import (
|
|
||||||
register_chatml_template,
|
|
||||||
register_llama3_template,
|
|
||||||
)
|
|
||||||
from axolotl.utils.trainer import disable_datasets_caching
|
from axolotl.utils.trainer import disable_datasets_caching
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||||
@@ -44,23 +40,6 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if parsed_cfg.chat_template == "chatml":
|
|
||||||
if parsed_cfg.default_system_message:
|
|
||||||
LOG.info(
|
|
||||||
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
|
||||||
)
|
|
||||||
register_chatml_template(parsed_cfg.default_system_message)
|
|
||||||
else:
|
|
||||||
register_chatml_template()
|
|
||||||
elif parsed_cfg.chat_template == "llama3":
|
|
||||||
if parsed_cfg.default_system_message:
|
|
||||||
LOG.info(
|
|
||||||
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
|
|
||||||
)
|
|
||||||
register_llama3_template(parsed_cfg.default_system_message)
|
|
||||||
else:
|
|
||||||
register_llama3_template()
|
|
||||||
|
|
||||||
if not parsed_cfg.dataset_prepared_path:
|
if not parsed_cfg.dataset_prepared_path:
|
||||||
msg = (
|
msg = (
|
||||||
Fore.RED
|
Fore.RED
|
||||||
|
|||||||
@@ -19,10 +19,6 @@ from axolotl.cli import (
|
|||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.prompt_strategies.sharegpt import (
|
|
||||||
register_chatml_template,
|
|
||||||
register_llama3_template,
|
|
||||||
)
|
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.train")
|
LOG = logging.getLogger("axolotl.cli.train")
|
||||||
@@ -42,21 +38,6 @@ def do_train(cfg, cli_args) -> None:
|
|||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
if cfg.chat_template == "chatml" and cfg.default_system_message:
|
|
||||||
LOG.info(
|
|
||||||
f"ChatML set. Adding default system message: {cfg.default_system_message}"
|
|
||||||
)
|
|
||||||
register_chatml_template(cfg.default_system_message)
|
|
||||||
else:
|
|
||||||
register_chatml_template()
|
|
||||||
|
|
||||||
if cfg.chat_template == "llama3" and cfg.default_system_message:
|
|
||||||
LOG.info(
|
|
||||||
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
|
|
||||||
)
|
|
||||||
register_llama3_template(cfg.default_system_message)
|
|
||||||
else:
|
|
||||||
register_llama3_template()
|
|
||||||
|
|
||||||
if cfg.rl: # and cfg.rl != "orpo":
|
if cfg.rl: # and cfg.rl != "orpo":
|
||||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ MOE_ARCH_BLOCK = {
|
|||||||
"JetMoeMoE",
|
"JetMoeMoE",
|
||||||
],
|
],
|
||||||
"mixtral": "MixtralSparseMoeBlock",
|
"mixtral": "MixtralSparseMoeBlock",
|
||||||
|
"phimoe": "PhiMoESparseMoeBlock",
|
||||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||||
"deepseek_v2": "DeepseekV2MoE",
|
"deepseek_v2": "DeepseekV2MoE",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -436,7 +436,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
and self.args.alternate_optimizer
|
and self.args.alternate_optimizer
|
||||||
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
|
not in [
|
||||||
|
"optimi_adamw",
|
||||||
|
"ao_adamw_8bit",
|
||||||
|
"ao_adamw_4bit",
|
||||||
|
"ao_adamw_fp8",
|
||||||
|
"adopt_adamw",
|
||||||
|
]
|
||||||
):
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
@@ -505,6 +511,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
)
|
)
|
||||||
|
elif self.args.alternate_optimizer == "adopt_adamw":
|
||||||
|
from axolotl.utils.optimizers.adopt import ADOPT
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
ADOPT(
|
||||||
|
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
@@ -1024,24 +1038,37 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def tokenize_row(
|
def tokenize_row(
|
||||||
self,
|
|
||||||
features,
|
features,
|
||||||
processing_class,
|
processing_class,
|
||||||
max_prompt_length,
|
max_prompt_length,
|
||||||
max_completion_length,
|
max_completion_length,
|
||||||
add_special_tokens,
|
add_special_tokens,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
res = super().tokenize_row(
|
res = DPOTrainer.tokenize_row(
|
||||||
features,
|
features,
|
||||||
processing_class,
|
processing_class,
|
||||||
max_prompt_length,
|
max_prompt_length,
|
||||||
max_completion_length,
|
max_completion_length,
|
||||||
add_special_tokens,
|
add_special_tokens,
|
||||||
)
|
)
|
||||||
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||||
|
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||||
for key in res.keys():
|
for key in res.keys():
|
||||||
res[key] = res[key][1:]
|
res[key] = res[key][1:]
|
||||||
|
|
||||||
|
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||||
|
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||||
|
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||||
|
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||||
|
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||||
|
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||||
|
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||||
|
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||||
|
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||||
|
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def training_step(
|
def training_step(
|
||||||
@@ -1273,6 +1300,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||||
callbacks.append(lisa_callback_factory(trainer))
|
callbacks.append(lisa_callback_factory(trainer))
|
||||||
|
|
||||||
|
if self.cfg.plugins:
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
callbacks.extend(
|
||||||
|
[
|
||||||
|
cb
|
||||||
|
for cb in plugin_manager.add_callbacks_post_trainer(
|
||||||
|
self.cfg, trainer
|
||||||
|
)
|
||||||
|
if cb
|
||||||
|
]
|
||||||
|
)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def _get_trainer_cls(self):
|
def _get_trainer_cls(self):
|
||||||
@@ -1390,17 +1429,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||||
# no eval set, so don't eval
|
# no eval set, so don't eval
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
training_arguments_kwargs["eval_strategy"] = "no"
|
||||||
elif self.cfg.eval_steps:
|
elif self.cfg.eval_steps:
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
training_arguments_kwargs["eval_strategy"] = "steps"
|
||||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
elif self.cfg.evaluation_strategy:
|
elif self.cfg.eval_strategy:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||||
"evaluation_strategy"
|
|
||||||
] = self.cfg.evaluation_strategy
|
|
||||||
else:
|
else:
|
||||||
# we have an eval set, but no steps defined, default to use epoch
|
# we have an eval set, but no steps defined, default to use epoch
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
training_arguments_kwargs["eval_strategy"] = "epoch"
|
||||||
|
|
||||||
if self.cfg.save_steps:
|
if self.cfg.save_steps:
|
||||||
training_arguments_kwargs["save_strategy"] = "steps"
|
training_arguments_kwargs["save_strategy"] = "steps"
|
||||||
@@ -1625,11 +1662,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
if self.cfg.optimizer in [
|
if self.cfg.optimizer in [
|
||||||
"optimi_adamw",
|
"optimi_adamw",
|
||||||
"ao_adamw_4bit",
|
"ao_adamw_4bit",
|
||||||
"ao_adamw_8bit",
|
"ao_adamw_8bit",
|
||||||
"ao_adamw_fp8",
|
"ao_adamw_fp8",
|
||||||
|
"adopt_adamw",
|
||||||
]:
|
]:
|
||||||
# Set default so transformers doesn't throw
|
# Set default so transformers doesn't throw
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||||
@@ -1832,10 +1871,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
training_args_kwargs["evaluation_strategy"] = "steps"
|
training_args_kwargs["eval_strategy"] = "steps"
|
||||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["evaluation_strategy"] = "no"
|
training_args_kwargs["eval_strategy"] = "no"
|
||||||
|
|
||||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||||
training_args_kwargs["bf16"] = True
|
training_args_kwargs["bf16"] = True
|
||||||
@@ -1890,17 +1929,18 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
# default to saving each epoch if not defined
|
# default to saving each epoch if not defined
|
||||||
training_args_kwargs["save_strategy"] = "epoch"
|
training_args_kwargs["save_strategy"] = "epoch"
|
||||||
|
|
||||||
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
|
||||||
if self.cfg.rl_beta:
|
if self.cfg.rl_beta:
|
||||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
training_args_kwargs["beta"] = self.cfg.rl_beta
|
||||||
if self.cfg.orpo_alpha:
|
if self.cfg.orpo_alpha:
|
||||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||||
|
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
|
||||||
training_args_cls = AxolotlDPOConfig
|
|
||||||
if self.cfg.rpo_alpha is not None:
|
if self.cfg.rpo_alpha is not None:
|
||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||||
|
|
||||||
|
training_args_cls = None
|
||||||
if self.cfg.rl == "simpo":
|
if self.cfg.rl == "simpo":
|
||||||
training_args_cls = AxolotlCPOConfig
|
training_args_cls = AxolotlCPOConfig
|
||||||
training_args_kwargs["loss_type"] = "simpo"
|
training_args_kwargs["loss_type"] = "simpo"
|
||||||
@@ -1909,13 +1949,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.cpo_alpha is not None:
|
if self.cfg.cpo_alpha is not None:
|
||||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||||
|
|
||||||
if self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
training_args_cls = AxolotlORPOConfig
|
training_args_cls = AxolotlORPOConfig
|
||||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
if self.cfg.max_prompt_len:
|
if self.cfg.max_prompt_len:
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
if self.cfg.rl == "kto":
|
elif self.cfg.rl == "kto":
|
||||||
training_args_cls = AxolotlKTOConfig
|
training_args_cls = AxolotlKTOConfig
|
||||||
|
|
||||||
training_args_kwargs["desirable_weight"] = (
|
training_args_kwargs["desirable_weight"] = (
|
||||||
@@ -1930,6 +1970,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.max_prompt_len:
|
if self.cfg.max_prompt_len:
|
||||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||||
|
|
||||||
|
else:
|
||||||
|
training_args_cls = AxolotlDPOConfig
|
||||||
|
if self.cfg.rl == "ipo":
|
||||||
|
training_args_kwargs["loss_type"] = "ipo"
|
||||||
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
training_args_kwargs["max_completion_length"] = None
|
||||||
|
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||||
|
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||||
|
if self.cfg.dpo_use_weighting is not None:
|
||||||
|
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||||
|
|
||||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
output_dir=self.cfg.output_dir,
|
output_dir=self.cfg.output_dir,
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
@@ -1950,7 +2001,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args = self.build_training_arguments(total_num_steps)
|
training_args = self.build_training_arguments(total_num_steps)
|
||||||
dpo_trainer_kwargs = {}
|
dpo_trainer_kwargs = {}
|
||||||
if self.cfg.rl == "ipo":
|
if self.cfg.rl == "ipo":
|
||||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
|
||||||
if self.cfg.dpo_label_smoothing:
|
if self.cfg.dpo_label_smoothing:
|
||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
@@ -1964,12 +2014,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rl in ["dpo", "ipo"]:
|
if self.cfg.rl in ["dpo", "ipo"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|
||||||
# these aren't used for the ORPO trainer
|
|
||||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
|
||||||
dpo_trainer_kwargs["max_target_length"] = None
|
|
||||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
|
||||||
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ class BasePlugin:
|
|||||||
|
|
||||||
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Adds callbacks to the trainer before training.
|
setup callbacks before creating the trainer.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
cfg (dict): The configuration for the plugin.
|
cfg (dict): The configuration for the plugin.
|
||||||
@@ -155,14 +155,15 @@ class BasePlugin:
|
|||||||
self, cfg, trainer
|
self, cfg, trainer
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Adds callbacks to the trainer after training.
|
Adds callbacks to the trainer after creating the trainer.
|
||||||
|
This is useful for callbacks that require access to the model or trainer.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
cfg (dict): The configuration for the plugin.
|
cfg (dict): The configuration for the plugin.
|
||||||
trainer (object): The trainer object for training.
|
trainer (object): The trainer object for training.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
List[callable]: A list of callback functions to be added
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -393,7 +394,9 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
callbacks = []
|
callbacks = []
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
|
plugin_callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
|
||||||
|
if plugin_callbacks: # if the plugin returned a list of callbacks
|
||||||
|
callbacks.extend(plugin_callbacks)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||||
@@ -409,7 +412,9 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
callbacks = []
|
callbacks = []
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
plugin_callbacks = plugin.add_callbacks_post_trainer(cfg, trainer)
|
||||||
|
if plugin_callbacks:
|
||||||
|
callbacks.extend(plugin_callbacks)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def post_train_unload(self, cfg):
|
def post_train_unload(self, cfg):
|
||||||
|
|||||||
21
src/axolotl/integrations/grokfast/LICENSE
Normal file
21
src/axolotl/integrations/grokfast/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
13
src/axolotl/integrations/grokfast/README.md
Normal file
13
src/axolotl/integrations/grokfast/README.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Grokfast Optimizer
|
||||||
|
|
||||||
|
See https://github.com/ironjr/grokfast
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.grokfast.GrokfastPlugin
|
||||||
|
|
||||||
|
grokfast_alpha: 2.0
|
||||||
|
grokfast_lamb: 0.98
|
||||||
|
```
|
||||||
50
src/axolotl/integrations/grokfast/__init__.py
Normal file
50
src/axolotl/integrations/grokfast/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""
|
||||||
|
Grokfast plugin for Axolotl
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from transformers.trainer_callback import TrainerCallback
|
||||||
|
|
||||||
|
from ..base import BasePlugin
|
||||||
|
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
from .optimizer import gradfilter_ema
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.integrations.grokfast")
|
||||||
|
|
||||||
|
|
||||||
|
class GrokfastCallbackHandler(TrainerCallback):
|
||||||
|
"""
|
||||||
|
Transformer trainer callbacks for Grokfast
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs):
|
||||||
|
super().__init__(*args_, **kwargs)
|
||||||
|
self.grads = None
|
||||||
|
self.alpha = alpha
|
||||||
|
self.lamb = lamb
|
||||||
|
|
||||||
|
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
|
||||||
|
self.grads = None
|
||||||
|
|
||||||
|
def on_pre_optimizer_step(
|
||||||
|
self, args_, state, control, **kwargs
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
model = kwargs.pop("model")
|
||||||
|
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class GrokfastPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for Grokfast optimizer integraton with Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_input_args(self):
|
||||||
|
return "axolotl.integrations.grokfast.GrokfastArgs"
|
||||||
|
|
||||||
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||||
|
LOG.info("Adding Grokfast callback to the trainer")
|
||||||
|
callback = GrokfastCallbackHandler(
|
||||||
|
alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb
|
||||||
|
)
|
||||||
|
return [callback]
|
||||||
15
src/axolotl/integrations/grokfast/args.py
Normal file
15
src/axolotl/integrations/grokfast/args.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
config args for grokfast plugin
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class GrokfastArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for Grokfast optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
grokfast_alpha: Optional[float] = 0.98
|
||||||
|
grokfast_lamb: Optional[float] = 2.0
|
||||||
63
src/axolotl/integrations/grokfast/optimizer.py
Normal file
63
src/axolotl/integrations/grokfast/optimizer.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
|
||||||
|
# Reference: https://github.com/ironjr/grokfast
|
||||||
|
|
||||||
|
# pylint: skip-file
|
||||||
|
from collections import deque
|
||||||
|
from typing import Dict, Literal, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def gradfilter_ma(
|
||||||
|
m: nn.Module,
|
||||||
|
grads: Optional[Dict[str, deque]] = None,
|
||||||
|
window_size: int = 100,
|
||||||
|
lamb: float = 5.0,
|
||||||
|
filter_type: Literal["mean", "sum"] = "mean",
|
||||||
|
warmup: bool = True,
|
||||||
|
trigger: bool = False, # For ablation study.
|
||||||
|
) -> Dict[str, deque]:
|
||||||
|
if grads is None:
|
||||||
|
grads = {
|
||||||
|
n: deque(maxlen=window_size)
|
||||||
|
for n, p in m.named_parameters()
|
||||||
|
if p.requires_grad and p.grad is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, p in m.named_parameters():
|
||||||
|
if p.requires_grad and p.grad is not None:
|
||||||
|
grads[n].append(p.grad.data.detach()) # .cpu())
|
||||||
|
|
||||||
|
# Modify the gradients.
|
||||||
|
if not warmup or len(grads[n]) == window_size and not trigger:
|
||||||
|
if filter_type == "mean":
|
||||||
|
avg = sum(grads[n]) / len(grads[n])
|
||||||
|
elif filter_type == "sum":
|
||||||
|
avg = sum(grads[n])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized filter_type {filter_type}")
|
||||||
|
p.grad.data = p.grad.data + avg * lamb
|
||||||
|
|
||||||
|
return grads
|
||||||
|
|
||||||
|
|
||||||
|
def gradfilter_ema(
|
||||||
|
m: nn.Module,
|
||||||
|
grads: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
alpha: float = 0.98,
|
||||||
|
lamb: float = 2.0,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
if grads is None:
|
||||||
|
grads = {
|
||||||
|
n: p.grad.data.detach()
|
||||||
|
for n, p in m.named_parameters()
|
||||||
|
if p.requires_grad and p.grad is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, p in m.named_parameters():
|
||||||
|
if p.requires_grad and p.grad is not None:
|
||||||
|
grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
|
||||||
|
p.grad.data = p.grad.data + grads[n] * lamb
|
||||||
|
|
||||||
|
return grads
|
||||||
@@ -23,6 +23,7 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||||
@@ -82,7 +83,9 @@ class LigerPlugin(BasePlugin):
|
|||||||
if cfg.liger_glu_activation:
|
if cfg.liger_glu_activation:
|
||||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
from transformers.loss.loss_utils import nn
|
||||||
|
|
||||||
|
nn.functional.cross_entropy = liger_cross_entropy
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
||||||
elif cfg.model_config_type == "deepseek_v2":
|
elif cfg.model_config_type == "deepseek_v2":
|
||||||
@@ -106,6 +109,8 @@ class LigerPlugin(BasePlugin):
|
|||||||
if cfg.liger_glu_activation:
|
if cfg.liger_glu_activation:
|
||||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
|
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
||||||
|
# nn.CrossEntropyLoss in the forward method.
|
||||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||||
|
|||||||
@@ -1,231 +0,0 @@
|
|||||||
"""
|
|
||||||
monkeypatch to add a get_turns method
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Generator, Tuple
|
|
||||||
|
|
||||||
from fastchat.conversation import SeparatorStyle
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt(self) -> str:
|
|
||||||
ret = ""
|
|
||||||
for role, msg in self.get_turns():
|
|
||||||
ret += role + msg
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def get_turns( # pylint: disable=too-many-return-statements
|
|
||||||
self,
|
|
||||||
) -> Generator[Tuple[str, str], None, None]:
|
|
||||||
"""Get the prompt for generation."""
|
|
||||||
system_prompt = self.system_template.format(system_message=self.system_message)
|
|
||||||
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
|
|
||||||
yield "", system_prompt + self.sep
|
|
||||||
for role, message in self.messages:
|
|
||||||
if message:
|
|
||||||
yield role + ": ", message + self.sep
|
|
||||||
else:
|
|
||||||
yield role + ":", ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
|
||||||
seps = [self.sep, self.sep2]
|
|
||||||
yield "", system_prompt + seps[0]
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
if message:
|
|
||||||
yield role + ": ", message + seps[i % 2]
|
|
||||||
else:
|
|
||||||
yield role + ":", ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
|
|
||||||
yield "", system_prompt + self.sep
|
|
||||||
for role, message in self.messages:
|
|
||||||
if message:
|
|
||||||
yield role + ": ", message + self.sep
|
|
||||||
else:
|
|
||||||
yield role + ": ", "" # must be end with a space
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
|
|
||||||
yield "", "" if system_prompt == "" else system_prompt + self.sep
|
|
||||||
for role, message in self.messages:
|
|
||||||
if message:
|
|
||||||
yield role + "\n", message + self.sep
|
|
||||||
else:
|
|
||||||
yield role + "\n", ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
|
||||||
yield "", system_prompt
|
|
||||||
for role, message in self.messages:
|
|
||||||
if message:
|
|
||||||
yield role, message + self.sep
|
|
||||||
else:
|
|
||||||
yield role, ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
|
|
||||||
seps = [self.sep, self.sep2]
|
|
||||||
yield "", system_prompt
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
if message:
|
|
||||||
yield role, message + seps[i % 2]
|
|
||||||
else:
|
|
||||||
yield role, ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.RWKV:
|
|
||||||
yield "", system_prompt
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
if message:
|
|
||||||
yield role + ": ", message.replace("\r\n", "\n").replace(
|
|
||||||
"\n\n", "\n"
|
|
||||||
) + "\n\n"
|
|
||||||
else:
|
|
||||||
yield role + ":", ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
|
|
||||||
if self.system_message:
|
|
||||||
if self.messages:
|
|
||||||
# For llama, the system message is incorporated into the first human instruction
|
|
||||||
first_role, first_msg = self.messages[0]
|
|
||||||
if first_role == self.roles[0]:
|
|
||||||
system_prompt += first_msg
|
|
||||||
self.messages.pop(0)
|
|
||||||
yield "", system_prompt
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
if message:
|
|
||||||
if (i % 2 == 0 and not self.system_message) or (
|
|
||||||
i % 2 != 0 and self.system_message
|
|
||||||
):
|
|
||||||
role = "<s> " + role
|
|
||||||
yield role + " ", message
|
|
||||||
else:
|
|
||||||
yield role, ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
|
|
||||||
contains_sys_msg = False
|
|
||||||
if self.system_message:
|
|
||||||
contains_sys_msg = True
|
|
||||||
if self.messages:
|
|
||||||
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction separated by a newline
|
|
||||||
first_role, first_msg = self.messages[0]
|
|
||||||
if first_role == self.roles[0]:
|
|
||||||
system_prompt = self.system_template.format(
|
|
||||||
system_message=" " + self.system_message
|
|
||||||
)
|
|
||||||
system_prompt += first_msg
|
|
||||||
self.messages.pop(0)
|
|
||||||
yield "", system_prompt
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
if message and i == 0 and not contains_sys_msg:
|
|
||||||
yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
|
|
||||||
elif message:
|
|
||||||
yield role + " ", message
|
|
||||||
else:
|
|
||||||
yield role, ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.LLAMA3:
|
|
||||||
if self.system_message:
|
|
||||||
# For llama3, the system message is NOT incorporated into the first human instruction
|
|
||||||
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
|
|
||||||
yield "", system_prompt
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
if message:
|
|
||||||
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
|
|
||||||
else:
|
|
||||||
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.GEMMA:
|
|
||||||
if self.system_message:
|
|
||||||
raise ValueError("Gemma chat template does not support system messages")
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
prefix = "<bos>" if i == 0 else ""
|
|
||||||
message_str = message if message else ""
|
|
||||||
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.CHATGLM:
|
|
||||||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
|
||||||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
|
||||||
round_add_n = 1 if self.name == "chatglm2" else 0
|
|
||||||
if system_prompt:
|
|
||||||
yield "", system_prompt + self.sep
|
|
||||||
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
if i % 2 == 0:
|
|
||||||
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
|
|
||||||
|
|
||||||
if message:
|
|
||||||
yield f"{role}:", f"{message}{self.sep}"
|
|
||||||
else:
|
|
||||||
yield f"{role}:", ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.CHATML:
|
|
||||||
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
|
|
||||||
for role, message in self.messages:
|
|
||||||
if message:
|
|
||||||
yield role + "\n", message + self.sep + "\n"
|
|
||||||
else:
|
|
||||||
yield role + "\n", ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.CHATGLM3:
|
|
||||||
if self.system_message:
|
|
||||||
yield "", system_prompt
|
|
||||||
for role, message in self.messages:
|
|
||||||
if message:
|
|
||||||
yield role + "\n", " " + message
|
|
||||||
else:
|
|
||||||
yield role
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.CHATINTERN:
|
|
||||||
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
|
||||||
seps = [self.sep, self.sep2]
|
|
||||||
yield "", system_prompt
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
prefix = "<s>" if i % 2 == 0 else ""
|
|
||||||
if message:
|
|
||||||
yield prefix + role + ":", message + seps[i % 2] + "\n"
|
|
||||||
else:
|
|
||||||
yield role + ":", ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.DOLLY:
|
|
||||||
seps = [self.sep, self.sep2]
|
|
||||||
yield "", system_prompt
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
if message:
|
|
||||||
suffix = "\n\n" if i % 2 == 1 else ""
|
|
||||||
yield role + ":\n", message + seps[i % 2] + suffix
|
|
||||||
else:
|
|
||||||
yield role + ":\n", ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.PHOENIX:
|
|
||||||
yield "", system_prompt
|
|
||||||
for role, message in self.messages:
|
|
||||||
if message:
|
|
||||||
yield role + ": ", "<s>" + message + "</s>"
|
|
||||||
else:
|
|
||||||
yield role + ": " + "<s>", ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.ROBIN:
|
|
||||||
yield "", system_prompt + self.sep
|
|
||||||
for role, message in self.messages:
|
|
||||||
if message:
|
|
||||||
yield role + ":\n", message + self.sep
|
|
||||||
else:
|
|
||||||
yield role + ":\n", ""
|
|
||||||
return
|
|
||||||
if self.sep_style == SeparatorStyle.FALCON_CHAT:
|
|
||||||
if self.system_message:
|
|
||||||
yield "", system_prompt + self.sep
|
|
||||||
for role, message in self.messages:
|
|
||||||
if message:
|
|
||||||
yield role + ": ", message + self.sep
|
|
||||||
else:
|
|
||||||
yield role + ":", ""
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
|
||||||
|
|
||||||
|
|
||||||
def add_get_turns_to_conversation():
|
|
||||||
import fastchat.conversation
|
|
||||||
|
|
||||||
fastchat.conversation.Conversation.get_turns = get_turns
|
|
||||||
fastchat.conversation.Conversation.get_prompt = get_prompt
|
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
"""multipack patching for v2 of sample packing"""
|
"""multipack patching for v2 of sample packing"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
@@ -19,6 +20,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"falcon",
|
"falcon",
|
||||||
"phi",
|
"phi",
|
||||||
"phi3",
|
"phi3",
|
||||||
|
"phimoe",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
"gemmoe",
|
"gemmoe",
|
||||||
@@ -27,71 +29,28 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
|
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
||||||
if model_type == "gemmoe":
|
if has_remote_code:
|
||||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
patch_remote(model_name)
|
||||||
elif model_type == "deepseek_v2":
|
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||||
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
|
||||||
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
|
|
||||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
|
||||||
patch_mixtral_moe_forward_zero3()
|
|
||||||
return
|
|
||||||
|
|
||||||
# retain for legacy
|
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||||
if model_type == "mixtral":
|
patch_mixtral_moe_forward_zero3()
|
||||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
patch_mixtral_moe_forward_zero3()
|
|
||||||
elif model_type == "llama":
|
|
||||||
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
|
|
||||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "mistral":
|
|
||||||
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
|
|
||||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "qwen2":
|
|
||||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "qwen2_moe":
|
|
||||||
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "falcon":
|
|
||||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "phi":
|
|
||||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "gemma":
|
|
||||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "gemma2":
|
|
||||||
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "starcoder2":
|
|
||||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_remote(model_name, config_name, modeling_name):
|
def patch_remote(model_name):
|
||||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||||
# we need to load the model here in order for modeling_* to be available
|
# we need to load the model here in order for modeling_* to be available
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||||
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
|
parts = model_config.__class__.__module__.split(".")
|
||||||
|
parts[-1] = parts[-1].replace("configuration_", "modeling_", 1)
|
||||||
|
module_name = ".".join(parts)
|
||||||
modeling_arch = importlib.import_module(module_name)
|
modeling_arch = importlib.import_module(module_name)
|
||||||
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
|
if hasattr(modeling_arch, "_get_unpad_data"):
|
||||||
|
modeling_arch._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
|||||||
83
src/axolotl/monkeypatch/trainer_fsdp_grad_accum.py
Normal file
83
src/axolotl/monkeypatch/trainer_fsdp_grad_accum.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""
|
||||||
|
fix for FSDP gradient accumulation
|
||||||
|
see https://github.com/huggingface/transformers/pull/34645
|
||||||
|
"""
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||||
|
|
||||||
|
LOG = get_logger("axolotl.monkeypatch.trainer_fsdp_grad_accumulation")
|
||||||
|
|
||||||
|
ORIGINAL_CONTEXT_CODE = """
|
||||||
|
context = (
|
||||||
|
functools.partial(self.accelerator.no_sync, model=model)
|
||||||
|
if i == len(batch_samples) - 1
|
||||||
|
else contextlib.nullcontext
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_CONTEXT_CODE = """
|
||||||
|
context = (
|
||||||
|
functools.partial(self.accelerator.no_sync, model=model)
|
||||||
|
if i != len(batch_samples) - 1
|
||||||
|
else contextlib.nullcontext
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_loop_code() -> str:
|
||||||
|
training_loop = inspect.getsource(
|
||||||
|
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
return training_loop
|
||||||
|
|
||||||
|
|
||||||
|
def check_training_loop_is_patchable() -> bool:
|
||||||
|
train_loop = get_training_loop_code()
|
||||||
|
train_loop, _ = detab_code(train_loop)
|
||||||
|
return ORIGINAL_CONTEXT_CODE in train_loop
|
||||||
|
|
||||||
|
|
||||||
|
def patch_training_loop_for_fsdp_grad_accum():
|
||||||
|
"""
|
||||||
|
monkeypatch for fixing the training loop for FSDP gradient accumulation
|
||||||
|
"""
|
||||||
|
|
||||||
|
train_loop = get_training_loop_code()
|
||||||
|
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
|
||||||
|
train_loop
|
||||||
|
)
|
||||||
|
train_loop, _ = detab_code(train_loop)
|
||||||
|
assert (
|
||||||
|
ORIGINAL_CONTEXT_CODE in train_loop
|
||||||
|
), "Original _inner_training_loop code not found"
|
||||||
|
|
||||||
|
train_loop = train_loop.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
|
||||||
|
train_loop = train_loop.replace(
|
||||||
|
"def _inner_training_loop(",
|
||||||
|
"def _fixed_inner_training_loop(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.trainer
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.trainer):
|
||||||
|
if item in train_loop:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.trainer import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(train_loop, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching _inner_training_loop", main_process_only=True)
|
||||||
|
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||||
|
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
"""Module containing the InstructShareGPTPromptTokenizingStrategy class"""
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
|
||||||
from axolotl.prompters import ShareGPTPrompterV2
|
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
||||||
conversation = (
|
|
||||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
|
||||||
)
|
|
||||||
strategy = InstructShareGPTPromptTokenizingStrategy(
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation=conversation,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
return strategy
|
|
||||||
|
|
||||||
|
|
||||||
class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
basic sharegpt strategy to grab conversations from the sample row
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
|
||||||
return [
|
|
||||||
{"from": "human", "value": prompt["instruction"]},
|
|
||||||
{"from": "gpt", "value": prompt["output"]},
|
|
||||||
]
|
|
||||||
@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Generator, List, Sequence
|
from typing import Generator, List, Sequence
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
|
from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -75,7 +75,7 @@ class Llama2ChatConversation:
|
|||||||
|
|
||||||
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
||||||
"""
|
"""
|
||||||
Tokenizing strategy for ShareGPT prompts.
|
Tokenizing strategy for Llama2 prompts.
|
||||||
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
|
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -191,7 +191,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
|
|||||||
conv.messages = [] # pylint: disable=R0801
|
conv.messages = [] # pylint: disable=R0801
|
||||||
for j, sentence in enumerate(source):
|
for j, sentence in enumerate(source):
|
||||||
role = roles[sentence["from"]]
|
role = roles[sentence["from"]]
|
||||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
assert role == conv.roles[j % 2], ALTERNATING_ASSERTION_FAILED_ROLE
|
||||||
if sentence["value"]:
|
if sentence["value"]:
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
yield conv
|
yield conv
|
||||||
|
|||||||
@@ -1,223 +0,0 @@
|
|||||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Dict, Optional, Type
|
|
||||||
|
|
||||||
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
|
||||||
from axolotl.prompters import ShareGPTPrompterV2
|
|
||||||
from axolotl.utils.tokenization import (
|
|
||||||
chatml_to_conversation,
|
|
||||||
merge_consecutive_messages,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
|
||||||
|
|
||||||
|
|
||||||
def register_chatml_template(system_message=None):
|
|
||||||
system_message = system_message or "You are a helpful assistant."
|
|
||||||
register_conv_template(
|
|
||||||
Conversation(
|
|
||||||
name="chatml",
|
|
||||||
system_template="<|im_start|>system\n{system_message}",
|
|
||||||
system_message=system_message,
|
|
||||||
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
|
||||||
sep_style=SeparatorStyle.CHATML,
|
|
||||||
sep="<|im_end|>",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
register_conv_template(
|
|
||||||
Conversation(
|
|
||||||
name="chatml_glaive",
|
|
||||||
system_template="<|im_start|>system\n{system_message}",
|
|
||||||
system_message=system_message,
|
|
||||||
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
|
|
||||||
sep_style=SeparatorStyle.CHATML,
|
|
||||||
sep="<|im_end|>",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def register_llama3_template(system_message=None):
|
|
||||||
system_message = system_message or "You are a helpful assistant."
|
|
||||||
register_conv_template(
|
|
||||||
Conversation(
|
|
||||||
name="llama3",
|
|
||||||
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
|
||||||
system_message=system_message,
|
|
||||||
roles=("user", "assistant"),
|
|
||||||
sep_style=SeparatorStyle.LLAMA3,
|
|
||||||
sep="",
|
|
||||||
stop_str="<|eot_id|>",
|
|
||||||
stop_token_ids=[128001, 128009],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_loader(
|
|
||||||
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
|
|
||||||
prompter_cls: Type["ShareGPTPrompterV2"],
|
|
||||||
default_conversation: Optional[str] = None,
|
|
||||||
):
|
|
||||||
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
||||||
LOG.warning(
|
|
||||||
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template",
|
|
||||||
)
|
|
||||||
conversation = (
|
|
||||||
ds_cfg["conversation"]
|
|
||||||
if ds_cfg and "conversation" in ds_cfg
|
|
||||||
else default_conversation
|
|
||||||
)
|
|
||||||
field_human = (
|
|
||||||
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
|
||||||
)
|
|
||||||
field_model = (
|
|
||||||
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
|
||||||
)
|
|
||||||
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
|
||||||
strategy = tokenization_strategy_cls(
|
|
||||||
prompter_cls(
|
|
||||||
conversation=conversation,
|
|
||||||
role_key_model=field_model,
|
|
||||||
role_key_human=field_human,
|
|
||||||
roles=roles,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
|
|
||||||
strategy.strict = ds_cfg["strict"]
|
|
||||||
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
|
||||||
strategy.messages = ds_cfg["field_messages"]
|
|
||||||
return strategy
|
|
||||||
|
|
||||||
return _load
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
basic sharegpt strategy to grab conversations from the sample row
|
|
||||||
"""
|
|
||||||
|
|
||||||
_strict = False
|
|
||||||
_messages = "conversations"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def strict(self):
|
|
||||||
return self._strict
|
|
||||||
|
|
||||||
@strict.setter
|
|
||||||
def strict(self, strict):
|
|
||||||
self._strict = strict
|
|
||||||
|
|
||||||
@property
|
|
||||||
def messages(self):
|
|
||||||
return self._messages
|
|
||||||
|
|
||||||
@messages.setter
|
|
||||||
def messages(self, messages):
|
|
||||||
self._messages = messages
|
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
|
||||||
conversations = prompt[self.messages]
|
|
||||||
if self.strict:
|
|
||||||
return conversations
|
|
||||||
role_key = "from"
|
|
||||||
if "role" in conversations[0].keys():
|
|
||||||
role_key = "role"
|
|
||||||
value_key = "value"
|
|
||||||
if "text" in conversations[0].keys():
|
|
||||||
value_key = "text"
|
|
||||||
elif "content" in conversations[0].keys():
|
|
||||||
value_key = "content"
|
|
||||||
# remap roles - allow for assistant turn"
|
|
||||||
role_map = {
|
|
||||||
"user": "human",
|
|
||||||
"human": "human",
|
|
||||||
"assistant": "gpt",
|
|
||||||
"gpt": "gpt",
|
|
||||||
"system": "system",
|
|
||||||
}
|
|
||||||
turns = [
|
|
||||||
{
|
|
||||||
"from": (
|
|
||||||
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
|
|
||||||
),
|
|
||||||
"value": t[value_key],
|
|
||||||
"weight": 1
|
|
||||||
if "weight" not in t or t["weight"] is None
|
|
||||||
else t["weight"],
|
|
||||||
}
|
|
||||||
for t in conversations
|
|
||||||
]
|
|
||||||
return turns
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleRoleShareGPTPromptTokenizingStrategy(
|
|
||||||
SimpleShareGPTPromptTokenizingStrategy
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
|
||||||
conversations = prompt["conversations"]
|
|
||||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
|
||||||
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
|
|
||||||
return turns
|
|
||||||
|
|
||||||
|
|
||||||
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
sharegpt strategy that remaps oasst data to sharegpt format
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
|
||||||
conversations = prompt["conversations"]
|
|
||||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
|
||||||
role_map = {"prompter": "human", "assistant": "gpt"}
|
|
||||||
turns = [
|
|
||||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
|
||||||
]
|
|
||||||
return turns
|
|
||||||
|
|
||||||
|
|
||||||
class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
sharegpt strategy that remaps ultrachat data to sharegpt format
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
|
||||||
conversations = prompt["messages"]
|
|
||||||
role_map = {"user": "human", "assistant": "gpt"}
|
|
||||||
turns = [
|
|
||||||
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
|
|
||||||
]
|
|
||||||
return turns
|
|
||||||
|
|
||||||
|
|
||||||
class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
sharegpt strategy that remaps glaive data to sharegpt format
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
|
||||||
conversation = chatml_to_conversation(prompt)
|
|
||||||
conversation = merge_consecutive_messages(conversation)
|
|
||||||
|
|
||||||
return conversation
|
|
||||||
|
|
||||||
|
|
||||||
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
|
||||||
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
|
||||||
load_ultrachat = build_loader(
|
|
||||||
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
|
|
||||||
)
|
|
||||||
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
|
||||||
load_glaive = build_loader(
|
|
||||||
GlaiveShareGPTPromptTokenizingStrategy,
|
|
||||||
ShareGPTPrompterV2,
|
|
||||||
default_conversation="chatml_glaive",
|
|
||||||
)
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
"""Module for Jokes prompts using sharegpt style """
|
|
||||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
|
||||||
from axolotl.prompters import ShareGPTPrompterV2
|
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg):
|
|
||||||
return SimpleJokesShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(),
|
|
||||||
tokenizer,
|
|
||||||
cfg.train_on_inputs,
|
|
||||||
cfg.sequence_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
Tokenization strategy for asking bot to tell a joke and then explain why its funny
|
|
||||||
"""
|
|
||||||
|
|
||||||
# title, text, explanation
|
|
||||||
def get_conversation_thread(self, prompt):
|
|
||||||
title = "" if not prompt["title"] else prompt["title"] + " "
|
|
||||||
return [
|
|
||||||
{"from": "human", "value": "Tell me a joke."},
|
|
||||||
{"from": "gpt", "value": title + prompt["text"]},
|
|
||||||
{"from": "human", "value": "Why is that joke funny?"},
|
|
||||||
{"from": "gpt", "value": prompt["explanation"]},
|
|
||||||
]
|
|
||||||
@@ -1,17 +1,12 @@
|
|||||||
"""Module containing PromptTokenizingStrategy and Prompter classes"""
|
"""Module containing PromptTokenizingStrategy and Prompter classes"""
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
import copy
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
from fastchat.conversation import Conversation
|
|
||||||
from transformers import BatchEncoding, PreTrainedTokenizer
|
from transformers import BatchEncoding, PreTrainedTokenizer
|
||||||
|
|
||||||
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
from axolotl.prompters import Prompter
|
||||||
add_get_turns_to_conversation,
|
|
||||||
)
|
|
||||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -21,8 +16,6 @@ LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
|||||||
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
||||||
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
||||||
|
|
||||||
add_get_turns_to_conversation()
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidDataException(Exception):
|
class InvalidDataException(Exception):
|
||||||
"""
|
"""
|
||||||
@@ -331,154 +324,6 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
||||||
"""
|
|
||||||
Tokenizing strategy for ShareGPT prompts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
|
||||||
return prompt["conversations"]
|
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
|
||||||
# Initial values. We will append to these as we go through the conversation.
|
|
||||||
result, current_len = tokenize_prompt_default()
|
|
||||||
conversation: Conversation = (
|
|
||||||
self.prompter._conversation.copy() # pylint: disable=protected-access
|
|
||||||
)
|
|
||||||
|
|
||||||
input_roles = {conversation.roles[0]}
|
|
||||||
output_roles = {conversation.roles[1]}
|
|
||||||
|
|
||||||
if len(conversation.roles) == 3:
|
|
||||||
tool_role_label = conversation.roles[2]
|
|
||||||
input_roles.add(tool_role_label)
|
|
||||||
|
|
||||||
# Add roles from the config
|
|
||||||
if self.prompter.roles:
|
|
||||||
if "input" in self.prompter.roles and self.prompter.roles["input"]:
|
|
||||||
for role in self.prompter.roles["input"]:
|
|
||||||
input_roles.add(role)
|
|
||||||
|
|
||||||
if "output" in self.prompter.roles and self.prompter.roles["output"]:
|
|
||||||
for role in self.prompter.roles["output"]:
|
|
||||||
output_roles.add(role)
|
|
||||||
|
|
||||||
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
|
||||||
role_remap = []
|
|
||||||
if (
|
|
||||||
conversation.name == "vicuna_v1.1"
|
|
||||||
and "roles" in prompt
|
|
||||||
and len(prompt["roles"]) >= 2
|
|
||||||
):
|
|
||||||
role_remap = [
|
|
||||||
{"from": conversation.roles[0], "to": prompt["roles"][0]},
|
|
||||||
{"from": conversation.roles[1], "to": prompt["roles"][1]},
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
for _, part in enumerate(
|
|
||||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
|
||||||
):
|
|
||||||
if not isinstance(part, tuple):
|
|
||||||
LOG.warning(f"expected tuple, got {part}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if len(part) <= 2:
|
|
||||||
role, content = part
|
|
||||||
weight = 1
|
|
||||||
else:
|
|
||||||
role, content, weight = part
|
|
||||||
|
|
||||||
# Uses "in" because role contains extra characters
|
|
||||||
input_turn = any(r.lower() in role.lower() for r in input_roles)
|
|
||||||
output_turn = any(r.lower() in role.lower() for r in output_roles)
|
|
||||||
empty_role = role.strip() == ""
|
|
||||||
|
|
||||||
if not any([input_turn, output_turn, empty_role]):
|
|
||||||
LOG.warning(f"unhandled role: {role}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if input_turn:
|
|
||||||
role = (
|
|
||||||
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
|
||||||
if role_remap
|
|
||||||
else role
|
|
||||||
)
|
|
||||||
turn = role + content
|
|
||||||
# this is still the user query, we should
|
|
||||||
if not content.strip():
|
|
||||||
LOG.warning(f"user turn has empty text: {prompt}")
|
|
||||||
res = self._tokenize(
|
|
||||||
turn,
|
|
||||||
add_eos_token=False,
|
|
||||||
strip_bos_token=True,
|
|
||||||
)
|
|
||||||
if self.train_on_inputs and weight == 1:
|
|
||||||
labels = copy.deepcopy(res["input_ids"])
|
|
||||||
else:
|
|
||||||
# everything from this is masked out from the labels
|
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
|
||||||
elif output_turn:
|
|
||||||
role = (
|
|
||||||
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
|
||||||
if role_remap
|
|
||||||
else role
|
|
||||||
)
|
|
||||||
turn = role + content
|
|
||||||
# this should be the assistant response, should end with an eos token
|
|
||||||
if not content.strip():
|
|
||||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
|
||||||
add_eos_token = not (
|
|
||||||
conversation.name == "chatml"
|
|
||||||
and conversation.sep == self.tokenizer.eos_token
|
|
||||||
)
|
|
||||||
res = self._tokenize(
|
|
||||||
turn,
|
|
||||||
add_eos_token=add_eos_token,
|
|
||||||
strip_bos_token=True,
|
|
||||||
)
|
|
||||||
role_res = self._tokenize(
|
|
||||||
role.rstrip(),
|
|
||||||
add_eos_token=False,
|
|
||||||
strip_bos_token=True,
|
|
||||||
)
|
|
||||||
labels = copy.deepcopy(res["input_ids"])
|
|
||||||
if not self.train_on_inputs:
|
|
||||||
# mask out role tokens from the labels
|
|
||||||
len_role = len(role_res["input_ids"])
|
|
||||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
|
||||||
len_role, len(labels)
|
|
||||||
)
|
|
||||||
if weight == 0:
|
|
||||||
# everything from this is masked out from the labels
|
|
||||||
# (role is masked out too because it makes no sense if contents is masked out)
|
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
|
||||||
|
|
||||||
elif empty_role:
|
|
||||||
turn = content
|
|
||||||
# this is only ever the first part, should include the bos token and the user query
|
|
||||||
res = self._tokenize(
|
|
||||||
turn, add_eos_token=False, strip_bos_token=False
|
|
||||||
)
|
|
||||||
if self.train_on_inputs and weight == 1:
|
|
||||||
labels = copy.deepcopy(res["input_ids"])
|
|
||||||
else:
|
|
||||||
# everything from this is masked out from the labels
|
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
result, current_len = parse_tokenized_to_result(
|
|
||||||
result,
|
|
||||||
current_len,
|
|
||||||
res,
|
|
||||||
labels,
|
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
except (KeyError, AssertionError, IndexError) as err:
|
|
||||||
raise InvalidDataException(str(err)) from err
|
|
||||||
|
|
||||||
|
|
||||||
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
||||||
"""
|
"""
|
||||||
Returns the default values for the tokenize prompt function
|
Returns the default values for the tokenize prompt function
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from enum import Enum
|
|||||||
from typing import Generator, Optional, Union
|
from typing import Generator, Optional, Union
|
||||||
|
|
||||||
from colorama import Fore
|
from colorama import Fore
|
||||||
from fastchat.conversation import Conversation, get_conv_template
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
IGNORE_TOKEN_ID = -100
|
IGNORE_TOKEN_ID = -100
|
||||||
@@ -262,166 +261,10 @@ class ReflectAlpacaPrompter(Prompter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
ALTERNATING_ASSERTION_FAILED_ROLE = (
|
||||||
"Role did not alternate between turns (gpt and human). Please check your data."
|
"Role did not alternate between turns (gpt and human). Please check your data."
|
||||||
)
|
)
|
||||||
|
|
||||||
CONVERSATION_ROLE_FORMAT = {
|
|
||||||
"chatml": "<|im_start|>{ROLE}",
|
|
||||||
"zephyr": "<|{ROLE}|>",
|
|
||||||
"vicuna_v1.1": "{ROLE}",
|
|
||||||
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
|
||||||
"""
|
|
||||||
A prompter that generates prompts for the ShareGPT
|
|
||||||
"""
|
|
||||||
|
|
||||||
role_key_human = "human"
|
|
||||||
role_key_model = "gpt"
|
|
||||||
# Optional, only used for tool usage datasets.
|
|
||||||
role_key_tool: Optional[str] = None
|
|
||||||
# Optional, role input/output mapping
|
|
||||||
roles: Optional[dict] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prompt_style=None, # pylint: disable=unused-argument
|
|
||||||
conversation: Optional[Union[str, Conversation]] = None,
|
|
||||||
role_key_human: Optional[str] = None,
|
|
||||||
role_key_model: Optional[str] = None,
|
|
||||||
role_key_tool: Optional[str] = None,
|
|
||||||
roles: Optional[dict] = None,
|
|
||||||
):
|
|
||||||
if conversation:
|
|
||||||
if isinstance(conversation, Conversation):
|
|
||||||
self._conversation = conversation
|
|
||||||
else:
|
|
||||||
self._conversation = get_conv_template(conversation)
|
|
||||||
else:
|
|
||||||
self._conversation = get_conv_template("vicuna_v1.1")
|
|
||||||
if role_key_human:
|
|
||||||
self.role_key_human = role_key_human
|
|
||||||
if role_key_model:
|
|
||||||
self.role_key_model = role_key_model
|
|
||||||
if role_key_tool:
|
|
||||||
self.role_key_tool = role_key_tool
|
|
||||||
if roles:
|
|
||||||
self.roles = roles
|
|
||||||
|
|
||||||
def _build_result(self, source):
|
|
||||||
if len(source) < 2:
|
|
||||||
# If there isn't a back and forth conversation, ignore it
|
|
||||||
# also happens on the data splitting leaving empty conversations
|
|
||||||
raise IndexError(
|
|
||||||
f"A conversation entry has less than 2 messages :\n{source}"
|
|
||||||
)
|
|
||||||
|
|
||||||
conv = self._conversation.copy()
|
|
||||||
|
|
||||||
original_source = source.copy()
|
|
||||||
# Add the conversation system prompt if provided, otherwise use the default one
|
|
||||||
if source[0]["from"] == "system":
|
|
||||||
conv.set_system_message(source[0]["value"])
|
|
||||||
source.pop(0)
|
|
||||||
|
|
||||||
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
|
|
||||||
if self.role_key_tool:
|
|
||||||
roles[self.role_key_tool] = conv.roles[2]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Apply prompt templates
|
|
||||||
if source[0]["from"] not in roles:
|
|
||||||
# Skip the first one if it is not from human
|
|
||||||
source = source[1:]
|
|
||||||
except IndexError as err:
|
|
||||||
# sometimes there is a bing or system chat
|
|
||||||
raise err
|
|
||||||
|
|
||||||
conv.messages = []
|
|
||||||
for _, sentence in enumerate(source):
|
|
||||||
from_role = sentence["from"]
|
|
||||||
if from_role in roles:
|
|
||||||
role = roles[from_role]
|
|
||||||
else:
|
|
||||||
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
|
|
||||||
"Please help us by creating an Issue to add support for this conversation type."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._conversation.name in ["llama3"]:
|
|
||||||
role = from_role
|
|
||||||
else:
|
|
||||||
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
|
|
||||||
ROLE=from_role
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
|
||||||
if (
|
|
||||||
role != "assistant"
|
|
||||||
): # back to back assistant calls may be okay for tool calls
|
|
||||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
|
||||||
|
|
||||||
conv.append_message(role, sentence["value"])
|
|
||||||
turns = list(conv.get_turns())
|
|
||||||
original_source_length = len(original_source)
|
|
||||||
assert len(turns) in [
|
|
||||||
original_source_length - 1,
|
|
||||||
original_source_length,
|
|
||||||
original_source_length + 1,
|
|
||||||
]
|
|
||||||
if len(turns) == original_source_length + 1:
|
|
||||||
original_source = [{"weight": None}] + original_source
|
|
||||||
elif len(turns) == original_source_length - 1:
|
|
||||||
original_source = original_source[1:]
|
|
||||||
return [
|
|
||||||
(*turn, weight)
|
|
||||||
for turn, weight in zip(
|
|
||||||
turns,
|
|
||||||
[
|
|
||||||
1 if "weight" not in e or e["weight"] is None else e["weight"]
|
|
||||||
for e in original_source
|
|
||||||
],
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
|
||||||
turns = self._build_result(source)
|
|
||||||
|
|
||||||
for part in turns:
|
|
||||||
if part[0] and not part[1]:
|
|
||||||
LOG.warning(f"role with empty message: {part[0]}")
|
|
||||||
yield part
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
|
|
||||||
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
|
|
||||||
|
|
||||||
|
|
||||||
class ShareGPTPrompterV2(ShareGPTPrompter):
|
|
||||||
"""
|
|
||||||
A V2 prompter that generates prompts for the ShareGPT
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
conversation: Optional[Union[str, Conversation]] = None,
|
|
||||||
role_key_human: Optional[str] = None,
|
|
||||||
role_key_model: Optional[str] = None,
|
|
||||||
role_key_tool: Optional[str] = None,
|
|
||||||
roles: Optional[dict] = None,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
conversation=conversation,
|
|
||||||
role_key_human=role_key_human,
|
|
||||||
role_key_model=role_key_model,
|
|
||||||
role_key_tool=role_key_tool,
|
|
||||||
roles=roles,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedPrompter(Prompter):
|
class UnsupportedPrompter(Prompter):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -64,10 +64,7 @@ class EvalFirstStepCallback(
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if (
|
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
||||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
|
||||||
and state.global_step == 1
|
|
||||||
):
|
|
||||||
control.should_evaluate = True
|
control.should_evaluate = True
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,8 +1,6 @@
|
|||||||
"""Module for working with config dicts"""
|
"""Module for working with config dicts"""
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -10,7 +8,6 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
|
|
||||||
from axolotl.integrations.config import merge_input_args
|
from axolotl.integrations.config import merge_input_args
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import SUPPORTED_METRICS
|
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
@@ -215,11 +212,6 @@ def normalize_cfg_datasets(cfg):
|
|||||||
if cfg.chat_template:
|
if cfg.chat_template:
|
||||||
if cfg.datasets:
|
if cfg.datasets:
|
||||||
for idx, ds_cfg in enumerate(cfg.datasets):
|
for idx, ds_cfg in enumerate(cfg.datasets):
|
||||||
if ds_cfg.type == "sharegpt" and not ds_cfg.conversation:
|
|
||||||
LOG.info(
|
|
||||||
f"updating dataset {ds_cfg.path} with `conversation: {cfg.chat_template}` to match your chat_template"
|
|
||||||
)
|
|
||||||
cfg.datasets[idx].conversation = cfg.chat_template
|
|
||||||
if (
|
if (
|
||||||
ds_cfg.type in ["orpo.chat_template", "chat_template"]
|
ds_cfg.type in ["orpo.chat_template", "chat_template"]
|
||||||
and not ds_cfg.chat_template
|
and not ds_cfg.chat_template
|
||||||
@@ -252,391 +244,3 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
|||||||
return DictDefault(
|
return DictDefault(
|
||||||
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def legacy_validate_config(cfg):
|
|
||||||
"""
|
|
||||||
This is a "pre-validation" step that handles the yaml configuration before we have any
|
|
||||||
information about the model architecture
|
|
||||||
"""
|
|
||||||
if is_torch_bf16_gpu_available():
|
|
||||||
if not cfg.bf16 and not cfg.bfloat16:
|
|
||||||
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
|
||||||
else:
|
|
||||||
if (
|
|
||||||
not cfg.merge_lora
|
|
||||||
and not cfg.is_preprocess
|
|
||||||
and (cfg.bf16 is True or cfg.bfloat16 is True)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
# pylint: disable=too-many-boolean-expressions
|
|
||||||
not (cfg.bf16 or cfg.bfloat16)
|
|
||||||
and (cfg.fp16 or cfg.float16)
|
|
||||||
and not cfg.adapter
|
|
||||||
and not cfg.flash_attention
|
|
||||||
and cfg.sample_packing
|
|
||||||
):
|
|
||||||
LOG.warning(
|
|
||||||
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
|
|
||||||
)
|
|
||||||
# ValueError: Attempting to unscale FP16 gradients.
|
|
||||||
# OR
|
|
||||||
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
|
|
||||||
if cfg.max_packed_sequence_len:
|
|
||||||
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
|
||||||
|
|
||||||
if cfg.sample_packing and cfg.rl:
|
|
||||||
raise ValueError("`sample_packing: true` does not work with RLHF training")
|
|
||||||
|
|
||||||
if cfg.sample_packing and not cfg.pad_to_sequence_len:
|
|
||||||
LOG.warning(
|
|
||||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
|
||||||
raise ValueError(
|
|
||||||
"please set only one of gradient_accumulation_steps or batch_size"
|
|
||||||
)
|
|
||||||
if cfg.batch_size:
|
|
||||||
LOG.warning(
|
|
||||||
"%s\n%s",
|
|
||||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
|
||||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
cfg.eval_batch_size
|
|
||||||
and cfg.micro_batch_size
|
|
||||||
and cfg.eval_batch_size != cfg.micro_batch_size
|
|
||||||
):
|
|
||||||
LOG.warning(
|
|
||||||
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.adapter == "qlora":
|
|
||||||
if cfg.merge_lora:
|
|
||||||
# can't merge qlora if loaded in 8bit or 4bit
|
|
||||||
if cfg.load_in_8bit:
|
|
||||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
|
||||||
|
|
||||||
if cfg.gptq:
|
|
||||||
raise ValueError("Can't merge qlora if gptq")
|
|
||||||
|
|
||||||
if cfg.load_in_4bit:
|
|
||||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
|
||||||
|
|
||||||
else:
|
|
||||||
if cfg.load_in_8bit:
|
|
||||||
raise ValueError("Can't load qlora in 8bit")
|
|
||||||
|
|
||||||
if cfg.gptq:
|
|
||||||
raise ValueError("Can't load qlora if gptq")
|
|
||||||
|
|
||||||
if not cfg.load_in_4bit:
|
|
||||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
|
||||||
|
|
||||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
|
||||||
raise ValueError("Fused modules are not supported with QLoRA")
|
|
||||||
|
|
||||||
loftq = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
|
||||||
if not cfg.load_in_8bit and cfg.adapter == "lora" and not loftq:
|
|
||||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
|
||||||
|
|
||||||
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
|
||||||
raise ValueError("Fused modules are not supported with LoRA")
|
|
||||||
|
|
||||||
if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters:
|
|
||||||
raise ValueError(
|
|
||||||
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.relora_steps:
|
|
||||||
if cfg.adapter not in ("lora", "qlora"):
|
|
||||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
|
||||||
|
|
||||||
if cfg.fsdp:
|
|
||||||
raise ValueError("fsdp not supported with ReLoRA")
|
|
||||||
|
|
||||||
if cfg.deepspeed:
|
|
||||||
raise ValueError("deepspeed not supported with ReLoRA")
|
|
||||||
|
|
||||||
if cfg.lr_scheduler == "one_cycle":
|
|
||||||
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
|
||||||
|
|
||||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
|
||||||
raise ValueError("Fused modules are not supported with ReLoRA")
|
|
||||||
|
|
||||||
if cfg.trust_remote_code:
|
|
||||||
LOG.warning(
|
|
||||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
|
|
||||||
raise ValueError(
|
|
||||||
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
|
||||||
raise ValueError("FSDP is not supported for falcon models")
|
|
||||||
|
|
||||||
if (
|
|
||||||
cfg.base_model and "mpt" in cfg.base_model.lower()
|
|
||||||
) and cfg.gradient_checkpointing:
|
|
||||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
|
||||||
|
|
||||||
if cfg.flash_optimum is True:
|
|
||||||
if cfg.adapter:
|
|
||||||
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
|
||||||
if cfg.fp16 or cfg.bf16:
|
|
||||||
raise ValueError("AMP is not supported with BetterTransformer")
|
|
||||||
if cfg.float16 is not True and cfg.bfloat16 is not True:
|
|
||||||
LOG.warning(
|
|
||||||
"You should probably set bfloat16 or float16 to true to "
|
|
||||||
"load the model in float16 for BetterTransformers"
|
|
||||||
)
|
|
||||||
if int(torch.__version__.split(".", maxsplit=1)[0]) < 2:
|
|
||||||
LOG.warning("torch>=2.0.0 required")
|
|
||||||
raise ValueError(
|
|
||||||
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.pretraining_dataset and cfg.group_by_length:
|
|
||||||
LOG.warning(
|
|
||||||
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
|
||||||
)
|
|
||||||
if cfg.pretraining_dataset and not cfg.max_steps:
|
|
||||||
raise ValueError(
|
|
||||||
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
|
|
||||||
)
|
|
||||||
|
|
||||||
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
|
||||||
not cfg.optimizer or "adamw" not in cfg.optimizer
|
|
||||||
):
|
|
||||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
|
||||||
|
|
||||||
if cfg.push_to_hub_model_id:
|
|
||||||
raise ValueError(
|
|
||||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
|
|
||||||
LOG.warning(
|
|
||||||
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.gptq and cfg.revision_of_model:
|
|
||||||
raise ValueError(
|
|
||||||
"revision_of_model is not supported for GPTQ models. "
|
|
||||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
|
||||||
+ "point to its path, and remove revision_of_model from the config."
|
|
||||||
)
|
|
||||||
|
|
||||||
# if cfg.sample_packing and cfg.sdp_attention:
|
|
||||||
# # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
|
||||||
# raise ValueError(
|
|
||||||
# "sample_packing not compatible with sdp_attention. Use flash_attention"
|
|
||||||
# )
|
|
||||||
|
|
||||||
if cfg.sample_packing and cfg.xformers_attention:
|
|
||||||
raise ValueError(
|
|
||||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
|
|
||||||
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
|
|
||||||
LOG.warning(
|
|
||||||
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
|
|
||||||
"This may work on H100s."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.early_stopping_patience:
|
|
||||||
if not cfg.save_steps or not cfg.eval_steps:
|
|
||||||
raise ValueError(
|
|
||||||
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
|
||||||
)
|
|
||||||
if cfg.save_steps % cfg.eval_steps != 0:
|
|
||||||
raise ValueError(
|
|
||||||
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.datasets:
|
|
||||||
for idx, ds_cfg in enumerate(cfg.datasets):
|
|
||||||
if not ds_cfg.type:
|
|
||||||
continue
|
|
||||||
if ds_cfg.type == "sharegpt:chat":
|
|
||||||
LOG.warning(
|
|
||||||
PendingDeprecationWarning(
|
|
||||||
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cfg.datasets[idx].type = "sharegpt"
|
|
||||||
if "sharegpt_simple" in ds_cfg.type:
|
|
||||||
LOG.warning(
|
|
||||||
PendingDeprecationWarning(
|
|
||||||
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
|
||||||
"sharegpt_simple", "sharegpt"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.saves_per_epoch and cfg.save_steps:
|
|
||||||
raise ValueError(
|
|
||||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
|
||||||
)
|
|
||||||
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
|
|
||||||
raise ValueError(
|
|
||||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
|
||||||
)
|
|
||||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
|
||||||
raise ValueError(
|
|
||||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
|
||||||
)
|
|
||||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
|
||||||
raise ValueError(
|
|
||||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
cfg.evals_per_epoch
|
|
||||||
and cfg.evaluation_strategy
|
|
||||||
and cfg.evaluation_strategy != "steps"
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
cfg.evaluation_strategy
|
|
||||||
and cfg.eval_steps
|
|
||||||
and cfg.evaluation_strategy != "steps"
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
cfg.val_set_size == 0
|
|
||||||
and (cfg.eval_steps or cfg.evaluation_strategy)
|
|
||||||
and not cfg.test_datasets
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
cfg.sample_packing
|
|
||||||
and cfg.eval_table_size
|
|
||||||
and cfg.eval_sample_packing is not False
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
|
|
||||||
raise ValueError(
|
|
||||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
|
|
||||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.rope_scaling:
|
|
||||||
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
|
||||||
|
|
||||||
if cfg.wandb_run_id and not cfg.wandb_name:
|
|
||||||
cfg.wandb_name = cfg.wandb_run_id
|
|
||||||
|
|
||||||
LOG.warning(
|
|
||||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.noisy_embedding_alpha is not None:
|
|
||||||
# Deprecated, use neftune_noise_alpha
|
|
||||||
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
|
||||||
if cfg.neftune_noise_alpha is None:
|
|
||||||
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
|
|
||||||
else:
|
|
||||||
# User is providing both; bail and have them sort out their settings
|
|
||||||
raise ValueError(
|
|
||||||
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
|
||||||
raise ValueError("neftune_noise_alpha must be > 0.0")
|
|
||||||
|
|
||||||
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
cfg.unfrozen_parameters
|
|
||||||
and cfg.gradient_checkpointing_kwargs
|
|
||||||
and cfg.gradient_checkpointing_kwargs.use_reentrant is True
|
|
||||||
):
|
|
||||||
# https://github.com/huggingface/transformers/issues/21381
|
|
||||||
raise ValueError(
|
|
||||||
"`use_reentrant` must be false when used with partially frozen model."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.deepspeed and Path(cfg.deepspeed).is_file():
|
|
||||||
with open(cfg.deepspeed, encoding="utf-8") as file:
|
|
||||||
contents = file.read()
|
|
||||||
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
|
|
||||||
if cfg.flash_attention:
|
|
||||||
if (
|
|
||||||
deepspeed_cfg.zero_optimization
|
|
||||||
and deepspeed_cfg.zero_optimization.stage == 3
|
|
||||||
):
|
|
||||||
if not (
|
|
||||||
(
|
|
||||||
deepspeed_cfg.bf16
|
|
||||||
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
|
|
||||||
is True
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
deepspeed_cfg.fp16
|
|
||||||
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
|
|
||||||
is True
|
|
||||||
)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
|
||||||
)
|
|
||||||
if "8bit" in cfg.optimizer and deepspeed_cfg.optimizer:
|
|
||||||
LOG.warning(
|
|
||||||
f"conflicting optimizer: {cfg.optimizer} used alongside deepspeed optimizer."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.test_datasets and cfg.val_set_size:
|
|
||||||
raise ValueError(
|
|
||||||
"non-zero val_set_size should not be used with test_datasets configuration"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.fsdp and "bnb" in cfg.optimizer:
|
|
||||||
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
|
|
||||||
|
|
||||||
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
|
|
||||||
raise ValueError(
|
|
||||||
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.eval_causal_lm_metrics:
|
|
||||||
if not isinstance(cfg.eval_causal_lm_metrics, list):
|
|
||||||
raise ValueError("eval_causal_lm_metrics must be a list")
|
|
||||||
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
|
||||||
if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS:
|
|
||||||
raise ValueError(
|
|
||||||
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO
|
|
||||||
# MPT 7b
|
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
||||||
# no 8bit adaAmw w bf16
|
|
||||||
|
|
||||||
# GPT-NeoX
|
|
||||||
# evals broken when extending context len
|
|
||||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
|
||||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
|
|
||||||
# attention_mask = causal_mask + attention_mask
|
|
||||||
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
|
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ class ChatTemplate(str, Enum):
|
|||||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||||
exaone = "exaone" # pylint: disable=invalid-name
|
exaone = "exaone" # pylint: disable=invalid-name
|
||||||
|
metharme = "metharme" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedParameters(BaseModel):
|
class DeprecatedParameters(BaseModel):
|
||||||
@@ -67,6 +68,7 @@ class DeprecatedParameters(BaseModel):
|
|||||||
rope_scaling: Optional[Any] = None
|
rope_scaling: Optional[Any] = None
|
||||||
noisy_embedding_alpha: Optional[float] = None
|
noisy_embedding_alpha: Optional[float] = None
|
||||||
dpo_beta: Optional[float] = None
|
dpo_beta: Optional[float] = None
|
||||||
|
evaluation_strategy: Optional[str] = None
|
||||||
|
|
||||||
@field_validator("max_packed_sequence_len")
|
@field_validator("max_packed_sequence_len")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -98,6 +100,13 @@ class DeprecatedParameters(BaseModel):
|
|||||||
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
||||||
return dpo_beta
|
return dpo_beta
|
||||||
|
|
||||||
|
@field_validator("evaluation_strategy")
|
||||||
|
@classmethod
|
||||||
|
def validate_evaluation_strategy(cls, evaluation_strategy):
|
||||||
|
if evaluation_strategy is not None:
|
||||||
|
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
||||||
|
return evaluation_strategy
|
||||||
|
|
||||||
|
|
||||||
class RemappedParameters(BaseModel):
|
class RemappedParameters(BaseModel):
|
||||||
"""parameters that have been remapped to other names"""
|
"""parameters that have been remapped to other names"""
|
||||||
@@ -427,6 +436,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
"ao_adamw_4bit",
|
"ao_adamw_4bit",
|
||||||
"ao_adamw_8bit",
|
"ao_adamw_8bit",
|
||||||
"ao_adamw_fp8",
|
"ao_adamw_fp8",
|
||||||
|
"adopt_adamw",
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
] = OptimizerNames.ADAMW_HF.value
|
] = OptimizerNames.ADAMW_HF.value
|
||||||
@@ -588,6 +598,9 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
rl: Optional[RLType] = None
|
rl: Optional[RLType] = None
|
||||||
reward_model: Optional[bool] = None
|
reward_model: Optional[bool] = None
|
||||||
|
dpo_use_weighting: Optional[
|
||||||
|
bool
|
||||||
|
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||||
|
|
||||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||||
@@ -726,7 +739,7 @@ class AxolotlInputConfig(
|
|||||||
warmup_ratio: Optional[float] = None
|
warmup_ratio: Optional[float] = None
|
||||||
eval_steps: Optional[Union[int, float]] = None
|
eval_steps: Optional[Union[int, float]] = None
|
||||||
evals_per_epoch: Optional[Union[int]] = None
|
evals_per_epoch: Optional[Union[int]] = None
|
||||||
evaluation_strategy: Optional[str] = None
|
eval_strategy: Optional[str] = None
|
||||||
save_steps: Optional[Union[int, float]] = None
|
save_steps: Optional[Union[int, float]] = None
|
||||||
saves_per_epoch: Optional[int] = None
|
saves_per_epoch: Optional[int] = None
|
||||||
save_strategy: Optional[str] = None
|
save_strategy: Optional[str] = None
|
||||||
@@ -778,28 +791,25 @@ class AxolotlInputConfig(
|
|||||||
is_mistral_derived_model: Optional[bool] = Field(default=None)
|
is_mistral_derived_model: Optional[bool] = Field(default=None)
|
||||||
is_qwen_derived_model: Optional[bool] = Field(default=None)
|
is_qwen_derived_model: Optional[bool] = Field(default=None)
|
||||||
|
|
||||||
|
plugins: Optional[List[str]] = Field(default=None)
|
||||||
|
|
||||||
@field_validator("datasets", mode="before")
|
@field_validator("datasets", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def fix_sharegpt_datasets(cls, datasets):
|
def deprecate_sharegpt_datasets(cls, datasets):
|
||||||
for idx, ds_cfg in enumerate(datasets):
|
for _, ds_cfg in enumerate(datasets):
|
||||||
if not ds_cfg["type"]:
|
if not ds_cfg.get("type"):
|
||||||
continue
|
continue
|
||||||
if ds_cfg["type"] == "sharegpt:chat":
|
|
||||||
LOG.warning(
|
ds_type = ds_cfg["type"]
|
||||||
PendingDeprecationWarning(
|
# skip if it's a dict (for custom user instruction prompt)
|
||||||
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
|
if isinstance(ds_type, dict):
|
||||||
)
|
continue
|
||||||
)
|
|
||||||
datasets[idx]["type"] = "sharegpt"
|
if isinstance(ds_type, str) and ds_type.startswith("sharegpt"):
|
||||||
if "sharegpt_simple" in ds_cfg["type"]:
|
raise ValueError(
|
||||||
LOG.warning(
|
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
|
||||||
PendingDeprecationWarning(
|
|
||||||
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
datasets[idx]["type"] = datasets[idx]["type"].replace(
|
|
||||||
"sharegpt_simple", "sharegpt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return datasets
|
return datasets
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -1031,21 +1041,21 @@ class AxolotlInputConfig(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def check_evals(cls, data):
|
def check_evals(cls, data):
|
||||||
if (
|
if (
|
||||||
data.get("evaluation_strategy")
|
data.get("eval_strategy")
|
||||||
and data.get("eval_steps")
|
and data.get("eval_steps")
|
||||||
and data.get("evaluation_strategy") != "steps"
|
and data.get("eval_strategy") != "steps"
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
"eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps."
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
data.get("val_set_size") == 0
|
data.get("val_set_size") == 0
|
||||||
and (data.get("eval_steps") or data.get("evaluation_strategy"))
|
and (data.get("eval_steps") or data.get("eval_strategy"))
|
||||||
and not data.get("test_datasets")
|
and not data.get("test_datasets")
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
"eval_steps and eval_strategy are not supported with val_set_size == 0"
|
||||||
)
|
)
|
||||||
if data.get("evals_per_epoch") and data.get("eval_steps"):
|
if data.get("evals_per_epoch") and data.get("eval_steps"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -1053,11 +1063,11 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
data.get("evals_per_epoch")
|
data.get("evals_per_epoch")
|
||||||
and data.get("evaluation_strategy")
|
and data.get("eval_strategy")
|
||||||
and data.get("evaluation_strategy") != "steps"
|
and data.get("eval_strategy") != "steps"
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
"eval_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||||
)
|
)
|
||||||
|
|
||||||
if data.get("do_bench_eval") and not (
|
if data.get("do_bench_eval") and not (
|
||||||
@@ -1289,6 +1299,25 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def warn_qlora_zero3_w_use_reentrant(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("adapter") == "qlora"
|
||||||
|
and data.get("gradient_checkpointing_kwargs", {})
|
||||||
|
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
|
||||||
|
is False
|
||||||
|
and "zero3" in data.get("deepspeed", "")
|
||||||
|
):
|
||||||
|
# may result in:
|
||||||
|
# torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint:
|
||||||
|
# Recomputed values for the following tensors have different metadata
|
||||||
|
# than during the forward pass.
|
||||||
|
LOG.warning(
|
||||||
|
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_val_w_test_datasets(cls, data):
|
def check_val_w_test_datasets(cls, data):
|
||||||
@@ -1298,6 +1327,19 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_eval_strategy(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("evaluation_strategy") is not None
|
||||||
|
and data.get("eval_strategy") is None
|
||||||
|
):
|
||||||
|
LOG.info(
|
||||||
|
"explicitly setting `eval_strategy` from the `evaluation_strategy`"
|
||||||
|
)
|
||||||
|
data["eval_strategy"] = data.get("evaluation_strategy")
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
||||||
|
|||||||
@@ -260,6 +260,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
for config_dataset in for_d_in_datasets(cfg_datasets):
|
for config_dataset in for_d_in_datasets(cfg_datasets):
|
||||||
ds: Optional[Union[Dataset, DatasetDict]] = None
|
ds: Optional[Union[Dataset, DatasetDict]] = None
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
|
ds_trust_remote_code = config_dataset.trust_remote_code
|
||||||
try:
|
try:
|
||||||
# this is just a basic check to see if the path is a
|
# this is just a basic check to see if the path is a
|
||||||
# valid HF dataset that's loadable
|
# valid HF dataset that's loadable
|
||||||
@@ -269,6 +270,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
streaming=True,
|
streaming=True,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
revision=config_dataset.revision,
|
revision=config_dataset.revision,
|
||||||
|
trust_remote_code=ds_trust_remote_code,
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||||
@@ -348,7 +350,15 @@ def load_tokenized_prepared_datasets(
|
|||||||
split=None,
|
split=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ds = load_from_disk(config_dataset.path)
|
try:
|
||||||
|
ds = load_from_disk(config_dataset.path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
ds = load_dataset(
|
||||||
|
config_dataset.path,
|
||||||
|
name=config_dataset.name,
|
||||||
|
streaming=False,
|
||||||
|
split=None,
|
||||||
|
)
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
ds_type = get_ds_type(config_dataset)
|
ds_type = get_ds_type(config_dataset)
|
||||||
|
|
||||||
@@ -366,7 +376,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
elif ds_from_hub:
|
elif ds_from_hub:
|
||||||
load_ds_kwargs = {}
|
load_ds_kwargs = {}
|
||||||
if config_dataset.split:
|
if config_dataset.split:
|
||||||
load_ds_kwargs = {"split": config_dataset.split}
|
load_ds_kwargs["split"] = config_dataset.split
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
config_dataset.path,
|
config_dataset.path,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
@@ -374,6 +384,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
data_files=config_dataset.data_files,
|
data_files=config_dataset.data_files,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
revision=config_dataset.revision,
|
revision=config_dataset.revision,
|
||||||
|
trust_remote_code=config_dataset.trust_remote_code,
|
||||||
**load_ds_kwargs,
|
**load_ds_kwargs,
|
||||||
)
|
)
|
||||||
elif ds_from_cloud and remote_file_system:
|
elif ds_from_cloud and remote_file_system:
|
||||||
@@ -391,6 +402,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
split=None,
|
||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
|
trust_remote_code=config_dataset.trust_remote_code,
|
||||||
)
|
)
|
||||||
elif config_dataset.path.startswith("https://"):
|
elif config_dataset.path.startswith("https://"):
|
||||||
ds_type = get_ds_type(config_dataset)
|
ds_type = get_ds_type(config_dataset)
|
||||||
@@ -401,6 +413,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
split=None,
|
||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
|
trust_remote_code=config_dataset.trust_remote_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if isinstance(config_dataset.data_files, str):
|
if isinstance(config_dataset.data_files, str):
|
||||||
|
|||||||
25
src/axolotl/utils/environment.py
Normal file
25
src/axolotl/utils/environment.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""
|
||||||
|
utils to get GPU info for the current environment
|
||||||
|
"""
|
||||||
|
from accelerate.utils.environment import (
|
||||||
|
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
|
||||||
|
)
|
||||||
|
from accelerate.utils.environment import get_gpu_info
|
||||||
|
|
||||||
|
|
||||||
|
def check_cuda_p2p_ib_support():
|
||||||
|
if not accelerate_check_cuda_p2p_ib_support():
|
||||||
|
return False
|
||||||
|
unsupported_devices = {"RTX 6000 Ada"}
|
||||||
|
try:
|
||||||
|
device_names, device_count = get_gpu_info()
|
||||||
|
if 1 < device_count < 8:
|
||||||
|
if any(
|
||||||
|
unsupported_device in device_name
|
||||||
|
for device_name in device_names
|
||||||
|
for unsupported_device in unsupported_devices
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
except Exception: # pylint: disable=broad-except # nosec
|
||||||
|
pass
|
||||||
|
return True
|
||||||
@@ -14,6 +14,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
torch_version = version.parse(torch.__version__)
|
||||||
|
|
||||||
|
if torch_version < version.parse("2.4.0"):
|
||||||
|
torch_cuda_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
||||||
|
torch_cuda_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
||||||
|
else:
|
||||||
|
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||||
|
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||||
|
|
||||||
|
|
||||||
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||||
@@ -25,7 +35,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.cuda.amp.custom_fwd
|
@torch_cuda_amp_custom_fwd
|
||||||
def forward(ctx, forward_function, hidden_states, *args):
|
def forward(ctx, forward_function, hidden_states, *args):
|
||||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -36,7 +46,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.cuda.amp.custom_bwd
|
@torch_cuda_amp_custom_bwd
|
||||||
def backward(ctx, dY):
|
def backward(ctx, dY):
|
||||||
(hidden_states,) = ctx.saved_tensors
|
(hidden_states,) = ctx.saved_tensors
|
||||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||||
|
|||||||
@@ -238,6 +238,7 @@ def load_tokenizer(cfg):
|
|||||||
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
and k != "pad_token"
|
||||||
):
|
):
|
||||||
lora_modules_to_save = ", ".join(
|
lora_modules_to_save = ", ".join(
|
||||||
[f"`{x}`" for x in lora_modules_to_save]
|
[f"`{x}`" for x in lora_modules_to_save]
|
||||||
@@ -394,10 +395,17 @@ class ModelLoader:
|
|||||||
and self.cfg.flash_attention
|
and self.cfg.flash_attention
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
|
has_remote_code = (
|
||||||
|
"auto_map" in self.model_config
|
||||||
|
and "AutoModelForCausalLM" in self.model_config["auto_map"]
|
||||||
|
)
|
||||||
|
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||||
|
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
|
||||||
|
has_remote_code = self.cfg.trust_remote_code
|
||||||
patch_for_multipack(
|
patch_for_multipack(
|
||||||
self.cfg.model_config_type,
|
self.cfg.model_config_type,
|
||||||
model_name=self.cfg.base_model,
|
model_name=self.cfg.base_model,
|
||||||
is_remote_code=self.cfg.trust_remote_code,
|
has_remote_code=has_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.is_llama_derived_model:
|
if self.cfg.is_llama_derived_model:
|
||||||
|
|||||||
508
src/axolotl/utils/optimizers/adopt.py
Normal file
508
src/axolotl/utils/optimizers/adopt.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
"""
|
||||||
|
Copied from https://github.com/iShohei220/adopt
|
||||||
|
|
||||||
|
ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024)
|
||||||
|
Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka
|
||||||
|
"""
|
||||||
|
# mypy: ignore-errors
|
||||||
|
# pylint: skip-file
|
||||||
|
# mypy: allow-untyped-decorators
|
||||||
|
# mypy: allow-untyped-defs
|
||||||
|
from typing import List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.optim.optimizer import (
|
||||||
|
Optimizer,
|
||||||
|
ParamsT,
|
||||||
|
_default_to_fused_or_foreach,
|
||||||
|
_device_dtype_check_for_fused,
|
||||||
|
_disable_dynamo_if_unsupported,
|
||||||
|
_get_capturable_supported_devices,
|
||||||
|
_get_scalar_dtype,
|
||||||
|
_get_value,
|
||||||
|
_use_grad_for_differentiable,
|
||||||
|
_view_as_real,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["ADOPT", "adopt"]
|
||||||
|
|
||||||
|
|
||||||
|
class ADOPT(Optimizer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: ParamsT,
|
||||||
|
lr: Union[float, Tensor] = 1e-3,
|
||||||
|
betas: Tuple[float, float] = (0.9, 0.9999),
|
||||||
|
eps: float = 1e-6,
|
||||||
|
weight_decay: float = 0.0,
|
||||||
|
decoupled: bool = False,
|
||||||
|
*,
|
||||||
|
foreach: Optional[bool] = None,
|
||||||
|
maximize: bool = False,
|
||||||
|
capturable: bool = False,
|
||||||
|
differentiable: bool = False,
|
||||||
|
fused: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
if isinstance(lr, Tensor):
|
||||||
|
if foreach and not capturable:
|
||||||
|
raise ValueError(
|
||||||
|
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||||
|
)
|
||||||
|
if lr.numel() != 1:
|
||||||
|
raise ValueError("Tensor lr must be 1-element")
|
||||||
|
if not 0.0 <= lr:
|
||||||
|
raise ValueError(f"Invalid learning rate: {lr}")
|
||||||
|
if not 0.0 <= eps:
|
||||||
|
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||||
|
if not 0.0 <= betas[0] < 1.0:
|
||||||
|
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||||
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
|
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||||
|
if not 0.0 <= weight_decay:
|
||||||
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
|
defaults = dict(
|
||||||
|
lr=lr,
|
||||||
|
betas=betas,
|
||||||
|
eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
decoupled=decoupled,
|
||||||
|
maximize=maximize,
|
||||||
|
foreach=foreach,
|
||||||
|
capturable=capturable,
|
||||||
|
differentiable=differentiable,
|
||||||
|
fused=fused,
|
||||||
|
)
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
if fused:
|
||||||
|
# TODO: support fused
|
||||||
|
raise RuntimeError("`fused` is not currently supported")
|
||||||
|
|
||||||
|
if differentiable:
|
||||||
|
raise RuntimeError("`fused` does not support `differentiable`")
|
||||||
|
self._step_supports_amp_scaling = True
|
||||||
|
# TODO(crcrpar): [low prec params & their higher prec copy]
|
||||||
|
# Support AMP with FP16/BF16 model params which would need
|
||||||
|
# higher prec copy of params to do update math in higher prec to
|
||||||
|
# alleviate the loss of information.
|
||||||
|
if foreach:
|
||||||
|
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super().__setstate__(state)
|
||||||
|
for group in self.param_groups:
|
||||||
|
group.setdefault("maximize", False)
|
||||||
|
group.setdefault("foreach", None)
|
||||||
|
group.setdefault("capturable", False)
|
||||||
|
group.setdefault("differentiable", False)
|
||||||
|
fused = group.setdefault("fused", None)
|
||||||
|
for p in group["params"]:
|
||||||
|
p_state = self.state.get(p, [])
|
||||||
|
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
|
||||||
|
step_val = float(p_state["step"])
|
||||||
|
p_state["step"] = (
|
||||||
|
torch.tensor(
|
||||||
|
step_val,
|
||||||
|
dtype=_get_scalar_dtype(is_fused=fused),
|
||||||
|
device=p.device,
|
||||||
|
)
|
||||||
|
if group["capturable"] or group["fused"]
|
||||||
|
else torch.tensor(step_val, dtype=_get_scalar_dtype())
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_group(
|
||||||
|
self,
|
||||||
|
group,
|
||||||
|
params_with_grad,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
):
|
||||||
|
has_complex = False
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is not None:
|
||||||
|
has_complex |= torch.is_complex(p)
|
||||||
|
params_with_grad.append(p)
|
||||||
|
if p.grad.is_sparse:
|
||||||
|
raise RuntimeError("ADOPT does not support sparse gradients")
|
||||||
|
grads.append(p.grad)
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
# Lazy state initialization
|
||||||
|
if len(state) == 0:
|
||||||
|
if group["fused"]:
|
||||||
|
_device_dtype_check_for_fused(p)
|
||||||
|
# note(crcrpar): [special device hosting for step]
|
||||||
|
# Deliberately host `step` on CPU if both capturable and fused are off.
|
||||||
|
# This is because kernel launches are costly on CUDA and XLA.
|
||||||
|
state["step"] = (
|
||||||
|
torch.zeros(
|
||||||
|
(),
|
||||||
|
dtype=_get_scalar_dtype(is_fused=group["fused"]),
|
||||||
|
device=p.device,
|
||||||
|
)
|
||||||
|
if group["capturable"] or group["fused"]
|
||||||
|
else torch.tensor(0.0, dtype=_get_scalar_dtype())
|
||||||
|
)
|
||||||
|
# Exponential moving average of gradient values
|
||||||
|
state["exp_avg"] = torch.zeros_like(
|
||||||
|
p, memory_format=torch.preserve_format
|
||||||
|
)
|
||||||
|
# Exponential moving average of squared gradient values
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(
|
||||||
|
p, memory_format=torch.preserve_format
|
||||||
|
)
|
||||||
|
|
||||||
|
exp_avgs.append(state["exp_avg"])
|
||||||
|
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||||
|
|
||||||
|
if group["differentiable"] and state["step"].requires_grad:
|
||||||
|
raise RuntimeError(
|
||||||
|
"`requires_grad` is not supported for `step` in differentiable mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Foreach without capturable does not support a tensor lr
|
||||||
|
if (
|
||||||
|
group["foreach"]
|
||||||
|
and torch.is_tensor(group["lr"])
|
||||||
|
and not group["capturable"]
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||||
|
)
|
||||||
|
|
||||||
|
state_steps.append(state["step"])
|
||||||
|
return has_complex
|
||||||
|
|
||||||
|
@_use_grad_for_differentiable
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""Perform a single optimization step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
closure (Callable, optional): A closure that reevaluates the model
|
||||||
|
and returns the loss.
|
||||||
|
"""
|
||||||
|
self._cuda_graph_capture_health_check()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
params_with_grad: List[Tensor] = []
|
||||||
|
grads: List[Tensor] = []
|
||||||
|
exp_avgs: List[Tensor] = []
|
||||||
|
exp_avg_sqs: List[Tensor] = []
|
||||||
|
state_steps: List[Tensor] = []
|
||||||
|
beta1, beta2 = group["betas"]
|
||||||
|
|
||||||
|
has_complex = self._init_group(
|
||||||
|
group,
|
||||||
|
params_with_grad,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
adopt(
|
||||||
|
params_with_grad,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
has_complex=has_complex,
|
||||||
|
beta1=beta1,
|
||||||
|
beta2=beta2,
|
||||||
|
lr=group["lr"],
|
||||||
|
weight_decay=group["weight_decay"],
|
||||||
|
decoupled=group["decoupled"],
|
||||||
|
eps=group["eps"],
|
||||||
|
maximize=group["maximize"],
|
||||||
|
foreach=group["foreach"],
|
||||||
|
capturable=group["capturable"],
|
||||||
|
differentiable=group["differentiable"],
|
||||||
|
fused=group["fused"],
|
||||||
|
grad_scale=getattr(self, "grad_scale", None),
|
||||||
|
found_inf=getattr(self, "found_inf", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def _single_tensor_adopt(
|
||||||
|
params: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
exp_avgs: List[Tensor],
|
||||||
|
exp_avg_sqs: List[Tensor],
|
||||||
|
state_steps: List[Tensor],
|
||||||
|
grad_scale: Optional[Tensor],
|
||||||
|
found_inf: Optional[Tensor],
|
||||||
|
*,
|
||||||
|
has_complex: bool,
|
||||||
|
beta1: float,
|
||||||
|
beta2: float,
|
||||||
|
lr: Union[float, Tensor],
|
||||||
|
weight_decay: float,
|
||||||
|
decoupled: bool,
|
||||||
|
eps: float,
|
||||||
|
maximize: bool,
|
||||||
|
capturable: bool,
|
||||||
|
differentiable: bool,
|
||||||
|
):
|
||||||
|
assert grad_scale is None and found_inf is None
|
||||||
|
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
# this assert is due to JIT being dumb and not realizing that the ops below
|
||||||
|
# have overloads to handle both float and Tensor lrs, so we just assert it's
|
||||||
|
# a float since most people using JIT are using floats
|
||||||
|
assert isinstance(lr, float)
|
||||||
|
|
||||||
|
for i, param in enumerate(params):
|
||||||
|
grad = grads[i] if not maximize else -grads[i]
|
||||||
|
exp_avg = exp_avgs[i]
|
||||||
|
exp_avg_sq = exp_avg_sqs[i]
|
||||||
|
step_t = state_steps[i]
|
||||||
|
|
||||||
|
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||||
|
if not torch._utils.is_compiling() and capturable:
|
||||||
|
capturable_supported_devices = _get_capturable_supported_devices()
|
||||||
|
assert (
|
||||||
|
param.device.type == step_t.device.type
|
||||||
|
and param.device.type in capturable_supported_devices
|
||||||
|
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||||
|
|
||||||
|
# update step
|
||||||
|
step_t += 1
|
||||||
|
|
||||||
|
if weight_decay != 0:
|
||||||
|
if decoupled:
|
||||||
|
param.add_(param, alpha=-lr * weight_decay)
|
||||||
|
else:
|
||||||
|
grad = grad.add(param, alpha=weight_decay)
|
||||||
|
|
||||||
|
if torch.is_complex(param):
|
||||||
|
grad = torch.view_as_real(grad)
|
||||||
|
if exp_avg is not None:
|
||||||
|
exp_avg = torch.view_as_real(exp_avg)
|
||||||
|
if exp_avg_sq is not None:
|
||||||
|
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||||
|
param = torch.view_as_real(param)
|
||||||
|
|
||||||
|
step = step_t if capturable or differentiable else _get_value(step_t)
|
||||||
|
if step == 1:
|
||||||
|
exp_avg_sq.addcmul_(grad, grad.conj())
|
||||||
|
continue
|
||||||
|
|
||||||
|
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
||||||
|
if step == 2:
|
||||||
|
exp_avg.addcdiv_(grad, denom)
|
||||||
|
else:
|
||||||
|
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
|
||||||
|
|
||||||
|
param.add_(exp_avg, alpha=-lr)
|
||||||
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||||
|
|
||||||
|
|
||||||
|
def _multi_tensor_adopt(
|
||||||
|
params: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
exp_avgs: List[Tensor],
|
||||||
|
exp_avg_sqs: List[Tensor],
|
||||||
|
state_steps: List[Tensor],
|
||||||
|
grad_scale: Optional[Tensor],
|
||||||
|
found_inf: Optional[Tensor],
|
||||||
|
*,
|
||||||
|
has_complex: bool,
|
||||||
|
beta1: float,
|
||||||
|
beta2: float,
|
||||||
|
lr: Union[float, Tensor],
|
||||||
|
weight_decay: float,
|
||||||
|
decoupled: bool,
|
||||||
|
eps: float,
|
||||||
|
maximize: bool,
|
||||||
|
capturable: bool,
|
||||||
|
differentiable: bool,
|
||||||
|
):
|
||||||
|
if len(params) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(lr, Tensor) and not capturable:
|
||||||
|
raise RuntimeError(
|
||||||
|
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||||
|
if not torch._utils.is_compiling() and capturable:
|
||||||
|
capturable_supported_devices = _get_capturable_supported_devices(
|
||||||
|
supports_xla=False
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
p.device.type == step.device.type
|
||||||
|
and p.device.type in capturable_supported_devices
|
||||||
|
for p, step in zip(params, state_steps)
|
||||||
|
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||||
|
|
||||||
|
assert grad_scale is None and found_inf is None
|
||||||
|
|
||||||
|
assert not differentiable, "_foreach ops don't support autograd"
|
||||||
|
|
||||||
|
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||||
|
[params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item]
|
||||||
|
)
|
||||||
|
for (
|
||||||
|
device_params_,
|
||||||
|
device_grads_,
|
||||||
|
device_exp_avgs_,
|
||||||
|
device_exp_avg_sqs_,
|
||||||
|
device_state_steps_,
|
||||||
|
), _ in grouped_tensors.values():
|
||||||
|
device_params = cast(List[Tensor], device_params_)
|
||||||
|
device_grads = cast(List[Tensor], device_grads_)
|
||||||
|
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
|
||||||
|
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
|
||||||
|
device_state_steps = cast(List[Tensor], device_state_steps_)
|
||||||
|
|
||||||
|
# Handle complex parameters
|
||||||
|
if has_complex:
|
||||||
|
_view_as_real(
|
||||||
|
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
|
||||||
|
)
|
||||||
|
|
||||||
|
if maximize:
|
||||||
|
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Update steps
|
||||||
|
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||||
|
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||||
|
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||||
|
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||||
|
torch._foreach_add_(
|
||||||
|
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch._foreach_add_(device_state_steps, 1)
|
||||||
|
|
||||||
|
if weight_decay != 0:
|
||||||
|
if decoupled:
|
||||||
|
torch._foreach_add_(
|
||||||
|
device_params, device_params, alpha=-lr * weight_decay
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||||
|
if maximize:
|
||||||
|
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||||
|
else:
|
||||||
|
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||||
|
device_grads, device_params, alpha=weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
if device_state_steps[0] == 1:
|
||||||
|
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
||||||
|
continue
|
||||||
|
|
||||||
|
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||||
|
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
|
||||||
|
|
||||||
|
if device_state_steps[0] == 2:
|
||||||
|
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
|
||||||
|
else:
|
||||||
|
torch._foreach_mul_(device_exp_avgs, beta1)
|
||||||
|
torch._foreach_addcdiv_(
|
||||||
|
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
|
||||||
|
)
|
||||||
|
|
||||||
|
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||||
|
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||||
|
torch._foreach_addcmul_(
|
||||||
|
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
|
||||||
|
def adopt(
|
||||||
|
params: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
exp_avgs: List[Tensor],
|
||||||
|
exp_avg_sqs: List[Tensor],
|
||||||
|
state_steps: List[Tensor],
|
||||||
|
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||||
|
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||||
|
foreach: Optional[bool] = None,
|
||||||
|
capturable: bool = False,
|
||||||
|
differentiable: bool = False,
|
||||||
|
fused: Optional[bool] = None,
|
||||||
|
grad_scale: Optional[Tensor] = None,
|
||||||
|
found_inf: Optional[Tensor] = None,
|
||||||
|
has_complex: bool = False,
|
||||||
|
*,
|
||||||
|
beta1: float,
|
||||||
|
beta2: float,
|
||||||
|
lr: Union[float, Tensor],
|
||||||
|
weight_decay: float,
|
||||||
|
decoupled: bool,
|
||||||
|
eps: float,
|
||||||
|
maximize: bool,
|
||||||
|
):
|
||||||
|
r"""Functional API that performs ADOPT algorithm computation."""
|
||||||
|
# Respect when the user inputs False/True for foreach or fused. We only want to change
|
||||||
|
# the default when neither have been user-specified. Note that we default to foreach
|
||||||
|
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
|
||||||
|
# bake-in time before making it the default, even if it is typically faster.
|
||||||
|
if fused is None and foreach is None:
|
||||||
|
_, foreach = _default_to_fused_or_foreach(
|
||||||
|
params, differentiable, use_fused=False
|
||||||
|
)
|
||||||
|
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
|
||||||
|
if foreach and isinstance(lr, Tensor) and not capturable:
|
||||||
|
foreach = False
|
||||||
|
if fused is None:
|
||||||
|
fused = False
|
||||||
|
if foreach is None:
|
||||||
|
foreach = False
|
||||||
|
|
||||||
|
# this check is slow during compilation, so we skip it
|
||||||
|
# if it's strictly needed we can add this check back in dynamo
|
||||||
|
if not torch._utils.is_compiling() and not all(
|
||||||
|
isinstance(t, torch.Tensor) for t in state_steps
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
if foreach and torch.jit.is_scripting():
|
||||||
|
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||||
|
if fused and torch.jit.is_scripting():
|
||||||
|
raise RuntimeError("torch.jit.script not supported with fused optimizers")
|
||||||
|
|
||||||
|
# if fused and not torch.jit.is_scripting():
|
||||||
|
# func = _fused_adopt
|
||||||
|
# elif foreach and not torch.jit.is_scripting():
|
||||||
|
if foreach and not torch.jit.is_scripting():
|
||||||
|
func = _multi_tensor_adopt
|
||||||
|
else:
|
||||||
|
func = _single_tensor_adopt
|
||||||
|
|
||||||
|
func(
|
||||||
|
params,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
has_complex=has_complex,
|
||||||
|
beta1=beta1,
|
||||||
|
beta2=beta2,
|
||||||
|
lr=lr,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
decoupled=decoupled,
|
||||||
|
eps=eps,
|
||||||
|
maximize=maximize,
|
||||||
|
capturable=capturable,
|
||||||
|
differentiable=differentiable,
|
||||||
|
grad_scale=grad_scale,
|
||||||
|
found_inf=found_inf,
|
||||||
|
)
|
||||||
@@ -1,250 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.distributed._tensor import DTensor
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torchao.prototype.low_bit_optim.subclass_4bit import OptimState4bit
|
|
||||||
from torchao.prototype.low_bit_optim.subclass_8bit import OptimState8bit
|
|
||||||
from torchao.prototype.low_bit_optim.subclass_fp8 import OptimStateFp8
|
|
||||||
|
|
||||||
|
|
||||||
class _ShampooBase(Optimizer):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params,
|
|
||||||
lr=1e-1,
|
|
||||||
momentum=0.0,
|
|
||||||
weight_decay=0.0,
|
|
||||||
eps=1e-4,
|
|
||||||
update_freq=1,
|
|
||||||
*,
|
|
||||||
block_size,
|
|
||||||
quantization_bits,
|
|
||||||
optimizer_state_class,
|
|
||||||
):
|
|
||||||
if lr <= 0.0:
|
|
||||||
raise ValueError(f"Invalid learning rate: {lr}")
|
|
||||||
if momentum < 0.0:
|
|
||||||
raise ValueError(f"Invalid momentum value: {momentum}")
|
|
||||||
if weight_decay < 0.0:
|
|
||||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
|
||||||
if eps < 0.0:
|
|
||||||
raise ValueError(f"Invalid eps value: {eps}")
|
|
||||||
if update_freq < 1:
|
|
||||||
raise ValueError(f"Invalid update_freq value: {update_freq}")
|
|
||||||
|
|
||||||
defaults = dict(
|
|
||||||
lr=lr,
|
|
||||||
momentum=momentum,
|
|
||||||
weight_decay=weight_decay,
|
|
||||||
eps=eps,
|
|
||||||
update_freq=update_freq,
|
|
||||||
)
|
|
||||||
super().__init__(params, defaults)
|
|
||||||
self.block_size = block_size
|
|
||||||
self.quantization_bits = quantization_bits
|
|
||||||
self.optimizer_state_class = optimizer_state_class
|
|
||||||
|
|
||||||
def step(self, closure: Optional[callable] = None) -> Optional[float]:
|
|
||||||
loss = None
|
|
||||||
if closure is not None:
|
|
||||||
loss = closure()
|
|
||||||
|
|
||||||
for group in self.param_groups:
|
|
||||||
for p in group["params"]:
|
|
||||||
if p.grad is None:
|
|
||||||
continue
|
|
||||||
grad = p.grad.data
|
|
||||||
state = self.state[p]
|
|
||||||
|
|
||||||
# State initialization
|
|
||||||
if len(state) == 0:
|
|
||||||
state["step"] = 0
|
|
||||||
state["momentum_buffer"] = self._new_buffer(grad, True)
|
|
||||||
state["preconds"] = []
|
|
||||||
state["inv_preconds"] = []
|
|
||||||
for dim in grad.size():
|
|
||||||
state["preconds"].append(
|
|
||||||
self.optimizer_state_class.zeros(
|
|
||||||
(dim, dim),
|
|
||||||
signed=False,
|
|
||||||
block_size=self.block_size,
|
|
||||||
device=grad.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
state["inv_preconds"].append(
|
|
||||||
torch.zeros((dim, dim), device=grad.device)
|
|
||||||
)
|
|
||||||
|
|
||||||
state["step"] += 1
|
|
||||||
beta = group["momentum"]
|
|
||||||
weight_decay = group["weight_decay"]
|
|
||||||
lr = group["lr"]
|
|
||||||
eps = group["eps"]
|
|
||||||
update_freq = group["update_freq"]
|
|
||||||
|
|
||||||
# Apply momentum
|
|
||||||
if beta > 0:
|
|
||||||
state["momentum_buffer"].mul_(beta).add_(grad, alpha=1 - beta)
|
|
||||||
grad = state["momentum_buffer"]
|
|
||||||
|
|
||||||
# Apply weight decay
|
|
||||||
if weight_decay > 0:
|
|
||||||
grad = grad.add(p.data, alpha=weight_decay)
|
|
||||||
|
|
||||||
# Preconditioning
|
|
||||||
order = grad.ndimension()
|
|
||||||
original_size = grad.size()
|
|
||||||
for dim_id, dim in enumerate(grad.size()):
|
|
||||||
precond = state["preconds"][dim_id]
|
|
||||||
inv_precond = state["inv_preconds"][dim_id]
|
|
||||||
|
|
||||||
# Reshape grad
|
|
||||||
grad = grad.transpose(0, dim_id).contiguous()
|
|
||||||
transposed_size = grad.size()
|
|
||||||
grad = grad.view(dim, -1)
|
|
||||||
|
|
||||||
grad_t = grad.t()
|
|
||||||
|
|
||||||
# Update preconditioner
|
|
||||||
precond_fp32 = precond.dequantize()
|
|
||||||
precond_update = grad @ grad_t
|
|
||||||
precond_fp32.add_(precond_update)
|
|
||||||
|
|
||||||
# Quantize preconditioner back
|
|
||||||
precond.copy_(precond_fp32)
|
|
||||||
|
|
||||||
# Update inverse preconditioner
|
|
||||||
if state["step"] % update_freq == 0:
|
|
||||||
inv_precond.copy_(
|
|
||||||
self._compute_inv_precond(precond_fp32, eps, order)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Precondition grad
|
|
||||||
if dim_id == order - 1:
|
|
||||||
# Last dimension
|
|
||||||
grad = grad_t @ inv_precond
|
|
||||||
grad = grad.view(original_size)
|
|
||||||
else:
|
|
||||||
grad = inv_precond @ grad
|
|
||||||
grad = grad.view(transposed_size)
|
|
||||||
|
|
||||||
# Update parameter
|
|
||||||
p.data.add_(grad, alpha=-lr)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def _compute_inv_precond(self, precond: Tensor, eps: float, order: int):
|
|
||||||
# Add eps for numerical stability
|
|
||||||
precond = precond + torch.eye(precond.size(0), device=precond.device) * eps
|
|
||||||
|
|
||||||
# Compute matrix power
|
|
||||||
inv_precond = self._matrix_power(precond, -1.0 / (2 * order))
|
|
||||||
|
|
||||||
return inv_precond
|
|
||||||
|
|
||||||
def _matrix_power(self, matrix: Tensor, power: float) -> Tensor:
|
|
||||||
# Compute matrix power using SVD
|
|
||||||
u, s, v = torch.svd(matrix)
|
|
||||||
s_pow = s.pow(power)
|
|
||||||
return u @ torch.diag(s_pow) @ v.t()
|
|
||||||
|
|
||||||
# bring your own function to create zero-filled subclass
|
|
||||||
@staticmethod
|
|
||||||
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
# follow bitsandbytes, only quantize tensors >= 4096 values
|
|
||||||
# also wrap subclass in DTensor when needed
|
|
||||||
def _new_buffer(self, p: Tensor, signed: bool):
|
|
||||||
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
|
|
||||||
if isinstance(p, DTensor):
|
|
||||||
out = DTensor.from_local(
|
|
||||||
local_tensor=self._subclass_zeros(
|
|
||||||
p.to_local(), signed, self.block_size
|
|
||||||
),
|
|
||||||
device_mesh=p.device_mesh,
|
|
||||||
placements=p.placements,
|
|
||||||
run_check=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
out = self._subclass_zeros(p, signed, self.block_size)
|
|
||||||
else:
|
|
||||||
out = torch.zeros_like(p)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Shampoo8bit(_ShampooBase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params,
|
|
||||||
lr=1e-1,
|
|
||||||
momentum=0.0,
|
|
||||||
weight_decay=0.0,
|
|
||||||
eps=1e-4,
|
|
||||||
update_freq=1,
|
|
||||||
*,
|
|
||||||
block_size=256,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
params,
|
|
||||||
lr,
|
|
||||||
momentum,
|
|
||||||
weight_decay,
|
|
||||||
eps,
|
|
||||||
update_freq,
|
|
||||||
block_size=block_size,
|
|
||||||
quantization_bits=8,
|
|
||||||
optimizer_state_class=OptimState8bit,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Shampoo4bit(_ShampooBase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params,
|
|
||||||
lr=1e-1,
|
|
||||||
momentum=0.0,
|
|
||||||
weight_decay=0.0,
|
|
||||||
eps=1e-4,
|
|
||||||
update_freq=1,
|
|
||||||
*,
|
|
||||||
block_size=128,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
params,
|
|
||||||
lr,
|
|
||||||
momentum,
|
|
||||||
weight_decay,
|
|
||||||
eps,
|
|
||||||
update_freq,
|
|
||||||
block_size=block_size,
|
|
||||||
quantization_bits=4,
|
|
||||||
optimizer_state_class=OptimState4bit,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ShampooFp8(_ShampooBase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
params,
|
|
||||||
lr=1e-1,
|
|
||||||
momentum=0.0,
|
|
||||||
weight_decay=0.0,
|
|
||||||
eps=1e-4,
|
|
||||||
update_freq=1,
|
|
||||||
*,
|
|
||||||
block_size=256,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
params,
|
|
||||||
lr,
|
|
||||||
momentum,
|
|
||||||
weight_decay,
|
|
||||||
eps,
|
|
||||||
update_freq,
|
|
||||||
block_size=block_size,
|
|
||||||
quantization_bits=8, # FP8 uses 8 bits
|
|
||||||
optimizer_state_class=OptimStateFp8,
|
|
||||||
)
|
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
"""Module for tokenization utilities"""
|
"""Module for tokenization utilities"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
@@ -93,65 +91,3 @@ def check_rl_example_labels(example, tokenizer, text_only=False):
|
|||||||
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
|
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
|
||||||
|
|
||||||
return delimiter.join(colored_tokens)
|
return delimiter.join(colored_tokens)
|
||||||
|
|
||||||
|
|
||||||
GLAIVE_ROLES = ["USER", "ASSISTANT", "FUNCTION RESPONSE"]
|
|
||||||
GLAIVE_TO_SHAREGPT_ROLE = {
|
|
||||||
"SYSTEM": "system",
|
|
||||||
"USER": "human",
|
|
||||||
"ASSISTANT": "gpt",
|
|
||||||
"FUNCTION RESPONSE": "tool",
|
|
||||||
}
|
|
||||||
|
|
||||||
GLAIVE_MSG_REGEX = re.compile(rf"({'|'.join(GLAIVE_ROLES)}): ")
|
|
||||||
|
|
||||||
|
|
||||||
def chatml_to_conversation(row: Dict[str, str]) -> List[Dict[str, str]]:
|
|
||||||
"""
|
|
||||||
Converts a ChatML formatted row to a list of messages in ShareGPT format.
|
|
||||||
Initially based off https://github.com/lilacai/lilac/blob/main/notebooks/GlaiveToShareGPT.ipynb.
|
|
||||||
"""
|
|
||||||
|
|
||||||
system_prompt = row.get("system")
|
|
||||||
if system_prompt:
|
|
||||||
system_prompt = system_prompt.removeprefix("SYSTEM: ")
|
|
||||||
|
|
||||||
chat_str = row["chat"]
|
|
||||||
chat_msgs = [s.strip() for s in GLAIVE_MSG_REGEX.split(chat_str) if s]
|
|
||||||
|
|
||||||
chat_msg_dicts = [
|
|
||||||
{"from": GLAIVE_TO_SHAREGPT_ROLE[role], "value": value}
|
|
||||||
for role, value in zip(chat_msgs[::2], chat_msgs[1::2])
|
|
||||||
]
|
|
||||||
|
|
||||||
if system_prompt:
|
|
||||||
chat_msg_dicts = [
|
|
||||||
{"from": GLAIVE_TO_SHAREGPT_ROLE["SYSTEM"], "value": system_prompt}
|
|
||||||
] + chat_msg_dicts
|
|
||||||
|
|
||||||
return chat_msg_dicts
|
|
||||||
|
|
||||||
|
|
||||||
def merge_consecutive_messages(messages):
|
|
||||||
"""
|
|
||||||
Merge consecutive messages from the same sender into a single message.
|
|
||||||
This can be useful with datasets that contain multiple consecutive tool calls.
|
|
||||||
"""
|
|
||||||
|
|
||||||
merged_messages = []
|
|
||||||
current_from = None
|
|
||||||
current_message = ""
|
|
||||||
|
|
||||||
for msg in messages:
|
|
||||||
if current_from == msg["from"]:
|
|
||||||
current_message += msg["value"]
|
|
||||||
else:
|
|
||||||
if current_from is not None:
|
|
||||||
merged_messages.append({"from": current_from, "value": current_message})
|
|
||||||
current_from = msg["from"]
|
|
||||||
current_message = msg["value"]
|
|
||||||
|
|
||||||
if current_from is not None:
|
|
||||||
merged_messages.append({"from": current_from, "value": current_message})
|
|
||||||
|
|
||||||
return merged_messages
|
|
||||||
|
|||||||
@@ -16,7 +16,11 @@ from torch.utils.data import DataLoader, RandomSampler
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
|
from axolotl.monkeypatch.trainer_fsdp_grad_accum import (
|
||||||
|
patch_training_loop_for_fsdp_grad_accum,
|
||||||
|
)
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
LOG = get_logger("axolotl")
|
LOG = get_logger("axolotl")
|
||||||
@@ -184,11 +188,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
min_sequence_len=cfg.min_sample_len or 2,
|
min_sequence_len=cfg.min_sample_len or 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.is_preprocess:
|
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
||||||
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
||||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
|
||||||
|
|
||||||
if cfg.model_config_type == "mamba":
|
if cfg.model_config_type == "mamba":
|
||||||
LOG.info("dropping attention_mask column")
|
LOG.info("dropping attention_mask column")
|
||||||
@@ -461,6 +464,9 @@ def setup_fsdp_envs(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_optim_env(cfg):
|
def prepare_optim_env(cfg):
|
||||||
|
if not check_cuda_p2p_ib_support():
|
||||||
|
if os.getenv("NCCL_P2P_DISABLE") is None:
|
||||||
|
os.environ["NCCL_P2P_DISABLE"] = "1"
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
@@ -490,6 +496,11 @@ def prepare_opinionated_env(cfg):
|
|||||||
def setup_trainer(
|
def setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||||
):
|
):
|
||||||
|
if cfg.fsdp:
|
||||||
|
try:
|
||||||
|
patch_training_loop_for_fsdp_grad_accum()
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
|
|||||||
16
tests/e2e/conftest.py
Normal file
16
tests/e2e/conftest.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
shared pytest fixtures
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
# Create a temporary directory
|
||||||
|
_temp_dir = tempfile.mkdtemp()
|
||||||
|
yield _temp_dir
|
||||||
|
# Clean up the directory after the test
|
||||||
|
shutil.rmtree(_temp_dir)
|
||||||
@@ -3,28 +3,25 @@ E2E tests for multigpu eval
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import with_temp_dir
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
class TestMultiGPUEval(unittest.TestCase):
|
class TestMultiGPUEval:
|
||||||
"""
|
"""
|
||||||
Test case for MultiGPU Eval Sample Packing
|
Test case for MultiGPU Eval Sample Packing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_eval_sample_packing(self, temp_dir):
|
def test_eval_sample_packing(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -83,13 +80,14 @@ class TestMultiGPUEval(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_eval(self, temp_dir):
|
def test_eval(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -148,6 +146,8 @@ class TestMultiGPUEval(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
|||||||
@@ -4,17 +4,17 @@ E2E tests for multigpu lora tinyllama
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import is_hopper, with_temp_dir
|
from ..utils import is_hopper
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -28,18 +28,16 @@ def download_model():
|
|||||||
snapshot_download("TinyLlama/TinyLlama_v1.1")
|
snapshot_download("TinyLlama/TinyLlama_v1.1")
|
||||||
|
|
||||||
|
|
||||||
class TestMultiGPULlama(unittest.TestCase):
|
class TestMultiGPULlama:
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using LoRA
|
Test case for Llama models using LoRA
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_lora_ddp(self, temp_dir):
|
def test_lora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
@@ -48,9 +46,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -81,19 +77,23 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize(
|
||||||
def test_lora_ddp_packed(self, temp_dir):
|
"gradient_accumulation_steps",
|
||||||
|
[1, 4],
|
||||||
|
)
|
||||||
|
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_sample_packing": False,
|
"eval_sample_packing": False,
|
||||||
@@ -105,9 +105,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -118,7 +116,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 15,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
@@ -138,6 +136,8 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
@@ -145,7 +145,6 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
|
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
|
||||||
@with_temp_dir
|
|
||||||
def test_dpo_lora_ddp(self, temp_dir):
|
def test_dpo_lora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -210,13 +209,14 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_dpo_qlora_ddp(self, temp_dir):
|
def test_dpo_qlora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -278,25 +278,27 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize(
|
||||||
def test_fsdp(self, temp_dir):
|
"gradient_accumulation_steps",
|
||||||
|
[1, 4],
|
||||||
|
)
|
||||||
|
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.01,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -305,9 +307,9 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 10,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
@@ -324,7 +326,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"fsdp_use_orig_params": False,
|
"fsdp_use_orig_params": False,
|
||||||
"fsdp_cpu_ram_efficient_loading": False,
|
"fsdp_cpu_ram_efficient_loading": False,
|
||||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -341,28 +343,29 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize(
|
||||||
def test_fsdp_packed(self, temp_dir):
|
"fsdp_state_dict_type",
|
||||||
|
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
|
||||||
|
)
|
||||||
|
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_sample_packing": False,
|
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -390,7 +393,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"fsdp_use_orig_params": False,
|
"fsdp_use_orig_params": False,
|
||||||
"fsdp_cpu_ram_efficient_loading": False,
|
"fsdp_cpu_ram_efficient_loading": False,
|
||||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
"fsdp_state_dict_type": fsdp_state_dict_type,
|
||||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -407,13 +410,14 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -483,28 +487,29 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize(
|
||||||
def test_ds_zero3_packed(self, temp_dir):
|
"gradient_accumulation_steps",
|
||||||
|
[1, 4],
|
||||||
|
)
|
||||||
|
def test_ds_zero3_packed(self, temp_dir, gradient_accumulation_steps):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_sample_packing": False,
|
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -515,7 +520,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 15,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
@@ -536,19 +541,19 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_ds_zero3_qlora_packed(self, temp_dir):
|
def test_ds_zero3_qlora_packed(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
@@ -561,9 +566,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -595,6 +598,8 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
|||||||
@@ -4,31 +4,30 @@ E2E tests for multigpu qwen2
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import with_temp_dir
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
class TestMultiGPUQwen2(unittest.TestCase):
|
class TestMultiGPUQwen2:
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using LoRA
|
Test case for Llama models using LoRA
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
||||||
def test_qlora_fsdp_dpo(self, temp_dir):
|
def test_qlora_fsdp_dpo(self, base_model, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "Qwen/Qwen2-1.5B",
|
"base_model": base_model,
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"rl": "dpo",
|
"rl": "dpo",
|
||||||
"chat_template": "chatml",
|
"chat_template": "chatml",
|
||||||
@@ -47,9 +46,9 @@ class TestMultiGPUQwen2(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 5,
|
||||||
"warmup_steps": 20,
|
"warmup_steps": 20,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
@@ -91,6 +90,8 @@ class TestMultiGPUQwen2(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
|||||||
15
tests/e2e/patched/test_trainer_fsdp.py
Normal file
15
tests/e2e/patched/test_trainer_fsdp.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.trainer_fsdp_grad_accum import check_training_loop_is_patchable
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainerFSDPIntegration(unittest.TestCase):
|
||||||
|
"""Unsloth monkeypatch integration tests."""
|
||||||
|
|
||||||
|
def test_train_loop_patchable(self):
|
||||||
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
|
self.assertTrue(
|
||||||
|
check_training_loop_is_patchable(),
|
||||||
|
"HF transformers _inner_training_loop has changed and isn't patchable",
|
||||||
|
)
|
||||||
@@ -115,6 +115,51 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_dpo_use_weighting(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 64,
|
||||||
|
"lora_alpha": 32,
|
||||||
|
"lora_dropout": 0.1,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"special_tokens": {},
|
||||||
|
"rl": "dpo",
|
||||||
|
"dpo_use_weighting": True,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
|
||||||
|
"type": "chatml.ultra",
|
||||||
|
"split": "train",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "paged_adamw_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"warmup_steps": 5,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
|
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||||
|
|
||||||
@pytest.mark.skip("kto_pair no longer supported in trl")
|
@pytest.mark.skip("kto_pair no longer supported in trl")
|
||||||
@with_temp_dir
|
@with_temp_dir
|
||||||
def test_kto_pair_lora(self, temp_dir):
|
def test_kto_pair_lora(self, temp_dir):
|
||||||
|
|||||||
66
tests/e2e/test_llama.py
Normal file
66
tests/e2e/test_llama.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLlama(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_fft_trust_remote_code(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"bf16": True,
|
||||||
|
"save_safetensors": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
@@ -13,7 +13,7 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import with_temp_dir
|
from .utils import require_torch_2_5_1, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -65,3 +65,80 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
@require_torch_2_5_1
|
||||||
|
def test_adopt_adamw(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adopt_adamw",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "schedule_free_adamw",
|
||||||
|
"lr_scheduler": "constant",
|
||||||
|
"save_safetensors": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
|
|||||||
85
tests/e2e/test_qwen.py
Normal file
85
tests/e2e/test_qwen.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for qwen
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.qwen")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestE2eQwen:
|
||||||
|
"""
|
||||||
|
Test cases for qwen models
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
||||||
|
def test_dpo(self, base_model, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": base_model,
|
||||||
|
"rl": "dpo",
|
||||||
|
"chat_template": "qwen_25",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||||
|
"split": "train",
|
||||||
|
"type": "chat_template.default",
|
||||||
|
"field_messages": "conversation",
|
||||||
|
"field_chosen": "chosen",
|
||||||
|
"field_rejected": "rejected",
|
||||||
|
"message_field_role": "role",
|
||||||
|
"message_field_content": "content",
|
||||||
|
"roles": {
|
||||||
|
"system": ["system"],
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
|
"warmup_steps": 20,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
"tf32": True,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
]
|
||||||
|
)
|
||||||
@@ -6,11 +6,13 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from importlib.metadata import version
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
# from importlib.metadata import version
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
def with_temp_dir(test_func):
|
def with_temp_dir(test_func):
|
||||||
@wraps(test_func)
|
@wraps(test_func)
|
||||||
@@ -43,12 +45,24 @@ def require_torch_2_3_1(test_case):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def is_min_2_3_1():
|
def is_min_2_3_1():
|
||||||
torch_version = version("torch")
|
torch_version = version.parse(torch.__version__)
|
||||||
return torch_version >= "2.3.1"
|
return torch_version >= version.parse("2.3.1")
|
||||||
|
|
||||||
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
|
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_torch_2_5_1(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires torch >= 2.3.1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def is_min_2_5_1():
|
||||||
|
torch_version = version.parse(torch.__version__)
|
||||||
|
return torch_version >= version.parse("2.5.1")
|
||||||
|
|
||||||
|
return unittest.skipUnless(is_min_2_5_1(), "test torch 2.5.1")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def is_hopper():
|
def is_hopper():
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
return compute_capability == (9, 0)
|
return compute_capability == (9, 0)
|
||||||
|
|||||||
@@ -1,500 +0,0 @@
|
|||||||
"""
|
|
||||||
Test module for sharegpt integration w chatml
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from datasets import Dataset
|
|
||||||
from tokenizers import AddedToken
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from axolotl.datasets import TokenizedPromptDataset
|
|
||||||
from axolotl.prompt_strategies.sharegpt import (
|
|
||||||
GlaiveShareGPTPromptTokenizingStrategy,
|
|
||||||
SimpleShareGPTPromptTokenizingStrategy,
|
|
||||||
register_chatml_template,
|
|
||||||
register_llama3_template,
|
|
||||||
)
|
|
||||||
from axolotl.prompters import ShareGPTPrompterV2
|
|
||||||
|
|
||||||
register_chatml_template()
|
|
||||||
register_llama3_template()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="sharegpt_dataset")
|
|
||||||
def fixture_sharegpt_dataset():
|
|
||||||
return Dataset.from_list(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{
|
|
||||||
"from": "system",
|
|
||||||
"value": "repeat",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "hello",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "hello",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "goodbye",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "goodbye",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="sharegpt_dataset_with_weights")
|
|
||||||
def fixture_sharegpt_dataset_with_weights():
|
|
||||||
return Dataset.from_list(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{
|
|
||||||
"from": "system",
|
|
||||||
"value": "repeat",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "hello",
|
|
||||||
"weight": 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "hello",
|
|
||||||
"weight": 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "rehello",
|
|
||||||
"weight": 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "rehello",
|
|
||||||
"weight": 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "goodbye",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "goodbye",
|
|
||||||
"weight": 0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="glaive_dataset")
|
|
||||||
def fixture_sharegpt_glaive_dataset():
|
|
||||||
return Dataset.from_list(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"system": "SYSTEM: This is a system prompt",
|
|
||||||
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="multi_role_dataset")
|
|
||||||
def fixture_multi_role_dataset():
|
|
||||||
return Dataset.from_list(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"conversations": [
|
|
||||||
{
|
|
||||||
"from": "system",
|
|
||||||
"value": "use get_weather(city) to get the weather for a city",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "human",
|
|
||||||
"value": "hello, what's the weather in New York?",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "let me get that for you",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "tool",
|
|
||||||
"value": "get_weather(New York)",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"from": "gpt",
|
|
||||||
"value": "the weather in New York is 70 degrees and sunny",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="tokenizer")
|
|
||||||
def fixture_tokenizer():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
"casperhansen/mistral-7b-instruct-v0.1-awq"
|
|
||||||
)
|
|
||||||
tokenizer.add_special_tokens(
|
|
||||||
{
|
|
||||||
"eos_token": AddedToken(
|
|
||||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tokenizer.add_tokens(
|
|
||||||
[
|
|
||||||
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="llama3_tokenizer")
|
|
||||||
def fixture_llama3_tokenizer():
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
|
||||||
tokenizer.eos_token = "<|eot_id|>"
|
|
||||||
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
class TestSharegptLlama3:
|
|
||||||
"""Test class for ShareGPT style datasets with llama-3 prompts"""
|
|
||||||
|
|
||||||
def test_tokenization(self, sharegpt_dataset, llama3_tokenizer):
|
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation="llama3",
|
|
||||||
role_key_model=None,
|
|
||||||
role_key_human=None,
|
|
||||||
),
|
|
||||||
llama3_tokenizer,
|
|
||||||
False, # train_on_inputs
|
|
||||||
2048, # sequence_len
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
|
||||||
strategy, sharegpt_dataset, process_count=1
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = dataset_wrapper[0]["input_ids"]
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
assert input_ids == [
|
|
||||||
128000, # bos
|
|
||||||
128006, 9125, 128007, # system header
|
|
||||||
271, 31724, 128009, # sys prompt, eot
|
|
||||||
128006, 882, 128007, # user header
|
|
||||||
271, 15339, 128009, # user prompt eot
|
|
||||||
128006, 78191, 128007, # assistant header
|
|
||||||
271, 15339, 128009, # assistant response eot
|
|
||||||
128006, 882, 128007,
|
|
||||||
271, 19045, 29474, 128009,
|
|
||||||
128006, 78191, 128007,
|
|
||||||
271, 19045, 29474, 128009,
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def test_tokenization_with_weights(
|
|
||||||
self, sharegpt_dataset_with_weights, llama3_tokenizer
|
|
||||||
):
|
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation="llama3",
|
|
||||||
role_key_model=None,
|
|
||||||
role_key_human=None,
|
|
||||||
),
|
|
||||||
llama3_tokenizer,
|
|
||||||
False, # train_on_inputs
|
|
||||||
2048, # sequence_len
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
|
||||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = dataset_wrapper[0]["input_ids"]
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
assert input_ids == [
|
|
||||||
128000, # bos
|
|
||||||
128006, 9125, 128007, # system header
|
|
||||||
271, 31724, 128009, # sys prompt, eot
|
|
||||||
128006, 882, 128007, # user header
|
|
||||||
271, 15339, 128009, # user prompt eot
|
|
||||||
128006, 78191, 128007, # assistant header
|
|
||||||
271, 15339, 128009, # assistant response eot
|
|
||||||
128006, 882, 128007,
|
|
||||||
271, 11310, 4896, 128009,
|
|
||||||
128006, 78191, 128007,
|
|
||||||
271, 11310, 4896, 128009,
|
|
||||||
128006, 882, 128007,
|
|
||||||
271, 19045, 29474, 128009,
|
|
||||||
128006, 78191, 128007,
|
|
||||||
271, 19045, 29474, 128009,
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class TestSharegptChatML:
|
|
||||||
"""
|
|
||||||
Test class for sharegpt prompter
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_no_double_im_end(self, sharegpt_dataset, tokenizer):
|
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation="chatml",
|
|
||||||
role_key_model=None,
|
|
||||||
role_key_human=None,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
False, # train_on_inputs
|
|
||||||
2048, # sequence_len
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
|
||||||
strategy, sharegpt_dataset, process_count=1
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = dataset_wrapper[0]["input_ids"]
|
|
||||||
# fmt: off
|
|
||||||
assert input_ids == [
|
|
||||||
# 28705, 13, is " \n"
|
|
||||||
1, # bos
|
|
||||||
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
|
||||||
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
|
||||||
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
|
|
||||||
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
|
||||||
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def test_no_double_im_end_with_weights(
|
|
||||||
self, sharegpt_dataset_with_weights, tokenizer
|
|
||||||
):
|
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation="chatml",
|
|
||||||
role_key_model=None,
|
|
||||||
role_key_human=None,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
False, # train_on_inputs
|
|
||||||
2048, # sequence_len
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
|
||||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = dataset_wrapper[0]["input_ids"]
|
|
||||||
# fmt: off
|
|
||||||
assert input_ids == [
|
|
||||||
# 28705, 13, is " \n"
|
|
||||||
1, # bos
|
|
||||||
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
|
||||||
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
|
||||||
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
|
|
||||||
32001, 2188, 13, 267, 21558, 32000, 28705, 13, # human
|
|
||||||
32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt
|
|
||||||
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
|
||||||
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def test_no_train_on_input(self, sharegpt_dataset, tokenizer):
|
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation="chatml",
|
|
||||||
role_key_model=None,
|
|
||||||
role_key_human=None,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
False, # train_on_inputs
|
|
||||||
2048, # sequence_len
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
|
||||||
strategy, sharegpt_dataset, process_count=1
|
|
||||||
)
|
|
||||||
|
|
||||||
labels = dataset_wrapper[0]["labels"]
|
|
||||||
# fmt: off
|
|
||||||
assert labels == [
|
|
||||||
-100, # bos
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, # system
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, # human
|
|
||||||
-100, -100, 13, 21558, 32000, 28705, 13, # gpt
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, -100, # human
|
|
||||||
-100, -100, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def test_no_train_on_input_with_weights(
|
|
||||||
self, sharegpt_dataset_with_weights, tokenizer
|
|
||||||
):
|
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation="chatml",
|
|
||||||
role_key_model=None,
|
|
||||||
role_key_human=None,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
False, # train_on_inputs
|
|
||||||
2048, # sequence_len
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
|
||||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
|
||||||
)
|
|
||||||
|
|
||||||
labels = dataset_wrapper[0]["labels"]
|
|
||||||
# fmt: off
|
|
||||||
assert labels == [
|
|
||||||
-100, # bos
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, # system
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, # human
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, # gpt with weight zero
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, -100, # human
|
|
||||||
-100, -100, 13, 267, 21558, 32000, 28705, 13, # gpt
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, -100, # human
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight zero
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def test_w_train_on_input(self, sharegpt_dataset, tokenizer):
|
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation="chatml",
|
|
||||||
role_key_model=None,
|
|
||||||
role_key_human=None,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
True, # train_on_inputs
|
|
||||||
2048, # sequence_len
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
|
||||||
strategy, sharegpt_dataset, process_count=1
|
|
||||||
)
|
|
||||||
|
|
||||||
labels = dataset_wrapper[0]["labels"]
|
|
||||||
# fmt: off
|
|
||||||
assert labels == [
|
|
||||||
1, # bos
|
|
||||||
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
|
||||||
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
|
||||||
32001, 13892, 13, 21558, 32000, 28705, 13, # gpt
|
|
||||||
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
|
||||||
32001, 13892, 13, 12684, 17664, 32000, 28705, 13, # gpt
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def test_w_train_on_input_with_weights(
|
|
||||||
self, sharegpt_dataset_with_weights, tokenizer
|
|
||||||
):
|
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation="chatml",
|
|
||||||
role_key_model=None,
|
|
||||||
role_key_human=None,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
True, # train_on_inputs
|
|
||||||
2048, # sequence_len
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
|
||||||
strategy, sharegpt_dataset_with_weights, process_count=1
|
|
||||||
)
|
|
||||||
|
|
||||||
labels = dataset_wrapper[0]["labels"]
|
|
||||||
# fmt: off
|
|
||||||
assert labels == [
|
|
||||||
1, # bos
|
|
||||||
32001, 1587, 13, 25997, 32000, 28705, 13, # system
|
|
||||||
32001, 2188, 13, 21558, 32000, 28705, 13, # human
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, # gpt with weight 0
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, -100, # human with weight 0
|
|
||||||
32001, 13892, 13, 267, 21558, 32000, 28705, 13, # gpt
|
|
||||||
32001, 2188, 13, 12684, 17664, 32000, 28705, 13, # human
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, -100 # gpt with weight 0
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def test_chatml_glaive(self, glaive_dataset, tokenizer):
|
|
||||||
strategy = GlaiveShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(
|
|
||||||
conversation="chatml",
|
|
||||||
role_key_model=None,
|
|
||||||
role_key_human=None,
|
|
||||||
),
|
|
||||||
tokenizer,
|
|
||||||
True, # train_on_inputs
|
|
||||||
2048, # sequence_len
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
|
||||||
strategy, glaive_dataset, process_count=1
|
|
||||||
)
|
|
||||||
|
|
||||||
labels = dataset_wrapper[0]["labels"]
|
|
||||||
# fmt: off
|
|
||||||
assert labels == [
|
|
||||||
1, # bos
|
|
||||||
32001, 1587, 13, 3260, 349, 264, 1587, 11510, 32000, 28705, 13, # system
|
|
||||||
32001, 2188, 13, 6325, 368, 1820, 264, 9314, 354, 528, 477, 1450, 2726, 298, 4222, 28804, 32000, 28705, 13, # human
|
|
||||||
32001, 13892, 13, 28737, 28742, 28719, 7371, 28725, 562, 315, 949, 28742, 28707, 506, 272, 21368, 298, 1820, 22447, 28723, 28705, 523, 28766, 416, 1009, 772, 28766, 28767, 32000, 28705, 13 # gpt
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def test_multi_role_dataset(self, multi_role_dataset, tokenizer):
|
|
||||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
|
||||||
ShareGPTPrompterV2(conversation="chatml", roles={"input": ["tool"]}),
|
|
||||||
tokenizer,
|
|
||||||
False, # train_on_inputs
|
|
||||||
2048, # sequence_len
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_wrapper = TokenizedPromptDataset(
|
|
||||||
strategy, multi_role_dataset, process_count=1
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = dataset_wrapper[0]["input_ids"]
|
|
||||||
# fmt: off
|
|
||||||
assert input_ids == [
|
|
||||||
1, # bos
|
|
||||||
32001, 1587, 13, 1730, 625, 28730, 769, 1223, 28732, 18373, 28731, 298, 625, 272, 8086, 354, 264, 2990, 32000, 28705, 13, # system
|
|
||||||
32001, 2188, 13, 21558, 28725, 767, 28742, 28713, 272, 8086, 297, 1450, 2726, 28804, 32000, 28705, 13, # human
|
|
||||||
32001, 13892, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
|
||||||
32001, 3921, 13, 527, 28730, 769, 1223, 28732, 2972, 2726, 28731, 32000, 28705, 13, # tool
|
|
||||||
32001, 13892, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
labels = dataset_wrapper[0]["labels"]
|
|
||||||
# fmt: off
|
|
||||||
assert labels == [
|
|
||||||
-100, # bos
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # system
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # human
|
|
||||||
-100, -100, 13, 895, 528, 625, 369, 354, 368, 32000, 28705, 13, # gpt
|
|
||||||
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, # tool
|
|
||||||
-100, -100, 13, 1237, 8086, 297, 1450, 2726, 349, 28705, 28787, 28734, 11182, 304, 4376, 1780, 32000, 28705, 13 # gpt
|
|
||||||
]
|
|
||||||
# fmt: on
|
|
||||||
@@ -371,44 +371,79 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
def test_load_local_hub_with_revision(self):
|
def test_load_local_hub_with_revision(self):
|
||||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir2:
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
tmp_ds_path = Path(tmp_dir2) / "mhenrichsen/alpaca_2k_test"
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
snapshot_download(
|
||||||
snapshot_download(
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
repo_type="dataset",
|
||||||
repo_type="dataset",
|
local_dir=tmp_ds_path,
|
||||||
local_dir=tmp_ds_path,
|
revision="d05c1cb",
|
||||||
revision="d05c1cb",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
"ds_type": "parquet",
|
"ds_type": "parquet",
|
||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
"data_files": [
|
"data_files": [
|
||||||
f"{tmp_ds_path}/alpaca_2000.parquet",
|
f"{tmp_ds_path}/alpaca_2000.parquet",
|
||||||
],
|
],
|
||||||
"revision": "d05c1cb",
|
"revision": "d05c1cb",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
self.tokenizer, cfg, prepared_path
|
self.tokenizer, cfg, prepared_path
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
|
def test_loading_local_dataset_folder(self):
|
||||||
|
"""Verify that a dataset downloaded to a local folder can be loaded"""
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=tmp_ds_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": str(tmp_ds_path),
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(dataset) == 2000
|
||||||
|
assert "input_ids" in dataset.features
|
||||||
|
assert "attention_mask" in dataset.features
|
||||||
|
assert "labels" in dataset.features
|
||||||
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -39,12 +39,12 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
|||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "lorem/ipsum",
|
"path": "lorem/ipsum",
|
||||||
"type": "sharegpt",
|
"type": "chat_template",
|
||||||
"conversation": "vicuna_v1.1",
|
"chat_template": "gemma",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"path": "sit/amet",
|
"path": "sit/amet",
|
||||||
"type": "sharegpt",
|
"type": "chat_template",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@@ -52,8 +52,8 @@ class NormalizeConfigTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
normalize_cfg_datasets(cfg)
|
normalize_cfg_datasets(cfg)
|
||||||
|
|
||||||
assert cfg.datasets[0].conversation == "vicuna_v1.1"
|
assert cfg.datasets[0].chat_template == "gemma"
|
||||||
assert cfg.datasets[1].conversation == "chatml"
|
assert cfg.datasets[1].chat_template == "chatml"
|
||||||
|
|
||||||
@patch("axolotl.utils.config.is_torch_bf16_gpu_available")
|
@patch("axolotl.utils.config.is_torch_bf16_gpu_available")
|
||||||
def test_bf16_auto_setter_available(self, mock_bf16_avail):
|
def test_bf16_auto_setter_available(self, mock_bf16_avail):
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
import functools
|
import functools
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@@ -21,6 +22,7 @@ class TestPretrainingPacking(unittest.TestCase):
|
|||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
self.tokenizer.pad_token = "</s>"
|
self.tokenizer.pad_token = "</s>"
|
||||||
|
|
||||||
|
@pytest.mark.flaky(retries=3, delay=5)
|
||||||
def test_packing_stream_dataset(self):
|
def test_packing_stream_dataset(self):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import unittest
|
import unittest
|
||||||
from copy import deepcopy
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -21,12 +20,8 @@ from axolotl.prompt_strategies.llama2_chat import (
|
|||||||
LLama2ChatTokenizingStrategy,
|
LLama2ChatTokenizingStrategy,
|
||||||
)
|
)
|
||||||
from axolotl.prompt_strategies.orpo.chat_template import load
|
from axolotl.prompt_strategies.orpo.chat_template import load
|
||||||
from axolotl.prompt_strategies.sharegpt import GlaiveShareGPTPromptTokenizingStrategy
|
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
||||||
from axolotl.prompt_tokenizers import (
|
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
||||||
AlpacaPromptTokenizingStrategy,
|
|
||||||
ShareGPTPromptTokenizingStrategy,
|
|
||||||
)
|
|
||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
@@ -65,17 +60,6 @@ test_data = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def prompt_strat(conversation, tokenizer):
|
|
||||||
"Helper function to create a prompt strategy for testing."
|
|
||||||
prompter = ShareGPTPrompterV2(conversation=conversation)
|
|
||||||
return ShareGPTPromptTokenizingStrategy(
|
|
||||||
prompter,
|
|
||||||
tokenizer,
|
|
||||||
False,
|
|
||||||
2048,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestPromptTokenizationStrategies(unittest.TestCase):
|
class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
Test class for prompt tokenization strategies.
|
Test class for prompt tokenization strategies.
|
||||||
@@ -98,196 +82,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_sharegpt_integration(self):
|
|
||||||
with open(
|
|
||||||
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
|
||||||
) as fin:
|
|
||||||
data = fin.read()
|
|
||||||
conversation = json.loads(data)
|
|
||||||
with open(
|
|
||||||
Path(__file__).parent / "fixtures/conversation.tokenized.json",
|
|
||||||
encoding="utf-8",
|
|
||||||
) as fin:
|
|
||||||
data = fin.read()
|
|
||||||
tokenized_conversation = json.loads(data)
|
|
||||||
prompter = ShareGPTPrompterV2()
|
|
||||||
strat = ShareGPTPromptTokenizingStrategy(
|
|
||||||
prompter,
|
|
||||||
self.tokenizer,
|
|
||||||
False,
|
|
||||||
2048,
|
|
||||||
)
|
|
||||||
example = strat.tokenize_prompt(conversation)
|
|
||||||
for fields in ["input_ids", "attention_mask", "labels"]:
|
|
||||||
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
|
||||||
self.assertEqual(example[fields], tokenized_conversation[fields])
|
|
||||||
|
|
||||||
def test_sharegpt_warnings_integration(self):
|
|
||||||
with open(
|
|
||||||
Path(__file__).parent / "fixtures/conversation.missingturns.json",
|
|
||||||
encoding="utf-8",
|
|
||||||
) as fin:
|
|
||||||
data = fin.read()
|
|
||||||
conversation = json.loads(data)
|
|
||||||
prompter = ShareGPTPrompterV2()
|
|
||||||
strat = ShareGPTPromptTokenizingStrategy(
|
|
||||||
prompter,
|
|
||||||
self.tokenizer,
|
|
||||||
False,
|
|
||||||
2048,
|
|
||||||
)
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
strat.tokenize_prompt(conversation)
|
|
||||||
assert "assistant turn has empty text" in self._caplog.records[1].message
|
|
||||||
|
|
||||||
def test_sharegpt_warnings_turns(self):
|
|
||||||
conversation = {
|
|
||||||
"conversations": [
|
|
||||||
{"from": "system", "value": "lorem"},
|
|
||||||
{"from": "gpt", "value": "ipsum"},
|
|
||||||
{"from": "human", "value": "dolor"},
|
|
||||||
{"from": "human", "value": "dolor"},
|
|
||||||
{"from": "gpt", "value": "sit"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
prompter = ShareGPTPrompterV2()
|
|
||||||
strat = ShareGPTPromptTokenizingStrategy(
|
|
||||||
prompter,
|
|
||||||
self.tokenizer,
|
|
||||||
False,
|
|
||||||
2048,
|
|
||||||
)
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
strat.tokenize_prompt(conversation)
|
|
||||||
assert (
|
|
||||||
"Role did not alternate between turns (gpt and human)"
|
|
||||||
in self._caplog.records[0].message
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_sharegpt_llama(self):
|
|
||||||
"Make sure the sharegpt/llama is tokenized and formatted correctly."
|
|
||||||
strat = prompt_strat("llama-2", self.tokenizer)
|
|
||||||
|
|
||||||
def tokenize(conv):
|
|
||||||
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
|
|
||||||
|
|
||||||
def decode(ids):
|
|
||||||
return strat.tokenizer.decode(ids)
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
# System message, multi-turn conversations
|
|
||||||
mt_ids = tokenize(test_data['multi_turn_sys'])
|
|
||||||
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
|
||||||
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
|
||||||
|
|
||||||
# System message, single-turn conversations
|
|
||||||
st_ids = tokenize(test_data['single_turn_sys'])
|
|
||||||
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
|
|
||||||
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
|
||||||
|
|
||||||
# No system message, single-turn
|
|
||||||
ns_ids = tokenize(test_data['single_turn_no_sys'])
|
|
||||||
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
|
||||||
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
|
||||||
|
|
||||||
# No system message, multi-turn
|
|
||||||
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
|
|
||||||
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
|
||||||
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def test_sharegpt_mistral(self):
|
|
||||||
"Make sure the sharegpt/mistral is tokenized and formatted correctly."
|
|
||||||
strat = prompt_strat("mistral", self.tokenizer)
|
|
||||||
|
|
||||||
def tokenize(conv):
|
|
||||||
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
|
|
||||||
|
|
||||||
def decode(ids):
|
|
||||||
return strat.tokenizer.decode(ids)
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
# System message, multi-turn conversations
|
|
||||||
mt_ids = tokenize(test_data['multi_turn_sys'])
|
|
||||||
assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
|
|
||||||
assert mt_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
|
||||||
|
|
||||||
# System message, single-turn conversations
|
|
||||||
st_ids = tokenize(test_data['single_turn_sys'])
|
|
||||||
assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
|
|
||||||
assert st_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
|
||||||
|
|
||||||
# No system message, single-turn
|
|
||||||
ns_ids = tokenize(test_data['single_turn_no_sys'])
|
|
||||||
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
|
||||||
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
|
||||||
|
|
||||||
# No system message, multi-turn
|
|
||||||
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
|
|
||||||
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
|
|
||||||
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def test_sharegpt_changes_roles(self):
|
|
||||||
conversation = {
|
|
||||||
"roles": ["USER", "CHARACTER"],
|
|
||||||
"conversations": [
|
|
||||||
{"from": "system", "value": "lorem"},
|
|
||||||
{"from": "gpt", "value": "ipsum"},
|
|
||||||
{"from": "human", "value": "dolor"},
|
|
||||||
{"from": "gpt", "value": "sit"},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
prompter = ShareGPTPrompterV2()
|
|
||||||
strat = ShareGPTPromptTokenizingStrategy(
|
|
||||||
prompter,
|
|
||||||
self.tokenizer,
|
|
||||||
False,
|
|
||||||
2048,
|
|
||||||
)
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
res = strat.tokenize_prompt(conversation)
|
|
||||||
assert "CHARACTER" in self.tokenizer.decode(res["input_ids"])
|
|
||||||
|
|
||||||
def test_sharegpt_assistant_label_ignore(self):
|
|
||||||
conversation = {
|
|
||||||
"roles": ["user", "assistant"],
|
|
||||||
"conversations": [
|
|
||||||
{"from": "system", "value": "lorem"},
|
|
||||||
{"from": "gpt", "value": "ipsum"},
|
|
||||||
{"from": "human", "value": "dolor"},
|
|
||||||
{"from": "gpt", "value": "sit"},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
prompter = ShareGPTPrompterV2()
|
|
||||||
strat = ShareGPTPromptTokenizingStrategy(
|
|
||||||
prompter,
|
|
||||||
self.tokenizer,
|
|
||||||
False,
|
|
||||||
2048,
|
|
||||||
)
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
res = strat.tokenize_prompt(conversation)
|
|
||||||
idx = res["input_ids"].index(20255) # assistant token
|
|
||||||
assert res["labels"][idx] == -100
|
|
||||||
|
|
||||||
def test_glaive_tool_label_ignore(self):
|
|
||||||
conversation = {
|
|
||||||
"system": "SYSTEM: This is a system prompt",
|
|
||||||
"chat": "USER: Can you book a flight for me from New York to London? ASSISTANT: I'm sorry, but I don't have the capability to book flights. <|endoftext|>",
|
|
||||||
}
|
|
||||||
prompter = ShareGPTPrompterV2()
|
|
||||||
strat = GlaiveShareGPTPromptTokenizingStrategy(
|
|
||||||
prompter,
|
|
||||||
self.tokenizer,
|
|
||||||
False,
|
|
||||||
2048,
|
|
||||||
)
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
res = strat.tokenize_prompt(conversation)
|
|
||||||
idx = res["input_ids"].index(13566) # assistant token
|
|
||||||
assert res["labels"][idx] == -100
|
|
||||||
|
|
||||||
def test_no_sys_prompt(self):
|
def test_no_sys_prompt(self):
|
||||||
"""
|
"""
|
||||||
tests the interface between the user and assistant parts
|
tests the interface between the user and assistant parts
|
||||||
|
|||||||
@@ -646,39 +646,6 @@ class TestValidation(BaseValidation):
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
def test_sharegpt_deprecation(self, minimal_cfg):
|
|
||||||
cfg = (
|
|
||||||
DictDefault(
|
|
||||||
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
|
|
||||||
)
|
|
||||||
| minimal_cfg
|
|
||||||
)
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
new_cfg = validate_config(cfg)
|
|
||||||
assert any(
|
|
||||||
"`type: sharegpt:chat` will soon be deprecated." in record.message
|
|
||||||
for record in self._caplog.records
|
|
||||||
)
|
|
||||||
assert new_cfg.datasets[0].type == "sharegpt"
|
|
||||||
|
|
||||||
cfg = (
|
|
||||||
DictDefault(
|
|
||||||
{
|
|
||||||
"datasets": [
|
|
||||||
{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
| minimal_cfg
|
|
||||||
)
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
new_cfg = validate_config(cfg)
|
|
||||||
assert any(
|
|
||||||
"`type: sharegpt_simple` will soon be deprecated." in record.message
|
|
||||||
for record in self._caplog.records
|
|
||||||
)
|
|
||||||
assert new_cfg.datasets[0].type == "sharegpt:load_role"
|
|
||||||
|
|
||||||
def test_no_conflict_save_strategy(self, minimal_cfg):
|
def test_no_conflict_save_strategy(self, minimal_cfg):
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
@@ -759,7 +726,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"evaluation_strategy": "epoch",
|
"eval_strategy": "epoch",
|
||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -767,14 +734,14 @@ class TestValidation(BaseValidation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
ValueError, match=r".*eval_strategy and eval_steps mismatch.*"
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"evaluation_strategy": "no",
|
"eval_strategy": "no",
|
||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -782,14 +749,14 @@ class TestValidation(BaseValidation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
ValueError, match=r".*eval_strategy and eval_steps mismatch.*"
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"evaluation_strategy": "steps",
|
"eval_strategy": "steps",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -800,7 +767,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"evaluation_strategy": "steps",
|
"eval_strategy": "steps",
|
||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -823,7 +790,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"evaluation_strategy": "no",
|
"eval_strategy": "no",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -834,7 +801,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"evaluation_strategy": "epoch",
|
"eval_strategy": "epoch",
|
||||||
"val_set_size": 0,
|
"val_set_size": 0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -843,7 +810,7 @@ class TestValidation(BaseValidation):
|
|||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
match=r".*eval_steps and eval_strategy are not supported with val_set_size == 0.*",
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
@@ -859,7 +826,7 @@ class TestValidation(BaseValidation):
|
|||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
match=r".*eval_steps and eval_strategy are not supported with val_set_size == 0.*",
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
@@ -889,7 +856,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"evaluation_strategy": "epoch",
|
"eval_strategy": "epoch",
|
||||||
"val_set_size": 0.01,
|
"val_set_size": 0.01,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -1128,6 +1095,24 @@ class TestValidation(BaseValidation):
|
|||||||
assert new_cfg["dpo_beta"] is None
|
assert new_cfg["dpo_beta"] is None
|
||||||
assert len(self._caplog.records) == 1
|
assert len(self._caplog.records) == 1
|
||||||
|
|
||||||
|
def test_eval_strategy_remap(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "steps",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
assert new_cfg.eval_strategy == "steps"
|
||||||
|
assert (
|
||||||
|
"evaluation_strategy is deprecated, use eval_strategy instead"
|
||||||
|
in self._caplog.records[0].message
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestValidationCheckModelConfig(BaseValidation):
|
class TestValidationCheckModelConfig(BaseValidation):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -48,9 +48,8 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
| {
|
| {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "LDJnr/Puffin",
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
"type": "sharegpt",
|
"type": "alpaca",
|
||||||
"conversation": "chatml",
|
|
||||||
"shards": 10,
|
"shards": 10,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -62,7 +61,6 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
def _check_config():
|
def _check_config():
|
||||||
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
|
assert checked_cfg.datasets[0].path == cfg.datasets[0].path
|
||||||
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
|
assert checked_cfg.datasets[0].type == cfg.datasets[0].type
|
||||||
assert checked_cfg.datasets[0].conversation == cfg.datasets[0].conversation
|
|
||||||
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
|
assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards
|
||||||
|
|
||||||
_check_config()
|
_check_config()
|
||||||
@@ -236,3 +234,59 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
_check_config()
|
_check_config()
|
||||||
|
|
||||||
|
def test_dataset_sharegpt_deprecation(self, minimal_cfg):
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"chat_template": "chatml",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "LDJnr/Puffin",
|
||||||
|
"type": "sharegpt",
|
||||||
|
"conversation": "chatml",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check sharegpt deprecation is raised
|
||||||
|
with pytest.raises(ValueError, match=r".*type: sharegpt.*` is deprecated.*"):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
# Check that deprecation is not thrown for non-str type
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": {
|
||||||
|
"field_instruction": "instruction",
|
||||||
|
"field_output": "output",
|
||||||
|
"field_system": "system",
|
||||||
|
"format": "<|user|> {instruction} {input} <|model|>",
|
||||||
|
"no_input_format": "<|user|> {instruction} <|model|>",
|
||||||
|
"system_prompt": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
# Check that deprecation is not thrown for non-sharegpt type
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user