Compare commits
37 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb8bfab9cc | ||
|
|
0c8b1d824a | ||
|
|
fd70eec577 | ||
|
|
d42f202046 | ||
|
|
0dabde1962 | ||
|
|
15f1462ccd | ||
|
|
521e62daf1 | ||
|
|
c16ec398d7 | ||
|
|
2f20cb7ebf | ||
|
|
71d4030b79 | ||
|
|
f3a5d119af | ||
|
|
ba219b51a5 | ||
|
|
5be8e13d35 | ||
|
|
2d7830fda6 | ||
|
|
5e98cdddac | ||
|
|
1d7aee0ad2 | ||
|
|
659ee5d723 | ||
|
|
342935cff3 | ||
|
|
c5eb9ea2c2 | ||
|
|
f2145a3ccb | ||
|
|
010d0e7ff3 | ||
|
|
01881c3113 | ||
|
|
0e8eb96e07 | ||
|
|
4e1891b12b | ||
|
|
28924fc791 | ||
|
|
8c480b2804 | ||
|
|
a4b1cc6df0 | ||
|
|
7b78a31593 | ||
|
|
810ebc2c0e | ||
|
|
ad435a3b09 | ||
|
|
9f1cf9b17c | ||
|
|
3931a42763 | ||
|
|
dc8f9059f7 | ||
|
|
234e94e9dd | ||
|
|
f68fb71005 | ||
|
|
9bc3ee6c75 | ||
|
|
d356740ffa |
12
.github/workflows/base.yml
vendored
12
.github/workflows/base.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
- cuda: "124"
|
- cuda: "124"
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
python_version: "3.11"
|
python_version: "3.10"
|
||||||
pytorch: 2.4.1
|
pytorch: 2.4.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "124"
|
- cuda: "124"
|
||||||
@@ -44,19 +44,21 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: Docker metadata
|
- name: Docker metadata
|
||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v3
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-base
|
images: |
|
||||||
|
winglian/axolotl-base
|
||||||
|
axolotlai/axolotl-base
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v2
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
2
.github/workflows/docs.yml
vendored
2
.github/workflows/docs.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
|||||||
- name: Set up Quarto
|
- name: Set up Quarto
|
||||||
uses: quarto-dev/quarto-actions/setup@v2
|
uses: quarto-dev/quarto-actions/setup@v2
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v3
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
- name: install dependencies
|
- name: install dependencies
|
||||||
|
|||||||
6
.github/workflows/lint.yml
vendored
6
.github/workflows/lint.yml
vendored
@@ -15,9 +15,9 @@ jobs:
|
|||||||
name: pre-commit
|
name: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
- uses: pre-commit/action@v3.0.0
|
- uses: pre-commit/action@v3.0.1
|
||||||
|
|||||||
33
.github/workflows/main.yml
vendored
33
.github/workflows/main.yml
vendored
@@ -4,6 +4,8 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
|
tags:
|
||||||
|
- "v*"
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
@@ -42,7 +44,12 @@ jobs:
|
|||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl
|
images: |
|
||||||
|
winglian/axolotl
|
||||||
|
axolotlai/axolotl
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=semver,pattern={{version}}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
@@ -56,7 +63,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
||||||
@@ -104,20 +111,25 @@ jobs:
|
|||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-cloud
|
images: |
|
||||||
|
winglian/axolotl-cloud
|
||||||
|
axolotlai/axolotl-cloud
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=semver,pattern={{version}}
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
file: ./docker/Dockerfile-cloud
|
file: ./docker/Dockerfile-cloud
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
@@ -146,20 +158,25 @@ jobs:
|
|||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-cloud-term
|
images: |
|
||||||
|
winglian/axolotl-cloud-term
|
||||||
|
axolotlai/axolotl-cloud-term
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=semver,pattern={{version}}
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
file: ./docker/Dockerfile-cloud-no-tmux
|
file: ./docker/Dockerfile-cloud-no-tmux
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
|
|||||||
5
.github/workflows/multi-gpu-e2e.yml
vendored
5
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -8,6 +8,11 @@ on:
|
|||||||
schedule:
|
schedule:
|
||||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||||
|
|
||||||
|
# Cancel jobs on the same ref if a new one is triggered
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-axolotl-multigpu:
|
test-axolotl-multigpu:
|
||||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||||
|
|||||||
10
.github/workflows/nightlies.yml
vendored
10
.github/workflows/nightlies.yml
vendored
@@ -41,7 +41,9 @@ jobs:
|
|||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl
|
images: |
|
||||||
|
winglian/axolotl
|
||||||
|
axolotlai/axolotl
|
||||||
tags: |
|
tags: |
|
||||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
@@ -103,7 +105,9 @@ jobs:
|
|||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-cloud
|
images: |
|
||||||
|
winglian/axolotl-cloud
|
||||||
|
axolotlai/axolotl-cloud
|
||||||
tags: |
|
tags: |
|
||||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
@@ -112,7 +116,7 @@ jobs:
|
|||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
|
|||||||
24
.github/workflows/pypi.yml
vendored
24
.github/workflows/pypi.yml
vendored
@@ -3,13 +3,31 @@ name: publish pypi
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- 'v*'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
setup_release:
|
||||||
|
name: Create Release
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Get the tag version
|
||||||
|
id: extract_branch
|
||||||
|
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
||||||
|
shell: bash
|
||||||
|
|
||||||
|
- name: Create Release
|
||||||
|
id: create_release
|
||||||
|
uses: actions/create-release@v1
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
with:
|
||||||
|
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
||||||
|
release_name: ${{ steps.extract_branch.outputs.branch }}
|
||||||
pypi-publish:
|
pypi-publish:
|
||||||
name: Upload release to PyPI
|
name: Upload release to PyPI
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
needs: [setup_release]
|
||||||
environment:
|
environment:
|
||||||
name: pypi
|
name: pypi
|
||||||
url: https://pypi.org/p/axolotl
|
url: https://pypi.org/p/axolotl
|
||||||
@@ -17,10 +35,10 @@ jobs:
|
|||||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
|
||||||
|
|||||||
11
.github/workflows/tests-nightly.yml
vendored
11
.github/workflows/tests-nightly.yml
vendored
@@ -9,12 +9,12 @@ jobs:
|
|||||||
name: pre-commit
|
name: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
- uses: pre-commit/action@v3.0.0
|
- uses: pre-commit/action@v3.0.1
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
@@ -30,10 +30,10 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
@@ -48,6 +48,7 @@ jobs:
|
|||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
||||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
||||||
|
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
15
.github/workflows/tests.yml
vendored
15
.github/workflows/tests.yml
vendored
@@ -15,17 +15,22 @@ on:
|
|||||||
- '.github/workflows/*.yml'
|
- '.github/workflows/*.yml'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
# Cancel jobs on the same ref if a new one is triggered
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pre-commit:
|
pre-commit:
|
||||||
name: pre-commit
|
name: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
- uses: pre-commit/action@v3.0.0
|
- uses: pre-commit/action@v3.0.1
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
@@ -41,10 +46,10 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
|
|||||||
#### Docker
|
#### Docker
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest
|
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||||
```
|
```
|
||||||
|
|
||||||
Or run on the current files for development:
|
Or run on the current files for development:
|
||||||
@@ -178,7 +178,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
|
|||||||
A more powerful Docker command to run would be this:
|
A more powerful Docker command to run would be this:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest
|
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-latest
|
||||||
```
|
```
|
||||||
|
|
||||||
It additionally:
|
It additionally:
|
||||||
@@ -210,7 +210,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
|||||||
|
|
||||||
#### Cloud GPU
|
#### Cloud GPU
|
||||||
|
|
||||||
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
|
For cloud GPU providers that support docker images, use [`axolotlai/axolotl-cloud:main-latest`](https://hub.docker.com/r/axolotlai/axolotl-cloud/tags)
|
||||||
|
|
||||||
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||||
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
|
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
|
||||||
@@ -319,7 +319,7 @@ Write a job description in YAML as below:
|
|||||||
# dstack.yaml
|
# dstack.yaml
|
||||||
type: task
|
type: task
|
||||||
|
|
||||||
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2
|
image: axolotlai/axolotl-cloud:main-latest
|
||||||
|
|
||||||
env:
|
env:
|
||||||
- HUGGING_FACE_HUB_TOKEN
|
- HUGGING_FACE_HUB_TOKEN
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
FROM winglian/axolotl-base:{{ BASE_TAG }}
|
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
||||||
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||||
@@ -28,6 +28,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||||
|
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import tempfile
|
|||||||
import jinja2
|
import jinja2
|
||||||
import modal
|
import modal
|
||||||
from jinja2 import select_autoescape
|
from jinja2 import select_autoescape
|
||||||
from modal import Image, Stub
|
from modal import App, Image
|
||||||
|
|
||||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ cicd_image = (
|
|||||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||||
)
|
)
|
||||||
|
|
||||||
stub = Stub("Axolotl CI/CD", secrets=[])
|
app = App("Axolotl CI/CD", secrets=[])
|
||||||
|
|
||||||
|
|
||||||
N_GPUS = int(os.environ.get("N_GPUS", 2))
|
N_GPUS = int(os.environ.get("N_GPUS", 2))
|
||||||
@@ -61,7 +61,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||||
|
|
||||||
|
|
||||||
@stub.function(
|
@app.function(
|
||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=60 * 60,
|
timeout=60 * 60,
|
||||||
@@ -72,6 +72,6 @@ def cicd_pytest():
|
|||||||
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
||||||
|
|
||||||
|
|
||||||
@stub.local_entrypoint()
|
@app.local_entrypoint()
|
||||||
def main():
|
def main():
|
||||||
cicd_pytest.remote()
|
cicd_pytest.remote()
|
||||||
|
|||||||
@@ -2,4 +2,4 @@
|
|||||||
set -e
|
set -e
|
||||||
|
|
||||||
# only run one test at a time so as not to OOM the GPU
|
# only run one test at a time so as not to OOM the GPU
|
||||||
pytest -n1 /workspace/axolotl/tests/e2e/multigpu/
|
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import tempfile
|
|||||||
import jinja2
|
import jinja2
|
||||||
import modal
|
import modal
|
||||||
from jinja2 import select_autoescape
|
from jinja2 import select_autoescape
|
||||||
from modal import Image, Stub
|
from modal import App, Image
|
||||||
|
|
||||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ cicd_image = (
|
|||||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||||
)
|
)
|
||||||
|
|
||||||
stub = Stub("Axolotl CI/CD", secrets=[])
|
app = App("Axolotl CI/CD", secrets=[])
|
||||||
|
|
||||||
|
|
||||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||||
@@ -62,7 +62,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||||
|
|
||||||
|
|
||||||
@stub.function(
|
@app.function(
|
||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=60 * 60,
|
timeout=60 * 60,
|
||||||
@@ -73,6 +73,6 @@ def cicd_pytest():
|
|||||||
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
||||||
|
|
||||||
|
|
||||||
@stub.local_entrypoint()
|
@app.local_entrypoint()
|
||||||
def main():
|
def main():
|
||||||
cicd_pytest.remote()
|
cicd_pytest.remote()
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
ARG BASE_TAG=main-base
|
ARG BASE_TAG=main-base
|
||||||
FROM winglian/axolotl-base:$BASE_TAG
|
FROM axolotlai/axolotl-base:$BASE_TAG
|
||||||
|
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
ARG AXOLOTL_EXTRAS=""
|
ARG AXOLOTL_EXTRAS=""
|
||||||
|
|||||||
@@ -35,7 +35,3 @@ RUN git lfs install --skip-repo && \
|
|||||||
pip3 install awscli && \
|
pip3 install awscli && \
|
||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|
||||||
RUN if [ "$PYTHON_VERSION" != "2.5.1" ] ; then \
|
|
||||||
pip3 install flash-attn==2.6.3; \
|
|
||||||
fi
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
ARG BASE_TAG=main
|
ARG BASE_TAG=main
|
||||||
FROM winglian/axolotl:$BASE_TAG
|
FROM axolotlai/axolotl:$BASE_TAG
|
||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
ARG BASE_TAG=main
|
ARG BASE_TAG=main
|
||||||
FROM winglian/axolotl:$BASE_TAG
|
FROM axolotlai/axolotl:$BASE_TAG
|
||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
ARG BASE_TAG=main-base
|
ARG BASE_TAG=main-base
|
||||||
FROM winglian/axolotl-base:$BASE_TAG
|
FROM axolotlai/axolotl-base:$BASE_TAG
|
||||||
|
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
ARG AXOLOTL_EXTRAS=""
|
ARG AXOLOTL_EXTRAS=""
|
||||||
|
|||||||
@@ -91,6 +91,7 @@ datasets:
|
|||||||
name: # Optional[str] name of dataset configuration to load
|
name: # Optional[str] name of dataset configuration to load
|
||||||
train_on_split: train # Optional[str] name of dataset split to load from
|
train_on_split: train # Optional[str] name of dataset split to load from
|
||||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||||
|
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
|
||||||
|
|
||||||
# Custom user instruction prompt
|
# Custom user instruction prompt
|
||||||
- path: repo
|
- path: repo
|
||||||
@@ -405,6 +406,7 @@ lr_div_factor: # Learning rate div factor
|
|||||||
# - adamw_torch_fused
|
# - adamw_torch_fused
|
||||||
# - adamw_torch_xla
|
# - adamw_torch_xla
|
||||||
# - adamw_apex_fused
|
# - adamw_apex_fused
|
||||||
|
# - adopt_adamw (only for torch version >= 2.5.1)
|
||||||
# - adafactor
|
# - adafactor
|
||||||
# - adamw_anyprecision
|
# - adamw_anyprecision
|
||||||
# - sgd
|
# - sgd
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
|
|||||||
|
|
||||||
## Debugging With Docker
|
## Debugging With Docker
|
||||||
|
|
||||||
Using [official Axolotl Docker images](https://hub.docker.com/r/winglian/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
|
Using [official Axolotl Docker images](https://hub.docker.com/r/axolotlai/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
@@ -202,11 +202,11 @@ cd axolotl
|
|||||||
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
|
||||||
```
|
```
|
||||||
|
|
||||||
>[!Tip]
|
>[!Tip]
|
||||||
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/winglian/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
||||||
|
|
||||||
You will now be in the container. Next, perform an editable install of Axolotl:
|
You will now be in the container. Next, perform an editable install of Axolotl:
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
|
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
|
||||||
"!pip install flash-attn==\"2.5.0\"\n",
|
"!pip install flash-attn==\"2.7.0.post2\"\n",
|
||||||
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
93
examples/mistral/mistral-dpo-qlora.yml
Normal file
93
examples/mistral/mistral-dpo-qlora.yml
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
#Note that we are switching from the regular chat template to chatml.
|
||||||
|
#If you experience problems with the special tokens, training for more epochs can help.
|
||||||
|
#After training, merge the model before inference otherwise you might
|
||||||
|
#face problems with the special tokens.
|
||||||
|
|
||||||
|
base_model: mistralai/Mistral-7B-Instruct-v0.2
|
||||||
|
model_type: MistralForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: chatml
|
||||||
|
rl: dpo
|
||||||
|
datasets:
|
||||||
|
- path: olivermolenschot/alpaca_messages_dpo_test
|
||||||
|
type: chat_template.default
|
||||||
|
field_messages: conversation
|
||||||
|
field_chosen: chosen
|
||||||
|
field_rejected: rejected
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/dpo-qlora
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 8
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.2
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
lora_modules_to_save:
|
||||||
|
- embed_tokens
|
||||||
|
- lm_head
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 16
|
||||||
|
num_epochs: 6
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0001
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: false
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<|im_start|>"
|
||||||
|
eos_token: "<|im_end|>"
|
||||||
67
examples/qwen2/dpo.yaml
Normal file
67
examples/qwen2/dpo.yaml
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
base_model: Qwen/Qwen2.5-0.5B
|
||||||
|
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: qwen_25
|
||||||
|
rl: dpo
|
||||||
|
datasets:
|
||||||
|
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||||
|
type: chat_template.default
|
||||||
|
field_messages: conversation
|
||||||
|
field_chosen: chosen
|
||||||
|
field_rejected: rejected
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
roles:
|
||||||
|
system:
|
||||||
|
- system
|
||||||
|
user:
|
||||||
|
- user
|
||||||
|
assistant:
|
||||||
|
- assistant
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./outputs/dpo-out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 2
|
||||||
|
num_epochs: 4
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
@@ -1,2 +1,3 @@
|
|||||||
pytest
|
pytest
|
||||||
pytest-xdist
|
pytest-xdist
|
||||||
|
pytest-retry
|
||||||
|
|||||||
@@ -1,18 +1,18 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.13.2
|
peft==0.13.2
|
||||||
transformers==4.46.1
|
transformers==4.46.2
|
||||||
tokenizers>=0.20.1
|
tokenizers>=0.20.1
|
||||||
bitsandbytes==0.44.1
|
bitsandbytes==0.44.1
|
||||||
accelerate==1.1.0
|
accelerate==1.1.0
|
||||||
datasets==3.0.1
|
datasets==3.1.0
|
||||||
deepspeed==0.15.3
|
deepspeed==0.15.3
|
||||||
pydantic==2.6.3
|
pydantic==2.6.3
|
||||||
addict
|
addict
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
requests
|
requests
|
||||||
flash-attn==2.6.3
|
flash-attn==2.7.0.post2
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
@@ -33,7 +33,7 @@ tensorboard
|
|||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
autoawq>=0.2.5
|
autoawq>=0.2.5
|
||||||
triton>=2.3.0
|
triton>=2.3.0
|
||||||
liger-kernel==0.4.0
|
liger-kernel==0.4.1
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
@@ -53,3 +53,4 @@ immutabledict==4.2.0
|
|||||||
antlr4-python3-runtime==4.13.2
|
antlr4-python3-runtime==4.13.2
|
||||||
|
|
||||||
torchao==0.5.0
|
torchao==0.5.0
|
||||||
|
schedulefree==1.3.0
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
# Export specific ENV variables to /etc/rp_environment
|
# Export specific ENV variables to /etc/rp_environment
|
||||||
echo "Exporting environment variables..."
|
echo "Exporting environment variables..."
|
||||||
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
printenv | grep -E '^HF_|^BNB_|^CUDA_|^NCCL_|^NV|^RUNPOD_|^PATH=|^_=' | sed 's/^\([^=]*\)=\(.*\)$/export \1="\2"/' | grep -v 'printenv' >> /etc/rp_environment
|
||||||
echo 'source /etc/rp_environment' >> ~/.bashrc
|
echo 'source /etc/rp_environment' >> ~/.bashrc
|
||||||
|
|
||||||
add_keys_to_authorized() {
|
add_keys_to_authorized() {
|
||||||
|
|||||||
11
setup.py
11
setup.py
@@ -39,7 +39,10 @@ def parse_requirements():
|
|||||||
else:
|
else:
|
||||||
# detect the version of torch already installed
|
# detect the version of torch already installed
|
||||||
# and set it so dependencies don't clobber the torch version
|
# and set it so dependencies don't clobber the torch version
|
||||||
torch_version = version("torch")
|
try:
|
||||||
|
torch_version = version("torch")
|
||||||
|
except PackageNotFoundError:
|
||||||
|
torch_version = "2.5.1"
|
||||||
_install_requires.append(f"torch=={torch_version}")
|
_install_requires.append(f"torch=={torch_version}")
|
||||||
|
|
||||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||||
@@ -54,6 +57,10 @@ def parse_requirements():
|
|||||||
|
|
||||||
if (major, minor) >= (2, 5):
|
if (major, minor) >= (2, 5):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
if patch == 0:
|
||||||
|
_install_requires.append("xformers==0.0.28.post2")
|
||||||
|
else:
|
||||||
|
_install_requires.append("xformers==0.0.28.post3")
|
||||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||||
elif (major, minor) >= (2, 4):
|
elif (major, minor) >= (2, 4):
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
@@ -98,7 +105,7 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.6.3",
|
"flash-attn==2.7.0.post2",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.14.4",
|
"deepspeed==0.14.4",
|
||||||
|
|||||||
@@ -190,18 +190,15 @@ def do_inference(
|
|||||||
):
|
):
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
|
||||||
|
|
||||||
for token, symbol in default_tokens.items():
|
|
||||||
# If the token isn't already specified in the config, add it
|
|
||||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
|
||||||
tokenizer.add_special_tokens({token: symbol})
|
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
|
chat_template_str = None
|
||||||
if prompter:
|
if prompter:
|
||||||
prompter_module = getattr(
|
prompter_module = getattr(
|
||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
elif cfg.chat_template:
|
||||||
|
chat_template_str = get_chat_template(cfg.chat_template)
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
@@ -211,13 +208,31 @@ def do_inference(
|
|||||||
instruction = get_multi_line_input()
|
instruction = get_multi_line_input()
|
||||||
if not instruction:
|
if not instruction:
|
||||||
return
|
return
|
||||||
|
|
||||||
if prompter_module:
|
if prompter_module:
|
||||||
prompt: str = next(
|
prompt: str = next(
|
||||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = instruction.strip()
|
prompt = instruction.strip()
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
|
||||||
|
if chat_template_str:
|
||||||
|
batch = tokenizer.apply_chat_template(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||||
|
|
||||||
print("=" * 40)
|
print("=" * 40)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -257,13 +272,6 @@ def do_inference_gradio(
|
|||||||
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
|
||||||
default_tokens: Dict[str, str] = {}
|
|
||||||
|
|
||||||
for token, symbol in default_tokens.items():
|
|
||||||
# If the token isn't already specified in the config, add it
|
|
||||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
|
||||||
tokenizer.add_special_tokens({token: symbol})
|
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
chat_template_str = None
|
chat_template_str = None
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ MOE_ARCH_BLOCK = {
|
|||||||
"JetMoeMoE",
|
"JetMoeMoE",
|
||||||
],
|
],
|
||||||
"mixtral": "MixtralSparseMoeBlock",
|
"mixtral": "MixtralSparseMoeBlock",
|
||||||
|
"phimoe": "PhiMoESparseMoeBlock",
|
||||||
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
|
||||||
"deepseek_v2": "DeepseekV2MoE",
|
"deepseek_v2": "DeepseekV2MoE",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -436,7 +436,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
and self.args.alternate_optimizer
|
and self.args.alternate_optimizer
|
||||||
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
|
not in [
|
||||||
|
"optimi_adamw",
|
||||||
|
"ao_adamw_8bit",
|
||||||
|
"ao_adamw_4bit",
|
||||||
|
"ao_adamw_fp8",
|
||||||
|
"adopt_adamw",
|
||||||
|
]
|
||||||
):
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
@@ -505,6 +511,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
)
|
)
|
||||||
|
elif self.args.alternate_optimizer == "adopt_adamw":
|
||||||
|
from axolotl.utils.optimizers.adopt import ADOPT
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
ADOPT(
|
||||||
|
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
@@ -1024,24 +1038,37 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
|
|
||||||
return super().push_to_hub(*args, **kwargs)
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def tokenize_row(
|
def tokenize_row(
|
||||||
self,
|
|
||||||
features,
|
features,
|
||||||
processing_class,
|
processing_class,
|
||||||
max_prompt_length,
|
max_prompt_length,
|
||||||
max_completion_length,
|
max_completion_length,
|
||||||
add_special_tokens,
|
add_special_tokens,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
res = super().tokenize_row(
|
res = DPOTrainer.tokenize_row(
|
||||||
features,
|
features,
|
||||||
processing_class,
|
processing_class,
|
||||||
max_prompt_length,
|
max_prompt_length,
|
||||||
max_completion_length,
|
max_completion_length,
|
||||||
add_special_tokens,
|
add_special_tokens,
|
||||||
)
|
)
|
||||||
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||||
|
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||||
for key in res.keys():
|
for key in res.keys():
|
||||||
res[key] = res[key][1:]
|
res[key] = res[key][1:]
|
||||||
|
|
||||||
|
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||||
|
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||||
|
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||||
|
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||||
|
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||||
|
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||||
|
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||||
|
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||||
|
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||||
|
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def training_step(
|
def training_step(
|
||||||
@@ -1273,6 +1300,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||||
callbacks.append(lisa_callback_factory(trainer))
|
callbacks.append(lisa_callback_factory(trainer))
|
||||||
|
|
||||||
|
if self.cfg.plugins:
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
callbacks.extend(
|
||||||
|
[
|
||||||
|
cb
|
||||||
|
for cb in plugin_manager.add_callbacks_post_trainer(
|
||||||
|
self.cfg, trainer
|
||||||
|
)
|
||||||
|
if cb
|
||||||
|
]
|
||||||
|
)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def _get_trainer_cls(self):
|
def _get_trainer_cls(self):
|
||||||
@@ -1390,17 +1429,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||||
# no eval set, so don't eval
|
# no eval set, so don't eval
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
training_arguments_kwargs["eval_strategy"] = "no"
|
||||||
elif self.cfg.eval_steps:
|
elif self.cfg.eval_steps:
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
training_arguments_kwargs["eval_strategy"] = "steps"
|
||||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
elif self.cfg.evaluation_strategy:
|
elif self.cfg.eval_strategy:
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||||
"evaluation_strategy"
|
|
||||||
] = self.cfg.evaluation_strategy
|
|
||||||
else:
|
else:
|
||||||
# we have an eval set, but no steps defined, default to use epoch
|
# we have an eval set, but no steps defined, default to use epoch
|
||||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
training_arguments_kwargs["eval_strategy"] = "epoch"
|
||||||
|
|
||||||
if self.cfg.save_steps:
|
if self.cfg.save_steps:
|
||||||
training_arguments_kwargs["save_strategy"] = "steps"
|
training_arguments_kwargs["save_strategy"] = "steps"
|
||||||
@@ -1625,11 +1662,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
if self.cfg.optimizer in [
|
if self.cfg.optimizer in [
|
||||||
"optimi_adamw",
|
"optimi_adamw",
|
||||||
"ao_adamw_4bit",
|
"ao_adamw_4bit",
|
||||||
"ao_adamw_8bit",
|
"ao_adamw_8bit",
|
||||||
"ao_adamw_fp8",
|
"ao_adamw_fp8",
|
||||||
|
"adopt_adamw",
|
||||||
]:
|
]:
|
||||||
# Set default so transformers doesn't throw
|
# Set default so transformers doesn't throw
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||||
@@ -1832,10 +1871,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
training_args_kwargs["evaluation_strategy"] = "steps"
|
training_args_kwargs["eval_strategy"] = "steps"
|
||||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["evaluation_strategy"] = "no"
|
training_args_kwargs["eval_strategy"] = "no"
|
||||||
|
|
||||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||||
training_args_kwargs["bf16"] = True
|
training_args_kwargs["bf16"] = True
|
||||||
@@ -1933,6 +1972,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
training_args_cls = AxolotlDPOConfig
|
training_args_cls = AxolotlDPOConfig
|
||||||
|
if self.cfg.rl == "ipo":
|
||||||
|
training_args_kwargs["loss_type"] = "ipo"
|
||||||
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
training_args_kwargs["max_completion_length"] = None
|
||||||
|
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||||
|
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||||
if self.cfg.dpo_use_weighting is not None:
|
if self.cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||||
|
|
||||||
@@ -1956,7 +2001,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args = self.build_training_arguments(total_num_steps)
|
training_args = self.build_training_arguments(total_num_steps)
|
||||||
dpo_trainer_kwargs = {}
|
dpo_trainer_kwargs = {}
|
||||||
if self.cfg.rl == "ipo":
|
if self.cfg.rl == "ipo":
|
||||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
|
||||||
if self.cfg.dpo_label_smoothing:
|
if self.cfg.dpo_label_smoothing:
|
||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
@@ -1970,12 +2014,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rl in ["dpo", "ipo"]:
|
if self.cfg.rl in ["dpo", "ipo"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|
||||||
# these aren't used for the ORPO trainer
|
|
||||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
|
||||||
dpo_trainer_kwargs["max_target_length"] = None
|
|
||||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
|
||||||
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ class BasePlugin:
|
|||||||
|
|
||||||
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Adds callbacks to the trainer before training.
|
setup callbacks before creating the trainer.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
cfg (dict): The configuration for the plugin.
|
cfg (dict): The configuration for the plugin.
|
||||||
@@ -155,14 +155,15 @@ class BasePlugin:
|
|||||||
self, cfg, trainer
|
self, cfg, trainer
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Adds callbacks to the trainer after training.
|
Adds callbacks to the trainer after creating the trainer.
|
||||||
|
This is useful for callbacks that require access to the model or trainer.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
cfg (dict): The configuration for the plugin.
|
cfg (dict): The configuration for the plugin.
|
||||||
trainer (object): The trainer object for training.
|
trainer (object): The trainer object for training.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
List[callable]: A list of callback functions to be added
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -393,7 +394,9 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
callbacks = []
|
callbacks = []
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
|
plugin_callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
|
||||||
|
if plugin_callbacks: # if the plugin returned a list of callbacks
|
||||||
|
callbacks.extend(plugin_callbacks)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||||
@@ -409,7 +412,9 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
callbacks = []
|
callbacks = []
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
plugin_callbacks = plugin.add_callbacks_post_trainer(cfg, trainer)
|
||||||
|
if plugin_callbacks:
|
||||||
|
callbacks.extend(plugin_callbacks)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def post_train_unload(self, cfg):
|
def post_train_unload(self, cfg):
|
||||||
|
|||||||
21
src/axolotl/integrations/grokfast/LICENSE
Normal file
21
src/axolotl/integrations/grokfast/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
13
src/axolotl/integrations/grokfast/README.md
Normal file
13
src/axolotl/integrations/grokfast/README.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Grokfast Optimizer
|
||||||
|
|
||||||
|
See https://github.com/ironjr/grokfast
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.grokfast.GrokfastPlugin
|
||||||
|
|
||||||
|
grokfast_alpha: 2.0
|
||||||
|
grokfast_lamb: 0.98
|
||||||
|
```
|
||||||
50
src/axolotl/integrations/grokfast/__init__.py
Normal file
50
src/axolotl/integrations/grokfast/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""
|
||||||
|
Grokfast plugin for Axolotl
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from transformers.trainer_callback import TrainerCallback
|
||||||
|
|
||||||
|
from ..base import BasePlugin
|
||||||
|
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
from .optimizer import gradfilter_ema
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.integrations.grokfast")
|
||||||
|
|
||||||
|
|
||||||
|
class GrokfastCallbackHandler(TrainerCallback):
|
||||||
|
"""
|
||||||
|
Transformer trainer callbacks for Grokfast
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs):
|
||||||
|
super().__init__(*args_, **kwargs)
|
||||||
|
self.grads = None
|
||||||
|
self.alpha = alpha
|
||||||
|
self.lamb = lamb
|
||||||
|
|
||||||
|
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
|
||||||
|
self.grads = None
|
||||||
|
|
||||||
|
def on_pre_optimizer_step(
|
||||||
|
self, args_, state, control, **kwargs
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
model = kwargs.pop("model")
|
||||||
|
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class GrokfastPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for Grokfast optimizer integraton with Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_input_args(self):
|
||||||
|
return "axolotl.integrations.grokfast.GrokfastArgs"
|
||||||
|
|
||||||
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||||
|
LOG.info("Adding Grokfast callback to the trainer")
|
||||||
|
callback = GrokfastCallbackHandler(
|
||||||
|
alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb
|
||||||
|
)
|
||||||
|
return [callback]
|
||||||
15
src/axolotl/integrations/grokfast/args.py
Normal file
15
src/axolotl/integrations/grokfast/args.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
config args for grokfast plugin
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class GrokfastArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for Grokfast optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
grokfast_alpha: Optional[float] = 0.98
|
||||||
|
grokfast_lamb: Optional[float] = 2.0
|
||||||
63
src/axolotl/integrations/grokfast/optimizer.py
Normal file
63
src/axolotl/integrations/grokfast/optimizer.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
|
||||||
|
# Reference: https://github.com/ironjr/grokfast
|
||||||
|
|
||||||
|
# pylint: skip-file
|
||||||
|
from collections import deque
|
||||||
|
from typing import Dict, Literal, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def gradfilter_ma(
|
||||||
|
m: nn.Module,
|
||||||
|
grads: Optional[Dict[str, deque]] = None,
|
||||||
|
window_size: int = 100,
|
||||||
|
lamb: float = 5.0,
|
||||||
|
filter_type: Literal["mean", "sum"] = "mean",
|
||||||
|
warmup: bool = True,
|
||||||
|
trigger: bool = False, # For ablation study.
|
||||||
|
) -> Dict[str, deque]:
|
||||||
|
if grads is None:
|
||||||
|
grads = {
|
||||||
|
n: deque(maxlen=window_size)
|
||||||
|
for n, p in m.named_parameters()
|
||||||
|
if p.requires_grad and p.grad is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, p in m.named_parameters():
|
||||||
|
if p.requires_grad and p.grad is not None:
|
||||||
|
grads[n].append(p.grad.data.detach()) # .cpu())
|
||||||
|
|
||||||
|
# Modify the gradients.
|
||||||
|
if not warmup or len(grads[n]) == window_size and not trigger:
|
||||||
|
if filter_type == "mean":
|
||||||
|
avg = sum(grads[n]) / len(grads[n])
|
||||||
|
elif filter_type == "sum":
|
||||||
|
avg = sum(grads[n])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized filter_type {filter_type}")
|
||||||
|
p.grad.data = p.grad.data + avg * lamb
|
||||||
|
|
||||||
|
return grads
|
||||||
|
|
||||||
|
|
||||||
|
def gradfilter_ema(
|
||||||
|
m: nn.Module,
|
||||||
|
grads: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
alpha: float = 0.98,
|
||||||
|
lamb: float = 2.0,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
if grads is None:
|
||||||
|
grads = {
|
||||||
|
n: p.grad.data.detach()
|
||||||
|
for n, p in m.named_parameters()
|
||||||
|
if p.requires_grad and p.grad is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, p in m.named_parameters():
|
||||||
|
if p.requires_grad and p.grad is not None:
|
||||||
|
grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
|
||||||
|
p.grad.data = p.grad.data + grads[n] * lamb
|
||||||
|
|
||||||
|
return grads
|
||||||
@@ -23,6 +23,7 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||||
@@ -82,7 +83,9 @@ class LigerPlugin(BasePlugin):
|
|||||||
if cfg.liger_glu_activation:
|
if cfg.liger_glu_activation:
|
||||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
from transformers.loss.loss_utils import nn
|
||||||
|
|
||||||
|
nn.functional.cross_entropy = liger_cross_entropy
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
||||||
elif cfg.model_config_type == "deepseek_v2":
|
elif cfg.model_config_type == "deepseek_v2":
|
||||||
@@ -106,6 +109,8 @@ class LigerPlugin(BasePlugin):
|
|||||||
if cfg.liger_glu_activation:
|
if cfg.liger_glu_activation:
|
||||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
|
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
||||||
|
# nn.CrossEntropyLoss in the forward method.
|
||||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""multipack patching for v2 of sample packing"""
|
"""multipack patching for v2 of sample packing"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
@@ -19,6 +20,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"falcon",
|
"falcon",
|
||||||
"phi",
|
"phi",
|
||||||
"phi3",
|
"phi3",
|
||||||
|
"phimoe",
|
||||||
"gemma",
|
"gemma",
|
||||||
"gemma2",
|
"gemma2",
|
||||||
"gemmoe",
|
"gemmoe",
|
||||||
@@ -27,71 +29,28 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
|
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
||||||
if model_type == "gemmoe":
|
if has_remote_code:
|
||||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
patch_remote(model_name)
|
||||||
elif model_type == "deepseek_v2":
|
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||||
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
|
||||||
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
|
|
||||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
|
||||||
patch_mixtral_moe_forward_zero3()
|
|
||||||
return
|
|
||||||
|
|
||||||
# retain for legacy
|
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||||
if model_type == "mixtral":
|
patch_mixtral_moe_forward_zero3()
|
||||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
patch_mixtral_moe_forward_zero3()
|
|
||||||
elif model_type == "llama":
|
|
||||||
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
|
|
||||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "mistral":
|
|
||||||
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
|
|
||||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "qwen2":
|
|
||||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "qwen2_moe":
|
|
||||||
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "falcon":
|
|
||||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "phi":
|
|
||||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "gemma":
|
|
||||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "gemma2":
|
|
||||||
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
elif model_type == "starcoder2":
|
|
||||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_remote(model_name, config_name, modeling_name):
|
def patch_remote(model_name):
|
||||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||||
# we need to load the model here in order for modeling_* to be available
|
# we need to load the model here in order for modeling_* to be available
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||||
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
|
parts = model_config.__class__.__module__.split(".")
|
||||||
|
parts[-1] = parts[-1].replace("configuration_", "modeling_", 1)
|
||||||
|
module_name = ".".join(parts)
|
||||||
modeling_arch = importlib.import_module(module_name)
|
modeling_arch = importlib.import_module(module_name)
|
||||||
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
|
if hasattr(modeling_arch, "_get_unpad_data"):
|
||||||
|
modeling_arch._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
|||||||
83
src/axolotl/monkeypatch/trainer_fsdp_grad_accum.py
Normal file
83
src/axolotl/monkeypatch/trainer_fsdp_grad_accum.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""
|
||||||
|
fix for FSDP gradient accumulation
|
||||||
|
see https://github.com/huggingface/transformers/pull/34645
|
||||||
|
"""
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from accelerate.logging import get_logger
|
||||||
|
from transformers.trainer import Trainer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||||
|
|
||||||
|
LOG = get_logger("axolotl.monkeypatch.trainer_fsdp_grad_accumulation")
|
||||||
|
|
||||||
|
ORIGINAL_CONTEXT_CODE = """
|
||||||
|
context = (
|
||||||
|
functools.partial(self.accelerator.no_sync, model=model)
|
||||||
|
if i == len(batch_samples) - 1
|
||||||
|
else contextlib.nullcontext
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATCHED_CONTEXT_CODE = """
|
||||||
|
context = (
|
||||||
|
functools.partial(self.accelerator.no_sync, model=model)
|
||||||
|
if i != len(batch_samples) - 1
|
||||||
|
else contextlib.nullcontext
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_loop_code() -> str:
|
||||||
|
training_loop = inspect.getsource(
|
||||||
|
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
return training_loop
|
||||||
|
|
||||||
|
|
||||||
|
def check_training_loop_is_patchable() -> bool:
|
||||||
|
train_loop = get_training_loop_code()
|
||||||
|
train_loop, _ = detab_code(train_loop)
|
||||||
|
return ORIGINAL_CONTEXT_CODE in train_loop
|
||||||
|
|
||||||
|
|
||||||
|
def patch_training_loop_for_fsdp_grad_accum():
|
||||||
|
"""
|
||||||
|
monkeypatch for fixing the training loop for FSDP gradient accumulation
|
||||||
|
"""
|
||||||
|
|
||||||
|
train_loop = get_training_loop_code()
|
||||||
|
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
|
||||||
|
train_loop
|
||||||
|
)
|
||||||
|
train_loop, _ = detab_code(train_loop)
|
||||||
|
assert (
|
||||||
|
ORIGINAL_CONTEXT_CODE in train_loop
|
||||||
|
), "Original _inner_training_loop code not found"
|
||||||
|
|
||||||
|
train_loop = train_loop.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
|
||||||
|
train_loop = train_loop.replace(
|
||||||
|
"def _inner_training_loop(",
|
||||||
|
"def _fixed_inner_training_loop(",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load imports necessary
|
||||||
|
import transformers.trainer
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(transformers.trainer):
|
||||||
|
if item in train_loop:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
"from transformers.trainer import ("
|
||||||
|
+ ", ".join(x for x in items_to_import)
|
||||||
|
+ ")",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec(train_loop, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
|
LOG.info("patching _inner_training_loop", main_process_only=True)
|
||||||
|
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||||
|
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
)
|
||||||
@@ -64,10 +64,7 @@ class EvalFirstStepCallback(
|
|||||||
control: TrainerControl,
|
control: TrainerControl,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if (
|
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
||||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
|
||||||
and state.global_step == 1
|
|
||||||
):
|
|
||||||
control.should_evaluate = True
|
control.should_evaluate = True
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,8 +1,6 @@
|
|||||||
"""Module for working with config dicts"""
|
"""Module for working with config dicts"""
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -10,7 +8,6 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
|
|
||||||
from axolotl.integrations.config import merge_input_args
|
from axolotl.integrations.config import merge_input_args
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import SUPPORTED_METRICS
|
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
)
|
)
|
||||||
@@ -247,370 +244,3 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
|||||||
return DictDefault(
|
return DictDefault(
|
||||||
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def legacy_validate_config(cfg):
|
|
||||||
"""
|
|
||||||
This is a "pre-validation" step that handles the yaml configuration before we have any
|
|
||||||
information about the model architecture
|
|
||||||
"""
|
|
||||||
if is_torch_bf16_gpu_available():
|
|
||||||
if not cfg.bf16 and not cfg.bfloat16:
|
|
||||||
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
|
||||||
else:
|
|
||||||
if (
|
|
||||||
not cfg.merge_lora
|
|
||||||
and not cfg.is_preprocess
|
|
||||||
and (cfg.bf16 is True or cfg.bfloat16 is True)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
# pylint: disable=too-many-boolean-expressions
|
|
||||||
not (cfg.bf16 or cfg.bfloat16)
|
|
||||||
and (cfg.fp16 or cfg.float16)
|
|
||||||
and not cfg.adapter
|
|
||||||
and not cfg.flash_attention
|
|
||||||
and cfg.sample_packing
|
|
||||||
):
|
|
||||||
LOG.warning(
|
|
||||||
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
|
|
||||||
)
|
|
||||||
# ValueError: Attempting to unscale FP16 gradients.
|
|
||||||
# OR
|
|
||||||
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
|
|
||||||
if cfg.max_packed_sequence_len:
|
|
||||||
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
|
||||||
|
|
||||||
if cfg.sample_packing and cfg.rl:
|
|
||||||
raise ValueError("`sample_packing: true` does not work with RLHF training")
|
|
||||||
|
|
||||||
if cfg.sample_packing and not cfg.pad_to_sequence_len:
|
|
||||||
LOG.warning(
|
|
||||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
|
||||||
raise ValueError(
|
|
||||||
"please set only one of gradient_accumulation_steps or batch_size"
|
|
||||||
)
|
|
||||||
if cfg.batch_size:
|
|
||||||
LOG.warning(
|
|
||||||
"%s\n%s",
|
|
||||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
|
||||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
cfg.eval_batch_size
|
|
||||||
and cfg.micro_batch_size
|
|
||||||
and cfg.eval_batch_size != cfg.micro_batch_size
|
|
||||||
):
|
|
||||||
LOG.warning(
|
|
||||||
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.adapter == "qlora":
|
|
||||||
if cfg.merge_lora:
|
|
||||||
# can't merge qlora if loaded in 8bit or 4bit
|
|
||||||
if cfg.load_in_8bit:
|
|
||||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
|
||||||
|
|
||||||
if cfg.gptq:
|
|
||||||
raise ValueError("Can't merge qlora if gptq")
|
|
||||||
|
|
||||||
if cfg.load_in_4bit:
|
|
||||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
|
||||||
|
|
||||||
else:
|
|
||||||
if cfg.load_in_8bit:
|
|
||||||
raise ValueError("Can't load qlora in 8bit")
|
|
||||||
|
|
||||||
if cfg.gptq:
|
|
||||||
raise ValueError("Can't load qlora if gptq")
|
|
||||||
|
|
||||||
if not cfg.load_in_4bit:
|
|
||||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
|
||||||
|
|
||||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
|
||||||
raise ValueError("Fused modules are not supported with QLoRA")
|
|
||||||
|
|
||||||
loftq = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
|
||||||
if not cfg.load_in_8bit and cfg.adapter == "lora" and not loftq:
|
|
||||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
|
||||||
|
|
||||||
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
|
||||||
raise ValueError("Fused modules are not supported with LoRA")
|
|
||||||
|
|
||||||
if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters:
|
|
||||||
raise ValueError(
|
|
||||||
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.relora_steps:
|
|
||||||
if cfg.adapter not in ("lora", "qlora"):
|
|
||||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
|
||||||
|
|
||||||
if cfg.fsdp:
|
|
||||||
raise ValueError("fsdp not supported with ReLoRA")
|
|
||||||
|
|
||||||
if cfg.deepspeed:
|
|
||||||
raise ValueError("deepspeed not supported with ReLoRA")
|
|
||||||
|
|
||||||
if cfg.lr_scheduler == "one_cycle":
|
|
||||||
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
|
||||||
|
|
||||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
|
||||||
raise ValueError("Fused modules are not supported with ReLoRA")
|
|
||||||
|
|
||||||
if cfg.trust_remote_code:
|
|
||||||
LOG.warning(
|
|
||||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
|
|
||||||
raise ValueError(
|
|
||||||
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
|
||||||
raise ValueError("FSDP is not supported for falcon models")
|
|
||||||
|
|
||||||
if (
|
|
||||||
cfg.base_model and "mpt" in cfg.base_model.lower()
|
|
||||||
) and cfg.gradient_checkpointing:
|
|
||||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
|
||||||
|
|
||||||
if cfg.flash_optimum is True:
|
|
||||||
if cfg.adapter:
|
|
||||||
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
|
||||||
if cfg.fp16 or cfg.bf16:
|
|
||||||
raise ValueError("AMP is not supported with BetterTransformer")
|
|
||||||
if cfg.float16 is not True and cfg.bfloat16 is not True:
|
|
||||||
LOG.warning(
|
|
||||||
"You should probably set bfloat16 or float16 to true to "
|
|
||||||
"load the model in float16 for BetterTransformers"
|
|
||||||
)
|
|
||||||
if int(torch.__version__.split(".", maxsplit=1)[0]) < 2:
|
|
||||||
LOG.warning("torch>=2.0.0 required")
|
|
||||||
raise ValueError(
|
|
||||||
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.pretraining_dataset and cfg.group_by_length:
|
|
||||||
LOG.warning(
|
|
||||||
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
|
||||||
)
|
|
||||||
if cfg.pretraining_dataset and not cfg.max_steps:
|
|
||||||
raise ValueError(
|
|
||||||
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
|
|
||||||
)
|
|
||||||
|
|
||||||
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
|
||||||
not cfg.optimizer or "adamw" not in cfg.optimizer
|
|
||||||
):
|
|
||||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
|
||||||
|
|
||||||
if cfg.push_to_hub_model_id:
|
|
||||||
raise ValueError(
|
|
||||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
|
|
||||||
LOG.warning(
|
|
||||||
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.gptq and cfg.revision_of_model:
|
|
||||||
raise ValueError(
|
|
||||||
"revision_of_model is not supported for GPTQ models. "
|
|
||||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
|
||||||
+ "point to its path, and remove revision_of_model from the config."
|
|
||||||
)
|
|
||||||
|
|
||||||
# if cfg.sample_packing and cfg.sdp_attention:
|
|
||||||
# # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
|
||||||
# raise ValueError(
|
|
||||||
# "sample_packing not compatible with sdp_attention. Use flash_attention"
|
|
||||||
# )
|
|
||||||
|
|
||||||
if cfg.sample_packing and cfg.xformers_attention:
|
|
||||||
raise ValueError(
|
|
||||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
|
|
||||||
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
|
|
||||||
LOG.warning(
|
|
||||||
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
|
|
||||||
"This may work on H100s."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.early_stopping_patience:
|
|
||||||
if not cfg.save_steps or not cfg.eval_steps:
|
|
||||||
raise ValueError(
|
|
||||||
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
|
||||||
)
|
|
||||||
if cfg.save_steps % cfg.eval_steps != 0:
|
|
||||||
raise ValueError(
|
|
||||||
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.saves_per_epoch and cfg.save_steps:
|
|
||||||
raise ValueError(
|
|
||||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
|
||||||
)
|
|
||||||
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
|
|
||||||
raise ValueError(
|
|
||||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
|
||||||
)
|
|
||||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
|
||||||
raise ValueError(
|
|
||||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
|
||||||
)
|
|
||||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
|
||||||
raise ValueError(
|
|
||||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
cfg.evals_per_epoch
|
|
||||||
and cfg.evaluation_strategy
|
|
||||||
and cfg.evaluation_strategy != "steps"
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
cfg.evaluation_strategy
|
|
||||||
and cfg.eval_steps
|
|
||||||
and cfg.evaluation_strategy != "steps"
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
cfg.val_set_size == 0
|
|
||||||
and (cfg.eval_steps or cfg.evaluation_strategy)
|
|
||||||
and not cfg.test_datasets
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
cfg.sample_packing
|
|
||||||
and cfg.eval_table_size
|
|
||||||
and cfg.eval_sample_packing is not False
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
|
|
||||||
)
|
|
||||||
|
|
||||||
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
|
|
||||||
raise ValueError(
|
|
||||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
|
|
||||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.rope_scaling:
|
|
||||||
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
|
||||||
|
|
||||||
if cfg.wandb_run_id and not cfg.wandb_name:
|
|
||||||
cfg.wandb_name = cfg.wandb_run_id
|
|
||||||
|
|
||||||
LOG.warning(
|
|
||||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.noisy_embedding_alpha is not None:
|
|
||||||
# Deprecated, use neftune_noise_alpha
|
|
||||||
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
|
||||||
if cfg.neftune_noise_alpha is None:
|
|
||||||
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
|
|
||||||
else:
|
|
||||||
# User is providing both; bail and have them sort out their settings
|
|
||||||
raise ValueError(
|
|
||||||
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
|
||||||
raise ValueError("neftune_noise_alpha must be > 0.0")
|
|
||||||
|
|
||||||
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
cfg.unfrozen_parameters
|
|
||||||
and cfg.gradient_checkpointing_kwargs
|
|
||||||
and cfg.gradient_checkpointing_kwargs.use_reentrant is True
|
|
||||||
):
|
|
||||||
# https://github.com/huggingface/transformers/issues/21381
|
|
||||||
raise ValueError(
|
|
||||||
"`use_reentrant` must be false when used with partially frozen model."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.deepspeed and Path(cfg.deepspeed).is_file():
|
|
||||||
with open(cfg.deepspeed, encoding="utf-8") as file:
|
|
||||||
contents = file.read()
|
|
||||||
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
|
|
||||||
if cfg.flash_attention:
|
|
||||||
if (
|
|
||||||
deepspeed_cfg.zero_optimization
|
|
||||||
and deepspeed_cfg.zero_optimization.stage == 3
|
|
||||||
):
|
|
||||||
if not (
|
|
||||||
(
|
|
||||||
deepspeed_cfg.bf16
|
|
||||||
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
|
|
||||||
is True
|
|
||||||
)
|
|
||||||
or (
|
|
||||||
deepspeed_cfg.fp16
|
|
||||||
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
|
|
||||||
is True
|
|
||||||
)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
|
||||||
)
|
|
||||||
if "8bit" in cfg.optimizer and deepspeed_cfg.optimizer:
|
|
||||||
LOG.warning(
|
|
||||||
f"conflicting optimizer: {cfg.optimizer} used alongside deepspeed optimizer."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.test_datasets and cfg.val_set_size:
|
|
||||||
raise ValueError(
|
|
||||||
"non-zero val_set_size should not be used with test_datasets configuration"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.fsdp and "bnb" in cfg.optimizer:
|
|
||||||
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
|
|
||||||
|
|
||||||
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
|
|
||||||
raise ValueError(
|
|
||||||
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.eval_causal_lm_metrics:
|
|
||||||
if not isinstance(cfg.eval_causal_lm_metrics, list):
|
|
||||||
raise ValueError("eval_causal_lm_metrics must be a list")
|
|
||||||
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
|
||||||
if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS:
|
|
||||||
raise ValueError(
|
|
||||||
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO
|
|
||||||
# MPT 7b
|
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
||||||
# no 8bit adaAmw w bf16
|
|
||||||
|
|
||||||
# GPT-NeoX
|
|
||||||
# evals broken when extending context len
|
|
||||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
|
||||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
|
|
||||||
# attention_mask = causal_mask + attention_mask
|
|
||||||
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
|
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ class ChatTemplate(str, Enum):
|
|||||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||||
exaone = "exaone" # pylint: disable=invalid-name
|
exaone = "exaone" # pylint: disable=invalid-name
|
||||||
|
metharme = "metharme" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedParameters(BaseModel):
|
class DeprecatedParameters(BaseModel):
|
||||||
@@ -67,6 +68,7 @@ class DeprecatedParameters(BaseModel):
|
|||||||
rope_scaling: Optional[Any] = None
|
rope_scaling: Optional[Any] = None
|
||||||
noisy_embedding_alpha: Optional[float] = None
|
noisy_embedding_alpha: Optional[float] = None
|
||||||
dpo_beta: Optional[float] = None
|
dpo_beta: Optional[float] = None
|
||||||
|
evaluation_strategy: Optional[str] = None
|
||||||
|
|
||||||
@field_validator("max_packed_sequence_len")
|
@field_validator("max_packed_sequence_len")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -98,6 +100,13 @@ class DeprecatedParameters(BaseModel):
|
|||||||
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
||||||
return dpo_beta
|
return dpo_beta
|
||||||
|
|
||||||
|
@field_validator("evaluation_strategy")
|
||||||
|
@classmethod
|
||||||
|
def validate_evaluation_strategy(cls, evaluation_strategy):
|
||||||
|
if evaluation_strategy is not None:
|
||||||
|
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
||||||
|
return evaluation_strategy
|
||||||
|
|
||||||
|
|
||||||
class RemappedParameters(BaseModel):
|
class RemappedParameters(BaseModel):
|
||||||
"""parameters that have been remapped to other names"""
|
"""parameters that have been remapped to other names"""
|
||||||
@@ -427,6 +436,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
"ao_adamw_4bit",
|
"ao_adamw_4bit",
|
||||||
"ao_adamw_8bit",
|
"ao_adamw_8bit",
|
||||||
"ao_adamw_fp8",
|
"ao_adamw_fp8",
|
||||||
|
"adopt_adamw",
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
] = OptimizerNames.ADAMW_HF.value
|
] = OptimizerNames.ADAMW_HF.value
|
||||||
@@ -729,7 +739,7 @@ class AxolotlInputConfig(
|
|||||||
warmup_ratio: Optional[float] = None
|
warmup_ratio: Optional[float] = None
|
||||||
eval_steps: Optional[Union[int, float]] = None
|
eval_steps: Optional[Union[int, float]] = None
|
||||||
evals_per_epoch: Optional[Union[int]] = None
|
evals_per_epoch: Optional[Union[int]] = None
|
||||||
evaluation_strategy: Optional[str] = None
|
eval_strategy: Optional[str] = None
|
||||||
save_steps: Optional[Union[int, float]] = None
|
save_steps: Optional[Union[int, float]] = None
|
||||||
saves_per_epoch: Optional[int] = None
|
saves_per_epoch: Optional[int] = None
|
||||||
save_strategy: Optional[str] = None
|
save_strategy: Optional[str] = None
|
||||||
@@ -781,6 +791,8 @@ class AxolotlInputConfig(
|
|||||||
is_mistral_derived_model: Optional[bool] = Field(default=None)
|
is_mistral_derived_model: Optional[bool] = Field(default=None)
|
||||||
is_qwen_derived_model: Optional[bool] = Field(default=None)
|
is_qwen_derived_model: Optional[bool] = Field(default=None)
|
||||||
|
|
||||||
|
plugins: Optional[List[str]] = Field(default=None)
|
||||||
|
|
||||||
@field_validator("datasets", mode="before")
|
@field_validator("datasets", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def deprecate_sharegpt_datasets(cls, datasets):
|
def deprecate_sharegpt_datasets(cls, datasets):
|
||||||
@@ -788,7 +800,12 @@ class AxolotlInputConfig(
|
|||||||
if not ds_cfg.get("type"):
|
if not ds_cfg.get("type"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if ds_cfg["type"].startswith("sharegpt"):
|
ds_type = ds_cfg["type"]
|
||||||
|
# skip if it's a dict (for custom user instruction prompt)
|
||||||
|
if isinstance(ds_type, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(ds_type, str) and ds_type.startswith("sharegpt"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
|
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
|
||||||
)
|
)
|
||||||
@@ -1024,21 +1041,21 @@ class AxolotlInputConfig(
|
|||||||
@classmethod
|
@classmethod
|
||||||
def check_evals(cls, data):
|
def check_evals(cls, data):
|
||||||
if (
|
if (
|
||||||
data.get("evaluation_strategy")
|
data.get("eval_strategy")
|
||||||
and data.get("eval_steps")
|
and data.get("eval_steps")
|
||||||
and data.get("evaluation_strategy") != "steps"
|
and data.get("eval_strategy") != "steps"
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
"eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps."
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
data.get("val_set_size") == 0
|
data.get("val_set_size") == 0
|
||||||
and (data.get("eval_steps") or data.get("evaluation_strategy"))
|
and (data.get("eval_steps") or data.get("eval_strategy"))
|
||||||
and not data.get("test_datasets")
|
and not data.get("test_datasets")
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
"eval_steps and eval_strategy are not supported with val_set_size == 0"
|
||||||
)
|
)
|
||||||
if data.get("evals_per_epoch") and data.get("eval_steps"):
|
if data.get("evals_per_epoch") and data.get("eval_steps"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -1046,11 +1063,11 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
data.get("evals_per_epoch")
|
data.get("evals_per_epoch")
|
||||||
and data.get("evaluation_strategy")
|
and data.get("eval_strategy")
|
||||||
and data.get("evaluation_strategy") != "steps"
|
and data.get("eval_strategy") != "steps"
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
"eval_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||||
)
|
)
|
||||||
|
|
||||||
if data.get("do_bench_eval") and not (
|
if data.get("do_bench_eval") and not (
|
||||||
@@ -1282,6 +1299,25 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def warn_qlora_zero3_w_use_reentrant(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("adapter") == "qlora"
|
||||||
|
and data.get("gradient_checkpointing_kwargs", {})
|
||||||
|
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
|
||||||
|
is False
|
||||||
|
and "zero3" in data.get("deepspeed", "")
|
||||||
|
):
|
||||||
|
# may result in:
|
||||||
|
# torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint:
|
||||||
|
# Recomputed values for the following tensors have different metadata
|
||||||
|
# than during the forward pass.
|
||||||
|
LOG.warning(
|
||||||
|
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_val_w_test_datasets(cls, data):
|
def check_val_w_test_datasets(cls, data):
|
||||||
@@ -1291,6 +1327,19 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_eval_strategy(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("evaluation_strategy") is not None
|
||||||
|
and data.get("eval_strategy") is None
|
||||||
|
):
|
||||||
|
LOG.info(
|
||||||
|
"explicitly setting `eval_strategy` from the `evaluation_strategy`"
|
||||||
|
)
|
||||||
|
data["eval_strategy"] = data.get("evaluation_strategy")
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
||||||
|
|||||||
@@ -260,6 +260,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
for config_dataset in for_d_in_datasets(cfg_datasets):
|
for config_dataset in for_d_in_datasets(cfg_datasets):
|
||||||
ds: Optional[Union[Dataset, DatasetDict]] = None
|
ds: Optional[Union[Dataset, DatasetDict]] = None
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
|
ds_trust_remote_code = config_dataset.trust_remote_code
|
||||||
try:
|
try:
|
||||||
# this is just a basic check to see if the path is a
|
# this is just a basic check to see if the path is a
|
||||||
# valid HF dataset that's loadable
|
# valid HF dataset that's loadable
|
||||||
@@ -269,6 +270,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
streaming=True,
|
streaming=True,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
revision=config_dataset.revision,
|
revision=config_dataset.revision,
|
||||||
|
trust_remote_code=ds_trust_remote_code,
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||||
@@ -348,7 +350,15 @@ def load_tokenized_prepared_datasets(
|
|||||||
split=None,
|
split=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ds = load_from_disk(config_dataset.path)
|
try:
|
||||||
|
ds = load_from_disk(config_dataset.path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
ds = load_dataset(
|
||||||
|
config_dataset.path,
|
||||||
|
name=config_dataset.name,
|
||||||
|
streaming=False,
|
||||||
|
split=None,
|
||||||
|
)
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
ds_type = get_ds_type(config_dataset)
|
ds_type = get_ds_type(config_dataset)
|
||||||
|
|
||||||
@@ -366,7 +376,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
elif ds_from_hub:
|
elif ds_from_hub:
|
||||||
load_ds_kwargs = {}
|
load_ds_kwargs = {}
|
||||||
if config_dataset.split:
|
if config_dataset.split:
|
||||||
load_ds_kwargs = {"split": config_dataset.split}
|
load_ds_kwargs["split"] = config_dataset.split
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
config_dataset.path,
|
config_dataset.path,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
@@ -374,6 +384,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
data_files=config_dataset.data_files,
|
data_files=config_dataset.data_files,
|
||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
revision=config_dataset.revision,
|
revision=config_dataset.revision,
|
||||||
|
trust_remote_code=config_dataset.trust_remote_code,
|
||||||
**load_ds_kwargs,
|
**load_ds_kwargs,
|
||||||
)
|
)
|
||||||
elif ds_from_cloud and remote_file_system:
|
elif ds_from_cloud and remote_file_system:
|
||||||
@@ -391,6 +402,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
split=None,
|
||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
|
trust_remote_code=config_dataset.trust_remote_code,
|
||||||
)
|
)
|
||||||
elif config_dataset.path.startswith("https://"):
|
elif config_dataset.path.startswith("https://"):
|
||||||
ds_type = get_ds_type(config_dataset)
|
ds_type = get_ds_type(config_dataset)
|
||||||
@@ -401,6 +413,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
streaming=False,
|
streaming=False,
|
||||||
split=None,
|
split=None,
|
||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
|
trust_remote_code=config_dataset.trust_remote_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if isinstance(config_dataset.data_files, str):
|
if isinstance(config_dataset.data_files, str):
|
||||||
|
|||||||
25
src/axolotl/utils/environment.py
Normal file
25
src/axolotl/utils/environment.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""
|
||||||
|
utils to get GPU info for the current environment
|
||||||
|
"""
|
||||||
|
from accelerate.utils.environment import (
|
||||||
|
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
|
||||||
|
)
|
||||||
|
from accelerate.utils.environment import get_gpu_info
|
||||||
|
|
||||||
|
|
||||||
|
def check_cuda_p2p_ib_support():
|
||||||
|
if not accelerate_check_cuda_p2p_ib_support():
|
||||||
|
return False
|
||||||
|
unsupported_devices = {"RTX 6000 Ada"}
|
||||||
|
try:
|
||||||
|
device_names, device_count = get_gpu_info()
|
||||||
|
if 1 < device_count < 8:
|
||||||
|
if any(
|
||||||
|
unsupported_device in device_name
|
||||||
|
for device_name in device_names
|
||||||
|
for unsupported_device in unsupported_devices
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
except Exception: # pylint: disable=broad-except # nosec
|
||||||
|
pass
|
||||||
|
return True
|
||||||
@@ -14,6 +14,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
torch_version = version.parse(torch.__version__)
|
||||||
|
|
||||||
|
if torch_version < version.parse("2.4.0"):
|
||||||
|
torch_cuda_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
||||||
|
torch_cuda_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
||||||
|
else:
|
||||||
|
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||||
|
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||||
|
|
||||||
|
|
||||||
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||||
@@ -25,7 +35,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.cuda.amp.custom_fwd
|
@torch_cuda_amp_custom_fwd
|
||||||
def forward(ctx, forward_function, hidden_states, *args):
|
def forward(ctx, forward_function, hidden_states, *args):
|
||||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -36,7 +46,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.cuda.amp.custom_bwd
|
@torch_cuda_amp_custom_bwd
|
||||||
def backward(ctx, dY):
|
def backward(ctx, dY):
|
||||||
(hidden_states,) = ctx.saved_tensors
|
(hidden_states,) = ctx.saved_tensors
|
||||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||||
|
|||||||
@@ -238,6 +238,7 @@ def load_tokenizer(cfg):
|
|||||||
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
and k != "pad_token"
|
||||||
):
|
):
|
||||||
lora_modules_to_save = ", ".join(
|
lora_modules_to_save = ", ".join(
|
||||||
[f"`{x}`" for x in lora_modules_to_save]
|
[f"`{x}`" for x in lora_modules_to_save]
|
||||||
@@ -394,10 +395,17 @@ class ModelLoader:
|
|||||||
and self.cfg.flash_attention
|
and self.cfg.flash_attention
|
||||||
and self.cfg.sample_packing
|
and self.cfg.sample_packing
|
||||||
):
|
):
|
||||||
|
has_remote_code = (
|
||||||
|
"auto_map" in self.model_config
|
||||||
|
and "AutoModelForCausalLM" in self.model_config["auto_map"]
|
||||||
|
)
|
||||||
|
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||||
|
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
|
||||||
|
has_remote_code = self.cfg.trust_remote_code
|
||||||
patch_for_multipack(
|
patch_for_multipack(
|
||||||
self.cfg.model_config_type,
|
self.cfg.model_config_type,
|
||||||
model_name=self.cfg.base_model,
|
model_name=self.cfg.base_model,
|
||||||
is_remote_code=self.cfg.trust_remote_code,
|
has_remote_code=has_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.cfg.is_llama_derived_model:
|
if self.cfg.is_llama_derived_model:
|
||||||
|
|||||||
508
src/axolotl/utils/optimizers/adopt.py
Normal file
508
src/axolotl/utils/optimizers/adopt.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
"""
|
||||||
|
Copied from https://github.com/iShohei220/adopt
|
||||||
|
|
||||||
|
ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024)
|
||||||
|
Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka
|
||||||
|
"""
|
||||||
|
# mypy: ignore-errors
|
||||||
|
# pylint: skip-file
|
||||||
|
# mypy: allow-untyped-decorators
|
||||||
|
# mypy: allow-untyped-defs
|
||||||
|
from typing import List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.optim.optimizer import (
|
||||||
|
Optimizer,
|
||||||
|
ParamsT,
|
||||||
|
_default_to_fused_or_foreach,
|
||||||
|
_device_dtype_check_for_fused,
|
||||||
|
_disable_dynamo_if_unsupported,
|
||||||
|
_get_capturable_supported_devices,
|
||||||
|
_get_scalar_dtype,
|
||||||
|
_get_value,
|
||||||
|
_use_grad_for_differentiable,
|
||||||
|
_view_as_real,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["ADOPT", "adopt"]
|
||||||
|
|
||||||
|
|
||||||
|
class ADOPT(Optimizer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: ParamsT,
|
||||||
|
lr: Union[float, Tensor] = 1e-3,
|
||||||
|
betas: Tuple[float, float] = (0.9, 0.9999),
|
||||||
|
eps: float = 1e-6,
|
||||||
|
weight_decay: float = 0.0,
|
||||||
|
decoupled: bool = False,
|
||||||
|
*,
|
||||||
|
foreach: Optional[bool] = None,
|
||||||
|
maximize: bool = False,
|
||||||
|
capturable: bool = False,
|
||||||
|
differentiable: bool = False,
|
||||||
|
fused: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
if isinstance(lr, Tensor):
|
||||||
|
if foreach and not capturable:
|
||||||
|
raise ValueError(
|
||||||
|
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||||
|
)
|
||||||
|
if lr.numel() != 1:
|
||||||
|
raise ValueError("Tensor lr must be 1-element")
|
||||||
|
if not 0.0 <= lr:
|
||||||
|
raise ValueError(f"Invalid learning rate: {lr}")
|
||||||
|
if not 0.0 <= eps:
|
||||||
|
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||||
|
if not 0.0 <= betas[0] < 1.0:
|
||||||
|
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||||
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
|
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||||
|
if not 0.0 <= weight_decay:
|
||||||
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
|
defaults = dict(
|
||||||
|
lr=lr,
|
||||||
|
betas=betas,
|
||||||
|
eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
decoupled=decoupled,
|
||||||
|
maximize=maximize,
|
||||||
|
foreach=foreach,
|
||||||
|
capturable=capturable,
|
||||||
|
differentiable=differentiable,
|
||||||
|
fused=fused,
|
||||||
|
)
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
if fused:
|
||||||
|
# TODO: support fused
|
||||||
|
raise RuntimeError("`fused` is not currently supported")
|
||||||
|
|
||||||
|
if differentiable:
|
||||||
|
raise RuntimeError("`fused` does not support `differentiable`")
|
||||||
|
self._step_supports_amp_scaling = True
|
||||||
|
# TODO(crcrpar): [low prec params & their higher prec copy]
|
||||||
|
# Support AMP with FP16/BF16 model params which would need
|
||||||
|
# higher prec copy of params to do update math in higher prec to
|
||||||
|
# alleviate the loss of information.
|
||||||
|
if foreach:
|
||||||
|
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super().__setstate__(state)
|
||||||
|
for group in self.param_groups:
|
||||||
|
group.setdefault("maximize", False)
|
||||||
|
group.setdefault("foreach", None)
|
||||||
|
group.setdefault("capturable", False)
|
||||||
|
group.setdefault("differentiable", False)
|
||||||
|
fused = group.setdefault("fused", None)
|
||||||
|
for p in group["params"]:
|
||||||
|
p_state = self.state.get(p, [])
|
||||||
|
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
|
||||||
|
step_val = float(p_state["step"])
|
||||||
|
p_state["step"] = (
|
||||||
|
torch.tensor(
|
||||||
|
step_val,
|
||||||
|
dtype=_get_scalar_dtype(is_fused=fused),
|
||||||
|
device=p.device,
|
||||||
|
)
|
||||||
|
if group["capturable"] or group["fused"]
|
||||||
|
else torch.tensor(step_val, dtype=_get_scalar_dtype())
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_group(
|
||||||
|
self,
|
||||||
|
group,
|
||||||
|
params_with_grad,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
):
|
||||||
|
has_complex = False
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is not None:
|
||||||
|
has_complex |= torch.is_complex(p)
|
||||||
|
params_with_grad.append(p)
|
||||||
|
if p.grad.is_sparse:
|
||||||
|
raise RuntimeError("ADOPT does not support sparse gradients")
|
||||||
|
grads.append(p.grad)
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
# Lazy state initialization
|
||||||
|
if len(state) == 0:
|
||||||
|
if group["fused"]:
|
||||||
|
_device_dtype_check_for_fused(p)
|
||||||
|
# note(crcrpar): [special device hosting for step]
|
||||||
|
# Deliberately host `step` on CPU if both capturable and fused are off.
|
||||||
|
# This is because kernel launches are costly on CUDA and XLA.
|
||||||
|
state["step"] = (
|
||||||
|
torch.zeros(
|
||||||
|
(),
|
||||||
|
dtype=_get_scalar_dtype(is_fused=group["fused"]),
|
||||||
|
device=p.device,
|
||||||
|
)
|
||||||
|
if group["capturable"] or group["fused"]
|
||||||
|
else torch.tensor(0.0, dtype=_get_scalar_dtype())
|
||||||
|
)
|
||||||
|
# Exponential moving average of gradient values
|
||||||
|
state["exp_avg"] = torch.zeros_like(
|
||||||
|
p, memory_format=torch.preserve_format
|
||||||
|
)
|
||||||
|
# Exponential moving average of squared gradient values
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(
|
||||||
|
p, memory_format=torch.preserve_format
|
||||||
|
)
|
||||||
|
|
||||||
|
exp_avgs.append(state["exp_avg"])
|
||||||
|
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||||
|
|
||||||
|
if group["differentiable"] and state["step"].requires_grad:
|
||||||
|
raise RuntimeError(
|
||||||
|
"`requires_grad` is not supported for `step` in differentiable mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Foreach without capturable does not support a tensor lr
|
||||||
|
if (
|
||||||
|
group["foreach"]
|
||||||
|
and torch.is_tensor(group["lr"])
|
||||||
|
and not group["capturable"]
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||||
|
)
|
||||||
|
|
||||||
|
state_steps.append(state["step"])
|
||||||
|
return has_complex
|
||||||
|
|
||||||
|
@_use_grad_for_differentiable
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""Perform a single optimization step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
closure (Callable, optional): A closure that reevaluates the model
|
||||||
|
and returns the loss.
|
||||||
|
"""
|
||||||
|
self._cuda_graph_capture_health_check()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
params_with_grad: List[Tensor] = []
|
||||||
|
grads: List[Tensor] = []
|
||||||
|
exp_avgs: List[Tensor] = []
|
||||||
|
exp_avg_sqs: List[Tensor] = []
|
||||||
|
state_steps: List[Tensor] = []
|
||||||
|
beta1, beta2 = group["betas"]
|
||||||
|
|
||||||
|
has_complex = self._init_group(
|
||||||
|
group,
|
||||||
|
params_with_grad,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
adopt(
|
||||||
|
params_with_grad,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
has_complex=has_complex,
|
||||||
|
beta1=beta1,
|
||||||
|
beta2=beta2,
|
||||||
|
lr=group["lr"],
|
||||||
|
weight_decay=group["weight_decay"],
|
||||||
|
decoupled=group["decoupled"],
|
||||||
|
eps=group["eps"],
|
||||||
|
maximize=group["maximize"],
|
||||||
|
foreach=group["foreach"],
|
||||||
|
capturable=group["capturable"],
|
||||||
|
differentiable=group["differentiable"],
|
||||||
|
fused=group["fused"],
|
||||||
|
grad_scale=getattr(self, "grad_scale", None),
|
||||||
|
found_inf=getattr(self, "found_inf", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def _single_tensor_adopt(
|
||||||
|
params: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
exp_avgs: List[Tensor],
|
||||||
|
exp_avg_sqs: List[Tensor],
|
||||||
|
state_steps: List[Tensor],
|
||||||
|
grad_scale: Optional[Tensor],
|
||||||
|
found_inf: Optional[Tensor],
|
||||||
|
*,
|
||||||
|
has_complex: bool,
|
||||||
|
beta1: float,
|
||||||
|
beta2: float,
|
||||||
|
lr: Union[float, Tensor],
|
||||||
|
weight_decay: float,
|
||||||
|
decoupled: bool,
|
||||||
|
eps: float,
|
||||||
|
maximize: bool,
|
||||||
|
capturable: bool,
|
||||||
|
differentiable: bool,
|
||||||
|
):
|
||||||
|
assert grad_scale is None and found_inf is None
|
||||||
|
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
# this assert is due to JIT being dumb and not realizing that the ops below
|
||||||
|
# have overloads to handle both float and Tensor lrs, so we just assert it's
|
||||||
|
# a float since most people using JIT are using floats
|
||||||
|
assert isinstance(lr, float)
|
||||||
|
|
||||||
|
for i, param in enumerate(params):
|
||||||
|
grad = grads[i] if not maximize else -grads[i]
|
||||||
|
exp_avg = exp_avgs[i]
|
||||||
|
exp_avg_sq = exp_avg_sqs[i]
|
||||||
|
step_t = state_steps[i]
|
||||||
|
|
||||||
|
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||||
|
if not torch._utils.is_compiling() and capturable:
|
||||||
|
capturable_supported_devices = _get_capturable_supported_devices()
|
||||||
|
assert (
|
||||||
|
param.device.type == step_t.device.type
|
||||||
|
and param.device.type in capturable_supported_devices
|
||||||
|
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||||
|
|
||||||
|
# update step
|
||||||
|
step_t += 1
|
||||||
|
|
||||||
|
if weight_decay != 0:
|
||||||
|
if decoupled:
|
||||||
|
param.add_(param, alpha=-lr * weight_decay)
|
||||||
|
else:
|
||||||
|
grad = grad.add(param, alpha=weight_decay)
|
||||||
|
|
||||||
|
if torch.is_complex(param):
|
||||||
|
grad = torch.view_as_real(grad)
|
||||||
|
if exp_avg is not None:
|
||||||
|
exp_avg = torch.view_as_real(exp_avg)
|
||||||
|
if exp_avg_sq is not None:
|
||||||
|
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||||
|
param = torch.view_as_real(param)
|
||||||
|
|
||||||
|
step = step_t if capturable or differentiable else _get_value(step_t)
|
||||||
|
if step == 1:
|
||||||
|
exp_avg_sq.addcmul_(grad, grad.conj())
|
||||||
|
continue
|
||||||
|
|
||||||
|
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
||||||
|
if step == 2:
|
||||||
|
exp_avg.addcdiv_(grad, denom)
|
||||||
|
else:
|
||||||
|
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
|
||||||
|
|
||||||
|
param.add_(exp_avg, alpha=-lr)
|
||||||
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||||
|
|
||||||
|
|
||||||
|
def _multi_tensor_adopt(
|
||||||
|
params: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
exp_avgs: List[Tensor],
|
||||||
|
exp_avg_sqs: List[Tensor],
|
||||||
|
state_steps: List[Tensor],
|
||||||
|
grad_scale: Optional[Tensor],
|
||||||
|
found_inf: Optional[Tensor],
|
||||||
|
*,
|
||||||
|
has_complex: bool,
|
||||||
|
beta1: float,
|
||||||
|
beta2: float,
|
||||||
|
lr: Union[float, Tensor],
|
||||||
|
weight_decay: float,
|
||||||
|
decoupled: bool,
|
||||||
|
eps: float,
|
||||||
|
maximize: bool,
|
||||||
|
capturable: bool,
|
||||||
|
differentiable: bool,
|
||||||
|
):
|
||||||
|
if len(params) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(lr, Tensor) and not capturable:
|
||||||
|
raise RuntimeError(
|
||||||
|
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||||
|
if not torch._utils.is_compiling() and capturable:
|
||||||
|
capturable_supported_devices = _get_capturable_supported_devices(
|
||||||
|
supports_xla=False
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
p.device.type == step.device.type
|
||||||
|
and p.device.type in capturable_supported_devices
|
||||||
|
for p, step in zip(params, state_steps)
|
||||||
|
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||||
|
|
||||||
|
assert grad_scale is None and found_inf is None
|
||||||
|
|
||||||
|
assert not differentiable, "_foreach ops don't support autograd"
|
||||||
|
|
||||||
|
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||||
|
[params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item]
|
||||||
|
)
|
||||||
|
for (
|
||||||
|
device_params_,
|
||||||
|
device_grads_,
|
||||||
|
device_exp_avgs_,
|
||||||
|
device_exp_avg_sqs_,
|
||||||
|
device_state_steps_,
|
||||||
|
), _ in grouped_tensors.values():
|
||||||
|
device_params = cast(List[Tensor], device_params_)
|
||||||
|
device_grads = cast(List[Tensor], device_grads_)
|
||||||
|
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
|
||||||
|
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
|
||||||
|
device_state_steps = cast(List[Tensor], device_state_steps_)
|
||||||
|
|
||||||
|
# Handle complex parameters
|
||||||
|
if has_complex:
|
||||||
|
_view_as_real(
|
||||||
|
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
|
||||||
|
)
|
||||||
|
|
||||||
|
if maximize:
|
||||||
|
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Update steps
|
||||||
|
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||||
|
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||||
|
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||||
|
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||||
|
torch._foreach_add_(
|
||||||
|
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch._foreach_add_(device_state_steps, 1)
|
||||||
|
|
||||||
|
if weight_decay != 0:
|
||||||
|
if decoupled:
|
||||||
|
torch._foreach_add_(
|
||||||
|
device_params, device_params, alpha=-lr * weight_decay
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||||
|
if maximize:
|
||||||
|
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||||
|
else:
|
||||||
|
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||||
|
device_grads, device_params, alpha=weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
if device_state_steps[0] == 1:
|
||||||
|
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
||||||
|
continue
|
||||||
|
|
||||||
|
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||||
|
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
|
||||||
|
|
||||||
|
if device_state_steps[0] == 2:
|
||||||
|
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
|
||||||
|
else:
|
||||||
|
torch._foreach_mul_(device_exp_avgs, beta1)
|
||||||
|
torch._foreach_addcdiv_(
|
||||||
|
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
|
||||||
|
)
|
||||||
|
|
||||||
|
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||||
|
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||||
|
torch._foreach_addcmul_(
|
||||||
|
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
|
||||||
|
def adopt(
|
||||||
|
params: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
exp_avgs: List[Tensor],
|
||||||
|
exp_avg_sqs: List[Tensor],
|
||||||
|
state_steps: List[Tensor],
|
||||||
|
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||||
|
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||||
|
foreach: Optional[bool] = None,
|
||||||
|
capturable: bool = False,
|
||||||
|
differentiable: bool = False,
|
||||||
|
fused: Optional[bool] = None,
|
||||||
|
grad_scale: Optional[Tensor] = None,
|
||||||
|
found_inf: Optional[Tensor] = None,
|
||||||
|
has_complex: bool = False,
|
||||||
|
*,
|
||||||
|
beta1: float,
|
||||||
|
beta2: float,
|
||||||
|
lr: Union[float, Tensor],
|
||||||
|
weight_decay: float,
|
||||||
|
decoupled: bool,
|
||||||
|
eps: float,
|
||||||
|
maximize: bool,
|
||||||
|
):
|
||||||
|
r"""Functional API that performs ADOPT algorithm computation."""
|
||||||
|
# Respect when the user inputs False/True for foreach or fused. We only want to change
|
||||||
|
# the default when neither have been user-specified. Note that we default to foreach
|
||||||
|
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
|
||||||
|
# bake-in time before making it the default, even if it is typically faster.
|
||||||
|
if fused is None and foreach is None:
|
||||||
|
_, foreach = _default_to_fused_or_foreach(
|
||||||
|
params, differentiable, use_fused=False
|
||||||
|
)
|
||||||
|
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
|
||||||
|
if foreach and isinstance(lr, Tensor) and not capturable:
|
||||||
|
foreach = False
|
||||||
|
if fused is None:
|
||||||
|
fused = False
|
||||||
|
if foreach is None:
|
||||||
|
foreach = False
|
||||||
|
|
||||||
|
# this check is slow during compilation, so we skip it
|
||||||
|
# if it's strictly needed we can add this check back in dynamo
|
||||||
|
if not torch._utils.is_compiling() and not all(
|
||||||
|
isinstance(t, torch.Tensor) for t in state_steps
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
if foreach and torch.jit.is_scripting():
|
||||||
|
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||||
|
if fused and torch.jit.is_scripting():
|
||||||
|
raise RuntimeError("torch.jit.script not supported with fused optimizers")
|
||||||
|
|
||||||
|
# if fused and not torch.jit.is_scripting():
|
||||||
|
# func = _fused_adopt
|
||||||
|
# elif foreach and not torch.jit.is_scripting():
|
||||||
|
if foreach and not torch.jit.is_scripting():
|
||||||
|
func = _multi_tensor_adopt
|
||||||
|
else:
|
||||||
|
func = _single_tensor_adopt
|
||||||
|
|
||||||
|
func(
|
||||||
|
params,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
has_complex=has_complex,
|
||||||
|
beta1=beta1,
|
||||||
|
beta2=beta2,
|
||||||
|
lr=lr,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
decoupled=decoupled,
|
||||||
|
eps=eps,
|
||||||
|
maximize=maximize,
|
||||||
|
capturable=capturable,
|
||||||
|
differentiable=differentiable,
|
||||||
|
grad_scale=grad_scale,
|
||||||
|
found_inf=found_inf,
|
||||||
|
)
|
||||||
@@ -16,7 +16,11 @@ from torch.utils.data import DataLoader, RandomSampler
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
|
from axolotl.monkeypatch.trainer_fsdp_grad_accum import (
|
||||||
|
patch_training_loop_for_fsdp_grad_accum,
|
||||||
|
)
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
LOG = get_logger("axolotl")
|
LOG = get_logger("axolotl")
|
||||||
@@ -184,11 +188,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
min_sequence_len=cfg.min_sample_len or 2,
|
min_sequence_len=cfg.min_sample_len or 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.is_preprocess:
|
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
||||||
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
||||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
|
||||||
|
|
||||||
if cfg.model_config_type == "mamba":
|
if cfg.model_config_type == "mamba":
|
||||||
LOG.info("dropping attention_mask column")
|
LOG.info("dropping attention_mask column")
|
||||||
@@ -461,6 +464,9 @@ def setup_fsdp_envs(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_optim_env(cfg):
|
def prepare_optim_env(cfg):
|
||||||
|
if not check_cuda_p2p_ib_support():
|
||||||
|
if os.getenv("NCCL_P2P_DISABLE") is None:
|
||||||
|
os.environ["NCCL_P2P_DISABLE"] = "1"
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
@@ -490,6 +496,11 @@ def prepare_opinionated_env(cfg):
|
|||||||
def setup_trainer(
|
def setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||||
):
|
):
|
||||||
|
if cfg.fsdp:
|
||||||
|
try:
|
||||||
|
patch_training_loop_for_fsdp_grad_accum()
|
||||||
|
except AssertionError:
|
||||||
|
pass
|
||||||
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
|
|||||||
16
tests/e2e/conftest.py
Normal file
16
tests/e2e/conftest.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
shared pytest fixtures
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
# Create a temporary directory
|
||||||
|
_temp_dir = tempfile.mkdtemp()
|
||||||
|
yield _temp_dir
|
||||||
|
# Clean up the directory after the test
|
||||||
|
shutil.rmtree(_temp_dir)
|
||||||
@@ -3,28 +3,25 @@ E2E tests for multigpu eval
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import with_temp_dir
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
class TestMultiGPUEval(unittest.TestCase):
|
class TestMultiGPUEval:
|
||||||
"""
|
"""
|
||||||
Test case for MultiGPU Eval Sample Packing
|
Test case for MultiGPU Eval Sample Packing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_eval_sample_packing(self, temp_dir):
|
def test_eval_sample_packing(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -83,13 +80,14 @@ class TestMultiGPUEval(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_eval(self, temp_dir):
|
def test_eval(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -148,6 +146,8 @@ class TestMultiGPUEval(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
|||||||
@@ -4,17 +4,17 @@ E2E tests for multigpu lora tinyllama
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import is_hopper, with_temp_dir
|
from ..utils import is_hopper
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -28,18 +28,16 @@ def download_model():
|
|||||||
snapshot_download("TinyLlama/TinyLlama_v1.1")
|
snapshot_download("TinyLlama/TinyLlama_v1.1")
|
||||||
|
|
||||||
|
|
||||||
class TestMultiGPULlama(unittest.TestCase):
|
class TestMultiGPULlama:
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using LoRA
|
Test case for Llama models using LoRA
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_lora_ddp(self, temp_dir):
|
def test_lora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
@@ -48,9 +46,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -81,19 +77,23 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize(
|
||||||
def test_lora_ddp_packed(self, temp_dir):
|
"gradient_accumulation_steps",
|
||||||
|
[1, 4],
|
||||||
|
)
|
||||||
|
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_sample_packing": False,
|
"eval_sample_packing": False,
|
||||||
@@ -105,9 +105,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -118,7 +116,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 15,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
@@ -138,6 +136,8 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
@@ -145,7 +145,6 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
|
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
|
||||||
@with_temp_dir
|
|
||||||
def test_dpo_lora_ddp(self, temp_dir):
|
def test_dpo_lora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -210,13 +209,14 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_dpo_qlora_ddp(self, temp_dir):
|
def test_dpo_qlora_ddp(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -278,25 +278,27 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize(
|
||||||
def test_fsdp(self, temp_dir):
|
"gradient_accumulation_steps",
|
||||||
|
[1, 4],
|
||||||
|
)
|
||||||
|
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.01,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -305,9 +307,9 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 10,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
@@ -324,7 +326,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"fsdp_use_orig_params": False,
|
"fsdp_use_orig_params": False,
|
||||||
"fsdp_cpu_ram_efficient_loading": False,
|
"fsdp_cpu_ram_efficient_loading": False,
|
||||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -341,28 +343,29 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize(
|
||||||
def test_fsdp_packed(self, temp_dir):
|
"fsdp_state_dict_type",
|
||||||
|
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
|
||||||
|
)
|
||||||
|
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_sample_packing": False,
|
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -390,7 +393,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"fsdp_use_orig_params": False,
|
"fsdp_use_orig_params": False,
|
||||||
"fsdp_cpu_ram_efficient_loading": False,
|
"fsdp_cpu_ram_efficient_loading": False,
|
||||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
"fsdp_state_dict_type": fsdp_state_dict_type,
|
||||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -407,13 +410,14 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -483,28 +487,29 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize(
|
||||||
def test_ds_zero3_packed(self, temp_dir):
|
"gradient_accumulation_steps",
|
||||||
|
[1, 4],
|
||||||
|
)
|
||||||
|
def test_ds_zero3_packed(self, temp_dir, gradient_accumulation_steps):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"eval_sample_packing": False,
|
|
||||||
"pad_to_sequence_len": True,
|
"pad_to_sequence_len": True,
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -515,7 +520,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 15,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 4,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
@@ -536,19 +541,19 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@with_temp_dir
|
|
||||||
def test_ds_zero3_qlora_packed(self, temp_dir):
|
def test_ds_zero3_qlora_packed(self, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
"lora_r": 8,
|
"lora_r": 8,
|
||||||
@@ -561,9 +566,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"val_set_size": 0.05,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {
|
"special_tokens": {
|
||||||
"unk_token": "<unk>",
|
"pad_token": "<|endoftext|>",
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -595,6 +598,8 @@ class TestMultiGPULlama(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
|||||||
@@ -4,31 +4,30 @@ E2E tests for multigpu qwen2
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from accelerate.test_utils import execute_subprocess_async
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from ..utils import with_temp_dir
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
class TestMultiGPUQwen2(unittest.TestCase):
|
class TestMultiGPUQwen2:
|
||||||
"""
|
"""
|
||||||
Test case for Llama models using LoRA
|
Test case for Llama models using LoRA
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@with_temp_dir
|
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
||||||
def test_qlora_fsdp_dpo(self, temp_dir):
|
def test_qlora_fsdp_dpo(self, base_model, temp_dir):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"base_model": "Qwen/Qwen2-1.5B",
|
"base_model": base_model,
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"rl": "dpo",
|
"rl": "dpo",
|
||||||
"chat_template": "chatml",
|
"chat_template": "chatml",
|
||||||
@@ -47,9 +46,9 @@ class TestMultiGPUQwen2(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
"max_steps": 15,
|
"max_steps": 5,
|
||||||
"warmup_steps": 20,
|
"warmup_steps": 20,
|
||||||
"micro_batch_size": 4,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
@@ -91,6 +90,8 @@ class TestMultiGPUQwen2(unittest.TestCase):
|
|||||||
"launch",
|
"launch",
|
||||||
"--num-processes",
|
"--num-processes",
|
||||||
"2",
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
"-m",
|
"-m",
|
||||||
"axolotl.cli.train",
|
"axolotl.cli.train",
|
||||||
str(Path(temp_dir) / "config.yaml"),
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
|||||||
15
tests/e2e/patched/test_trainer_fsdp.py
Normal file
15
tests/e2e/patched/test_trainer_fsdp.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.trainer_fsdp_grad_accum import check_training_loop_is_patchable
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainerFSDPIntegration(unittest.TestCase):
|
||||||
|
"""Unsloth monkeypatch integration tests."""
|
||||||
|
|
||||||
|
def test_train_loop_patchable(self):
|
||||||
|
# ensures the current version of transformers has loss code that matches our patching code
|
||||||
|
self.assertTrue(
|
||||||
|
check_training_loop_is_patchable(),
|
||||||
|
"HF transformers _inner_training_loop has changed and isn't patchable",
|
||||||
|
)
|
||||||
66
tests/e2e/test_llama.py
Normal file
66
tests/e2e/test_llama.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLlama(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_fft_trust_remote_code(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"trust_remote_code": True,
|
||||||
|
"sequence_len": 512,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"sample_packing": True,
|
||||||
|
"bf16": True,
|
||||||
|
"save_safetensors": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
@@ -13,7 +13,7 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import with_temp_dir
|
from .utils import require_torch_2_5_1, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -65,3 +65,80 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
@require_torch_2_5_1
|
||||||
|
def test_adopt_adamw(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adopt_adamw",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 4,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "schedule_free_adamw",
|
||||||
|
"lr_scheduler": "constant",
|
||||||
|
"save_safetensors": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||||
|
|||||||
85
tests/e2e/test_qwen.py
Normal file
85
tests/e2e/test_qwen.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for qwen
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from accelerate.test_utils import execute_subprocess_async
|
||||||
|
from transformers.testing_utils import get_torch_dist_unique_port
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.qwen")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestE2eQwen:
|
||||||
|
"""
|
||||||
|
Test cases for qwen models
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
||||||
|
def test_dpo(self, base_model, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": base_model,
|
||||||
|
"rl": "dpo",
|
||||||
|
"chat_template": "qwen_25",
|
||||||
|
"sequence_len": 2048,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||||
|
"split": "train",
|
||||||
|
"type": "chat_template.default",
|
||||||
|
"field_messages": "conversation",
|
||||||
|
"field_chosen": "chosen",
|
||||||
|
"field_rejected": "rejected",
|
||||||
|
"message_field_role": "role",
|
||||||
|
"message_field_content": "content",
|
||||||
|
"roles": {
|
||||||
|
"system": ["system"],
|
||||||
|
"user": ["user"],
|
||||||
|
"assistant": ["assistant"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"max_steps": 5,
|
||||||
|
"warmup_steps": 20,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 2,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_bnb_8bit",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"flash_attention": True,
|
||||||
|
"bf16": "auto",
|
||||||
|
"tf32": True,
|
||||||
|
"gradient_checkpointing": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# write cfg to yaml file
|
||||||
|
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
execute_subprocess_async(
|
||||||
|
[
|
||||||
|
"accelerate",
|
||||||
|
"launch",
|
||||||
|
"--num-processes",
|
||||||
|
"2",
|
||||||
|
"--main_process_port",
|
||||||
|
f"{get_torch_dist_unique_port()}",
|
||||||
|
"-m",
|
||||||
|
"axolotl.cli.train",
|
||||||
|
str(Path(temp_dir) / "config.yaml"),
|
||||||
|
]
|
||||||
|
)
|
||||||
@@ -6,11 +6,13 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from importlib.metadata import version
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
# from importlib.metadata import version
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
def with_temp_dir(test_func):
|
def with_temp_dir(test_func):
|
||||||
@wraps(test_func)
|
@wraps(test_func)
|
||||||
@@ -43,12 +45,24 @@ def require_torch_2_3_1(test_case):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def is_min_2_3_1():
|
def is_min_2_3_1():
|
||||||
torch_version = version("torch")
|
torch_version = version.parse(torch.__version__)
|
||||||
return torch_version >= "2.3.1"
|
return torch_version >= version.parse("2.3.1")
|
||||||
|
|
||||||
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
|
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_torch_2_5_1(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires torch >= 2.3.1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def is_min_2_5_1():
|
||||||
|
torch_version = version.parse(torch.__version__)
|
||||||
|
return torch_version >= version.parse("2.5.1")
|
||||||
|
|
||||||
|
return unittest.skipUnless(is_min_2_5_1(), "test torch 2.5.1")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def is_hopper():
|
def is_hopper():
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
return compute_capability == (9, 0)
|
return compute_capability == (9, 0)
|
||||||
|
|||||||
@@ -371,44 +371,79 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
def test_load_local_hub_with_revision(self):
|
def test_load_local_hub_with_revision(self):
|
||||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir2:
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
tmp_ds_path = Path(tmp_dir2) / "mhenrichsen/alpaca_2k_test"
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
snapshot_download(
|
||||||
snapshot_download(
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
repo_type="dataset",
|
||||||
repo_type="dataset",
|
local_dir=tmp_ds_path,
|
||||||
local_dir=tmp_ds_path,
|
revision="d05c1cb",
|
||||||
revision="d05c1cb",
|
)
|
||||||
)
|
|
||||||
|
|
||||||
prepared_path = Path(tmp_dir) / "prepared"
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
"ds_type": "parquet",
|
"ds_type": "parquet",
|
||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
"data_files": [
|
"data_files": [
|
||||||
f"{tmp_ds_path}/alpaca_2000.parquet",
|
f"{tmp_ds_path}/alpaca_2000.parquet",
|
||||||
],
|
],
|
||||||
"revision": "d05c1cb",
|
"revision": "d05c1cb",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
self.tokenizer, cfg, prepared_path
|
self.tokenizer, cfg, prepared_path
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
assert len(dataset) == 2000
|
||||||
assert "input_ids" in dataset.features
|
assert "input_ids" in dataset.features
|
||||||
assert "attention_mask" in dataset.features
|
assert "attention_mask" in dataset.features
|
||||||
assert "labels" in dataset.features
|
assert "labels" in dataset.features
|
||||||
shutil.rmtree(tmp_ds_path)
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
|
def test_loading_local_dataset_folder(self):
|
||||||
|
"""Verify that a dataset downloaded to a local folder can be loaded"""
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=tmp_ds_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
prepared_path = Path(tmp_dir) / "prepared"
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": str(tmp_ds_path),
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
|
self.tokenizer, cfg, prepared_path
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(dataset) == 2000
|
||||||
|
assert "input_ids" in dataset.features
|
||||||
|
assert "attention_mask" in dataset.features
|
||||||
|
assert "labels" in dataset.features
|
||||||
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
import functools
|
import functools
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@@ -21,6 +22,7 @@ class TestPretrainingPacking(unittest.TestCase):
|
|||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
self.tokenizer.pad_token = "</s>"
|
self.tokenizer.pad_token = "</s>"
|
||||||
|
|
||||||
|
@pytest.mark.flaky(retries=3, delay=5)
|
||||||
def test_packing_stream_dataset(self):
|
def test_packing_stream_dataset(self):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
|
|||||||
@@ -726,7 +726,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"evaluation_strategy": "epoch",
|
"eval_strategy": "epoch",
|
||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -734,14 +734,14 @@ class TestValidation(BaseValidation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
ValueError, match=r".*eval_strategy and eval_steps mismatch.*"
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"evaluation_strategy": "no",
|
"eval_strategy": "no",
|
||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -749,14 +749,14 @@ class TestValidation(BaseValidation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
ValueError, match=r".*eval_strategy and eval_steps mismatch.*"
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"evaluation_strategy": "steps",
|
"eval_strategy": "steps",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -767,7 +767,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"evaluation_strategy": "steps",
|
"eval_strategy": "steps",
|
||||||
"eval_steps": 10,
|
"eval_steps": 10,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -790,7 +790,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"evaluation_strategy": "no",
|
"eval_strategy": "no",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -801,7 +801,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"evaluation_strategy": "epoch",
|
"eval_strategy": "epoch",
|
||||||
"val_set_size": 0,
|
"val_set_size": 0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -810,7 +810,7 @@ class TestValidation(BaseValidation):
|
|||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
match=r".*eval_steps and eval_strategy are not supported with val_set_size == 0.*",
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
@@ -826,7 +826,7 @@ class TestValidation(BaseValidation):
|
|||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
match=r".*eval_steps and eval_strategy are not supported with val_set_size == 0.*",
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
@@ -856,7 +856,7 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"evaluation_strategy": "epoch",
|
"eval_strategy": "epoch",
|
||||||
"val_set_size": 0.01,
|
"val_set_size": 0.01,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -1095,6 +1095,24 @@ class TestValidation(BaseValidation):
|
|||||||
assert new_cfg["dpo_beta"] is None
|
assert new_cfg["dpo_beta"] is None
|
||||||
assert len(self._caplog.records) == 1
|
assert len(self._caplog.records) == 1
|
||||||
|
|
||||||
|
def test_eval_strategy_remap(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"evaluation_strategy": "steps",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
assert new_cfg.eval_strategy == "steps"
|
||||||
|
assert (
|
||||||
|
"evaluation_strategy is deprecated, use eval_strategy instead"
|
||||||
|
in self._caplog.records[0].message
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestValidationCheckModelConfig(BaseValidation):
|
class TestValidationCheckModelConfig(BaseValidation):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -234,3 +234,59 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
_check_config()
|
_check_config()
|
||||||
|
|
||||||
|
def test_dataset_sharegpt_deprecation(self, minimal_cfg):
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"chat_template": "chatml",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "LDJnr/Puffin",
|
||||||
|
"type": "sharegpt",
|
||||||
|
"conversation": "chatml",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check sharegpt deprecation is raised
|
||||||
|
with pytest.raises(ValueError, match=r".*type: sharegpt.*` is deprecated.*"):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
# Check that deprecation is not thrown for non-str type
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": {
|
||||||
|
"field_instruction": "instruction",
|
||||||
|
"field_output": "output",
|
||||||
|
"field_system": "system",
|
||||||
|
"format": "<|user|> {instruction} {input} <|model|>",
|
||||||
|
"no_input_format": "<|user|> {instruction} <|model|>",
|
||||||
|
"system_prompt": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
# Check that deprecation is not thrown for non-sharegpt type
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user