Compare commits
26 Commits
v0.5.0
...
transforme
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60763b2e61 | ||
|
|
082a41af9d | ||
|
|
2d7830fda6 | ||
|
|
5e98cdddac | ||
|
|
1d7aee0ad2 | ||
|
|
659ee5d723 | ||
|
|
342935cff3 | ||
|
|
c5eb9ea2c2 | ||
|
|
f2145a3ccb | ||
|
|
010d0e7ff3 | ||
|
|
01881c3113 | ||
|
|
0e8eb96e07 | ||
|
|
4e1891b12b | ||
|
|
28924fc791 | ||
|
|
8c480b2804 | ||
|
|
a4b1cc6df0 | ||
|
|
7b78a31593 | ||
|
|
810ebc2c0e | ||
|
|
ad435a3b09 | ||
|
|
9f1cf9b17c | ||
|
|
3931a42763 | ||
|
|
dc8f9059f7 | ||
|
|
234e94e9dd | ||
|
|
f68fb71005 | ||
|
|
9bc3ee6c75 | ||
|
|
d356740ffa |
10
.github/workflows/base.yml
vendored
10
.github/workflows/base.yml
vendored
@@ -44,19 +44,21 @@ jobs:
|
|||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: Docker metadata
|
- name: Docker metadata
|
||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v3
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-base
|
images: |
|
||||||
|
winglian/axolotl-base
|
||||||
|
axolotlai/axolotl-base
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v2
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
2
.github/workflows/docs.yml
vendored
2
.github/workflows/docs.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
|||||||
- name: Set up Quarto
|
- name: Set up Quarto
|
||||||
uses: quarto-dev/quarto-actions/setup@v2
|
uses: quarto-dev/quarto-actions/setup@v2
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v3
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
- name: install dependencies
|
- name: install dependencies
|
||||||
|
|||||||
6
.github/workflows/lint.yml
vendored
6
.github/workflows/lint.yml
vendored
@@ -15,9 +15,9 @@ jobs:
|
|||||||
name: pre-commit
|
name: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
- uses: pre-commit/action@v3.0.0
|
- uses: pre-commit/action@v3.0.1
|
||||||
|
|||||||
30
.github/workflows/main.yml
vendored
30
.github/workflows/main.yml
vendored
@@ -4,6 +4,8 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
|
tags:
|
||||||
|
- "v*"
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
@@ -42,7 +44,12 @@ jobs:
|
|||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl
|
images: |
|
||||||
|
winglian/axolotl
|
||||||
|
axolotlai/axolotl
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=semver,pattern={{version}}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
@@ -56,7 +63,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
||||||
@@ -104,20 +111,22 @@ jobs:
|
|||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-cloud
|
images: |
|
||||||
|
winglian/axolotl-cloud
|
||||||
|
axolotlai/axolotl-cloud
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
file: ./docker/Dockerfile-cloud
|
file: ./docker/Dockerfile-cloud
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
@@ -146,20 +155,25 @@ jobs:
|
|||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-cloud-term
|
images: |
|
||||||
|
winglian/axolotl-cloud-term
|
||||||
|
axolotlai/axolotl-cloud-term
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=semver,pattern={{version}}
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
file: ./docker/Dockerfile-cloud-no-tmux
|
file: ./docker/Dockerfile-cloud-no-tmux
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
|
|||||||
10
.github/workflows/nightlies.yml
vendored
10
.github/workflows/nightlies.yml
vendored
@@ -41,7 +41,9 @@ jobs:
|
|||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl
|
images: |
|
||||||
|
winglian/axolotl
|
||||||
|
axolotlai/axolotl
|
||||||
tags: |
|
tags: |
|
||||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
@@ -103,7 +105,9 @@ jobs:
|
|||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v5
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-cloud
|
images: |
|
||||||
|
winglian/axolotl-cloud
|
||||||
|
axolotlai/axolotl-cloud
|
||||||
tags: |
|
tags: |
|
||||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
@@ -112,7 +116,7 @@ jobs:
|
|||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
|
|||||||
24
.github/workflows/pypi.yml
vendored
24
.github/workflows/pypi.yml
vendored
@@ -3,13 +3,31 @@ name: publish pypi
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- 'v*'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
setup_release:
|
||||||
|
name: Create Release
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Get the tag version
|
||||||
|
id: extract_branch
|
||||||
|
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
||||||
|
shell: bash
|
||||||
|
|
||||||
|
- name: Create Release
|
||||||
|
id: create_release
|
||||||
|
uses: actions/create-release@v1
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
with:
|
||||||
|
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
||||||
|
release_name: ${{ steps.extract_branch.outputs.branch }}
|
||||||
pypi-publish:
|
pypi-publish:
|
||||||
name: Upload release to PyPI
|
name: Upload release to PyPI
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
needs: [setup_release]
|
||||||
environment:
|
environment:
|
||||||
name: pypi
|
name: pypi
|
||||||
url: https://pypi.org/p/axolotl
|
url: https://pypi.org/p/axolotl
|
||||||
@@ -17,10 +35,10 @@ jobs:
|
|||||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
|
||||||
|
|||||||
10
.github/workflows/tests-nightly.yml
vendored
10
.github/workflows/tests-nightly.yml
vendored
@@ -9,12 +9,12 @@ jobs:
|
|||||||
name: pre-commit
|
name: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
- uses: pre-commit/action@v3.0.0
|
- uses: pre-commit/action@v3.0.1
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
@@ -30,10 +30,10 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
|
|||||||
15
.github/workflows/tests.yml
vendored
15
.github/workflows/tests.yml
vendored
@@ -15,17 +15,22 @@ on:
|
|||||||
- '.github/workflows/*.yml'
|
- '.github/workflows/*.yml'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
# Cancel jobs on the same ref if a new one is triggered
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
pre-commit:
|
pre-commit:
|
||||||
name: pre-commit
|
name: pre-commit
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/setup-python@v4
|
- uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
- uses: pre-commit/action@v3.0.0
|
- uses: pre-commit/action@v3.0.1
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
@@ -41,10 +46,10 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python_version }}
|
python-version: ${{ matrix.python_version }}
|
||||||
cache: 'pip' # caching pip dependencies
|
cache: 'pip' # caching pip dependencies
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
|
|||||||
#### Docker
|
#### Docker
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest
|
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||||
```
|
```
|
||||||
|
|
||||||
Or run on the current files for development:
|
Or run on the current files for development:
|
||||||
@@ -178,7 +178,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
|
|||||||
A more powerful Docker command to run would be this:
|
A more powerful Docker command to run would be this:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest
|
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-latest
|
||||||
```
|
```
|
||||||
|
|
||||||
It additionally:
|
It additionally:
|
||||||
@@ -210,7 +210,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
|||||||
|
|
||||||
#### Cloud GPU
|
#### Cloud GPU
|
||||||
|
|
||||||
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
|
For cloud GPU providers that support docker images, use [`axolotlai/axolotl-cloud:main-latest`](https://hub.docker.com/r/axolotlai/axolotl-cloud/tags)
|
||||||
|
|
||||||
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||||
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
|
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
|
||||||
@@ -319,7 +319,7 @@ Write a job description in YAML as below:
|
|||||||
# dstack.yaml
|
# dstack.yaml
|
||||||
type: task
|
type: task
|
||||||
|
|
||||||
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2
|
image: axolotlai/axolotl-cloud:main-latest
|
||||||
|
|
||||||
env:
|
env:
|
||||||
- HUGGING_FACE_HUB_TOKEN
|
- HUGGING_FACE_HUB_TOKEN
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
FROM winglian/axolotl-base:{{ BASE_TAG }}
|
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
||||||
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import tempfile
|
|||||||
import jinja2
|
import jinja2
|
||||||
import modal
|
import modal
|
||||||
from jinja2 import select_autoescape
|
from jinja2 import select_autoescape
|
||||||
from modal import Image, Stub
|
from modal import App, Image
|
||||||
|
|
||||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ cicd_image = (
|
|||||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||||
)
|
)
|
||||||
|
|
||||||
stub = Stub("Axolotl CI/CD", secrets=[])
|
app = App("Axolotl CI/CD", secrets=[])
|
||||||
|
|
||||||
|
|
||||||
N_GPUS = int(os.environ.get("N_GPUS", 2))
|
N_GPUS = int(os.environ.get("N_GPUS", 2))
|
||||||
@@ -61,7 +61,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||||
|
|
||||||
|
|
||||||
@stub.function(
|
@app.function(
|
||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=60 * 60,
|
timeout=60 * 60,
|
||||||
@@ -72,6 +72,6 @@ def cicd_pytest():
|
|||||||
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
||||||
|
|
||||||
|
|
||||||
@stub.local_entrypoint()
|
@app.local_entrypoint()
|
||||||
def main():
|
def main():
|
||||||
cicd_pytest.remote()
|
cicd_pytest.remote()
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import tempfile
|
|||||||
import jinja2
|
import jinja2
|
||||||
import modal
|
import modal
|
||||||
from jinja2 import select_autoescape
|
from jinja2 import select_autoescape
|
||||||
from modal import Image, Stub
|
from modal import App, Image
|
||||||
|
|
||||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ cicd_image = (
|
|||||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||||
)
|
)
|
||||||
|
|
||||||
stub = Stub("Axolotl CI/CD", secrets=[])
|
app = App("Axolotl CI/CD", secrets=[])
|
||||||
|
|
||||||
|
|
||||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||||
@@ -62,7 +62,7 @@ def run_cmd(cmd: str, run_folder: str):
|
|||||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||||
|
|
||||||
|
|
||||||
@stub.function(
|
@app.function(
|
||||||
image=cicd_image,
|
image=cicd_image,
|
||||||
gpu=GPU_CONFIG,
|
gpu=GPU_CONFIG,
|
||||||
timeout=60 * 60,
|
timeout=60 * 60,
|
||||||
@@ -73,6 +73,6 @@ def cicd_pytest():
|
|||||||
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
||||||
|
|
||||||
|
|
||||||
@stub.local_entrypoint()
|
@app.local_entrypoint()
|
||||||
def main():
|
def main():
|
||||||
cicd_pytest.remote()
|
cicd_pytest.remote()
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
ARG BASE_TAG=main-base
|
ARG BASE_TAG=main-base
|
||||||
FROM winglian/axolotl-base:$BASE_TAG
|
FROM axolotlai/axolotl-base:$BASE_TAG
|
||||||
|
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
ARG AXOLOTL_EXTRAS=""
|
ARG AXOLOTL_EXTRAS=""
|
||||||
|
|||||||
@@ -35,7 +35,3 @@ RUN git lfs install --skip-repo && \
|
|||||||
pip3 install awscli && \
|
pip3 install awscli && \
|
||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|
||||||
RUN if [ "$PYTHON_VERSION" != "2.5.1" ] ; then \
|
|
||||||
pip3 install flash-attn==2.6.3; \
|
|
||||||
fi
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
ARG BASE_TAG=main
|
ARG BASE_TAG=main
|
||||||
FROM winglian/axolotl:$BASE_TAG
|
FROM axolotlai/axolotl:$BASE_TAG
|
||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
ARG BASE_TAG=main
|
ARG BASE_TAG=main
|
||||||
FROM winglian/axolotl:$BASE_TAG
|
FROM axolotlai/axolotl:$BASE_TAG
|
||||||
|
|
||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
ARG BASE_TAG=main-base
|
ARG BASE_TAG=main-base
|
||||||
FROM winglian/axolotl-base:$BASE_TAG
|
FROM axolotlai/axolotl-base:$BASE_TAG
|
||||||
|
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
ARG AXOLOTL_EXTRAS=""
|
ARG AXOLOTL_EXTRAS=""
|
||||||
|
|||||||
@@ -405,6 +405,7 @@ lr_div_factor: # Learning rate div factor
|
|||||||
# - adamw_torch_fused
|
# - adamw_torch_fused
|
||||||
# - adamw_torch_xla
|
# - adamw_torch_xla
|
||||||
# - adamw_apex_fused
|
# - adamw_apex_fused
|
||||||
|
# - adopt_adamw (only for torch version >= 2.5.1)
|
||||||
# - adafactor
|
# - adafactor
|
||||||
# - adamw_anyprecision
|
# - adamw_anyprecision
|
||||||
# - sgd
|
# - sgd
|
||||||
|
|||||||
@@ -185,7 +185,7 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
|
|||||||
|
|
||||||
## Debugging With Docker
|
## Debugging With Docker
|
||||||
|
|
||||||
Using [official Axolotl Docker images](https://hub.docker.com/r/winglian/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
|
Using [official Axolotl Docker images](https://hub.docker.com/r/axolotlai/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
|
||||||
|
|
||||||
### Setup
|
### Setup
|
||||||
|
|
||||||
@@ -202,11 +202,11 @@ cd axolotl
|
|||||||
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
|
||||||
```
|
```
|
||||||
|
|
||||||
>[!Tip]
|
>[!Tip]
|
||||||
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/winglian/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
||||||
|
|
||||||
You will now be in the container. Next, perform an editable install of Axolotl:
|
You will now be in the container. Next, perform an editable install of Axolotl:
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
|
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
|
||||||
"!pip install flash-attn==\"2.5.0\"\n",
|
"!pip install flash-attn==\"2.7.0.post2\"\n",
|
||||||
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
93
examples/mistral/mistral-dpo-qlora.yml
Normal file
93
examples/mistral/mistral-dpo-qlora.yml
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
#Note that we are switching from the regular chat template to chatml.
|
||||||
|
#If you experience problems with the special tokens, training for more epochs can help.
|
||||||
|
#After training, merge the model before inference otherwise you might
|
||||||
|
#face problems with the special tokens.
|
||||||
|
|
||||||
|
base_model: mistralai/Mistral-7B-Instruct-v0.2
|
||||||
|
model_type: MistralForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
chat_template: chatml
|
||||||
|
rl: dpo
|
||||||
|
datasets:
|
||||||
|
- path: olivermolenschot/alpaca_messages_dpo_test
|
||||||
|
type: chat_template.default
|
||||||
|
field_messages: conversation
|
||||||
|
field_chosen: chosen
|
||||||
|
field_rejected: rejected
|
||||||
|
message_field_role: role
|
||||||
|
message_field_content: content
|
||||||
|
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/dpo-qlora
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 8
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.2
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
|
||||||
|
lora_target_modules:
|
||||||
|
- gate_proj
|
||||||
|
- down_proj
|
||||||
|
- up_proj
|
||||||
|
- q_proj
|
||||||
|
- v_proj
|
||||||
|
- k_proj
|
||||||
|
- o_proj
|
||||||
|
lora_modules_to_save:
|
||||||
|
- embed_tokens
|
||||||
|
- lm_head
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 16
|
||||||
|
num_epochs: 6
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0001
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: auto
|
||||||
|
fp16:
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: false
|
||||||
|
s2_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<|im_start|>"
|
||||||
|
eos_token: "<|im_end|>"
|
||||||
@@ -1,2 +1,3 @@
|
|||||||
pytest
|
pytest
|
||||||
pytest-xdist
|
pytest-xdist
|
||||||
|
pytest-retry
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ addict
|
|||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
requests
|
requests
|
||||||
flash-attn==2.6.3
|
flash-attn==2.7.0.post2
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
@@ -33,7 +33,7 @@ tensorboard
|
|||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
autoawq>=0.2.5
|
autoawq>=0.2.5
|
||||||
triton>=2.3.0
|
triton>=2.3.0
|
||||||
liger-kernel==0.4.0
|
liger-kernel==0.4.1
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
|
|||||||
11
setup.py
11
setup.py
@@ -39,7 +39,10 @@ def parse_requirements():
|
|||||||
else:
|
else:
|
||||||
# detect the version of torch already installed
|
# detect the version of torch already installed
|
||||||
# and set it so dependencies don't clobber the torch version
|
# and set it so dependencies don't clobber the torch version
|
||||||
torch_version = version("torch")
|
try:
|
||||||
|
torch_version = version("torch")
|
||||||
|
except PackageNotFoundError:
|
||||||
|
torch_version = "2.5.1"
|
||||||
_install_requires.append(f"torch=={torch_version}")
|
_install_requires.append(f"torch=={torch_version}")
|
||||||
|
|
||||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||||
@@ -54,6 +57,10 @@ def parse_requirements():
|
|||||||
|
|
||||||
if (major, minor) >= (2, 5):
|
if (major, minor) >= (2, 5):
|
||||||
_install_requires.pop(_install_requires.index(xformers_version))
|
_install_requires.pop(_install_requires.index(xformers_version))
|
||||||
|
if patch == 0:
|
||||||
|
_install_requires.append("xformers==0.0.28.post2")
|
||||||
|
else:
|
||||||
|
_install_requires.append("xformers==0.0.28.post3")
|
||||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||||
elif (major, minor) >= (2, 4):
|
elif (major, minor) >= (2, 4):
|
||||||
if patch == 0:
|
if patch == 0:
|
||||||
@@ -98,7 +105,7 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.6.3",
|
"flash-attn==2.7.0.post2",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.14.4",
|
"deepspeed==0.14.4",
|
||||||
|
|||||||
@@ -190,18 +190,15 @@ def do_inference(
|
|||||||
):
|
):
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
|
||||||
|
|
||||||
for token, symbol in default_tokens.items():
|
|
||||||
# If the token isn't already specified in the config, add it
|
|
||||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
|
||||||
tokenizer.add_special_tokens({token: symbol})
|
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
|
chat_template_str = None
|
||||||
if prompter:
|
if prompter:
|
||||||
prompter_module = getattr(
|
prompter_module = getattr(
|
||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
elif cfg.chat_template:
|
||||||
|
chat_template_str = get_chat_template(cfg.chat_template)
|
||||||
|
|
||||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|
||||||
@@ -211,13 +208,31 @@ def do_inference(
|
|||||||
instruction = get_multi_line_input()
|
instruction = get_multi_line_input()
|
||||||
if not instruction:
|
if not instruction:
|
||||||
return
|
return
|
||||||
|
|
||||||
if prompter_module:
|
if prompter_module:
|
||||||
prompt: str = next(
|
prompt: str = next(
|
||||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = instruction.strip()
|
prompt = instruction.strip()
|
||||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
|
||||||
|
if chat_template_str:
|
||||||
|
batch = tokenizer.apply_chat_template(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
return_tensors="pt",
|
||||||
|
add_special_tokens=True,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
chat_template=chat_template_str,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||||
|
|
||||||
print("=" * 40)
|
print("=" * 40)
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -257,13 +272,6 @@ def do_inference_gradio(
|
|||||||
|
|
||||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||||
prompter = cli_args.prompter
|
prompter = cli_args.prompter
|
||||||
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
|
||||||
default_tokens: Dict[str, str] = {}
|
|
||||||
|
|
||||||
for token, symbol in default_tokens.items():
|
|
||||||
# If the token isn't already specified in the config, add it
|
|
||||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
|
||||||
tokenizer.add_special_tokens({token: symbol})
|
|
||||||
|
|
||||||
prompter_module = None
|
prompter_module = None
|
||||||
chat_template_str = None
|
chat_template_str = None
|
||||||
|
|||||||
@@ -436,7 +436,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
and self.args.alternate_optimizer
|
and self.args.alternate_optimizer
|
||||||
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
|
not in [
|
||||||
|
"optimi_adamw",
|
||||||
|
"ao_adamw_8bit",
|
||||||
|
"ao_adamw_4bit",
|
||||||
|
"ao_adamw_fp8",
|
||||||
|
"adopt_adamw",
|
||||||
|
]
|
||||||
):
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
@@ -505,6 +511,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
)
|
)
|
||||||
|
elif self.args.alternate_optimizer == "adopt_adamw":
|
||||||
|
from axolotl.utils.optimizers.adopt import ADOPT
|
||||||
|
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
ADOPT(
|
||||||
|
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
@@ -1273,6 +1287,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||||
callbacks.append(lisa_callback_factory(trainer))
|
callbacks.append(lisa_callback_factory(trainer))
|
||||||
|
|
||||||
|
if self.cfg.plugins:
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
callbacks.extend(
|
||||||
|
[
|
||||||
|
cb
|
||||||
|
for cb in plugin_manager.add_callbacks_post_trainer(
|
||||||
|
self.cfg, trainer
|
||||||
|
)
|
||||||
|
if cb
|
||||||
|
]
|
||||||
|
)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def _get_trainer_cls(self):
|
def _get_trainer_cls(self):
|
||||||
@@ -1625,11 +1651,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
if self.cfg.optimizer in [
|
if self.cfg.optimizer in [
|
||||||
"optimi_adamw",
|
"optimi_adamw",
|
||||||
"ao_adamw_4bit",
|
"ao_adamw_4bit",
|
||||||
"ao_adamw_8bit",
|
"ao_adamw_8bit",
|
||||||
"ao_adamw_fp8",
|
"ao_adamw_fp8",
|
||||||
|
"adopt_adamw",
|
||||||
]:
|
]:
|
||||||
# Set default so transformers doesn't throw
|
# Set default so transformers doesn't throw
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||||
@@ -1933,6 +1961,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
training_args_cls = AxolotlDPOConfig
|
training_args_cls = AxolotlDPOConfig
|
||||||
|
if self.cfg.rl == "ipo":
|
||||||
|
training_args_kwargs["loss_type"] = "ipo"
|
||||||
|
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||||
|
training_args_kwargs["max_completion_length"] = None
|
||||||
|
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||||
|
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||||
if self.cfg.dpo_use_weighting is not None:
|
if self.cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||||
|
|
||||||
@@ -1956,7 +1990,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args = self.build_training_arguments(total_num_steps)
|
training_args = self.build_training_arguments(total_num_steps)
|
||||||
dpo_trainer_kwargs = {}
|
dpo_trainer_kwargs = {}
|
||||||
if self.cfg.rl == "ipo":
|
if self.cfg.rl == "ipo":
|
||||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
|
||||||
if self.cfg.dpo_label_smoothing:
|
if self.cfg.dpo_label_smoothing:
|
||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
@@ -1970,12 +2003,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rl in ["dpo", "ipo"]:
|
if self.cfg.rl in ["dpo", "ipo"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|
||||||
# these aren't used for the ORPO trainer
|
|
||||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
|
||||||
dpo_trainer_kwargs["max_target_length"] = None
|
|
||||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
|
||||||
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
|
||||||
elif self.cfg.rl == "orpo":
|
elif self.cfg.rl == "orpo":
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
trainer_cls_args = [self.model]
|
trainer_cls_args = [self.model]
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ class BasePlugin:
|
|||||||
|
|
||||||
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Adds callbacks to the trainer before training.
|
setup callbacks before creating the trainer.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
cfg (dict): The configuration for the plugin.
|
cfg (dict): The configuration for the plugin.
|
||||||
@@ -155,14 +155,15 @@ class BasePlugin:
|
|||||||
self, cfg, trainer
|
self, cfg, trainer
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Adds callbacks to the trainer after training.
|
Adds callbacks to the trainer after creating the trainer.
|
||||||
|
This is useful for callbacks that require access to the model or trainer.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
cfg (dict): The configuration for the plugin.
|
cfg (dict): The configuration for the plugin.
|
||||||
trainer (object): The trainer object for training.
|
trainer (object): The trainer object for training.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
List[callable]: A list of callback functions to be added
|
||||||
"""
|
"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -393,7 +394,9 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
callbacks = []
|
callbacks = []
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
|
plugin_callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
|
||||||
|
if plugin_callbacks: # if the plugin returned a list of callbacks
|
||||||
|
callbacks.extend(plugin_callbacks)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||||
@@ -409,7 +412,9 @@ class PluginManager:
|
|||||||
"""
|
"""
|
||||||
callbacks = []
|
callbacks = []
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
plugin_callbacks = plugin.add_callbacks_post_trainer(cfg, trainer)
|
||||||
|
if plugin_callbacks:
|
||||||
|
callbacks.extend(plugin_callbacks)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def post_train_unload(self, cfg):
|
def post_train_unload(self, cfg):
|
||||||
|
|||||||
21
src/axolotl/integrations/grokfast/LICENSE
Normal file
21
src/axolotl/integrations/grokfast/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
13
src/axolotl/integrations/grokfast/README.md
Normal file
13
src/axolotl/integrations/grokfast/README.md
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# Grokfast Optimizer
|
||||||
|
|
||||||
|
See https://github.com/ironjr/grokfast
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
plugins:
|
||||||
|
- axolotl.integrations.grokfast.GrokfastPlugin
|
||||||
|
|
||||||
|
grokfast_alpha: 2.0
|
||||||
|
grokfast_lamb: 0.98
|
||||||
|
```
|
||||||
50
src/axolotl/integrations/grokfast/__init__.py
Normal file
50
src/axolotl/integrations/grokfast/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""
|
||||||
|
Grokfast plugin for Axolotl
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from transformers.trainer_callback import TrainerCallback
|
||||||
|
|
||||||
|
from ..base import BasePlugin
|
||||||
|
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
from .optimizer import gradfilter_ema
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.integrations.grokfast")
|
||||||
|
|
||||||
|
|
||||||
|
class GrokfastCallbackHandler(TrainerCallback):
|
||||||
|
"""
|
||||||
|
Transformer trainer callbacks for Grokfast
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs):
|
||||||
|
super().__init__(*args_, **kwargs)
|
||||||
|
self.grads = None
|
||||||
|
self.alpha = alpha
|
||||||
|
self.lamb = lamb
|
||||||
|
|
||||||
|
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
|
||||||
|
self.grads = None
|
||||||
|
|
||||||
|
def on_pre_optimizer_step(
|
||||||
|
self, args_, state, control, **kwargs
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
model = kwargs.pop("model")
|
||||||
|
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class GrokfastPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for Grokfast optimizer integraton with Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_input_args(self):
|
||||||
|
return "axolotl.integrations.grokfast.GrokfastArgs"
|
||||||
|
|
||||||
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||||
|
LOG.info("Adding Grokfast callback to the trainer")
|
||||||
|
callback = GrokfastCallbackHandler(
|
||||||
|
alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb
|
||||||
|
)
|
||||||
|
return [callback]
|
||||||
15
src/axolotl/integrations/grokfast/args.py
Normal file
15
src/axolotl/integrations/grokfast/args.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
config args for grokfast plugin
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class GrokfastArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for Grokfast optimizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
grokfast_alpha: Optional[float] = 0.98
|
||||||
|
grokfast_lamb: Optional[float] = 2.0
|
||||||
63
src/axolotl/integrations/grokfast/optimizer.py
Normal file
63
src/axolotl/integrations/grokfast/optimizer.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
|
||||||
|
# Reference: https://github.com/ironjr/grokfast
|
||||||
|
|
||||||
|
# pylint: skip-file
|
||||||
|
from collections import deque
|
||||||
|
from typing import Dict, Literal, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def gradfilter_ma(
|
||||||
|
m: nn.Module,
|
||||||
|
grads: Optional[Dict[str, deque]] = None,
|
||||||
|
window_size: int = 100,
|
||||||
|
lamb: float = 5.0,
|
||||||
|
filter_type: Literal["mean", "sum"] = "mean",
|
||||||
|
warmup: bool = True,
|
||||||
|
trigger: bool = False, # For ablation study.
|
||||||
|
) -> Dict[str, deque]:
|
||||||
|
if grads is None:
|
||||||
|
grads = {
|
||||||
|
n: deque(maxlen=window_size)
|
||||||
|
for n, p in m.named_parameters()
|
||||||
|
if p.requires_grad and p.grad is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, p in m.named_parameters():
|
||||||
|
if p.requires_grad and p.grad is not None:
|
||||||
|
grads[n].append(p.grad.data.detach()) # .cpu())
|
||||||
|
|
||||||
|
# Modify the gradients.
|
||||||
|
if not warmup or len(grads[n]) == window_size and not trigger:
|
||||||
|
if filter_type == "mean":
|
||||||
|
avg = sum(grads[n]) / len(grads[n])
|
||||||
|
elif filter_type == "sum":
|
||||||
|
avg = sum(grads[n])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized filter_type {filter_type}")
|
||||||
|
p.grad.data = p.grad.data + avg * lamb
|
||||||
|
|
||||||
|
return grads
|
||||||
|
|
||||||
|
|
||||||
|
def gradfilter_ema(
|
||||||
|
m: nn.Module,
|
||||||
|
grads: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
alpha: float = 0.98,
|
||||||
|
lamb: float = 2.0,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
if grads is None:
|
||||||
|
grads = {
|
||||||
|
n: p.grad.data.detach()
|
||||||
|
for n, p in m.named_parameters()
|
||||||
|
if p.requires_grad and p.grad is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
for n, p in m.named_parameters():
|
||||||
|
if p.requires_grad and p.grad is not None:
|
||||||
|
grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
|
||||||
|
p.grad.data = p.grad.data + grads[n] * lamb
|
||||||
|
|
||||||
|
return grads
|
||||||
@@ -23,6 +23,7 @@ import logging
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||||
@@ -82,7 +83,9 @@ class LigerPlugin(BasePlugin):
|
|||||||
if cfg.liger_glu_activation:
|
if cfg.liger_glu_activation:
|
||||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
from transformers.loss.loss_utils import nn
|
||||||
|
|
||||||
|
nn.functional.cross_entropy = liger_cross_entropy
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
||||||
elif cfg.model_config_type == "deepseek_v2":
|
elif cfg.model_config_type == "deepseek_v2":
|
||||||
@@ -106,6 +109,8 @@ class LigerPlugin(BasePlugin):
|
|||||||
if cfg.liger_glu_activation:
|
if cfg.liger_glu_activation:
|
||||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
|
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
||||||
|
# nn.CrossEntropyLoss in the forward method.
|
||||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||||
if cfg.liger_fused_linear_cross_entropy:
|
if cfg.liger_fused_linear_cross_entropy:
|
||||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -58,6 +58,7 @@ class ChatTemplate(str, Enum):
|
|||||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||||
exaone = "exaone" # pylint: disable=invalid-name
|
exaone = "exaone" # pylint: disable=invalid-name
|
||||||
|
metharme = "metharme" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedParameters(BaseModel):
|
class DeprecatedParameters(BaseModel):
|
||||||
@@ -427,6 +428,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
"ao_adamw_4bit",
|
"ao_adamw_4bit",
|
||||||
"ao_adamw_8bit",
|
"ao_adamw_8bit",
|
||||||
"ao_adamw_fp8",
|
"ao_adamw_fp8",
|
||||||
|
"adopt_adamw",
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
] = OptimizerNames.ADAMW_HF.value
|
] = OptimizerNames.ADAMW_HF.value
|
||||||
@@ -781,6 +783,8 @@ class AxolotlInputConfig(
|
|||||||
is_mistral_derived_model: Optional[bool] = Field(default=None)
|
is_mistral_derived_model: Optional[bool] = Field(default=None)
|
||||||
is_qwen_derived_model: Optional[bool] = Field(default=None)
|
is_qwen_derived_model: Optional[bool] = Field(default=None)
|
||||||
|
|
||||||
|
plugins: Optional[List[str]] = Field(default=None)
|
||||||
|
|
||||||
@field_validator("datasets", mode="before")
|
@field_validator("datasets", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def deprecate_sharegpt_datasets(cls, datasets):
|
def deprecate_sharegpt_datasets(cls, datasets):
|
||||||
@@ -788,7 +792,12 @@ class AxolotlInputConfig(
|
|||||||
if not ds_cfg.get("type"):
|
if not ds_cfg.get("type"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if ds_cfg["type"].startswith("sharegpt"):
|
ds_type = ds_cfg["type"]
|
||||||
|
# skip if it's a dict (for custom user instruction prompt)
|
||||||
|
if isinstance(ds_type, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(ds_type, str) and ds_type.startswith("sharegpt"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
|
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
|
||||||
)
|
)
|
||||||
@@ -1393,6 +1402,17 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_fsdp_grad_accum_4_46_2(cls, data):
|
||||||
|
if data.get("fsdp") and data.get("gradient_accumulation_steps") > 1:
|
||||||
|
if version("transformers") == "4.46.2":
|
||||||
|
raise ValueError(
|
||||||
|
"FSDP w/ gradient_accumulation_steps > 1 is broken with transformers==4.46.2. "
|
||||||
|
"Please use a lower value or switch to an older version of transformers."
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||||
|
|||||||
25
src/axolotl/utils/environment.py
Normal file
25
src/axolotl/utils/environment.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""
|
||||||
|
utils to get GPU info for the current environment
|
||||||
|
"""
|
||||||
|
from accelerate.utils.environment import (
|
||||||
|
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
|
||||||
|
)
|
||||||
|
from accelerate.utils.environment import get_gpu_info
|
||||||
|
|
||||||
|
|
||||||
|
def check_cuda_p2p_ib_support():
|
||||||
|
if not accelerate_check_cuda_p2p_ib_support():
|
||||||
|
return False
|
||||||
|
unsupported_devices = {"RTX 6000 Ada"}
|
||||||
|
try:
|
||||||
|
device_names, device_count = get_gpu_info()
|
||||||
|
if 1 < device_count < 8:
|
||||||
|
if any(
|
||||||
|
unsupported_device in device_name
|
||||||
|
for device_name in device_names
|
||||||
|
for unsupported_device in unsupported_devices
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
except Exception: # pylint: disable=broad-except # nosec
|
||||||
|
pass
|
||||||
|
return True
|
||||||
@@ -14,6 +14,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
torch_version = version.parse(torch.__version__)
|
||||||
|
|
||||||
|
if torch_version < version.parse("2.4.0"):
|
||||||
|
torch_cuda_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
||||||
|
torch_cuda_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
||||||
|
else:
|
||||||
|
torch_cuda_amp_custom_fwd = torch.amp.custom_fwd(device_type="cuda")
|
||||||
|
torch_cuda_amp_custom_bwd = torch.amp.custom_bwd(device_type="cuda")
|
||||||
|
|
||||||
|
|
||||||
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
||||||
@@ -25,7 +35,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.cuda.amp.custom_fwd
|
@torch_cuda_amp_custom_fwd
|
||||||
def forward(ctx, forward_function, hidden_states, *args):
|
def forward(ctx, forward_function, hidden_states, *args):
|
||||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -36,7 +46,7 @@ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.cuda.amp.custom_bwd
|
@torch_cuda_amp_custom_bwd
|
||||||
def backward(ctx, dY):
|
def backward(ctx, dY):
|
||||||
(hidden_states,) = ctx.saved_tensors
|
(hidden_states,) = ctx.saved_tensors
|
||||||
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
|
||||||
|
|||||||
508
src/axolotl/utils/optimizers/adopt.py
Normal file
508
src/axolotl/utils/optimizers/adopt.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
"""
|
||||||
|
Copied from https://github.com/iShohei220/adopt
|
||||||
|
|
||||||
|
ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate (2024)
|
||||||
|
Taniguchi, Shohei and Harada, Keno and Minegishi, Gouki and Oshima, Yuta and Jeong, Seong Cheol and Nagahara, Go and Iiyama, Tomoshi and Suzuki, Masahiro and Iwasawa, Yusuke and Matsuo, Yutaka
|
||||||
|
"""
|
||||||
|
# mypy: ignore-errors
|
||||||
|
# pylint: skip-file
|
||||||
|
# mypy: allow-untyped-decorators
|
||||||
|
# mypy: allow-untyped-defs
|
||||||
|
from typing import List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.optim.optimizer import (
|
||||||
|
Optimizer,
|
||||||
|
ParamsT,
|
||||||
|
_default_to_fused_or_foreach,
|
||||||
|
_device_dtype_check_for_fused,
|
||||||
|
_disable_dynamo_if_unsupported,
|
||||||
|
_get_capturable_supported_devices,
|
||||||
|
_get_scalar_dtype,
|
||||||
|
_get_value,
|
||||||
|
_use_grad_for_differentiable,
|
||||||
|
_view_as_real,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["ADOPT", "adopt"]
|
||||||
|
|
||||||
|
|
||||||
|
class ADOPT(Optimizer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: ParamsT,
|
||||||
|
lr: Union[float, Tensor] = 1e-3,
|
||||||
|
betas: Tuple[float, float] = (0.9, 0.9999),
|
||||||
|
eps: float = 1e-6,
|
||||||
|
weight_decay: float = 0.0,
|
||||||
|
decoupled: bool = False,
|
||||||
|
*,
|
||||||
|
foreach: Optional[bool] = None,
|
||||||
|
maximize: bool = False,
|
||||||
|
capturable: bool = False,
|
||||||
|
differentiable: bool = False,
|
||||||
|
fused: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
if isinstance(lr, Tensor):
|
||||||
|
if foreach and not capturable:
|
||||||
|
raise ValueError(
|
||||||
|
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||||
|
)
|
||||||
|
if lr.numel() != 1:
|
||||||
|
raise ValueError("Tensor lr must be 1-element")
|
||||||
|
if not 0.0 <= lr:
|
||||||
|
raise ValueError(f"Invalid learning rate: {lr}")
|
||||||
|
if not 0.0 <= eps:
|
||||||
|
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||||
|
if not 0.0 <= betas[0] < 1.0:
|
||||||
|
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
||||||
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
|
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||||
|
if not 0.0 <= weight_decay:
|
||||||
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
|
defaults = dict(
|
||||||
|
lr=lr,
|
||||||
|
betas=betas,
|
||||||
|
eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
decoupled=decoupled,
|
||||||
|
maximize=maximize,
|
||||||
|
foreach=foreach,
|
||||||
|
capturable=capturable,
|
||||||
|
differentiable=differentiable,
|
||||||
|
fused=fused,
|
||||||
|
)
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
if fused:
|
||||||
|
# TODO: support fused
|
||||||
|
raise RuntimeError("`fused` is not currently supported")
|
||||||
|
|
||||||
|
if differentiable:
|
||||||
|
raise RuntimeError("`fused` does not support `differentiable`")
|
||||||
|
self._step_supports_amp_scaling = True
|
||||||
|
# TODO(crcrpar): [low prec params & their higher prec copy]
|
||||||
|
# Support AMP with FP16/BF16 model params which would need
|
||||||
|
# higher prec copy of params to do update math in higher prec to
|
||||||
|
# alleviate the loss of information.
|
||||||
|
if foreach:
|
||||||
|
raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super().__setstate__(state)
|
||||||
|
for group in self.param_groups:
|
||||||
|
group.setdefault("maximize", False)
|
||||||
|
group.setdefault("foreach", None)
|
||||||
|
group.setdefault("capturable", False)
|
||||||
|
group.setdefault("differentiable", False)
|
||||||
|
fused = group.setdefault("fused", None)
|
||||||
|
for p in group["params"]:
|
||||||
|
p_state = self.state.get(p, [])
|
||||||
|
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
|
||||||
|
step_val = float(p_state["step"])
|
||||||
|
p_state["step"] = (
|
||||||
|
torch.tensor(
|
||||||
|
step_val,
|
||||||
|
dtype=_get_scalar_dtype(is_fused=fused),
|
||||||
|
device=p.device,
|
||||||
|
)
|
||||||
|
if group["capturable"] or group["fused"]
|
||||||
|
else torch.tensor(step_val, dtype=_get_scalar_dtype())
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_group(
|
||||||
|
self,
|
||||||
|
group,
|
||||||
|
params_with_grad,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
):
|
||||||
|
has_complex = False
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is not None:
|
||||||
|
has_complex |= torch.is_complex(p)
|
||||||
|
params_with_grad.append(p)
|
||||||
|
if p.grad.is_sparse:
|
||||||
|
raise RuntimeError("ADOPT does not support sparse gradients")
|
||||||
|
grads.append(p.grad)
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
# Lazy state initialization
|
||||||
|
if len(state) == 0:
|
||||||
|
if group["fused"]:
|
||||||
|
_device_dtype_check_for_fused(p)
|
||||||
|
# note(crcrpar): [special device hosting for step]
|
||||||
|
# Deliberately host `step` on CPU if both capturable and fused are off.
|
||||||
|
# This is because kernel launches are costly on CUDA and XLA.
|
||||||
|
state["step"] = (
|
||||||
|
torch.zeros(
|
||||||
|
(),
|
||||||
|
dtype=_get_scalar_dtype(is_fused=group["fused"]),
|
||||||
|
device=p.device,
|
||||||
|
)
|
||||||
|
if group["capturable"] or group["fused"]
|
||||||
|
else torch.tensor(0.0, dtype=_get_scalar_dtype())
|
||||||
|
)
|
||||||
|
# Exponential moving average of gradient values
|
||||||
|
state["exp_avg"] = torch.zeros_like(
|
||||||
|
p, memory_format=torch.preserve_format
|
||||||
|
)
|
||||||
|
# Exponential moving average of squared gradient values
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(
|
||||||
|
p, memory_format=torch.preserve_format
|
||||||
|
)
|
||||||
|
|
||||||
|
exp_avgs.append(state["exp_avg"])
|
||||||
|
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||||
|
|
||||||
|
if group["differentiable"] and state["step"].requires_grad:
|
||||||
|
raise RuntimeError(
|
||||||
|
"`requires_grad` is not supported for `step` in differentiable mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Foreach without capturable does not support a tensor lr
|
||||||
|
if (
|
||||||
|
group["foreach"]
|
||||||
|
and torch.is_tensor(group["lr"])
|
||||||
|
and not group["capturable"]
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||||
|
)
|
||||||
|
|
||||||
|
state_steps.append(state["step"])
|
||||||
|
return has_complex
|
||||||
|
|
||||||
|
@_use_grad_for_differentiable
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""Perform a single optimization step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
closure (Callable, optional): A closure that reevaluates the model
|
||||||
|
and returns the loss.
|
||||||
|
"""
|
||||||
|
self._cuda_graph_capture_health_check()
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
params_with_grad: List[Tensor] = []
|
||||||
|
grads: List[Tensor] = []
|
||||||
|
exp_avgs: List[Tensor] = []
|
||||||
|
exp_avg_sqs: List[Tensor] = []
|
||||||
|
state_steps: List[Tensor] = []
|
||||||
|
beta1, beta2 = group["betas"]
|
||||||
|
|
||||||
|
has_complex = self._init_group(
|
||||||
|
group,
|
||||||
|
params_with_grad,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
adopt(
|
||||||
|
params_with_grad,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
has_complex=has_complex,
|
||||||
|
beta1=beta1,
|
||||||
|
beta2=beta2,
|
||||||
|
lr=group["lr"],
|
||||||
|
weight_decay=group["weight_decay"],
|
||||||
|
decoupled=group["decoupled"],
|
||||||
|
eps=group["eps"],
|
||||||
|
maximize=group["maximize"],
|
||||||
|
foreach=group["foreach"],
|
||||||
|
capturable=group["capturable"],
|
||||||
|
differentiable=group["differentiable"],
|
||||||
|
fused=group["fused"],
|
||||||
|
grad_scale=getattr(self, "grad_scale", None),
|
||||||
|
found_inf=getattr(self, "found_inf", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def _single_tensor_adopt(
|
||||||
|
params: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
exp_avgs: List[Tensor],
|
||||||
|
exp_avg_sqs: List[Tensor],
|
||||||
|
state_steps: List[Tensor],
|
||||||
|
grad_scale: Optional[Tensor],
|
||||||
|
found_inf: Optional[Tensor],
|
||||||
|
*,
|
||||||
|
has_complex: bool,
|
||||||
|
beta1: float,
|
||||||
|
beta2: float,
|
||||||
|
lr: Union[float, Tensor],
|
||||||
|
weight_decay: float,
|
||||||
|
decoupled: bool,
|
||||||
|
eps: float,
|
||||||
|
maximize: bool,
|
||||||
|
capturable: bool,
|
||||||
|
differentiable: bool,
|
||||||
|
):
|
||||||
|
assert grad_scale is None and found_inf is None
|
||||||
|
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
# this assert is due to JIT being dumb and not realizing that the ops below
|
||||||
|
# have overloads to handle both float and Tensor lrs, so we just assert it's
|
||||||
|
# a float since most people using JIT are using floats
|
||||||
|
assert isinstance(lr, float)
|
||||||
|
|
||||||
|
for i, param in enumerate(params):
|
||||||
|
grad = grads[i] if not maximize else -grads[i]
|
||||||
|
exp_avg = exp_avgs[i]
|
||||||
|
exp_avg_sq = exp_avg_sqs[i]
|
||||||
|
step_t = state_steps[i]
|
||||||
|
|
||||||
|
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||||
|
if not torch._utils.is_compiling() and capturable:
|
||||||
|
capturable_supported_devices = _get_capturable_supported_devices()
|
||||||
|
assert (
|
||||||
|
param.device.type == step_t.device.type
|
||||||
|
and param.device.type in capturable_supported_devices
|
||||||
|
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||||
|
|
||||||
|
# update step
|
||||||
|
step_t += 1
|
||||||
|
|
||||||
|
if weight_decay != 0:
|
||||||
|
if decoupled:
|
||||||
|
param.add_(param, alpha=-lr * weight_decay)
|
||||||
|
else:
|
||||||
|
grad = grad.add(param, alpha=weight_decay)
|
||||||
|
|
||||||
|
if torch.is_complex(param):
|
||||||
|
grad = torch.view_as_real(grad)
|
||||||
|
if exp_avg is not None:
|
||||||
|
exp_avg = torch.view_as_real(exp_avg)
|
||||||
|
if exp_avg_sq is not None:
|
||||||
|
exp_avg_sq = torch.view_as_real(exp_avg_sq)
|
||||||
|
param = torch.view_as_real(param)
|
||||||
|
|
||||||
|
step = step_t if capturable or differentiable else _get_value(step_t)
|
||||||
|
if step == 1:
|
||||||
|
exp_avg_sq.addcmul_(grad, grad.conj())
|
||||||
|
continue
|
||||||
|
|
||||||
|
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
|
||||||
|
if step == 2:
|
||||||
|
exp_avg.addcdiv_(grad, denom)
|
||||||
|
else:
|
||||||
|
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
|
||||||
|
|
||||||
|
param.add_(exp_avg, alpha=-lr)
|
||||||
|
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
|
||||||
|
|
||||||
|
|
||||||
|
def _multi_tensor_adopt(
|
||||||
|
params: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
exp_avgs: List[Tensor],
|
||||||
|
exp_avg_sqs: List[Tensor],
|
||||||
|
state_steps: List[Tensor],
|
||||||
|
grad_scale: Optional[Tensor],
|
||||||
|
found_inf: Optional[Tensor],
|
||||||
|
*,
|
||||||
|
has_complex: bool,
|
||||||
|
beta1: float,
|
||||||
|
beta2: float,
|
||||||
|
lr: Union[float, Tensor],
|
||||||
|
weight_decay: float,
|
||||||
|
decoupled: bool,
|
||||||
|
eps: float,
|
||||||
|
maximize: bool,
|
||||||
|
capturable: bool,
|
||||||
|
differentiable: bool,
|
||||||
|
):
|
||||||
|
if len(params) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(lr, Tensor) and not capturable:
|
||||||
|
raise RuntimeError(
|
||||||
|
"lr as a Tensor is not supported for capturable=False and foreach=True"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||||
|
if not torch._utils.is_compiling() and capturable:
|
||||||
|
capturable_supported_devices = _get_capturable_supported_devices(
|
||||||
|
supports_xla=False
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
p.device.type == step.device.type
|
||||||
|
and p.device.type in capturable_supported_devices
|
||||||
|
for p, step in zip(params, state_steps)
|
||||||
|
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||||
|
|
||||||
|
assert grad_scale is None and found_inf is None
|
||||||
|
|
||||||
|
assert not differentiable, "_foreach ops don't support autograd"
|
||||||
|
|
||||||
|
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||||
|
[params, grads, exp_avgs, exp_avg_sqs, state_steps] # type: ignore[list-item]
|
||||||
|
)
|
||||||
|
for (
|
||||||
|
device_params_,
|
||||||
|
device_grads_,
|
||||||
|
device_exp_avgs_,
|
||||||
|
device_exp_avg_sqs_,
|
||||||
|
device_state_steps_,
|
||||||
|
), _ in grouped_tensors.values():
|
||||||
|
device_params = cast(List[Tensor], device_params_)
|
||||||
|
device_grads = cast(List[Tensor], device_grads_)
|
||||||
|
device_exp_avgs = cast(List[Tensor], device_exp_avgs_)
|
||||||
|
device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_)
|
||||||
|
device_state_steps = cast(List[Tensor], device_state_steps_)
|
||||||
|
|
||||||
|
# Handle complex parameters
|
||||||
|
if has_complex:
|
||||||
|
_view_as_real(
|
||||||
|
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
|
||||||
|
)
|
||||||
|
|
||||||
|
if maximize:
|
||||||
|
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Update steps
|
||||||
|
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
|
||||||
|
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
|
||||||
|
# wrapped it once now. The alpha is required to assure we go to the right overload.
|
||||||
|
if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
|
||||||
|
torch._foreach_add_(
|
||||||
|
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch._foreach_add_(device_state_steps, 1)
|
||||||
|
|
||||||
|
if weight_decay != 0:
|
||||||
|
if decoupled:
|
||||||
|
torch._foreach_add_(
|
||||||
|
device_params, device_params, alpha=-lr * weight_decay
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Re-use the intermediate memory (device_grads) already allocated for maximize
|
||||||
|
if maximize:
|
||||||
|
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
|
||||||
|
else:
|
||||||
|
device_grads = torch._foreach_add( # type: ignore[assignment]
|
||||||
|
device_grads, device_params, alpha=weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
if device_state_steps[0] == 1:
|
||||||
|
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
|
||||||
|
continue
|
||||||
|
|
||||||
|
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
|
||||||
|
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
|
||||||
|
|
||||||
|
if device_state_steps[0] == 2:
|
||||||
|
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
|
||||||
|
else:
|
||||||
|
torch._foreach_mul_(device_exp_avgs, beta1)
|
||||||
|
torch._foreach_addcdiv_(
|
||||||
|
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
|
||||||
|
)
|
||||||
|
|
||||||
|
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
|
||||||
|
torch._foreach_mul_(device_exp_avg_sqs, beta2)
|
||||||
|
torch._foreach_addcmul_(
|
||||||
|
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt)
|
||||||
|
def adopt(
|
||||||
|
params: List[Tensor],
|
||||||
|
grads: List[Tensor],
|
||||||
|
exp_avgs: List[Tensor],
|
||||||
|
exp_avg_sqs: List[Tensor],
|
||||||
|
state_steps: List[Tensor],
|
||||||
|
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
|
||||||
|
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
|
||||||
|
foreach: Optional[bool] = None,
|
||||||
|
capturable: bool = False,
|
||||||
|
differentiable: bool = False,
|
||||||
|
fused: Optional[bool] = None,
|
||||||
|
grad_scale: Optional[Tensor] = None,
|
||||||
|
found_inf: Optional[Tensor] = None,
|
||||||
|
has_complex: bool = False,
|
||||||
|
*,
|
||||||
|
beta1: float,
|
||||||
|
beta2: float,
|
||||||
|
lr: Union[float, Tensor],
|
||||||
|
weight_decay: float,
|
||||||
|
decoupled: bool,
|
||||||
|
eps: float,
|
||||||
|
maximize: bool,
|
||||||
|
):
|
||||||
|
r"""Functional API that performs ADOPT algorithm computation."""
|
||||||
|
# Respect when the user inputs False/True for foreach or fused. We only want to change
|
||||||
|
# the default when neither have been user-specified. Note that we default to foreach
|
||||||
|
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
|
||||||
|
# bake-in time before making it the default, even if it is typically faster.
|
||||||
|
if fused is None and foreach is None:
|
||||||
|
_, foreach = _default_to_fused_or_foreach(
|
||||||
|
params, differentiable, use_fused=False
|
||||||
|
)
|
||||||
|
# Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
|
||||||
|
if foreach and isinstance(lr, Tensor) and not capturable:
|
||||||
|
foreach = False
|
||||||
|
if fused is None:
|
||||||
|
fused = False
|
||||||
|
if foreach is None:
|
||||||
|
foreach = False
|
||||||
|
|
||||||
|
# this check is slow during compilation, so we skip it
|
||||||
|
# if it's strictly needed we can add this check back in dynamo
|
||||||
|
if not torch._utils.is_compiling() and not all(
|
||||||
|
isinstance(t, torch.Tensor) for t in state_steps
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||||
|
)
|
||||||
|
|
||||||
|
if foreach and torch.jit.is_scripting():
|
||||||
|
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||||
|
if fused and torch.jit.is_scripting():
|
||||||
|
raise RuntimeError("torch.jit.script not supported with fused optimizers")
|
||||||
|
|
||||||
|
# if fused and not torch.jit.is_scripting():
|
||||||
|
# func = _fused_adopt
|
||||||
|
# elif foreach and not torch.jit.is_scripting():
|
||||||
|
if foreach and not torch.jit.is_scripting():
|
||||||
|
func = _multi_tensor_adopt
|
||||||
|
else:
|
||||||
|
func = _single_tensor_adopt
|
||||||
|
|
||||||
|
func(
|
||||||
|
params,
|
||||||
|
grads,
|
||||||
|
exp_avgs,
|
||||||
|
exp_avg_sqs,
|
||||||
|
state_steps,
|
||||||
|
has_complex=has_complex,
|
||||||
|
beta1=beta1,
|
||||||
|
beta2=beta2,
|
||||||
|
lr=lr,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
decoupled=decoupled,
|
||||||
|
eps=eps,
|
||||||
|
maximize=maximize,
|
||||||
|
capturable=capturable,
|
||||||
|
differentiable=differentiable,
|
||||||
|
grad_scale=grad_scale,
|
||||||
|
found_inf=found_inf,
|
||||||
|
)
|
||||||
@@ -17,6 +17,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
LOG = get_logger("axolotl")
|
LOG = get_logger("axolotl")
|
||||||
@@ -184,11 +185,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
min_sequence_len=cfg.min_sample_len or 2,
|
min_sequence_len=cfg.min_sample_len or 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.is_preprocess:
|
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
||||||
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
||||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
|
||||||
|
|
||||||
if cfg.model_config_type == "mamba":
|
if cfg.model_config_type == "mamba":
|
||||||
LOG.info("dropping attention_mask column")
|
LOG.info("dropping attention_mask column")
|
||||||
@@ -461,6 +461,9 @@ def setup_fsdp_envs(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_optim_env(cfg):
|
def prepare_optim_env(cfg):
|
||||||
|
if not check_cuda_p2p_ib_support():
|
||||||
|
if os.getenv("NCCL_P2P_DISABLE") is None:
|
||||||
|
os.environ["NCCL_P2P_DISABLE"] = "1"
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from axolotl.train import train
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from .utils import with_temp_dir
|
from .utils import require_torch_2_5_1, with_temp_dir
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
os.environ["WANDB_DISABLED"] = "true"
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
@@ -65,3 +65,46 @@ class TestCustomOptimizers(unittest.TestCase):
|
|||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
@require_torch_2_5_1
|
||||||
|
def test_adopt_adamw(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "JackFram/llama-68m",
|
||||||
|
"tokenizer_type": "LlamaTokenizer",
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": True,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.05,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"val_set_size": 0.1,
|
||||||
|
"special_tokens": {
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
},
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"num_epochs": 1,
|
||||||
|
"micro_batch_size": 8,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adopt_adamw",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||||
|
|||||||
@@ -6,11 +6,13 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from importlib.metadata import version
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
# from importlib.metadata import version
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
|
||||||
def with_temp_dir(test_func):
|
def with_temp_dir(test_func):
|
||||||
@wraps(test_func)
|
@wraps(test_func)
|
||||||
@@ -43,12 +45,24 @@ def require_torch_2_3_1(test_case):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def is_min_2_3_1():
|
def is_min_2_3_1():
|
||||||
torch_version = version("torch")
|
torch_version = version.parse(torch.__version__)
|
||||||
return torch_version >= "2.3.1"
|
return torch_version >= version.parse("2.3.1")
|
||||||
|
|
||||||
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
|
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
|
||||||
|
|
||||||
|
|
||||||
|
def require_torch_2_5_1(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires torch >= 2.3.1
|
||||||
|
"""
|
||||||
|
|
||||||
|
def is_min_2_5_1():
|
||||||
|
torch_version = version.parse(torch.__version__)
|
||||||
|
return torch_version >= version.parse("2.5.1")
|
||||||
|
|
||||||
|
return unittest.skipUnless(is_min_2_5_1(), "test torch 2.5.1")(test_case)
|
||||||
|
|
||||||
|
|
||||||
def is_hopper():
|
def is_hopper():
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
return compute_capability == (9, 0)
|
return compute_capability == (9, 0)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
import functools
|
import functools
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@@ -21,6 +22,7 @@ class TestPretrainingPacking(unittest.TestCase):
|
|||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
self.tokenizer.pad_token = "</s>"
|
self.tokenizer.pad_token = "</s>"
|
||||||
|
|
||||||
|
@pytest.mark.flaky(retries=3, delay=5)
|
||||||
def test_packing_stream_dataset(self):
|
def test_packing_stream_dataset(self):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
|
|||||||
@@ -234,3 +234,59 @@ class TestValidationCheckDatasetConfig(BaseValidation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
_check_config()
|
_check_config()
|
||||||
|
|
||||||
|
def test_dataset_sharegpt_deprecation(self, minimal_cfg):
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"chat_template": "chatml",
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "LDJnr/Puffin",
|
||||||
|
"type": "sharegpt",
|
||||||
|
"conversation": "chatml",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check sharegpt deprecation is raised
|
||||||
|
with pytest.raises(ValueError, match=r".*type: sharegpt.*` is deprecated.*"):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
# Check that deprecation is not thrown for non-str type
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": {
|
||||||
|
"field_instruction": "instruction",
|
||||||
|
"field_output": "output",
|
||||||
|
"field_system": "system",
|
||||||
|
"format": "<|user|> {instruction} {input} <|model|>",
|
||||||
|
"no_input_format": "<|user|> {instruction} <|model|>",
|
||||||
|
"system_prompt": "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
# Check that deprecation is not thrown for non-sharegpt type
|
||||||
|
cfg = DictDefault(
|
||||||
|
minimal_cfg
|
||||||
|
| {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user