Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
20d0427ac9 update llama3 example base models to use nous 2024-07-15 17:19:00 -04:00
114 changed files with 961 additions and 7198 deletions

View File

@@ -12,24 +12,36 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "121"
cuda_version: 12.1.1
cudnn_version: 8
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.3.1
pytorch: 2.1.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.1
cudnn_version: 8
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.1.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.4.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v3
@@ -55,7 +67,6 @@ jobs:
labels: ${{ steps.metadata.outputs.labels }}
build-args: |
CUDA_VERSION=${{ matrix.cuda_version }}
CUDNN_VERSION=${{ matrix.cudnn_version }}
CUDA=${{ matrix.cuda }}
PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch }}

View File

@@ -6,7 +6,7 @@ on:
- '**.py'
- 'requirements.txt'
- '.github/workflows/*.yml'
- "*.[q]md"
- "*.md"
- "examples/**/*.y[a]?ml"
workflow_dispatch:

View File

@@ -13,22 +13,28 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.3.1
axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras: mamba-ssm
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
pytorch: 2.1.2
axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -59,7 +65,6 @@ jobs:
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
@@ -70,22 +75,27 @@ jobs:
strategy:
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.3.1
pytorch: 2.1.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.1
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -124,7 +134,7 @@ jobs:
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:

View File

@@ -1,55 +0,0 @@
name: docker-multigpu-tests-biweekly
on:
pull_request:
paths:
- 'tests/e2e/multigpu/*.py'
workflow_dispatch:
schedule:
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
jobs:
test-axolotl-multigpu:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.multigpu

View File

@@ -12,22 +12,28 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.3.1
pytorch: 2.1.2
axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.1
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -69,22 +75,27 @@ jobs:
strategy:
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.3.1
pytorch: 2.1.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.1
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -1,120 +0,0 @@
name: Tests Nightly against upstream main
on:
workflow_dispatch:
schedule:
- cron: '0 0 * * *' # Runs at 00:00 UTC every day
jobs:
pre-commit:
name: pre-commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
env:
SKIP: no-commit-to-branch
pytest:
name: PyTest
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.0"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
- name: Update requirements.txt
run: |
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install -U -e .
pip3 install -r requirements-tests.txt
- name: Run tests
run: |
pytest --ignore=tests/e2e/ tests/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 60
needs: [pre-commit, pytest]
strategy:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
python_version: "3.10"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.tests

View File

@@ -26,8 +26,6 @@ jobs:
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
env:
SKIP: no-commit-to-branch
pytest:
name: PyTest
@@ -36,7 +34,6 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.3.1", "2.4.0"]
timeout-minutes: 20
steps:
@@ -49,10 +46,6 @@ jobs:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
- name: Install dependencies
run: |
pip3 install --upgrade pip
@@ -64,10 +57,6 @@ jobs:
run: |
pytest --ignore=tests/e2e/ tests/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
@@ -79,24 +68,27 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 121
cuda_version: 12.1.1
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.3.1
pytorch: 2.1.2
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.1
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
num_gpus: 1
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.2.2
num_gpus: 1
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -107,13 +99,12 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal

View File

@@ -11,9 +11,6 @@ ignore_errors = True
[mypy-axolotl.models.mixtral.*]
ignore_errors = True
[mypy-axolotl.integrations.liger.models.*]
ignore_errors = True
[mypy-axolotl.models.phi.*]
ignore_errors = True

View File

@@ -8,8 +8,6 @@ repos:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:

104
README.md
View File

@@ -1,9 +1,5 @@
# Axolotl
![tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg)
![tests-nightly](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg)
![multigpu-semi-weekly tests](https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg)
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
Features:
@@ -11,7 +7,7 @@ Features:
- Supports fullfinetune, lora, qlora, relora, and gptq
- Customize configurations using a simple yaml file or CLI overwrite
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
- Integrated with xformer, flash attention, [liger kernel](https://github.com/linkedin/Liger-Kernel), rope scaling, and multipacking
- Integrated with xformer, flash attention, rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb or mlflow
@@ -26,50 +22,38 @@ Features:
<td>
## Table of Contents
- [Axolotl](#axolotl)
- [Table of Contents](#table-of-contents)
- [Axolotl supports](#axolotl-supports)
- [Quickstart ⚡](#quickstart-)
- [Usage](#usage)
- [Advanced Setup](#advanced-setup)
- [Environment](#environment)
- [Docker](#docker)
- [Conda/Pip venv](#condapip-venv)
- [Cloud GPU](#cloud-gpu)
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- [LambdaLabs](#lambdalabs)
- [GCP](#gcp)
- [Windows](#windows)
- [Mac](#mac)
- [Google Colab](#google-colab)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
- [Dataset](#dataset)
- [Config](#config)
- [All Config Options](#all-config-options)
- [Train](#train)
- [Preprocess dataset](#preprocess-dataset)
- [Multi-GPU](#multi-gpu)
- [DeepSpeed](#deepspeed)
- [FSDP](#fsdp)
- [FSDP + QLoRA](#fsdp--qlora)
- [Weights \& Biases Logging](#weights--biases-logging)
- [Special Tokens](#special-tokens)
- [Liger Kernel](#liger-kernel)
- [Inference Playground](#inference-playground)
- [Merge LORA to base](#merge-lora-to-base)
- [Common Errors 🧰](#common-errors-)
- [Tokenization Mismatch b/w Inference \& Training](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
- [Need help? 🙋](#need-help-)
- [Badge ❤🏷️](#badge-)
- [Community Showcase](#community-showcase)
- [Contributing 🤝](#contributing-)
- [Sponsors 🤝❤](#sponsors-)
- [💎 Diamond Sponsors - Contact directly](#-diamond-sponsors---contact-directly)
- [🥇 Gold Sponsors - $5000/mo](#-gold-sponsors---5000mo)
- [🥈 Silver Sponsors - $1000/mo](#-silver-sponsors---1000mo)
- [🥉 Bronze Sponsors - $500/mo](#-bronze-sponsors---500mo)
- [Introduction](#axolotl)
- [Supported Features](#axolotl-supports)
- [Quickstart](#quickstart-)
- [Environment](#environment)
- [Docker](#docker)
- [Conda/Pip venv](#condapip-venv)
- [Cloud GPU](#cloud-gpu) - Latitude.sh, JarvisLabs, RunPod
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- [Windows](#windows)
- [Mac](#mac)
- [Google Colab](#google-colab)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Launching on public clouds via dstack](#launching-on-public-clouds-via-dstack)
- [Dataset](#dataset)
- [Config](#config)
- [Train](#train)
- [Inference](#inference-playground)
- [Merge LORA to Base](#merge-lora-to-base)
- [Special Tokens](#special-tokens)
- [All Config Options](#all-config-options)
- Advanced Topics
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
- [Need Help?](#need-help-)
- [Badge](#badge-)
- [Community Showcase](#community-showcase)
- [Contributing](#contributing-)
- [Sponsors](#sponsors-)
</td>
<td>
@@ -111,7 +95,6 @@ Features:
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
| Jamba | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
✅: supported
❌: not supported
@@ -350,7 +333,7 @@ For further and fine-grained use cases, please refer to the official [dstack doc
Axolotl supports a variety of dataset formats. It is recommended to use a JSONL. The schema of the JSONL depends upon the task and the prompt template you wish to use. Instead of a JSONL, you can also use a HuggingFace dataset with columns for each JSONL field.
See [the documentation](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
See [these docs](https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/) for more information on how to use different dataset formats.
### Config
@@ -531,25 +514,6 @@ tokens: # these are delimiters
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
##### Liger Kernel
Liger Kernel: Efficient Triton Kernels for LLM Training
https://github.com/linkedin/Liger-Kernel
Liger (LinkedIn GPU Efficient Runtime) Kernel is a collection of Triton kernels designed specifically for LLM training.
It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. The Liger Kernel
composes well and is compatible with both FSDP and Deepspeed.
```yaml
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
```
### Inference Playground
Axolotl allows you to load your model in an interactive terminal playground for quick experimentation.

View File

@@ -36,8 +36,6 @@ website:
- docs/nccl.qmd
- docs/mac.qmd
- docs/multi-node.qmd
- docs/unsloth.qmd
- docs/amd_hpc.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Reference"

View File

@@ -8,7 +8,6 @@ ENV BNB_CUDA_VERSION="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
@@ -24,16 +23,10 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^transformers.*#transformers @ git+https://github.com/huggingface/transformers.git@main#' requirements.txt; \
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
fi
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image

View File

@@ -2,5 +2,5 @@
set -e
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
pytest /workspace/axolotl/tests/e2e/patched/
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/

View File

@@ -1,77 +0,0 @@
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import Image, Stub
cicd_path = pathlib.Path(__file__).parent.resolve()
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.3.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.3.1"),
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
}
dockerfile_contents = df_template.render(**df_args)
temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
force_build=True,
gpu="A10G",
)
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
stub = Stub("Axolotl CI/CD", secrets=[])
N_GPUS = int(os.environ.get("N_GPUS", 2))
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=45 * 60,
cpu=8.0,
memory=131072 * N_GPUS,
)
def cicd_pytest():
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
@stub.local_entrypoint()
def main():
cicd_pytest.remote()

View File

@@ -1,5 +0,0 @@
#!/bin/bash
set -e
# only run one test at a time so as not to OOM the GPU
pytest -n1 /workspace/axolotl/tests/e2e/multigpu/

View File

@@ -1,8 +1,6 @@
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os
import pathlib
import tempfile
@@ -23,12 +21,11 @@ df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.3.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.11-cu121-2.3.1"),
"CUDA": os.environ.get("CUDA", "121"),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.0.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.10-cu118-2.0.1"),
"CUDA": os.environ.get("CUDA", "118"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
}
dockerfile_contents = df_template.render(**df_args)

View File

@@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN pip install causal_conv1d
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image

View File

@@ -3,7 +3,7 @@ ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"

View File

@@ -3,6 +3,7 @@ FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -3,6 +3,7 @@ FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
ENV HF_HUB_ENABLE_HF_TRANSFER="1"

View File

@@ -1,108 +0,0 @@
---
title: Training with AMD GPUs on HPC Systems
description: A comprehensive guide for using Axolotl on distributed systems with AMD GPUs
---
This guide provides step-by-step instructions for installing and configuring Axolotl on a High-Performance Computing (HPC) environment equipped with AMD GPUs.
## Setup
### 1. Install Python
We recommend using Miniforge, a minimal conda-based Python distribution:
```bash
curl -L -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
```
### 2. Configure Python Environment
Add Python to your PATH and ensure it's available at login:
```bash
echo 'export PATH=~/miniforge3/bin:$PATH' >> ~/.bashrc
echo 'if [ -f ~/.bashrc ]; then . ~/.bashrc; fi' >> ~/.bash_profile
```
### 3. Load AMD GPU Software
Load the ROCm module:
```bash
module load rocm/5.7.1
```
Note: The specific module name and version may vary depending on your HPC system. Consult your system documentation for the correct module name.
### 4. Install PyTorch
Install PyTorch with ROCm support:
```bash
pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7 --force-reinstall
```
### 5. Install Flash Attention
Clone and install the Flash Attention repository:
```bash
git clone --recursive https://github.com/ROCmSoftwarePlatform/flash-attention.git
export GPU_ARCHS="gfx90a"
cd flash-attention
export PYTHON_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])')
patch "${PYTHON_SITE_PACKAGES}/torch/utils/hipify/hipify_python.py" hipify_patch.patch
pip install .
```
### 6. Install Axolotl
Clone and install Axolotl:
```bash
git clone https://github.com/axolotl-ai-cloud/axolotl
cd axolotl
pip install packaging ninja
pip install -e .
```
### 7. Apply xformers Workaround
xformers appears to be incompatible with ROCm. Apply the following workarounds:
- Edit $HOME/packages/axolotl/src/axolotl/monkeypatch/llama_attn_hijack_flash.py modifying the code to always return `False` for SwiGLU availability from xformers.
- Edit $HOME/miniforge3/lib/python3.10/site-packages/xformers/ops/swiglu_op.py replacing the "SwiGLU" function with a pass statement.
### 8. Prepare Job Submission Script
Create a script for job submission using your HPC's particular software (e.g. Slurm, PBS). Include necessary environment setup and the command to run Axolotl training. If the compute node(s) do(es) not have internet access, it is recommended to include
```bash
export TRANSFORMERS_OFFLINE=1
export HF_DATASETS_OFFLINE=1
```
### 9. Download Base Model
Download a base model using the Hugging Face CLI:
```bash
huggingface-cli download meta-llama/Meta-Llama-3.1-8B --local-dir ~/hfdata/llama3.1-8B
```
### 10. Create Axolotl Configuration
Create an Axolotl configuration file (YAML format) tailored to your specific training requirements and dataset. Use FSDP for multi-node training.
Note: Deepspeed did not work at the time of testing. However, if anyone managed to get it working, please let us know.
### 11. Preprocess Data
Run preprocessing on the login node:
```bash
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess /path/to/your/config.yaml
```
### 12. Train
You are now ready to submit your previously prepared job script. 🚂

View File

@@ -54,14 +54,6 @@ conversations where `from` is `prompter` `assistant` instead of default sharegpt
{"conversations": [{"from": "...", "value": "..."}]}
```
## sharegpt.load_ultrachat
conversations where the turns field is 'messages', human is 'user' and gpt is 'assistant'.
```{.json filename="data.jsonl"}
{"messages": [{"user": "...", "assistant": "..."}]}
```
## sharegpt_jokes
creates a chat where bot is asked to tell a joke, then explain why the joke is funny

View File

@@ -7,7 +7,7 @@ order: 5
- Pass an empty `type:` in your axolotl config.
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
- To indicate that a token should be ignored during training, set its corresponding label to `-100`.
- You must add BOS and EOS, and make sure that you are training on EOS by not setting its label to -100.
- Do not add BOS/EOS. Axolotl will add them for you based on the default tokenizer for the model you're using.
- For pretraining, do not truncate/pad documents to the context window length.
- For instruction training, documents must be truncated/padded as desired.

View File

@@ -205,7 +205,7 @@ ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
hi there!. goodbye farewell</s>
```
We can check that the right tokens are ignored by comparing the labels
We can check that the right tokens are ingored by comparing the labels
to each token:
```python

View File

@@ -1,28 +0,0 @@
# MultiModal / Vision Language Models (BETA)
### Supported Models
- Mllama, i.e. llama with vision models
### Usage
Currently multimodal support is limited and doesn't have full feature parity. To finetune a multimodal Llama w/ LoRA,
you'll need to use the following in YAML in combination with the rest of the required hyperparams.
```yaml
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
processor_type: AutoProcessor
skip_prepare_dataset: true
chat_template: llama3_2_vision
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
remove_unused_columns: false
sample_packing: false
# only finetune the Language model, leave the vision model and vision tower frozen
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
```

View File

@@ -1,19 +0,0 @@
---
title: "PyTorch ao"
description: "Custom data types and layouts for training and inference"
---
### Installation
Stable Release from the PyTorch index
```bash
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124
```
Nightly release
```bash
pip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
```

View File

@@ -1,49 +0,0 @@
---
title: "Unsloth"
description: "Hyper-optimized QLoRA finetuning for single GPUs"
---
### Overview
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
standard industry baselines.
### Installation
The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
to date libraries.
```bash
pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps --force-reinstall xformers==0.0.26.post1
```
### Using unsloth w Axolotl
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
Our unsloth integration is currently limited to the following model architectures:
- llama
These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
```yaml
unsloth_lora_mlp: true
unsloth_lora_qkv: true
unsloth_lora_o: true
```
These options are composable and can be used with multi-gpu finetuning
```yaml
unsloth_cross_entropy_loss: true
unsloth_rms_norm: true
unsloth_rope: true
```
### Limitations
- Single GPU only; e.g. no multi-gpu support
- No deepspeed or FSDP support (requires multi-gpu)
- LoRA + QLoRA support only. No full fine tunes or fp8 support.
- Limited model architecture support. Llama, Phi, Gemma, Mistral only
- No MoE support.

View File

@@ -43,6 +43,7 @@
},
"outputs": [],
"source": [
"!pip install torch==\"2.1.2\"\n",
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n",
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""

View File

@@ -1,67 +0,0 @@
base_model: deepseek-ai/DeepSeek-V2-Lite
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
special_tokens:
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD

View File

@@ -1,83 +0,0 @@
base_model: axolotl-quants/DeepSeek-V2.5-bnb-nf4-bf16
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
chat_template: deepseek_v2
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
adapter: qlora
lora_r: 256
lora_alpha: 256
lora_target_linear: true
peft_use_rslora: true
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
special_tokens:
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD

View File

@@ -6,5 +6,5 @@
- ✅ qlora w/ deepspeed Zero-3 needs at least 2x GPUs and 67GiB VRAM (wtf?)
- ✅ qlora single-gpu, ~51GiB VRAM
- ✅ multipack
- FSDP
- FSDP
- ❓ 8-bit LoRA

View File

@@ -1,61 +0,0 @@
base_model: ai21labs/AI21-Jamba-1.5-Large
tokenizer_type: AutoTokenizer
load_in_4bit: true
strict: false
use_tensorboard: true
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: jamba
drop_system_message: true
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: jamba-large-fsdp-qlora-ft
save_safetensors: true
adapter: qlora
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: [down_proj,gate_proj,in_proj,k_proj,o_proj,out_proj,q_proj,up_proj,v_proj,x_proj]
lora_target_linear: false
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: JambaAttentionDecoderLayer,JambaMambaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD

View File

@@ -1,63 +0,0 @@
base_model: alpindale/Llama-3.2-11B-Vision-Instruct
processor_type: AutoProcessor
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: llama3_2_vision
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -1,76 +0,0 @@
base_model: NousResearch/Meta-Llama-3.1-8B
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_swiglu: true
liger_fused_linear_cross_entropy: true
strict: false
chat_template: llama3
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./outputs/out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_backward_prefetch: BACKWARD_PRE
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>

View File

@@ -1,4 +1,6 @@
base_model: NousResearch/Meta-Llama-3.1-8B
base_model: NousResearch/Meta-Llama-3-8B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false

View File

@@ -1,81 +0,0 @@
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
chat_template: llama3
rl: dpo
datasets:
- path: fozziethebeat/alpaca_messages_2k_dpo_test
type: chat_template.default
chat_template: llama3
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
roles:
system:
- system
user:
- user
assistant:
- assistant
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -1,4 +1,4 @@
base_model: NousResearch/Meta-Llama-3-8B-Instruct
base_model: meta-llama/Meta-Llama-3-8B-Instruct
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer
@@ -74,5 +74,3 @@ deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,63 +0,0 @@
base_model: hugging-quants/Meta-Llama-3.1-405B-BNB-NF4-BF16
tokenizer_type: AutoTokenizer
load_in_4bit: true
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out/qlora-llama3_1-405b
save_safetensors: true
adapter: qlora
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:
pad_token: <|finetune_right_pad_id|>

View File

@@ -1,4 +1,4 @@
base_model: casperhansen/llama-3-70b-fp16
base_model: NousResearch/Meta-Llama-3-70B
model_type: LlamaForCausalLM
tokenizer_type: AutoTokenizer # PreTrainedTokenizerFast

View File

@@ -7,7 +7,7 @@ load_in_4bit: true
strict: false
datasets:
- path: aaditya/alpaca_subset_1
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path:
val_set_size: 0

View File

@@ -1,76 +0,0 @@
base_model: microsoft/Phi-3.5-mini-instruct
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: true
load_in_4bit: false
strict: false
chat_template: phi_3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
chat_template: phi_3
field_messages: messages
message_field_role: role
message_field_content: content
roles:
user:
- user
assistant:
- assistant
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./outputs/lora-out
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 2
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bfloat16: true
bf16: true
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 4
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -1,65 +0,0 @@
base_model: mistral-community/pixtral-12b
processor_type: AutoProcessor
load_in_8bit: true
strict: false
# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false
chat_template: llama3_2_vision
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out
adapter: lora
lora_model_dir:
sequence_len: 8192
pad_to_sequence_len: false
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true
gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:

View File

@@ -72,5 +72,4 @@ fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
special_tokens:

View File

@@ -1,4 +1,4 @@
base_model: TinyLlama/TinyLlama_v1.1
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer

View File

@@ -1,5 +1,6 @@
base_model: TinyLlama/TinyLlama_v1.1
tokenizer_type: AutoTokenizer
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
load_in_4bit: false

View File

@@ -9,9 +9,9 @@ strict: false
max_steps: 200
pretraining_dataset:
- path: allenai/c4
name: en
type: pretrain
path: c4
name: en
type: pretrain
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/model-out

View File

@@ -1,4 +1,4 @@
base_model: TinyLlama/TinyLlama_v1.1
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer

View File

@@ -1,18 +1,18 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.0
transformers==4.45.1
tokenizers>=0.19.1
bitsandbytes==0.44.0
accelerate==0.34.2
datasets==2.21.0
deepspeed==0.14.4
peft==0.11.1
transformers==4.42.3
tokenizers==0.19.1
bitsandbytes==0.43.1
accelerate==0.32.0
deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b
pydantic==2.6.3
addict
fire
PyYAML>=6.0
requests
flash-attn==2.6.3
datasets==2.19.1
flash-attn==2.6.1
sentencepiece
wandb
einops
@@ -21,26 +21,23 @@ optimum==1.16.2
hf_transfer
colorama
numba
numpy>=1.24.4,<=2.0.1
numpy>=1.24.4
# qlora things
evaluate==0.4.1
scipy
scikit-learn==1.4.2
scikit-learn==1.2.2
pynvml
art
fschat @ git+https://github.com/lm-sys/FastChat.git@27a05b04a35510afb1d767ae7e5990cbd278f8fe
gradio==3.50.2
tensorboard
python-dotenv==1.0.1
autoawq>=0.2.5
triton>=2.3.0
liger-kernel==0.3.0
mamba-ssm==1.2.0.post1
# remote filesystems
s3fs>=2024.5.0
gcsfs>=2024.5.0
s3fs
gcsfs
# adlfs
trl==0.9.6

View File

@@ -80,13 +80,13 @@ setup(
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn==2.6.3",
"flash-attn==2.6.1",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.2#subdirectory=csrc/fused_dense_lib",
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib",
],
"deepspeed": [
"deepspeed==0.14.4",
"deepspeed @ git+https://github.com/microsoft/DeepSpeed.git@bc48371c5e1fb8fd70fc79285e66201dbb65679b",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -27,10 +27,8 @@ from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
@@ -40,9 +38,9 @@ from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -235,8 +233,7 @@ def do_inference_gradio(
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
# default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
default_tokens: Dict[str, str] = {}
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
@@ -244,13 +241,10 @@ def do_inference_gradio(
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = chat_templates(cfg.chat_template)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
@@ -264,24 +258,7 @@ def do_inference_gradio(
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
@@ -304,7 +281,6 @@ def do_inference_gradio(
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"attention_mask": batch["attention_mask"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
@@ -389,11 +365,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
cfg.axolotl_config_path = config
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
@@ -404,15 +375,13 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"n_gpu": os.environ.get("WORLD_SIZE", 1),
"compute_capability": gpu_version,
},
)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)
normalize_cfg_datasets(cfg)
@@ -430,12 +399,9 @@ def load_datasets(
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,
tokenizer,
processor=processor,
cfg, tokenizer
)
if cli_args.debug or cfg.debug:

View File

@@ -1,204 +0,0 @@
"""
This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint
"""
import json
import logging
import os
import shutil
from pathlib import Path
from typing import Dict, Union
import fire
import torch
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.format_utils as dist_cp_format_utils
import transformers
from accelerate.utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
is_torch_version,
)
from dotenv import load_dotenv
from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights")
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
"""
A custom planner to cast tensors to bfloat16 on the fly during loading.
"""
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
tensor.copy_(tensor.to(torch.bfloat16))
def _distributed_checkpoint_to_merged_weights(
checkpoint_dir: Union[str, Path],
save_path: str,
safe_serialization: bool = False,
max_shard_size: str = "5GB",
):
"""
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`
Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
"""
state_dict: Dict = {}
save_path_ = Path(save_path)
save_path_.mkdir(exist_ok=True)
dist_cp_format_utils._load_state_dict( # pylint: disable=protected-access
state_dict,
storage_reader=dist_cp.FileSystemReader(checkpoint_dir),
planner=BFloat16CastPlanner(), # pylint: disable=protected-access
no_dist=True,
)
# To handle if state is a dict like {model: {...}}
if len(state_dict.keys()) == 1:
state_dict = state_dict[list(state_dict)[0]]
# Ensure all tensors are in bfloat16
for key, value in state_dict.items():
if isinstance(value, torch.Tensor) and value.dtype != torch.bfloat16:
state_dict[key] = value.to(torch.bfloat16)
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
# Save index if sharded
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the model
filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor] for tensor in tensors}
if safe_serialization:
safe_save_file(
shard, os.path.join(save_path_, shard_file), metadata={"format": "pt"}
)
else:
torch.save(shard, os.path.join(save_path_, shard_file))
if index is not None:
save_index_file = (
SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
)
save_index_file = os.path.join(save_path_, save_index_file)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as fout:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
fout.write(content)
return save_path_
def merge_fsdp_weights(
checkpoint_dir: str,
output_path: str,
safe_serialization: bool = False,
remove_checkpoint_dir: bool = False,
):
"""
Merge the weights from sharded FSDP model checkpoints into a single combined checkpoint. Should be used if
`SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}/model.safetensors` if
`safe_serialization` else `pytorch_model.bin`.
Note: this is a CPU-bound process.
Args:
checkpoint_dir (`str`):
The directory containing the FSDP checkpoints (can be either the model or optimizer).
output_path (`str`):
The path to save the merged checkpoint.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the merged weights with safetensors (recommended).
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging.
"""
checkpoint_dir_ = Path(checkpoint_dir)
from accelerate.state import PartialState
if not is_torch_version(">=", "2.3.0"):
raise ValueError("`merge_fsdp_weights` requires PyTorch >= 2.3.0`")
# Verify that the checkpoint directory exists
if not checkpoint_dir_.exists():
model_path_exists = (checkpoint_dir_ / "pytorch_model_fsdp_0").exists()
optimizer_path_exists = (checkpoint_dir_ / "optimizer_0").exists()
err = f"Tried to load from {checkpoint_dir_} but couldn't find a valid metadata file."
if model_path_exists and optimizer_path_exists:
err += (
" However, potential model and optimizer checkpoint directories exist."
)
err += f"Please pass in either {checkpoint_dir_}/pytorch_model_fsdp_0 or {checkpoint_dir_}/optimizer_0"
err += "instead."
elif model_path_exists:
err += " However, a potential model checkpoint directory exists."
err += (
f"Please try passing in {checkpoint_dir_}/pytorch_model_fsdp_0 instead."
)
elif optimizer_path_exists:
err += " However, a potential optimizer checkpoint directory exists."
err += f"Please try passing in {checkpoint_dir_}/optimizer_0 instead."
raise ValueError(err)
# To setup `save` to work
state = PartialState()
if state.is_main_process:
LOG.info(f"Merging FSDP weights from {checkpoint_dir_}")
save_path = _distributed_checkpoint_to_merged_weights(
checkpoint_dir_, output_path, safe_serialization
)
LOG.info(f"Successfully merged FSDP weights and saved to {save_path}")
if remove_checkpoint_dir:
LOG.info(f"Removing old checkpoint directory {checkpoint_dir_}")
shutil.rmtree(checkpoint_dir_)
state.wait_for_everyone()
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(
config,
**kwargs,
)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
merge_fsdp_weights(
checkpoint_dir=str(fsdp_dir),
output_path=str(Path(parsed_cfg.output_dir) / "merged"),
safe_serialization=True,
)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -2,7 +2,6 @@
CLI to run training on a model
"""
import logging
import warnings
from pathlib import Path
from typing import Union
@@ -77,19 +76,8 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
if parsed_cli_args.download:
model_name = parsed_cfg.base_model
with warnings.catch_warnings():
# there are a bunch of useless UserWarnings about
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
warnings.simplefilter("ignore")
with init_empty_weights(include_buffers=True):
# fmt: off
try:
AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=True
)
except Exception as exc: # pylint: disable=broad-exception-caught,unused-variable # nosec B110 # noqa F841
pass
# fmt: on
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
LOG.info(
Fore.GREEN

View File

@@ -1,15 +0,0 @@
"""
Common architecture specific constants
"""
MOE_ARCH_BLOCK = {
"dbrx": "DbrxFFN",
"jamba": "JambaSparseMoeBlock",
"jetmoe": [
"JetMoeMoA",
"JetMoeMoE",
],
"mixtral": "MixtralSparseMoeBlock",
"qwen2_moe": "Qwen2MoeSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
}

View File

@@ -1,150 +0,0 @@
"""
helper functions for fixing the embeddings/tokenizer
"""
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import itertools
import numpy as np
import torch
@torch.inference_mode
def fix_untrained_tokens(model, tokenizer, train_dataset, eps=1e-16):
"""
Many of the newer models have reserved tokens that are not trained.
"""
embedding_matrix = model.get_input_embeddings().weight
lm_head_matrix = model.get_output_embeddings().weight
# Get untrained tokens
indicator_untrained = torch.amax(embedding_matrix, axis=1) <= eps
where_untrained = torch.where(indicator_untrained)[0]
n_untrained = where_untrained.shape[0]
n_trained = embedding_matrix.shape[0] - n_untrained
# Get set and actual tokens
where_untrained = where_untrained.tolist()
if len(where_untrained) == 0:
return False
# Remove untrained indices where it's longer
where_untrained_set = frozenset(where_untrained)
actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
# Remove None items in actual_bad_tokens
actual_bad_tokens = [x for x in actual_bad_tokens if x is not None]
# Check if tokenizer and training datasets have bad tokens
if_bad_first = False
if_bad_second = False
# Check tokenizer's chat template for any untrained tokens
chat_template = getattr(tokenizer, "chat_template", None)
if chat_template is not None:
if_bad_first = any(x in chat_template for x in actual_bad_tokens)
# Check the first 250, last 250 input_ids
size_dataset = len(train_dataset)
size = min(size_dataset, 250)
for j in range(size):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
if_bad = any(item in where_untrained_set for item in input_ids)
if if_bad:
if_bad_second = True
break
# Check last 250
if not if_bad_second:
left = max(size_dataset - 250, 0)
for j in range(left, size_dataset):
input_ids = train_dataset[j]
if "input_ids" in input_ids:
input_ids = input_ids["input_ids"]
if_bad = any(item in where_untrained_set for item in input_ids)
if if_bad:
if_bad_second = True
break
# Check if bad tokens exists!
if not if_bad_first and not if_bad_second:
return False
# Count all the possible bad tokens
final_counts = np.zeros(
max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64
)
def mapping(examples):
input_ids = examples["input_ids"]
counter = np.fromiter(itertools.chain.from_iterable(input_ids), dtype=np.int32)
np.add.at(final_counts, counter, 1)
train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
# Get sum of all items
sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
# Remove bad tokens
sum_embedding -= torch.sum(
embedding_matrix[where_untrained], dtype=torch.float32, axis=0
)
sum_lm_head -= torch.sum(
lm_head_matrix[where_untrained], dtype=torch.float32, axis=0
)
# Find correct average by dividing by sum of trained tokens
mean_embedding = sum_embedding / n_trained
mean_lm_head = sum_lm_head / n_trained
# Scale each to be equal to 1/max_frequency. Also set some to 0 if none seen
scaling = final_counts[where_untrained] / max(final_counts.max(), 1)
scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
mean_embedding = (
mean_embedding.repeat(
(
n_untrained,
1,
)
)
* scaling
)
mean_lm_head = (
mean_lm_head.repeat(
(
n_untrained,
1,
)
)
* scaling
)
where_null = scaling.ravel() == 0
mean_embedding[where_null] = 0
mean_lm_head[where_null] = 0
# Set them to the mean
embedding_matrix[where_untrained] = mean_embedding.to(embedding_matrix.dtype)
lm_head_matrix[where_untrained] = mean_lm_head.to(lm_head_matrix.dtype)
# Clean up
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
return True

View File

@@ -4,25 +4,21 @@ Builder for the training args and trainer
"""
import abc
import gc
import importlib
import importlib.util
import logging
import math
import os
import sys
from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Type, Union
from typing import Dict, List, Literal, Optional, Type, Union
import torch
import transformers
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
@@ -32,20 +28,12 @@ from transformers import (
TrainerCallback,
TrainingArguments,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import (
CPOConfig,
CPOTrainer,
DPOConfig,
DPOTrainer,
KTOConfig,
KTOTrainer,
ORPOConfig,
ORPOTrainer,
)
from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
from trl.trainer.utils import pad_to_length
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_mlflow_available
@@ -61,14 +49,12 @@ from axolotl.utils.callbacks import (
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.chat_templates import chat_templates
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.models import ensure_dtype
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
@@ -246,16 +232,6 @@ class AxolotlTrainingMixins:
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
@dataclass
@@ -289,24 +265,92 @@ class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
"""
@dataclass
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
class AxolotlTrainer(Trainer):
"""
CPO config for CPO training
"""
simpo_gamma: Optional[float] = field(
default=None,
metadata={"help": "simpo gamma parameter"},
)
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
tag_names = ["axolotl"]
def __init__(
self,
*_args,
num_epochs=1,
bench_data_collator=None,
eval_data_collator=None,
**kwargs,
):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer != "optimi_adamw"
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", None
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
@@ -332,23 +376,7 @@ class SchedulerMixin(Trainer):
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
if "pct_start" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["pct_start"] = pct_start
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["anneal_strategy"] = "cos"
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif use_cosine_quadratic:
if use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
@@ -386,125 +414,6 @@ class SchedulerMixin(Trainer):
return self.lr_scheduler
class AxolotlTrainer(SchedulerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
tag_names = ["axolotl"]
def __init__(
self,
*_args,
num_epochs=1,
bench_data_collator=None,
eval_data_collator=None,
**kwargs,
):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
elif self.args.alternate_optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:
@@ -512,10 +421,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
train_batch_size = (
self.state.train_batch_size or self.args.per_device_train_batch_size
batch_max_len = (
self.args.per_device_train_batch_size * self.args.max_seq_length
)
batch_max_len = train_batch_size * self.args.max_seq_length
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
lengths=get_dataset_lengths(self.train_dataset),
@@ -870,14 +778,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, metrics=None):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, metrics=metrics)
class AxolotlMambaTrainer(AxolotlTrainer):
"""
@@ -907,6 +807,37 @@ class AxolotlMambaTrainer(AxolotlTrainer):
return lm_loss
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
tag_names = ["axolotl", "onecycle"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
pct_start=pct_start,
div_factor=6,
)
return self.lr_scheduler
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
@@ -946,7 +877,7 @@ class ReLoRATrainer(AxolotlTrainer):
return self.lr_scheduler
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
class AxolotlDPOTrainer(DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
@@ -975,9 +906,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
if is_sagemaker_mp_enabled():
@@ -1006,16 +937,8 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
res[key] = res[key][1:]
return res
def training_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs)
gc.collect()
torch.cuda.empty_cache()
return loss
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
class AxolotlORPOTrainer(ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
@@ -1023,7 +946,7 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
class AxolotlKTOTrainer(KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
@@ -1031,14 +954,6 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
@@ -1049,11 +964,10 @@ class TrainerBuilderBase(abc.ABC):
_model_ref = None
_peft_config = None
def __init__(self, cfg, model, tokenizer, processor=None):
def __init__(self, cfg, model, tokenizer):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
self.processor = processor
# in case the model supports tagging, add the axolotl tag.
# This makes sure the tag is correctly pushed even if a user calls
@@ -1199,6 +1113,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def _get_trainer_cls(self):
if self.cfg.lr_scheduler == "one_cycle" and (
self.cfg.fsdp or self.cfg.adapter == "qlora"
):
return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
@@ -1248,9 +1166,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = {
k.lstrip("fsdp_"): v for k, v in dict(self.cfg.fsdp_config).items()
}
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True
@@ -1359,10 +1275,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"torch_compile_backend"
] = self.cfg.torch_compile_backend
if self.cfg.torch_compile_mode:
training_arguments_kwargs[
"torch_compile_mode"
] = self.cfg.torch_compile_mode
# DDP Config
if self.cfg.ddp_timeout:
@@ -1387,10 +1299,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs[
"auto_find_batch_size"
] = self.cfg.auto_find_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
@@ -1424,8 +1332,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
report_to = []
if self.cfg.use_wandb:
report_to.append("wandb")
if self.cfg.wandb_name:
training_arguments_kwargs["run_name"] = self.cfg.wandb_name
if self.cfg.use_mlflow:
report_to.append("mlflow")
if self.cfg.use_tensorboard:
@@ -1454,15 +1360,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[
"alternate_lr_scheduler_type"
] = self.cfg.lr_scheduler
else:
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
)
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler
if self.cfg.lr_scheduler
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine"
)
training_arguments_kwargs["lr_scheduler_kwargs"] = (
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
@@ -1475,9 +1378,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
training_arguments_kwargs["sample_packing"] = bool(self.cfg.sample_packing)
training_arguments_kwargs["multipack_real_batches"] = (
not self.cfg.flash_attention or self.cfg.multipack_real_batches
)
training_arguments_kwargs[
"multipack_real_batches"
] = not self.cfg.flash_attention
training_arguments_kwargs["eval_sample_packing"] = bool(
self.cfg.eval_sample_packing
)
@@ -1522,10 +1425,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = chat_templates(
self.cfg.chat_template
)
if self.cfg.rl == "orpo":
training_arguments_kwargs["orpo_alpha"] = self.cfg.orpo_alpha
@@ -1537,12 +1436,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
if self.cfg.optimizer in [
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
]:
if self.cfg.optimizer == "optimi_adamw":
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer
@@ -1575,11 +1469,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
if self.cfg.accelerator_config:
training_arguments_kwargs[
"accelerator_config"
] = self.cfg.accelerator_config
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
@@ -1587,12 +1476,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
training_args = self.hook_post_create_training_args(training_args)
# unset run_name so wandb sets up experiment names
if self.cfg.use_wandb and training_args.run_name == training_args.output_dir:
training_args.run_name = ( # pylint: disable=attribute-defined-outside-init
None
)
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
@@ -1672,12 +1555,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
if self.cfg.processor_type and self.processor:
collator = MultiModalChatDataCollator
kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template
else:
collator = DataCollatorForSeq2Seq
collator = DataCollatorForSeq2Seq
return collator(
self.tokenizer,
@@ -1784,27 +1662,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
if self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.rl_beta
if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["simpo_gamma"] = self.cfg.simpo_gamma
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
if self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
@@ -1812,6 +1679,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
training_args_kwargs["desirable_weight"] = (
self.cfg.kto_desirable_weight or 1.0
)
@@ -1857,6 +1725,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
trainer_cls_args = [self.model, self.model_ref]
# these aren't used for the ORPO trainer
@@ -1864,15 +1733,14 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = True
if self.cfg.rl == "dpo":
dpo_trainer_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl in ["kto"]:
trainer_cls = AxolotlKTOTrainer
trainer_cls_args = [self.model]
elif self.cfg.rl in ["simpo"]:
trainer_cls = AxolotlCPOTrainer
trainer_cls_args = [self.model]
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
dpo_trainer = trainer_cls(
@@ -1885,8 +1753,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
)
if self.cfg.fsdp:
ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
if self.cfg.rl in ["dpo", "ipo"] and dpo_trainer.ref_model:
ensure_dtype(dpo_trainer.ref_model, dtype=self.cfg.torch_dtype)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):

View File

@@ -1,58 +0,0 @@
### AXOLOTL COMMUNITY LICENSE AGREEMENT
This Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and
any individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms
and conditions set forth in this Agreement.
1. Definitions
1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement.
1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl,
which may be licensed separately by their respective authors and/or licensors.
1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at
https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which
permits Plugin Integrations to integrate with the Axolotl service.
2. Grant of License
2.1 Axolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge,
publish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions:
- Licensee must comply with all the terms and conditions of this Agreement.
- Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial
portions of the Software.
2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3.
3. Restrictions
3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for
free or for sale any services, platform, or equivalent to third parties for the purposes of allowing such
third parties to fine-tune artificial intelligence models.
3.2 Licensee shall not:
- Use the Software for any illegal or unauthorized purpose.
- Reverse engineer, decompile, or disassemble the Software.
- Remove or modify any copyright, trademark, or other proprietary notices contained in the Software.
- Use the Software in a way that could damage, disable, overburden, or impair the functionality of the
Software or interfere with any third-party use of the Software.
3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement.
4. Intellectual Property Rights
4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee
acknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to
Licensee.
5. Disclaimer of Warranty
5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF
CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
6. Termination
6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and
conditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any
copies in its possession.
7. Governing Law
7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California,
without regards to conflicts of laws provisions thereof.
8. Entire Agreement
8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter
hereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning
the Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and
Licensees continued use of the Software after any such updates shall constitute acceptance of updated terms
on a go-forward basis. Axolotl will use commercially reasonable efforts to provide Licensee notice of any
material updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be
bound by the terms and conditions of this Agreement.
This Agreement was last updated on August 23, 2024.

View File

@@ -1,383 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""
Base class for all plugins.
A plugin is a reusable, modular, and self-contained piece of code that extends the functionality of Axolotl.
Plugins can be used to integrate third-party models, modify the training process, or add new features.
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
"""
import importlib
import logging
from typing import List
class BasePlugin:
"""
Base class for all plugins. Defines the interface for plugin methods.
Attributes:
None
Methods:
register(cfg): Registers the plugin with the given configuration.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_load(cfg, model): Performs actions after the model is loaded.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
add_callbacks_post_trainer(cfg, trainer): Adds callbacks to the trainer after training.
"""
def __init__(self):
"""
Initializes the BasePlugin.
"""
def register(self, cfg):
"""
Registers the plugin with the given configuration.
Parameters:
cfg (dict): The configuration for the plugin.
Returns:
None
"""
def get_input_args(self):
"""
Returns a pydantic model for the plugin's input arguments.
"""
def pre_model_load(self, cfg):
"""
Performs actions before the model is loaded.
Parameters:
cfg (dict): The configuration for the plugin.
Returns:
None
"""
def post_model_load(self, cfg, model):
"""
Performs actions after the model is loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
"""
def pre_lora_load(self, cfg, model):
"""
Performs actions before LoRA weights are loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
"""
def post_lora_load(self, cfg, model):
"""
Performs actions after LoRA weights are loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
"""
def create_optimizer(self, cfg, trainer):
"""
Creates and returns an optimizer for training.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
object: The created optimizer.
"""
def create_lr_scheduler(self, cfg, trainer, optimizer):
"""
Creates and returns a learning rate scheduler.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
Returns:
object: The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model):
"""
Adds callbacks to the trainer before training.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
def add_callbacks_post_trainer(self, cfg, trainer):
"""
Adds callbacks to the trainer after training.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
def load_plugin(plugin_name: str) -> BasePlugin:
"""
Loads a plugin based on the given plugin name.
The plugin name should be in the format "module_name.class_name".
This function splits the plugin name into module and class, imports the module,
retrieves the class from the module, and creates an instance of the class.
Parameters:
plugin_name (str): The name of the plugin to be loaded. The name should be in the format "module_name.class_name".
Returns:
BasePlugin: An instance of the loaded plugin.
Raises:
ImportError: If the plugin module cannot be imported.
"""
# split the plugin name into module and class
module_name, class_name = plugin_name.rsplit(".", 1)
# import the module
module = importlib.import_module(module_name)
# instantiate the class
plugin_class = getattr(module, class_name)
# create an instance of the class
plugin = plugin_class()
return plugin
class PluginManager:
"""
The PluginManager class is responsible for loading and managing plugins.
It should be a singleton so it can be accessed from anywhere in the codebase.
Attributes:
plugins (List[BasePlugin]): A list of loaded plugins.
Methods:
get_instance(): Static method to get the singleton instance of PluginManager.
register(plugin_name: str): Registers a new plugin by its name.
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
"""
plugins: List[BasePlugin] = []
_instance = None
def __new__(cls):
"""
Creates a new instance of PluginManager if it doesn't exist yet.
"""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins: List[BasePlugin] = []
return cls._instance
@staticmethod
def get_instance() -> "PluginManager":
"""
Returns the singleton instance of PluginManager.
If the instance doesn't exist, it creates a new one.
"""
if PluginManager._instance is None:
PluginManager()
return PluginManager._instance # type: ignore
def register(self, plugin_name: str):
"""
Registers a new plugin by its name.
Parameters:
plugin_name (str): The name of the plugin to be registered.
Returns:
None
Raises:
ImportError: If the plugin module cannot be imported.
"""
try:
plugin = load_plugin(plugin_name)
self.plugins.append(plugin)
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
def get_input_args(self):
"""
Returns a list of Pydantic classes for all registered plugins' input arguments.'
Returns:
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
"""
input_args = []
for plugin in self.plugins:
input_args_from_plugin = plugin.get_input_args()
if input_args_from_plugin is not None:
input_args.append(input_args_from_plugin)
return input_args
def pre_model_load(self, cfg):
"""
Calls the pre_model_load method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
Returns:
None
"""
for plugin in self.plugins:
plugin.pre_model_load(cfg)
def post_model_load(self, cfg, model):
"""
Calls the post_model_load method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins:
plugin.post_model_load(cfg, model)
def pre_lora_load(self, cfg, model):
"""
Calls the pre_lora_load method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins:
plugin.pre_lora_load(cfg, model)
def post_lora_load(self, cfg, model):
"""
Calls the post_lora_load method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
None
"""
for plugin in self.plugins:
plugin.post_lora_load(cfg, model)
def create_optimizer(self, cfg, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Returns:
object: The created optimizer, or None if none was found.
"""
for plugin in self.plugins:
optimizer = plugin.create_optimizer(cfg, trainer)
if optimizer is not None:
return optimizer
return None
def create_lr_scheduler(self, cfg, trainer, optimizer):
"""
Calls the create_lr_scheduler method of all registered plugins and returns the first non-None scheduler.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
Returns:
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins:
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
if scheduler is not None:
return scheduler
return None
def add_callbacks_pre_trainer(self, cfg, model):
"""
Calls the add_callbacks_pre_trainer method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
model (object): The loaded model.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
return callbacks
def add_callbacks_post_trainer(self, cfg, trainer):
"""
Calls the add_callbacks_post_trainer method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks

View File

@@ -1,65 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""
module to handle merging the plugins' input arguments with the base configurations.
this was moved here to prevent circular imports
"""
from typing import Any, Dict, List
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
)
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlInputConfig as AxolotlInputConfigBase,
)
def merge_input_args():
"""
Merges input arguments from registered plugins with the base configurations.
This function retrieves the input arguments from registered plugins using the PluginManager.
It then dynamically creates new classes, AxolotlConfigWCapabilities and AxolotlInputConfig,
that inherit from the base configurations and include the input arguments from the plugins.
Returns:
tuple: A tuple containing the newly created classes, AxolotlConfigWCapabilities and AxolotlInputConfig.
"""
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
input_args: List[str] = plugin_manager.get_input_args()
plugin_classes = []
dynamic_input = ""
for plugin_args in input_args:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
plugin_classes.append(plugin_cls)
if dynamic_input:
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
namespace: Dict[Any, Any] = {}
exec( # pylint: disable=exec-used # nosec B102
dynamic_input, globals(), namespace
)
AxolotlInputConfig = namespace[ # pylint: disable=invalid-name
"AxolotlInputConfig"
]
AxolotlConfigWCapabilities = namespace[ # pylint: disable=invalid-name
"AxolotlConfigWCapabilities"
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase

View File

@@ -1,202 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -1,189 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module for the Plugin for LIGER integraton with Axolotl.
Liger Kernel is the collection of Triton-native kernels for LLM Training.
It is designed to be performant, correct, and light-weight.
"""
import logging
import sys
from functools import partial
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
class LigerPlugin(BasePlugin):
"""
Plugin for LIGER integraton with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
if cfg.model_config_type == "llama":
from liger_kernel.transformers.model.llama import (
lce_forward as llama_lce_forward,
)
from transformers.models.llama import modeling_llama
if cfg.liger_rope:
modeling_llama.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_llama.LlamaRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_llama.LlamaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
elif cfg.liger_fused_linear_cross_entropy:
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward
elif cfg.model_config_type == "mistral":
from liger_kernel.transformers.model.mistral import (
lce_forward as mistral_lce_forward,
)
from transformers.models.mistral import modeling_mistral
if cfg.liger_rope:
modeling_mistral.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_mistral.MistralRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_mistral.MistralMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward
elif cfg.model_config_type == "gemma":
from liger_kernel.transformers.model.gemma import (
lce_forward as gemma_lce_forward,
)
from transformers.models.gemma import modeling_gemma
if cfg.liger_rope:
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma.GemmaRMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cfg.liger_swiglu:
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba
from .models.jamba import lce_forward as jamba_lce_forward
if cfg.liger_rope:
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
elif cfg.model_config_type == "qwen2":
from liger_kernel.transformers.model.qwen2 import (
lce_forward as qwen2_lce_forward,
)
from transformers.models.qwen2 import modeling_qwen2
if cfg.liger_rope:
modeling_qwen2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_qwen2.Qwen2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_qwen2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
elif cfg.model_config_type == "deepseek_v2":
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(
cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
)
modeling_mod = sys.modules[model.__class__.__module__]
from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward
if cfg.liger_rope:
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
logging.warning("Fused liger_rope is not supported for DeepseekV2.")
if cfg.liger_rms_norm:
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cfg.liger_cross_entropy:
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type == "gemma2":
from transformers.models.gemma2 import modeling_gemma2
if cfg.liger_rope:
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma2.Gemma2RMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cfg.liger_swiglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
logging.warning(
"Fused linear cross entropy is not supported for Gemma 2."
)
elif cfg.model_config_type == "phi3":
from liger_kernel.transformers.model.phi3 import (
lce_forward as phi3_lce_forward,
)
from transformers.models.phi3 import modeling_phi3
if cfg.liger_rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm
if cfg.liger_swiglu:
modeling_phi3.Phi3MLP = LigerSwiGLUMLP
if cfg.liger_cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward

View File

@@ -1,32 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module for handling LIGER input arguments.
"""
from typing import Optional
from pydantic import BaseModel
class LigerArgs(BaseModel):
"""
Input args for LIGER.
"""
liger_rope: Optional[bool] = None
liger_rms_norm: Optional[bool] = None
liger_swiglu: Optional[bool] = None
liger_cross_entropy: Optional[bool] = None
liger_fused_linear_cross_entropy: Optional[bool] = None

View File

@@ -1,127 +0,0 @@
"""
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
"""
# pylint: disable=duplicate-code
from typing import List, Optional, Tuple, Union
import torch
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithPast
# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
# @replace_return_docstrings(
# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
# )
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers.,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM
>>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
loss = None
logits = None
if self.training:
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# flatten tokens
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
shift_labels = shift_labels.view(-1)
lce = LigerFusedLinearCrossEntropyLoss()
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -1,173 +0,0 @@
"""
Jamba model with LigerFusedLinearCrossEntropyLoss
"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple, Union
import torch
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
from transformers.models.jamba.modeling_jamba import (
_CONFIG_FOR_DOC,
JAMBA_INPUTS_DOCSTRING,
HybridMambaAttentionDynamicCache,
load_balancing_loss_func,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: Optional[Union[int, None]] = None,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int` or `None`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all
`input_ids`. Only last token logits are needed for generation, and calculating them only for that token
can save memory, which becomes pretty significant for long sequences.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, JambaForCausalLM
>>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
>>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
return_dict=return_dict,
)
hidden_states = outputs[0]
loss = None
logits = None
if self.training:
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# flatten tokens
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
shift_labels = shift_labels.view(-1)
lce = LigerFusedLinearCrossEntropyLoss()
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
if num_logits_to_keep is None:
logits = self.lm_head(hidden_states)
else:
logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :])
logits = logits.float()
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits if return_dict else outputs[-1],
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(
loss.device
) # make sure to reside in the same device
if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)

View File

@@ -1,202 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -1,21 +0,0 @@
## Spectrum: Targeted Training on Signal to Noise Ratio
by Eric Hartford, Lucas Atkins, Fernando Fernandes, David Golchinfar
This plugin contains code to freeze the bottom fraction of modules in a model, based on the Signal-to-Noise Ratio (SNR).
### Overview
Spectrum is a tool for scanning and evaluating the Signal-to-Noise Ratio (SNR) of layers in large language models.
By identifying the top n% of layers with the highest SNR, you can optimize training efficiency.
### Usage
```yaml
plugins:
- axolotl.integrations.spectrum.SpectrumPlugin
spectrum_top_fraction: 0.5
# Optional if using a pre-scanned model as your base_model. Useful if using a model mirror
spectrum_model_name: meta-llama/Meta-Llama-3.1-8B
```

View File

@@ -1,102 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Spectrum Plugin to automatically generate unfrozen parameters based on SNR data.
"""
import json
import logging
import requests
from axolotl.integrations.base import BasePlugin
from .args import SpectrumArgs # pylint: disable=unused-import. # noqa: F401
def _generate_unfrozen_params_yaml(snr_data, top_fraction=0.5):
unfrozen_parameters = {}
for layer_name, info in snr_data.items():
layer_type = info["type"]
if layer_type not in unfrozen_parameters:
unfrozen_parameters[layer_type] = []
unfrozen_parameters[layer_type].append((layer_name, info["snr"]))
top_layers_by_type = {}
for layer_type, layers in unfrozen_parameters.items():
layers_sorted = sorted(layers, key=lambda x: x[1], reverse=True)
num_top_layers = int(len(layers) * top_fraction)
top_layers_by_type[layer_type] = [
layer[0] for layer in layers_sorted[:num_top_layers]
]
unfrozen_parameters = [
"^lm_head.weight$",
"^model.embed_tokens.weight$",
]
for layer_type, layer_names in top_layers_by_type.items():
for layer_name in layer_names:
unfrozen_parameters.append(layer_name)
return unfrozen_parameters
class SpectrumPlugin(BasePlugin):
"""
Spectrum Plugin to automatically generate unfrozen parameters based on SNR data.
"""
base_url = "https://raw.githubusercontent.com/cognitivecomputations/spectrum/main/model_snr_results/"
base_path = "./model_snr_results/"
snr_file_template = "snr_results_{model_name_slug}.json"
def get_input_args(self):
return "axolotl.integrations.spectrum.SpectrumArgs"
def pre_model_load(self, cfg):
if cfg.get("spectrum_model_name"):
model_name = cfg["spectrum_model_name"]
else:
model_name = cfg["base_model"]
top_fraction = cfg.get("spectrum_top_fraction", 50)
model_slug = model_name.replace("/", "-").replace("_", "-")
snr_url = self.base_url + self.snr_file_template.format(
model_name_slug=model_slug
)
snr_path = self.base_path + self.snr_file_template.format(
model_name_slug=model_slug
)
# first check if the files exist locally and read the json
snr_data = None
try:
with open(snr_path, "r", encoding="utf-8") as fin:
snr_data = json.load(fin)
except FileNotFoundError:
pass
except Exception as exc: # pylint: disable=broad-exception-caught
logging.warning(f"Failed to read SNR data from {snr_path}: {exc}")
if not snr_data:
try:
snr_data = requests.get(snr_url, timeout=60).json()
except requests.exceptions.RequestException as exc:
logging.warning(f"Failed to fetch SNR data from {snr_url}: {exc}")
return
# also catch json parsing errors
except json.JSONDecodeError as exc:
logging.warning(f"Failed to parse SNR data from {snr_url}: {exc}")
return
unfrozen_parameters = _generate_unfrozen_params_yaml(
snr_data, top_fraction=top_fraction
)
cfg["unfrozen_parameters"] = unfrozen_parameters

View File

@@ -1,29 +0,0 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module for handling Spectrum input arguments.
"""
from typing import Optional
from pydantic import BaseModel
class SpectrumArgs(BaseModel):
"""
Input args for Spectrum.
"""
spectrum_top_fraction: Optional[float] = 0.5
spectrum_model_name: Optional[str] = None

133
src/axolotl/loraplus.py Normal file
View File

@@ -0,0 +1,133 @@
"""Module for LoRA+"""
# MIT License
#
# Copyright (c) 2024 nikhil-ghosh-berkeley
# https://github.com/nikhil-ghosh-berkeley/loraplus
import logging
from functools import reduce
from peft.tuners import lora
from torch import nn
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
LOG = logging.getLogger("axolotl.loraplus")
def get_module(name, opt_model):
"""
Retrieve a module from a model using its parameter name.
Args:
name (str): Full name of the parameter, typically including module path.
opt_model (torch.nn.Module): The model from which to retrieve the module.
Returns:
Module corresponding to the given name.
"""
parent_idx = 2 if "lora" in name else 1
module_names = name.split(sep=".")[:-parent_idx]
module = reduce(getattr, module_names, opt_model)
return module
def create_loraplus_optimizer(
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding=None,
):
"""
Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
Args:
opt_model (torch.nn.Module): The model for which the optimizer is being created.
optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
Returns:
An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
"""
assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
if loraplus_lr_embedding is None:
loraplus_lr_embedding = 1e-6
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
param_groups = {
"groupA": {},
"groupB": {},
"groupB_no_decay": {},
"embedding": {},
}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
module = get_module(name, opt_model)
if isinstance(module, lora.Embedding):
param_groups["embedding"][name] = param
elif "lora_B" in name or param.ndim == 1:
if name in decay_parameters:
param_groups["groupB"][name] = param
else:
param_groups["groupB_no_decay"][name] = param
else:
param_groups["groupA"][name] = param
assigned_param_groups = ""
for group, group_params in param_groups.items():
assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
LOG.info(assigned_param_groups)
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
optimizer_grouped_parameters = [
{
"params": list(param_groups["groupA"].values()),
"weight_decay": weight_decay,
"lr": lr,
},
{
"params": list(param_groups["embedding"].values()),
"weight_decay": weight_decay,
"lr": loraplus_lr_embedding,
},
{
"params": list(param_groups["groupB"].values()),
"weight_decay": weight_decay,
"lr": lr * loraplus_lr_ratio,
},
{
"params": list(param_groups["groupB_no_decay"].values()),
"weight_decay": 0.0,
"lr": lr * loraplus_lr_ratio,
},
]
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
return optimizer

View File

@@ -1,229 +0,0 @@
"""
Monkeypatch for Vision Llama for FA2 support
"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple
import torch
from flash_attn.flash_attn_interface import flash_attn_func
from transformers.cache_utils import Cache
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
from transformers.models.mllama.modeling_mllama import (
MllamaTextCrossAttention,
MllamaTextSelfAttention,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import is_flash_attn_greater_or_equal_2_10
class MllamaTextCrossFlashAttention2(MllamaTextCrossAttention):
"""
Mllama flash cross-attention module. This module inherits from `MllamaTextCrossAttention` and
implements the forward pass using Flash Attention for improved performance.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Check if flash attention version is greater or equal to 2.1
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: Optional[ # pylint: disable=unused-argument
torch.Tensor
] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
query_states = self.q_norm(query_states)
if cross_attention_states is not None:
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(
bsz, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, -1, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
key_states = self.k_norm(key_states)
if past_key_value is not None:
key_states, value_states = past_key_value.update(
key_states,
value_states,
self.layer_idx,
{"cache_position": cache_position},
)
elif cache_position[0] != 0:
key_states, value_states = (
past_key_value.key_cache[self.layer_idx],
past_key_value.value_cache[self.layer_idx],
)
else:
raise ValueError(
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
)
# Transpose to get the expected layout for flash attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# Apply Flash Attention
dropout_rate = self.dropout if self.training else 0.0
output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=dropout_rate,
softmax_scale=None,
causal=False,
return_attn_probs=output_attentions,
)
attn_output = output.contiguous().view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class MllamaTextSelfFlashAttention2(MllamaTextSelfAttention):
"""
Mllama flash self-attention module. This module inherits from `MllamaTextSelfAttention` and
implements the forward pass using Flash Attention for improved performance.
"""
def __init__(self, config: MllamaTextConfig, layer_idx: int, *args, **kwargs):
super().__init__(config, layer_idx, *args, **kwargs)
# Check if flash attention version is greater or equal to 2.1
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
past_key_value=None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Flash attention requires the input to have the shape
# batch_size x seq_length x num_heads x head_dim
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Transpose to get the expected layout for flash attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.dropout if self.training else 0.0
# Handle potential silent casting to float32
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = (
self.config._pre_quantization_dtype # pylint: disable=protected-access
)
else:
target_dtype = self.q_proj.weight.dtype
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=True,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def patch_mllama():
from transformers.models.mllama.modeling_mllama import (
MLLAMA_TEXT_ATTENTION_CLASSES,
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES,
MLLAMA_VISION_ATTENTION_CLASSES,
MllamaPreTrainedModel,
)
MllamaPreTrainedModel._supports_flash_attn_2 = ( # pylint: disable=protected-access
True
)
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[
"flash_attention_2"
] = MllamaTextCrossFlashAttention2
# fallback to SDPA
MLLAMA_VISION_ATTENTION_CLASSES[
"flash_attention_2"
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]

View File

@@ -78,33 +78,6 @@ def replace_llama_qkv_with_fused(model):
set_module_name(model, name, qkv)
def patch_llama_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
def patch_llama_rms_norm():
try:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
@@ -131,11 +104,30 @@ def replace_llama_attn_with_flash_attn(
# skip only if explicitly disabled
if cross_entropy:
patch_llama_cross_entropy()
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
# skip only if explicitly disabled
if rms_norm:
patch_llama_rms_norm()
try:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
class FusedAttention(LlamaAttention):

View File

@@ -9,18 +9,18 @@ from axolotl.monkeypatch.utils import (
def hijack_llama_prepare_4d_mask():
from transformers import modeling_attn_mask_utils
from transformers.models.llama import modeling_llama
import transformers.modeling_attn_mask_utils
import transformers.models.llama.modeling_llama
modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)
modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)

View File

@@ -2,7 +2,6 @@
# pylint: disable=duplicate-code
import logging
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
@@ -46,15 +45,6 @@ def replace_mistral_attn_with_flash_attn(
)
def patch_mistral_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
@torch.jit.script
def _make_sliding_window_causal_mask(
bsz: int,

View File

@@ -10,15 +10,11 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mllama_text_model",
"llama",
"mistral",
"mixtral",
"qwen2",
"qwen2_moe",
"falcon",
"phi",
"phi3",
"gemma",
"gemma2",
"gemmoe",
@@ -27,36 +23,13 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
]
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
return
# retain for legacy
def patch_for_multipack(model_type, model_name=None):
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "llama":
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "mistral":
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
@@ -85,6 +58,12 @@ def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "jamba":
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
def patch_remote(model_name, config_name, modeling_name):

View File

@@ -16,7 +16,6 @@
# This code is based off the following work:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# pylint: disable=duplicate-code
""" PyTorch StableLM Epoch model. """
import importlib
import math

View File

@@ -1,22 +1,21 @@
"""module for patching with unsloth optimizations"""
import inspect
import logging
import re
import types
from typing import Tuple
import torch
from accelerate.logging import get_logger
from peft import PeftModelForCausalLM
from torch import nn
from transformers.models.llama.modeling_llama import (
LlamaFlashAttention2,
LlamaForCausalLM,
)
LOG = get_logger("axolotl.monkeypatch.unsloth")
LOG = logging.getLogger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
ORIGINAL_CEL_CODE = """ if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
@@ -28,7 +27,8 @@ ORIGINAL_CEL_CODE = """# Shift so that tokens < n predict n
loss = loss_fct(shift_logits, shift_labels)
"""
PATCHED_CEL_CODE = """shift_logits = logits[..., :-1, :].contiguous()
PATCHED_CEL_CODE = """ if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = fast_cross_entropy_loss(
logits = shift_logits,
@@ -97,51 +97,48 @@ def check_self_attn_is_patchable() -> bool:
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
if model_type == "llama":
forward = get_forward_code()
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
def integrate_cross_entropy_loss_patch():
forward = get_forward_code()
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)
forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
else:
raise ValueError("Unsupported model type")
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
print("patching unsloth fast_cross_entropy_loss")
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
def detab_code(code: str) -> Tuple[str, str]:
@@ -182,30 +179,12 @@ def patch_self_attn_lora():
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching unsloth attn lora", main_process_only=True)
print("patching unsloth attn lora")
LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
def integrate_rope_embeddings():
import transformers.models.llama.modeling_llama
from unsloth.kernels.rope_embedding import fast_rope_embedding
def apply_rotary_pos_emb( # pylint: disable=unused-argument
q, # pylint: disable=invalid-name
k, # pylint: disable=invalid-name
cos,
sin,
position_ids=None,
unsqueeze_dim=1,
):
return fast_rope_embedding(q, k, cos, sin)
LOG.info("patching unsloth RoPE embeddings", main_process_only=True)
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
from unsloth.kernels import apply_lora_mlp_swiglu
@@ -238,7 +217,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
else:
LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
@@ -264,7 +243,9 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_qkv = apply_lora_qkv
else:
layer.self_attn.apply_qkv = original_apply_qkv
LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx)
logging.warning(
"unable to apply unsloth lora qkv patch to layer %d", idx
)
if cfg.unsloth_lora_o:
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
@@ -283,33 +264,6 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_o = apply_lora_o
else:
layer.self_attn.apply_o = original_apply_o
LOG.warning(
logging.warning(
"unable to apply unsloth lora o_proj patch to layer %d", idx
)
def patch_unsloth_layernorm():
try:
import transformers.models.llama.modeling_llama
from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm
class LlamaRMSNorm(nn.Module):
"""LlamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
return Fast_RMS_Layernorm.apply(
hidden_states, self.weight, self.variance_epsilon, False
)
LOG.info("patching with unsloth.kernels.rms_layernorm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning("missing unsloth library")

View File

@@ -17,9 +17,11 @@ def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
max_num = int(torch.max(attention_mask).item())
batch_size, _ = attention_mask.shape
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)
for i in range(1, max_num + 1):
mask = attention_mask == i
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)
result = counts.flatten()
nonzero_indices = torch.nonzero(result).squeeze(-1)
return result[nonzero_indices]

View File

@@ -9,7 +9,7 @@ from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
LOG = logging.getLogger("axolotl.prompt_strategies")
def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
def load(strategy, tokenizer, cfg, ds_cfg):
try:
load_fn = "load"
if strategy.split(".")[-1].startswith("load_"):
@@ -24,8 +24,6 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor
return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None

View File

@@ -5,30 +5,23 @@ HF Chat Templates prompt strategy
import logging
from typing import Any, Dict, List, Optional
from transformers import ProcessorMixin
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates
# Configure the logger
LOG = logging.getLogger("axolotl")
LOG.setLevel(logging.INFO)
class ChatTemplatePrompter(Prompter):
"""Prompter for HF chat templates"""
"""prompter for HF chat templates"""
def __init__(
self,
tokenizer,
processor=None,
chat_template=None,
max_length=2048,
message_field_role: str = "from",
message_field_content: str = "value",
message_field_training: Optional[str] = None,
message_field_training_detail: Optional[str] = None,
roles: Optional[Dict[str, List[str]]] = None,
drop_system_message: bool = False,
):
@@ -44,20 +37,16 @@ class ChatTemplatePrompter(Prompter):
}
self.message_field_role = message_field_role
self.message_field_content = message_field_content
self.message_field_training = message_field_training
self.message_field_training_detail = message_field_training_detail
self.tokenizer = tokenizer
self.processor: ProcessorMixin = processor
self.chat_template = chat_template
self.max_length = max_length
self.drop_system_message = drop_system_message
def build_prompt(self, conversation, add_generation_prompt=False, images=None):
def build_prompt(self, conversation, add_generation_prompt=False):
turns = [
{
"role": self.roles[t[self.message_field_role]],
"content": t[self.message_field_content],
"training": t.get(self.message_field_training, None),
}
for t in conversation
]
@@ -65,28 +54,6 @@ class ChatTemplatePrompter(Prompter):
if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]
if self.processor:
text = self.processor.apply_chat_template(
turns,
chat_template=self.chat_template,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
batch = self.processor(
text=text,
images=images,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
)
# workaround since processor works in batches instead of single examples
for k, val in batch.items():
if k in ["pixel_values"]:
batch[k] = val.tolist()
else:
batch[k] = val.squeeze().tolist()
return batch
return self.tokenizer.apply_chat_template(
turns,
truncation=True,
@@ -95,108 +62,6 @@ class ChatTemplatePrompter(Prompter):
chat_template=self.chat_template,
)
def get_offsets_for_train_detail(
self, text: str, train_details: List[Dict], mask_untrainable: bool = True
) -> List[int]:
tokenized_output = self.tokenizer(
text, return_offsets_mapping=True, add_special_tokens=False
)
tokens = tokenized_output.tokens()
token_offsets = tokenized_output["offset_mapping"]
LOG.debug(f"Tokenizing text: {text}")
LOG.debug(f"Tokens: {tokens}")
# Adjust the end offsets. For some reason by default they are set to the same value as the start offsets.
for i in range(len(token_offsets) - 1):
token_offsets[i] = (token_offsets[i][0], token_offsets[i + 1][0] - 1)
# Ensure the last token's end offset is set correctly
token_offsets[-1] = (token_offsets[-1][0], len(text) - 1)
LOG.debug(f"Token offsets: {token_offsets}")
# Initialize all offsets as IGNORE_TOKEN_ID (not trained)
result = [IGNORE_TOKEN_ID] * len(token_offsets)
# Adjust train_details to align with token boundaries
adjusted_train_details = self.adjust_train_details(train_details, token_offsets)
for idx, (start, end) in enumerate(token_offsets):
for detail in adjusted_train_details:
# Check if the token is completely within the detail's range
if start >= detail["begin_offset"] and end <= detail["end_offset"]:
if detail["train"] or not mask_untrainable:
result[idx] = start
LOG.debug(f"Token {idx} ({tokens[idx]}) marked for training")
else:
LOG.debug(
f"Token {idx} ({tokens[idx]}) marked as non-trainable"
)
elif start < detail["end_offset"] and end > detail["begin_offset"]:
# Token partially overlaps with detail, always mark as non-trainable
LOG.debug(
f"Token {idx} ({tokens[idx]}) partially overlaps detail, marked as non-trainable"
)
LOG.debug(f"Final result: {result}")
return result
def adjust_train_details(
self, train_details: List[Dict], token_offsets: List[tuple]
) -> List[Dict]:
adjusted_details = []
for detail in train_details:
begin_offset = detail["begin_offset"]
end_offset = detail["end_offset"]
# Find the first token that starts after or at the begin_offset
begin_token = next(
(
i
for i, (t_start, t_end) in enumerate(token_offsets)
if t_start >= begin_offset
),
len(token_offsets),
)
if begin_token > 0 and token_offsets[begin_token - 1][1] > begin_offset:
begin_token -= 1
# Find the last token that ends before or at the end_offset
end_token = next(
(
i
for i in range(len(token_offsets) - 1, -1, -1)
if token_offsets[i][1] <= end_offset
),
-1,
)
if (
end_token < len(token_offsets) - 1
and token_offsets[end_token + 1][0] < end_offset
):
end_token += 1
if begin_token <= end_token:
adjusted_begin = token_offsets[begin_token][0]
adjusted_end = token_offsets[end_token][1]
if adjusted_begin != begin_offset or adjusted_end != end_offset:
LOG.warning(
f"Adjusting detail offsets: ({begin_offset}, {end_offset}) -> ({adjusted_begin}, {adjusted_end})"
)
adjusted_details.append(
{
"begin_offset": adjusted_begin,
"end_offset": adjusted_end,
"train": detail["train"],
}
)
else:
LOG.warning(
f"Could not adjust detail offsets: ({begin_offset}, {end_offset}). Skipping this detail."
)
return adjusted_details
class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
@@ -205,20 +70,6 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
_messages = "conversations"
def __init__(
self,
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=None,
train_on_eos=None,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else []
self.train_on_eos = train_on_eos
self.images = "images"
@property
def messages(self):
return self._messages
@@ -228,212 +79,62 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
self._messages = messages
def tokenize_prompt(self, prompt):
# Old simple legacy behavior that works reliably.
if (
not self.roles_to_train
and not self.train_on_eos
and not self.prompter.message_field_training
and not self.prompter.message_field_training_detail
):
turns = self.get_conversation_thread(prompt)
images = self.get_images(prompt)
prompt_ids = self.prompter.build_prompt(
turns[:-1],
add_generation_prompt=True,
images=images,
)
tokenized_res = self.prompter.build_prompt(turns, images=images)
tokenized_prompt = {}
if isinstance(tokenized_res, list):
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
tokenized_prompt["input_ids"] = input_ids
tokenized_prompt["attention_mask"] = [1] * len(input_ids)
else:
input_ids = tokenized_res["input_ids"]
tokenized_prompt = tokenized_res
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids
tokenized_prompt["labels"] = labels
return tokenized_prompt
turns = prompt[self.messages]
turns = self.get_conversation_thread(prompt)
prompt_ids = self.prompter.build_prompt(turns[:-1], add_generation_prompt=True)
input_ids = self.prompter.build_prompt(turns)
labels = [IGNORE_TOKEN_ID] * len(input_ids)
last_eos_idx = -1
for index, turn in enumerate(turns):
role = turn.get(self.prompter.message_field_role)
content = turn.get(self.prompter.message_field_content)
train_turn = turn.get(self.prompter.message_field_training)
train_detail = turn.get(self.prompter.message_field_training_detail)
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids
LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
)
should_train = (
train_turn
if train_turn is not None
else (
bool(train_detail is not None)
if train_detail is not None
else self.train_on_inputs or role in self.roles_to_train
)
)
LOG.debug(f"Should train: {should_train}")
turn_start_idx, turn_end_idx = self.find_turn(
conversation_ids=input_ids, turn=index, turn_content=turn
)
LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
if train_detail:
token_offsets = self.prompter.get_offsets_for_train_detail(
content, train_detail
)
LOG.debug(f"Token offsets: {token_offsets}")
for i, offset in enumerate(token_offsets):
if offset != IGNORE_TOKEN_ID and turn_start_idx + i < len(
input_ids
):
labels[turn_start_idx + i] = input_ids[turn_start_idx + i]
LOG.debug(
f"Label set at index {turn_start_idx + i}: {input_ids[turn_start_idx + i]}"
)
else:
labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx
]
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}")
LOG.debug(f"Labels after processing turn {index}: {labels}")
# Handle EOS token
eos_idx = self.find_eos_token(input_ids, turn_end_idx)
if eos_idx == turn_end_idx:
last_eos_idx = eos_idx
if self.train_on_eos == "all" or (
self.train_on_eos == "turn" and should_train
):
labels[eos_idx] = input_ids[eos_idx]
LOG.debug(f"EOS token set for training at index {eos_idx}")
else:
LOG.debug(
f"EOS token missing after turn {turn}. eos_idx: {eos_idx}, turn_end_idx: {turn_end_idx}"
)
# Handle 'last' option for train_on_eos
if self.train_on_eos == "last" and last_eos_idx != -1:
labels[last_eos_idx] = input_ids[last_eos_idx]
LOG.debug(f"Last EOS token set for training at index {last_eos_idx}")
LOG.debug(f"Final labels: {labels}")
return {
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
def find_eos_token(self, input_ids, start_idx):
eos_token_id = self.tokenizer.eos_token_id
for i in range(start_idx, len(input_ids)):
if input_ids[i] == eos_token_id:
return i
return -1
def find_turn(self, conversation_ids, turn, turn_content):
"""
Locate the starting and ending indices of the specified turn in a conversation.
Args:
conversation_ids (list[int]): Token IDs representing the conversation.
turn (int): The turn number to locate (based on EOS tokens).
turn_content (str): String containing the content of the turn.
Returns:
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
Returns (-1, -1) if the turn content is not found.
"""
content = turn_content.get(self.prompter.message_field_content, "")
content_ids = self.tokenizer.encode(content, add_special_tokens=False)
eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0
# Locate the starting index after the specified number of EOS tokens
for i, token_id in enumerate(conversation_ids):
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn:
start_search_idx = (
i + 1
) # Start searching after the specified turn's EOS token
break
# Find the start index of the content within the conversation
start_idx = -1
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1):
if conversation_ids[i : i + len(content_ids)] == content_ids:
start_idx = i
break
if start_idx != -1:
end_idx = start_idx + len(content_ids)
else:
end_idx = -1
return start_idx, end_idx
return tokenized_prompt
def get_conversation_thread(self, prompt):
return prompt[self.messages]
def get_images(self, prompt):
return prompt.get(self.images, None)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
ds_cfg = ds_cfg or {}
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")),
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail",
None,
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
}
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", []),
"train_on_eos": ds_cfg.get("train_on_eos", None),
}
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
chat_template = (
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
)
message_field_role = (
ds_cfg["message_field_role"]
if ds_cfg and "message_field_role" in ds_cfg
else "from"
)
message_field_content = (
ds_cfg["message_field_content"]
if ds_cfg and "message_field_content" in ds_cfg
else "value"
)
roles = ds_cfg["roles"] if ds_cfg and "roles" in ds_cfg else None
drop_system_message = (
ds_cfg["drop_system_message"]
if ds_cfg and "drop_system_message" in ds_cfg
else False
)
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
tokenizer,
chat_templates(chat_template),
message_field_role=message_field_role,
message_field_content=message_field_content,
roles=roles,
drop_system_message=drop_system_message,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy

View File

@@ -1,78 +0,0 @@
"""
DPO prompt strategies for using tokenizer chat templates.
"""
from axolotl.utils.chat_templates import chat_templates
def default(
cfg, dataset_idx=0, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
ds_cfg = cfg["datasets"][dataset_idx]
chat_template_str = chat_templates(cfg.chat_template)
field_messages = ds_cfg.get("field_messages", "messages")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
field_message_role = ds_cfg.get("message_field_role", "role")
field_message_content = ds_cfg.get("message_field_content", "content")
role_map_inv = ds_cfg.get(
"roles",
{
"user": ["user"],
"assistant": ["assistant"],
"system": ["system"],
},
)
role_map = {}
for target, sources in role_map_inv.items():
for source in sources:
role_map[source] = target
def transform_fn(sample, tokenizer=None):
messages = sample[field_messages]
messages = [
{
"role": role_map[m[field_message_role]],
"content": m[field_message_content],
}
for m in messages
]
chosen = {
"role": role_map[sample[field_chosen][field_message_role]],
"content": sample[field_chosen][field_message_content],
}
rejected = {
"role": role_map[sample[field_rejected][field_message_role]],
"content": sample[field_rejected][field_message_content],
}
result = {}
result["prompt"] = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=False,
)
result["chosen"] = tokenizer.apply_chat_template(
[chosen],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)
chosen_strip_index = result["chosen"].find(chosen["content"])
result["chosen"] = result["chosen"][chosen_strip_index:].rstrip()
result["rejected"] = tokenizer.apply_chat_template(
[rejected],
add_generation_prompt=False,
chat_template=chat_template_str,
tokenize=False,
)
rejected_strip_index = result["rejected"].find(rejected["content"])
result["rejected"] = result["rejected"][rejected_strip_index:].rstrip()
return result
return transform_fn

View File

@@ -65,10 +65,8 @@ class AlpacaPrompter(Prompter):
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
elif self.prompt_style == PromptStyle.PHI.value:
self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>"
self.turn_no_input_format = (
"<|user|>\n{instruction}<|end|>\n<|assistant|>\n"
)
self.system_format = "<|system|>\n{system}<|end|>\n"
self.turn_no_input_format = "<|user|>\n{instruction}<|end|><|assistant|>"
self.system_format = "<|system|>{system}\n"
def _build_result(self, instruction, input_text, output):
# returns the full prompt from instruction and optional input
@@ -352,12 +350,9 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
"Please help us by creating an Issue to add support for this conversation type."
)
if self._conversation.name in ["llama3"]:
role = from_role
else:
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
ROLE=from_role
)
role = CONVERSATION_ROLE_FORMAT[self._conversation.name].format(
ROLE=from_role
)
if len(conv.messages) > 0 and ((role == conv.messages[-1][0])):
if (

View File

@@ -12,7 +12,6 @@ import torch
import transformers.modelcard
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
@@ -20,11 +19,10 @@ from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs
from axolotl.core.tokenizer_utils import fix_untrained_tokens
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
try:
@@ -54,24 +52,12 @@ class TrainDatasetMeta:
def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# enable expandable segments for cuda allocation to improve VRAM usage
torch_version = torch.__version__.split(".")
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ[
"PYTORCH_CUDA_ALLOC_CONF"
] = "expandable_segments:True,roundup_power2_divisions:16"
# load the tokenizer first
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
processor = None
if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer)
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
@@ -99,9 +85,7 @@ def train(
LOG.debug(msg)
# we wait unitl the last possible moment to setup Accelerator
Accelerator()
model, peft_config = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model.generation_config.do_sample = True
model_ref = None
@@ -127,17 +111,9 @@ def train(
eval_dataset,
(model, model_ref, peft_config),
tokenizer,
processor,
total_num_steps,
)
if cfg.fix_untrained_tokens:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
@@ -201,12 +177,9 @@ def train(
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
state_dict_type = "FULL_STATE_DICT"
if trainer.is_fsdp_enabled:
if cfg.fsdp_final_state_dict_type:
state_dict_type = cfg.fsdp_final_state_dict_type
trainer.accelerator.state.fsdp_plugin.set_state_dict_type(state_dict_type)
LOG.info(f"Set FSDP state dict type to {state_dict_type} for saving.")
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
@@ -218,38 +191,30 @@ def train(
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
if (
state_dict_type == "SHARDED_STATE_DICT"
and cfg.fsdp_config.fsdp_state_dict_type == "SHARDED_STATE_DICT"
):
save_fsdp_model(
trainer.accelerator.state.fsdp_plugin,
trainer.accelerator,
trainer.model,
cfg.output_dir,
)
elif state_dict_type == "FULL_STATE_DICT":
trainer.save_model(cfg.output_dir)
trainer.save_model(cfg.output_dir)
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone()
trainer.save_model(cfg.output_dir)
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
# the trainer saved a model.safetensors file in the output directory,
# but it is most likely a proxy model and if so, should be deleted
maybe_proxy = os.path.exists(os.path.join(cfg.output_dir, "model.safetensors"))
maybe_sharded = os.path.exists(
os.path.join(cfg.output_dir, "model.safetensors.index.json")
)
if maybe_proxy and maybe_sharded:
# but it is a proxy model and should be deleted
if os.path.exists(os.path.join(cfg.output_dir, "model.safetensors")):
LOG.info(f"Deleting {os.path.join(cfg.output_dir, 'model.safetensors')}")
LOG.info("This is a proxy model and should be deleted")
try:
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
except FileNotFoundError:
pass
os.remove(os.path.join(cfg.output_dir, "model.safetensors"))
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
# The model name saved is `pytorch_model.bin`
unwrapped_model.save_pretrained(
cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process,
save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
)
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
model = BetterTransformer.reverse(model)

File diff suppressed because one or more lines are too long

View File

@@ -1,14 +1,17 @@
"""
DataCollator for axolotl to pad labels and position_ids for packed sequences
"""
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Dict, Optional, Sequence, Union
import numpy as np
import torch
import transformers
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
IGNORE_INDEX = -100
@dataclass
class DataCollatorForSeq2Seq:
@@ -180,6 +183,34 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
return super().__call__(out_features, return_tensors=return_tensors)
@dataclass
class MambaDataCollator:
"""
Collator for State Space Models (Mamba)
"""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[torch.LongTensor(instance[key]) for instance in instances]
for key in ("input_ids", "labels")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
return {
"input_ids": input_ids,
"labels": labels,
}
@dataclass
class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""

View File

@@ -1,10 +0,0 @@
"""
shared axolotl collators for multipack, mamba, multimodal
"""
from .batching import ( # noqa: F401
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
PretrainingBatchSamplerDataCollatorForSeq2Seq,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from .mamba import MambaDataCollator # noqa: F401

View File

@@ -1,4 +0,0 @@
"""
basic shared collator constants
"""
IGNORE_INDEX = -100

View File

@@ -1,38 +0,0 @@
"""
collators for Mamba
"""
from dataclasses import dataclass
from typing import Dict, Sequence
import torch
import transformers
from axolotl.utils.collators.core import IGNORE_INDEX
@dataclass
class MambaDataCollator:
"""
Collator for State Space Models (Mamba)
"""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[torch.LongTensor(instance[key]) for instance in instances]
for key in ("input_ids", "labels")
)
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id,
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
return {
"input_ids": input_ids,
"labels": labels,
}

View File

@@ -1,77 +0,0 @@
"""
Collators for multi-modal chat messages and packing
"""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from transformers import PreTrainedTokenizerBase, ProcessorMixin
from transformers.data.data_collator import DataCollatorMixin
from transformers.utils import PaddingStrategy
@dataclass
class MultiModalChatDataCollator(DataCollatorMixin):
"""
Collator for multi-modal chat messages
"""
tokenizer: PreTrainedTokenizerBase
processor: ProcessorMixin
return_tensors: str = "pt"
chat_template: Optional[str] = None
packing: bool = False
max_images: int = -1
padding: Union[bool, str, PaddingStrategy] = True
pad_to_multiple_of: Optional[int] = None
def __post_init__(self):
if self.packing:
raise ValueError("Packing is currently not supported.")
def torch_call(
self, examples: List[Union[List[int], Any, Dict[str, Any]]]
) -> Dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
return self.__class__.process_rows(
examples, self.processor, self.chat_template, self.max_images
)
@staticmethod
def process_rows(examples, processor, chat_template, max_images, length_only=False):
# HINT: use `_torch_collate_batch` to stack and pad tensors
# see also DataCollatorWithFlattening and DefaultDataCollator
# *** This is COPIED from the trl example sft_vlm.py code ***
# use this as a starting point
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(
example["messages"], chat_template=chat_template, tokenize=False
)
for example in examples
]
images = [example["images"] for example in examples]
if max_images > 0:
images = [img_batch[:max_images] for img_batch in images]
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
)
labels[labels == image_token_id] = -100
batch["labels"] = labels
if length_only:
return {
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
}
return batch

View File

@@ -8,14 +8,11 @@ from typing import Optional
import torch
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.config import merge_input_args
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import SUPPORTED_METRICS
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
)
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlInputConfig as AxolotlInputConfigBase,
SUPPORTED_METRICS,
AxolotlConfigWCapabilities,
AxolotlInputConfig,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
@@ -121,36 +118,15 @@ def normalize_config(cfg):
cfg.base_model_config = cfg.base_model
model_config = load_model_config(cfg)
cfg.model_config_type = model_config.model_type
cfg.tokenizer_config = (
cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
)
cfg.is_multimodal = (
hasattr(model_config, "model_type")
and model_config.model_type in ["llava", "mllama"]
or any(
multimodal_name in cfg.base_model.lower()
for multimodal_name in [
"pixtral",
]
)
or cfg.is_multimodal
)
if cfg.is_multimodal:
cfg.processor_config = (
cfg.processor_config or cfg.base_model_config or cfg.base_model
)
model_config = model_config.text_config
cfg.model_config_type = model_config.model_type
# figure out if the model is llama
cfg.is_llama_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type == ["llama", "mllama_text_model"]
)
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower()
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
@@ -231,15 +207,6 @@ def normalize_cfg_datasets(cfg):
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
AxolotlConfigWCapabilities = AxolotlConfigWCapabilitiesBase
AxolotlInputConfig = AxolotlInputConfigBase
if cfg.plugins:
(
AxolotlConfigWCapabilities, # pylint: disable=invalid-name
AxolotlInputConfig, # pylint: disable=invalid-name
) = merge_input_args()
if capabilities:
return DictDefault(
dict(

View File

@@ -7,7 +7,6 @@ Module for pydantic models for configuration
import logging
import os
from enum import Enum
from importlib.metadata import version
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
@@ -78,7 +77,6 @@ class PretrainingDataset(BaseModel):
split: Optional[str] = "train"
text_column: Optional[str] = "text"
type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False
class UserDefinedPrompterType(BaseModel):
@@ -116,16 +114,10 @@ class SFTDataset(BaseModel):
field_messages: Optional[str] = None
message_field_role: Optional[str] = None
message_field_content: Optional[str] = None
message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None
roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
drop_system_message: Optional[bool] = None
trust_remote_code: Optional[bool] = False
class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""
@@ -166,7 +158,6 @@ class KTODataset(BaseModel):
split: Optional[str] = None
type: Optional[Union[UserDefinedKTOType, str]] = None
data_files: Optional[List[str]] = None
trust_remote_code: Optional[bool] = False
class RLType(str, Enum):
@@ -176,7 +167,6 @@ class RLType(str, Enum):
ipo = "ipo" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name
kto = "kto" # pylint: disable=invalid-name
simpo = "simpo" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
@@ -188,11 +178,7 @@ class ChatTemplate(str, Enum):
gemma = "gemma" # pylint: disable=invalid-name
cohere = "cohere" # pylint: disable=invalid-name
llama3 = "llama3" # pylint: disable=invalid-name
llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name
phi_3 = "phi_3" # pylint: disable=invalid-name
phi_35 = "phi_35" # pylint: disable=invalid-name
deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name
jamba = "jamba" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
@@ -229,23 +215,16 @@ class LoraConfig(BaseModel):
lora_r: Optional[int] = None
lora_alpha: Optional[int] = None
lora_fan_in_fan_out: Optional[bool] = None
lora_target_modules: Optional[Union[str, List[str]]] = None
lora_target_modules: Optional[List[str]] = None
lora_target_linear: Optional[bool] = None
lora_modules_to_save: Optional[List[str]] = None
lora_dropout: Optional[float] = 0.0
peft_layers_to_transform: Optional[List[int]] = None
peft_layers_pattern: Optional[List[str]] = None
peft: Optional[PeftConfig] = None
peft_use_dora: Optional[bool] = None
peft_use_rslora: Optional[bool] = None
peft_layer_replication: Optional[List[Tuple[int, int]]] = None
qlora_sharded_model_loading: Optional[bool] = Field(
default=False,
metadata={
"help": "load qlora model in sharded format for FSDP using answer.ai technique."
},
)
lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None
bnb_config_kwargs: Optional[Dict[str, Any]] = None
@@ -300,13 +279,6 @@ class LoraConfig(BaseModel):
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self
@field_validator("loraplus_lr_embedding")
@classmethod
def convert_loraplus_lr_embedding(cls, loraplus_lr_embedding):
if loraplus_lr_embedding and isinstance(loraplus_lr_embedding, str):
loraplus_lr_embedding = float(loraplus_lr_embedding)
return loraplus_lr_embedding
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
@@ -330,13 +302,8 @@ class ModelInputConfig(BaseModel):
tokenizer_type: Optional[str] = Field(
default=None, metadata={"help": "transformers tokenizer class"}
)
processor_type: Optional[str] = Field(
default=None, metadata={"help": "transformers processor class"}
)
trust_remote_code: Optional[bool] = None
model_kwargs: Optional[Dict[str, Any]] = None
@field_validator("trust_remote_code")
@classmethod
def hint_trust_remote_code(cls, trust_remote_code):
@@ -368,24 +335,13 @@ class HyperparametersConfig(BaseModel):
},
)
auto_find_batch_size: Optional[bool] = None
train_on_inputs: Optional[bool] = False
group_by_length: Optional[bool] = None
learning_rate: Union[str, float]
weight_decay: Optional[float] = 0.0
optimizer: Optional[
Union[
OptimizerNames,
Literal[
"lion_pytorch",
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
],
]
Union[OptimizerNames, Literal["lion_pytorch", "optimi_adamw"]]
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
@@ -397,7 +353,7 @@ class HyperparametersConfig(BaseModel):
},
)
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
lr_scheduler: Optional[SchedulerType] = "cosine"
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None
cosine_min_lr_ratio: Optional[float] = None
@@ -535,7 +491,6 @@ class AxolotlInputConfig(
dataset_prepared_path: Optional[str] = None
dataset_shard_num: Optional[int] = None
dataset_shard_idx: Optional[int] = None
skip_prepare_dataset: Optional[bool] = False
pretraining_dataset: Optional[ # type: ignore
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
@@ -549,8 +504,6 @@ class AxolotlInputConfig(
dataloader_prefetch_factor: Optional[int] = None
dataloader_drop_last: Optional[bool] = None
accelerator_config: Optional[Dict[str, Any]] = None
remove_unused_columns: Optional[bool] = None
push_dataset_to_hub: Optional[str] = None
@@ -608,7 +561,6 @@ class AxolotlInputConfig(
eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None
# for PoSE context length extension
use_pose: Optional[bool] = None
@@ -634,21 +586,14 @@ class AxolotlInputConfig(
flash_attn_fuse_mlp: Optional[bool] = None
flash_optimum: Optional[bool] = None
eager_attention: Optional[bool] = None
unsloth_cross_entropy_loss: Optional[bool] = None
unsloth_lora_mlp: Optional[bool] = None
unsloth_lora_qkv: Optional[bool] = None
unsloth_lora_o: Optional[bool] = None
unsloth_rms_norm: Optional[bool] = None
unsloth_rope: Optional[bool] = None
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None
fsdp_config: Optional[Dict[str, Any]] = None
fsdp_final_state_dict_type: Optional[
Literal["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
] = None
val_set_size: Optional[float] = Field(default=0.0)
@@ -657,9 +602,6 @@ class AxolotlInputConfig(
torch_compile: Optional[bool] = None
torch_compile_backend: Optional[str] = None
torch_compile_mode: Optional[
Literal["default", "reduce-overhead", "max-autotune"]
] = None
max_steps: Optional[int] = None
warmup_steps: Optional[int] = None
@@ -681,8 +623,6 @@ class AxolotlInputConfig(
orpo_alpha: Optional[float] = None
rpo_alpha: Optional[float] = None
simpo_gamma: Optional[float] = None
cpo_alpha: Optional[float] = None
kto_desirable_weight: Optional[float] = None
kto_undesirable_weight: Optional[float] = None
@@ -697,8 +637,6 @@ class AxolotlInputConfig(
chat_template: Optional[ChatTemplate] = None
default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None
# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None
@@ -764,24 +702,6 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_pretraining_split_batches_accelerate(cls, data):
# alternatively set ACCELERATE_SPLIT_BATCHES=False
if data.get("pretraining_dataset"):
accelerator_config = data.get("accelerator_config", {})
if not accelerator_config:
data["accelerator_config"] = {
"split_batches": False,
"dispatch_batches": False,
}
else:
if accelerator_config.get("split_batches") is None:
data["accelerator_config"]["split_batches"] = False
if accelerator_config.get("dispatch_batches") is None:
data["accelerator_config"]["dispatch_batches"] = False
return data
@model_validator(mode="before")
@classmethod
def check_gptq_w_revision(cls, data):
@@ -900,7 +820,7 @@ class AxolotlInputConfig(
@model_validator(mode="after")
def check_adamw_optimizer_params(self):
if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
not self.optimizer or "adamw" not in str(self.optimizer).lower()
not self.optimizer or "adamw" not in self.optimizer.value
):
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self
@@ -971,8 +891,6 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_eval_packing(cls, data):
# TODO also should check test_datasets and val_set_size as we can skip
# if there are no eval datasets/splits
if (
data.get("sample_packing")
and data.get("eval_table_size")
@@ -1003,18 +921,6 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_mm_prepare(cls, data):
if data.get("skip_prepare_dataset"):
if data.get("remove_unused_columns") is None:
LOG.info(
"setting `remove_unused_columns: false` for skip_prepare_dataset"
)
data["remove_unused_columns"] = False
return data
@model_validator(mode="before")
@classmethod
def check_warmup(cls, data):
@@ -1042,20 +948,12 @@ class AxolotlInputConfig(
return neftune_noise_alpha
@model_validator(mode="after")
def check_rl_beta(self):
def check(self):
if self.dpo_beta and not self.rl_beta:
self.rl_beta = self.dpo_beta
del self.dpo_beta
return self
@model_validator(mode="after")
def check_simpo_warmup(self):
if self.rl == "simpo" and self.warmup_ratio:
raise ValueError(
"warmup_ratio is not supported with the simpo trainer. Please use `warmup_steps` instead"
)
return self
@model_validator(mode="before")
@classmethod
def check_frozen(cls, data):
@@ -1070,15 +968,6 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_peft_layers_pattern(cls, data):
if data.get("peft_layers_pattern") and not data.get("peft_layers_to_transform"):
raise ValueError(
"peft_layers_pattern requires peft_layers_to_transform to be set"
)
return data
@model_validator(mode="after")
def check_fft_possible_bad_config(self):
if (
@@ -1198,20 +1087,6 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_sharded_state_dict_w_safetensors(cls, data):
if (
data.get("fsdp")
and data.get("save_safetensors")
and data.get("fsdp_config")
and data["fsdp_config"].get("fsdp_state_dict_type") == "SHARDED_STATE_DICT"
):
raise ValueError(
"FSDP SHARDED_STATE_DICT not compatible with save_safetensors"
)
return data
@model_validator(mode="before")
@classmethod
def check_causal_lm_evals(cls, data):
@@ -1237,55 +1112,6 @@ class AxolotlInputConfig(
raise ValueError("either datasets or pretraining_dataset is required")
return data
@model_validator(mode="before")
@classmethod
def check_xentropy_patch_conflicts(cls, data):
if data.get("flash_attn_cross_entropy") and data.get(
"unsloth_cross_entropy_loss"
):
raise ValueError(
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
)
return data
@model_validator(mode="before")
@classmethod
def check_qlora_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
)
return data
@model_validator(mode="before")
@classmethod
def check_unsloth_xformers_version(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
xformers_version = version("xformers")
if xformers_version == "0.0.27":
raise ValueError(
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):
if data.get("deepspeed") and data.get("torch_compile"):
raise ValueError(
"torch_compile should be set within your deepspeed config file"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
@@ -1331,37 +1157,9 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
return data
@model_validator(mode="before")
@classmethod
def check_hopper_8bit_lora(cls, data):
is_sm_90: bool = (
data["capabilities"]
and data["capabilities"].get("compute_capability") == "sm_90"
)
if data.get("adapter") and data.get("load_in_8bit") and is_sm_90:
# see https://github.com/bitsandbytes-foundation/bitsandbytes/issues/538#issuecomment-2262945464
raise ValueError("8-bit LoRA is not supported on Hopper GPUs")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_deepspeed(cls, data):
if data.get("deepspeed") and data.get("fsdp"):
raise ValueError("deepspeed and fsdp cannot be used together.")
return data
@model_validator(mode="before")
@classmethod
def check_multigpu_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
capabilities = data.get("capabilities")
if capabilities and capabilities.get("n_gpu", 0) > 1:
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
)
return data

View File

@@ -18,10 +18,10 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
) -> Dict[str, List]:
res = tokenizer(
examples["text"],
examples,
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,

View File

@@ -1,5 +1,4 @@
"""data handling specific to DPO"""
import inspect
import logging
from functools import partial

View File

@@ -42,7 +42,7 @@ from axolotl.prompters import (
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.utils import md5
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_local_main_process, zero_first
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
@@ -51,31 +51,20 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl")
def prepare_dataset(cfg, tokenizer, processor=None):
def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset:
with zero_first(is_local_main_process()):
with zero_first(is_main_process()):
if cfg.test_datasets:
train_dataset, _, prompters = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
split="train",
processor=processor,
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
)
_, eval_dataset, _ = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
split="test",
processor=processor,
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="test"
)
else:
train_dataset, eval_dataset, prompters = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
processor=processor,
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
path = cfg.pretraining_dataset
@@ -134,7 +123,6 @@ def load_tokenized_prepared_datasets(
cfg,
default_dataset_prepared_path,
split="train",
processor=None,
) -> Tuple[DatasetDict, List[Prompter]]:
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
tokenizer_name = cfg.tokenizer_config
@@ -172,12 +160,8 @@ def load_tokenized_prepared_datasets(
use_auth_token = cfg.hf_use_auth_token
try:
if cfg.push_dataset_to_hub:
LOG.info(
f"Attempting to load prepared dataset from Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
)
dataset = load_dataset(
cfg.push_dataset_to_hub,
ds_hash,
f"{cfg.push_dataset_to_hub}/{ds_hash}",
token=use_auth_token,
)
dataset = dataset[split]
@@ -186,26 +170,17 @@ def load_tokenized_prepared_datasets(
# pylint: disable=duplicate-code
if dataset:
# This is for the case where we already loaded a pretokenized dataset from the hub
...
elif (
cfg.dataset_prepared_path
and any(prepared_ds_path.glob("*"))
and not cfg.is_preprocess
and not cfg.skip_prepare_dataset
):
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))
LOG.info("Prepared dataset loaded from disk...")
else:
if cfg.push_dataset_to_hub:
LOG.info("Unable to find prepared dataset in Huggingface hub")
if cfg.is_preprocess:
LOG.info(
f"Skipping prepared dataset in {prepared_ds_path} for pre-processing..."
)
else:
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
LOG.info("Loading raw datasets...")
if not cfg.is_preprocess:
LOG.warning(
@@ -223,8 +198,6 @@ def load_tokenized_prepared_datasets(
def for_d_in_datasets(dataset_configs):
for dataset in dataset_configs:
if dataset.name and isinstance(dataset.name, list):
# load_dataset doesn't properly handle multiple named configurations
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
else:
@@ -235,8 +208,6 @@ def load_tokenized_prepared_datasets(
ds: Optional[Union[Dataset, DatasetDict]] = None
ds_from_hub = False
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
load_dataset(
config_dataset.path,
name=config_dataset.name,
@@ -436,16 +407,12 @@ def load_tokenized_prepared_datasets(
dataset=ds,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style,
processor=processor,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
if len(datasets) == 1:
dataset = datasets[0]
else:
LOG.info("merging datasets")
dataset = concatenate_datasets(datasets)
LOG.info("merging datasets")
dataset = concatenate_datasets(datasets)
if len(datasets) > 1:
if cfg.shuffle_merged_datasets:
@@ -454,20 +421,17 @@ def load_tokenized_prepared_datasets(
else:
LOG.debug("NOT shuffling merged datasets")
if not cfg.skip_prepare_dataset:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub:
LOG.info(
f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
)
dataset.push_to_hub(
cfg.push_dataset_to_hub,
ds_hash,
private=True,
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
)
return dataset, prompters
@@ -496,14 +460,9 @@ def load_prepare_datasets(
cfg,
default_dataset_prepared_path,
split="train",
processor=None,
) -> Tuple[Dataset, Dataset, List[Prompter]]:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer,
cfg,
default_dataset_prepared_path,
split=split,
processor=processor,
tokenizer, cfg, default_dataset_prepared_path, split=split
)
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
@@ -569,7 +528,6 @@ def get_dataset_wrapper(
d_base_type,
dataset,
d_prompt_style=None,
processor=None,
):
dataset_wrapper = None
dataset_prompter = None
@@ -602,11 +560,7 @@ def get_dataset_wrapper(
dataset,
**ds_kwargs,
)
elif cfg.skip_prepare_dataset:
dataset_wrapper = dataset
elif ds_strategy := load(
config_dataset.type, tokenizer, cfg, config_dataset, processor=processor
):
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy,

View File

@@ -44,10 +44,6 @@ def is_main_process():
return dist.get_rank() == 0
def is_local_main_process():
return PartialState().is_main_process
def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))
@@ -153,11 +149,11 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
if is_main_process():
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device(), dtype=torch.float32
)
value_scalar, device=torch.cuda.current_device()
).float()
else:
value_tensor = torch.tensor(
0.0, device=torch.cuda.current_device(), dtype=torch.float32
0.0, device=torch.cuda.current_device()
) # Placeholder tensor
# Broadcast the tensor to all processes.

View File

@@ -13,7 +13,6 @@ from fastcore.parallel import parallel
from torch import Tensor, nn
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from transformers.quantizers import AutoHfQuantizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
@@ -174,7 +173,6 @@ def load_sharded_model_quant(
low_memory=True,
verbose=False,
loading_workers=2,
quantization_config=None,
):
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
@@ -188,26 +186,15 @@ def load_sharded_model_quant(
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
compress_statistics=True, # bnb_4bit_use_double_quant
skip_modules=[
"lm_head",
"embed_out",
],
)
else:
# this is the more common case with HF transformers
# TODO can we detect the model arch and dynamically set skip_modules
model.model = _replace_linear(
model.model,
Linear4bit,
compute_dtype=compute_dtype,
quant_type="nf4",
quant_storage=quant_storage,
compress_statistics=True, # bnb_4bit_use_double_quant
skip_modules=[
"lm_head",
"embed_out",
],
)
model.is_loaded_in_4bit = True
@@ -264,11 +251,6 @@ def load_sharded_model_quant(
quant_method=quant_method,
)
# these attributes are needed to inform transformers/peft of the quantization
model.is_quantized = True
model.quantization_method = "bitsandbytes"
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
# cleanup any extra memory usage from parallel loading

View File

@@ -1,7 +1,7 @@
"""Module for models and model loading"""
# pylint: disable=too-many-lines
import gc
import logging
import math
import os
@@ -28,21 +28,14 @@ from transformers import ( # noqa: F401
AddedToken,
AutoConfig,
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
AwqConfig,
BitsAndBytesConfig,
GPTQConfig,
LlavaForConditionalGeneration,
MllamaForConditionalGeneration,
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
)
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -85,9 +78,6 @@ def get_module_class_from_name(module, name):
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]):
if cfg.is_multimodal:
model_config = model_config.text_config
quant_config_exists = (
hasattr(model_config, "quantization_config")
and model_config.quantization_config
@@ -104,7 +94,7 @@ def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDef
"Please make sure to point to a GPTQ model."
)
if not cfg.gptq and quant_config_exists and not cfg.load_in_4bit:
if not cfg.gptq and quant_config_exists:
raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
@@ -307,63 +297,25 @@ def load_tokenizer(cfg):
return tokenizer
def load_processor(cfg: DictDefault, tokenizer: PreTrainedTokenizerBase):
processor_kwargs: Dict[str, Any] = {} # do we actually need this?
processor_cls = AutoProcessor
if cfg.processor_type:
processor_cls = getattr(transformers, cfg.processor_type)
processor = processor_cls.from_pretrained(
cfg.processor_config,
trust_remote_code=cfg.trust_remote_code or False,
tokenizer=tokenizer,
**processor_kwargs,
)
return processor
def load_model(
cfg: DictDefault,
tokenizer: PreTrainedTokenizerBase,
*,
processor: ProcessorMixin = None, # pylint: disable=unused-argument
inference: bool = False,
reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""
Load a model for a given configuration and tokenizer.
"""
base_model = cfg.base_model
model_type = cfg.type_of_model
model_config = load_model_config(cfg)
# load any patches from plugins
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(cfg)
if cfg.is_multimodal:
text_model_config = model_config.text_config
else:
text_model_config = model_config
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
if cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if hasattr(model_config, "model_type") and model_config.model_type == "mllama":
if cfg.flash_attention:
from axolotl.monkeypatch.attention.mllama import patch_mllama
patch_mllama()
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
if cfg.flash_attention:
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
@@ -394,36 +346,7 @@ def load_model(
and cfg.flash_attention
and cfg.sample_packing
):
patch_for_multipack(
cfg.model_config_type,
model_name=cfg.base_model,
is_remote_code=cfg.trust_remote_code,
)
if cfg.is_llama_derived_model:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)
if cfg.flash_attn_cross_entropy:
patch_llama_cross_entropy()
if cfg.flash_attn_rms_norm:
patch_llama_rms_norm()
elif cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm()
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import (
integrate_cross_entropy_loss_patch,
)
integrate_cross_entropy_loss_patch(model_type="llama")
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block
@@ -476,7 +399,7 @@ def load_model(
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch(model_type="llama")
integrate_cross_entropy_loss_patch()
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -484,12 +407,23 @@ def load_model(
patch_self_attn_lora()
# Modify mistral derived models
if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss:
if (
cfg.model_config_type == "mistral"
and cfg.flash_attention
and cfg.sample_packing
):
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
patch_mistral_cross_entropy,
replace_mistral_attn_with_flash_attn,
)
patch_mistral_cross_entropy()
LOG.info("patching mistral with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
LOG.info("patching _expand_mask")
hijack_expand_mask()
model_kwargs: Dict[str, Any] = {}
@@ -500,19 +434,6 @@ def load_model(
max_memory = cfg.max_memory
device_map = cfg.device_map
AutoModelLoader = AutoModelForCausalLM # pylint: disable=invalid-name
if cfg.is_multimodal:
if model_config.model_type == "llava":
AutoModelLoader = ( # pylint: disable=invalid-name
LlavaForConditionalGeneration
)
elif model_config.model_type == "mllama":
AutoModelLoader = ( # pylint: disable=invalid-name
MllamaForConditionalGeneration
)
else:
AutoModelLoader = AutoModelForVision2Seq # pylint: disable=invalid-name
if cfg.gpu_memory_limit:
gpu_memory_limit = (
str(cfg.gpu_memory_limit) + "GiB"
@@ -530,7 +451,7 @@ def load_model(
from accelerate import infer_auto_device_map
with init_empty_weights():
model_canvas = AutoModelLoader.from_config(
model_canvas = AutoModelForCausalLM.from_config(
model_config, trust_remote_code=cfg.trust_remote_code or False
)
model_canvas.tie_weights()
@@ -575,25 +496,7 @@ def load_model(
model_kwargs["quantization_config"] = GPTQConfig(
**model_config.quantization_config
)
if (
cfg.adapter in ["qlora", "lora"]
and hasattr(model_config, "quantization_config")
and model_config.quantization_config["quant_method"]
in ["gptq", "awq", "bitsandbytes"]
):
if model_config.quantization_config["quant_method"] == "gptq":
model_kwargs["quantization_config"] = GPTQConfig(
**model_config.quantization_config
)
elif model_config.quantization_config["quant_method"] == "awq":
model_kwargs["quantization_config"] = AwqConfig(
**model_config.quantization_config
)
elif model_config.quantization_config["quant_method"] == "bitsandbytes":
model_kwargs["quantization_config"] = BitsAndBytesConfig(
**model_config.quantization_config
)
elif cfg.adapter == "qlora" and cfg.load_in_4bit:
if cfg.adapter == "qlora" and cfg.load_in_4bit:
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
@@ -603,9 +506,7 @@ def load_model(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
cfg.deepspeed or cfg.fsdp
):
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32
@@ -641,12 +542,25 @@ def load_model(
# sample packing uses custom FA2 patch
if cfg.flash_attention:
if not cfg.sample_packing and cfg.s2_attention:
pass
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
if not cfg.sample_packing:
if cfg.s2_attention:
pass
# most other models support flash attention, we can define exceptions as they come up
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
else:
model_kwargs["attn_implementation"] = "eager"
model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
elif cfg.sdp_attention:
model_kwargs["attn_implementation"] = "sdpa"
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
@@ -676,23 +590,14 @@ def load_model(
elif (
qlora_fsdp
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading
and (cfg.model_config_type == "dbrx" or cfg.qlora_sharded_model_loading)
and cfg.model_config_type == "dbrx"
):
quant_storage = cfg.torch_dtype
quantization_config = hasattr(
model_config, "quantization_config"
) and getattr(model_config, "quantization_config")
quantization_config = (
quantization_config or model_kwargs["quantization_config"]
)
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = load_sharded_model_quant(
base_model,
model_config,
cfg,
quant_storage=quant_storage,
quantization_config=quantization_config,
)
skip_move_to_device = True
elif (
@@ -700,14 +605,12 @@ def load_model(
and not cfg.trust_remote_code
and not cfg.gptq
):
if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
skip_move_to_device = True
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = AutoModelLoader.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
**model_kwargs,
@@ -746,17 +649,13 @@ def load_model(
and not cfg.trust_remote_code
):
if cfg.gptq:
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = AutoModelLoader.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = getattr(transformers, model_type).from_pretrained(
base_model,
config=model_config,
@@ -767,38 +666,34 @@ def load_model(
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts
if (
hasattr(text_model_config, "max_seq_len")
and text_model_config.max_seq_len
hasattr(model_config, "max_seq_len")
and model_config.max_seq_len
and cfg.sequence_len > model_config.max_seq_len
):
text_model_config.max_seq_len = cfg.sequence_len
model_config.max_seq_len = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
elif (
hasattr(text_model_config, "max_sequence_length")
and text_model_config.max_sequence_length
and cfg.sequence_len > text_model_config.max_sequence_length
hasattr(model_config, "max_sequence_length")
and model_config.max_sequence_length
and cfg.sequence_len > model_config.max_sequence_length
):
text_model_config.max_sequence_length = cfg.sequence_len
model_config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}")
if cfg.gptq:
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = AutoModelLoader.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else:
if cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
# disabling either of these two still leads to VRAM spike before setting back down
skip_move_to_device = True
if "device_map" in model_kwargs:
del model_kwargs["device_map"]
if cfg.is_multimodal:
model_config.text_config = text_model_config
model = AutoModelLoader.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False,
@@ -876,16 +771,12 @@ def load_model(
set_z3_leaf_modules,
)
if cfg.model_config_type in MOE_ARCH_BLOCK:
moe_blocks = MOE_ARCH_BLOCK[cfg.model_config_type]
moe_blocks = [moe_blocks] if isinstance(moe_blocks, str) else moe_blocks
set_z3_leaf_modules(
model,
[
get_module_class_from_name(model, module_name)
for module_name in moe_blocks
],
)
if cfg.model_config_type == "mixtral":
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock")
set_z3_leaf_modules(model, [moe_block])
elif cfg.model_config_type == "dbrx":
moe_block = get_module_class_from_name(model, "DbrxFFN")
set_z3_leaf_modules(model, [moe_block])
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
@@ -899,9 +790,6 @@ def load_model(
# make sure everything is in the same dtype
skip_prepare_model_for_kbit_training = True
if is_deepspeed_zero3_enabled():
skip_prepare_model_for_kbit_training = True
if cfg.adapter in ["lora", "qlora"]:
if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable(
@@ -936,9 +824,6 @@ def load_model(
else:
model, lora_config = load_adapter(model, cfg, cfg.adapter)
if is_deepspeed_zero3_enabled():
skip_move_to_device = True
if (
cfg.ddp
and not load_in_8bit
@@ -978,15 +863,6 @@ def load_model(
integrate_lora_patch(model, cfg)
if cfg.unsloth_rope:
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
integrate_rope_embeddings()
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
# TODO resume_from_checkpoint handling
return model, lora_config
@@ -1080,17 +956,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
from peft import LoraConfig, get_peft_model
lora_target_modules = cfg.lora_target_modules or []
lora_target_modules = list(cfg.lora_target_modules or [])
if cfg.lora_target_linear:
linear_names = find_all_linear_names(model)
LOG.info(f"found linear modules: {repr(sorted(linear_names))}")
lora_target_modules_as_list = (
lora_target_modules
if isinstance(lora_target_modules, list)
else [lora_target_modules]
)
lora_target_modules = list(set(lora_target_modules_as_list + linear_names))
LOG.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(lora_target_modules + linear_names))
lora_config_kwargs = {}
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
@@ -1109,13 +980,11 @@ def load_lora(model, cfg, inference=False, config_only=False):
lora_alpha=cfg.lora_alpha,
target_modules=lora_target_modules,
layers_to_transform=cfg.peft_layers_to_transform,
layers_pattern=cfg.peft_layers_pattern,
lora_dropout=cfg.lora_dropout,
fan_in_fan_out=cfg.lora_fan_in_fan_out,
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
bias="none",
# task_type="CAUSAL_LM",
task_type="CONDITIONAL_GENERATION" if cfg.is_multimodal else "CAUSAL_LM",
task_type="CAUSAL_LM",
**lora_config_kwargs,
)
@@ -1167,20 +1036,9 @@ def load_lora(model, cfg, inference=False, config_only=False):
def ensure_dtype(model, dtype=torch.bfloat16):
for name, module in model.named_modules():
weight_mismatch = False
bias_mismatch = False
try:
weight_mismatch = module.weight.dtype != dtype
if module.weight.dtype != dtype:
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
module.to(dtype)
except AttributeError:
pass
try:
bias_mismatch = module.bias.dtype != dtype
except AttributeError:
pass
if weight_mismatch:
print(f"Converting module {name}.weight: {module.weight.dtype} -> {dtype}")
if bias_mismatch:
print(f"Converting module {name}.bias: {module.bias.dtype} -> {dtype}")
if weight_mismatch or bias_mismatch:
module.to(dtype)

View File

@@ -11,8 +11,6 @@ import numba
import numpy as np
from torch.utils.data import BatchSampler, Sampler
from axolotl.utils.distributed import reduce_and_broadcast
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
@@ -176,46 +174,16 @@ class MultipackBatchSampler(BatchSampler):
def efficiency(self):
return self.eff_total_used / self.eff_total_slots
def gather_efficiency(self):
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
return math.floor(0.997 * max(estimates))
sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
calc_sample_packing_eff_est,
)
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
)
return sample_packing_eff_est
def gather_len_batches(self, num):
def calc_min_len(estimates: list[(int, float)]):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(0.998 * min(estimates))
min_len_batches = reduce_and_broadcast(
lambda: num,
calc_min_len,
)
return min_len_batches
def __len__(self):
len_batches = self.num_batches()
return self.gather_len_batches(len_batches)
self.num_batches()
return self._len_est()
def _len_est(self):
efficiency = (
self.packing_efficiency_estimate
if self.packing_efficiency_estimate
else self.gather_efficiency()
)
world_size = int(os.getenv("WORLD_SIZE", "1"))
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // world_size
LOG.info(
f"packing_efficiency_estimate: {efficiency} "
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
@@ -227,7 +195,7 @@ class MultipackBatchSampler(BatchSampler):
* math.floor(
0.99
* lengths_sum_per_device
/ efficiency
/ self.packing_efficiency_estimate
// (self.batch_max_len * self.batch_size)
)
- 1

View File

@@ -62,7 +62,7 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
"""Helper function to process and color tokens."""
colored_tokens = [
color_token_for_rl_debug(tokenizer.decode(token), token, color, text_only)
for token in tokenizer.encode(tokens, add_special_tokens=False)
for token in tokenizer.encode(tokens)
]
return colored_tokens

View File

@@ -1,5 +1,4 @@
"""Module containing the Trainer class and related functions"""
import json
import math
import os
import random
@@ -16,7 +15,7 @@ from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger("axolotl")
@@ -183,106 +182,90 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)
with zero_first(is_main_process()):
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if cfg.is_preprocess:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if (
cfg.is_mistral_derived_model and cfg.flash_attention
) or cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids")
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids")
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
if eval_dataset:
eval_dataset = eval_dataset.filter(
train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
# drop samples with where the number of elements with labels not equal to -100 is zero
def drop_no_trainable_tokens(sample):
return np.sum(np.array(sample["labels"]) != -100) > 0
if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Group By Length",
)
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Drop Samples with Zero Trainable Tokens",
)
if eval_dataset:
eval_dataset = eval_dataset.filter(
drop_no_trainable_tokens,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Drop Samples with Zero Trainable Tokens",
)
if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Group By Length",
)
if cfg.use_pose:
pose_kwargs = {}
if cfg.pose_num_chunks is not None:
pose_kwargs["chunks"] = cfg.pose_num_chunks
pose_fn = partial(
add_pose_position_ids,
max_context_len=cfg.pose_max_context_len,
split_on_token_ids=cfg.pose_split_on_token_ids,
**pose_kwargs,
)
train_dataset = train_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
train_dataset = train_dataset.sort("sequence_len")
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
train_dataset = train_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
)
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
)
if cfg.use_pose:
pose_kwargs = {}
if cfg.pose_num_chunks is not None:
pose_kwargs["chunks"] = cfg.pose_num_chunks
pose_fn = partial(
add_pose_position_ids,
max_context_len=cfg.pose_max_context_len,
split_on_token_ids=cfg.pose_split_on_token_ids,
**pose_kwargs,
)
train_dataset = train_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
train_dataset = train_dataset.sort("sequence_len")
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
pose_fn,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
train_dataset = train_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
)
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
)
return train_dataset, eval_dataset
@@ -306,7 +289,7 @@ def process_pretraining_datasets_for_packing(
def calculate_total_num_steps(cfg, train_dataset, update=True):
if not cfg.total_num_tokens and not cfg.skip_prepare_dataset:
if not cfg.total_num_tokens:
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
@@ -319,11 +302,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
skip_estimates = cfg.model_config_type == "mamba"
if (
not skip_estimates
and not cfg.total_supervised_tokens
and not cfg.skip_prepare_dataset
):
if not skip_estimates and not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
@@ -361,7 +340,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
main_process_only=True,
)
else:
if cfg.flash_attention and not cfg.multipack_real_batches:
if cfg.flash_attention:
sampler_batch_size = 1
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
else:
@@ -412,27 +391,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
return total_num_steps
def setup_torch_compile_env(cfg):
if cfg.torch_compile:
if not cfg.torch_compile_backend:
os.environ["ACCELERATE_DYNAMO_BACKEND"] = "INDUCTOR"
else:
os.environ["ACCELERATE_DYNAMO_BACKEND"] = cfg.torch_compile_backend.upper()
def setup_deepspeed_env(cfg, stage=None):
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if stage:
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
# If we don't assign this, it doesn't actually get set in the accelerate weakref
_ = HfTrainerDeepSpeedConfig(cfg.deepspeed)
def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true"
if cfg.fsdp_config.fsdp_activation_checkpointing:
@@ -459,16 +417,8 @@ def prepare_optim_env(cfg):
if cfg.fsdp:
setup_fsdp_envs(cfg)
elif cfg.deepspeed:
stage = None
# check if the cfg.deepspeed is a file
if os.path.isfile(cfg.deepspeed):
# parse with json
with open(cfg.deepspeed, "r", encoding="utf-8") as fin:
deepspeed_config = json.load(fin)
stage = deepspeed_config.get("zero_optimization", {}).get("stage", None)
setup_deepspeed_env(cfg, stage=stage)
setup_torch_compile_env(cfg)
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
@@ -476,21 +426,13 @@ def prepare_optim_env(cfg):
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
def prepare_opinionated_env(cfg):
if cfg.qlora_sharded_model_loading:
# model loading is forked after the tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "orpo", "kto"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]
else:
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset

View File

@@ -1,110 +0,0 @@
"""
Simple end-to-end test for Liger integration
"""
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
class LigerIntegrationTestCase(unittest.TestCase):
"""
e2e tests for liger integration with Axolotl
"""
@with_temp_dir
def test_llama_wo_flce(self, temp_dir):
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"plugins": [
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rope": True,
"liger_rms_norm": True,
"liger_swiglu": True,
"liger_cross_entropy": True,
"liger_fused_linear_cross_entropy": False,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
@with_temp_dir
def test_llama_w_flce(self, temp_dir):
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"plugins": [
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rope": True,
"liger_rms_norm": True,
"liger_swiglu": True,
"liger_cross_entropy": False,
"liger_fused_linear_cross_entropy": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

Some files were not shown because too many files have changed in this diff Show More