Compare commits
2 Commits
transforme
...
mm3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cdd8be7097 | ||
|
|
08143c7b0d |
24
.github/workflows/base.yml
vendored
24
.github/workflows/base.yml
vendored
@@ -28,37 +28,23 @@ jobs:
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.4.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
- name: Docker metadata
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@v3
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-base
|
||||
axolotlai/axolotl-base
|
||||
images: winglian/axolotl-base
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
|
||||
2
.github/workflows/docs.yml
vendored
2
.github/workflows/docs.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: Set up Quarto
|
||||
uses: quarto-dev/quarto-actions/setup@v2
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: install dependencies
|
||||
|
||||
6
.github/workflows/lint.yml
vendored
6
.github/workflows/lint.yml
vendored
@@ -15,9 +15,9 @@ jobs:
|
||||
name: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
|
||||
44
.github/workflows/main.yml
vendored
44
.github/workflows/main.yml
vendored
@@ -4,8 +4,6 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
tags:
|
||||
- "v*"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
@@ -29,12 +27,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.4.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -44,12 +37,7 @@ jobs:
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl
|
||||
axolotlai/axolotl
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{version}}
|
||||
images: winglian/axolotl
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to Docker Hub
|
||||
@@ -63,7 +51,7 @@ jobs:
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
|
||||
@@ -96,12 +84,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.4.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -111,22 +94,20 @@ jobs:
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-cloud
|
||||
axolotlai/axolotl-cloud
|
||||
images: winglian/axolotl-cloud
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
file: ./docker/Dockerfile-cloud
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
@@ -155,25 +136,20 @@ jobs:
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-cloud-term
|
||||
axolotlai/axolotl-cloud-term
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{version}}
|
||||
images: winglian/axolotl-cloud-term
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
file: ./docker/Dockerfile-cloud-no-tmux
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
|
||||
13
.github/workflows/multi-gpu-e2e.yml
vendored
13
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -21,17 +21,10 @@ jobs:
|
||||
pytorch: 2.3.1
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.3.1
|
||||
axolotl_extras:
|
||||
num_gpus: 2
|
||||
nightly_build: "true"
|
||||
|
||||
24
.github/workflows/nightlies.yml
vendored
24
.github/workflows/nightlies.yml
vendored
@@ -26,12 +26,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.4.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -41,9 +36,7 @@ jobs:
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl
|
||||
axolotlai/axolotl
|
||||
images: winglian/axolotl
|
||||
tags: |
|
||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||
- name: Set up Docker Buildx
|
||||
@@ -90,12 +83,7 @@ jobs:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.4.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
@@ -105,9 +93,7 @@ jobs:
|
||||
id: metadata
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: |
|
||||
winglian/axolotl-cloud
|
||||
axolotlai/axolotl-cloud
|
||||
images: winglian/axolotl-cloud
|
||||
tags: |
|
||||
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
|
||||
- name: Login to Docker Hub
|
||||
@@ -116,7 +102,7 @@ jobs:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@v2
|
||||
- name: Build
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
|
||||
27
.github/workflows/pypi.yml
vendored
27
.github/workflows/pypi.yml
vendored
@@ -3,31 +3,12 @@ name: publish pypi
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
workflow_dispatch:
|
||||
- '*'
|
||||
|
||||
jobs:
|
||||
setup_release:
|
||||
name: Create Release
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Get the tag version
|
||||
id: extract_branch
|
||||
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
||||
shell: bash
|
||||
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
||||
release_name: ${{ steps.extract_branch.outputs.branch }}
|
||||
pypi-publish:
|
||||
name: Upload release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: [setup_release]
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/axolotl
|
||||
@@ -35,10 +16,10 @@ jobs:
|
||||
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
@@ -46,7 +27,7 @@ jobs:
|
||||
run: |
|
||||
pip3 install wheel packaging
|
||||
pip3 install -e .
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
- name: Extract tag name
|
||||
id: tag
|
||||
|
||||
25
.github/workflows/tests-nightly.yml
vendored
25
.github/workflows/tests-nightly.yml
vendored
@@ -9,12 +9,12 @@ jobs:
|
||||
name: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
@@ -25,15 +25,15 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
||||
pytorch_version: ["2.3.1", "2.4.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
@@ -47,14 +47,13 @@ jobs:
|
||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging
|
||||
pip3 install -U -e .
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
@@ -82,17 +81,17 @@ jobs:
|
||||
num_gpus: 1
|
||||
axolotl_extras: mamba-ssm
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
pytorch: 2.3.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
axolotl_extras: mamba-ssm
|
||||
nightly_build: "true"
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.4.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
nightly_build: "true"
|
||||
|
||||
81
.github/workflows/tests.yml
vendored
81
.github/workflows/tests.yml
vendored
@@ -15,22 +15,17 @@ on:
|
||||
- '.github/workflows/*.yml'
|
||||
workflow_dispatch:
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: 'pip' # caching pip dependencies
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
@@ -41,33 +36,29 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.10", "3.11"]
|
||||
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
|
||||
pytorch_version: ["2.3.1", "2.4.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }}
|
||||
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging
|
||||
pip3 install -U -e .
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
@@ -77,52 +68,12 @@ jobs:
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
docker-e2e-tests-1st:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
needs: [pre-commit, pytest]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.4.1
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Install Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==0.63.64 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
|
||||
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
|
||||
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
|
||||
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
|
||||
- name: Run tests job on Modal
|
||||
run: |
|
||||
modal run cicd.tests
|
||||
|
||||
docker-e2e-tests:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
needs: [pre-commit, pytest, docker-e2e-tests-1st]
|
||||
timeout-minutes: 60
|
||||
needs: [pre-commit, pytest]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
@@ -134,10 +85,16 @@ jobs:
|
||||
pytorch: 2.3.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: mamba-ssm
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.1
|
||||
num_gpus: 1
|
||||
axolotl_extras: mamba-ssm
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
pytorch: 2.4.0
|
||||
num_gpus: 1
|
||||
axolotl_extras:
|
||||
steps:
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
[settings]
|
||||
profile=black
|
||||
known_third_party=wandb,comet_ml
|
||||
known_third_party=wandb
|
||||
|
||||
38
README.md
38
README.md
@@ -14,7 +14,7 @@ Features:
|
||||
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
|
||||
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
|
||||
- Easily run with Docker locally or on the cloud
|
||||
- Log results and optionally checkpoints to wandb, mlflow or Comet
|
||||
- Log results and optionally checkpoints to wandb or mlflow
|
||||
- And more!
|
||||
|
||||
<a href="https://www.phorm.ai/query?projectId=e315ba4a-4e14-421f-ab05-38a1f9076f25">
|
||||
@@ -121,7 +121,7 @@ Features:
|
||||
|
||||
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
|
||||
|
||||
**Requirements**: Nvidia GPU (Ampere architecture or newer for `bf16` and Flash Attention), Python >=3.10 and PyTorch >=2.3.1.
|
||||
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl
|
||||
@@ -159,7 +159,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
|
||||
#### Docker
|
||||
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest
|
||||
```
|
||||
|
||||
Or run on the current files for development:
|
||||
@@ -178,7 +178,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
|
||||
A more powerful Docker command to run would be this:
|
||||
|
||||
```bash
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-latest
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest
|
||||
```
|
||||
|
||||
It additionally:
|
||||
@@ -210,7 +210,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
||||
|
||||
#### Cloud GPU
|
||||
|
||||
For cloud GPU providers that support docker images, use [`axolotlai/axolotl-cloud:main-latest`](https://hub.docker.com/r/axolotlai/axolotl-cloud/tags)
|
||||
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
|
||||
|
||||
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
|
||||
@@ -319,7 +319,7 @@ Write a job description in YAML as below:
|
||||
# dstack.yaml
|
||||
type: task
|
||||
|
||||
image: axolotlai/axolotl-cloud:main-latest
|
||||
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2
|
||||
|
||||
env:
|
||||
- HUGGING_FACE_HUB_TOKEN
|
||||
@@ -383,10 +383,11 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- typescript
|
||||
type: ... # unimplemented custom format
|
||||
|
||||
# chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template
|
||||
# fastchat conversation
|
||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
- path: ...
|
||||
type: chat_template
|
||||
chat_template: chatml # defaults to tokenizer's chat_template
|
||||
type: sharegpt
|
||||
conversation: chatml # default: vicuna_v1.1
|
||||
|
||||
# local
|
||||
- path: data.jsonl # or json
|
||||
@@ -514,22 +515,6 @@ wandb_name:
|
||||
wandb_log_model:
|
||||
```
|
||||
|
||||
##### Comet Logging
|
||||
|
||||
Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to wandb with `comet login`.
|
||||
|
||||
- wandb options
|
||||
```yaml
|
||||
use_comet:
|
||||
comet_api_key:
|
||||
comet_workspace:
|
||||
comet_project_name:
|
||||
comet_experiment_key:
|
||||
comet_mode:
|
||||
comet_online:
|
||||
comet_experiment_config:
|
||||
```
|
||||
|
||||
##### Special Tokens
|
||||
|
||||
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
|
||||
@@ -561,8 +546,7 @@ plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_layer_norm: true
|
||||
liger_swiglu: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
```
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
|
||||
FROM winglian/axolotl-base:{{ BASE_TAG }}
|
||||
|
||||
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
|
||||
@@ -23,11 +23,11 @@ RUN git fetch origin +$GITHUB_REF && \
|
||||
git checkout FETCH_HEAD
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN pip install causal_conv1d
|
||||
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
@@ -37,7 +37,7 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
fi
|
||||
|
||||
# So we can test the Docker image
|
||||
RUN pip install -r requirements-dev.txt -r requirements-tests.txt
|
||||
RUN pip install -r requirements-tests.txt
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -10,7 +10,7 @@ import tempfile
|
||||
import jinja2
|
||||
import modal
|
||||
from jinja2 import select_autoescape
|
||||
from modal import App, Image
|
||||
from modal import Image, Stub
|
||||
|
||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
@@ -46,7 +46,7 @@ cicd_image = (
|
||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
)
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
stub = Stub("Axolotl CI/CD", secrets=[])
|
||||
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 2))
|
||||
@@ -61,10 +61,10 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
|
||||
@app.function(
|
||||
@stub.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=60 * 60,
|
||||
timeout=45 * 60,
|
||||
cpu=8.0,
|
||||
memory=131072 * N_GPUS,
|
||||
)
|
||||
@@ -72,6 +72,6 @@ def cicd_pytest():
|
||||
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
@stub.local_entrypoint()
|
||||
def main():
|
||||
cicd_pytest.remote()
|
||||
|
||||
@@ -10,7 +10,7 @@ import tempfile
|
||||
import jinja2
|
||||
import modal
|
||||
from jinja2 import select_autoescape
|
||||
from modal import App, Image
|
||||
from modal import Image, Stub
|
||||
|
||||
cicd_path = pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
@@ -47,7 +47,7 @@ cicd_image = (
|
||||
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
)
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
stub = Stub("Axolotl CI/CD", secrets=[])
|
||||
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||
@@ -62,10 +62,10 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
|
||||
@app.function(
|
||||
@stub.function(
|
||||
image=cicd_image,
|
||||
gpu=GPU_CONFIG,
|
||||
timeout=60 * 60,
|
||||
timeout=45 * 60,
|
||||
cpu=8.0,
|
||||
memory=131072,
|
||||
)
|
||||
@@ -73,6 +73,6 @@ def cicd_pytest():
|
||||
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
||||
|
||||
|
||||
@app.local_entrypoint()
|
||||
@stub.local_entrypoint()
|
||||
def main():
|
||||
cicd_pytest.remote()
|
||||
|
||||
@@ -14,6 +14,15 @@
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
|
||||
@@ -24,6 +24,15 @@
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
|
||||
@@ -20,6 +20,15 @@
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 32,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"train_batch_size": "auto",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Example config for debugging the chat_template prompt format
|
||||
# Example config for debugging the sharegpt prompt format
|
||||
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
@@ -7,8 +7,8 @@ load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
- path: philschmid/guanaco-sharegpt-style
|
||||
type: sharegpt
|
||||
shards: 10
|
||||
val_set_size: 0
|
||||
output_dir: temp_debug/axolotl_outputs/model
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG BASE_TAG=main-base
|
||||
FROM axolotlai/axolotl-base:$BASE_TAG
|
||||
FROM winglian/axolotl-base:$BASE_TAG
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
@@ -20,6 +20,7 @@ RUN git clone --depth=1 https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN pip install causal_conv1d
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG BASE_TAG=main
|
||||
FROM axolotlai/axolotl:$BASE_TAG
|
||||
FROM winglian/axolotl:$BASE_TAG
|
||||
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG BASE_TAG=main
|
||||
FROM axolotlai/axolotl:$BASE_TAG
|
||||
FROM winglian/axolotl:$BASE_TAG
|
||||
|
||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG BASE_TAG=main-base
|
||||
FROM axolotlai/axolotl-base:$BASE_TAG
|
||||
FROM winglian/axolotl-base:$BASE_TAG
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
|
||||
@@ -83,14 +83,22 @@ lora_on_cpu: true
|
||||
datasets:
|
||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
|
||||
data_files: # Optional[str] path to source data files
|
||||
shards: # Optional[int] number of shards to split data into
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
train_on_split: train # Optional[str] name of dataset split to load from
|
||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||
|
||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
field_human: # Optional[str]. Human key to use for conversation.
|
||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
||||
# Add additional keys from your dataset as input or output roles
|
||||
roles:
|
||||
input: # Optional[List[str]]. These will be masked based on train_on_input
|
||||
output: # Optional[List[str]].
|
||||
|
||||
# Custom user instruction prompt
|
||||
- path: repo
|
||||
@@ -115,48 +123,6 @@ datasets:
|
||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||
field:
|
||||
|
||||
# Using chat template
|
||||
- path: ...
|
||||
# Set type to `chat_template` to use this strategy
|
||||
type: chat_template
|
||||
# Specify the name of the chat template to use
|
||||
# The name of the chat template to use for training, following values are supported:
|
||||
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default.
|
||||
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
|
||||
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml.
|
||||
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
|
||||
chat_template: tokenizer_default
|
||||
# Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`).
|
||||
chat_template_jinja:
|
||||
# The key in the data example that contains the messages. Default is "messages".
|
||||
field_messages: messages
|
||||
# The key in the message turn that contains the role. Default is "role".
|
||||
message_field_role: role
|
||||
# The key in the message turn that contains the content. Default is "content".
|
||||
message_field_content: content
|
||||
# Optional[Dict[str, List]]. Roles mapping for the messages.
|
||||
roles:
|
||||
user: ["human", "user"]
|
||||
assistant: ["gpt", "assistant", "ai"]
|
||||
system: ["system"]
|
||||
|
||||
## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on.
|
||||
|
||||
# Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss.
|
||||
roles_to_train: ["gpt", "assistant"]
|
||||
# Optional[str]. Which EOS tokens to train on in the conversation. Possible values are:
|
||||
# - all: train on all EOS tokens
|
||||
# - turn: train on the EOS token at the end of each trainable turn
|
||||
# - last: train on the last EOS token in the conversation
|
||||
train_on_eos: last
|
||||
# The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`.
|
||||
message_field_training: training
|
||||
# The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn.
|
||||
# The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train).
|
||||
# See example at `docs/dataset-formats/conversation.qmd`
|
||||
message_field_training_detail: train_detail
|
||||
|
||||
|
||||
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
|
||||
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
|
||||
shuffle_merged_datasets: true
|
||||
@@ -174,19 +140,10 @@ test_datasets:
|
||||
|
||||
# use RL training: 'dpo', 'ipo', 'kto'
|
||||
rl:
|
||||
# whether to perform weighting if doing DPO training. Boolean.
|
||||
dpo_use_weighting:
|
||||
|
||||
# The name of the chat template to use for training, following values are supported:
|
||||
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
|
||||
# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py
|
||||
# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer.
|
||||
# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field.
|
||||
# The selected chat template will be saved to the tokenizer_config.json for easier inferencing
|
||||
# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template.
|
||||
chat_template: tokenizer_default
|
||||
# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null.
|
||||
chat_template_jinja: null
|
||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||
# Currently supports chatml and inst (mistral/mixtral)
|
||||
chat_template: chatml
|
||||
# Changes the default system message
|
||||
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
|
||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
@@ -308,21 +265,8 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
|
||||
# mlflow configuration if you're using it
|
||||
mlflow_tracking_uri: # URI to mlflow
|
||||
mlflow_experiment_name: # Your experiment name
|
||||
mlflow_run_name: # Your run name
|
||||
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
||||
|
||||
# Comet configuration if you're using it
|
||||
# Make sure your `COMET_API_KEY` environment variable is set (recommended) or you login to Comet with `comet login`.
|
||||
# Check out our documentation for more details https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment-Creation/#comet_ml.start
|
||||
use_comet: # Enable or disable Comet integration.
|
||||
comet_api_key: # API key for Comet. Recommended to set via `comet login`.
|
||||
comet_workspace: # Workspace name in Comet. Defaults to the user's default workspace.
|
||||
comet_project_name: # Project name in Comet. Defaults to Uncategorized.
|
||||
comet_experiment_key: # Identifier for the experiment. Used to append data to an existing experiment or control the key of new experiments. Default to a random key.
|
||||
comet_mode: # Create a new experiment ("create") or log to an existing one ("get"). Default ("get_or_create") auto-selects based on configuration.
|
||||
comet_online: # Set to True to log data to Comet server, or False for offline storage. Default is True.
|
||||
comet_experiment_config: # Dictionary for additional configuration settings, see the doc for more details.
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
output_dir: ./completed-model
|
||||
|
||||
@@ -357,7 +301,7 @@ max_steps:
|
||||
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", "chrf", "perplexity"]
|
||||
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
|
||||
|
||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||
@@ -405,7 +349,6 @@ lr_div_factor: # Learning rate div factor
|
||||
# - adamw_torch_fused
|
||||
# - adamw_torch_xla
|
||||
# - adamw_apex_fused
|
||||
# - adopt_adamw (only for torch version >= 2.5.1)
|
||||
# - adafactor
|
||||
# - adamw_anyprecision
|
||||
# - sgd
|
||||
|
||||
@@ -6,8 +6,31 @@ order: 3
|
||||
|
||||
## sharegpt
|
||||
|
||||
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
|
||||
conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
Note: `type: sharegpt` opens special configs:
|
||||
- `conversation`: enables conversions to many Conversation types. Refer to the 'name' [here](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) for options.
|
||||
- `roles`: allows you to specify the roles for input and output. This is useful for datasets with custom roles such as `tool` etc to support masking.
|
||||
- `field_human`: specify the key to use instead of `human` in the conversation.
|
||||
- `field_model`: specify the key to use instead of `gpt` in the conversation.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
path: ...
|
||||
type: sharegpt
|
||||
|
||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
field_human: # Optional[str]. Human key to use for conversation.
|
||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
||||
# Add additional keys from your dataset as input or output roles
|
||||
roles:
|
||||
input: # Optional[List[str]]. These will be masked based on train_on_input
|
||||
output: # Optional[List[str]].
|
||||
```
|
||||
|
||||
## pygmalion
|
||||
|
||||
@@ -15,137 +38,34 @@ IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
## sharegpt.load_role
|
||||
|
||||
## chat_template
|
||||
|
||||
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
||||
conversations where `role` is used instead of `from`
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"role": "...", "content": "..."}]}
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
See `config.qmd` for full configs and supported templates.
|
||||
## sharegpt.load_guanaco
|
||||
|
||||
### Migrating from sharegpt
|
||||
|
||||
Most configs can be adapted as follows:
|
||||
|
||||
```yaml
|
||||
# old
|
||||
chat_template: chatml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: sharegpt
|
||||
conversation: chatml
|
||||
|
||||
# new (if using tokenizer's chat_template)
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
# new (if setting a new chat_template like chatml, gemma, etc)
|
||||
chat_template: chatml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
```
|
||||
|
||||
We recommend checking the below examples for other usecases.
|
||||
|
||||
### Examples
|
||||
|
||||
1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
```
|
||||
|
||||
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
chat_template: gemma # this overwrites the tokenizer's chat_template
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train: ["assistant"]
|
||||
```
|
||||
|
||||
3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train: ["assistant"]
|
||||
```
|
||||
|
||||
4. Using a custom jinja template on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty
|
||||
chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
|
||||
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
roles_to_train: ["assistant"]
|
||||
```
|
||||
|
||||
5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation
|
||||
|
||||
For a data sample that looks like:
|
||||
conversations where `from` is `prompter` `assistant` instead of default sharegpt
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{
|
||||
"conversations": [
|
||||
{"from": "system", "value": "You are an AI assistant.", "train": false},
|
||||
{"from": "human", "value": "Hello", "train": false},
|
||||
{"from": "assistant", "value": "Hello", "train": true},
|
||||
{"from": "human", "value": "How are you?", "train": true},
|
||||
{
|
||||
"from": "assistant",
|
||||
"value": "I'm doing very well, thank you!",
|
||||
"train_detail": [
|
||||
{"begin_offset": 0, "end_offset": 8, "train": false},
|
||||
{"begin_offset": 9, "end_offset": 18, "train": true},
|
||||
{"begin_offset": 19, "end_offset": 30, "train": false},
|
||||
],
|
||||
},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "I'm doing very well, thank you!",
|
||||
"train": true,
|
||||
},
|
||||
{"from": "assistant", "value": "Hi there!", "train": true}
|
||||
]
|
||||
}
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
|
||||
The configuration would look like:
|
||||
## sharegpt.load_ultrachat
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
chat_template: tokenizer_default
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
roles_to_train: []
|
||||
train_on_eos: turn
|
||||
message_field_training: train
|
||||
message_field_training_detail: train_detail
|
||||
conversations where the turns field is 'messages', human is 'user' and gpt is 'assistant'.
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"messages": [{"user": "...", "assistant": "..."}]}
|
||||
```
|
||||
|
||||
Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time.
|
||||
## sharegpt_jokes
|
||||
|
||||
creates a chat where bot is asked to tell a joke, then explain why the joke is funny
|
||||
|
||||
```{.json filename="data.jsonl"}
|
||||
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
|
||||
```
|
||||
|
||||
@@ -51,12 +51,12 @@ While debugging it's helpful to simplify your test scenario as much as possible.
|
||||
|
||||
### Background
|
||||
|
||||
The below example shows how to configure VSCode to debug data preprocessing of the `chat_template` format. This is the format used when you have the following in your axolotl config:
|
||||
The below example shows how to configure VSCode to debug data preprocessing of the `sharegpt` format. This is the format used when you have the following in your axolotl config:
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: <path to your chat_template formatted dataset> # example on HF Hub: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
- path: <path to your sharegpt formatted dataset> # example on HF Hub: philschmid/guanaco-sharegpt-style
|
||||
type: sharegpt
|
||||
```
|
||||
|
||||
>[!Important]
|
||||
@@ -83,7 +83,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
|
||||
|
||||
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
|
||||
|
||||
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
||||
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_sharegpt.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
|
||||
|
||||
```jsonc
|
||||
// .vscode/launch.json
|
||||
@@ -91,12 +91,12 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Debug axolotl prompt - chat_template",
|
||||
"name": "Debug axolotl prompt - sharegpt",
|
||||
"type": "python",
|
||||
"module": "accelerate.commands.launch",
|
||||
"request": "launch",
|
||||
"args": [
|
||||
"-m", "axolotl.cli.train", "dev_chat_template.yml",
|
||||
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
|
||||
// The flags below simplify debugging by overriding the axolotl config
|
||||
// with the debugging tips above. Modify as needed.
|
||||
"--dataset_processes=1", // limits data preprocessing to one process
|
||||
@@ -185,7 +185,7 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
|
||||
|
||||
## Debugging With Docker
|
||||
|
||||
Using [official Axolotl Docker images](https://hub.docker.com/r/axolotlai/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
|
||||
Using [official Axolotl Docker images](https://hub.docker.com/r/winglian/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
|
||||
|
||||
### Setup
|
||||
|
||||
@@ -202,11 +202,11 @@ cd axolotl
|
||||
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
|
||||
|
||||
```bash
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
|
||||
>[!Tip]
|
||||
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
||||
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/winglian/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
|
||||
|
||||
You will now be in the container. Next, perform an editable install of Axolotl:
|
||||
|
||||
@@ -240,6 +240,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
|
||||
</div>
|
||||
<br>
|
||||
|
||||
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
|
||||
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/sharegpt.yml`, but this is the same thing.
|
||||
|
||||
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).
|
||||
|
||||
@@ -44,7 +44,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
|
||||
"!pip install flash-attn==\"2.7.0.post2\"\n",
|
||||
"!pip install flash-attn==\"2.5.0\"\n",
|
||||
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
||||
]
|
||||
},
|
||||
|
||||
@@ -9,17 +9,14 @@ strict: false
|
||||
plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_swiglu: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
chat_template: deepseek_v2
|
||||
datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
split: train
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
|
||||
@@ -11,11 +11,8 @@ chat_template: gemma
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
chat_template: gemma
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
base_model: google/gemma-2-2b
|
||||
model_type: AutoModelForSequenceClassification
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
reward_model: true
|
||||
chat_template: gemma
|
||||
datasets:
|
||||
- path: argilla/distilabel-intel-orca-dpo-pairs
|
||||
type: bradley_terry.chat_template
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
remove_unused_columns: false
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
eval_sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16:
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_ratio: 0.1
|
||||
evals_per_epoch:
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -4,15 +4,11 @@ tokenizer_type: AutoTokenizer
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
use_tensorboard: true
|
||||
chat_template: jamba
|
||||
datasets:
|
||||
- path: cgato/SlimOrcaDedupCleaned
|
||||
type: chat_template
|
||||
chat_template: jamba
|
||||
drop_system_message: true
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.0
|
||||
output_dir: jamba-large-fsdp-qlora-ft
|
||||
|
||||
@@ -4,7 +4,7 @@ plugins:
|
||||
- axolotl.integrations.liger.LigerPlugin
|
||||
liger_rope: true
|
||||
liger_rms_norm: true
|
||||
liger_glu_activation: true
|
||||
liger_swiglu: true
|
||||
liger_fused_linear_cross_entropy: true
|
||||
|
||||
strict: false
|
||||
@@ -14,10 +14,6 @@ datasets:
|
||||
- path: mlabonne/FineTome-100k
|
||||
type: chat_template
|
||||
split: train[:20%]
|
||||
field_messages: conversations
|
||||
message_field_role: from
|
||||
message_field_content: value
|
||||
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
output_dir: ./outputs/out
|
||||
|
||||
@@ -11,6 +11,7 @@ rl: dpo
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||
type: chat_template.default
|
||||
chat_template: llama3
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
|
||||
@@ -10,6 +10,7 @@ chat_template: llama3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
chat_template: llama3
|
||||
field_messages: messages
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
base_model: meta-llama/Llama-3.2-1B
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: true
|
||||
eval_sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
pad_token: "<|end_of_text|>"
|
||||
@@ -1,93 +0,0 @@
|
||||
#Note that we are switching from the regular chat template to chatml.
|
||||
#If you experience problems with the special tokens, training for more epochs can help.
|
||||
#After training, merge the model before inference otherwise you might
|
||||
#face problems with the special tokens.
|
||||
|
||||
base_model: mistralai/Mistral-7B-Instruct-v0.2
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
chat_template: chatml
|
||||
rl: dpo
|
||||
datasets:
|
||||
- path: olivermolenschot/alpaca_messages_dpo_test
|
||||
type: chat_template.default
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./outputs/dpo-qlora
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.2
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
lora_modules_to_save:
|
||||
- embed_tokens
|
||||
- lm_head
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 16
|
||||
num_epochs: 6
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0001
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: false
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<|im_start|>"
|
||||
eos_token: "<|im_end|>"
|
||||
@@ -10,6 +10,7 @@ chat_template: phi_3
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_test
|
||||
type: chat_template
|
||||
chat_template: phi_3
|
||||
field_messages: messages
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
|
||||
@@ -2,4 +2,3 @@ pre-commit
|
||||
black
|
||||
mypy
|
||||
types-requests
|
||||
tbparse
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
pytest
|
||||
pytest-xdist
|
||||
pytest-retry
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft==0.13.2
|
||||
transformers==4.46.1
|
||||
tokenizers>=0.20.1
|
||||
bitsandbytes==0.44.1
|
||||
accelerate==1.1.0
|
||||
datasets==3.0.1
|
||||
deepspeed==0.15.3
|
||||
peft==0.13.0
|
||||
transformers==4.45.1
|
||||
tokenizers>=0.19.1
|
||||
bitsandbytes==0.44.0
|
||||
accelerate==0.34.2
|
||||
datasets==2.21.0
|
||||
deepspeed==0.14.4
|
||||
pydantic==2.6.3
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
requests
|
||||
flash-attn==2.7.0.post2
|
||||
flash-attn==2.6.3
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers>=0.0.23.post1
|
||||
xformers==0.0.27
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
colorama
|
||||
@@ -28,12 +28,13 @@ scipy
|
||||
scikit-learn==1.4.2
|
||||
pynvml
|
||||
art
|
||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
python-dotenv==1.0.1
|
||||
autoawq>=0.2.5
|
||||
triton>=2.3.0
|
||||
liger-kernel==0.4.1
|
||||
liger-kernel==0.3.0
|
||||
|
||||
mamba-ssm==1.2.0.post1
|
||||
|
||||
@@ -42,14 +43,6 @@ s3fs>=2024.5.0
|
||||
gcsfs>=2024.5.0
|
||||
# adlfs
|
||||
|
||||
trl==0.12.0
|
||||
trl==0.9.6
|
||||
zstandard==0.22.0
|
||||
fastcore
|
||||
|
||||
# lm eval harness
|
||||
lm_eval==0.4.4
|
||||
langdetect==1.0.9
|
||||
immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.5.0
|
||||
|
||||
@@ -1,315 +0,0 @@
|
||||
accelerate==0.34.1
|
||||
addict==2.4.0
|
||||
aiofiles==23.2.1
|
||||
aiohttp==3.9.0
|
||||
aiosignal==1.3.1
|
||||
aiostream==0.5.2
|
||||
alembic==1.13.1
|
||||
annotated-types==0.6.0
|
||||
annoy==1.17.3
|
||||
ansible==6.7.0
|
||||
ansible-core==2.13.13
|
||||
ansible-vault==2.1.0
|
||||
anyio==3.7.1
|
||||
appdirs==1.4.4
|
||||
art==6.0
|
||||
asgiref==3.7.2
|
||||
async-timeout==4.0.2
|
||||
attrdict==2.0.1
|
||||
attrs==22.2.0
|
||||
awscli==1.32.75
|
||||
-e git+ssh://git@github.com/OpenAccess-AI-Collective/axolotl.git@6e354682e3c1735d3f7fb9e362280c38e922260f#egg=axolotl
|
||||
backoff==2.2.1
|
||||
base58==2.1.1
|
||||
beartype==0.17.2
|
||||
bitnet==0.2.1
|
||||
bitsandbytes==0.42.0
|
||||
bittensor==6.7.0
|
||||
black==23.7.0
|
||||
blinker==1.7.0
|
||||
boto3==1.34.75
|
||||
botocore==1.34.75
|
||||
cachetools==5.3.3
|
||||
cachy==0.1.1
|
||||
certifi==2023.7.22
|
||||
cffi==1.16.0
|
||||
cfgv==3.3.1
|
||||
chai-guanaco==1.2.4
|
||||
charset-normalizer==3.2.0
|
||||
cleo==0.6.8
|
||||
click==8.1.7
|
||||
cloudpickle==2.0.0
|
||||
cohere==4.11.2
|
||||
colorama==0.4.4
|
||||
coloredlogs==15.0.1
|
||||
CoLT5-attention==0.10.20
|
||||
contextlib2==21.6.0
|
||||
contourpy==1.2.0
|
||||
cryptography==41.0.3
|
||||
cycler==0.12.1
|
||||
cytoolz==0.12.3
|
||||
databricks-cli==0.18.0
|
||||
dataclasses-json==0.5.7
|
||||
datasets==2.11.0
|
||||
ddt==1.6.0
|
||||
decorator==5.1.1
|
||||
deepspeed==0.15.0
|
||||
# Editable Git install with no remote (dialogpt==0.1)
|
||||
-e /Users/wing/Projects/ml/dialogpt/src
|
||||
dill==0.3.6
|
||||
distlib==0.3.6
|
||||
docker==7.0.0
|
||||
docker-pycreds==0.4.0
|
||||
docstring-parser==0.15
|
||||
docutils==0.16
|
||||
ecdsa==0.18.0
|
||||
einops==0.7.0
|
||||
einops-exts==0.0.4
|
||||
einx==0.1.3
|
||||
entrypoints==0.4
|
||||
eth-hash==0.6.0
|
||||
eth-keys==0.5.0
|
||||
eth-typing==4.0.0
|
||||
eth-utils==2.3.1
|
||||
evaluate==0.4.0
|
||||
exceptiongroup==1.1.1
|
||||
fastapi==0.109.2
|
||||
fastcore==1.5.29
|
||||
ffmpy==0.4.0
|
||||
filelock==3.12.2
|
||||
-e git+https://github.com/NousResearch/finetuning-subnet.git@24e9407d6b4430a7ca39d344692f89ce5a97d27e#egg=finetuning_subnet
|
||||
fire==0.5.0
|
||||
first==2.0.2
|
||||
flake8==7.0.0
|
||||
Flask==3.0.1
|
||||
fonttools==4.47.2
|
||||
frozendict==2.4.1
|
||||
frozenlist==1.3.3
|
||||
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
|
||||
fsspec==2023.6.0
|
||||
fuzzywuzzy==0.18.0
|
||||
gitdb==4.0.10
|
||||
GitPython==3.1.31
|
||||
google-pasta==0.2.0
|
||||
gradio==4.42.0
|
||||
gradio_client==1.3.0
|
||||
greenlet==2.0.2
|
||||
grpclib==0.4.7
|
||||
gunicorn==21.2.0
|
||||
h11==0.14.0
|
||||
h2==4.1.0
|
||||
hpack==4.0.0
|
||||
httpcore==0.17.3
|
||||
httpx==0.24.1
|
||||
huggingface-hub==0.23.4
|
||||
humanfriendly==10.0
|
||||
hyperframe==6.0.1
|
||||
identify==2.5.24
|
||||
idna==3.4
|
||||
immutables==0.20
|
||||
importlib-metadata==6.7.0
|
||||
importlib-resources==6.1.1
|
||||
inflection==0.5.1
|
||||
iniconfig==2.0.0
|
||||
itsdangerous==2.1.2
|
||||
Jinja2==3.1.2
|
||||
jmespath==1.0.1
|
||||
joblib==1.3.2
|
||||
jsonlines==3.1.0
|
||||
jsonschema==2.6.0
|
||||
kiwisolver==1.4.5
|
||||
langchain==0.0.144
|
||||
Levenshtein==0.24.0
|
||||
libcst==1.1.0
|
||||
liger-kernel==0.0.0
|
||||
lion-pytorch==0.1.2
|
||||
llama-cpp-python==0.1.36
|
||||
llvmlite==0.40.1
|
||||
local-attention==1.9.0
|
||||
loguru==0.7.0
|
||||
Mako==1.3.2
|
||||
Markdown==3.5.2
|
||||
markdown-it-py==3.0.0
|
||||
markdown2==2.4.10
|
||||
MarkupSafe==2.1.2
|
||||
marshmallow==3.19.0
|
||||
marshmallow-enum==1.5.1
|
||||
matplotlib==3.8.2
|
||||
mccabe==0.7.0
|
||||
mdurl==0.1.2
|
||||
MEGABYTE-pytorch==0.0.7
|
||||
-e git+https://github.com/cg123/mergekit.git@53c5f414774a0558b8d84858fb6374bc93a8f1c1#egg=mergekit
|
||||
mlflow==2.10.0
|
||||
modal==0.62.77
|
||||
more-itertools==10.2.0
|
||||
mpmath==1.2.1
|
||||
msgpack==1.0.7
|
||||
msgpack-numpy-opentensor==0.5.0
|
||||
multidict==6.0.4
|
||||
multiprocess==0.70.14
|
||||
munch==2.5.0
|
||||
mypy==1.3.0
|
||||
mypy-extensions==1.0.0
|
||||
nest-asyncio==1.6.0
|
||||
netaddr==0.10.1
|
||||
networkx==3.0rc1
|
||||
nh3==0.2.14
|
||||
nodeenv==1.8.0
|
||||
nomic==2.0.2
|
||||
numba==0.57.1
|
||||
numexpr==2.8.4
|
||||
numpy==1.24.4
|
||||
oauthlib==3.2.2
|
||||
openai==0.27.4
|
||||
openapi==1.1.0
|
||||
openapi-schema-pydantic==1.2.4
|
||||
optimum==1.8.6
|
||||
orjson==3.10.7
|
||||
packaging==23.1
|
||||
pandas==2.0.0
|
||||
parameterized==0.9.0
|
||||
password-strength==0.0.3.post2
|
||||
pastel==0.1.1
|
||||
pathos==0.3.0
|
||||
pathspec==0.11.1
|
||||
pathtools==0.1.2
|
||||
peft==0.11.1
|
||||
pendulum==3.0.0
|
||||
Pillow==9.5.0
|
||||
pip-tools==1.11.0
|
||||
platformdirs==3.2.0
|
||||
pluggy==1.4.0
|
||||
poetry==0.7.1
|
||||
pox==0.3.2
|
||||
ppft==1.7.6.6
|
||||
pre-commit==3.3.2
|
||||
prettytable==3.10.0
|
||||
prompt-toolkit==3.0.39
|
||||
protobuf==3.20.2
|
||||
protobuf3-to-dict==0.1.5
|
||||
psutil==5.9.5
|
||||
psycopg==3.1.18
|
||||
PuLP==2.8.0
|
||||
py==1.11.0
|
||||
py-bip39-bindings==0.1.11
|
||||
py-cpuinfo==9.0.0
|
||||
py-ed25519-zebra-bindings==1.0.1
|
||||
py-sr25519-bindings==0.2.0
|
||||
pyarrow==11.0.0
|
||||
pyasn1==0.6.0
|
||||
pycodestyle==2.11.1
|
||||
pycparser==2.21
|
||||
pycryptodome==3.20.0
|
||||
pydantic==2.5.3
|
||||
pydantic_core==2.14.6
|
||||
pydub==0.25.1
|
||||
pyfiglet==0.8.post1
|
||||
pyflakes==3.2.0
|
||||
Pygments==2.15.1
|
||||
PyJWT==2.8.0
|
||||
pylev==1.4.0
|
||||
PyNaCl==1.5.0
|
||||
pynvml==11.5.0
|
||||
pyparsing==2.4.7
|
||||
pyrsistent==0.14.11
|
||||
pytest==8.0.2
|
||||
pytest-asyncio==0.23.4
|
||||
python-dateutil==2.8.2
|
||||
python-dotenv==1.0.1
|
||||
python-Levenshtein==0.24.0
|
||||
python-multipart==0.0.9
|
||||
pytz==2023.3
|
||||
PyYAML==6.0.1
|
||||
querystring-parser==1.2.4
|
||||
rapidfuzz==3.6.1
|
||||
regex==2023.6.3
|
||||
requests==2.31.0
|
||||
requests-toolbelt==0.8.0
|
||||
resolvelib==0.8.1
|
||||
responses==0.18.0
|
||||
retry==0.9.2
|
||||
rich==13.7.0
|
||||
rsa==4.7.2
|
||||
ruff==0.6.3
|
||||
s3transfer==0.10.1
|
||||
safetensors==0.4.5
|
||||
sagemaker==2.148.0
|
||||
scalecodec==1.2.7
|
||||
schedulefree==1.2.1
|
||||
schema==0.7.5
|
||||
scikit-learn==1.4.0
|
||||
scipy==1.9.3
|
||||
seaborn==0.13.2
|
||||
semantic-version==2.10.0
|
||||
sentencepiece==0.2.0
|
||||
sentry-sdk==1.19.1
|
||||
setproctitle==1.3.2
|
||||
shellingham==1.5.4
|
||||
shortuuid==1.0.11
|
||||
shtab==1.6.5
|
||||
sigtools==4.0.1
|
||||
six==1.16.0
|
||||
skypilot==0.4.1
|
||||
smdebug-rulesconfig==1.0.1
|
||||
smmap==5.0.0
|
||||
sniffio==1.3.0
|
||||
SQLAlchemy==1.4.47
|
||||
sqlparse==0.4.4
|
||||
starlette==0.36.3
|
||||
substrate-interface==1.5.2
|
||||
svgwrite==1.4.3
|
||||
sympy==1.11.1
|
||||
synchronicity==0.6.7
|
||||
tabulate==0.9.0
|
||||
tblib==1.7.0
|
||||
tenacity==8.2.2
|
||||
tensor-parallel==2.0.0
|
||||
termcolor==2.2.0
|
||||
text2art==0.2.0
|
||||
threadpoolctl==3.2.0
|
||||
tiktoken==0.6.0
|
||||
time-machine==2.14.1
|
||||
timm==0.9.16
|
||||
tokenizers==0.19.1
|
||||
tokenmonster==1.1.12
|
||||
toml==0.9.6
|
||||
tomli==2.0.1
|
||||
tomlkit==0.12.0
|
||||
toolz==0.12.1
|
||||
torch==2.2.0
|
||||
torchdata==0.6.1
|
||||
torchdiffeq==0.2.3
|
||||
TorchFix==0.4.0
|
||||
torchtext==0.15.2
|
||||
torchvision==0.17.0
|
||||
tqdm==4.66.2
|
||||
transformers==4.44.2
|
||||
trl==0.9.6
|
||||
typer==0.12.5
|
||||
types-certifi==2021.10.8.3
|
||||
types-requests==2.31.0.20240125
|
||||
types-setuptools==69.0.0.20240125
|
||||
types-toml==0.10.8.7
|
||||
typing==3.7.4.3
|
||||
typing-inspect==0.8.0
|
||||
typing_extensions==4.9.0
|
||||
tyro==0.5.18
|
||||
tzdata==2023.3
|
||||
unique-names-generator==1.0.2
|
||||
urllib3==2.2.2
|
||||
uvicorn==0.22.0
|
||||
vector_quantize_pytorch==1.14.1
|
||||
virtualenv==20.23.0
|
||||
voyager==2.0.2
|
||||
wandb==0.16.2
|
||||
watchfiles==0.21.0
|
||||
wavedrom==2.0.3.post3
|
||||
wcwidth==0.2.6
|
||||
websocket-client==1.7.0
|
||||
websockets==12.0
|
||||
Werkzeug==3.0.1
|
||||
wonderwords==2.2.0
|
||||
xxhash==3.2.0
|
||||
yarl==1.8.2
|
||||
zetascale==2.2.7
|
||||
zipp==3.15.0
|
||||
@@ -1,60 +0,0 @@
|
||||
"""
|
||||
helper script to parse chat datasets into a usable yaml
|
||||
"""
|
||||
import click
|
||||
import yaml
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("dataset", type=str)
|
||||
@click.option("--split", type=str, default="train")
|
||||
def parse_dataset(dataset=None, split="train"):
|
||||
ds_cfg = {}
|
||||
ds_cfg["path"] = dataset
|
||||
ds_cfg["split"] = split
|
||||
ds_cfg["type"] = "chat_template"
|
||||
ds_cfg["chat_template"] = "<<<Replace based on your model>>>"
|
||||
|
||||
dataset = load_dataset(dataset, split=split)
|
||||
features = dataset.features
|
||||
feature_keys = features.keys()
|
||||
field_messages = None
|
||||
for key in ["conversation", "conversations", "messages"]:
|
||||
if key in feature_keys:
|
||||
field_messages = key
|
||||
break
|
||||
if not field_messages:
|
||||
raise ValueError(
|
||||
f'No conversation field found in dataset: {", ".join(feature_keys)}'
|
||||
)
|
||||
ds_cfg["field_messages"] = field_messages
|
||||
|
||||
message_fields = features["conversations"][0].keys()
|
||||
message_field_role = None
|
||||
for key in ["from", "role"]:
|
||||
if key in message_fields:
|
||||
message_field_role = key
|
||||
break
|
||||
if not message_field_role:
|
||||
raise ValueError(
|
||||
f'No role field found in messages: {", ".join(message_fields)}'
|
||||
)
|
||||
ds_cfg["message_field_role"] = message_field_role
|
||||
|
||||
message_field_content = None
|
||||
for key in ["content", "text", "value"]:
|
||||
if key in message_fields:
|
||||
message_field_content = key
|
||||
break
|
||||
if not message_field_content:
|
||||
raise ValueError(
|
||||
f'No content field found in messages: {", ".join(message_fields)}'
|
||||
)
|
||||
ds_cfg["message_field_content"] = message_field_content
|
||||
|
||||
print(yaml.dump({"datasets": [ds_cfg]}))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parse_dataset()
|
||||
39
setup.py
39
setup.py
@@ -30,19 +30,13 @@ def parse_requirements():
|
||||
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||
|
||||
if "Darwin" in platform.system():
|
||||
# don't install xformers on MacOS
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
else:
|
||||
# detect the version of torch already installed
|
||||
# and set it so dependencies don't clobber the torch version
|
||||
try:
|
||||
torch_version = version("torch")
|
||||
except PackageNotFoundError:
|
||||
torch_version = "2.5.1"
|
||||
torch_version = version("torch")
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
@@ -55,39 +49,20 @@ def parse_requirements():
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
_install_requires.append("xformers==0.0.28.post2")
|
||||
else:
|
||||
_install_requires.append("xformers==0.0.28.post3")
|
||||
_install_requires.pop(_install_requires.index(autoawq_version))
|
||||
elif (major, minor) >= (2, 4):
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers==0.0.28.post1")
|
||||
elif (major, minor) >= (2, 3):
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
if (major, minor) >= (2, 3):
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.26.post1")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.27")
|
||||
elif (major, minor) >= (2, 2):
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.25.post1")
|
||||
else:
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.23.post1")
|
||||
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
|
||||
return _install_requires, _dependency_links
|
||||
|
||||
|
||||
@@ -96,7 +71,7 @@ install_requires, dependency_links = parse_requirements()
|
||||
|
||||
setup(
|
||||
name="axolotl",
|
||||
version="0.5.0",
|
||||
version="0.4.1",
|
||||
description="LLM Trainer",
|
||||
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
||||
package_dir={"": "src"},
|
||||
@@ -105,7 +80,10 @@ setup(
|
||||
dependency_links=dependency_links,
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn==2.7.0.post2",
|
||||
"flash-attn==2.6.3",
|
||||
],
|
||||
"fused-dense-lib": [
|
||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.14.4",
|
||||
@@ -113,7 +91,6 @@ setup(
|
||||
],
|
||||
"mamba-ssm": [
|
||||
"mamba-ssm==1.2.0.post1",
|
||||
"causal_conv1d",
|
||||
],
|
||||
"auto-gptq": [
|
||||
"auto-gptq==0.5.1",
|
||||
|
||||
@@ -30,8 +30,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
normalize_config,
|
||||
@@ -55,22 +54,8 @@ LOG = logging.getLogger("axolotl.scripts")
|
||||
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
AXOLOTL_LOGO = """
|
||||
#@@ #@@ @@# @@#
|
||||
@@ @@ @@ @@ =@@# @@ #@ =@@#.
|
||||
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
|
||||
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
|
||||
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
|
||||
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
|
||||
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
|
||||
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
|
||||
@@@@ @@@@@@@@@@@@@@@@
|
||||
"""
|
||||
|
||||
|
||||
def print_legacy_axolotl_text_art(suffix=None):
|
||||
def print_axolotl_text_art(suffix=None):
|
||||
font = "nancyj"
|
||||
ascii_text = " axolotl"
|
||||
if suffix:
|
||||
@@ -83,13 +68,6 @@ def print_legacy_axolotl_text_art(suffix=None):
|
||||
print_dep_versions()
|
||||
|
||||
|
||||
def print_axolotl_text_art(
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if is_main_process():
|
||||
print(AXOLOTL_LOGO)
|
||||
|
||||
|
||||
def print_dep_versions():
|
||||
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
|
||||
max_len = max(len(pkg) for pkg in packages)
|
||||
@@ -190,15 +168,18 @@ def do_inference(
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
chat_template_str = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
elif cfg.chat_template:
|
||||
chat_template_str = get_chat_template(cfg.chat_template)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
@@ -208,31 +189,13 @@ def do_inference(
|
||||
instruction = get_multi_line_input()
|
||||
if not instruction:
|
||||
return
|
||||
|
||||
if prompter_module:
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
|
||||
if chat_template_str:
|
||||
batch = tokenizer.apply_chat_template(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
return_tensors="pt",
|
||||
add_special_tokens=True,
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
)
|
||||
else:
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
print("=" * 40)
|
||||
model.eval()
|
||||
@@ -272,6 +235,13 @@ def do_inference_gradio(
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
default_tokens: Dict[str, str] = {}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
chat_template_str = None
|
||||
@@ -280,7 +250,7 @@ def do_inference_gradio(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
elif cfg.chat_template:
|
||||
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
|
||||
chat_template_str = chat_templates(cfg.chat_template)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
@@ -451,8 +421,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
|
||||
setup_mlflow_env_vars(cfg)
|
||||
|
||||
setup_comet_env_vars(cfg)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
@@ -470,12 +438,7 @@ def load_datasets(
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
if (
|
||||
cli_args.debug
|
||||
or cfg.debug
|
||||
or cli_args.debug_text_only
|
||||
or int(cli_args.debug_num_examples) > 0
|
||||
):
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
train_dataset.select(
|
||||
|
||||
@@ -23,7 +23,10 @@ from axolotl.cli import (
|
||||
)
|
||||
from axolotl.common.cli import PreprocessCliArgs
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.utils.trainer import disable_datasets_caching
|
||||
from axolotl.prompt_strategies.sharegpt import (
|
||||
register_chatml_template,
|
||||
register_llama3_template,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||
|
||||
@@ -40,6 +43,23 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
if parsed_cfg.chat_template == "chatml":
|
||||
if parsed_cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||
)
|
||||
register_chatml_template(parsed_cfg.default_system_message)
|
||||
else:
|
||||
register_chatml_template()
|
||||
elif parsed_cfg.chat_template == "llama3":
|
||||
if parsed_cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
|
||||
)
|
||||
register_llama3_template(parsed_cfg.default_system_message)
|
||||
else:
|
||||
register_llama3_template()
|
||||
|
||||
if not parsed_cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
@@ -50,11 +70,10 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
with disable_datasets_caching():
|
||||
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
||||
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
if parsed_cli_args.download:
|
||||
model_name = parsed_cfg.base_model
|
||||
|
||||
@@ -3,11 +3,13 @@ CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Tuple, Union
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
@@ -18,7 +20,10 @@ from axolotl.cli import (
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.prompt_strategies.sharegpt import (
|
||||
register_chatml_template,
|
||||
register_llama3_template,
|
||||
)
|
||||
from axolotl.train import train
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.train")
|
||||
@@ -34,23 +39,32 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
return do_train(parsed_cfg, parsed_cli_args)
|
||||
|
||||
|
||||
def do_train(cfg, cli_args) -> None:
|
||||
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
if cfg.chat_template == "chatml" and cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"ChatML set. Adding default system message: {cfg.default_system_message}"
|
||||
)
|
||||
register_chatml_template(cfg.default_system_message)
|
||||
else:
|
||||
register_chatml_template()
|
||||
|
||||
if cfg.chat_template == "llama3" and cfg.default_system_message:
|
||||
LOG.info(
|
||||
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
|
||||
)
|
||||
register_llama3_template(cfg.default_system_message)
|
||||
else:
|
||||
register_llama3_template()
|
||||
|
||||
if cfg.rl: # and cfg.rl != "orpo":
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
|
||||
del model
|
||||
del tokenizer
|
||||
|
||||
plugin_manager.post_train_unload(cfg)
|
||||
return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -23,7 +23,7 @@ class TrainerCliArgs:
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=0)
|
||||
debug_num_examples: int = field(default=5)
|
||||
inference: bool = field(default=False)
|
||||
merge_lora: bool = field(default=False)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
"""
|
||||
ChatML transformation functions for MessageContents
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from ..messages import MessageContents, Messages
|
||||
from .shared import wrap_tools
|
||||
|
||||
|
||||
def format_message(
|
||||
message: Messages,
|
||||
message_index: Optional[int] = None, # pylint: disable=unused-argument
|
||||
) -> Messages:
|
||||
if message.is_chat_formatted:
|
||||
return message
|
||||
|
||||
# prepend the role prefix within a MessageContents to message.content
|
||||
message.content.insert(
|
||||
0,
|
||||
MessageContents(
|
||||
type="text",
|
||||
value=f"<|im_start|>{message.role}\n",
|
||||
weight=0,
|
||||
),
|
||||
)
|
||||
message.content.append(
|
||||
MessageContents(type="text", value="<|im_end|>", weight=message.weight)
|
||||
)
|
||||
message.content.append(MessageContents(type="text", value="\n", weight=0))
|
||||
|
||||
message = wrap_tools(message)
|
||||
|
||||
message.is_chat_formatted = True
|
||||
return message
|
||||
@@ -1,45 +0,0 @@
|
||||
"""
|
||||
Llama 3.x chat formatting functions for MessageContents
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from ..messages import MessageContents, Messages
|
||||
from .shared import wrap_tools
|
||||
|
||||
|
||||
def format_message(message: Messages, message_index: Optional[int] = None) -> Messages:
|
||||
if message.is_chat_formatted:
|
||||
return message
|
||||
|
||||
message_role = message.role
|
||||
if message.role == "tool":
|
||||
message_role = "ipython"
|
||||
|
||||
# prepend the role prefix within a MessageContents to message.content
|
||||
message.content.insert(
|
||||
0,
|
||||
MessageContents(
|
||||
type="text",
|
||||
value=f"<|start_header_id|>{message_role}<|end_header_id|>\n\n",
|
||||
weight=0,
|
||||
),
|
||||
)
|
||||
|
||||
message.content.append(
|
||||
MessageContents(type="text", value="<|eot_id|>", weight=message.weight)
|
||||
)
|
||||
|
||||
message = wrap_tools(message)
|
||||
|
||||
if message_index == 0:
|
||||
message.content.insert(
|
||||
0,
|
||||
MessageContents(
|
||||
type="text",
|
||||
value="<|begin_of_text|>",
|
||||
weight=0,
|
||||
),
|
||||
)
|
||||
|
||||
message.is_chat_formatted = True
|
||||
return message
|
||||
@@ -1,47 +0,0 @@
|
||||
"""
|
||||
shared functions for format transforms
|
||||
"""
|
||||
from axolotl.core.chat.messages import MessageContents, Messages
|
||||
|
||||
|
||||
def wrap_tools(message: Messages):
|
||||
# loop over message.content by index to find tool calls, we need to wrap each with tags,
|
||||
# so be wary of indexing issues when changing the list while iterating.
|
||||
# iterate over the range in reverse order to avoid index shifting
|
||||
for i in range(len(message.content) - 1, -1, -1):
|
||||
if message.content[i].type == "tool_call":
|
||||
# append a </tool_call> MessageContents text tag after
|
||||
message.content.insert(
|
||||
i + 1,
|
||||
MessageContents(
|
||||
type="text", value="</tool_call>\n", weight=message.weight
|
||||
),
|
||||
)
|
||||
# make sure the actual tool call content ends with a newline
|
||||
message.content[i].has_newline = True
|
||||
# prepend a <tool_call> MessageContents text tag before
|
||||
message.content.insert(
|
||||
i,
|
||||
MessageContents(
|
||||
type="text", value="<tool_call>\n", weight=message.weight
|
||||
),
|
||||
)
|
||||
elif message.content[i].type == "tool_response":
|
||||
# append a </tool_call> MessageContents text tag after
|
||||
message.content.insert(
|
||||
i + 1,
|
||||
MessageContents(
|
||||
type="text", value="</tool_response>\n", weight=message.weight
|
||||
),
|
||||
)
|
||||
# make sure the actual tool response content ends with a newline
|
||||
message.content[i].has_newline = True
|
||||
# prepend a <tool_call> MessageContents text tag before
|
||||
message.content.insert(
|
||||
i,
|
||||
MessageContents(
|
||||
type="text", value="<tool_response>\n", weight=message.weight
|
||||
),
|
||||
)
|
||||
|
||||
return message
|
||||
@@ -1,230 +0,0 @@
|
||||
"""
|
||||
internal message representations of chat messages
|
||||
"""
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
class MessageRoles(str, Enum):
|
||||
"""
|
||||
Message roles for the system, user, assistant, and tools
|
||||
"""
|
||||
|
||||
system = "system" # pylint: disable=invalid-name
|
||||
user = "user" # pylint: disable=invalid-name
|
||||
assistant = "assistant" # pylint: disable=invalid-name
|
||||
tool = "tool" # pylint: disable=invalid-name
|
||||
ipython = ( # pylint: disable=invalid-name
|
||||
# for responses from builtin tools
|
||||
"ipython"
|
||||
)
|
||||
|
||||
|
||||
class MessageContentTypes(str, Enum):
|
||||
"""
|
||||
Message content types for text, image, audio, tool calls, and tool responses
|
||||
"""
|
||||
|
||||
special_token = "special_token" # pylint: disable=invalid-name # nosec B105
|
||||
text = "text" # pylint: disable=invalid-name
|
||||
image = "image" # pylint: disable=invalid-name
|
||||
audio = "audio" # pylint: disable=invalid-name
|
||||
tool_call = "tool_call" # pylint: disable=invalid-name # to differentiate regular responses from tool calls from the assistant
|
||||
tool_response = "tool_response" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SpecialToken(str, Enum):
|
||||
"""
|
||||
Special tokens for beginning of string and end of string
|
||||
"""
|
||||
|
||||
bos_token = "bos_token" # pylint: disable=invalid-name # nosec B105
|
||||
eos_token = "eos_token" # pylint: disable=invalid-name # nosec B105
|
||||
|
||||
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""
|
||||
Tool call function with name and arguments
|
||||
"""
|
||||
|
||||
name: str
|
||||
arguments: dict[str, str]
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
"""
|
||||
Tool with description, function, and parameters
|
||||
"""
|
||||
|
||||
description: str
|
||||
function: ToolCallFunction
|
||||
parameters: dict[str, str] # .properties
|
||||
|
||||
|
||||
class ToolCallContents(BaseModel):
|
||||
"""
|
||||
Tool call contents with name, arguments, and optional id
|
||||
"""
|
||||
|
||||
name: str
|
||||
arguments: dict[str, Union[str, int]]
|
||||
id: Optional[str] = None # pylint: disable=invalid-name
|
||||
|
||||
def __str__(self) -> str:
|
||||
data = {"name": self.name, "arguments": self.arguments}
|
||||
if self.id is not None:
|
||||
data["id"] = self.id
|
||||
return json.dumps(data)
|
||||
|
||||
|
||||
class ToolResponseContents(BaseModel):
|
||||
"""
|
||||
Tool response contents with name, content, and optional id
|
||||
"""
|
||||
|
||||
name: str
|
||||
content: Union[str, dict[str, Union[str, int, float]]]
|
||||
id: Optional[str] = None # pylint: disable=invalid-name
|
||||
|
||||
def __str__(self) -> str:
|
||||
data = {"name": self.name, "content": self.content}
|
||||
if self.id is not None:
|
||||
data["id"] = self.id
|
||||
return json.dumps(data)
|
||||
|
||||
|
||||
class MessageContents(BaseModel):
|
||||
"""
|
||||
Message contents with type, value, metadata, weight, newline, and end of contents
|
||||
"""
|
||||
|
||||
type: Union[str, MessageContentTypes]
|
||||
value: Union[str, ToolCallContents, ToolResponseContents, SpecialToken]
|
||||
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
|
||||
weight: Optional[Union[int, float]] = None
|
||||
has_newline: bool = False
|
||||
eoc: bool = False # end of contents
|
||||
|
||||
def __str__(self) -> str:
|
||||
str_val = str(self.value)
|
||||
if self.has_newline and not str_val.endswith("\n"):
|
||||
str_val += "\n"
|
||||
return str_val
|
||||
|
||||
|
||||
class Messages(BaseModel):
|
||||
"""
|
||||
Messages with role, content, metadata, weight, and chat formatting
|
||||
"""
|
||||
|
||||
role: Union[MessageRoles, str] # allows for arbitrary roles
|
||||
content: List["MessageContents"]
|
||||
meta: Optional[dict[str, Any]] = None # support additional arbitrary metadata
|
||||
weight: Optional[Union[int, float]] = None
|
||||
is_chat_formatted: bool = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "".join(str(c) for c in self.content)
|
||||
|
||||
def tokenized(
|
||||
self, tokenizer: PreTrainedTokenizer, ignore_index=-100
|
||||
) -> dict[str, List[int]]:
|
||||
# iterate over the contents, tokenizing the concatenated string values up to the current MessageContents
|
||||
# returns a dictionary mapping w input_ids, attention_mask, and labels
|
||||
input_ids: List[int] = []
|
||||
labels: List[int] = []
|
||||
pending_input_ids: List[int] = []
|
||||
pending_weight = self.weight
|
||||
running_content = ""
|
||||
for _, msg_content in enumerate(self.content):
|
||||
# TODO also handle non-text content types
|
||||
if msg_content.type in [
|
||||
MessageContentTypes.text.value,
|
||||
MessageContentTypes.tool_call.value,
|
||||
MessageContentTypes.tool_response.value,
|
||||
]:
|
||||
running_content += str(msg_content)
|
||||
tok_results = tokenizer(running_content, add_special_tokens=False)
|
||||
tok_input_ids = tok_results["input_ids"]
|
||||
if pending_input_ids:
|
||||
new_pending_inputs = tok_input_ids[
|
||||
len(input_ids) : len(input_ids) + len(pending_input_ids)
|
||||
]
|
||||
if new_pending_inputs != pending_input_ids:
|
||||
# logging.warning("tokenization mismatch from concatenation.")
|
||||
pending_input_ids = new_pending_inputs
|
||||
input_ids.extend(pending_input_ids)
|
||||
if pending_weight:
|
||||
labels.extend(pending_input_ids)
|
||||
else:
|
||||
labels.extend([ignore_index] * len(pending_input_ids))
|
||||
pending_input_ids = tok_results["input_ids"][len(input_ids) :]
|
||||
pending_weight = self.weight and msg_content.weight not in [0, 0.0]
|
||||
input_ids.extend(pending_input_ids)
|
||||
if pending_weight:
|
||||
labels.extend(pending_input_ids)
|
||||
else:
|
||||
labels.extend([ignore_index] * len(pending_input_ids))
|
||||
attention_mask = [1] * len(input_ids)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
class Chats(BaseModel):
|
||||
"""
|
||||
top level data structure for chat conversations
|
||||
"""
|
||||
|
||||
conversation: List[Messages]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "".join(str(c) for c in self.conversation)
|
||||
|
||||
def tokenized(
|
||||
self, tokenizer: Callable[[str], dict[str, List[int]]], ignore_index=-100
|
||||
) -> dict[str, List[int]]:
|
||||
input_ids = []
|
||||
attention_mask = []
|
||||
labels = []
|
||||
for msg in self.conversation:
|
||||
msg_results = msg.tokenized(tokenizer, ignore_index)
|
||||
input_ids.extend(msg_results["input_ids"])
|
||||
attention_mask.extend(msg_results["attention_mask"])
|
||||
labels.extend(msg_results["labels"])
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"labels": labels,
|
||||
}
|
||||
|
||||
|
||||
class ChatFormattedChats(Chats):
|
||||
"""
|
||||
Chat formatted chats with formatter and optional train on inputs
|
||||
"""
|
||||
|
||||
formatter: Callable # [[Union[dict, Chats]], Chats]
|
||||
train_on_inputs: bool = False
|
||||
|
||||
def model_post_init(self, __context):
|
||||
for i, msg in enumerate(self.conversation):
|
||||
self.conversation[i] = self.formatter(msg, message_index=i)
|
||||
if self.train_on_inputs:
|
||||
self.conversation[i].weight = 1
|
||||
|
||||
|
||||
class PreferenceChats(BaseModel):
|
||||
"""
|
||||
representation for preference data for chat
|
||||
"""
|
||||
|
||||
prompt: List[Messages]
|
||||
chosen: Messages
|
||||
rejected: Messages
|
||||
@@ -1,55 +0,0 @@
|
||||
"""
|
||||
chat dataset module
|
||||
"""
|
||||
import os
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from datasets import Dataset
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from axolotl.core.chat.messages import ChatFormattedChats
|
||||
|
||||
|
||||
class TokenizedChatDataset(Dataset):
|
||||
"""
|
||||
Tokenized chat dataset
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Dataset,
|
||||
model_transform: Union[PreTrainedTokenizer, Callable],
|
||||
*args,
|
||||
message_transform: Optional[Callable] = None,
|
||||
formatter=None,
|
||||
process_count: Optional[int] = None,
|
||||
keep_in_memory: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
def map_fn(ex):
|
||||
if message_transform is not None:
|
||||
ex = message_transform(ex)
|
||||
if formatter is not None:
|
||||
ex = ChatFormattedChats(
|
||||
formatter=formatter,
|
||||
**ex,
|
||||
)
|
||||
else:
|
||||
ex = ChatFormattedChats(
|
||||
**ex,
|
||||
)
|
||||
return ex.tokenized(model_transform)
|
||||
|
||||
process_or_cpu_count: int = (
|
||||
process_count or os.cpu_count() # type: ignore[assignment]
|
||||
)
|
||||
num_proc = min(64, process_or_cpu_count)
|
||||
features = data.features.keys()
|
||||
tokenized_data = data.map(
|
||||
map_fn,
|
||||
num_proc=num_proc,
|
||||
keep_in_memory=keep_in_memory,
|
||||
remove_columns=features,
|
||||
desc="Tokenizing Chats",
|
||||
)
|
||||
super().__init__(tokenized_data.data, *args, **kwargs)
|
||||
@@ -1,150 +0,0 @@
|
||||
"""
|
||||
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
|
||||
"""
|
||||
from typing import Any, Mapping, Union
|
||||
|
||||
|
||||
def chat_message_transform_builder( # pylint: disable=dangerous-default-value
|
||||
train_on_inputs=False,
|
||||
conversations_field: str = "conversations",
|
||||
message_field_role: Union[str, list[str]] = ["role", "from"], # commonly "role"
|
||||
message_field_content: Union[str, list[str]] = [
|
||||
"value",
|
||||
"text",
|
||||
"content",
|
||||
], # commonly "content"
|
||||
message_field_training: Union[str, list[str]] = [
|
||||
"train",
|
||||
"weight",
|
||||
], # commonly "weight"
|
||||
):
|
||||
"""Builds a transform that takes a row from the dataset and converts it to a Chat
|
||||
|
||||
Args:
|
||||
train_on_inputs (bool, optional):
|
||||
If True, the transform will train on the inputs. If False, the transform will train on the targets.
|
||||
Defaults to False.
|
||||
conversations_field (str, optional):
|
||||
The field name of the conversations. Defaults to "conversations".
|
||||
message_field_role (str | list[str], optional):
|
||||
The field name of the role. Defaults to "role".
|
||||
message_field_content (str | list[str], optional):
|
||||
The field name of the message content. Defaults to "content".
|
||||
message_field_training (str | list[str], optional):
|
||||
The field name of the train/weight. Defaults to "weight".
|
||||
|
||||
Returns:
|
||||
Callable:
|
||||
A function that takes a list of conversations and returns a list of messages.
|
||||
"""
|
||||
|
||||
message_field_role = (
|
||||
[message_field_role]
|
||||
if isinstance(message_field_role, str)
|
||||
else message_field_role
|
||||
)
|
||||
message_field_content = (
|
||||
[message_field_content]
|
||||
if isinstance(message_field_content, str)
|
||||
else message_field_content
|
||||
)
|
||||
message_weight_fields = (
|
||||
[message_field_training]
|
||||
if isinstance(message_field_training, str)
|
||||
else message_field_training
|
||||
)
|
||||
|
||||
role_value_mappings = {
|
||||
"system": "system",
|
||||
"user": "user",
|
||||
"human": "user",
|
||||
"assistant": "assistant",
|
||||
"gpt": "assistant",
|
||||
"tool": "tool",
|
||||
"ipython": "ipython",
|
||||
}
|
||||
if train_on_inputs:
|
||||
role_default_weights_mappings = {
|
||||
"system": 1,
|
||||
"user": 1,
|
||||
"assistant": 1,
|
||||
"tool": 1,
|
||||
"ipython": 1,
|
||||
}
|
||||
else:
|
||||
role_default_weights_mappings = {
|
||||
"system": 0,
|
||||
"user": 0,
|
||||
"assistant": 1,
|
||||
"tool": 0,
|
||||
"ipython": 0,
|
||||
}
|
||||
|
||||
def transform_builder(sample: Mapping[str, Any]):
|
||||
if conversations_field not in sample:
|
||||
raise ValueError(f"Field '{conversations_field}' not found in sample.")
|
||||
# if none of the role fields are in the message, raise an error
|
||||
if not any(
|
||||
role in sample[conversations_field][0] for role in message_field_role
|
||||
):
|
||||
raise ValueError("No role field found in message.")
|
||||
role_field = next(
|
||||
role
|
||||
for role in message_field_role
|
||||
if role in sample[conversations_field][0]
|
||||
)
|
||||
if not any(
|
||||
field in sample[conversations_field][0] for field in message_field_content
|
||||
):
|
||||
raise ValueError("No message_content field found in message.")
|
||||
message_content_field = next(
|
||||
field
|
||||
for field in message_field_content
|
||||
if field in sample[conversations_field][0]
|
||||
)
|
||||
if not any(
|
||||
field in sample[conversations_field][0] for field in message_field_training
|
||||
):
|
||||
message_weight_field = None
|
||||
else:
|
||||
message_weight_field = next(
|
||||
field
|
||||
for field in message_weight_fields
|
||||
if field in sample[conversations_field][0]
|
||||
)
|
||||
|
||||
messages = []
|
||||
for message in sample[conversations_field]:
|
||||
role = role_value_mappings[message[role_field]]
|
||||
weight = (
|
||||
int(message[message_weight_field])
|
||||
if message_weight_field
|
||||
else role_default_weights_mappings[role]
|
||||
)
|
||||
|
||||
# TODO if "tool_calls" in message[message_content_field]: then convert tool call to ToolCallContents
|
||||
if isinstance(message[message_content_field], str):
|
||||
messages.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"value": message[message_content_field],
|
||||
}
|
||||
],
|
||||
"weight": weight,
|
||||
}
|
||||
)
|
||||
else:
|
||||
messages.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": message[message_content_field],
|
||||
"weight": weight,
|
||||
}
|
||||
)
|
||||
|
||||
return {"conversation": messages}
|
||||
|
||||
return transform_builder
|
||||
@@ -7,7 +7,6 @@ import abc
|
||||
import gc
|
||||
import importlib
|
||||
import importlib.util
|
||||
import inspect
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -28,6 +27,7 @@ from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import (
|
||||
EarlyStoppingCallback,
|
||||
PreTrainedModel,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
@@ -43,15 +43,12 @@ from trl import (
|
||||
KTOTrainer,
|
||||
ORPOConfig,
|
||||
ORPOTrainer,
|
||||
RewardConfig,
|
||||
RewardTrainer,
|
||||
)
|
||||
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
|
||||
from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils import is_mlflow_available
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
@@ -64,7 +61,7 @@ from axolotl.utils.callbacks import (
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
from axolotl.utils.collators import (
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
@@ -304,13 +301,6 @@ class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
|
||||
"""
|
||||
Reward config for Reward training
|
||||
"""
|
||||
|
||||
|
||||
class SchedulerMixin(Trainer):
|
||||
"""
|
||||
Mixin class for scheduler setup in CausalTrainer.
|
||||
@@ -408,10 +398,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
def __init__(
|
||||
self,
|
||||
*_args,
|
||||
num_epochs=1,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_epochs = num_epochs
|
||||
self.bench_data_collator = bench_data_collator
|
||||
self.eval_data_collator = eval_data_collator
|
||||
super().__init__(*_args, **kwargs)
|
||||
@@ -436,13 +428,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
if (
|
||||
self.args.loraplus_lr_ratio is None
|
||||
and self.args.alternate_optimizer
|
||||
not in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]
|
||||
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
|
||||
):
|
||||
return super().create_optimizer()
|
||||
|
||||
@@ -511,14 +497,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||
)
|
||||
elif self.args.alternate_optimizer == "adopt_adamw":
|
||||
from axolotl.utils.optimizers.adopt import ADOPT
|
||||
|
||||
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||
ADOPT(
|
||||
optimizer_grouped_parameters, decoupled=True, **optimizer_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||
@@ -681,9 +659,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# use one's weighted cross entropy loss calc
|
||||
# if self.args.sample_packing:
|
||||
# labels = inputs.pop("labels")
|
||||
@@ -691,18 +667,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
if self.args.orpo_alpha:
|
||||
return self.orpo_compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=return_outputs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
return super().compute_loss(
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=return_outputs,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
return self.orpo_compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
@staticmethod
|
||||
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
|
||||
@@ -798,13 +764,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
).squeeze(2)
|
||||
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
|
||||
|
||||
def orpo_compute_loss(
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
):
|
||||
def orpo_compute_loss(self, model, inputs, return_outputs=False):
|
||||
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
|
||||
inputs,
|
||||
label_pad_token=-100,
|
||||
@@ -910,13 +870,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
def _save_checkpoint(self, model, trial, **kwargs):
|
||||
def _save_checkpoint(self, model, trial, metrics=None):
|
||||
# make sure the checkpoint dir exists, since trainer is flakey
|
||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
||||
run_dir = self._get_output_dir(trial=trial)
|
||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
return super()._save_checkpoint(model, trial, **kwargs)
|
||||
return super()._save_checkpoint(model, trial, metrics=metrics)
|
||||
|
||||
|
||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
@@ -931,7 +891,6 @@ class AxolotlMambaTrainer(AxolotlTrainer):
|
||||
model,
|
||||
inputs,
|
||||
return_outputs=False, # pylint: disable=unused-argument
|
||||
num_items_in_batch=None, # pylint: disable=unused-argument
|
||||
):
|
||||
input_ids = inputs.pop("input_ids")
|
||||
lm_logits = model(input_ids).logits
|
||||
@@ -1039,32 +998,18 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
def tokenize_row(
|
||||
self,
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
self, feature, model: Optional[Union[PreTrainedModel, torch.nn.Module]] = None
|
||||
) -> Dict:
|
||||
res = super().tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
)
|
||||
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
||||
res = super().tokenize_row(feature, model=model)
|
||||
if self.tokenizer.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
num_items_in_batch=None,
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
|
||||
) -> torch.Tensor:
|
||||
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
|
||||
loss: torch.Tensor = super().training_step(model, inputs)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
@@ -1094,14 +1039,6 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
Extend the base RewardTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
Base class for trainer builder
|
||||
@@ -1162,49 +1099,26 @@ class TrainerBuilderBase(abc.ABC):
|
||||
|
||||
def get_callbacks(self) -> List[TrainerCallback]:
|
||||
callbacks = []
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
|
||||
)
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
if self.cfg.use_mlflow and is_mlflow_available():
|
||||
from transformers.integrations.integration_utils import MLflowCallback
|
||||
|
||||
from axolotl.utils.callbacks.mlflow_ import (
|
||||
SaveAxolotlConfigtoMlflowCallback,
|
||||
)
|
||||
|
||||
callbacks.extend(
|
||||
[
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path),
|
||||
MLflowCallback,
|
||||
]
|
||||
)
|
||||
if self.cfg.use_comet and is_comet_available():
|
||||
from axolotl.utils.callbacks.comet_ import SaveAxolotlConfigtoCometCallback
|
||||
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
|
||||
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
return callbacks
|
||||
|
||||
@abstractmethod
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
"""
|
||||
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
||||
"""
|
||||
callbacks = []
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
|
||||
)
|
||||
return callbacks
|
||||
|
||||
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
||||
# TODO
|
||||
@@ -1250,7 +1164,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
callbacks = []
|
||||
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer, "wandb"
|
||||
@@ -1265,11 +1179,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
trainer, self.tokenizer, "mlflow"
|
||||
)
|
||||
callbacks.append(LogPredictionCallback(self.cfg))
|
||||
if self.cfg.use_comet and is_comet_available() and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer, "comet_ml"
|
||||
)
|
||||
callbacks.append(LogPredictionCallback(self.cfg))
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
||||
@@ -1287,18 +1196,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||
callbacks.append(lisa_callback_factory(trainer))
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
[
|
||||
cb
|
||||
for cb in plugin_manager.add_callbacks_post_trainer(
|
||||
self.cfg, trainer
|
||||
)
|
||||
if cb
|
||||
]
|
||||
)
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self):
|
||||
@@ -1306,8 +1203,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return ReLoRATrainer
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return AxolotlMambaTrainer
|
||||
if self.cfg.reward_model:
|
||||
return AxolotlRewardTrainer
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
@@ -1535,16 +1430,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
report_to.append("mlflow")
|
||||
if self.cfg.use_tensorboard:
|
||||
report_to.append("tensorboard")
|
||||
if self.cfg.use_comet:
|
||||
report_to.append("comet_ml")
|
||||
|
||||
training_arguments_kwargs["report_to"] = report_to
|
||||
if self.cfg.use_wandb:
|
||||
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
|
||||
elif self.cfg.use_mlflow:
|
||||
training_arguments_kwargs["run_name"] = self.cfg.mlflow_run_name
|
||||
else:
|
||||
training_arguments_kwargs["run_name"] = None
|
||||
training_arguments_kwargs["run_name"] = (
|
||||
self.cfg.wandb_name if self.cfg.use_wandb else None
|
||||
)
|
||||
training_arguments_kwargs["optim"] = (
|
||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||
)
|
||||
@@ -1633,9 +1523,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
|
||||
if self.cfg.chat_template:
|
||||
training_arguments_kwargs["chat_template"] = get_chat_template(
|
||||
self.cfg.chat_template,
|
||||
tokenizer=self.tokenizer,
|
||||
training_arguments_kwargs["chat_template"] = chat_templates(
|
||||
self.cfg.chat_template
|
||||
)
|
||||
|
||||
if self.cfg.rl == "orpo":
|
||||
@@ -1648,16 +1537,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
trainer_kwargs = {}
|
||||
|
||||
if self.cfg.reward_model:
|
||||
trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
if self.cfg.optimizer in [
|
||||
"optimi_adamw",
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
]:
|
||||
# Set default so transformers doesn't throw
|
||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||
@@ -1696,13 +1580,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"accelerator_config"
|
||||
] = self.cfg.accelerator_config
|
||||
|
||||
training_args_cls = (
|
||||
AxolotlTrainingArguments
|
||||
if not self.cfg.reward_model
|
||||
else AxolotlRewardConfig
|
||||
)
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
**training_arguments_kwargs,
|
||||
training_args = (
|
||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
)
|
||||
training_args = self.hook_post_create_training_args(training_args)
|
||||
|
||||
@@ -1724,37 +1605,27 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
|
||||
if self.cfg.reward_model:
|
||||
data_collator_kwargs["max_length"] = self.cfg.sequence_len
|
||||
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||
trainer_kwargs, trainer_cls
|
||||
)
|
||||
if eval_data_collator := self.build_collator(
|
||||
training_args, is_eval=True, **data_collator_kwargs
|
||||
):
|
||||
if not self.cfg.reward_model:
|
||||
trainer_kwargs["eval_data_collator"] = eval_data_collator
|
||||
if not self.cfg.reward_model:
|
||||
trainer_kwargs["bench_data_collator"] = transformers.DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
)
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "processing_class" in sig.parameters.keys():
|
||||
trainer_kwargs["processing_class"] = self.tokenizer
|
||||
else:
|
||||
trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
|
||||
trainer = trainer_cls(
|
||||
model=self.model,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
data_collator=self.build_collator(training_args, **data_collator_kwargs),
|
||||
eval_data_collator=self.build_collator(
|
||||
training_args, is_eval=True, **data_collator_kwargs
|
||||
),
|
||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
callbacks=self.get_callbacks(),
|
||||
num_epochs=self.cfg.num_epochs,
|
||||
**trainer_kwargs,
|
||||
)
|
||||
trainer = self.hook_post_create_trainer(trainer)
|
||||
@@ -1788,14 +1659,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
BatchSamplerDataCollatorForSeq2Seq,
|
||||
DataCollatorForSeq2Seq,
|
||||
RewardDataCollatorWithPadding,
|
||||
]
|
||||
]
|
||||
if self.cfg.reward_model:
|
||||
collator = RewardDataCollatorWithPadding
|
||||
if "max_length" in kwargs:
|
||||
kwargs.pop("max_length")
|
||||
elif use_batch_sampler_collator:
|
||||
if use_batch_sampler_collator:
|
||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
elif (
|
||||
@@ -1832,7 +1698,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
callbacks = []
|
||||
return callbacks
|
||||
|
||||
def build_training_arguments(self, total_num_steps):
|
||||
@@ -1918,18 +1784,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
|
||||
if self.cfg.rl_beta:
|
||||
training_args_kwargs["beta"] = self.cfg.rl_beta
|
||||
if self.cfg.orpo_alpha:
|
||||
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
||||
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
||||
|
||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
if self.cfg.rpo_alpha is not None:
|
||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||
|
||||
training_args_cls = None
|
||||
if self.cfg.rl == "simpo":
|
||||
training_args_cls = AxolotlCPOConfig
|
||||
training_args_kwargs["loss_type"] = "simpo"
|
||||
@@ -1938,13 +1803,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.cpo_alpha is not None:
|
||||
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
|
||||
|
||||
elif self.cfg.rl == "orpo":
|
||||
if self.cfg.rl == "orpo":
|
||||
training_args_cls = AxolotlORPOConfig
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
elif self.cfg.rl == "kto":
|
||||
if self.cfg.rl == "kto":
|
||||
training_args_cls = AxolotlKTOConfig
|
||||
|
||||
training_args_kwargs["desirable_weight"] = (
|
||||
@@ -1959,17 +1824,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.max_prompt_len:
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
||||
|
||||
else:
|
||||
training_args_cls = AxolotlDPOConfig
|
||||
if self.cfg.rl == "ipo":
|
||||
training_args_kwargs["loss_type"] = "ipo"
|
||||
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["max_completion_length"] = None
|
||||
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
|
||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||
output_dir=self.cfg.output_dir,
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
@@ -1990,6 +1844,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args = self.build_training_arguments(total_num_steps)
|
||||
dpo_trainer_kwargs = {}
|
||||
if self.cfg.rl == "ipo":
|
||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
||||
if self.cfg.dpo_label_smoothing:
|
||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||
if self.eval_dataset:
|
||||
@@ -2003,6 +1858,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.rl in ["dpo", "ipo"]:
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
trainer_cls_args = [self.model, self.model_ref]
|
||||
|
||||
# these aren't used for the ORPO trainer
|
||||
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["max_target_length"] = None
|
||||
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
||||
dpo_trainer_kwargs["generate_during_eval"] = True
|
||||
elif self.cfg.rl == "orpo":
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
trainer_cls_args = [self.model]
|
||||
@@ -2014,17 +1875,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
trainer_cls_args = [self.model]
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "processing_class" in sig.parameters.keys():
|
||||
dpo_trainer_kwargs["processing_class"] = self.tokenizer
|
||||
else:
|
||||
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
|
||||
|
||||
dpo_trainer = trainer_cls(
|
||||
*trainer_cls_args,
|
||||
args=training_args,
|
||||
train_dataset=self.train_dataset,
|
||||
tokenizer=self.tokenizer,
|
||||
callbacks=self.get_callbacks(),
|
||||
**dpo_trainer_kwargs,
|
||||
)
|
||||
@@ -2046,11 +1901,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = super().get_callbacks()
|
||||
callbacks = []
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
callbacks = []
|
||||
return callbacks
|
||||
|
||||
def build(self, total_num_steps):
|
||||
|
||||
@@ -18,10 +18,9 @@ Plugins can be used to integrate third-party models, modify the training process
|
||||
|
||||
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
|
||||
"""
|
||||
import collections
|
||||
import importlib
|
||||
import logging
|
||||
from typing import OrderedDict
|
||||
from typing import List
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
@@ -48,7 +47,7 @@ class BasePlugin:
|
||||
Initializes the BasePlugin.
|
||||
"""
|
||||
|
||||
def register(self, cfg): # pylint: disable=unused-argument
|
||||
def register(self, cfg):
|
||||
"""
|
||||
Registers the plugin with the given configuration.
|
||||
|
||||
@@ -64,7 +63,7 @@ class BasePlugin:
|
||||
Returns a pydantic model for the plugin's input arguments.
|
||||
"""
|
||||
|
||||
def pre_model_load(self, cfg): # pylint: disable=unused-argument
|
||||
def pre_model_load(self, cfg):
|
||||
"""
|
||||
Performs actions before the model is loaded.
|
||||
|
||||
@@ -75,7 +74,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
def post_model_load(self, cfg, model):
|
||||
"""
|
||||
Performs actions after the model is loaded.
|
||||
|
||||
@@ -87,7 +86,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
def pre_lora_load(self, cfg, model):
|
||||
"""
|
||||
Performs actions before LoRA weights are loaded.
|
||||
|
||||
@@ -99,7 +98,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
def post_lora_load(self, cfg, model):
|
||||
"""
|
||||
Performs actions after LoRA weights are loaded.
|
||||
|
||||
@@ -111,7 +110,7 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
def create_optimizer(self, cfg, trainer):
|
||||
"""
|
||||
Creates and returns an optimizer for training.
|
||||
|
||||
@@ -123,9 +122,7 @@ class BasePlugin:
|
||||
object: The created optimizer.
|
||||
"""
|
||||
|
||||
def create_lr_scheduler(
|
||||
self, cfg, trainer, optimizer
|
||||
): # pylint: disable=unused-argument
|
||||
def create_lr_scheduler(self, cfg, trainer, optimizer):
|
||||
"""
|
||||
Creates and returns a learning rate scheduler.
|
||||
|
||||
@@ -138,9 +135,9 @@ class BasePlugin:
|
||||
object: The created learning rate scheduler.
|
||||
"""
|
||||
|
||||
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||
def add_callbacks_pre_trainer(self, cfg, model):
|
||||
"""
|
||||
setup callbacks before creating the trainer.
|
||||
Adds callbacks to the trainer before training.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
@@ -149,45 +146,17 @@ class BasePlugin:
|
||||
Returns:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||
"""
|
||||
return []
|
||||
|
||||
def add_callbacks_post_trainer(
|
||||
self, cfg, trainer
|
||||
): # pylint: disable=unused-argument
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
"""
|
||||
Adds callbacks to the trainer after creating the trainer.
|
||||
This is useful for callbacks that require access to the model or trainer.
|
||||
Adds callbacks to the trainer after training.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
|
||||
Returns:
|
||||
List[callable]: A list of callback functions to be added
|
||||
"""
|
||||
return []
|
||||
|
||||
def post_train(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after training is complete.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The axolotl configuration
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
def post_train_unload(self, cfg): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after training is complete and the model is unloaded.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
|
||||
Returns:
|
||||
None
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||
"""
|
||||
|
||||
|
||||
@@ -235,7 +204,7 @@ class PluginManager:
|
||||
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
|
||||
"""
|
||||
|
||||
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
|
||||
plugins: List[BasePlugin] = []
|
||||
|
||||
_instance = None
|
||||
|
||||
@@ -245,7 +214,7 @@ class PluginManager:
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super(PluginManager, cls).__new__(cls)
|
||||
cls._instance.plugins = collections.OrderedDict()
|
||||
cls._instance.plugins: List[BasePlugin] = []
|
||||
return cls._instance
|
||||
|
||||
@staticmethod
|
||||
@@ -273,7 +242,7 @@ class PluginManager:
|
||||
"""
|
||||
try:
|
||||
plugin = load_plugin(plugin_name)
|
||||
self.plugins[plugin_name] = plugin
|
||||
self.plugins.append(plugin)
|
||||
except ImportError:
|
||||
logging.error(f"Failed to load plugin: {plugin_name}")
|
||||
|
||||
@@ -285,7 +254,7 @@ class PluginManager:
|
||||
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
|
||||
"""
|
||||
input_args = []
|
||||
for plugin in self.plugins.values():
|
||||
for plugin in self.plugins:
|
||||
input_args_from_plugin = plugin.get_input_args()
|
||||
if input_args_from_plugin is not None:
|
||||
input_args.append(input_args_from_plugin)
|
||||
@@ -301,7 +270,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
for plugin in self.plugins:
|
||||
plugin.pre_model_load(cfg)
|
||||
|
||||
def post_model_load(self, cfg, model):
|
||||
@@ -315,7 +284,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
for plugin in self.plugins:
|
||||
plugin.post_model_load(cfg, model)
|
||||
|
||||
def pre_lora_load(self, cfg, model):
|
||||
@@ -329,7 +298,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
for plugin in self.plugins:
|
||||
plugin.pre_lora_load(cfg, model)
|
||||
|
||||
def post_lora_load(self, cfg, model):
|
||||
@@ -343,7 +312,7 @@ class PluginManager:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
for plugin in self.plugins:
|
||||
plugin.post_lora_load(cfg, model)
|
||||
|
||||
def create_optimizer(self, cfg, trainer):
|
||||
@@ -357,7 +326,7 @@ class PluginManager:
|
||||
Returns:
|
||||
object: The created optimizer, or None if none was found.
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
for plugin in self.plugins:
|
||||
optimizer = plugin.create_optimizer(cfg, trainer)
|
||||
if optimizer is not None:
|
||||
return optimizer
|
||||
@@ -375,7 +344,7 @@ class PluginManager:
|
||||
Returns:
|
||||
object: The created learning rate scheduler, or None if none was found.
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
for plugin in self.plugins:
|
||||
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
|
||||
if scheduler is not None:
|
||||
return scheduler
|
||||
@@ -393,10 +362,8 @@ class PluginManager:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
||||
"""
|
||||
callbacks = []
|
||||
for plugin in self.plugins.values():
|
||||
plugin_callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
|
||||
if plugin_callbacks: # if the plugin returned a list of callbacks
|
||||
callbacks.extend(plugin_callbacks)
|
||||
for plugin in self.plugins:
|
||||
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
|
||||
return callbacks
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
@@ -411,22 +378,6 @@ class PluginManager:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs.
|
||||
"""
|
||||
callbacks = []
|
||||
for plugin in self.plugins.values():
|
||||
plugin_callbacks = plugin.add_callbacks_post_trainer(cfg, trainer)
|
||||
if plugin_callbacks:
|
||||
callbacks.extend(plugin_callbacks)
|
||||
for plugin in self.plugins:
|
||||
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
|
||||
return callbacks
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
"""
|
||||
Calls the post_train_unload method of all registered plugins.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_train_unload(cfg)
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -1,13 +0,0 @@
|
||||
# Grokfast Optimizer
|
||||
|
||||
See https://github.com/ironjr/grokfast
|
||||
|
||||
### Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.grokfast.GrokfastPlugin
|
||||
|
||||
grokfast_alpha: 2.0
|
||||
grokfast_lamb: 0.98
|
||||
```
|
||||
@@ -1,50 +0,0 @@
|
||||
"""
|
||||
Grokfast plugin for Axolotl
|
||||
"""
|
||||
import logging
|
||||
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
from ..base import BasePlugin
|
||||
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
|
||||
from .optimizer import gradfilter_ema
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.grokfast")
|
||||
|
||||
|
||||
class GrokfastCallbackHandler(TrainerCallback):
|
||||
"""
|
||||
Transformer trainer callbacks for Grokfast
|
||||
"""
|
||||
|
||||
def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs):
|
||||
super().__init__(*args_, **kwargs)
|
||||
self.grads = None
|
||||
self.alpha = alpha
|
||||
self.lamb = lamb
|
||||
|
||||
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
|
||||
self.grads = None
|
||||
|
||||
def on_pre_optimizer_step(
|
||||
self, args_, state, control, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
model = kwargs.pop("model")
|
||||
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
|
||||
return control
|
||||
|
||||
|
||||
class GrokfastPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for Grokfast optimizer integraton with Axolotl.
|
||||
"""
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.grokfast.GrokfastArgs"
|
||||
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
LOG.info("Adding Grokfast callback to the trainer")
|
||||
callback = GrokfastCallbackHandler(
|
||||
alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb
|
||||
)
|
||||
return [callback]
|
||||
@@ -1,15 +0,0 @@
|
||||
"""
|
||||
config args for grokfast plugin
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GrokfastArgs(BaseModel):
|
||||
"""
|
||||
Input args for Grokfast optimizer.
|
||||
"""
|
||||
|
||||
grokfast_alpha: Optional[float] = 0.98
|
||||
grokfast_lamb: Optional[float] = 2.0
|
||||
@@ -1,63 +0,0 @@
|
||||
# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
|
||||
# Reference: https://github.com/ironjr/grokfast
|
||||
|
||||
# pylint: skip-file
|
||||
from collections import deque
|
||||
from typing import Dict, Literal, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def gradfilter_ma(
|
||||
m: nn.Module,
|
||||
grads: Optional[Dict[str, deque]] = None,
|
||||
window_size: int = 100,
|
||||
lamb: float = 5.0,
|
||||
filter_type: Literal["mean", "sum"] = "mean",
|
||||
warmup: bool = True,
|
||||
trigger: bool = False, # For ablation study.
|
||||
) -> Dict[str, deque]:
|
||||
if grads is None:
|
||||
grads = {
|
||||
n: deque(maxlen=window_size)
|
||||
for n, p in m.named_parameters()
|
||||
if p.requires_grad and p.grad is not None
|
||||
}
|
||||
|
||||
for n, p in m.named_parameters():
|
||||
if p.requires_grad and p.grad is not None:
|
||||
grads[n].append(p.grad.data.detach()) # .cpu())
|
||||
|
||||
# Modify the gradients.
|
||||
if not warmup or len(grads[n]) == window_size and not trigger:
|
||||
if filter_type == "mean":
|
||||
avg = sum(grads[n]) / len(grads[n])
|
||||
elif filter_type == "sum":
|
||||
avg = sum(grads[n])
|
||||
else:
|
||||
raise ValueError(f"Unrecognized filter_type {filter_type}")
|
||||
p.grad.data = p.grad.data + avg * lamb
|
||||
|
||||
return grads
|
||||
|
||||
|
||||
def gradfilter_ema(
|
||||
m: nn.Module,
|
||||
grads: Optional[Dict[str, torch.Tensor]] = None,
|
||||
alpha: float = 0.98,
|
||||
lamb: float = 2.0,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if grads is None:
|
||||
grads = {
|
||||
n: p.grad.data.detach()
|
||||
for n, p in m.named_parameters()
|
||||
if p.requires_grad and p.grad is not None
|
||||
}
|
||||
|
||||
for n, p in m.named_parameters():
|
||||
if p.requires_grad and p.grad is not None:
|
||||
grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
|
||||
p.grad.data = p.grad.data + grads[n] * lamb
|
||||
|
||||
return grads
|
||||
@@ -18,24 +18,20 @@ Module for the Plugin for LIGER integraton with Axolotl.
|
||||
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
||||
It is designed to be performant, correct, and light-weight.
|
||||
"""
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
from ...utils.distributed import zero_only
|
||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.liger")
|
||||
|
||||
|
||||
class LigerPlugin(BasePlugin):
|
||||
"""
|
||||
@@ -46,31 +42,59 @@ class LigerPlugin(BasePlugin):
|
||||
return "axolotl.integrations.liger.LigerArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||
kwargs = {}
|
||||
if "rope" in liger_fn_sig.parameters:
|
||||
kwargs["rope"] = cfg.liger_rope
|
||||
if "cross_entropy" in liger_fn_sig.parameters:
|
||||
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
||||
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
||||
kwargs[
|
||||
"fused_linear_cross_entropy"
|
||||
] = cfg.liger_fused_linear_cross_entropy
|
||||
if "rms_norm" in liger_fn_sig.parameters:
|
||||
kwargs["rms_norm"] = cfg.liger_rms_norm
|
||||
if "layer_norm" in liger_fn_sig.parameters:
|
||||
kwargs["layer_norm"] = cfg.liger_layer_norm
|
||||
if "geglu" in liger_fn_sig.parameters:
|
||||
kwargs["geglu"] = cfg.liger_glu_activation
|
||||
elif "swiglu" in liger_fn_sig.parameters:
|
||||
kwargs["swiglu"] = cfg.liger_glu_activation
|
||||
with zero_only():
|
||||
LOG.info(
|
||||
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
|
||||
if cfg.model_config_type == "llama":
|
||||
from liger_kernel.transformers.model.llama import (
|
||||
lce_forward as llama_lce_forward,
|
||||
)
|
||||
from transformers.models.llama import modeling_llama
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_llama.LlamaRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_llama.LlamaMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
elif cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "mistral":
|
||||
from liger_kernel.transformers.model.mistral import (
|
||||
lce_forward as mistral_lce_forward,
|
||||
)
|
||||
from transformers.models.mistral import modeling_mistral
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_mistral.MistralRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_mistral.MistralMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "gemma":
|
||||
from liger_kernel.transformers.model.gemma import (
|
||||
lce_forward as gemma_lce_forward,
|
||||
)
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_gemma.GemmaRMSNorm = partial(
|
||||
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
||||
)
|
||||
apply_liger_fn(**kwargs)
|
||||
if cfg.liger_swiglu:
|
||||
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "jamba":
|
||||
from transformers.models.jamba import modeling_jamba
|
||||
|
||||
@@ -80,14 +104,30 @@ class LigerPlugin(BasePlugin):
|
||||
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
||||
if cfg.liger_glu_activation:
|
||||
if cfg.liger_swiglu:
|
||||
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
from transformers.loss.loss_utils import nn
|
||||
|
||||
nn.functional.cross_entropy = liger_cross_entropy
|
||||
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "qwen2":
|
||||
from liger_kernel.transformers.model.qwen2 import (
|
||||
lce_forward as qwen2_lce_forward,
|
||||
)
|
||||
from transformers.models.qwen2 import modeling_qwen2
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "deepseek_v2":
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoModelForCausalLM
|
||||
@@ -106,11 +146,44 @@ class LigerPlugin(BasePlugin):
|
||||
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_glu_activation:
|
||||
if cfg.liger_swiglu:
|
||||
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
||||
if cfg.liger_cross_entropy:
|
||||
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
||||
# nn.CrossEntropyLoss in the forward method.
|
||||
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
||||
|
||||
elif cfg.model_config_type == "gemma2":
|
||||
from transformers.models.gemma2 import modeling_gemma2
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_gemma2.Gemma2RMSNorm = partial(
|
||||
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
||||
)
|
||||
if cfg.liger_swiglu:
|
||||
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
logging.warning(
|
||||
"Fused linear cross entropy is not supported for Gemma 2."
|
||||
)
|
||||
|
||||
elif cfg.model_config_type == "phi3":
|
||||
from liger_kernel.transformers.model.phi3 import (
|
||||
lce_forward as phi3_lce_forward,
|
||||
)
|
||||
from transformers.models.phi3 import modeling_phi3
|
||||
|
||||
if cfg.liger_rope:
|
||||
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||
if cfg.liger_rms_norm:
|
||||
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
|
||||
if cfg.liger_swiglu:
|
||||
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
|
||||
if cfg.liger_cross_entropy:
|
||||
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
|
||||
if cfg.liger_fused_linear_cross_entropy:
|
||||
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward
|
||||
|
||||
@@ -15,12 +15,9 @@
|
||||
"""
|
||||
Module for handling LIGER input arguments.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
LOG = logging.getLogger("axolotl.integrations.liger.args")
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LigerArgs(BaseModel):
|
||||
@@ -30,24 +27,6 @@ class LigerArgs(BaseModel):
|
||||
|
||||
liger_rope: Optional[bool] = None
|
||||
liger_rms_norm: Optional[bool] = None
|
||||
liger_layer_norm: Optional[bool] = None
|
||||
liger_swiglu: Optional[bool] = None
|
||||
liger_glu_activation: Optional[bool] = None
|
||||
liger_cross_entropy: Optional[bool] = None
|
||||
liger_fused_linear_cross_entropy: Optional[bool] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_deprecated_swiglu(cls, data):
|
||||
if data.get("liger_swiglu") is not None:
|
||||
if data.get("liger_glu_activation") is not None:
|
||||
raise ValueError(
|
||||
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
|
||||
)
|
||||
|
||||
LOG.warning(
|
||||
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
|
||||
"Please use 'liger_glu_activation' instead."
|
||||
)
|
||||
data["liger_glu_activation"] = data.pop("liger_swiglu")
|
||||
return data
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
# LM Eval Harness
|
||||
|
||||
### Usage
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.lm_eval.LMEvalPlugin
|
||||
|
||||
lm_eval_tasks:
|
||||
- gsm8k
|
||||
- hellaswag
|
||||
- arc_easy
|
||||
```
|
||||
@@ -1,42 +0,0 @@
|
||||
"""
|
||||
Module for the Plugin for LM Eval Harness
|
||||
"""
|
||||
import subprocess # nosec
|
||||
from datetime import datetime
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
|
||||
class LMEvalPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for LM Evaluation Harness integraton with Axolotl.
|
||||
"""
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
tasks = ",".join(cfg.lm_eval_tasks)
|
||||
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
|
||||
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
|
||||
output_path = cfg.output_dir
|
||||
output_path += "" if cfg.output_dir.endswith("/") else "/"
|
||||
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
subprocess.run( # nosec
|
||||
[
|
||||
"lm_eval",
|
||||
"--model",
|
||||
"hf",
|
||||
"--model_args",
|
||||
f"pretrained={cfg.output_dir}{fa2}{dtype}",
|
||||
"--tasks",
|
||||
tasks,
|
||||
"--batch_size",
|
||||
str(cfg.lm_eval_batch_size),
|
||||
"--output_path",
|
||||
output_path,
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
@@ -1,15 +0,0 @@
|
||||
"""
|
||||
Module for handling lm eval harness input arguments.
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LMEvalArgs(BaseModel):
|
||||
"""
|
||||
Input args for lm eval harness
|
||||
"""
|
||||
|
||||
lm_eval_tasks: List[str] = []
|
||||
lm_eval_batch_size: Optional[int] = 8
|
||||
231
src/axolotl/monkeypatch/fastchat_conversation_turns.py
Normal file
231
src/axolotl/monkeypatch/fastchat_conversation_turns.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""
|
||||
monkeypatch to add a get_turns method
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Generator, Tuple
|
||||
|
||||
from fastchat.conversation import SeparatorStyle
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
|
||||
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
ret = ""
|
||||
for role, msg in self.get_turns():
|
||||
ret += role + msg
|
||||
return ret
|
||||
|
||||
|
||||
def get_turns( # pylint: disable=too-many-return-statements
|
||||
self,
|
||||
) -> Generator[Tuple[str, str], None, None]:
|
||||
"""Get the prompt for generation."""
|
||||
system_prompt = self.system_template.format(system_message=self.system_message)
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role + ": ", message + seps[i % 2]
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ": ", "" # must be end with a space
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
|
||||
yield "", "" if system_prompt == "" else system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", message + self.sep
|
||||
else:
|
||||
yield role + "\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role, message + self.sep
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role, message + seps[i % 2]
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.RWKV:
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role + ": ", message.replace("\r\n", "\n").replace(
|
||||
"\n\n", "\n"
|
||||
) + "\n\n"
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
|
||||
if self.system_message:
|
||||
if self.messages:
|
||||
# For llama, the system message is incorporated into the first human instruction
|
||||
first_role, first_msg = self.messages[0]
|
||||
if first_role == self.roles[0]:
|
||||
system_prompt += first_msg
|
||||
self.messages.pop(0)
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
if (i % 2 == 0 and not self.system_message) or (
|
||||
i % 2 != 0 and self.system_message
|
||||
):
|
||||
role = "<s> " + role
|
||||
yield role + " ", message
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
|
||||
contains_sys_msg = False
|
||||
if self.system_message:
|
||||
contains_sys_msg = True
|
||||
if self.messages:
|
||||
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction separated by a newline
|
||||
first_role, first_msg = self.messages[0]
|
||||
if first_role == self.roles[0]:
|
||||
system_prompt = self.system_template.format(
|
||||
system_message=" " + self.system_message
|
||||
)
|
||||
system_prompt += first_msg
|
||||
self.messages.pop(0)
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message and i == 0 and not contains_sys_msg:
|
||||
yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
|
||||
elif message:
|
||||
yield role + " ", message
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA3:
|
||||
if self.system_message:
|
||||
# For llama3, the system message is NOT incorporated into the first human instruction
|
||||
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
|
||||
else:
|
||||
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.GEMMA:
|
||||
if self.system_message:
|
||||
raise ValueError("Gemma chat template does not support system messages")
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
prefix = "<bos>" if i == 0 else ""
|
||||
message_str = message if message else ""
|
||||
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATGLM:
|
||||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||
round_add_n = 1 if self.name == "chatglm2" else 0
|
||||
if system_prompt:
|
||||
yield "", system_prompt + self.sep
|
||||
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if i % 2 == 0:
|
||||
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
|
||||
|
||||
if message:
|
||||
yield f"{role}:", f"{message}{self.sep}"
|
||||
else:
|
||||
yield f"{role}:", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATML:
|
||||
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", message + self.sep + "\n"
|
||||
else:
|
||||
yield role + "\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATGLM3:
|
||||
if self.system_message:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", " " + message
|
||||
else:
|
||||
yield role
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATINTERN:
|
||||
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
prefix = "<s>" if i % 2 == 0 else ""
|
||||
if message:
|
||||
yield prefix + role + ":", message + seps[i % 2] + "\n"
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.DOLLY:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
suffix = "\n\n" if i % 2 == 1 else ""
|
||||
yield role + ":\n", message + seps[i % 2] + suffix
|
||||
else:
|
||||
yield role + ":\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.PHOENIX:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", "<s>" + message + "</s>"
|
||||
else:
|
||||
yield role + ": " + "<s>", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ROBIN:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ":\n", message + self.sep
|
||||
else:
|
||||
yield role + ":\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.FALCON_CHAT:
|
||||
if self.system_message:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ":", ""
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
|
||||
def add_get_turns_to_conversation():
|
||||
import fastchat.conversation
|
||||
|
||||
fastchat.conversation.Conversation.get_turns = get_turns
|
||||
fastchat.conversation.Conversation.get_prompt = get_prompt
|
||||
@@ -22,6 +22,7 @@ from transformers.models.llama.modeling_llama import (
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from xformers.ops import SwiGLU
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||
|
||||
@@ -43,19 +44,7 @@ except ImportError:
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def is_xformers_available() -> bool:
|
||||
try:
|
||||
import xformers # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def is_xformers_swiglu_available() -> bool:
|
||||
if not is_xformers_available():
|
||||
return False
|
||||
|
||||
from xformers.ops.common import get_xformers_operator
|
||||
|
||||
try:
|
||||
@@ -68,11 +57,6 @@ def is_xformers_swiglu_available() -> bool:
|
||||
|
||||
|
||||
def replace_llama_mlp_with_swiglu(model):
|
||||
if is_xformers_swiglu_available():
|
||||
from axolotl.monkeypatch.xformers_ import FusedMLP
|
||||
else:
|
||||
raise RuntimeError("xformers SwiGLU not available for this environment")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaMLP):
|
||||
mlp = FusedMLP(
|
||||
@@ -197,6 +181,49 @@ class FusedAttention(LlamaAttention):
|
||||
set_module_name(model, name, new_attn)
|
||||
|
||||
|
||||
class FusedMLP(torch.nn.Module):
|
||||
"""
|
||||
Fused MLP layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
gate_proj: torch.nn.Linear,
|
||||
up_proj: torch.nn.Linear,
|
||||
down_proj: torch.nn.Linear,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.swiglu = SwiGLU(
|
||||
in_features=config.hidden_size,
|
||||
hidden_features=config.intermediate_size,
|
||||
bias=False,
|
||||
_pack_weights=True,
|
||||
)
|
||||
# overwrite initialized weights with pretrained weights
|
||||
self.swiglu.w12.weight.data = torch.cat(
|
||||
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
||||
)
|
||||
self.swiglu.w3.weight.data = down_proj.weight.data
|
||||
|
||||
def _post_training(self, model, name):
|
||||
w1, w2 = torch.split( # pylint: disable=invalid-name
|
||||
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
||||
)
|
||||
|
||||
# Assign the split weights back to the original layers
|
||||
new_mlp = LlamaMLP(self.config)
|
||||
new_mlp.gate_proj.weight.data = w1
|
||||
new_mlp.up_proj.weight.data = w2
|
||||
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
||||
|
||||
set_module_name(model, name, new_mlp)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||
return self.swiglu(x)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
|
||||
@@ -44,8 +44,8 @@ def magnitude_pruning_(tensor, prune_ratio):
|
||||
def reset_optimizer(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
*,
|
||||
reset_params: List[str], # where str is the key to a torch.nn.Parameter
|
||||
optimizer_state_keys: List[str],
|
||||
reset_params: list[str], # where str is the key to a torch.nn.Parameter
|
||||
optimizer_state_keys: list[str],
|
||||
prune_ratio: float = 0.9,
|
||||
):
|
||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
|
||||
|
||||
@@ -16,6 +16,26 @@ from transformers.models.llama.modeling_llama import (
|
||||
|
||||
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||
|
||||
ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
"""
|
||||
|
||||
PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
loss = fast_cross_entropy_loss(
|
||||
logits = shift_logits,
|
||||
labels = shift_labels,
|
||||
)
|
||||
"""
|
||||
|
||||
ORIGINAL_QKV_CODE = """
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
@@ -60,6 +80,12 @@ def get_forward_code() -> str:
|
||||
return forward
|
||||
|
||||
|
||||
def check_cel_is_patchable() -> bool:
|
||||
forward = get_forward_code()
|
||||
forward, _ = detab_code(forward)
|
||||
return ORIGINAL_CEL_CODE in forward
|
||||
|
||||
|
||||
def get_self_attn_code() -> str:
|
||||
forward = inspect.getsource(LlamaFlashAttention2.forward)
|
||||
return forward
|
||||
@@ -72,31 +98,48 @@ def check_self_attn_is_patchable() -> bool:
|
||||
|
||||
|
||||
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
||||
from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss
|
||||
|
||||
def UnslothForCausalLMLoss( # pylint: disable=invalid-name
|
||||
logits,
|
||||
labels,
|
||||
vocab_size: int, # pylint: disable=unused-argument
|
||||
num_items_in_batch: int = None,
|
||||
ignore_index: int = -100, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
loss = fast_cross_entropy_loss(
|
||||
logits=shift_logits, labels=shift_labels, n_items=num_items_in_batch
|
||||
)
|
||||
return loss
|
||||
|
||||
if model_type == "llama":
|
||||
from transformers.loss import loss_utils
|
||||
forward = get_forward_code()
|
||||
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
||||
forward, _ = detab_code(forward)
|
||||
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
||||
|
||||
loss_utils.ForCausalLMLoss = UnslothForCausalLMLoss # type: ignore[assignment]
|
||||
forward = forward.replace(
|
||||
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
||||
)
|
||||
forward = forward.replace(
|
||||
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
||||
"",
|
||||
)
|
||||
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
||||
forward = forward.replace(
|
||||
"def forward(",
|
||||
"def fast_cross_entropy_loss_forward(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.models.llama.modeling_llama
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.models.llama.modeling_llama):
|
||||
if item in forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
||||
globals(),
|
||||
)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.models.llama.modeling_llama import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
|
||||
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
else:
|
||||
raise ValueError("Unsupported model type")
|
||||
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
"""
|
||||
Fused MLP layer for incrementally improved training efficiency
|
||||
"""
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaMLP
|
||||
from xformers.ops import SwiGLU
|
||||
|
||||
from axolotl.monkeypatch.utils import set_module_name
|
||||
|
||||
|
||||
class FusedMLP(torch.nn.Module):
|
||||
"""
|
||||
Fused MLP layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
gate_proj: torch.nn.Linear,
|
||||
up_proj: torch.nn.Linear,
|
||||
down_proj: torch.nn.Linear,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.swiglu = SwiGLU(
|
||||
in_features=config.hidden_size,
|
||||
hidden_features=config.intermediate_size,
|
||||
bias=False,
|
||||
_pack_weights=True,
|
||||
)
|
||||
# overwrite initialized weights with pretrained weights
|
||||
self.swiglu.w12.weight.data = torch.cat(
|
||||
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
||||
)
|
||||
self.swiglu.w3.weight.data = down_proj.weight.data
|
||||
|
||||
def _post_training(self, model, name):
|
||||
w1, w2 = torch.split( # pylint: disable=invalid-name
|
||||
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
||||
)
|
||||
|
||||
# Assign the split weights back to the original layers
|
||||
new_mlp = LlamaMLP(self.config)
|
||||
new_mlp.gate_proj.weight.data = w1
|
||||
new_mlp.up_proj.weight.data = w2
|
||||
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
||||
|
||||
set_module_name(model, name, new_mlp)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||
return self.swiglu(x)
|
||||
@@ -11,10 +11,6 @@ LOG = logging.getLogger("axolotl.prompt_strategies")
|
||||
|
||||
def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
||||
try:
|
||||
if strategy == "messages":
|
||||
from .messages import load as messages_load
|
||||
|
||||
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
|
||||
load_fn = "load"
|
||||
if strategy.split(".")[-1].startswith("load_"):
|
||||
load_fn = strategy.split(".")[-1]
|
||||
@@ -35,5 +31,4 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
|
||||
return None
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||
raise exc
|
||||
return None
|
||||
return None
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
### example yaml
|
||||
|
||||
```yaml
|
||||
chat_template: gemma
|
||||
datasets:
|
||||
- path: argilla/distilabel-intel-orca-dpo-pairs
|
||||
type: bradley_terry.chat_template
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/out
|
||||
```
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Module to load prompt strategies."""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
||||
|
||||
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry")
|
||||
|
||||
|
||||
def load(strategy, tokenizer, cfg, ds_cfg):
|
||||
# pylint: disable=duplicate-code
|
||||
try:
|
||||
load_fn = "load"
|
||||
if strategy.split(".")[-1].startswith("load_"):
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(
|
||||
f".{strategy}", "axolotl.prompt_strategies.bradley_terry"
|
||||
)
|
||||
func = getattr(mod, load_fn)
|
||||
load_kwargs = {}
|
||||
if strategy == "user_defined":
|
||||
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
||||
else:
|
||||
sig = inspect.signature(func)
|
||||
if "ds_cfg" in sig.parameters:
|
||||
load_kwargs["ds_cfg"] = ds_cfg
|
||||
return func(tokenizer, cfg, **load_kwargs)
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||
return None
|
||||
@@ -1,102 +0,0 @@
|
||||
"""
|
||||
Bradley-Terry model with chat template prompt strategy.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import (
|
||||
ChatTemplatePrompter,
|
||||
ChatTemplateStrategy,
|
||||
)
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
|
||||
# Configure the logger
|
||||
LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template")
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class BTChatTemplateStrategy(ChatTemplateStrategy):
|
||||
"""
|
||||
Bradley-Terry reward model pairwise chat template prompt strategy.
|
||||
"""
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
"""
|
||||
|
||||
:param prompt: the actual row of data from the underlying dataset
|
||||
:return:
|
||||
"""
|
||||
|
||||
self.messages = "chosen_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt[self.messages] = []
|
||||
if prompt["system"]:
|
||||
prompt[self.messages].append(
|
||||
{"role": "system", "content": prompt["system"]}
|
||||
)
|
||||
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
||||
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
|
||||
chosen_tokenized = super().tokenize_prompt(prompt)
|
||||
|
||||
self.messages = "rejected_messages"
|
||||
# pylint: disable=duplicate-code
|
||||
prompt[self.messages] = []
|
||||
if prompt["system"]:
|
||||
prompt[self.messages].append(
|
||||
{"role": "system", "content": prompt["system"]}
|
||||
)
|
||||
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
|
||||
prompt[self.messages].append(
|
||||
{"role": "assistant", "content": prompt["rejected"]}
|
||||
)
|
||||
rejected_tokenized = super().tokenize_prompt(prompt)
|
||||
|
||||
return {
|
||||
"input_ids_chosen": chosen_tokenized["input_ids"],
|
||||
"attention_mask_chosen": chosen_tokenized["attention_mask"],
|
||||
"labels_chosen": 1.0,
|
||||
"input_ids_rejected": rejected_tokenized["input_ids"],
|
||||
"attention_mask_rejected": rejected_tokenized["attention_mask"],
|
||||
"labels_rejected": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ds_cfg = ds_cfg or {}
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_string,
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
"message_field_training_detail": ds_cfg.get(
|
||||
"message_field_training_detail", None
|
||||
),
|
||||
"roles": ds_cfg.get("roles"),
|
||||
"drop_system_message": ds_cfg.get("drop_system_message", False),
|
||||
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
|
||||
"max_length": cfg.sequence_len + 1
|
||||
if not cfg.reward_model
|
||||
else cfg.sequence_len,
|
||||
}
|
||||
|
||||
strategy_params = {
|
||||
"train_on_inputs": cfg.train_on_inputs,
|
||||
"sequence_len": cfg.sequence_len,
|
||||
"roles_to_train": ds_cfg.get("roles_to_train", []),
|
||||
"train_on_eos": ds_cfg.get("train_on_eos", None),
|
||||
}
|
||||
|
||||
strategy = BTChatTemplateStrategy(
|
||||
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
|
||||
)
|
||||
|
||||
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
|
||||
return strategy
|
||||
@@ -1,27 +0,0 @@
|
||||
"""
|
||||
chatml transforms for datasets with system, input, chosen, rejected to match llama3 chat template
|
||||
"""
|
||||
|
||||
|
||||
def icr(
|
||||
cfg,
|
||||
**kwargs,
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
"""
|
||||
chatml transforms for datasets with system, input, chosen, rejected
|
||||
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
|
||||
"""
|
||||
|
||||
def transform_fn(sample):
|
||||
if "system" in sample and sample["system"]:
|
||||
prompt = (
|
||||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["chosen"] = prompt + f"{sample['chosen']}<|eot_id|>"
|
||||
sample["rejected"] = prompt + f"{sample['rejected']}<|eot_id|>"
|
||||
return sample
|
||||
|
||||
return transform_fn
|
||||
@@ -9,7 +9,7 @@ from transformers import ProcessorMixin
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
|
||||
# Configure the logger
|
||||
LOG = logging.getLogger("axolotl")
|
||||
@@ -403,16 +403,11 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
|
||||
# pylint: disable=duplicate-code
|
||||
ds_cfg = ds_cfg or {}
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
)
|
||||
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
|
||||
|
||||
prompter_params = {
|
||||
"tokenizer": tokenizer,
|
||||
"chat_template": chat_template_string,
|
||||
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
|
||||
"message_field_role": ds_cfg.get("message_field_role", "role"),
|
||||
"message_field_content": ds_cfg.get("message_field_content", "content"),
|
||||
"message_field_training": ds_cfg.get("message_field_training", None),
|
||||
|
||||
@@ -2,16 +2,15 @@
|
||||
DPO prompt strategies for using tokenizer chat templates.
|
||||
"""
|
||||
|
||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
|
||||
|
||||
def default(
|
||||
cfg, dataset_idx=0, **kwargs
|
||||
): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
ds_cfg = cfg["datasets"][dataset_idx]
|
||||
chat_template_choice, chat_template_jinja = extract_chat_template_args(
|
||||
cfg=cfg, ds_cfg=ds_cfg
|
||||
)
|
||||
chat_template_str = chat_templates(cfg.chat_template)
|
||||
|
||||
field_messages = ds_cfg.get("field_messages", "messages")
|
||||
field_chosen = ds_cfg.get("field_chosen", "chosen")
|
||||
field_rejected = ds_cfg.get("field_rejected", "rejected")
|
||||
@@ -31,12 +30,6 @@ def default(
|
||||
role_map[source] = target
|
||||
|
||||
def transform_fn(sample, tokenizer=None):
|
||||
chat_template_string = get_chat_template(
|
||||
user_choice=chat_template_choice,
|
||||
jinja_template=chat_template_jinja,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
messages = sample[field_messages]
|
||||
messages = [
|
||||
{
|
||||
@@ -53,29 +46,28 @@ def default(
|
||||
"role": role_map[sample[field_rejected][field_message_role]],
|
||||
"content": sample[field_rejected][field_message_content],
|
||||
}
|
||||
dummy_user_message = {"role": "user", "content": "[[dummy_message]]"}
|
||||
|
||||
result = {}
|
||||
result["prompt"] = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_string,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
result["chosen"] = tokenizer.apply_chat_template(
|
||||
[dummy_user_message, chosen],
|
||||
[chosen],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_string,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=False,
|
||||
)
|
||||
chosen_strip_index = result["chosen"].find(chosen["content"])
|
||||
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
|
||||
|
||||
result["rejected"] = tokenizer.apply_chat_template(
|
||||
[dummy_user_message, rejected],
|
||||
[rejected],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_string,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=False,
|
||||
)
|
||||
rejected_strip_index = result["rejected"].find(rejected["content"])
|
||||
|
||||
33
src/axolotl/prompt_strategies/instruct.py
Normal file
33
src/axolotl/prompt_strategies/instruct.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Module containing the InstructShareGPTPromptTokenizingStrategy class"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||
)
|
||||
strategy = InstructShareGPTPromptTokenizingStrategy(
|
||||
# pylint: disable=duplicate-code
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
return strategy
|
||||
|
||||
|
||||
class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return [
|
||||
{"from": "human", "value": prompt["instruction"]},
|
||||
{"from": "gpt", "value": prompt["output"]},
|
||||
]
|
||||
@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Generator, List, Sequence
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -75,7 +75,7 @@ class Llama2ChatConversation:
|
||||
|
||||
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for Llama2 prompts.
|
||||
Tokenizing strategy for ShareGPT prompts.
|
||||
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
|
||||
"""
|
||||
|
||||
@@ -191,7 +191,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
|
||||
conv.messages = [] # pylint: disable=R0801
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
assert role == conv.roles[j % 2], ALTERNATING_ASSERTION_FAILED_ROLE
|
||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
if sentence["value"]:
|
||||
conv.append_message(role, sentence["value"])
|
||||
yield conv
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
"""Module to load message prompt strategies."""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
LOG = logging.getLogger("axolotl.prompt_strategies.messages")
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg, processor=None):
|
||||
try:
|
||||
strategy = ds_cfg.get("input_transform", "chat")
|
||||
# pylint: disable=duplicate-code
|
||||
load_fn = "load"
|
||||
if strategy.split(".")[-1].startswith("load_"):
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(
|
||||
f".{strategy}", "axolotl.prompt_strategies.messages"
|
||||
)
|
||||
func = getattr(mod, load_fn)
|
||||
load_kwargs = {}
|
||||
sig = inspect.signature(func)
|
||||
if "ds_cfg" in sig.parameters:
|
||||
load_kwargs["ds_cfg"] = ds_cfg
|
||||
if "processor" in sig.parameters:
|
||||
load_kwargs["processor"] = processor
|
||||
return func(tokenizer, cfg, **load_kwargs)
|
||||
except ModuleNotFoundError:
|
||||
return None
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
|
||||
raise exc
|
||||
return None
|
||||
@@ -1,84 +0,0 @@
|
||||
"""
|
||||
Chat dataset wrapping strategy for new internal messages representations
|
||||
"""
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from axolotl.core.datasets.chat import TokenizedChatDataset
|
||||
from axolotl.core.datasets.transforms.chat_builder import chat_message_transform_builder
|
||||
from axolotl.prompt_tokenizers import DatasetWrappingStrategy
|
||||
|
||||
|
||||
class ChatMessageDatasetWrappingStrategy(DatasetWrappingStrategy):
|
||||
"""
|
||||
Chat dataset wrapping strategy for new internal messages representations
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
processor,
|
||||
message_transform=None,
|
||||
formatter=None,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
"""
|
||||
:param processor: tokenizer or image processor
|
||||
:param kwargs:
|
||||
"""
|
||||
self.processor = processor
|
||||
self.dataset = None
|
||||
self.message_transform = message_transform
|
||||
self.formatter = formatter
|
||||
|
||||
def wrap_dataset(
|
||||
self,
|
||||
dataset,
|
||||
process_count: Optional[int] = None,
|
||||
keep_in_memory: Optional[bool] = False,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
self.dataset = TokenizedChatDataset(
|
||||
dataset,
|
||||
message_transform=self.message_transform,
|
||||
model_transform=self.processor,
|
||||
formatter=self.formatter,
|
||||
process_count=process_count,
|
||||
keep_in_memory=keep_in_memory,
|
||||
)
|
||||
return self.dataset
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
ds_cfg = ds_cfg or {}
|
||||
|
||||
field_messages = ds_cfg.get("field_messages")
|
||||
message_field_role = ds_cfg.get("message_field_role")
|
||||
message_field_content = ds_cfg.get("message_field_content")
|
||||
message_field_training = ds_cfg.get("message_field_training")
|
||||
|
||||
builder_kwargs = {}
|
||||
if field_messages:
|
||||
builder_kwargs["conversations_field"] = field_messages
|
||||
if message_field_role:
|
||||
builder_kwargs["message_field_role"] = message_field_role
|
||||
if message_field_content:
|
||||
builder_kwargs["message_field_content"] = message_field_content
|
||||
if message_field_training:
|
||||
builder_kwargs["message_field_training"] = message_field_training
|
||||
|
||||
chat_template = ds_cfg.get("chat_template", cfg.get("chat_template", "chatml"))
|
||||
format_message = (
|
||||
lambda x: x # noqa E731 # pylint: disable=unnecessary-lambda-assignment
|
||||
)
|
||||
if chat_template == "chatml":
|
||||
from axolotl.core.chat.format.chatml import format_message # noqa F811
|
||||
if chat_template.startswith("llama3"):
|
||||
from axolotl.core.chat.format.llama3x import format_message # noqa F811
|
||||
message_transform: Callable = chat_message_transform_builder(
|
||||
train_on_inputs=ds_cfg.get("train_on_inputs", False),
|
||||
**builder_kwargs,
|
||||
)
|
||||
strategy = ChatMessageDatasetWrappingStrategy(
|
||||
tokenizer, message_transform=message_transform, formatter=format_message
|
||||
)
|
||||
|
||||
return strategy
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel
|
||||
|
||||
from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy
|
||||
from axolotl.prompters import Prompter
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
@@ -28,13 +28,18 @@ def load(
|
||||
"""
|
||||
chatml transforms for datasets with system, input, chosen, rejected
|
||||
"""
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
|
||||
)
|
||||
tokenizer.chat_template = chat_template_string
|
||||
|
||||
chat_template = chat_templates("chatml")
|
||||
if ds_cfg and "chat_template" in ds_cfg:
|
||||
chat_template = ds_cfg["chat_template"]
|
||||
try:
|
||||
chat_template = chat_templates(chat_template)
|
||||
except ValueError:
|
||||
pass
|
||||
tokenizer.chat_template = chat_template
|
||||
|
||||
return ORPOTokenizingStrategy(
|
||||
ORPOPrompter(chat_template_string, tokenizer),
|
||||
ORPOPrompter(chat_template, tokenizer),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
@@ -243,30 +248,28 @@ class ORPOPrompter(Prompter):
|
||||
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
||||
dataset_parser = ORPODatasetParsingStrategy()
|
||||
|
||||
chat_template_str = chat_templates(cfg.chat_template)
|
||||
|
||||
def transform_fn(sample, tokenizer=None):
|
||||
res = {}
|
||||
|
||||
chat_template_string = get_chat_template_from_config(
|
||||
cfg=cfg, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
res["prompt"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
|
||||
add_generation_prompt=True,
|
||||
chat_template=chat_template_string,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=False,
|
||||
)
|
||||
prompt_str_len = len(res["prompt"])
|
||||
res["chosen"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_string,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=False,
|
||||
)[prompt_str_len:]
|
||||
res["rejected"] = tokenizer.apply_chat_template(
|
||||
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
|
||||
add_generation_prompt=False,
|
||||
chat_template=chat_template_string,
|
||||
chat_template=chat_template_str,
|
||||
tokenize=False,
|
||||
)[prompt_str_len:]
|
||||
|
||||
|
||||
220
src/axolotl/prompt_strategies/sharegpt.py
Normal file
220
src/axolotl/prompt_strategies/sharegpt.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
from axolotl.utils.tokenization import (
|
||||
chatml_to_conversation,
|
||||
merge_consecutive_messages,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def register_chatml_template(system_message=None):
|
||||
system_message = system_message or "You are a helpful assistant."
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="chatml",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
system_message=system_message,
|
||||
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>",
|
||||
)
|
||||
)
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="chatml_glaive",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
system_message=system_message,
|
||||
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def register_llama3_template(system_message=None):
|
||||
system_message = system_message or "You are a helpful assistant."
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="llama3",
|
||||
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
||||
system_message=system_message,
|
||||
roles=("user", "assistant"),
|
||||
sep_style=SeparatorStyle.LLAMA3,
|
||||
sep="",
|
||||
stop_str="<|eot_id|>",
|
||||
stop_token_ids=[128001, 128009],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_loader(
|
||||
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
|
||||
prompter_cls: Type["ShareGPTPrompterV2"],
|
||||
default_conversation: Optional[str] = None,
|
||||
):
|
||||
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"]
|
||||
if ds_cfg and "conversation" in ds_cfg
|
||||
else default_conversation
|
||||
)
|
||||
field_human = (
|
||||
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
)
|
||||
field_model = (
|
||||
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
)
|
||||
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
|
||||
strategy = tokenization_strategy_cls(
|
||||
prompter_cls(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
role_key_human=field_human,
|
||||
roles=roles,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
|
||||
strategy.strict = ds_cfg["strict"]
|
||||
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
|
||||
strategy.messages = ds_cfg["field_messages"]
|
||||
return strategy
|
||||
|
||||
return _load
|
||||
|
||||
|
||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
_strict = False
|
||||
_messages = "conversations"
|
||||
|
||||
@property
|
||||
def strict(self):
|
||||
return self._strict
|
||||
|
||||
@strict.setter
|
||||
def strict(self, strict):
|
||||
self._strict = strict
|
||||
|
||||
@property
|
||||
def messages(self):
|
||||
return self._messages
|
||||
|
||||
@messages.setter
|
||||
def messages(self, messages):
|
||||
self._messages = messages
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt[self.messages]
|
||||
if self.strict:
|
||||
return conversations
|
||||
role_key = "from"
|
||||
if "role" in conversations[0].keys():
|
||||
role_key = "role"
|
||||
value_key = "value"
|
||||
if "text" in conversations[0].keys():
|
||||
value_key = "text"
|
||||
elif "content" in conversations[0].keys():
|
||||
value_key = "content"
|
||||
# remap roles - allow for assistant turn"
|
||||
role_map = {
|
||||
"user": "human",
|
||||
"human": "human",
|
||||
"assistant": "gpt",
|
||||
"gpt": "gpt",
|
||||
"system": "system",
|
||||
}
|
||||
turns = [
|
||||
{
|
||||
"from": (
|
||||
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
|
||||
),
|
||||
"value": t[value_key],
|
||||
"weight": 1
|
||||
if "weight" not in t or t["weight"] is None
|
||||
else t["weight"],
|
||||
}
|
||||
for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
class SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||
SimpleShareGPTPromptTokenizingStrategy
|
||||
):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
|
||||
return turns
|
||||
|
||||
|
||||
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
sharegpt strategy that remaps oasst data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
role_map = {"prompter": "human", "assistant": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
sharegpt strategy that remaps ultrachat data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["messages"]
|
||||
role_map = {"user": "human", "assistant": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
sharegpt strategy that remaps glaive data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversation = chatml_to_conversation(prompt)
|
||||
conversation = merge_consecutive_messages(conversation)
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||
load_ultrachat = build_loader(
|
||||
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
|
||||
)
|
||||
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
|
||||
load_glaive = build_loader(
|
||||
GlaiveShareGPTPromptTokenizingStrategy,
|
||||
ShareGPTPrompterV2,
|
||||
default_conversation="chatml_glaive",
|
||||
)
|
||||
28
src/axolotl/prompt_strategies/sharegpt_jokes.py
Normal file
28
src/axolotl/prompt_strategies/sharegpt_jokes.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Module for Jokes prompts using sharegpt style """
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return SimpleJokesShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenization strategy for asking bot to tell a joke and then explain why its funny
|
||||
"""
|
||||
|
||||
# title, text, explanation
|
||||
def get_conversation_thread(self, prompt):
|
||||
title = "" if not prompt["title"] else prompt["title"] + " "
|
||||
return [
|
||||
{"from": "human", "value": "Tell me a joke."},
|
||||
{"from": "gpt", "value": title + prompt["text"]},
|
||||
{"from": "human", "value": "Why is that joke funny?"},
|
||||
{"from": "gpt", "value": prompt["explanation"]},
|
||||
]
|
||||
@@ -1,12 +1,17 @@
|
||||
"""Module containing PromptTokenizingStrategy and Prompter classes"""
|
||||
|
||||
import abc
|
||||
import copy
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
from transformers import BatchEncoding, PreTrainedTokenizer
|
||||
|
||||
from axolotl.prompters import Prompter
|
||||
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
||||
add_get_turns_to_conversation,
|
||||
)
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -16,6 +21,8 @@ LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
||||
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
||||
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
||||
|
||||
add_get_turns_to_conversation()
|
||||
|
||||
|
||||
class InvalidDataException(Exception):
|
||||
"""
|
||||
@@ -23,12 +30,6 @@ class InvalidDataException(Exception):
|
||||
"""
|
||||
|
||||
|
||||
class DatasetWrappingStrategy(abc.ABC):
|
||||
"""
|
||||
Abstract class for wrapping datasets for Chat Messages
|
||||
"""
|
||||
|
||||
|
||||
class PromptTokenizingStrategy(abc.ABC):
|
||||
"""
|
||||
Abstract class for tokenizing strategies
|
||||
@@ -324,6 +325,154 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
|
||||
)
|
||||
|
||||
|
||||
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenizing strategy for ShareGPT prompts.
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt["conversations"]
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# Initial values. We will append to these as we go through the conversation.
|
||||
result, current_len = tokenize_prompt_default()
|
||||
conversation: Conversation = (
|
||||
self.prompter._conversation.copy() # pylint: disable=protected-access
|
||||
)
|
||||
|
||||
input_roles = {conversation.roles[0]}
|
||||
output_roles = {conversation.roles[1]}
|
||||
|
||||
if len(conversation.roles) == 3:
|
||||
tool_role_label = conversation.roles[2]
|
||||
input_roles.add(tool_role_label)
|
||||
|
||||
# Add roles from the config
|
||||
if self.prompter.roles:
|
||||
if "input" in self.prompter.roles and self.prompter.roles["input"]:
|
||||
for role in self.prompter.roles["input"]:
|
||||
input_roles.add(role)
|
||||
|
||||
if "output" in self.prompter.roles and self.prompter.roles["output"]:
|
||||
for role in self.prompter.roles["output"]:
|
||||
output_roles.add(role)
|
||||
|
||||
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
||||
role_remap = []
|
||||
if (
|
||||
conversation.name == "vicuna_v1.1"
|
||||
and "roles" in prompt
|
||||
and len(prompt["roles"]) >= 2
|
||||
):
|
||||
role_remap = [
|
||||
{"from": conversation.roles[0], "to": prompt["roles"][0]},
|
||||
{"from": conversation.roles[1], "to": prompt["roles"][1]},
|
||||
]
|
||||
|
||||
try:
|
||||
for _, part in enumerate(
|
||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
||||
):
|
||||
if not isinstance(part, tuple):
|
||||
LOG.warning(f"expected tuple, got {part}")
|
||||
continue
|
||||
|
||||
if len(part) <= 2:
|
||||
role, content = part
|
||||
weight = 1
|
||||
else:
|
||||
role, content, weight = part
|
||||
|
||||
# Uses "in" because role contains extra characters
|
||||
input_turn = any(r.lower() in role.lower() for r in input_roles)
|
||||
output_turn = any(r.lower() in role.lower() for r in output_roles)
|
||||
empty_role = role.strip() == ""
|
||||
|
||||
if not any([input_turn, output_turn, empty_role]):
|
||||
LOG.warning(f"unhandled role: {role}")
|
||||
continue
|
||||
|
||||
if input_turn:
|
||||
role = (
|
||||
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||
if role_remap
|
||||
else role
|
||||
)
|
||||
turn = role + content
|
||||
# this is still the user query, we should
|
||||
if not content.strip():
|
||||
LOG.warning(f"user turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if self.train_on_inputs and weight == 1:
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
else:
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif output_turn:
|
||||
role = (
|
||||
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||
if role_remap
|
||||
else role
|
||||
)
|
||||
turn = role + content
|
||||
# this should be the assistant response, should end with an eos token
|
||||
if not content.strip():
|
||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
||||
add_eos_token = not (
|
||||
conversation.name == "chatml"
|
||||
and conversation.sep == self.tokenizer.eos_token
|
||||
)
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=add_eos_token,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
role_res = self._tokenize(
|
||||
role.rstrip(),
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
if not self.train_on_inputs:
|
||||
# mask out role tokens from the labels
|
||||
len_role = len(role_res["input_ids"])
|
||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
||||
len_role, len(labels)
|
||||
)
|
||||
if weight == 0:
|
||||
# everything from this is masked out from the labels
|
||||
# (role is masked out too because it makes no sense if contents is masked out)
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
|
||||
elif empty_role:
|
||||
turn = content
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
turn, add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
if self.train_on_inputs and weight == 1:
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
else:
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
result,
|
||||
current_len,
|
||||
res,
|
||||
labels,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
return result
|
||||
except (KeyError, AssertionError, IndexError) as err:
|
||||
raise InvalidDataException(str(err)) from err
|
||||
|
||||
|
||||
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
||||
"""
|
||||
Returns the default values for the tokenize prompt function
|
||||
|
||||
@@ -5,6 +5,7 @@ from enum import Enum
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
from colorama import Fore
|
||||
from fastchat.conversation import Conversation, get_conv_template
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
IGNORE_TOKEN_ID = -100
|
||||
@@ -261,10 +262,166 @@ class ReflectAlpacaPrompter(Prompter):
|
||||
)
|
||||
|
||||
|
||||
ALTERNATING_ASSERTION_FAILED_ROLE = (
|
||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
"Role did not alternate between turns (gpt and human). Please check your data."
|
||||
)
|
||||
|
||||
CONVERSATION_ROLE_FORMAT = {
|
||||
"chatml": "<|im_start|>{ROLE}",
|
||||
"zephyr": "<|{ROLE}|>",
|
||||
"vicuna_v1.1": "{ROLE}",
|
||||
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
|
||||
}
|
||||
|
||||
|
||||
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
|
||||
role_key_human = "human"
|
||||
role_key_model = "gpt"
|
||||
# Optional, only used for tool usage datasets.
|
||||
role_key_tool: Optional[str] = None
|
||||
# Optional, role input/output mapping
|
||||
roles: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_style=None, # pylint: disable=unused-argument
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
role_key_tool: Optional[str] = None,
|
||||
roles: Optional[dict] = None,
|
||||
):
|
||||
if conversation:
|
||||
if isinstance(conversation, Conversation):
|
||||
self._conversation = conversation
|
||||
else:
|
||||
self._conversation = get_conv_template(conversation)
|
||||
else:
|
||||
self._conversation = get_conv_template("vicuna_v1.1")
|
||||
if role_key_human:
|
||||
self.role_key_human = role_key_human
|
||||
if role_key_model:
|
||||
self.role_key_model = role_key_model
|
||||
if role_key_tool:
|
||||
self.role_key_tool = role_key_tool
|
||||
if roles:
|
||||
self.roles = roles
|
||||
|
||||
def _build_result(self, source):
|
||||
if len(source) < 2:
|
||||
# If there isn't a back and forth conversation, ignore it
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
raise IndexError(
|
||||
f"A conversation entry has less than 2 messages :\n{source}"
|
||||
)
|
||||
|
||||
conv = self._conversation.copy()
|
||||
|
||||
original_source = source.copy()
|
||||
# Add the conversation system prompt if provided, otherwise use the default one
|
||||
if source[0]["from"] == "system":
|
||||
conv.set_system_message(source[0]["value"])
|
||||
source.pop(0)
|
||||
|
||||
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
|
||||
if self.role_key_tool:
|
||||
roles[self.role_key_tool] = conv.roles[2]
|
||||
|
||||
try:
|
||||
# Apply prompt templates
|
||||
if source[0]["from"] not in roles:
|
||||
# Skip the first one if it is not from human
|
||||
source = source[1:]
|
||||
except IndexError as err:
|
||||
# sometimes there is a bing or system chat
|
||||
raise err
|
||||
|
||||
conv.messages = []
|
||||
for _, sentence in enumerate(source):
|
||||
from_role = sentence["from"]
|
||||
if from_role in roles:
|
||||
role = roles[from_role]
|
||||
else:
|
||||
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
|
||||
raise NotImplementedError(
|
||||
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
|
||||
"Please help us by creating an Issue to add support for this conversation type."
|
||||
)
|
||||
|
||||
if self._conversation.name in ["llama3"]:
|
||||
role = from_role
|
||||
else:
|
||||
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
|
||||
ROLE=from_role
|
||||
)
|
||||
|
||||
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
|
||||
if (
|
||||
role != "assistant"
|
||||
): # back to back assistant calls may be okay for tool calls
|
||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||
|
||||
conv.append_message(role, sentence["value"])
|
||||
turns = list(conv.get_turns())
|
||||
original_source_length = len(original_source)
|
||||
assert len(turns) in [
|
||||
original_source_length - 1,
|
||||
original_source_length,
|
||||
original_source_length + 1,
|
||||
]
|
||||
if len(turns) == original_source_length + 1:
|
||||
original_source = [{"weight": None}] + original_source
|
||||
elif len(turns) == original_source_length - 1:
|
||||
original_source = original_source[1:]
|
||||
return [
|
||||
(*turn, weight)
|
||||
for turn, weight in zip(
|
||||
turns,
|
||||
[
|
||||
1 if "weight" not in e or e["weight"] is None else e["weight"]
|
||||
for e in original_source
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
turns = self._build_result(source)
|
||||
|
||||
for part in turns:
|
||||
if part[0] and not part[1]:
|
||||
LOG.warning(f"role with empty message: {part[0]}")
|
||||
yield part
|
||||
|
||||
def __repr__(self) -> str:
|
||||
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
|
||||
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
|
||||
|
||||
|
||||
class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
"""
|
||||
A V2 prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
role_key_tool: Optional[str] = None,
|
||||
roles: Optional[dict] = None,
|
||||
):
|
||||
super().__init__(
|
||||
conversation=conversation,
|
||||
role_key_human=role_key_human,
|
||||
role_key_model=role_key_model,
|
||||
role_key_tool=role_key_tool,
|
||||
roles=roles,
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedPrompter(Prompter):
|
||||
"""
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import save_fsdp_model
|
||||
from datasets import Dataset
|
||||
@@ -96,11 +97,12 @@ def train(
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
LOG.debug(msg)
|
||||
# we wait unitl the last possible moment to setup Accelerator
|
||||
Accelerator()
|
||||
model, peft_config = load_model(
|
||||
cfg, tokenizer, processor=processor, inference=cli_args.inference
|
||||
)
|
||||
if model.generation_config is not None:
|
||||
model.generation_config.do_sample = True
|
||||
model.generation_config.do_sample = True
|
||||
|
||||
model_ref = None
|
||||
if cfg.rl and cfg.rl != "orpo":
|
||||
@@ -260,10 +262,8 @@ def train(
|
||||
|
||||
if not cfg.hub_model_id:
|
||||
try:
|
||||
trainer.create_model_card(
|
||||
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
|
||||
)
|
||||
except (AttributeError, UnicodeDecodeError):
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
except AttributeError:
|
||||
pass
|
||||
elif cfg.hub_model_id:
|
||||
# defensively push to the hub to ensure the model card is updated
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
"""
|
||||
Basic utils for Axolotl
|
||||
"""
|
||||
import importlib.util
|
||||
import importlib
|
||||
|
||||
|
||||
def is_mlflow_available():
|
||||
return importlib.util.find_spec("mlflow") is not None
|
||||
|
||||
|
||||
def is_comet_available():
|
||||
return importlib.util.find_spec("comet_ml") is not None
|
||||
|
||||
@@ -29,7 +29,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
||||
|
||||
from axolotl.utils import is_comet_available, is_mlflow_available
|
||||
from axolotl.utils import is_mlflow_available
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.callbacks.perplexity import Perplexity
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
@@ -462,7 +462,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
||||
references=[[r] for r in references],
|
||||
predictions=predictions,
|
||||
)
|
||||
scores["eval_" + metric_name] = score
|
||||
scores[metric_name] = score
|
||||
return scores
|
||||
|
||||
def predict_with_generate():
|
||||
@@ -747,15 +747,6 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
|
||||
artifact_file="PredictionsVsGroundTruth.json",
|
||||
tracking_uri=tracking_uri,
|
||||
)
|
||||
elif logger == "comet_ml" and is_comet_available():
|
||||
import comet_ml
|
||||
|
||||
experiment = comet_ml.get_running_experiment()
|
||||
if experiment:
|
||||
experiment.log_table(
|
||||
f"{name} - Predictions vs Ground Truth.csv",
|
||||
pd.DataFrame(table_data),
|
||||
)
|
||||
|
||||
if is_main_process():
|
||||
log_table_from_dataloader("Eval", eval_dataloader)
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
"""Comet module for trainer callbacks"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import comet_ml
|
||||
from transformers import TrainerCallback, TrainerControl, TrainerState
|
||||
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoCometCallback(TrainerCallback):
|
||||
"""Callback to save axolotl config to comet"""
|
||||
|
||||
def __init__(self, axolotl_config_path):
|
||||
self.axolotl_config_path = axolotl_config_path
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
comet_experiment = comet_ml.start(source="axolotl")
|
||||
comet_experiment.log_other("Created from", "axolotl")
|
||||
comet_experiment.log_asset(
|
||||
self.axolotl_config_path,
|
||||
file_name="axolotl-config",
|
||||
)
|
||||
LOG.info(
|
||||
"The Axolotl config has been saved to the Comet Experiment under assets."
|
||||
)
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to Comet: {err}")
|
||||
return control
|
||||
File diff suppressed because one or more lines are too long
@@ -4,7 +4,6 @@ Collators for multi-modal chat messages and packing
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from PIL import Image
|
||||
from transformers import PreTrainedTokenizerBase, ProcessorMixin
|
||||
from transformers.data.data_collator import DataCollatorMixin
|
||||
from transformers.utils import PaddingStrategy
|
||||
@@ -21,6 +20,7 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
return_tensors: str = "pt"
|
||||
chat_template: Optional[str] = None
|
||||
packing: bool = False
|
||||
sequence_length: Optional[int] = None
|
||||
max_images: int = -1
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
@@ -33,11 +33,112 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
|
||||
) -> Dict[str, Any]:
|
||||
# Handle dict or lists with proper padding and conversion to tensor.
|
||||
if self.packing:
|
||||
return self.__class__.process_rows_packing(
|
||||
examples,
|
||||
self.processor,
|
||||
self.chat_template,
|
||||
self.max_images,
|
||||
self.sequence_length,
|
||||
)
|
||||
|
||||
return self.__class__.process_rows(
|
||||
examples, self.processor, self.chat_template, self.max_images
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def process_rows_packing(
|
||||
examples,
|
||||
processor,
|
||||
chat_template,
|
||||
max_images,
|
||||
sequence_length,
|
||||
length_only=False,
|
||||
):
|
||||
import torch
|
||||
|
||||
# Perform sample packing within a batch
|
||||
|
||||
if processor.tokenizer.sep_token is None:
|
||||
sep_token = "[SEP]"
|
||||
processor.tokenizer.add_tokens([sep_token])
|
||||
processor.tokenizer.sep_token = sep_token
|
||||
sep_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||
processor.tokenizer.sep_token
|
||||
)
|
||||
pad_token_id = processor.tokenizer.pad_token_id
|
||||
|
||||
texts = [
|
||||
processor.apply_chat_template(
|
||||
example["messages"], chat_template=chat_template, tokenize=False
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
images = [example["images"] for example in examples]
|
||||
|
||||
if max_images > 0:
|
||||
images = [img_batch[:max_images] for img_batch in images]
|
||||
|
||||
batch = processor(text=texts, images=images, padding=False)
|
||||
|
||||
n_sequence = len(examples)
|
||||
n_seq_in_batch = 0
|
||||
pack_len = 0
|
||||
features_pack = {}
|
||||
packed = {}
|
||||
features = list[batch.keys()]
|
||||
for feature in features:
|
||||
features_pack[feature] = []
|
||||
packed[feature] = []
|
||||
features.remove("input_ids")
|
||||
|
||||
for seq_in_batch_id in range(n_sequence):
|
||||
next_seq_len = len(batch["input_ids"][seq_in_batch_id])
|
||||
if not pack_len + next_seq_len + 1 < sequence_length:
|
||||
n_seq_in_batch += 1
|
||||
pack_len += next_seq_len + 1
|
||||
features_pack["input_ids"] += batch["input_ids"][seq_in_batch_id] + [
|
||||
sep_token_id
|
||||
]
|
||||
|
||||
"""
|
||||
Do something with attention mask and cross-attention
|
||||
"""
|
||||
|
||||
for feature in features:
|
||||
features_pack[feature] += batch[feature][seq_in_batch_id]
|
||||
|
||||
else:
|
||||
for _ in range(sequence_length - pack_len):
|
||||
features_pack["input_ids"] += [pad_token_id]
|
||||
|
||||
packed["input_ids"].append(
|
||||
torch.tensor(features_pack["input_ids"].copy())
|
||||
)
|
||||
|
||||
for feature in features:
|
||||
packed[feature].append(torch.tensor(features_pack[feature].copy()))
|
||||
features_pack[feature] = []
|
||||
|
||||
pack_len = 0
|
||||
|
||||
image_token_id = processor.tokenizer.convert_tokens_to_ids(
|
||||
processor.image_token
|
||||
)
|
||||
labels = [pack.clone() for pack in packed["input_ids"]]
|
||||
for label_id, label in enumerate(labels):
|
||||
labels[label_id][label == processor.tokenizer.pad_token_id] = -100 #
|
||||
# Ignore the image token index in the loss computation (model specific)
|
||||
|
||||
labels[label_id][label == image_token_id] = -100
|
||||
packed["labels"] = labels
|
||||
|
||||
if length_only:
|
||||
return {
|
||||
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
|
||||
}
|
||||
return packed
|
||||
|
||||
@staticmethod
|
||||
def process_rows(examples, processor, chat_template, max_images, length_only=False):
|
||||
# HINT: use `_torch_collate_batch` to stack and pad tensors
|
||||
@@ -53,12 +154,7 @@ class MultiModalChatDataCollator(DataCollatorMixin):
|
||||
)
|
||||
for example in examples
|
||||
]
|
||||
images = [
|
||||
Image.open(example["images"])
|
||||
if isinstance(example["images"], str)
|
||||
else example["images"]
|
||||
for example in examples
|
||||
]
|
||||
images = [example["images"] for example in examples]
|
||||
|
||||
if max_images > 0:
|
||||
images = [img_batch[:max_images] for img_batch in images]
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
"""Module for wandb utilities"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.comet_")
|
||||
|
||||
COMET_ENV_MAPPING_OVERRIDE = {
|
||||
"comet_mode": "COMET_START_MODE",
|
||||
"comet_online": "COMET_START_ONLINE",
|
||||
}
|
||||
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE = {
|
||||
"auto_histogram_activation_logging": "COMET_AUTO_LOG_HISTOGRAM_ACTIVATIONS",
|
||||
"auto_histogram_epoch_rate": "COMET_AUTO_LOG_HISTOGRAM_EPOCH_RATE",
|
||||
"auto_histogram_gradient_logging": "COMET_AUTO_LOG_HISTOGRAM_GRADIENTS",
|
||||
"auto_histogram_tensorboard_logging": "COMET_AUTO_LOG_HISTOGRAM_TENSORBOARD",
|
||||
"auto_histogram_weight_logging": "COMET_AUTO_LOG_HISTOGRAM_WEIGHTS",
|
||||
"auto_log_co2": "COMET_AUTO_LOG_CO2",
|
||||
"auto_metric_logging": "COMET_AUTO_LOG_METRICS",
|
||||
"auto_metric_step_rate": "COMET_AUTO_LOG_METRIC_STEP_RATE",
|
||||
"auto_output_logging": "COMET_AUTO_LOG_OUTPUT_LOGGER",
|
||||
"auto_param_logging": "COMET_AUTO_LOG_PARAMETERS",
|
||||
"comet_disabled": "COMET_AUTO_LOG_DISABLE",
|
||||
"display_summary_level": "COMET_DISPLAY_SUMMARY_LEVEL",
|
||||
"distributed_node_identifier": "COMET_DISTRIBUTED_NODE_IDENTIFIER",
|
||||
"log_code": "COMET_AUTO_LOG_CODE",
|
||||
"log_env_cpu": "COMET_AUTO_LOG_ENV_CPU",
|
||||
"log_env_details": "COMET_AUTO_LOG_ENV_DETAILS",
|
||||
"log_env_disk": "COMET_AUTO_LOG_ENV_DISK",
|
||||
"log_env_gpu": "COMET_AUTO_LOG_ENV_GPU",
|
||||
"log_env_host": "COMET_AUTO_LOG_ENV_HOST",
|
||||
"log_env_network": "COMET_AUTO_LOG_ENV_NETWORK",
|
||||
"log_git_metadata": "COMET_AUTO_LOG_GIT_METADATA",
|
||||
"log_git_patch": "COMET_AUTO_LOG_GIT_PATCH",
|
||||
"log_graph": "COMET_AUTO_LOG_GRAPH",
|
||||
"name": "COMET_START_EXPERIMENT_NAME",
|
||||
"offline_directory": "COMET_OFFLINE_DIRECTORY",
|
||||
"parse_args": "COMET_AUTO_LOG_CLI_ARGUMENTS",
|
||||
"tags": "COMET_START_EXPERIMENT_TAGS",
|
||||
}
|
||||
|
||||
|
||||
def python_value_to_environ_value(python_value):
|
||||
if isinstance(python_value, bool):
|
||||
if python_value is True:
|
||||
return "true"
|
||||
|
||||
return "false"
|
||||
|
||||
if isinstance(python_value, int):
|
||||
return str(python_value)
|
||||
|
||||
if isinstance(python_value, list): # Comet only have one list of string parameter
|
||||
return ",".join(map(str, python_value))
|
||||
|
||||
return python_value
|
||||
|
||||
|
||||
def setup_comet_env_vars(cfg: DictDefault):
|
||||
# TODO, we need to convert Axolotl configuration to environment variables
|
||||
# as Transformers integration are call first and would create an
|
||||
# Experiment first
|
||||
|
||||
for key in cfg.keys():
|
||||
if key.startswith("comet_") and key != "comet_experiment_config":
|
||||
value = cfg.get(key, "")
|
||||
|
||||
if value is not None and value != "":
|
||||
env_variable_name = COMET_ENV_MAPPING_OVERRIDE.get(key, key.upper())
|
||||
final_value = python_value_to_environ_value(value)
|
||||
os.environ[env_variable_name] = final_value
|
||||
|
||||
if cfg.comet_experiment_config:
|
||||
for key, value in cfg.comet_experiment_config.items():
|
||||
if value is not None and value != "":
|
||||
config_env_variable_name = (
|
||||
COMET_EXPERIMENT_CONFIG_ENV_MAPPING_OVERRIDE.get(key)
|
||||
)
|
||||
|
||||
if config_env_variable_name is None:
|
||||
LOG.warning(
|
||||
f"Unknown Comet Experiment Config name {key}, ignoring it"
|
||||
)
|
||||
continue
|
||||
|
||||
final_value = python_value_to_environ_value(value)
|
||||
os.environ[config_env_variable_name] = final_value
|
||||
|
||||
# Enable comet if project name is present
|
||||
if cfg.comet_project_name and len(cfg.comet_project_name) > 0:
|
||||
cfg.use_comet = True
|
||||
@@ -215,6 +215,11 @@ def normalize_cfg_datasets(cfg):
|
||||
if cfg.chat_template:
|
||||
if cfg.datasets:
|
||||
for idx, ds_cfg in enumerate(cfg.datasets):
|
||||
if ds_cfg.type == "sharegpt" and not ds_cfg.conversation:
|
||||
LOG.info(
|
||||
f"updating dataset {ds_cfg.path} with `conversation: {cfg.chat_template}` to match your chat_template"
|
||||
)
|
||||
cfg.datasets[idx].conversation = cfg.chat_template
|
||||
if (
|
||||
ds_cfg.type in ["orpo.chat_template", "chat_template"]
|
||||
and not ds_cfg.chat_template
|
||||
@@ -223,7 +228,6 @@ def normalize_cfg_datasets(cfg):
|
||||
f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template"
|
||||
)
|
||||
cfg.datasets[idx].chat_template = cfg.chat_template
|
||||
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
|
||||
|
||||
|
||||
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||
@@ -456,6 +460,27 @@ def legacy_validate_config(cfg):
|
||||
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
||||
)
|
||||
|
||||
if cfg.datasets:
|
||||
for idx, ds_cfg in enumerate(cfg.datasets):
|
||||
if not ds_cfg.type:
|
||||
continue
|
||||
if ds_cfg.type == "sharegpt:chat":
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||
)
|
||||
)
|
||||
cfg.datasets[idx].type = "sharegpt"
|
||||
if "sharegpt_simple" in ds_cfg.type:
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||
)
|
||||
)
|
||||
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
||||
"sharegpt_simple", "sharegpt"
|
||||
)
|
||||
|
||||
if cfg.saves_per_epoch and cfg.save_steps:
|
||||
raise ValueError(
|
||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||
|
||||
@@ -8,16 +8,9 @@ import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from importlib.metadata import version
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
StringConstraints,
|
||||
conlist,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
|
||||
from transformers import SchedulerType
|
||||
from transformers.training_args import OptimizerNames
|
||||
|
||||
@@ -28,39 +21,6 @@ LOG = logging.getLogger("axolotl.utils.config.models.input")
|
||||
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
||||
|
||||
|
||||
class RLType(str, Enum):
|
||||
"""RL trainer type configuration subset"""
|
||||
|
||||
dpo = "dpo" # pylint: disable=invalid-name
|
||||
ipo = "ipo" # pylint: disable=invalid-name
|
||||
orpo = "orpo" # pylint: disable=invalid-name
|
||||
kto = "kto" # pylint: disable=invalid-name
|
||||
simpo = "simpo" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
"""Chat templates configuration subset"""
|
||||
|
||||
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||
chatml = "chatml" # pylint: disable=invalid-name
|
||||
mistral_v1 = "mistral_v1" # pylint: disable=invalid-name
|
||||
mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name
|
||||
mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name
|
||||
gemma = "gemma" # pylint: disable=invalid-name
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||
jamba = "jamba" # pylint: disable=invalid-name
|
||||
jinja = "jinja" # pylint: disable=invalid-name
|
||||
qwen_25 = "qwen_25" # pylint: disable=invalid-name
|
||||
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
|
||||
exaone = "exaone" # pylint: disable=invalid-name
|
||||
metharme = "metharme" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DeprecatedParameters(BaseModel):
|
||||
"""configurations that are deprecated"""
|
||||
|
||||
@@ -142,22 +102,14 @@ class SFTDataset(BaseModel):
|
||||
path: Optional[str] = None
|
||||
split: Optional[str] = None
|
||||
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
||||
input_transform: Optional[str] = None
|
||||
shards: Optional[int] = None
|
||||
conversation: Optional[str] = None
|
||||
# Do not make this too strict or it will break the validator to choose different dataset class
|
||||
chat_template: Optional[
|
||||
Union[
|
||||
ChatTemplate,
|
||||
str,
|
||||
]
|
||||
] = None
|
||||
chat_template_jinja: Optional[str] = None
|
||||
chat_template: Optional[str] = None
|
||||
data_files: Optional[Union[str, List[str]]] = None
|
||||
input_format: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
ds_type: Optional[str] = None
|
||||
train_on_split: Optional[str] = None
|
||||
|
||||
field: Optional[str] = None
|
||||
field_human: Optional[str] = None
|
||||
field_model: Optional[str] = None
|
||||
@@ -168,31 +120,11 @@ class SFTDataset(BaseModel):
|
||||
message_field_training_detail: Optional[str] = None
|
||||
roles_to_train: Optional[List[str]] = None
|
||||
train_on_eos: Optional[str] = None
|
||||
|
||||
roles: Optional[Dict[str, List[str]]] = None
|
||||
drop_system_message: Optional[bool] = None
|
||||
|
||||
trust_remote_code: Optional[bool] = False
|
||||
revision: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_chat_template_config(cls, data):
|
||||
# Set chat_template to tokenizer_default if not set
|
||||
if data.get("type") == "chat_template" and not data.get("chat_template"):
|
||||
data["chat_template"] = ChatTemplate.tokenizer_default
|
||||
|
||||
# if chat_template is set to jinja, chat_template_jinja is required
|
||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||
"chat_template_jinja"
|
||||
):
|
||||
raise ValueError(
|
||||
"chat_template_jinja is required when chat_template is set to jinja"
|
||||
)
|
||||
|
||||
# If chat_template_jinja is set, set chat_template to jinja
|
||||
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
||||
data["chat_template"] = ChatTemplate.jinja
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class UserDefinedDPOType(BaseModel):
|
||||
@@ -214,7 +146,6 @@ class DPODataset(BaseModel):
|
||||
split: Optional[str] = None
|
||||
type: Optional[Union[UserDefinedDPOType, str]] = None
|
||||
data_files: Optional[List[str]] = None
|
||||
revision: Optional[str] = None
|
||||
|
||||
|
||||
class UserDefinedKTOType(BaseModel):
|
||||
@@ -236,7 +167,32 @@ class KTODataset(BaseModel):
|
||||
type: Optional[Union[UserDefinedKTOType, str]] = None
|
||||
data_files: Optional[List[str]] = None
|
||||
trust_remote_code: Optional[bool] = False
|
||||
revision: Optional[str] = None
|
||||
|
||||
|
||||
class RLType(str, Enum):
|
||||
"""RL trainer type configuration subset"""
|
||||
|
||||
dpo = "dpo" # pylint: disable=invalid-name
|
||||
ipo = "ipo" # pylint: disable=invalid-name
|
||||
orpo = "orpo" # pylint: disable=invalid-name
|
||||
kto = "kto" # pylint: disable=invalid-name
|
||||
simpo = "simpo" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class ChatTemplate(str, Enum):
|
||||
"""Chat templates configuration subset"""
|
||||
|
||||
alpaca = "alpaca" # pylint: disable=invalid-name
|
||||
chatml = "chatml" # pylint: disable=invalid-name
|
||||
inst = "inst" # pylint: disable=invalid-name
|
||||
gemma = "gemma" # pylint: disable=invalid-name
|
||||
cohere = "cohere" # pylint: disable=invalid-name
|
||||
llama3 = "llama3" # pylint: disable=invalid-name
|
||||
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
|
||||
phi_3 = "phi_3" # pylint: disable=invalid-name
|
||||
phi_35 = "phi_35" # pylint: disable=invalid-name
|
||||
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
|
||||
jamba = "jamba" # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LoftQConfig(BaseModel):
|
||||
@@ -428,7 +384,6 @@ class HyperparametersConfig(BaseModel):
|
||||
"ao_adamw_4bit",
|
||||
"ao_adamw_8bit",
|
||||
"ao_adamw_fp8",
|
||||
"adopt_adamw",
|
||||
],
|
||||
]
|
||||
] = OptimizerNames.ADAMW_HF.value
|
||||
@@ -489,7 +444,6 @@ class MLFlowConfig(BaseModel):
|
||||
use_mlflow: Optional[bool] = None
|
||||
mlflow_tracking_uri: Optional[str] = None
|
||||
mlflow_experiment_name: Optional[str] = None
|
||||
mlflow_run_name: Optional[str] = None
|
||||
hf_mlflow_log_artifacts: Optional[bool] = None
|
||||
|
||||
|
||||
@@ -535,19 +489,6 @@ class WandbConfig(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class CometConfig(BaseModel):
|
||||
"""Comet configuration subset"""
|
||||
|
||||
use_comet: Optional[bool] = None
|
||||
comet_api_key: Optional[str] = None
|
||||
comet_workspace: Optional[str] = None
|
||||
comet_project_name: Optional[str] = None
|
||||
comet_experiment_key: Optional[str] = None
|
||||
comet_mode: Optional[str] = None
|
||||
comet_online: Optional[bool] = None
|
||||
comet_experiment_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class GradioConfig(BaseModel):
|
||||
"""Gradio configuration subset"""
|
||||
|
||||
@@ -568,7 +509,6 @@ class AxolotlInputConfig(
|
||||
HyperparametersConfig,
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
CometConfig,
|
||||
LISAConfig,
|
||||
GradioConfig,
|
||||
RemappedParameters,
|
||||
@@ -586,13 +526,8 @@ class AxolotlInputConfig(
|
||||
resume_from_checkpoint: Optional[str] = None
|
||||
auto_resume_from_checkpoints: Optional[bool] = None
|
||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||
mean_resizing_embeddings: Optional[bool] = False
|
||||
|
||||
rl: Optional[RLType] = None
|
||||
reward_model: Optional[bool] = None
|
||||
dpo_use_weighting: Optional[
|
||||
bool
|
||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||
|
||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||
@@ -759,13 +694,7 @@ class AxolotlInputConfig(
|
||||
gpu_memory_limit: Optional[Union[int, str]] = None
|
||||
low_cpu_mem_usage: Optional[bool] = None
|
||||
|
||||
chat_template: Optional[
|
||||
Union[
|
||||
ChatTemplate,
|
||||
Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")],
|
||||
]
|
||||
] = None
|
||||
chat_template_jinja: Optional[str] = None
|
||||
chat_template: Optional[ChatTemplate] = None
|
||||
default_system_message: Optional[str] = None
|
||||
|
||||
fix_untrained_tokens: Optional[bool] = None
|
||||
@@ -783,25 +712,28 @@ class AxolotlInputConfig(
|
||||
is_mistral_derived_model: Optional[bool] = Field(default=None)
|
||||
is_qwen_derived_model: Optional[bool] = Field(default=None)
|
||||
|
||||
plugins: Optional[List[str]] = Field(default=None)
|
||||
|
||||
@field_validator("datasets", mode="before")
|
||||
@classmethod
|
||||
def deprecate_sharegpt_datasets(cls, datasets):
|
||||
for _, ds_cfg in enumerate(datasets):
|
||||
if not ds_cfg.get("type"):
|
||||
def fix_sharegpt_datasets(cls, datasets):
|
||||
for idx, ds_cfg in enumerate(datasets):
|
||||
if not ds_cfg["type"]:
|
||||
continue
|
||||
|
||||
ds_type = ds_cfg["type"]
|
||||
# skip if it's a dict (for custom user instruction prompt)
|
||||
if isinstance(ds_type, dict):
|
||||
continue
|
||||
|
||||
if isinstance(ds_type, str) and ds_type.startswith("sharegpt"):
|
||||
raise ValueError(
|
||||
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
|
||||
if ds_cfg["type"] == "sharegpt:chat":
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||
)
|
||||
)
|
||||
datasets[idx]["type"] = "sharegpt"
|
||||
if "sharegpt_simple" in ds_cfg["type"]:
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||
)
|
||||
)
|
||||
datasets[idx]["type"] = datasets[idx]["type"].replace(
|
||||
"sharegpt_simple", "sharegpt"
|
||||
)
|
||||
|
||||
return datasets
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -871,23 +803,6 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_chat_template_config(cls, data):
|
||||
# if chat_template is set to jinja, chat_template_jinja is required
|
||||
if data.get("chat_template") == ChatTemplate.jinja and not data.get(
|
||||
"chat_template_jinja"
|
||||
):
|
||||
raise ValueError(
|
||||
"chat_template_jinja is required when chat_template is set to jinja"
|
||||
)
|
||||
|
||||
# If chat_template_jinja is set, set chat_template to jinja
|
||||
if data.get("chat_template_jinja") and not data.get("chat_template"):
|
||||
data["chat_template"] = ChatTemplate.jinja
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sample_packing_wo_flash(cls, data):
|
||||
@@ -918,17 +833,6 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def hint_reward_model_pad(cls, data):
|
||||
if data.get("reward_model") and not data.get("pad_to_sequence_len"):
|
||||
LOG.warning(
|
||||
"`pad_to_sequence_len: true` is recommended when using reward_model"
|
||||
)
|
||||
if data.get("pad_to_sequence_len") is None:
|
||||
data["pad_to_sequence_len"] = True
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_gas_bsz(cls, data):
|
||||
@@ -1062,26 +966,6 @@ class AxolotlInputConfig(
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
|
||||
if data.get("do_bench_eval") and not (
|
||||
data.get("evals_per_epoch") or data.get("eval_steps")
|
||||
):
|
||||
raise ValueError(
|
||||
"do_bench_eval requires evals_per_epoch or eval_steps to be set."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_test_datasets_bench(cls, data):
|
||||
if (
|
||||
data.get("do_bench_eval")
|
||||
and not data.get("test_datasets")
|
||||
and not data.get("val_set_size")
|
||||
):
|
||||
LOG.warning(
|
||||
"`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset."
|
||||
)
|
||||
data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}]
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1402,17 +1286,6 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_grad_accum_4_46_2(cls, data):
|
||||
if data.get("fsdp") and data.get("gradient_accumulation_steps") > 1:
|
||||
if version("transformers") == "4.46.2":
|
||||
raise ValueError(
|
||||
"FSDP w/ gradient_accumulation_steps > 1 is broken with transformers==4.46.2. "
|
||||
"Please use a lower value or switch to an older version of transformers."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
|
||||
@@ -90,7 +90,6 @@ def load_prepare_dpo_datasets(cfg):
|
||||
ds = load_dataset( # pylint: disable=invalid-name
|
||||
ds_cfg["path"],
|
||||
split=ds_cfg["split"],
|
||||
revision=ds_cfg.get("revision", None),
|
||||
)
|
||||
split_datasets.insert(i, ds)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user