Compare commits

..

8 Commits

Author SHA1 Message Date
sunny
bfb80a3ef9 stuff 2024-10-30 13:44:06 -04:00
sunny
38773d661f fixing 2024-10-30 11:04:50 -04:00
sunny
271c2c2b82 fixed formatting 2024-10-29 15:50:56 -04:00
sunny
32b6f30947 fix attempt at issue 1991 2024-10-29 15:44:32 -04:00
sunny
fc1f275e6c yml change 2024-10-29 15:27:42 -04:00
sunny
46d2b4ce89 yml change 2024-10-29 15:25:25 -04:00
sunny
88c9a7aecc LOG for debug 2024-10-29 13:35:55 -04:00
sunny
d9a93990d1 yml 2024-10-29 10:40:32 -04:00
144 changed files with 3551 additions and 6001 deletions

View File

@@ -1,16 +1,6 @@
name: ci-cd-base
on:
push:
branches:
- "main"
paths:
- 'Dockerfile-base'
- '.github/workflows/base.yml'
pull_request:
paths:
- 'Dockerfile-base'
- '.github/workflows/base.yml'
workflow_dispatch:
jobs:
@@ -37,7 +27,7 @@ jobs:
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.10"
python_version: "3.11"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
@@ -50,25 +40,23 @@ jobs:
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.5.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
uses: docker/metadata-action@v3
with:
images: |
winglian/axolotl-base
axolotlai/axolotl-base
images: winglian/axolotl-base
- name: Login to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v2
- name: Build
uses: docker/build-push-action@v4
with:

View File

@@ -17,7 +17,7 @@ jobs:
- name: Set up Quarto
uses: quarto-dev/quarto-actions/setup@v2
- name: Setup Python
uses: actions/setup-python@v5
uses: actions/setup-python@v3
with:
python-version: '3.10'
- name: install dependencies

View File

@@ -15,9 +15,9 @@ jobs:
name: pre-commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1
- uses: pre-commit/action@v3.0.0

View File

@@ -4,13 +4,11 @@ on:
push:
branches:
- "main"
tags:
- "v*"
workflow_dispatch:
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
@@ -34,7 +32,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -44,12 +42,7 @@ jobs:
id: metadata
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl
axolotlai/axolotl
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
images: winglian/axolotl
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
@@ -63,7 +56,7 @@ jobs:
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
@@ -77,7 +70,7 @@ jobs:
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
@@ -101,7 +94,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -111,25 +104,20 @@ jobs:
id: metadata
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-cloud
axolotlai/axolotl-cloud
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
images: winglian/axolotl-cloud
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v2
- name: Build
uses: docker/build-push-action@v5
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud
push: ${{ github.event_name != 'pull_request' }}
@@ -140,7 +128,7 @@ jobs:
build-axolotl-cloud-no-tmux:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
@@ -158,25 +146,20 @@ jobs:
id: metadata
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-cloud-term
axolotlai/axolotl-cloud-term
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
images: winglian/axolotl-cloud-term
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v2
- name: Build
uses: docker/build-push-action@v5
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_type == 'tag' && 'main' || github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-cloud-no-tmux
push: ${{ github.event_name != 'pull_request' }}

View File

@@ -8,14 +8,9 @@ on:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
# Cancel jobs on the same ref if a new one is triggered
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
test-axolotl-multigpu:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
@@ -36,7 +31,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.5.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"

View File

@@ -7,7 +7,7 @@ on:
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
@@ -31,7 +31,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -41,9 +41,7 @@ jobs:
id: metadata
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl
axolotlai/axolotl
images: winglian/axolotl
tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
- name: Set up Docker Buildx
@@ -71,7 +69,7 @@ jobs:
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
@@ -95,7 +93,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.5.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
@@ -105,9 +103,7 @@ jobs:
id: metadata
uses: docker/metadata-action@v5
with:
images: |
winglian/axolotl-cloud
axolotlai/axolotl-cloud
images: winglian/axolotl-cloud
tags: |
type=raw,value={{ branch }}-{{ date 'YYYYMMDD' }}
- name: Login to Docker Hub
@@ -116,7 +112,7 @@ jobs:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v2
- name: Build
uses: docker/build-push-action@v5
with:

View File

@@ -3,24 +3,12 @@ name: publish pypi
on:
push:
tags:
- 'v*'
workflow_dispatch:
- '*'
jobs:
setup_release:
name: Create Release
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Create release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: gh release create "$GITHUB_REF_NAME" # GITHUB_REF_NAME is the tag name in `on.push.tags` workflows
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
needs: [setup_release]
environment:
name: pypi
url: https://pypi.org/p/axolotl
@@ -28,10 +16,10 @@ jobs:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
steps:
- name: Check out repository code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v5
uses: actions/setup-python@v4
with:
python-version: "3.10"
@@ -49,9 +37,9 @@ jobs:
run: |
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
- name: Build a source dist
- name: Build a binary wheel
run: |
python setup.py sdist
python setup.py sdist bdist_wheel
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

View File

@@ -9,12 +9,12 @@ jobs:
name: pre-commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1
- uses: pre-commit/action@v3.0.0
env:
SKIP: no-commit-to-branch
@@ -25,15 +25,15 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v5
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
@@ -48,14 +48,12 @@ jobs:
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests
@@ -84,6 +82,13 @@ jobs:
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -94,7 +99,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.5.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"

View File

@@ -8,33 +8,24 @@ on:
- '**.py'
- 'requirements.txt'
- '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
pull_request:
paths:
- '**.py'
- 'requirements.txt'
- '.github/workflows/*.yml'
- 'requirements-tests.txt'
- 'cicd/cicd.sh'
workflow_dispatch:
# Cancel jobs on the same ref if a new one is triggered
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
jobs:
pre-commit:
name: pre-commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1
- uses: pre-commit/action@v3.0.0
env:
SKIP: no-commit-to-branch
@@ -45,15 +36,15 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.1"]
pytorch_version: ["2.3.1", "2.4.1", "2.5.0"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v5
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
@@ -71,108 +62,22 @@ jobs:
run: |
pip3 show torch
pip3 install -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests
run: |
pytest -n8 --ignore=tests/e2e/ tests/
pytest --ignore=tests/e2e/ tests/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
python3 setup.py sdist
pip3 install dist/axolotl*.tar.gz
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests
run: |
pytest -n8 --ignore=tests/e2e/ tests/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest, pytest-sdist]
strategy:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.tests
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest, docker-e2e-tests-1st]
needs: [pre-commit, pytest]
strategy:
fail-fast: false
@@ -184,10 +89,22 @@ jobs:
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.0
num_gpus: 1
axolotl_extras:
steps:

3
.gitignore vendored
View File

@@ -182,6 +182,3 @@ submit.sh
typings/
out/
# vim
*.swp

295
1991.yml Normal file
View File

@@ -0,0 +1,295 @@
base_model: Qwen/Qwen2.5-14B-Instruct
model_type: AutoModelForCausalLM #nohup accelerate launch -m axolotl.cli.train /home/ubuntu/qwen2.5_14B.yml > training_output.log 2>&1 &
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
chat_template: chatml
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
unfrozen_parameters:
- ^lm_head.weight$
- ^model.embed_tokens.weight$
# input_layernorm layers
- model.layers.0.input_layernorm
- model.layers.1.input_layernorm
- model.layers.2.input_layernorm
- model.layers.3.input_layernorm
- model.layers.4.input_layernorm
- model.layers.5.input_layernorm
- model.layers.6.input_layernorm
- model.layers.7.input_layernorm
- model.layers.8.input_layernorm
- model.layers.9.input_layernorm
- model.layers.10.input_layernorm
- model.layers.11.input_layernorm
- model.layers.12.input_layernorm
- model.layers.13.input_layernorm
- model.layers.14.input_layernorm
- model.layers.15.input_layernorm
- model.layers.16.input_layernorm
- model.layers.17.input_layernorm
- model.layers.18.input_layernorm
- model.layers.19.input_layernorm
- model.layers.20.input_layernorm
- model.layers.21.input_layernorm
- model.layers.22.input_layernorm
- model.layers.23.input_layernorm
# lm_head layers
# mlp.down_proj layers
- model.layers.1.mlp.down_proj
- model.layers.35.mlp.down_proj
- model.layers.38.mlp.down_proj
- model.layers.37.mlp.down_proj
- model.layers.36.mlp.down_proj
- model.layers.15.mlp.down_proj
- model.layers.11.mlp.down_proj
- model.layers.12.mlp.down_proj
- model.layers.34.mlp.down_proj
- model.layers.44.mlp.down_proj
- model.layers.45.mlp.down_proj
- model.layers.9.mlp.down_proj
- model.layers.41.mlp.down_proj
- model.layers.33.mlp.down_proj
- model.layers.43.mlp.down_proj
- model.layers.40.mlp.down_proj
- model.layers.13.mlp.down_proj
- model.layers.8.mlp.down_proj
- model.layers.39.mlp.down_proj
- model.layers.10.mlp.down_proj
- model.layers.14.mlp.down_proj
- model.layers.16.mlp.down_proj
- model.layers.31.mlp.down_proj
- model.layers.32.mlp.down_proj
# mlp.gate_proj layers
- model.layers.1.mlp.gate_proj
- model.layers.44.mlp.gate_proj
- model.layers.46.mlp.gate_proj
- model.layers.45.mlp.gate_proj
- model.layers.43.mlp.gate_proj
- model.layers.47.mlp.gate_proj
- model.layers.42.mlp.gate_proj
- model.layers.32.mlp.gate_proj
- model.layers.27.mlp.gate_proj
- model.layers.33.mlp.gate_proj
- model.layers.28.mlp.gate_proj
- model.layers.39.mlp.gate_proj
- model.layers.41.mlp.gate_proj
- model.layers.40.mlp.gate_proj
- model.layers.30.mlp.gate_proj
- model.layers.29.mlp.gate_proj
- model.layers.31.mlp.gate_proj
- model.layers.26.mlp.gate_proj
- model.layers.37.mlp.gate_proj
- model.layers.10.mlp.gate_proj
- model.layers.38.mlp.gate_proj
- model.layers.12.mlp.gate_proj
- model.layers.36.mlp.gate_proj
- model.layers.13.mlp.gate_proj
# mlp.up_proj layers
- model.layers.1.mlp.up_proj
- model.layers.13.mlp.up_proj
- model.layers.11.mlp.up_proj
- model.layers.14.mlp.up_proj
- model.layers.15.mlp.up_proj
- model.layers.12.mlp.up_proj
- model.layers.8.mlp.up_proj
- model.layers.16.mlp.up_proj
- model.layers.9.mlp.up_proj
- model.layers.19.mlp.up_proj
- model.layers.10.mlp.up_proj
- model.layers.7.mlp.up_proj
- model.layers.17.mlp.up_proj
- model.layers.20.mlp.up_proj
- model.layers.21.mlp.up_proj
- model.layers.18.mlp.up_proj
- model.layers.38.mlp.up_proj
- model.layers.37.mlp.up_proj
- model.layers.39.mlp.up_proj
- model.layers.42.mlp.up_proj
- model.layers.41.mlp.up_proj
- model.layers.27.mlp.up_proj
- model.layers.28.mlp.up_proj
- model.layers.34.mlp.up_proj
# model.norm layers
# post_attention_layernorm layers
- model.layers.0.post_attention_layernorm
- model.layers.1.post_attention_layernorm
- model.layers.2.post_attention_layernorm
- model.layers.3.post_attention_layernorm
- model.layers.4.post_attention_layernorm
- model.layers.5.post_attention_layernorm
- model.layers.6.post_attention_layernorm
- model.layers.7.post_attention_layernorm
- model.layers.8.post_attention_layernorm
- model.layers.9.post_attention_layernorm
- model.layers.10.post_attention_layernorm
- model.layers.11.post_attention_layernorm
- model.layers.12.post_attention_layernorm
- model.layers.13.post_attention_layernorm
- model.layers.14.post_attention_layernorm
- model.layers.15.post_attention_layernorm
- model.layers.16.post_attention_layernorm
- model.layers.17.post_attention_layernorm
- model.layers.18.post_attention_layernorm
- model.layers.19.post_attention_layernorm
- model.layers.20.post_attention_layernorm
- model.layers.21.post_attention_layernorm
- model.layers.22.post_attention_layernorm
- model.layers.23.post_attention_layernorm
# self_attn.k_proj layers
- model.layers.47.self_attn.k_proj
- model.layers.39.self_attn.k_proj
- model.layers.41.self_attn.k_proj
- model.layers.37.self_attn.k_proj
- model.layers.35.self_attn.k_proj
- model.layers.44.self_attn.k_proj
- model.layers.38.self_attn.k_proj
- model.layers.14.self_attn.k_proj
- model.layers.7.self_attn.k_proj
- model.layers.12.self_attn.k_proj
- model.layers.11.self_attn.k_proj
- model.layers.32.self_attn.k_proj
- model.layers.10.self_attn.k_proj
- model.layers.8.self_attn.k_proj
- model.layers.9.self_attn.k_proj
- model.layers.6.self_attn.k_proj
- model.layers.45.self_attn.k_proj
- model.layers.42.self_attn.k_proj
- model.layers.5.self_attn.k_proj
- model.layers.40.self_attn.k_proj
- model.layers.33.self_attn.k_proj
- model.layers.0.self_attn.k_proj
- model.layers.34.self_attn.k_proj
- model.layers.13.self_attn.k_proj
# self_attn.o_proj layers
- model.layers.12.self_attn.o_proj
- model.layers.5.self_attn.o_proj
- model.layers.14.self_attn.o_proj
- model.layers.16.self_attn.o_proj
- model.layers.20.self_attn.o_proj
- model.layers.13.self_attn.o_proj
- model.layers.11.self_attn.o_proj
- model.layers.4.self_attn.o_proj
- model.layers.6.self_attn.o_proj
- model.layers.19.self_attn.o_proj
- model.layers.7.self_attn.o_proj
- model.layers.18.self_attn.o_proj
- model.layers.8.self_attn.o_proj
- model.layers.38.self_attn.o_proj
- model.layers.15.self_attn.o_proj
- model.layers.17.self_attn.o_proj
- model.layers.9.self_attn.o_proj
- model.layers.10.self_attn.o_proj
- model.layers.21.self_attn.o_proj
- model.layers.28.self_attn.o_proj
- model.layers.32.self_attn.o_proj
- model.layers.35.self_attn.o_proj
- model.layers.39.self_attn.o_proj
- model.layers.3.self_attn.o_proj
# self_attn.q_proj layers
- model.layers.1.self_attn.q_proj
- model.layers.2.self_attn.q_proj
- model.layers.3.self_attn.q_proj
- model.layers.44.self_attn.q_proj
- model.layers.29.self_attn.q_proj
- model.layers.45.self_attn.q_proj
- model.layers.43.self_attn.q_proj
- model.layers.32.self_attn.q_proj
- model.layers.38.self_attn.q_proj
- model.layers.19.self_attn.q_proj
- model.layers.42.self_attn.q_proj
- model.layers.34.self_attn.q_proj
- model.layers.36.self_attn.q_proj
- model.layers.40.self_attn.q_proj
- model.layers.26.self_attn.q_proj
- model.layers.20.self_attn.q_proj
- model.layers.39.self_attn.q_proj
- model.layers.28.self_attn.q_proj
- model.layers.35.self_attn.q_proj
- model.layers.41.self_attn.q_proj
- model.layers.33.self_attn.q_proj
- model.layers.25.self_attn.q_proj
- model.layers.30.self_attn.q_proj
- model.layers.27.self_attn.q_proj
# self_attn.v_proj layers
- model.layers.0.self_attn.v_proj
- model.layers.7.self_attn.v_proj
- model.layers.39.self_attn.v_proj
- model.layers.31.self_attn.v_proj
- model.layers.15.self_attn.v_proj
- model.layers.10.self_attn.v_proj
- model.layers.32.self_attn.v_proj
- model.layers.41.self_attn.v_proj
- model.layers.6.self_attn.v_proj
- model.layers.33.self_attn.v_proj
- model.layers.42.self_attn.v_proj
- model.layers.29.self_attn.v_proj
- model.layers.14.self_attn.v_proj
- model.layers.9.self_attn.v_proj
- model.layers.35.self_attn.v_proj
- model.layers.38.self_attn.v_proj
- model.layers.13.self_attn.v_proj
- model.layers.30.self_attn.v_proj
- model.layers.5.self_attn.v_proj
- model.layers.34.self_attn.v_proj
- model.layers.28.self_attn.v_proj
- model.layers.37.self_attn.v_proj
- model.layers.27.self_attn.v_proj
- model.layers.11.self_attn.v_proj
# model.embed_tokens layers
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_torch_fused
lr_scheduler: linear
learning_rate: 5e-6
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
gradient_checkpointing: unsloth
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 2
saves_per_epoch: 1
save_total_limit: 4
debug:
deepspeed: deepspeed_configs/zero3_bf16.json
weight_decay: 0.05
special_tokens:
eos_token: <|im_end|>

View File

@@ -1,4 +0,0 @@
include requirements.txt
include README.md
include LICENSE
recursive-include axolotl *.py

View File

@@ -1,21 +1,8 @@
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="image/axolotl_logo_digital_white.svg">
<source media="(prefers-color-scheme: light)" srcset="image/axolotl_logo_digital_black.svg">
<img alt="Axolotl" src="image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
</picture>
</p>
# Axolotl
<p align="center">
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
<img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars">
</p>
<p align="center">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
</p>
![tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg)
![tests-nightly](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg)
![multigpu-semi-weekly tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg)
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
@@ -88,7 +75,7 @@ Features:
<td>
<div align="center">
<img src="image/axolotl_symbol_digital_white.svg" alt="axolotl" width="160">
<img src="image/axolotl.png" alt="axolotl" width="160">
<div>
<p>
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
@@ -147,7 +134,7 @@ pip3 install -e '.[flash-attn,deepspeed]'
### Usage
```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml
# finetune lora
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
@@ -172,7 +159,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
#### Docker
```bash
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest
```
Or run on the current files for development:
@@ -191,7 +178,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl
A more powerful Docker command to run would be this:
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-latest
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest
```
It additionally:
@@ -223,7 +210,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
#### Cloud GPU
For cloud GPU providers that support docker images, use [`axolotlai/axolotl-cloud:main-latest`](https://hub.docker.com/r/axolotlai/axolotl-cloud/tags)
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- on JarvisLabs.ai use this [direct link](https://jarvislabs.ai/templates/axolotl)
@@ -332,7 +319,7 @@ Write a job description in YAML as below:
# dstack.yaml
type: task
image: axolotlai/axolotl-cloud:main-latest
image: winglian/axolotl-cloud:main-20240429-py3.11-cu121-2.2.2
env:
- HUGGING_FACE_HUB_TOKEN
@@ -396,10 +383,11 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
# chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template
# fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template)
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ...
type: chat_template
chat_template: chatml # defaults to tokenizer's chat_template
type: sharegpt
conversation: chatml # default: vicuna_v1.1
# local
- path: data.jsonl # or json
@@ -574,8 +562,7 @@ plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
```

View File

@@ -1,4 +1,4 @@
FROM axolotlai/axolotl-base:{{ BASE_TAG }}
FROM winglian/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
@@ -28,7 +28,6 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
@@ -37,9 +36,6 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh
RUN python scripts/cutcrossentropy_install.py | sh
# So we can test the Docker image
RUN pip install -r requirements-dev.txt -r requirements-tests.txt

View File

@@ -1,6 +1,6 @@
#!/bin/bash
set -e
pytest -v --durations=10 -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -v --durations=10 -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -10,7 +10,7 @@ import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import App, Image
from modal import Image, Stub
cicd_path = pathlib.Path(__file__).parent.resolve()
@@ -46,7 +46,7 @@ cicd_image = (
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
app = App("Axolotl CI/CD", secrets=[])
stub = Stub("Axolotl CI/CD", secrets=[])
N_GPUS = int(os.environ.get("N_GPUS", 2))
@@ -61,7 +61,7 @@ def run_cmd(cmd: str, run_folder: str):
exit(exit_code) # pylint: disable=consider-using-sys-exit
@app.function(
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=60 * 60,
@@ -72,6 +72,6 @@ def cicd_pytest():
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
@app.local_entrypoint()
@stub.local_entrypoint()
def main():
cicd_pytest.remote()

View File

@@ -2,4 +2,4 @@
set -e
# only run one test at a time so as not to OOM the GPU
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
pytest -n1 /workspace/axolotl/tests/e2e/multigpu/

View File

@@ -10,7 +10,7 @@ import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import App, Image
from modal import Image, Stub
cicd_path = pathlib.Path(__file__).parent.resolve()
@@ -40,7 +40,6 @@ with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
)
@@ -48,7 +47,7 @@ cicd_image = (
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
app = App("Axolotl CI/CD", secrets=[])
stub = Stub("Axolotl CI/CD", secrets=[])
N_GPUS = int(os.environ.get("N_GPUS", 1))
@@ -63,7 +62,7 @@ def run_cmd(cmd: str, run_folder: str):
exit(exit_code) # pylint: disable=consider-using-sys-exit
@app.function(
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=60 * 60,
@@ -74,6 +73,6 @@ def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
@app.local_entrypoint()
@stub.local_entrypoint()
def main():
cicd_pytest.remote()

View File

@@ -1,4 +1,4 @@
# Example config for debugging the chat_template prompt format
# Example config for debugging the sharegpt prompt format
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
@@ -7,8 +7,8 @@ load_in_8bit: true
load_in_4bit: false
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
- path: philschmid/guanaco-sharegpt-style
type: sharegpt
shards: 10
val_set_size: 0
output_dir: temp_debug/axolotl_outputs/model

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main-base
FROM axolotlai/axolotl-base:$BASE_TAG
FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
@@ -26,9 +26,6 @@ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
fi
RUN python scripts/unsloth_install.py | sh
RUN python scripts/cutcrossentropy_install.py | sh
# So we can test the Docker image
RUN pip install pytest

View File

@@ -29,9 +29,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
RUN git lfs install --skip-repo && \
pip3 install awscli && \

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main
FROM axolotlai/axolotl:$BASE_TAG
FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main
FROM axolotlai/axolotl:$BASE_TAG
FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"

View File

@@ -1,5 +1,5 @@
ARG BASE_TAG=main-base
FROM axolotlai/axolotl-base:$BASE_TAG
FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""

View File

@@ -83,7 +83,7 @@ lora_on_cpu: true
datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, gpteacher, oasst, reflection]
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # Optional[str] path to source data files
@@ -91,7 +91,15 @@ datasets:
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation.
# Add additional keys from your dataset as input or output roles
roles:
input: # Optional[List[str]]. These will be masked based on train_on_input
output: # Optional[List[str]].
# Custom user instruction prompt
- path: repo
@@ -162,9 +170,6 @@ datasets:
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
Deduplicates datasets and test_datasets with identical entries.
dataset_exact_deduplication: true
# A list of one or more datasets to eval the model with.
# You can use either test_datasets, or val_set_size, but not both.
test_datasets:
@@ -178,8 +183,6 @@ test_datasets:
# use RL training: 'dpo', 'ipo', 'kto'
rl:
# whether to perform weighting if doing DPO training. Boolean.
dpo_use_weighting:
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.
@@ -409,7 +412,6 @@ lr_div_factor: # Learning rate div factor
# - adamw_torch_fused
# - adamw_torch_xla
# - adamw_apex_fused
# - adopt_adamw (an EXPERIMENTAL optimizer, only for torch version >= 2.5.1)
# - adafactor
# - adamw_anyprecision
# - sgd

View File

@@ -6,8 +6,33 @@ order: 3
## sharegpt
IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below.
conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
```{.json filename="data.jsonl"}
{"conversations": [{"from": "...", "value": "..."}]}
```
Note: `type: sharegpt` opens special configs:
- `conversation`: enables conversions to many Conversation types. Refer to the 'name' [here](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) for options.
- `roles`: allows you to specify the roles for input and output. This is useful for datasets with custom roles such as `tool` etc to support masking.
- `field_human`: specify the key to use instead of `human` in the conversation.
- `field_model`: specify the key to use instead of `gpt` in the conversation.
```yaml
datasets:
path: ...
type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation.
# Add additional keys from your dataset as input or output roles
roles:
input: # Optional[List[str]]. These will be masked based on train_on_input
output: # Optional[List[str]].
```
## pygmalion
@@ -15,6 +40,38 @@ IMPORTANT: ShareGPT is deprecated!. Please see `chat_template` section below.
{"conversations": [{"role": "...", "value": "..."}]}
```
## sharegpt.load_role
conversations where `role` is used instead of `from`
```{.json filename="data.jsonl"}
{"conversations": [{"role": "...", "value": "..."}]}
```
## sharegpt.load_guanaco
conversations where `from` is `prompter` `assistant` instead of default sharegpt
```{.json filename="data.jsonl"}
{"conversations": [{"from": "...", "value": "..."}]}
```
## sharegpt.load_ultrachat
conversations where the turns field is 'messages', human is 'user' and gpt is 'assistant'.
```{.json filename="data.jsonl"}
{"messages": [{"user": "...", "assistant": "..."}]}
```
## sharegpt_jokes
creates a chat where bot is asked to tell a joke, then explain why the joke is funny
```{.json filename="data.jsonl"}
{"conversations": [{"title": "...", "text": "...", "explanation": "..."}]}
```
## chat_template

View File

@@ -51,12 +51,12 @@ While debugging it's helpful to simplify your test scenario as much as possible.
### Background
The below example shows how to configure VSCode to debug data preprocessing of the `chat_template` format. This is the format used when you have the following in your axolotl config:
The below example shows how to configure VSCode to debug data preprocessing of the `sharegpt` format. This is the format used when you have the following in your axolotl config:
```yaml
datasets:
- path: <path to your chat_template formatted dataset> # example on HF Hub: fozziethebeat/alpaca_messages_2k_test
type: chat_template
- path: <path to your sharegpt formatted dataset> # example on HF Hub: philschmid/guanaco-sharegpt-style
type: sharegpt
```
>[!Important]
@@ -83,7 +83,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_sharegpt.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
```jsonc
// .vscode/launch.json
@@ -91,12 +91,12 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
"version": "0.2.0",
"configurations": [
{
"name": "Debug axolotl prompt - chat_template",
"name": "Debug axolotl prompt - sharegpt",
"type": "python",
"module": "accelerate.commands.launch",
"request": "launch",
"args": [
"-m", "axolotl.cli.train", "dev_chat_template.yml",
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
// The flags below simplify debugging by overriding the axolotl config
// with the debugging tips above. Modify as needed.
"--dataset_processes=1", // limits data preprocessing to one process
@@ -185,7 +185,7 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
## Debugging With Docker
Using [official Axolotl Docker images](https://hub.docker.com/r/axolotlai/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
Using [official Axolotl Docker images](https://hub.docker.com/r/winglian/axolotl/tags) is a great way to debug your code, and is a very popular way to use Axolotl. Attaching VSCode to Docker takes a few more steps.
### Setup
@@ -202,11 +202,11 @@ cd axolotl
Next, run the desired docker image and mount the current directory. Below is a docker command you can run to do this:[^2]
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface axolotlai/axolotl:main-py3.10-cu118-2.0.1
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
```
>[!Tip]
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/axolotlai/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
> To understand which containers are available, see the [Docker section of the README](../README.md#docker) and the [DockerHub repo](https://hub.docker.com/r/winglian/axolotl/tags). For details of how the Docker containers are built, see axolotl's [Docker CI builds](../.github/workflows/main.yml).
You will now be in the container. Next, perform an editable install of Axolotl:
@@ -240,6 +240,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
</div>
<br>
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/sharegpt.yml`, but this is the same thing.
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).

View File

@@ -11,10 +11,12 @@ standard industry baselines.
### Installation
The following will install the correct unsloth and extras from source.
The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
to date libraries.
```bash
python scripts/unsloth_install.py | sh
pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps --force-reinstall xformers==0.0.26.post1
```
### Using unsloth w Axolotl

View File

@@ -2,15 +2,19 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "AKjdG7tbTb-n"
},
"source": [
"## Setting up"
"# Example notebook for running Axolotl on google colab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"id": "RcbNpOgWRcii"
},
"outputs": [],
"source": [
"import torch\n",
@@ -18,76 +22,82 @@
"assert (torch.cuda.is_available()==True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h3nLav8oTRA5"
},
"source": [
"## Install Axolotl and dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3c3yGAwnOIdi",
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01"
},
"outputs": [],
"source": [
"!pip install axolotl[deepspeed]"
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n",
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "BW2MFr7HTjub"
},
"source": [
"## Hugging Face login (optional)"
"## Create an yaml config file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import notebook_login\n",
"notebook_login()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"id": "9pkF2dSoQEUN"
},
"outputs": [],
"source": [
"import yaml\n",
"\n",
"# Your YAML string\n",
"yaml_string = \"\"\"\n",
"base_model: NousResearch/Meta-Llama-3.1-8B\n",
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
"model_type: LlamaForCausalLM\n",
"tokenizer_type: LlamaTokenizer\n",
"\n",
"load_in_8bit: false\n",
"load_in_4bit: true\n",
"strict: false\n",
"\n",
"datasets:\n",
" - path: tatsu-lab/alpaca\n",
" - path: mhenrichsen/alpaca_2k_test\n",
" type: alpaca\n",
"dataset_prepared_path: last_run_prepared\n",
"dataset_prepared_path:\n",
"val_set_size: 0.05\n",
"output_dir: ./outputs/lora-out\n",
"\n",
"sequence_len: 2048\n",
"sample_packing: true\n",
"eval_sample_packing: true\n",
"pad_to_sequence_len: true\n",
"output_dir: ./outputs/qlora-out\n",
"\n",
"adapter: qlora\n",
"lora_model_dir:\n",
"\n",
"sequence_len: 4096\n",
"sample_packing: true\n",
"eval_sample_packing: false\n",
"pad_to_sequence_len: true\n",
"\n",
"lora_r: 32\n",
"lora_alpha: 16\n",
"lora_dropout: 0.05\n",
"lora_target_modules:\n",
"lora_target_linear: true\n",
"lora_fan_in_fan_out:\n",
"lora_modules_to_save:\n",
" - embed_tokens\n",
" - lm_head\n",
"\n",
"wandb_project:\n",
"wandb_entity:\n",
@@ -95,12 +105,12 @@
"wandb_name:\n",
"wandb_log_model:\n",
"\n",
"gradient_accumulation_steps: 2\n",
"micro_batch_size: 1\n",
"num_epochs: 1\n",
"optimizer: paged_adamw_8bit\n",
"gradient_accumulation_steps: 4\n",
"micro_batch_size: 2\n",
"num_epochs: 4\n",
"optimizer: paged_adamw_32bit\n",
"lr_scheduler: cosine\n",
"learning_rate: 2e-5\n",
"learning_rate: 0.0002\n",
"\n",
"train_on_inputs: false\n",
"group_by_length: false\n",
@@ -111,15 +121,13 @@
"gradient_checkpointing: true\n",
"early_stopping_patience:\n",
"resume_from_checkpoint:\n",
"local_rank:\n",
"logging_steps: 1\n",
"xformers_attention:\n",
"flash_attention: false\n",
"sdp_attention: true\n",
"flash_attention: true\n",
"\n",
"warmup_steps: 1\n",
"max_steps: 25\n",
"evals_per_epoch: 1\n",
"eval_table_size:\n",
"warmup_steps: 10\n",
"evals_per_epoch: 4\n",
"saves_per_epoch: 1\n",
"debug:\n",
"deepspeed:\n",
@@ -127,9 +135,8 @@
"fsdp:\n",
"fsdp_config:\n",
"special_tokens:\n",
" pad_token: <|end_of_text|>\n",
"\"\"\"\n",
"\n",
"\"\"\"\n",
"\n",
"# Convert the YAML string to a Python dictionary\n",
"yaml_dict = yaml.safe_load(yaml_string)\n",
@@ -139,124 +146,31 @@
"\n",
"# Write the YAML file\n",
"with open(file_path, 'w') as file:\n",
" yaml.dump(yaml_dict, file)"
" yaml.dump(yaml_dict, file)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "bidoj8YLTusD"
},
"source": [
"Above we have a configuration file with base LLM model and datasets specified, among many other things. Axolotl can automatically detect whether the specified datasets are on HuggingFace repo or local machine.\n",
"\n",
"The Axolotl configuration options encompass model and dataset selection, data pre-processing, and training. Let's go through them line by line:\n",
"\n",
"* \"base model\": String value, specifies the underlying pre-trained LLM that will be used for finetuning\n",
"\n",
"Next we have options for model weights quantization. Quantization allows for reduction in occupied memory on GPUs.\n",
"\n",
"* \"load_in_8bit\": Boolean value, whether to quantize the model weights into 8-bit integer.\n",
"\n",
"* \"load_in_4bit\": Boolean value, whether to quantize the model weights into 4-bit integer.\n",
"\n",
"* \"strict\": Boolean value. If false, it allows for overriding established configuration options in the yaml file when executing in command-line interface.\n",
"\n",
"* \"datasets\": a list of dicts that contain path and type of data sets as well as other optional configurations where datasets are concerned. Supports multiple datasets.\n",
"\n",
"* \"val_set_size\": Either a float value less than one or an integer less than the total size of dataset. Sets the size of validation set from the whole dataset. If float, sets the proportion of the dataset assigned for validation. If integer, sets the direct size of validation set.\n",
"\n",
"* \"output_dir\": String value. Path of trained model.\n",
"\n",
"For data preprocessing:\n",
"\n",
"* \"sequence_len\": Integer. Specifies the maximum sequence length of the input. Typically 2048 or less.\n",
"\n",
"* \"pad_to_sequence_len\": Boolean. Padding input to maximum sequence length.\n",
"\n",
"* \"sample_packing\": Boolean. Specifies whether to use multi-packing with block diagonal attention.\n",
"\n",
"* \"special_tokens\": Python dict, optional. Allows users to specify the additional special tokens to be ignored by the tokenizer.\n",
"\n",
"For LoRA configuration and its hyperparamters:\n",
"\n",
"* \"adapter\": String. Either \"lora\" or \"qlora\", depending on user's choice.\n",
"\n",
"* \"lora_model_dir\": String, Optional. Path to directory that contains LoRA model, if there is already a trained LoRA model the user would like to use.\n",
"\n",
"* \"lora_r\": Integer. Refers to the rank of LoRA decomposition matrices. Higher value will reduce LoRA efficiency. Recommended to be set to 8.\n",
"\n",
"* \"lora_alpha\": Integer. Scale the weight matrices by $\\frac{\\text{lora_alpha}}{\\text{lora_r}}$Recommended to be fixed at 16.\n",
"\n",
"* \"lora_dropout\": Float that is 1 or less. The dropout probability of a lora layer.\n",
"\n",
"* \"lora_target_linear\": Boolean. If true, lora will target all linear modules in the transformers architecture.\n",
"\n",
"* \"lora_modules_to_save\": If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.\n",
"\n",
"See [LoRA](https://arxiv.org/abs/2106.09685) for detailed explanation of LoRA implementation.\n",
"\n",
"For the training configurations:\n",
"\n",
"* \"gradient_accumulation_steps\": Integer. The number of steps over which to accumulate gradient for batch training. E.g. if 2, backprop is performed every two steps.\n",
"\n",
"* \"micro_batch_size\": Integer. Batch size per gpu / gradient_accumulation_steps\n",
"\n",
"* \"num_epochs\": Integer. Number of epochs. One epoch is when training has looped over every batch in the whole data set once.\n",
"\n",
"* \"optimizer\": The optimizer to use for the training.\n",
"\n",
"* \"learning_rate\": The learning rate.\n",
"\n",
"* \"lr_scheduler\": The learning rate scheduler to use for adjusting learning rate during training.\n",
"\n",
"* \"train_on_inputs\": Boolean. Whether to ignore or include the user's prompt from the training labels.\n",
"\n",
"* \"group_by_length\": Boolean. Whether to group similarly sized data to minimize padding.\n",
"\n",
"* \"bf16\": Either \"auto\", \"true\", or \"false\". Whether to use CUDA bf16 floating point format. If set to \"auto\", will automatically apply bf16 should the gpu supports it.\n",
"\n",
"* \"fp16\": Optional. Specifies whether to use CUDA fp16. Automatically set to true if \"bf16\" is set to true. Otherwise false.\n",
"\n",
"* \"tf32\": Boolean. Whether to use CUDA tf32. Will override bf16.\n",
"\n",
"* \"gradient_checkpointing\": Boolean. Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing\n",
"\n",
"* \"gradient_checkpointing_kwargs\": Python Dict. Fed into the trainer.\n",
"\n",
"* \"logging_steps\": Integer. Log training information over every specified number of steps.\n",
"\n",
"* \"flash_attention\": Boolean. Whether to use the [flash attention](https://github.com/Dao-AILab/flash-attention) mechanism.\n",
"\n",
"* \"sdp_attention\": Boolean. Whether to use the Scaled Dot Product attention mechanism (the attention mechanism in the [original implementation](https://arxiv.org/abs/1706.03762) of transformers.)\n",
"\n",
"* \"warmup_steps\": Integer. The number of pre-training steps where a very low learning rate is used.\n",
"\n",
"* \"evals_per_epoch\": Integer. Number of evaluations to be performed within one training epoch.\n",
"\n",
"* \"saves_per_epoch\": Integer. Number of times the model is saved in one training epoch.\n",
"\n",
"* \"weight_decay\": Positive Float. Sets the \"strength\" of weight decay (i.e. setting the coefficient of L2 regularization)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above is but a snippet aiming to get users familiarized with the types of streamlined configuration options axolotl provides. For a full list of configuration options, see [here](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train the model"
"## Launch the training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ydTI2Jk2RStU",
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b"
},
"outputs": [],
"source": [
"# By using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
]
},
@@ -264,7 +178,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Predict with trained model"
"## Play with inference"
]
},
{
@@ -273,85 +187,36 @@
"metadata": {},
"outputs": [],
"source": [
"# By using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --lora_model_dir=\"./outputs/lora-out\" --gradio"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Deeper Dive"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It is also helpful to gain some familiarity over some of the core inner workings of axolotl"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configuration Normalization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Axolotl uses a custom Dict class, called ```DictDefault```\n",
"to store configurations specified in the yaml configuration file (into a Python variable named ```cfg```). The definition for this custom Dict can be found in the [utils/dict.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/dict.py)\n",
"\n",
"```DictDefault``` is amended such that calling a missing key from it will result in a ```None``` return type. This is important because if some configuration options aren't specified by the user, the ```None``` type allows Axolotl to perform boolean operations to determine the default settings for missing configurations. For more examples on how this is done, check out [utils/config/__init__.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/__init__.py)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading Models, Tokenizers, and Trainer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we inspect [cli.train.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/cli/train.py), we will find that most of the heavy lifting were done by the function ```train()``` which is itself imported from [src/axolotl/train.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/train.py).\n",
"\n",
"```train()``` takes care of loading the appropriate tokenizer and pre-trained model through ```load_model()``` and ```load_tokenizer()``` from [src/axolotl/utils/models.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/models.py) respectively.\n",
"\n",
"```load_tokenizer()``` loads in the appropriate tokenizer given the desired model, as well as chat templates.\n",
"\n",
"```ModelLoader``` class follows after tokenizer has been selected. It will automatically discern the base model type, load in the desired model, as well as applying model-appropriate attention mechanism modifications (e.g. flash attention). Depending on which base model the user chooses in the configuration, ```ModelLoader``` will utilize the corresponding \"attention hijacking\" script. For example, if the user specified the base model to be ```NousResearch/Meta-Llama-3.1-8B```, which is of llama type, and set ```flash_attn``` to ```True```, ```ModelLoader``` will load in [llama_attn_hijack_flash.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/llama_attn_hijack_flash.py). For a list of supported attention hijacking, please refer to the directory [/src/axolotl/monkeypatch/](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/monkeypatch)\n",
"\n",
"Another important operation encompassed in ```train()``` is setting up the training that takes into account of user-specified traning configurations (e.g. num_epochs, optimizer) through the use of ```setup_trainer()``` from [/src/axolotl/utils/trainer.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/trainer.py), which in turn relies on modules from [/src/axolotl/core/trainer_builder.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/core/trainer_builder.py).\n",
"```trainer_builder.py``` provides a list of trainer object options bespoke for the task type (Causal or Reinforcement learning ('dpo', 'ipo', 'kto') )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Monkey patch\n",
"\n",
"The [Monkey patch directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/monkeypatch) is where model architecture/optimization patching scripts are stored (these are modifications that are not implemented in the official releases, hence the name monkey patch). It includes attention jacking, ReLoRA, and unsloth optimization."
" --qlora_model_dir=\"./qlora-out\" --gradio"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.9.6"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}

View File

@@ -9,17 +9,14 @@ strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rms_norm: true
liger_glu_activation: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
chat_template: deepseek_v2
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -11,11 +11,8 @@ chat_template: gemma
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: gemma
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
val_set_size: 0.0
output_dir: ./outputs/out

View File

@@ -4,15 +4,11 @@ tokenizer_type: AutoTokenizer
load_in_4bit: true
strict: false
use_tensorboard: true
chat_template: jamba
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: jamba
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: jamba-large-fsdp-qlora-ft

View File

@@ -4,7 +4,7 @@ plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
strict: false
@@ -14,10 +14,6 @@ datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./outputs/out

View File

@@ -1,95 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
chat_template: llama3
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
dataset_exact_deduplication: true
dataset_prepared_path:
val_set_size: 0
output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -1,76 +0,0 @@
base_model: meta-llama/Llama-3.2-1B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/lora-out
dataset_exact_deduplication: true
test_value: true
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,63 +0,0 @@
base_model: llava-hf/llava-1.5-7b-hf
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: llava
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -1,93 +0,0 @@
#Note that we are switching from the regular chat template to chatml.
#If you experience problems with the special tokens, training for more epochs can help.
#After training, merge the model before inference otherwise you might
#face problems with the special tokens.
base_model: mistralai/Mistral-7B-Instruct-v0.2
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: true
strict: false
chat_template: chatml
rl: dpo
datasets:
- path: olivermolenschot/alpaca_messages_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/dpo-qlora
sequence_len: 2048
sample_packing: false
pad_to_sequence_len: true
adapter: qlora
lora_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.2
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
lora_modules_to_save:
- embed_tokens
- lm_head
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 16
num_epochs: 6
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<|im_start|>"
eos_token: "<|im_end|>"

View File

@@ -10,6 +10,7 @@ chat_template: phi_3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
chat_template: phi_3
field_messages: messages
message_field_role: role
message_field_content: content

View File

@@ -1,65 +0,0 @@
base_model: mistral-community/pixtral-12b
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: pixtral
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,63 +0,0 @@
base_model: Qwen/Qwen2-VL-7B-Instruct
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: qwen2_vl
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -1,67 +0,0 @@
base_model: Qwen/Qwen2.5-0.5B
strict: false
chat_template: qwen_25
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/dpo-out
sequence_len: 2048
sample_packing: false
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 24 KiB

After

Width:  |  Height:  |  Size: 11 KiB

View File

@@ -1,19 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 1113 283.5">
<path fill="#141310" d="M435,234.3l-12.1-48.8h-54.4l-12.1,48.8h-24.7l48.2-185.1h31.6l47.9,185.1h-24.5ZM417.7,164.9l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
<path fill="#141310" d="M568.2,234.3l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.6-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.4,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.2,49.3,71.7h-28.5Z"/>
<path fill="#141310" d="M658.6,236.3c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM658.6,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path fill="#141310" d="M860.6,236.3c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM860.6,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path fill="#141310" d="M773.9,234c-18,0-32.6-14.6-32.6-32.6V48.8h24.1v152.6c0,4.7,3.8,8.5,8.5,8.5h16.8v24.1h-16.8Z"/>
<path fill="#141310" d="M1036.2,234.3V81.4c0-4.7-3.8-8.5-8.5-8.5h-16.8v-24.1h16.8c18,0,32.6,14.6,32.6,32.6v152.9h-24.1Z"/>
<path fill="#141310" d="M978.6,234.3c-18,0-32.6-14.6-32.6-32.6v-85.1h-20.3v-22.1h20.3v-45.3h24.1v45.3h30.2v22.1h-30.2v85.1c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
<path fill="#141310" d="M51.5,49h12.2v-20.6h-12.2c-16,0-29,13-29,29v32.8h20.6v-32.8c0-4.7,3.8-8.4,8.4-8.4Z"/>
<path fill="#141310" d="M92.8,49h12.2v-20.6h-12.2c-16,0-29,13-29,29v12.2h20.6v-12.2c0-4.7,3.8-8.4,8.4-8.4Z"/>
<path fill="#141310" d="M249.3,57.4c0-16-13-29-29-29h-12.2v20.6h12.2c4.7,0,8.4,3.8,8.4,8.4v32.8h20.6v-32.8Z"/>
<path fill="#141310" d="M187.4,90.2v-20.6h-103.1v20.6h-41.2v20.6h-20.6v41.2c0,11.4,9.2,20.6,20.6,20.6h185.5c11.4,0,20.6-9.2,20.6-20.6v-41.2h-20.6v-20.6h-41.2ZM166.8,141.7c0-5.7-4.6-10.3-10.3-10.3s-10.3,4.6-10.3,10.3v10.3h-20.6v-20.6c0-11.4,9.2-20.6,20.6-20.6s20.6,9.2,20.6,20.6v10.3ZM228.7,141.7c0-5.7-4.6-10.3-10.3-10.3s-10.3,4.6-10.3,10.3v10.3h-20.6v-20.6c0-11.4,9.2-20.6,20.6-20.6s20.6,9.2,20.6,20.6v10.3Z"/>
<path fill="#141310" d="M208,57.4c0-16-13-29-29-29h-12.2v20.6h12.2c4.7,0,8.4,3.8,8.4,8.4v12.2h20.6v-12.2Z"/>
<rect fill="#141310" x="22.5" y="234.5" width="41.2" height="20.6"/>
<rect fill="#141310" x="84.3" y="234.5" width="164.9" height="20.6"/>
<rect fill="#141310" x="208" y="193.3" width="41.2" height="20.6"/>
<rect fill="#141310" x="22.5" y="193.3" width="164.9" height="20.6"/>
</svg>

Before

Width:  |  Height:  |  Size: 3.2 KiB

View File

@@ -1,11 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 1113 283.5">
<path fill="#fff" d="M462.9,234.2l-12.1-48.8h-54.4l-12.1,48.8h-24.7l48.2-185h31.6l47.9,185h-24.4ZM445.7,164.8l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
<path fill="#fff" d="M596.1,234.2l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.5-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.3,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.1,49.3,71.6h-28.5Z"/>
<path fill="#fff" d="M686.4,236.2c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.6,14.8-41.4,9.8-9.7,23.4-14.7,40.2-14.7s30.4,4.9,40.2,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.4-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM686.4,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.8v-36.7c0-10.5-2.8-18.5-8.2-23.8-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path fill="#fff" d="M888.3,236.2c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.6,14.8-41.4,9.8-9.7,23.4-14.7,40.2-14.7s30.4,4.9,40.2,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.4-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM888.3,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.8v-36.7c0-10.5-2.8-18.5-8.2-23.8-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path fill="#fff" d="M801.7,234c-18,0-32.6-14.6-32.6-32.6V48.8h24.1v152.5c0,4.7,3.8,8.5,8.5,8.5h16.7v24.1h-16.7Z"/>
<path fill="#fff" d="M1063.8,234.2V81.4c0-4.7-3.8-8.5-8.5-8.5h-16.7v-24.1h16.7c18,0,32.6,14.6,32.6,32.6v152.8h-24.1Z"/>
<path fill="#fff" d="M1006.2,234.2c-18,0-32.6-14.6-32.6-32.6v-85h-20.3v-22.1h20.3v-45.2h24.1v45.2h30.2v22.1h-30.2v85c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
<path fill="#fff" d="M160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM277.3,57.4c0-23.8-19.3-43.1-43.1-43.1h-12.2c-3.9,0-7.6,1.6-10.2,4.4-5.9-2.9-12.3-4.4-18.9-4.4h-12.2c-7.7,0-14.1,6.3-14.1,14.1v20.6c0,2.4.6,4.6,1.6,6.6h-37c1-2,1.6-4.2,1.6-6.6v-20.6c0-7.7-6.3-14.1-14.1-14.1h-12.2c-6.5,0-13,1.5-18.9,4.4-2.6-2.8-6.3-4.4-10.2-4.4h-12.2c-23.8,0-43.1,19.3-43.1,43.1v32.8c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v41.2c0,11,5.2,20.8,13.2,27.2-7.3.4-13.2,6.6-13.2,14v20.6c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v20.6c0,7.7,6.3,14.1,14.1,14.1h41.2c4.1,0,7.7-1.7,10.3-4.5,2.6,2.8,6.2,4.5,10.3,4.5h164.9c7.7,0,14.1-6.3,14.1-14.1v-20.6c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-20.6c0-7.5-5.8-13.6-13.2-14,8-6.4,13.2-16.2,13.2-27.2v-41.2c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-32.8ZM77.8,255.1h-41.2v-20.6h41.2v20.6ZM36.5,213.9v-20.6h164.9v20.6H36.5ZM263.3,255.1H98.4v-20.6h164.9v20.6ZM263.3,213.9h-41.2v-20.6h41.2v20.6ZM263.3,90.2h-20.6v20.6h20.6v41.2c0,11.4-9.2,20.6-20.6,20.6H57.2c-11.4,0-20.6-9.2-20.6-20.6v-41.2h20.6v-20.6h-20.6v-32.8c0-16,13-29,29-29h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v32.8h41.2v-20.6h-20.6v-12.2c0-16,13-29,29-29h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v12.2h103.1v-12.2c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16,0,29,13,29,29v12.2h-20.6v20.6h41.2v-32.8c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16,0,29,13,29,29v32.8ZM201.4,152h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6s-20.6,9.2-20.6,20.6v20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6Z"/>
</svg>

Before

Width:  |  Height:  |  Size: 6.6 KiB

View File

@@ -1,26 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 283.5 283.5">
<defs>
<style>
.cls-1 {
fill: #141310;
}
</style>
</defs>
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
<g>
<g id="Layer_1">
<g>
<path class="cls-1" d="M46.9,37.4h13.7V14.2h-13.7c-18,0-32.7,14.6-32.7,32.7v36.9h23.2v-36.9c0-5.2,4.2-9.5,9.5-9.5Z"/>
<path class="cls-1" d="M93.2,37.4h13.7V14.2h-13.7c-18,0-32.7,14.6-32.7,32.7v13.7h23.2v-13.7c0-5.2,4.2-9.5,9.5-9.5Z"/>
<path class="cls-1" d="M269.3,46.9c0-18-14.6-32.7-32.7-32.7h-13.7v23.2h13.7c5.2,0,9.5,4.2,9.5,9.5v36.9h23.2v-36.9Z"/>
<path class="cls-1" d="M199.7,83.8v-23.2h-116v23.2h-46.4v23.2H14.2v46.4c0,12.8,10.4,23.2,23.2,23.2h208.7c12.8,0,23.2-10.4,23.2-23.2v-46.4h-23.2v-23.2h-46.4ZM176.5,141.7c0-6.4-5.2-11.6-11.6-11.6s-11.6,5.2-11.6,11.6v11.6h-23.2v-23.2c0-12.8,10.4-23.2,23.2-23.2s23.2,10.4,23.2,23.2v11.6ZM246.1,141.7c0-6.4-5.2-11.6-11.6-11.6s-11.6,5.2-11.6,11.6v11.6h-23.2v-23.2c0-12.8,10.4-23.2,23.2-23.2s23.2,10.4,23.2,23.2v11.6Z"/>
<path class="cls-1" d="M222.9,46.9c0-18-14.6-32.7-32.7-32.7h-13.7v23.2h13.7c5.2,0,9.5,4.2,9.5,9.5v13.7h23.2v-13.7Z"/>
<rect class="cls-1" x="14.2" y="246.1" width="46.4" height="23.2"/>
<rect class="cls-1" x="83.8" y="246.1" width="185.5" height="23.2"/>
<rect class="cls-1" x="222.9" y="199.7" width="46.4" height="23.2"/>
<rect class="cls-1" x="14.2" y="199.7" width="185.5" height="23.2"/>
</g>
</g>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 1.6 KiB

View File

@@ -1,16 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 283.5 283.5">
<defs>
<style>
.cls-1 {
fill: #fff;
}
</style>
</defs>
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
<g>
<g id="Layer_1">
<path class="cls-1" d="M152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM269.3,57.3c0-23.8-19.4-43.1-43.1-43.1h-12.2c-3.9,0-7.6,1.6-10.2,4.4-5.9-2.9-12.3-4.4-18.9-4.4h-12.2c-7.8,0-14.1,6.3-14.1,14.1v20.6c0,2.4.6,4.6,1.6,6.6h-37c1-2,1.6-4.2,1.6-6.6v-20.6c0-7.8-6.3-14.1-14.1-14.1h-12.2c-6.6,0-13,1.5-18.9,4.4-2.6-2.8-6.3-4.4-10.2-4.4h-12.2c-23.8,0-43.1,19.4-43.1,43.1v32.8c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v41.3c0,11,5.2,20.9,13.2,27.2-7.4.4-13.2,6.6-13.2,14v20.6c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v20.6c0,7.8,6.3,14.1,14.1,14.1h41.3c4.1,0,7.7-1.7,10.3-4.5,2.6,2.8,6.2,4.5,10.3,4.5h165.1c7.8,0,14.1-6.3,14.1-14.1v-20.6c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-20.6c0-7.5-5.9-13.6-13.2-14,8-6.4,13.2-16.2,13.2-27.2v-41.3c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-32.8ZM69.5,255.2H28.2v-20.6h41.3v20.6ZM28.2,214v-20.6h165.1v20.6H28.2ZM255.2,255.2H90.1v-20.6h165.1v20.6ZM255.2,214h-41.3v-20.6h41.3v20.6ZM255.2,90.1h-20.6v20.6h20.6v41.3c0,11.4-9.2,20.6-20.6,20.6H48.9c-11.4,0-20.6-9.2-20.6-20.6v-41.3h20.6v-20.6h-20.6v-32.8c0-16.1,13-29.1,29.1-29.1h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v32.8h41.3v-20.6h-20.6v-12.2c0-16.1,13-29.1,29.1-29.1h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v12.2h103.2v-12.2c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16.1,0,29.1,13,29.1,29.1v12.2h-20.6v20.6h41.3v-32.8c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16.1,0,29.1,13,29.1,29.1v32.8ZM193.3,152h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6s-20.6,9.2-20.6,20.6v20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6Z"/>
</g>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 5.0 KiB

View File

@@ -1,17 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 765.4 212.6">
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
<g>
<g id="Layer_1">
<g>
<path d="M121.6,198.1l-12.1-48.8h-54.4l-12.1,48.8h-24.7L66.6,12.9h31.6l47.9,185.1h-24.5ZM104.4,128.6l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
<path d="M254.9,198.1l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.6-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.4,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.2,49.3,71.7h-28.5Z"/>
<path d="M345.2,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM345.2,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path d="M547.3,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM547.3,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path d="M460.6,197.8c-18,0-32.6-14.6-32.6-32.6V12.5h24.1v152.6c0,4.7,3.8,8.5,8.5,8.5h16.8v24.1h-16.8Z"/>
<path d="M722.8,198.1V45.2c0-4.7-3.8-8.5-8.5-8.5h-16.8V12.5h16.8c18,0,32.6,14.6,32.6,32.6v152.9h-24.1Z"/>
<path d="M665.2,198.1c-18,0-32.6-14.6-32.6-32.6v-85.1h-20.3v-22.1h20.3V12.9h24.1v45.3h30.2v22.1h-30.2v85.1c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
</g>
</g>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 2.1 KiB

View File

@@ -1,24 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 765.4 212.6">
<defs>
<style>
.cls-1 {
fill: #fff;
}
</style>
</defs>
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
<g>
<g id="Layer_1">
<g>
<path class="cls-1" d="M121.6,198.1l-12.1-48.8h-54.4l-12.1,48.8h-24.7L66.6,12.9h31.6l47.9,185.1h-24.5ZM104.4,128.6l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
<path class="cls-1" d="M254.9,198.1l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.6-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.4,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.2,49.3,71.7h-28.5Z"/>
<path class="cls-1" d="M345.2,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM345.2,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path class="cls-1" d="M547.3,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM547.3,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
<path class="cls-1" d="M460.6,197.8c-18,0-32.6-14.6-32.6-32.6V12.5h24.1v152.6c0,4.7,3.8,8.5,8.5,8.5h16.8v24.1h-16.8Z"/>
<path class="cls-1" d="M722.8,198.1V45.2c0-4.7-3.8-8.5-8.5-8.5h-16.8V12.5h16.8c18,0,32.6,14.6,32.6,32.6v152.9h-24.1Z"/>
<path class="cls-1" d="M665.2,198.1c-18,0-32.6-14.6-32.6-32.6v-85.1h-20.3v-22.1h20.3V12.9h24.1v45.3h30.2v22.1h-30.2v85.1c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
</g>
</g>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 2.3 KiB

View File

@@ -2,3 +2,4 @@ pre-commit
black
mypy
types-requests
tbparse

View File

@@ -1,5 +1,2 @@
pytest
pytest-xdist
pytest-retry
pytest-sugar
tbparse

View File

@@ -1,18 +1,18 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.2
transformers==4.46.3
transformers==4.46.0
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.1.0
datasets==3.1.0
deepspeed==0.15.4
accelerate==1.0.1
datasets==3.0.1
deepspeed==0.15.3
pydantic==2.6.3
addict
fire
PyYAML>=6.0
requests
flash-attn==2.7.0.post2
flash-attn==2.6.3
sentencepiece
wandb
einops
@@ -26,14 +26,15 @@ numpy>=1.24.4,<=2.0.1
evaluate==0.4.1
scipy
scikit-learn==1.4.2
nvidia-ml-py==12.560.30
pynvml
art
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq==0.2.7.post2
autoawq>=0.2.5
triton>=2.3.0
liger-kernel==0.4.2
liger-kernel==0.3.0
mamba-ssm==1.2.0.post1
@@ -42,7 +43,7 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs
trl==0.12.0
trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
zstandard==0.22.0
fastcore
@@ -53,4 +54,3 @@ immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
torchao==0.5.0
schedulefree==1.3.0

View File

@@ -2,7 +2,7 @@
# Export specific ENV variables to /etc/rp_environment
echo "Exporting environment variables..."
printenv | grep -E '^HF_|^BNB_|^CUDA_|^NCCL_|^NV|^RUNPOD_|^PATH=|^_=' | sed 's/^\([^=]*\)=\(.*\)$/export \1="\2"/' | grep -v 'printenv' >> /etc/rp_environment
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
echo 'source /etc/rp_environment' >> ~/.bashrc
add_keys_to_authorized() {

View File

@@ -1,28 +0,0 @@
"""Script to output the correct installation command for cut-cross-entropy."""
import importlib.util
import sys
try:
import torch
except ImportError as exc:
raise ImportError("Install torch via `pip install torch`") from exc
from packaging.version import Version as V
v = V(torch.__version__)
# no cut-cross-entropy support for torch < 2.4.0
if v < V("2.4.0"):
print("")
sys.exit(0)
cce_spec = importlib.util.find_spec("cut_cross_entropy")
cce_spec_transformers = importlib.util.find_spec("cut_cross_entropy.transformers")
UNINSTALL_PREFIX = ""
if cce_spec and not cce_spec_transformers:
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
print(
UNINSTALL_PREFIX
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
)

View File

@@ -1,36 +0,0 @@
# noqa
# pylint: skip-file
try:
import torch
except ImportError:
raise ImportError("Install torch via `pip install torch`")
from packaging.version import Version as V
v = V(torch.__version__)
cuda = str(torch.version.cuda)
try:
is_ampere = torch.cuda.get_device_capability()[0] >= 8
except RuntimeError:
is_ampere = False
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V("2.1.0"):
raise RuntimeError(f"Torch = {v} too old!")
elif v <= V("2.1.1"):
x = "cu{}{}-torch211"
elif v <= V("2.1.2"):
x = "cu{}{}-torch212"
elif v < V("2.3.0"):
x = "cu{}{}-torch220"
elif v < V("2.4.0"):
x = "cu{}{}-torch230"
elif v < V("2.5.0"):
x = "cu{}{}-torch240"
elif v < V("2.6.0"):
x = "cu{}{}-torch250"
else:
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
print(
f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"'
)

View File

@@ -39,10 +39,7 @@ def parse_requirements():
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.5.1"
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
@@ -57,10 +54,6 @@ def parse_requirements():
if (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers==0.0.28.post3")
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:
@@ -96,19 +89,22 @@ install_requires, dependency_links = parse_requirements()
setup(
name="axolotl",
version="0.5.2",
version="0.4.1",
description="LLM Trainer",
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
package_dir={"": "src"},
packages=find_packages("src"),
packages=find_packages(),
install_requires=install_requires,
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn==2.7.0.post2",
"flash-attn==2.6.3",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib",
],
"deepspeed": [
"deepspeed==0.15.4",
"deepspeed==0.14.4",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -27,17 +27,14 @@ from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
prepare_plugins,
validate_config,
)
from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
@@ -100,8 +97,8 @@ def print_dep_versions():
print("*" * 40)
print("**** Axolotl Dependency Versions *****")
for pkg in packages:
pkg_version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {version[1]: <15}")
print("*" * 40)
@@ -139,7 +136,7 @@ def check_remote_config(config: Union[str, Path]):
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n"
f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
)
return output_path
@@ -193,19 +190,18 @@ def do_inference(
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -215,31 +211,13 @@ def do_inference(
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
@@ -279,6 +257,13 @@ def do_inference_gradio(
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
default_tokens: Dict[str, str] = {}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
chat_template_str = None
@@ -287,7 +272,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
chat_template_str = get_chat_template(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -426,6 +411,11 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
cfg.axolotl_config_path = config
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
@@ -439,13 +429,8 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
},
)
prepare_plugins(cfg)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)

View File

@@ -19,7 +19,7 @@ from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, inference=True, **kwargs)
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -23,6 +23,10 @@ from axolotl.cli import (
)
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
)
from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger("axolotl.cli.preprocess")
@@ -40,6 +44,23 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
return_remaining_strings=True
)
if parsed_cfg.chat_template == "chatml":
if parsed_cfg.default_system_message:
LOG.info(
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_chatml_template(parsed_cfg.default_system_message)
else:
register_chatml_template()
elif parsed_cfg.chat_template == "llama3":
if parsed_cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_llama3_template(parsed_cfg.default_system_message)
else:
register_llama3_template()
if not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED

View File

@@ -19,6 +19,10 @@ from axolotl.cli import (
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.integrations.base import PluginManager
from axolotl.prompt_strategies.sharegpt import (
register_chatml_template,
register_llama3_template,
)
from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train")
@@ -38,6 +42,21 @@ def do_train(cfg, cli_args) -> None:
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if cfg.chat_template == "chatml" and cfg.default_system_message:
LOG.info(
f"ChatML set. Adding default system message: {cfg.default_system_message}"
)
register_chatml_template(cfg.default_system_message)
else:
register_chatml_template()
if cfg.chat_template == "llama3" and cfg.default_system_message:
LOG.info(
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
)
register_llama3_template(cfg.default_system_message)
else:
register_llama3_template()
if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -48,7 +48,6 @@ from trl import (
)
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available
@@ -107,22 +106,6 @@ def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
return kwargs
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
if isinstance(dataset_tags, str):
dataset_tags = [dataset_tags]
if (dataset_tags is not None) and (kwargs is not None):
if "dataset_tags" not in kwargs:
kwargs["dataset_tags"] = dataset_tags
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
kwargs["dataset_tags"].extend(dataset_tags)
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
dataset_tags.append(kwargs["dataset_tags"])
kwargs["dataset_tags"] = dataset_tags
return kwargs
@dataclass
class AxolotlTrainingMixins:
"""
@@ -236,14 +219,6 @@ class AxolotlTrainingMixins:
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
@@ -410,7 +385,7 @@ class SchedulerMixin(Trainer):
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
return super().create_scheduler(num_training_steps, optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
@@ -434,12 +409,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
@@ -461,75 +434,38 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
@@ -542,13 +478,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
@@ -575,16 +504,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters,
decouple=True,
**optimizer_kwargs,
)
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -937,9 +856,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@@ -979,13 +895,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, **kwargs):
def _save_checkpoint(self, model, trial):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
return super()._save_checkpoint(model, trial)
class AxolotlMambaTrainer(AxolotlTrainer):
@@ -1063,9 +979,8 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
def create_optimizer(self):
@@ -1104,44 +1019,28 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@staticmethod
def tokenize_row(
self,
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
res = DPOTrainer.tokenize_row(
res = super().tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
if processing_class.bos_token and processing_class.bos_token_id is not None:
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res
def training_step(
@@ -1248,12 +1147,6 @@ class TrainerBuilderBase(abc.ABC):
def get_callbacks(self) -> List[TrainerCallback]:
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
@@ -1280,23 +1173,11 @@ class TrainerBuilderBase(abc.ABC):
return callbacks
@abstractmethod
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
callbacks = []
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
callbacks.extend(
[
cb
for cb in plugin_manager.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
@@ -1379,8 +1260,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
def _get_trainer_cls(self):
@@ -1498,15 +1377,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["eval_strategy"] = "no"
training_arguments_kwargs["evaluation_strategy"] = "no"
elif self.cfg.eval_steps:
training_arguments_kwargs["eval_strategy"] = "steps"
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.eval_strategy:
training_arguments_kwargs["eval_strategy"] = self.cfg.eval_strategy
elif self.cfg.evaluation_strategy:
training_arguments_kwargs[
"evaluation_strategy"
] = self.cfg.evaluation_strategy
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["eval_strategy"] = "epoch"
training_arguments_kwargs["evaluation_strategy"] = "epoch"
if self.cfg.save_steps:
training_arguments_kwargs["save_strategy"] = "steps"
@@ -1644,9 +1525,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[
@@ -1717,8 +1595,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template,
tokenizer=self.tokenizer,
self.cfg.chat_template
)
if self.cfg.rl == "orpo":
@@ -1734,13 +1611,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model:
trainer_kwargs["max_length"] = self.cfg.sequence_len
# pylint: disable=duplicate-code
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
@@ -1831,10 +1706,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
else:
trainer_kwargs["tokenizer"] = self.tokenizer
if (trainer_cls is not AxolotlRewardTrainer) and self.cfg.datasets is not None:
trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
@@ -1897,7 +1768,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator = MultiModalChatDataCollator
kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template
kwargs["chat_template_type"] = self.cfg.chat_template
else:
collator = DataCollatorForSeq2Seq
@@ -1920,7 +1790,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
callbacks = []
return callbacks
def build_training_arguments(self, total_num_steps):
@@ -1948,10 +1818,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.eval_dataset:
training_args_kwargs["eval_strategy"] = "steps"
training_args_kwargs["evaluation_strategy"] = "steps"
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
else:
training_args_kwargs["eval_strategy"] = "no"
training_args_kwargs["evaluation_strategy"] = "no"
if self.cfg.bf16 or self.cfg.bfloat16:
training_args_kwargs["bf16"] = True
@@ -2006,18 +1876,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.rl_beta
if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_cls = None
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
@@ -2026,13 +1895,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl == "orpo":
if self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "kto":
if self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = (
@@ -2047,17 +1916,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_completion_length"] = None
training_args_kwargs["max_prompt_length"] = self.cfg.sequence_len
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
output_dir=self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size,
@@ -2078,6 +1936,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
training_args = self.build_training_arguments(total_num_steps)
dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo":
dpo_trainer_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
if self.eval_dataset:
@@ -2091,6 +1950,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer
trainer_cls_args = [self.model, self.model_ref]
# these aren't used for the ORPO trainer
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
@@ -2109,10 +1974,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
dpo_trainer_kwargs["tokenizer"] = self.tokenizer
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
dpo_trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
]
dpo_trainer = trainer_cls(
*trainer_cls_args,
args=training_args,
@@ -2138,11 +1999,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
def get_callbacks(self):
callbacks = super().get_callbacks()
callbacks = []
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
callbacks = []
return callbacks
def build(self, total_num_steps):

View File

@@ -40,7 +40,7 @@ class TRLPPOTrainer(PPOTrainer):
query_tensors,
return_prompt=False,
generate_ref_response=True,
**generation_kwargs,
**generation_kwargs
)
batch["response"] = self.tokenizer.batch_decode(response_tensors)
batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)

View File

@@ -18,10 +18,9 @@ Plugins can be used to integrate third-party models, modify the training process
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
"""
import collections
import importlib
import logging
from typing import OrderedDict
from typing import List
class BasePlugin:
@@ -48,7 +47,7 @@ class BasePlugin:
Initializes the BasePlugin.
"""
def register(self, cfg): # pylint: disable=unused-argument
def register(self, cfg):
"""
Registers the plugin with the given configuration.
@@ -64,7 +63,7 @@ class BasePlugin:
Returns a pydantic model for the plugin's input arguments.
"""
def pre_model_load(self, cfg): # pylint: disable=unused-argument
def pre_model_load(self, cfg):
"""
Performs actions before the model is loaded.
@@ -75,7 +74,7 @@ class BasePlugin:
None
"""
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
def post_model_load(self, cfg, model):
"""
Performs actions after the model is loaded.
@@ -87,7 +86,7 @@ class BasePlugin:
None
"""
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
def pre_lora_load(self, cfg, model):
"""
Performs actions before LoRA weights are loaded.
@@ -99,7 +98,7 @@ class BasePlugin:
None
"""
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
def post_lora_load(self, cfg, model):
"""
Performs actions after LoRA weights are loaded.
@@ -111,7 +110,7 @@ class BasePlugin:
None
"""
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
def create_optimizer(self, cfg, trainer):
"""
Creates and returns an optimizer for training.
@@ -123,9 +122,7 @@ class BasePlugin:
object: The created optimizer.
"""
def create_lr_scheduler(
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
def create_lr_scheduler(self, cfg, trainer, optimizer):
"""
Creates and returns a learning rate scheduler.
@@ -138,9 +135,9 @@ class BasePlugin:
object: The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
def add_callbacks_pre_trainer(self, cfg, model):
"""
setup callbacks before creating the trainer.
Adds callbacks to the trainer before training.
Parameters:
cfg (dict): The configuration for the plugin.
@@ -149,25 +146,20 @@ class BasePlugin:
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
def add_callbacks_post_trainer(
self, cfg, trainer
): # pylint: disable=unused-argument
def add_callbacks_post_trainer(self, cfg, trainer):
"""
Adds callbacks to the trainer after creating the trainer.
This is useful for callbacks that require access to the model or trainer.
Adds callbacks to the trainer after training.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
List[callable]: A list of callback functions to be added
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
def post_train(self, cfg, model): # pylint: disable=unused-argument
def post_train(self, cfg, model):
"""
Performs actions after training is complete.
@@ -179,7 +171,7 @@ class BasePlugin:
None
"""
def post_train_unload(self, cfg): # pylint: disable=unused-argument
def post_train_unload(self, cfg):
"""
Performs actions after training is complete and the model is unloaded.
@@ -235,7 +227,7 @@ class PluginManager:
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
"""
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
plugins: List[BasePlugin] = []
_instance = None
@@ -245,7 +237,7 @@ class PluginManager:
"""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins = collections.OrderedDict()
cls._instance.plugins: List[BasePlugin] = []
return cls._instance
@staticmethod
@@ -273,7 +265,7 @@ class PluginManager:
"""
try:
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
self.plugins.append(plugin)
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
@@ -285,7 +277,7 @@ class PluginManager:
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
"""
input_args = []
for plugin in self.plugins.values():
for plugin in self.plugins:
input_args_from_plugin = plugin.get_input_args()
if input_args_from_plugin is not None:
input_args.append(input_args_from_plugin)
@@ -301,7 +293,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
plugin.pre_model_load(cfg)
def post_model_load(self, cfg, model):
@@ -315,7 +307,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
plugin.post_model_load(cfg, model)
def pre_lora_load(self, cfg, model):
@@ -329,7 +321,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
plugin.pre_lora_load(cfg, model)
def post_lora_load(self, cfg, model):
@@ -343,7 +335,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
plugin.post_lora_load(cfg, model)
def create_optimizer(self, cfg, trainer):
@@ -357,7 +349,7 @@ class PluginManager:
Returns:
object: The created optimizer, or None if none was found.
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
optimizer = plugin.create_optimizer(cfg, trainer)
if optimizer is not None:
return optimizer
@@ -375,7 +367,7 @@ class PluginManager:
Returns:
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
if scheduler is not None:
return scheduler
@@ -393,10 +385,8 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins.values():
plugin_callbacks = plugin.add_callbacks_pre_trainer(cfg, model)
if plugin_callbacks: # if the plugin returned a list of callbacks
callbacks.extend(plugin_callbacks)
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
return callbacks
def add_callbacks_post_trainer(self, cfg, trainer):
@@ -411,10 +401,8 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins.values():
plugin_callbacks = plugin.add_callbacks_post_trainer(cfg, trainer)
if plugin_callbacks:
callbacks.extend(plugin_callbacks)
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks
def post_train_unload(self, cfg):
@@ -428,5 +416,5 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins.values():
for plugin in self.plugins:
plugin.post_train_unload(cfg)

View File

@@ -1,325 +0,0 @@
Acknowledgements
Portions of this Cut Cross Entropy Software may utilize the following copyrighted
material, the use of which is hereby acknowledged.
------
PyTorch
From PyTorch:
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
From Caffe2:
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
All contributions by Facebook:
Copyright (c) 2016 Facebook Inc.
All contributions by Google:
Copyright (c) 2015 Google Inc.
All rights reserved.
All contributions by Yangqing Jia:
Copyright (c) 2015 Yangqing Jia
All rights reserved.
All contributions by Kakao Brain:
Copyright 2019-2020 Kakao Brain
All contributions by Cruise LLC:
Copyright (c) 2022 Cruise LLC.
All rights reserved.
All contributions by Arm:
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
All contributions from Caffe:
Copyright(c) 2013, 2014, 2015, the respective contributors
All rights reserved.
All other contributions:
Copyright(c) 2015, 2016 the respective contributors
All rights reserved.
Caffe2 uses a copyright model similar to Caffe: each contributor holds
copyright over their contributions to Caffe2. The project versioning records
all such contribution and copyright details. If a contributor wants to further
mark their specific copyright on a particular contribution, they should
indicate their copyright solely in the commit message of the change when it is
committed.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
and IDIAP Research Institute nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
Triton
/*
* Copyright 2018-2020 Philippe Tillet
* Copyright 2020-2022 OpenAI
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
Transformers
Copyright 2018- The Hugging Face team. All rights reserved.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -1,47 +0,0 @@
Copyright (C) 2024 Apple Inc. All Rights Reserved.
IMPORTANT: This Apple software is supplied to you by Apple
Inc. ("Apple") in consideration of your agreement to the following
terms, and your use, installation, modification or redistribution of
this Apple software constitutes acceptance of these terms. If you do
not agree with these terms, please do not use, install, modify or
redistribute this Apple software.
In consideration of your agreement to abide by the following terms, and
subject to these terms, Apple grants you a personal, non-exclusive
license, under Apple's copyrights in this original Apple software (the
"Apple Software"), to use, reproduce, modify and redistribute the Apple
Software, with or without modifications, in source and/or binary forms;
provided that if you redistribute the Apple Software in its entirety and
without modifications, you must retain this notice and the following
text and disclaimers in all such redistributions of the Apple Software.
Neither the name, trademarks, service marks or logos of Apple Inc. may
be used to endorse or promote products derived from the Apple Software
without specific prior written permission from Apple. Except as
expressly stated in this notice, no other rights or licenses, express or
implied, are granted by Apple herein, including but not limited to any
patent rights that may be infringed by your derivative works or by other
works in which the Apple Software may be incorporated.
The Apple Software is provided by Apple on an "AS IS" basis. APPLE
MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
-------------------------------------------------------------------------------
SOFTWARE DISTRIBUTED WITH CUT CROSS ENTROPY:
The Cut Cross Entropy software includes a number of subcomponents with separate
copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.md.
-------------------------------------------------------------------------------

View File

@@ -1,10 +0,0 @@
# Cut Cross Entropy
### Usage
```yaml
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
cut_cross_entropy: true
```

View File

@@ -1,83 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module for the Plugin for Cut Cross Entropy integration with Axolotl.
Cut Cross Entropy is an optimized implementation of cross entropy loss
from Apple's ML team.
"""
import importlib
import logging
import torch
from axolotl.integrations.base import BasePlugin
from axolotl.utils import get_pytorch_version
from ...utils.distributed import zero_only
from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
_CCE_INSTALL_MESSAGE = (
"Please install cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
)
class CutCrossEntropyPlugin(BasePlugin):
"""
Plugin for Cut Cross Entropy integration with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.cut_cross_entropy.CutCrossEntropyArgs"
def _check_requirements(self):
"""Check if all requirements are met."""
# Check PyTorch version
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
raise ImportError(
"Cut Cross Entropy requires PyTorch >= 2.4.0. "
f"Current version: {torch.__version__}"
)
# Check if cut_cross_entropy is installed
cce_spec = importlib.util.find_spec("cut_cross_entropy")
if cce_spec is None:
raise ImportError(_CCE_INSTALL_MESSAGE)
cce_spec_transformers = importlib.util.find_spec(
"cut_cross_entropy.transformers"
)
if cce_spec_transformers is None:
raise ImportError(_CCE_INSTALL_MESSAGE)
def pre_model_load(self, cfg):
"""Apply cut cross entropy before model loading if enabled."""
if cfg.cut_cross_entropy:
self._check_requirements()
from cut_cross_entropy.transformers import cce_patch
with zero_only():
LOG.info(
f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}"
)
# The patch checks model_type internally
cce_patch(cfg.model_config_type)

View File

@@ -1,42 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module for handling Cut Cross Entropy input arguments.
"""
import logging
from typing import Optional
from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy.args")
class CutCrossEntropyArgs(BaseModel):
"""
Input args for Cut Cross Entropy.
"""
cut_cross_entropy: Optional[bool] = None
@model_validator(mode="before")
@classmethod
def check_dtype_is_half(cls, data):
if not (data.get("bf16") or data.get("fp16")):
raise ValueError(
"Cut Cross Entropy requires fp16/bf16 training for backward pass. "
"Please set `bf16` or `fp16` to `True`."
)
return data

View File

@@ -1,21 +0,0 @@
MIT License
Copyright (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -1,13 +0,0 @@
# Grokfast Optimizer
See https://github.com/ironjr/grokfast
### Usage
```yaml
plugins:
- axolotl.integrations.grokfast.GrokfastPlugin
grokfast_alpha: 2.0
grokfast_lamb: 0.98
```

View File

@@ -1,50 +0,0 @@
"""
Grokfast plugin for Axolotl
"""
import logging
from transformers.trainer_callback import TrainerCallback
from ..base import BasePlugin
from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401
from .optimizer import gradfilter_ema
LOG = logging.getLogger("axolotl.integrations.grokfast")
class GrokfastCallbackHandler(TrainerCallback):
"""
Transformer trainer callbacks for Grokfast
"""
def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs):
super().__init__(*args_, **kwargs)
self.grads = None
self.alpha = alpha
self.lamb = lamb
def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument
self.grads = None
def on_pre_optimizer_step(
self, args_, state, control, **kwargs
): # pylint: disable=unused-argument
model = kwargs.pop("model")
self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb)
return control
class GrokfastPlugin(BasePlugin):
"""
Plugin for Grokfast optimizer integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.grokfast.GrokfastArgs"
def add_callbacks_post_trainer(self, cfg, trainer):
LOG.info("Adding Grokfast callback to the trainer")
callback = GrokfastCallbackHandler(
alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb
)
return [callback]

View File

@@ -1,15 +0,0 @@
"""
config args for grokfast plugin
"""
from typing import Optional
from pydantic import BaseModel
class GrokfastArgs(BaseModel):
"""
Input args for Grokfast optimizer.
"""
grokfast_alpha: Optional[float] = 0.98
grokfast_lamb: Optional[float] = 2.0

View File

@@ -1,63 +0,0 @@
# Copyright: MIT License (c) 2024 Jaerin Lee, Bong Gyun Kang, Kihoon Kim, Kyoung Mu Lee
# Reference: https://github.com/ironjr/grokfast
# pylint: skip-file
from collections import deque
from typing import Dict, Literal, Optional
import torch
import torch.nn as nn
def gradfilter_ma(
m: nn.Module,
grads: Optional[Dict[str, deque]] = None,
window_size: int = 100,
lamb: float = 5.0,
filter_type: Literal["mean", "sum"] = "mean",
warmup: bool = True,
trigger: bool = False, # For ablation study.
) -> Dict[str, deque]:
if grads is None:
grads = {
n: deque(maxlen=window_size)
for n, p in m.named_parameters()
if p.requires_grad and p.grad is not None
}
for n, p in m.named_parameters():
if p.requires_grad and p.grad is not None:
grads[n].append(p.grad.data.detach()) # .cpu())
# Modify the gradients.
if not warmup or len(grads[n]) == window_size and not trigger:
if filter_type == "mean":
avg = sum(grads[n]) / len(grads[n])
elif filter_type == "sum":
avg = sum(grads[n])
else:
raise ValueError(f"Unrecognized filter_type {filter_type}")
p.grad.data = p.grad.data + avg * lamb
return grads
def gradfilter_ema(
m: nn.Module,
grads: Optional[Dict[str, torch.Tensor]] = None,
alpha: float = 0.98,
lamb: float = 2.0,
) -> Dict[str, torch.Tensor]:
if grads is None:
grads = {
n: p.grad.data.detach()
for n, p in m.named_parameters()
if p.requires_grad and p.grad is not None
}
for n, p in m.named_parameters():
if p.requires_grad and p.grad is not None:
grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)
p.grad.data = p.grad.data + grads[n] * lamb
return grads

View File

@@ -18,24 +18,20 @@ Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight.
"""
import inspect
import logging
import sys
from functools import partial
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.liger")
class LigerPlugin(BasePlugin):
"""
@@ -46,31 +42,59 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)
kwargs = {}
if "rope" in liger_fn_sig.parameters:
kwargs["rope"] = cfg.liger_rope
if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs[
"fused_linear_cross_entropy"
] = cfg.liger_fused_linear_cross_entropy
if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters:
kwargs["layer_norm"] = cfg.liger_layer_norm
if "geglu" in liger_fn_sig.parameters:
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
with zero_only():
LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
if cfg.model_config_type == "llama":
from liger_kernel.transformers.model.llama import (
lce_forward as llama_lce_forward,
)
from transformers.models.llama import modeling_llama
if cfg.liger_rope:
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_llama.LlamaRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_llama.LlamaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
elif cfg.liger_fused_linear_cross_entropy:
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
elif cfg.model_config_type == "mistral":
from liger_kernel.transformers.model.mistral import (
lce_forward as mistral_lce_forward,
)
from transformers.models.mistral import modeling_mistral
if cfg.liger_rope:
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_mistral.MistralRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_mistral.MistralMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
elif cfg.model_config_type == "gemma":
from liger_kernel.transformers.model.gemma import (
lce_forward as gemma_lce_forward,
)
from transformers.models.gemma import modeling_gemma
if cfg.liger_rope:
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma.GemmaRMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
apply_liger_fn(**kwargs)
if cfg.liger_swiglu:
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba
@@ -80,14 +104,30 @@ class LigerPlugin(BasePlugin):
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
if cfg.liger_swiglu:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "qwen2":
from liger_kernel.transformers.model.qwen2 import (
lce_forward as qwen2_lce_forward,
)
from transformers.models.qwen2 import modeling_qwen2
if cfg.liger_rope:
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM
@@ -106,11 +146,44 @@ class LigerPlugin(BasePlugin):
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_glu_activation:
if cfg.liger_swiglu:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_cross_entropy:
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
# nn.CrossEntropyLoss in the forward method.
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "gemma2":
from transformers.models.gemma2 import modeling_gemma2
if cfg.liger_rope:
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma2.Gemma2RMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cfg.liger_swiglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
logging.warning(
"Fused linear cross entropy is not supported for Gemma 2."
)
elif cfg.model_config_type == "phi3":
from liger_kernel.transformers.model.phi3 import (
lce_forward as phi3_lce_forward,
)
from transformers.models.phi3 import modeling_phi3
if cfg.liger_rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward

View File

@@ -15,12 +15,9 @@
"""
Module for handling LIGER input arguments.
"""
import logging
from typing import Optional
from pydantic import BaseModel, model_validator
LOG = logging.getLogger("axolotl.integrations.liger.args")
from pydantic import BaseModel
class LigerArgs(BaseModel):
@@ -30,24 +27,6 @@ class LigerArgs(BaseModel):
liger_rope: Optional[bool] = None
liger_rms_norm: Optional[bool] = None
liger_layer_norm: Optional[bool] = None
liger_swiglu: Optional[bool] = None
liger_glu_activation: Optional[bool] = None
liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None
@model_validator(mode="before")
@classmethod
def check_deprecated_swiglu(cls, data):
if data.get("liger_swiglu") is not None:
if data.get("liger_glu_activation") is not None:
raise ValueError(
"You cannot have both `liger_swiglu` and `liger_glu_activation` set."
)
LOG.warning(
"The 'liger_swiglu' argument is deprecated and will be removed in a future release. "
"Please use 'liger_glu_activation' instead."
)
data["liger_glu_activation"] = data.pop("liger_swiglu")
return data

View File

@@ -0,0 +1,231 @@
"""
monkeypatch to add a get_turns method
"""
import logging
from typing import Generator, Tuple
from fastchat.conversation import SeparatorStyle
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
def get_prompt(self) -> str:
ret = ""
for role, msg in self.get_turns():
ret += role + msg
return ret
def get_turns( # pylint: disable=too-many-return-statements
self,
) -> Generator[Tuple[str, str], None, None]:
"""Get the prompt for generation."""
system_prompt = self.system_template.format(system_message=self.system_message)
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message + seps[i % 2]
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ": ", "" # must be end with a space
return
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
yield "", "" if system_prompt == "" else system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role, message + self.sep
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role, message + seps[i % 2]
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.RWKV:
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message.replace("\r\n", "\n").replace(
"\n\n", "\n"
) + "\n\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
if self.system_message:
if self.messages:
# For llama, the system message is incorporated into the first human instruction
first_role, first_msg = self.messages[0]
if first_role == self.roles[0]:
system_prompt += first_msg
self.messages.pop(0)
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
if (i % 2 == 0 and not self.system_message) or (
i % 2 != 0 and self.system_message
):
role = "<s> " + role
yield role + " ", message
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
contains_sys_msg = False
if self.system_message:
contains_sys_msg = True
if self.messages:
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction separated by a newline
first_role, first_msg = self.messages[0]
if first_role == self.roles[0]:
system_prompt = self.system_template.format(
system_message=" " + self.system_message
)
system_prompt += first_msg
self.messages.pop(0)
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message and i == 0 and not contains_sys_msg:
yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
elif message:
yield role + " ", message
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.LLAMA3:
if self.system_message:
# For llama3, the system message is NOT incorporated into the first human instruction
# All messages follow <|start_header_id|>' + role + '<|end_header_id|>\n\n'+ message + '<|eot_id|>
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", f"{message.strip()}<|eot_id|>"
else:
yield f"<|start_header_id|>{role}<|end_header_id|>\n\n", ""
return
if self.sep_style == SeparatorStyle.GEMMA:
if self.system_message:
raise ValueError("Gemma chat template does not support system messages")
for i, (role, message) in enumerate(self.messages):
prefix = "<bos>" if i == 0 else ""
message_str = message if message else ""
yield prefix + "<start_of_turn>" + role + "\n", message_str + "<end_of_turn>\n"
return
if self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
round_add_n = 1 if self.name == "chatglm2" else 0
if system_prompt:
yield "", system_prompt + self.sep
for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
if message:
yield f"{role}", f"{message}{self.sep}"
else:
yield f"{role}", ""
return
if self.sep_style == SeparatorStyle.CHATML:
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep + "\n"
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.CHATGLM3:
if self.system_message:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role + "\n", " " + message
else:
yield role
return
if self.sep_style == SeparatorStyle.CHATINTERN:
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
prefix = "<s>" if i % 2 == 0 else ""
if message:
yield prefix + role + ":", message + seps[i % 2] + "\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.DOLLY:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
suffix = "\n\n" if i % 2 == 1 else ""
yield role + ":\n", message + seps[i % 2] + suffix
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.PHOENIX:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role + ": ", "<s>" + message + "</s>"
else:
yield role + ": " + "<s>", ""
return
if self.sep_style == SeparatorStyle.ROBIN:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ":\n", message + self.sep
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.FALCON_CHAT:
if self.system_message:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def add_get_turns_to_conversation():
import fastchat.conversation
fastchat.conversation.Conversation.get_turns = get_turns
fastchat.conversation.Conversation.get_prompt = get_prompt

View File

@@ -4,6 +4,7 @@
import logging
import warnings
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
@@ -93,32 +94,13 @@ def replace_llama_qkv_with_fused(model):
set_module_name(model, name, qkv)
def patch_fa_llama_cross_entropy():
LOG.info(
"patching transformers.loss.loss_utils.fixed_cross_entropy with flash_attn.ops.triton.cross_entropy"
)
from flash_attn.ops.triton.cross_entropy import (
cross_entropy_loss as flash_attn_cross_entropy_loss,
)
def patch_llama_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss
def fa2_fixed_cross_entropy(
source,
target,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
): # pylint: disable=unused-argument
reduction = "sum" if num_items_in_batch is not None else "mean"
loss, _ = flash_attn_cross_entropy_loss(
source, target, ignore_index=ignore_index
)
if reduction == "sum":
loss = loss.sum() / num_items_in_batch
else:
loss = loss.sum() / (target != ignore_index).sum()
return loss
transformers.loss.loss_utils.fixed_cross_entropy = fa2_fixed_cross_entropy
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
def patch_llama_rms_norm():
@@ -165,7 +147,7 @@ def replace_llama_attn_with_flash_attn(
# skip only if explicitly disabled
if cross_entropy:
patch_fa_llama_cross_entropy()
patch_llama_cross_entropy()
# skip only if explicitly disabled
if rms_norm:

View File

@@ -1,5 +1,4 @@
"""multipack patching for v2 of sample packing"""
import importlib
import transformers
@@ -28,28 +27,74 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
]
# def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
if has_remote_code:
patch_remote(model_name)
if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
# elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
elif hasattr(transformers, "modeling_flash_attention_utils"):
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
if not has_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
return
# retain for legacy
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "llama":
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "mistral":
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2_moe":
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma2":
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
def patch_remote(model_name):
def patch_remote(model_name, config_name, modeling_name):
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_* to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
parts = model_config.__class__.__module__.split(".")
parts[-1] = parts[-1].replace("configuration_", "modeling_", 1)
module_name = ".".join(parts)
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
modeling_arch = importlib.import_module(module_name)
if hasattr(modeling_arch, "_get_unpad_data"):
modeling_arch._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access

View File

@@ -46,10 +46,9 @@ def reset_optimizer(
*,
reset_params: List[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: List[str],
optimizer_magnitude_pruning: float = 0.9,
prune_ratio: float = 0.9,
):
# pylint:disable=unused-argument
pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
pruning_fn = partial(magnitude_pruning_, prune_ratio=prune_ratio)
n_zeros = 0
n_total = 0
@@ -57,22 +56,16 @@ def reset_optimizer(
if isinstance(optimizer, ZeroRedundancyOptimizer):
optimizer_state = optimizer.optim.state
for group in optimizer.param_groups:
for param in group["params"]:
state = optimizer_state[param]
for key, value in state.items():
if key not in optimizer_state_keys:
continue
if torch.is_tensor(value):
try:
pruning_fn(value)
n_total += value.numel()
n_zeros += torch.sum(value == 0).item()
except RuntimeError as exc:
if "quantile() input tensor is too large" in str(exc):
pass
else:
raise exc
for param in reset_params:
param_state = optimizer_state[param]
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
continue
for key in optimizer_state_keys:
pruning_fn(
param_state[key]
) # pruning fn has to be inplace to keep the same keys in the dict
n_total += param_state[key].numel()
n_zeros += torch.sum(param_state[key] == 0).item()
_zeroed = n_zeros / (1e-7 + n_total) * 100
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
@@ -136,9 +129,6 @@ class ReLoRACallback(TrainerCallback):
if "adam" in args.optim.lower():
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
if "8bit" in args.optim.lower():
optimizer_state_keys.append("state1")
optimizer_state_keys.append("state2")
else:
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
@@ -170,7 +160,7 @@ class ReLoRACallback(TrainerCallback):
optimizer,
reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys,
optimizer_magnitude_pruning=args.relora_prune_ratio,
prune_ratio=args.relora_prune_ratio,
)
if self.quantized:

View File

@@ -188,7 +188,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
for module in layer_modules
)
mlp_not_dora = all(
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
getattr(module, "lora_magnitude_vector", None) is None
for module in layer_modules
)
@@ -213,7 +213,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
for module in layer_modules
)
qkv_not_dora = all(
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
getattr(module, "lora_magnitude_vector", None) is None
for module in layer_modules
)
@@ -232,7 +232,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
for module in layer_modules
)
o_not_dora = all(
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
getattr(module, "lora_magnitude_vector", None) is None
for module in layer_modules
)

View File

@@ -0,0 +1,33 @@
"""Module containing the InstructShareGPTPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = InstructShareGPTPromptTokenizingStrategy(
# pylint: disable=duplicate-code
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
return strategy
class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
"""
def get_conversation_thread(self, prompt):
return [
{"from": "human", "value": prompt["instruction"]},
{"from": "gpt", "value": prompt["output"]},
]

View File

@@ -29,7 +29,7 @@ from dataclasses import dataclass, field
from typing import Generator, List, Sequence
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import ALTERNATING_ASSERTION_FAILED_ROLE, IGNORE_TOKEN_ID
from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE
@dataclass
@@ -75,7 +75,7 @@ class Llama2ChatConversation:
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for Llama2 prompts.
Tokenizing strategy for ShareGPT prompts.
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
"""
@@ -191,7 +191,7 @@ class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
conv.messages = [] # pylint: disable=R0801
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], ALTERNATING_ASSERTION_FAILED_ROLE
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
if sentence["value"]:
conv.append_message(role, sentence["value"])
yield conv

View File

@@ -0,0 +1,223 @@
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
import logging
from typing import Any, Dict, Optional, Type
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
from axolotl.utils.tokenization import (
chatml_to_conversation,
merge_consecutive_messages,
)
LOG = logging.getLogger("axolotl")
def register_chatml_template(system_message=None):
system_message = system_message or "You are a helpful assistant."
register_conv_template(
Conversation(
name="chatml",
system_template="<|im_start|>system\n{system_message}",
system_message=system_message,
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
)
register_conv_template(
Conversation(
name="chatml_glaive",
system_template="<|im_start|>system\n{system_message}",
system_message=system_message,
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
)
def register_llama3_template(system_message=None):
system_message = system_message or "You are a helpful assistant."
register_conv_template(
Conversation(
name="llama3",
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
system_message=system_message,
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA3,
sep="",
stop_str="<|eot_id|>",
stop_token_ids=[128001, 128009],
)
)
def build_loader(
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
prompter_cls: Type["ShareGPTPrompterV2"],
default_conversation: Optional[str] = None,
):
def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
LOG.warning(
"sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template",
)
conversation = (
ds_cfg["conversation"]
if ds_cfg and "conversation" in ds_cfg
else default_conversation
)
field_human = (
ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
)
field_model = (
ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
)
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None
strategy = tokenization_strategy_cls(
prompter_cls(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
roles=roles,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "strict" in ds_cfg and hasattr(strategy, "strict"):
strategy.strict = ds_cfg["strict"]
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy
return _load
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
"""
_strict = False
_messages = "conversations"
@property
def strict(self):
return self._strict
@strict.setter
def strict(self, strict):
self._strict = strict
@property
def messages(self):
return self._messages
@messages.setter
def messages(self, messages):
self._messages = messages
def get_conversation_thread(self, prompt):
conversations = prompt[self.messages]
if self.strict:
return conversations
role_key = "from"
if "role" in conversations[0].keys():
role_key = "role"
value_key = "value"
if "text" in conversations[0].keys():
value_key = "text"
elif "content" in conversations[0].keys():
value_key = "content"
# remap roles - allow for assistant turn"
role_map = {
"user": "human",
"human": "human",
"assistant": "gpt",
"gpt": "gpt",
"system": "system",
}
turns = [
{
"from": (
role_map[t[role_key]] if t[role_key] in role_map else t[role_key]
),
"value": t[value_key],
"weight": 1
if "weight" not in t or t["weight"] is None
else t["weight"],
}
for t in conversations
]
return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(
SimpleShareGPTPromptTokenizingStrategy
):
"""
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
"""
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
return turns
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps oasst data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
role_map = {"prompter": "human", "assistant": "gpt"}
turns = [
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
]
return turns
class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps ultrachat data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversations = prompt["messages"]
role_map = {"user": "human", "assistant": "gpt"}
turns = [
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
]
return turns
class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
"""
sharegpt strategy that remaps glaive data to sharegpt format
"""
def get_conversation_thread(self, prompt):
conversation = chatml_to_conversation(prompt)
conversation = merge_consecutive_messages(conversation)
return conversation
load = build_loader(SimpleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_role = build_loader(SimpleRoleShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_ultrachat = build_loader(
UltrachatShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2
)
load_guanaco = build_loader(GuanacoShareGPTPromptTokenizingStrategy, ShareGPTPrompterV2)
load_glaive = build_loader(
GlaiveShareGPTPromptTokenizingStrategy,
ShareGPTPrompterV2,
default_conversation="chatml_glaive",
)

View File

@@ -0,0 +1,28 @@
"""Module for Jokes prompts using sharegpt style """
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
def load(tokenizer, cfg):
return SimpleJokesShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class SimpleJokesShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
Tokenization strategy for asking bot to tell a joke and then explain why its funny
"""
# title, text, explanation
def get_conversation_thread(self, prompt):
title = "" if not prompt["title"] else prompt["title"] + " "
return [
{"from": "human", "value": "Tell me a joke."},
{"from": "gpt", "value": title + prompt["text"]},
{"from": "human", "value": "Why is that joke funny?"},
{"from": "gpt", "value": prompt["explanation"]},
]

View File

@@ -1,12 +1,17 @@
"""Module containing PromptTokenizingStrategy and Prompter classes"""
import abc
import copy
import logging
from typing import Dict, List, Tuple, Union
from fastchat.conversation import Conversation
from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.prompters import Prompter
from axolotl.monkeypatch.fastchat_conversation_turns import (
add_get_turns_to_conversation,
)
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
LOG = logging.getLogger("axolotl")
@@ -16,6 +21,8 @@ LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
add_get_turns_to_conversation()
class InvalidDataException(Exception):
"""
@@ -324,6 +331,154 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
)
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for ShareGPT prompts.
"""
def get_conversation_thread(self, prompt):
return prompt["conversations"]
def tokenize_prompt(self, prompt):
# Initial values. We will append to these as we go through the conversation.
result, current_len = tokenize_prompt_default()
conversation: Conversation = (
self.prompter._conversation.copy() # pylint: disable=protected-access
)
input_roles = {conversation.roles[0]}
output_roles = {conversation.roles[1]}
if len(conversation.roles) == 3:
tool_role_label = conversation.roles[2]
input_roles.add(tool_role_label)
# Add roles from the config
if self.prompter.roles:
if "input" in self.prompter.roles and self.prompter.roles["input"]:
for role in self.prompter.roles["input"]:
input_roles.add(role)
if "output" in self.prompter.roles and self.prompter.roles["output"]:
for role in self.prompter.roles["output"]:
output_roles.add(role)
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = []
if (
conversation.name == "vicuna_v1.1"
and "roles" in prompt
and len(prompt["roles"]) >= 2
):
role_remap = [
{"from": conversation.roles[0], "to": prompt["roles"][0]},
{"from": conversation.roles[1], "to": prompt["roles"][1]},
]
try:
for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if not isinstance(part, tuple):
LOG.warning(f"expected tuple, got {part}")
continue
if len(part) <= 2:
role, content = part
weight = 1
else:
role, content, weight = part
# Uses "in" because role contains extra characters
input_turn = any(r.lower() in role.lower() for r in input_roles)
output_turn = any(r.lower() in role.lower() for r in output_roles)
empty_role = role.strip() == ""
if not any([input_turn, output_turn, empty_role]):
LOG.warning(f"unhandled role: {role}")
continue
if input_turn:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
else role
)
turn = role + content
# this is still the user query, we should
if not content.strip():
LOG.warning(f"user turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
if self.train_on_inputs and weight == 1:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif output_turn:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
else role
)
turn = role + content
# this should be the assistant response, should end with an eos token
if not content.strip():
LOG.warning(f"assistant turn has empty text: {prompt}")
add_eos_token = not (
conversation.name == "chatml"
and conversation.sep == self.tokenizer.eos_token
)
res = self._tokenize(
turn,
add_eos_token=add_eos_token,
strip_bos_token=True,
)
role_res = self._tokenize(
role.rstrip(),
add_eos_token=False,
strip_bos_token=True,
)
labels = copy.deepcopy(res["input_ids"])
if not self.train_on_inputs:
# mask out role tokens from the labels
len_role = len(role_res["input_ids"])
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
if weight == 0:
# everything from this is masked out from the labels
# (role is masked out too because it makes no sense if contents is masked out)
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif empty_role:
turn = content
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
if self.train_on_inputs and weight == 1:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
return result
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
"""
Returns the default values for the tokenize prompt function

View File

@@ -5,6 +5,7 @@ from enum import Enum
from typing import Generator, Optional, Union
from colorama import Fore
from fastchat.conversation import Conversation, get_conv_template
LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100
@@ -261,10 +262,166 @@ class ReflectAlpacaPrompter(Prompter):
)
ALTERNATING_ASSERTION_FAILED_ROLE = (
SHAREGPT_ASSERTION_FAILED_ROLE = (
"Role did not alternate between turns (gpt and human). Please check your data."
)
CONVERSATION_ROLE_FORMAT = {
"chatml": "<|im_start|>{ROLE}",
"zephyr": "<|{ROLE}|>",
"vicuna_v1.1": "{ROLE}",
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
}
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
"""
A prompter that generates prompts for the ShareGPT
"""
role_key_human = "human"
role_key_model = "gpt"
# Optional, only used for tool usage datasets.
role_key_tool: Optional[str] = None
# Optional, role input/output mapping
roles: Optional[dict] = None
def __init__(
self,
prompt_style=None, # pylint: disable=unused-argument
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None,
roles: Optional[dict] = None,
):
if conversation:
if isinstance(conversation, Conversation):
self._conversation = conversation
else:
self._conversation = get_conv_template(conversation)
else:
self._conversation = get_conv_template("vicuna_v1.1")
if role_key_human:
self.role_key_human = role_key_human
if role_key_model:
self.role_key_model = role_key_model
if role_key_tool:
self.role_key_tool = role_key_tool
if roles:
self.roles = roles
def _build_result(self, source):
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
raise IndexError(
f"A conversation entry has less than 2 messages :\n{source}"
)
conv = self._conversation.copy()
original_source = source.copy()
# Add the conversation system prompt if provided, otherwise use the default one
if source[0]["from"] == "system":
conv.set_system_message(source[0]["value"])
source.pop(0)
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
if self.role_key_tool:
roles[self.role_key_tool] = conv.roles[2]
try:
# Apply prompt templates
if source[0]["from"] not in roles:
# Skip the first one if it is not from human
source = source[1:]
except IndexError as err:
# sometimes there is a bing or system chat
raise err
conv.messages = []
for _, sentence in enumerate(source):
from_role = sentence["from"]
if from_role in roles:
role = roles[from_role]
else:
if self._conversation.name not in CONVERSATION_ROLE_FORMAT:
raise NotImplementedError(
f"Role ({role}) not in default roles, and {self._conversation.name} does not support role remapping yet."
"Please help us by creating an Issue to add support for this conversation type."
)
if self._conversation.name in ["llama3"]:
role = from_role
else:
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
ROLE=from_role
)
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
if (
role != "assistant"
): # back to back assistant calls may be okay for tool calls
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"])
turns = list(conv.get_turns())
original_source_length = len(original_source)
assert len(turns) in [
original_source_length - 1,
original_source_length,
original_source_length + 1,
]
if len(turns) == original_source_length + 1:
original_source = [{"weight": None}] + original_source
elif len(turns) == original_source_length - 1:
original_source = original_source[1:]
return [
(*turn, weight)
for turn, weight in zip(
turns,
[
1 if "weight" not in e or e["weight"] is None else e["weight"]
for e in original_source
],
)
]
def build_prompt(self, source) -> Generator[str, None, None]:
turns = self._build_result(source)
for part in turns:
if part[0] and not part[1]:
LOG.warning(f"role with empty message: {part[0]}")
yield part
def __repr__(self) -> str:
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
class ShareGPTPrompterV2(ShareGPTPrompter):
"""
A V2 prompter that generates prompts for the ShareGPT
"""
def __init__(
self,
conversation: Optional[Union[str, Conversation]] = None,
role_key_human: Optional[str] = None,
role_key_model: Optional[str] = None,
role_key_tool: Optional[str] = None,
roles: Optional[dict] = None,
):
super().__init__(
conversation=conversation,
role_key_human=role_key_human,
role_key_model=role_key_model,
role_key_tool=role_key_tool,
roles=roles,
)
class UnsupportedPrompter(Prompter):
"""

View File

@@ -259,31 +259,11 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id:
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError
try:
# Check to make sure the base model is from HuggingFace not a local directory
hf_api = HfApi()
hf_api.model_info(cfg.base_model)
model_card_kwarg = {
"model_name": cfg.output_dir.lstrip("./")
.encode("utf-8")
.decode("utf-8")
}
if cfg.datasets is not None:
if cfg.rl is not None or cfg.reward_model:
model_card_kwarg["dataset_name"] = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
else:
model_card_kwarg["dataset_tags"] = [
d["path"] for d in cfg.datasets if not Path(d["path"]).is_dir()
]
trainer.create_model_card(**model_card_kwarg)
except (AttributeError, UnicodeDecodeError, RepositoryNotFoundError):
trainer.create_model_card(
model_name=cfg.output_dir.lstrip("./").encode("utf-8").decode("utf-8")
)
except (AttributeError, UnicodeDecodeError):
pass
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated

View File

@@ -1,11 +1,7 @@
"""
Basic utils for Axolotl
"""
import importlib.util
import re
import torch
def is_mlflow_available():
@@ -14,23 +10,3 @@ def is_mlflow_available():
def is_comet_available():
return importlib.util.find_spec("comet_ml") is not None
# pylint: disable=duplicate-code
def get_pytorch_version() -> tuple[int, int, int]:
"""
Get Pytorch version as a tuple of (major, minor, patch).
"""
torch_version = torch.__version__
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if not version_match:
raise ValueError("Invalid version format")
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = int(patch) if patch is not None else 0 # Default patch to 0 if not present
return major, minor, patch
# pylint: enable=duplicate-code

View File

@@ -1,23 +1,9 @@
"""Benchmarking and measurement utilities"""
import functools
import pynvml
import torch
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.distributed import get_device_type
try:
from pynvml import (
NVMLError,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo,
nvmlInit,
)
except ImportError:
NVMLError = None
nvmlDeviceGetHandleByIndex = None
nvmlDeviceGetMemoryInfo = None
nvmlInit = None
from pynvml.nvml import NVMLError
def check_cuda_device(default_value):
@@ -67,35 +53,24 @@ def mps_memory_usage_all():
return usage, reserved - usage, 0
def npu_memory_usage_all(device=0):
usage = torch.npu.memory_allocated(device) / 1024.0**3
reserved = torch.npu.memory_reserved(device) / 1024.0**3
return usage, reserved - usage, 0
@check_cuda_device(0.0)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
device = device.index
if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:])
if not nvmlInit:
return 0.0
try:
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(device)
info = nvmlDeviceGetMemoryInfo(handle)
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3
except NVMLError:
return 0.0
def log_gpu_memory_usage(log, msg, device):
cur_device = get_device_type()
if torch.backends.mps.is_available():
usage, cache, misc = mps_memory_usage_all()
elif "npu" in str(cur_device) and is_torch_npu_available():
usage, cache, misc = npu_memory_usage_all(device)
else:
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
@@ -104,7 +79,6 @@ def log_gpu_memory_usage(log, msg, device):
if misc > 0:
extras.append(f"+{misc:.03f}GB misc")
log.info(
f"{str(cur_device)} memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})",
stacklevel=2,
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
)
return usage, cache, misc

View File

@@ -28,7 +28,6 @@ from transformers import (
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from trl.models import unwrap_model_for_generation
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.bench import log_gpu_memory_usage
@@ -47,7 +46,6 @@ from axolotl.utils.distributed import (
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
IGNORE_INDEX = -100
LOG = logging.getLogger("axolotl.callbacks")
@@ -380,10 +378,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
for metric in self.cfg.eval_causal_lm_metrics:
if metric == "perplexity":
max_seq_len = self.cfg.eval_max_new_tokens
metrics[metric] = Perplexity(
tokenizer=tokenizer,
max_seq_len=max_seq_len,
)
metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len)
else:
try:
metrics[metric] = evaluate.load(metric)
@@ -400,11 +395,8 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
eval_dataloader,
**kwargs, # pylint: disable=unused-argument
):
trainer.model_wrapped.eval()
device = torch.device(
self.cfg.device
) # Use this instead of trainer.model_wrapped.device as it may return cpu if fsdp offloaded
trainer.model.eval()
device = torch.device(self.cfg.device)
# pylint: disable=duplicate-code
generation_config = GenerationConfig(
@@ -441,10 +433,6 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
for k in metric._feature_names() # pylint: disable=protected-access
if k in kwargs
}
if isinstance(metric, Perplexity):
metric_kwargs["model"] = trainer.model_wrapped
metric_score = metric.compute(**metric_kwargs)
return (
metric_score["score"]
@@ -480,97 +468,89 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
def predict_with_generate():
eval_src, eval_pred, eval_ref = [], [], []
with unwrap_model_for_generation(
trainer.model_wrapped, trainer.accelerator
) as unwrapped_model:
for batch in tqdm(eval_dataloader, disable=not is_main_process()):
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)
for batch in tqdm(eval_dataloader):
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)
if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch["input_ids"])
prompt_token_ids_list = []
completion_token_ids_list = []
for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
batch_pos_ids = [None] * len(batch["input_ids"])
pos_ranges = find_ranges(pos_ids)
prompt_token_ids_list = []
completion_token_ids_list = []
for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue
for input_ids_all, labels_all, pos_ids in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)
input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]
for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue
input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]
tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = (
input_ids != tokenizer.pad_token_id
)
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)
prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)
completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)
prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)
with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(device)
predictions = unwrapped_model.generate(
**prompt_encoding, generation_config=generation_config
tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)
del prompt_encoding
prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)
prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)
completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)
predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list,
skip_special_tokens=True,
prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)
with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
**prompt_encoding, generation_config=generation_config
)
eval_src.extend(prompt_texts)
eval_pred.extend(predicted_texts)
eval_ref.extend(completion_texts)
prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)
predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)
eval_src.extend(prompt_texts)
eval_pred.extend(predicted_texts)
eval_ref.extend(completion_texts)
return eval_src, eval_pred, eval_ref
eval_preds = predict_with_generate()
trainer.log(evaluate_preds(*eval_preds))
if is_main_process():
eval_preds = predict_with_generate()
trainer.log(evaluate_preds(*eval_preds))
return control

View File

@@ -8,8 +8,6 @@ from transformers.modeling_outputs import CausalLMOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from axolotl.utils.distributed import is_main_process
class Perplexity:
"""
@@ -19,13 +17,16 @@ class Perplexity:
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
max_seq_len: int,
stride: int = 512,
) -> None:
self.max_seq_len = max_seq_len
self.stride = stride
self.model = model
self.tokenizer = tokenizer
self.device = model.device
self.name = "perplexity"
def _feature_names(self) -> List[str]:
@@ -33,7 +34,6 @@ class Perplexity:
def compute(
self,
model: PreTrainedModel,
references: Optional[List[str]] = None,
) -> Dict[str, float]:
"""
@@ -41,21 +41,17 @@ class Perplexity:
"""
assert references is not None, "Missing parameter: references"
model.eval()
references_tokenized = self.tokenizer(
references, return_tensors="pt", padding=True, truncation=True
)
input_ids: Tensor = references_tokenized["input_ids"] # type: ignore
input_ids = input_ids.to(model.device)
input_ids = input_ids.to(self.device)
sequence_length = input_ids.size(1)
losses = []
prev_end_loc = 0
for begin_loc in tqdm(
range(0, sequence_length, self.stride), disable=not is_main_process()
):
for begin_loc in tqdm(range(0, sequence_length, self.stride)):
end_loc = min(begin_loc + self.max_seq_len, sequence_length)
trg_len = end_loc - prev_end_loc
input_ids_slice = input_ids[:, begin_loc:end_loc]
@@ -63,7 +59,7 @@ class Perplexity:
labels_slice[:, :-trg_len] = -100
with torch.no_grad():
outputs: CausalLMOutput = model(
outputs: CausalLMOutput = self.model(
input_ids=input_ids_slice, labels=labels_slice
)

File diff suppressed because one or more lines are too long

View File

@@ -1,10 +1,8 @@
"""
Collators for multi-modal chat messages and packing
"""
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Dict, List, Optional, Union
from PIL import Image
from transformers import PreTrainedTokenizerBase, ProcessorMixin
@@ -22,7 +20,6 @@ class MultiModalChatDataCollator(DataCollatorMixin):
processor: ProcessorMixin
return_tensors: str = "pt"
chat_template: Optional[str] = None
chat_template_type: Optional[str] = None
packing: bool = False
max_images: int = -1
padding: Union[bool, str, PaddingStrategy] = True
@@ -33,190 +30,38 @@ class MultiModalChatDataCollator(DataCollatorMixin):
raise ValueError("Packing is currently not supported.")
def torch_call(
self, examples: list[Union[list[int], Any, dict[str, Any]]]
) -> dict[str, Any]:
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
) -> Dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
return self.__class__.process_rows(
examples,
self.processor,
self.chat_template,
self.max_images,
chat_template_type=self.chat_template_type,
examples, self.processor, self.chat_template, self.max_images
)
@staticmethod
def preprocess(examples: list[dict]) -> list[dict]:
"""
Preprocess conversation examples to ensure consistent format.
Converts different conversation formats to OpenAI format with 'messages'.
Supports two formats:
1. OpenAI format with 'messages'
2. Legacy format with 'conversations'
Args:
examples: list of conversation dictionaries
Returns:
dict in OpenAI format with 'messages' key
Raises:
ValueError: If the conversation format is not supported
"""
role_mapping = {
"human": "user",
"gpt": "assistant",
}
def normalize_role(role: str) -> str:
"""Normalize role names to OpenAI format. Default to original role if not found."""
return role_mapping.get(role, role)
def convert_legacy_format(example: dict) -> dict:
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
messages = [
{
"role": normalize_role(convo["from"]),
"content": convo["value"],
}
for convo in example["conversations"]
]
# Create new dict without 'conversations' key
result = deepcopy(example)
result.pop("conversations")
return {"messages": messages, **result}
processed_examples = []
for example in examples:
# OpenAI format
if "messages" in example:
processed_examples.append(example)
# Legacy format
elif "conversations" in example:
processed_examples.append(convert_legacy_format(example))
else:
raise ValueError(
"Only `messages` and `conversations` message keys are currently supported."
)
return processed_examples
@staticmethod
def process_images(examples, max_images):
"""
Process images from examples, ensuring consistency in image presence and applying max_images limit.
Args:
examples: List of dictionaries that may contain 'images' key
max_images: Maximum number of images to keep per example (0 means no limit)
Returns:
Either None (if no images) or List[Image objects] (if all examples have images)
Raises:
ValueError: If there's a mix of None and non-None images
"""
def get_image(example):
if "images" not in example:
return None
images = example["images"]
if isinstance(images, str):
return Image.open(images)
return images
images = [get_image(example) for example in examples]
# Count None and non-None images
none_count = sum(1 for img in images if img is None)
# All images are None
if none_count == len(images):
return None
# Mix of None and non-None images
if none_count > 0:
raise ValueError(
"All images should be either None or not None. "
"Please provide images for all examples or None."
)
# Apply max_images limit if specified
if max_images > 0:
images = [
(
img_batch[:max_images]
if isinstance(img_batch, (list, tuple))
else img_batch
)
for img_batch in images
]
return images
@staticmethod
def pixtral_chat_conversion(messages):
is_single_message = not isinstance(messages, list)
if is_single_message:
messages = [messages]
for i, message in enumerate(messages):
if message["role"] == "user":
for j, content in enumerate(message["content"]):
if "type" in content and content["type"] == "text":
messages[i]["content"][j] = {
"type": "text",
"content": content["text"],
}
if message["role"] == "assistant":
messages[i]["content"] = message["content"][0]["text"]
if is_single_message:
return messages[0]
return messages
@staticmethod
def process_rows(
examples,
processor,
chat_template,
max_images,
length_only=False,
chat_template_type=None,
):
def process_rows(examples, processor, chat_template, max_images, length_only=False):
# HINT: use `_torch_collate_batch` to stack and pad tensors
# see also DataCollatorWithFlattening and DefaultDataCollator
# *** This is COPIED from the trl example sft_vlm.py code ***
# use this as a starting point
# Preprocess the examples
examples = __class__.preprocess(examples)
# Get the texts and images, and apply the chat template
if chat_template_type == "pixtral":
texts = [
processor.apply_chat_template(
__class__.pixtral_chat_conversion(example["messages"]),
chat_template=chat_template,
tokenize=False,
)
for example in examples
]
else:
texts = [
processor.apply_chat_template(
example["messages"], chat_template=chat_template, tokenize=False
)
for example in examples
]
texts = [
processor.apply_chat_template(
example["messages"], chat_template=chat_template, tokenize=False
)
for example in examples
]
images = [
Image.open(example["images"])
if isinstance(example["images"], str)
else example["images"]
for example in examples
]
images = __class__.process_images(examples, max_images=max_images)
if chat_template_type == "llava":
# LLava1.5 does not support multiple images
images = [image[0] for image in images]
if max_images > 0:
images = [img_batch[:max_images] for img_batch in images]
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
@@ -225,12 +70,9 @@ class MultiModalChatDataCollator(DataCollatorMixin):
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
if chat_template_type == "qwen2_vl":
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
else:
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
)
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
)
labels[labels == image_token_id] = -100
batch["labels"] = labels

View File

@@ -1,15 +1,16 @@
"""Module for working with config dicts"""
import json
import logging
import os
from pathlib import Path
from typing import Optional
import torch
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.integrations.base import PluginManager
from axolotl.integrations.config import merge_input_args
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import SUPPORTED_METRICS
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
)
@@ -31,10 +32,7 @@ def choose_device(cfg):
if torch.backends.mps.is_available():
return "mps"
if is_torch_npu_available():
return f"npu:{cfg.local_rank}"
raise SystemError("No CUDA/mps/npu device found")
raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
@@ -44,8 +42,6 @@ def choose_device(cfg):
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()}
elif cfg.device.startswith("npu"):
cfg.device_map = {"npu": torch.npu.current_device()}
else:
cfg.device_map = {"": cfg.device}
@@ -132,7 +128,7 @@ def normalize_config(cfg):
cfg.is_multimodal = (
hasattr(model_config, "model_type")
and model_config.model_type in ["llava", "mllama", "qwen2_vl", "qwen2_5_vl"]
and model_config.model_type in ["llava", "mllama"]
or any(
multimodal_name in cfg.base_model.lower()
for multimodal_name in [
@@ -145,12 +141,7 @@ def normalize_config(cfg):
cfg.processor_config = (
cfg.processor_config or cfg.base_model_config or cfg.base_model
)
try:
model_config = model_config.text_config
except AttributeError:
# for qwen2_vl
model_config = model_config.get_text_config()
model_config = model_config.text_config
cfg.model_config_type = model_config.model_type
@@ -224,6 +215,11 @@ def normalize_cfg_datasets(cfg):
if cfg.chat_template:
if cfg.datasets:
for idx, ds_cfg in enumerate(cfg.datasets):
if ds_cfg.type == "sharegpt" and not ds_cfg.conversation:
LOG.info(
f"updating dataset {ds_cfg.path} with `conversation: {cfg.chat_template}` to match your chat_template"
)
cfg.datasets[idx].conversation = cfg.chat_template
if (
ds_cfg.type in ["orpo.chat_template", "chat_template"]
and not ds_cfg.chat_template
@@ -235,11 +231,7 @@ def normalize_cfg_datasets(cfg):
cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja
def validate_config(
cfg: DictDefault,
capabilities: Optional[dict] = None,
env_capabilities: Optional[dict] = None,
):
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
AxolotlInputConfig = AxolotlInputConfigBase
@@ -249,35 +241,402 @@ def validate_config(
AxolotlInputConfig, # pylint: disable=invalid-name
) = merge_input_args()
if capabilities or env_capabilities:
if (capabilities and not env_capabilities) or (
env_capabilities and not capabilities
):
raise ValueError(
"Both capabilities and env_capabilities must be provided or not provided."
)
if capabilities:
return DictDefault(
dict(
AxolotlConfigWCapabilities(
**cfg.to_dict(),
capabilities=capabilities,
env_capabilities=env_capabilities,
**cfg.to_dict(), capabilities=capabilities
).model_dump(exclude_none=True)
)
)
return DictDefault(
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
)
def prepare_plugins(cfg):
def legacy_validate_config(cfg):
"""
Prepare the plugins for the configuration
This is a "pre-validation" step that handles the yaml configuration before we have any
information about the model architecture
"""
if is_torch_bf16_gpu_available():
if not cfg.bf16 and not cfg.bfloat16:
LOG.info("bf16 support detected, but not enabled for this configuration.")
else:
if (
not cfg.merge_lora
and not cfg.is_preprocess
and (cfg.bf16 is True or cfg.bfloat16 is True)
):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
if (
# pylint: disable=too-many-boolean-expressions
not (cfg.bf16 or cfg.bfloat16)
and (cfg.fp16 or cfg.float16)
and not cfg.adapter
and not cfg.flash_attention
and cfg.sample_packing
):
LOG.warning(
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
)
# ValueError: Attempting to unscale FP16 gradients.
# OR
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
if cfg.max_packed_sequence_len:
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
if cfg.sample_packing and cfg.rl:
raise ValueError("`sample_packing: true` does not work with RLHF training")
if cfg.sample_packing and not cfg.pad_to_sequence_len:
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
if cfg.gradient_accumulation_steps and cfg.batch_size:
raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size"
)
if cfg.batch_size:
LOG.warning(
"%s\n%s",
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
if (
cfg.eval_batch_size
and cfg.micro_batch_size
and cfg.eval_batch_size != cfg.micro_batch_size
):
LOG.warning(
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
)
if cfg.adapter == "qlora":
if cfg.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
if cfg.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit")
if cfg.gptq:
raise ValueError("Can't merge qlora if gptq")
if cfg.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
if cfg.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if cfg.gptq:
raise ValueError("Can't load qlora if gptq")
if not cfg.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with QLoRA")
loftq = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
if not cfg.load_in_8bit and cfg.adapter == "lora" and not loftq:
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
raise ValueError("Fused modules are not supported with LoRA")
if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters:
raise ValueError(
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
)
if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
if cfg.fsdp:
raise ValueError("fsdp not supported with ReLoRA")
if cfg.deepspeed:
raise ValueError("deepspeed not supported with ReLoRA")
if cfg.lr_scheduler == "one_cycle":
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with ReLoRA")
if cfg.trust_remote_code:
LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
)
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
raise ValueError(
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
)
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
raise ValueError("FSDP is not supported for falcon models")
if (
cfg.base_model and "mpt" in cfg.base_model.lower()
) and cfg.gradient_checkpointing:
raise ValueError("gradient_checkpointing is not supported for MPT models")
if cfg.flash_optimum is True:
if cfg.adapter:
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True and cfg.bfloat16 is not True:
LOG.warning(
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
if int(torch.__version__.split(".", maxsplit=1)[0]) < 2:
LOG.warning("torch>=2.0.0 required")
raise ValueError(
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
)
if cfg.pretraining_dataset and cfg.group_by_length:
LOG.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
)
if cfg.pretraining_dataset and not cfg.max_steps:
raise ValueError(
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
)
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer
):
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
if cfg.push_to_hub_model_id:
raise ValueError(
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
)
if cfg.gptq and cfg.revision_of_model:
raise ValueError(
"revision_of_model is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove revision_of_model from the config."
)
# if cfg.sample_packing and cfg.sdp_attention:
# # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
# raise ValueError(
# "sample_packing not compatible with sdp_attention. Use flash_attention"
# )
if cfg.sample_packing and cfg.xformers_attention:
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
LOG.warning(
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
"This may work on H100s."
)
if cfg.early_stopping_patience:
if not cfg.save_steps or not cfg.eval_steps:
raise ValueError(
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
)
if cfg.save_steps % cfg.eval_steps != 0:
raise ValueError(
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
)
if cfg.datasets:
for idx, ds_cfg in enumerate(cfg.datasets):
if not ds_cfg.type:
continue
if ds_cfg.type == "sharegpt:chat":
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
cfg.datasets[idx].type = "sharegpt"
if "sharegpt_simple" in ds_cfg.type:
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
"sharegpt_simple", "sharegpt"
)
if cfg.saves_per_epoch and cfg.save_steps:
raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
)
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
)
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if cfg.evals_per_epoch and cfg.eval_steps:
raise ValueError(
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
)
if (
cfg.evals_per_epoch
and cfg.evaluation_strategy
and cfg.evaluation_strategy != "steps"
):
raise ValueError(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if (
cfg.evaluation_strategy
and cfg.eval_steps
and cfg.evaluation_strategy != "steps"
):
raise ValueError(
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
)
if (
cfg.val_set_size == 0
and (cfg.eval_steps or cfg.evaluation_strategy)
and not cfg.test_datasets
):
raise ValueError(
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
)
if (
cfg.sample_packing
and cfg.eval_table_size
and cfg.eval_sample_packing is not False
):
raise ValueError(
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
)
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id
LOG.warning(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
)
if cfg.noisy_embedding_alpha is not None:
# Deprecated, use neftune_noise_alpha
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
if cfg.neftune_noise_alpha is None:
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
else:
# User is providing both; bail and have them sort out their settings
raise ValueError(
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
)
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
raise ValueError("neftune_noise_alpha must be > 0.0")
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
raise ValueError(
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
)
if (
cfg.unfrozen_parameters
and cfg.gradient_checkpointing_kwargs
and cfg.gradient_checkpointing_kwargs.use_reentrant is True
):
# https://github.com/huggingface/transformers/issues/21381
raise ValueError(
"`use_reentrant` must be false when used with partially frozen model."
)
if cfg.deepspeed and Path(cfg.deepspeed).is_file():
with open(cfg.deepspeed, encoding="utf-8") as file:
contents = file.read()
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
if cfg.flash_attention:
if (
deepspeed_cfg.zero_optimization
and deepspeed_cfg.zero_optimization.stage == 3
):
if not (
(
deepspeed_cfg.bf16
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
is True
)
or (
deepspeed_cfg.fp16
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
is True
)
):
raise ValueError(
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
)
if "8bit" in cfg.optimizer and deepspeed_cfg.optimizer:
LOG.warning(
f"conflicting optimizer: {cfg.optimizer} used alongside deepspeed optimizer."
)
if cfg.test_datasets and cfg.val_set_size:
raise ValueError(
"non-zero val_set_size should not be used with test_datasets configuration"
)
if cfg.fsdp and "bnb" in cfg.optimizer:
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
raise ValueError(
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
)
if cfg.eval_causal_lm_metrics:
if not isinstance(cfg.eval_causal_lm_metrics, list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS:
raise ValueError(
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
# no 8bit adaAmw w bf16
# GPT-NeoX
# evals broken when extending context len
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
# attention_mask = causal_mask + attention_mask
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3

View File

@@ -7,9 +7,9 @@ Module for pydantic models for configuration
import logging
import os
from enum import Enum
from importlib.metadata import version
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from packaging import version
from pydantic import (
BaseModel,
Field,
@@ -20,9 +20,8 @@ from pydantic import (
)
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.config.models.internals import EnvCapabilities, GPUCapabilities
from axolotl.utils.config.models.internals import GPUCapabilities
LOG = logging.getLogger("axolotl.utils.config.models.input")
@@ -51,7 +50,6 @@ class ChatTemplate(str, Enum):
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
llava = "llava" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
@@ -59,10 +57,6 @@ class ChatTemplate(str, Enum):
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name
metharme = "metharme" # pylint: disable=invalid-name
pixtral = "pixtral" # pylint: disable=invalid-name
qwen2_vl = "qwen2_vl" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):
@@ -72,7 +66,6 @@ class DeprecatedParameters(BaseModel):
rope_scaling: Optional[Any] = None
noisy_embedding_alpha: Optional[float] = None
dpo_beta: Optional[float] = None
evaluation_strategy: Optional[str] = None
@field_validator("max_packed_sequence_len")
@classmethod
@@ -104,13 +97,6 @@ class DeprecatedParameters(BaseModel):
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
return dpo_beta
@field_validator("evaluation_strategy")
@classmethod
def validate_evaluation_strategy(cls, evaluation_strategy):
if evaluation_strategy is not None:
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
return evaluation_strategy
class RemappedParameters(BaseModel):
"""parameters that have been remapped to other names"""
@@ -254,10 +240,8 @@ class KTODataset(BaseModel):
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
loftq_bits: int = Field(
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
)
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
loftq_bits: int = Field(default=4, metadata={"help": "Quantization bits for LoftQ"})
# loftq_iter: int = Field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
class PeftConfig(BaseModel):
@@ -300,8 +284,8 @@ class LoraConfig(BaseModel):
qlora_sharded_model_loading: Optional[bool] = Field(
default=False,
json_schema_extra={
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
metadata={
"help": "load qlora model in sharded format for FSDP using answer.ai technique."
},
)
lora_on_cpu: Optional[bool] = None
@@ -310,15 +294,13 @@ class LoraConfig(BaseModel):
loraplus_lr_ratio: Optional[float] = Field(
default=None,
json_schema_extra={
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
metadata={
"help": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
},
)
loraplus_lr_embedding: Optional[float] = Field(
default=1e-6,
json_schema_extra={
"description": "loraplus learning rate for lora embedding layers."
},
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
merge_lora: Optional[bool] = None
@@ -326,13 +308,11 @@ class LoraConfig(BaseModel):
@model_validator(mode="before")
@classmethod
def validate_adapter(cls, data):
if (
not data.get("adapter")
and not data.get("inference")
and (data.get("load_in_8bit") or data.get("load_in_4bit"))
if not data.get("adapter") and (
data.get("load_in_8bit") or data.get("load_in_4bit")
):
raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter for training."
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
return data
@@ -390,10 +370,10 @@ class ModelInputConfig(BaseModel):
tokenizer_use_fast: Optional[bool] = None
tokenizer_legacy: Optional[bool] = None
tokenizer_type: Optional[str] = Field(
default=None, json_schema_extra={"description": "transformers tokenizer class"}
default=None, metadata={"help": "transformers tokenizer class"}
)
processor_type: Optional[str] = Field(
default=None, json_schema_extra={"description": "transformers processor class"}
default=None, metadata={"help": "transformers processor class"}
)
trust_remote_code: Optional[bool] = None
@@ -415,18 +395,18 @@ class HyperparametersConfig(BaseModel):
gradient_accumulation_steps: Optional[int] = Field(default=1)
micro_batch_size: Optional[int] = Field(
default=1,
json_schema_extra={"description": "per gpu micro batch size for training"},
metadata={"help": "per gpu micro batch size for training"},
)
batch_size: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Total batch size, we do not recommended setting this manually"
metadata={
"help": "Total batch size, we do not recommended setting this manually"
},
)
eval_batch_size: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
metadata={
"help": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
},
)
@@ -436,8 +416,6 @@ class HyperparametersConfig(BaseModel):
group_by_length: Optional[bool] = None
learning_rate: Union[str, float]
embedding_lr: Optional[float] = None
embedding_lr_scale: Optional[float] = None
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[
@@ -448,18 +426,16 @@ class HyperparametersConfig(BaseModel):
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"adopt_adamw",
],
]
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None,
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
)
optim_target_modules: Optional[Union[List[str], Literal["all_linear"]]] = Field(
default=None,
json_schema_extra={
"description": "The target modules to optimize, i.e. the module names that you would like to train."
metadata={
"help": "The target modules to optimize, i.e. the module names that you would like to train."
},
)
torchdistx_path: Optional[str] = None
@@ -519,15 +495,15 @@ class LISAConfig(BaseModel):
lisa_n_layers: Optional[int] = Field(
default=None,
json_schema_extra={"description": "the number of activate layers in LISA"},
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = Field(
default=None,
json_schema_extra={"description": "how often to switch layers in LISA"},
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = Field(
default="model.layers",
json_schema_extra={"description": "path under the model to access the layers"},
metadata={"help": "path under the model to access the layers"},
)
@@ -611,9 +587,6 @@ class AxolotlInputConfig(
rl: Optional[RLType] = None
reward_model: Optional[bool] = None
dpo_use_weighting: Optional[
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
@@ -626,11 +599,9 @@ class AxolotlInputConfig(
pretraining_dataset: Optional[ # type: ignore
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
] = Field(
default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"},
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
)
dataset_processes: Optional[int] = Field(default=os.cpu_count())
dataset_exact_deduplication: Optional[bool] = None
dataset_keep_in_memory: Optional[bool] = None
dataloader_pin_memory: Optional[bool] = None
dataloader_num_workers: Optional[int] = None
@@ -688,8 +659,7 @@ class AxolotlInputConfig(
sequence_len: int = Field(default=512)
min_sample_len: Optional[int] = None
max_prompt_len: int = Field(
default=512,
json_schema_extra={"description": "maximum prompt length for RL training"},
default=512, metadata={"help": "maximum prompt length for RL training"}
)
sample_packing: Optional[bool] = None
sample_packing_group_size: Optional[int] = 100_000
@@ -708,8 +678,8 @@ class AxolotlInputConfig(
pretrain_multipack_buffer_size: Optional[int] = 10_000
pretrain_multipack_attn: Optional[bool] = Field(
default=True,
json_schema_extra={
"description": "whether to prevent cross attention for packed sequences during pretraining",
metadata={
"help": "whether to prevent cross attention for packed sequences during pretraining",
},
)
@@ -755,7 +725,7 @@ class AxolotlInputConfig(
warmup_ratio: Optional[float] = None
eval_steps: Optional[Union[int, float]] = None
evals_per_epoch: Optional[Union[int]] = None
eval_strategy: Optional[str] = None
evaluation_strategy: Optional[str] = None
save_steps: Optional[Union[int, float]] = None
saves_per_epoch: Optional[int] = None
save_strategy: Optional[str] = None
@@ -807,25 +777,28 @@ class AxolotlInputConfig(
is_mistral_derived_model: Optional[bool] = Field(default=None)
is_qwen_derived_model: Optional[bool] = Field(default=None)
plugins: Optional[List[str]] = Field(default=None)
@field_validator("datasets", mode="before")
@classmethod
def deprecate_sharegpt_datasets(cls, datasets):
for _, ds_cfg in enumerate(datasets):
if not ds_cfg.get("type"):
def fix_sharegpt_datasets(cls, datasets):
for idx, ds_cfg in enumerate(datasets):
if not ds_cfg["type"]:
continue
ds_type = ds_cfg["type"]
# skip if it's a dict (for custom user instruction prompt)
if isinstance(ds_type, dict):
continue
if isinstance(ds_type, str) and ds_type.startswith("sharegpt"):
raise ValueError(
"`type: sharegpt.*` is deprecated. Please use `type: chat_template` instead."
if ds_cfg["type"] == "sharegpt:chat":
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
datasets[idx]["type"] = "sharegpt"
if "sharegpt_simple" in ds_cfg["type"]:
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
datasets[idx]["type"] = datasets[idx]["type"].replace(
"sharegpt_simple", "sharegpt"
)
return datasets
@model_validator(mode="before")
@@ -1057,21 +1030,21 @@ class AxolotlInputConfig(
@classmethod
def check_evals(cls, data):
if (
data.get("eval_strategy")
data.get("evaluation_strategy")
and data.get("eval_steps")
and data.get("eval_strategy") != "steps"
and data.get("evaluation_strategy") != "steps"
):
raise ValueError(
"eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps."
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
)
if (
data.get("val_set_size") == 0
and (data.get("eval_steps") or data.get("eval_strategy"))
and (data.get("eval_steps") or data.get("evaluation_strategy"))
and not data.get("test_datasets")
):
raise ValueError(
"eval_steps and eval_strategy are not supported with val_set_size == 0"
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
)
if data.get("evals_per_epoch") and data.get("eval_steps"):
raise ValueError(
@@ -1079,11 +1052,11 @@ class AxolotlInputConfig(
)
if (
data.get("evals_per_epoch")
and data.get("eval_strategy")
and data.get("eval_strategy") != "steps"
and data.get("evaluation_strategy")
and data.get("evaluation_strategy") != "steps"
):
raise ValueError(
"eval_strategy must be empty or set to `steps` when used with evals_per_epoch."
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
if data.get("do_bench_eval") and not (
@@ -1315,26 +1288,6 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def warn_qlora_zero3_w_use_reentrant(cls, data):
if (
data.get("adapter") == "qlora"
and data.get("gradient_checkpointing_kwargs", {})
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
is False
and data.get("deepspeed", "") is not None
and "zero3" in data.get("deepspeed", "")
):
# may result in:
# torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint:
# Recomputed values for the following tensors have different metadata
# than during the forward pass.
LOG.warning(
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
)
return data
@model_validator(mode="before")
@classmethod
def check_val_w_test_datasets(cls, data):
@@ -1344,19 +1297,6 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_eval_strategy(cls, data):
if (
data.get("evaluation_strategy") is not None
and data.get("eval_strategy") is None
):
LOG.info(
"explicitly setting `eval_strategy` from the `evaluation_strategy`"
)
data["eval_strategy"] = data.get("evaluation_strategy")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_offload_w_8bit_optimizer(cls, data):
@@ -1435,6 +1375,21 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_unsloth_xformers_version(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
xformers_version = version("xformers")
if xformers_version == "0.0.27":
raise ValueError(
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):
@@ -1444,46 +1399,11 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_npu_config(cls, data):
if is_torch_npu_available():
# check attention config
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
for attn in attn_list:
if data.get(attn):
raise NotImplementedError(
f"{attn} is currently not supported in Ascend npu, please disable this configuration."
)
# check quant config
if data.get("optimizer") is not None and "bit" in data.get("optimizer"):
optimizer = data.get("optimizer")
raise NotImplementedError(
f"{optimizer} is currently not supported in Ascend npu, choose another one please."
)
quant_list = ["load_in_8bit", "load_in_4bit"]
for quant in quant_list:
if data.get(quant):
raise NotImplementedError(
f"Quantification is currently not supported in Ascend npu, please disable {quant}."
)
# check dtype config
if data.get("tf32"):
raise NotImplementedError(
"tf32 dtype is currently not supported in Ascend npu, please disable this configuration"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
capabilities: GPUCapabilities
env_capabilities: EnvCapabilities
@model_validator(mode="after")
def check_bf16(self):
@@ -1558,21 +1478,3 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
)
return data
@model_validator(mode="before")
@classmethod
def check_adopt_torch_version(cls, data):
if (data.get("optimizer") is not None) and ("adopt" in data.get("optimizer")):
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.5.1"):
raise ValueError(
"ADOPT optimizer is incompatible with torch version < 2.5.1"
)
return data

View File

@@ -12,9 +12,3 @@ class GPUCapabilities(BaseModel):
n_gpu: int = Field(default=1)
n_node: int = Field(default=1)
compute_capability: Optional[str] = Field(default=None)
class EnvCapabilities(BaseModel):
"""model to manage the environment capabilities statically"""
torch_version: Optional[str] = Field(default=None)

View File

@@ -13,7 +13,7 @@ from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.dpo import load as load_dpo
from axolotl.prompt_strategies.kto import load as load_kto
from axolotl.prompt_strategies.orpo import load as load_orpo
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.data.utils import md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.models import load_tokenizer
@@ -64,57 +64,15 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
tokenizer = load_tokenizer(cfg)
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
if isinstance(data_set, DatasetDict):
data_set = data_set["train"]
data_set = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",
)
if isinstance(data_set, DatasetDict):
data_set = data_set["train"]
return data_set
def drop_long_rl_seq(
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
):
if rl in ("dpo", "ipo", "orpo", "simpo"):
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
raise ValueError(
"Prompt, chosen and rejected keys are required for DPO/ORPO datasets"
)
prompt = sample["prompt"]
chosen = sample["chosen"]
rejected = sample["rejected"]
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
return (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
if rl == "kto":
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")
prompt = sample["prompt"]
completion = sample["completion"]
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_completion = len(
tokenizer(completion, add_special_tokens=False)["input_ids"]
)
return (len_prompt + len_completion) <= sequence_len
raise ValueError("Unknown RL type")
def load_prepare_dpo_datasets(cfg):
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
@@ -136,7 +94,7 @@ def load_prepare_dpo_datasets(cfg):
)
split_datasets.insert(i, ds)
tokenizer = load_tokenizer(cfg)
tokenizer = None
for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"]
@@ -163,28 +121,7 @@ def load_prepare_dpo_datasets(cfg):
# "prompt", "chosen" and "rejected" already preprocessed
split_datasets[i] = data_set
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
return combined_datasets
return concatenate_datasets(split_datasets)
with zero_first(is_main_process()):
train_is_preprocessed = False
@@ -208,9 +145,4 @@ def load_prepare_dpo_datasets(cfg):
if eval_dataset and not eval_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
if cfg.dataset_exact_deduplication:
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=train_dataset, eval_dataset=eval_dataset
)
return train_dataset, eval_dataset

View File

@@ -2,11 +2,9 @@
import functools
import logging
import time
from pathlib import Path
from typing import List, Optional, Tuple, Union
import requests
from datasets import (
Dataset,
DatasetDict,
@@ -44,7 +42,7 @@ from axolotl.prompters import (
UnsupportedPrompter,
)
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.data.utils import md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.trainer import (
@@ -55,28 +53,6 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl")
def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None):
prompters = []
if not cfg.pretraining_dataset:
@@ -136,9 +112,8 @@ def prepare_dataset(cfg, tokenizer, processor=None):
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if cfg.dataset_exact_deduplication:
LOG.info("Deduplication not available for pretrained datasets")
return train_dataset, eval_dataset, cfg.max_steps, prompters
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
@@ -261,7 +236,6 @@ def load_tokenized_prepared_datasets(
for config_dataset in for_d_in_datasets(cfg_datasets):
ds: Optional[Union[Dataset, DatasetDict]] = None
ds_from_hub = False
ds_trust_remote_code = config_dataset.trust_remote_code
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
@@ -271,7 +245,6 @@ def load_tokenized_prepared_datasets(
streaming=True,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=ds_trust_remote_code,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
@@ -351,15 +324,7 @@ def load_tokenized_prepared_datasets(
split=None,
)
else:
try:
ds = load_from_disk(config_dataset.path)
except FileNotFoundError:
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
split=None,
)
ds = load_from_disk(config_dataset.path)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
@@ -377,7 +342,7 @@ def load_tokenized_prepared_datasets(
elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
load_ds_kwargs = {"split": config_dataset.split}
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
@@ -385,7 +350,6 @@ def load_tokenized_prepared_datasets(
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
@@ -403,7 +367,6 @@ def load_tokenized_prepared_datasets(
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
@@ -414,7 +377,6 @@ def load_tokenized_prepared_datasets(
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
else:
if isinstance(config_dataset.data_files, str):
@@ -585,8 +547,7 @@ def load_prepare_datasets(
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)
if cfg.dataset_exact_deduplication:
_, _, dataset = deduplicate_and_log_datasets(dataset=dataset)
dataset = dataset.train_test_split(
test_size=val_set_size,
shuffle=False,
@@ -598,17 +559,12 @@ def load_prepare_datasets(
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
elif split == "test":
if cfg.dataset_exact_deduplication:
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
else:
eval_dataset = dataset
train_dataset = None
eval_dataset = dataset
else:
if cfg.dataset_exact_deduplication:
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
else:
train_dataset = dataset
train_dataset = dataset
eval_dataset = None
return train_dataset, eval_dataset, prompters

Some files were not shown because too many files have changed in this diff Show More