Compare commits
38 Commits
transforme
...
sageattent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8acc72dd8 | ||
|
|
51c9e1a035 | ||
|
|
45c0825587 | ||
|
|
94fc223f6c | ||
|
|
151abb7a67 | ||
|
|
bf416bdfd0 | ||
|
|
838b74d05b | ||
|
|
2e99bb303e | ||
|
|
68a26f1005 | ||
|
|
db51a9e4cb | ||
|
|
8961364bc9 | ||
|
|
e9c3a2aec0 | ||
|
|
02ca3f93b0 | ||
|
|
5f6f9186e4 | ||
|
|
6679e20f47 | ||
|
|
ec59d4cb83 | ||
|
|
a77c8a71cf | ||
|
|
775311f98f | ||
|
|
f007c38e49 | ||
|
|
d9b71edf84 | ||
|
|
c07bd2fa65 | ||
|
|
ed079d434a | ||
|
|
8403c67156 | ||
|
|
9871fa060b | ||
|
|
70cf79ef52 | ||
|
|
c06b8f0243 | ||
|
|
0c8b1d824a | ||
|
|
fd70eec577 | ||
|
|
d42f202046 | ||
|
|
0dabde1962 | ||
|
|
15f1462ccd | ||
|
|
521e62daf1 | ||
|
|
c16ec398d7 | ||
|
|
2f20cb7ebf | ||
|
|
71d4030b79 | ||
|
|
f3a5d119af | ||
|
|
ba219b51a5 | ||
|
|
5be8e13d35 |
2
.github/workflows/base.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
- cuda: "124"
|
||||
cuda_version: 12.4.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
python_version: "3.10"
|
||||
pytorch: 2.4.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "124"
|
||||
|
||||
13
.github/workflows/main.yml
vendored
@@ -10,7 +10,7 @@ on:
|
||||
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -49,7 +49,7 @@ jobs:
|
||||
axolotlai/axolotl
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{version}}
|
||||
type=pep440,pattern={{version}}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to Docker Hub
|
||||
@@ -77,7 +77,7 @@ jobs:
|
||||
|
||||
build-axolotl-cloud:
|
||||
needs: build-axolotl
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -114,6 +114,9 @@ jobs:
|
||||
images: |
|
||||
winglian/axolotl-cloud
|
||||
axolotlai/axolotl-cloud
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=pep440,pattern={{version}}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
@@ -137,7 +140,7 @@ jobs:
|
||||
|
||||
build-axolotl-cloud-no-tmux:
|
||||
needs: build-axolotl
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -160,7 +163,7 @@ jobs:
|
||||
axolotlai/axolotl-cloud-term
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=semver,pattern={{version}}
|
||||
type=pep440,pattern={{version}}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
|
||||
7
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -8,9 +8,14 @@ on:
|
||||
schedule:
|
||||
- cron: '0 0 * * 1,4' # Runs at 00:00 UTC every monday & thursday
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
jobs:
|
||||
test-axolotl-multigpu:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
||||
4
.github/workflows/nightlies.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -71,7 +71,7 @@ jobs:
|
||||
|
||||
build-axolotl-cloud:
|
||||
needs: build-axolotl
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
matrix:
|
||||
|
||||
19
.github/workflows/pypi.yml
vendored
@@ -10,20 +10,13 @@ jobs:
|
||||
setup_release:
|
||||
name: Create Release
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Get the tag version
|
||||
id: extract_branch
|
||||
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
|
||||
shell: bash
|
||||
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: actions/create-release@v1
|
||||
- name: Create release
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
tag_name: ${{ steps.extract_branch.outputs.branch }}
|
||||
release_name: ${{ steps.extract_branch.outputs.branch }}
|
||||
run: gh release create "$GITHUB_REF_NAME" # GITHUB_REF_NAME is the tag name in `on.push.tags` workflows
|
||||
pypi-publish:
|
||||
name: Upload release to PyPI
|
||||
runs-on: ubuntu-latest
|
||||
@@ -56,9 +49,9 @@ jobs:
|
||||
run: |
|
||||
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
|
||||
|
||||
- name: Build a binary wheel
|
||||
- name: Build a source dist
|
||||
run: |
|
||||
python setup.py sdist bdist_wheel
|
||||
python setup.py sdist
|
||||
|
||||
- name: Publish package distributions to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
|
||||
1
.github/workflows/tests-nightly.yml
vendored
@@ -48,6 +48,7 @@ jobs:
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
|
||||
50
.github/workflows/tests.yml
vendored
@@ -71,18 +71,62 @@ jobs:
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest --ignore=tests/e2e/ tests/
|
||||
pytest -n8 --ignore=tests/e2e/ tests/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
pytest-sdist:
|
||||
name: PyTest from Source Dist
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.4.1", "2.5.1"]
|
||||
timeout-minutes: 20
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
python3 setup.py sdist
|
||||
pip3 install dist/axolotl*.tar.gz
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -n8 --ignore=tests/e2e/ tests/
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
docker-e2e-tests-1st:
|
||||
if: github.repository_owner == 'axolotl-ai-cloud'
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: [self-hosted, modal]
|
||||
timeout-minutes: 90
|
||||
needs: [pre-commit, pytest]
|
||||
needs: [pre-commit, pytest, pytest-sdist]
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
3
.gitignore
vendored
@@ -182,3 +182,6 @@ submit.sh
|
||||
|
||||
typings/
|
||||
out/
|
||||
|
||||
# vim
|
||||
*.swp
|
||||
|
||||
4
MANIFEST.in
Normal file
@@ -0,0 +1,4 @@
|
||||
include requirements.txt
|
||||
include README.md
|
||||
include LICENSE
|
||||
recursive-include axolotl *.py
|
||||
23
README.md
@@ -1,8 +1,21 @@
|
||||
# Axolotl
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="image/axolotl_logo_digital_white.svg">
|
||||
<source media="(prefers-color-scheme: light)" srcset="image/axolotl_logo_digital_black.svg">
|
||||
<img alt="Axolotl" src="image/axolotl_logo_digital_black.svg" width="400" height="104" style="max-width: 100%;">
|
||||
</picture>
|
||||
</p>
|
||||
|
||||

|
||||

|
||||

|
||||
<p align="center">
|
||||
<img src="https://img.shields.io/github/license/axolotl-ai-cloud/axolotl.svg?color=blue" alt="GitHub License">
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests.yml/badge.svg" alt="tests">
|
||||
<a href="https://github.com/axolotl-ai-cloud/axolotl/releases"><img src="https://img.shields.io/github/release/axolotl-ai-cloud/axolotl.svg" alt="Releases"></a>
|
||||
<img src="https://img.shields.io/github/stars/axolotl-ai-cloud/axolotl" alt="GitHub Repo stars">
|
||||
</p>
|
||||
<p align="center">
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
|
||||
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
|
||||
</p>
|
||||
|
||||
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
|
||||
|
||||
@@ -75,7 +88,7 @@ Features:
|
||||
<td>
|
||||
|
||||
<div align="center">
|
||||
<img src="image/axolotl.png" alt="axolotl" width="160">
|
||||
<img src="image/axolotl_symbol_digital_white.svg" alt="axolotl" width="160">
|
||||
<div>
|
||||
<p>
|
||||
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
|
||||
|
||||
@@ -28,6 +28,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^peft.*#peft @ git+https://github.com/huggingface/peft.git@main#' requirements.txt; \
|
||||
sed -i 's#^accelerate.*#accelerate @ git+https://github.com/huggingface/accelerate.git@main#' requirements.txt; \
|
||||
sed -i 's#^trl.*#trl @ git+https://github.com/huggingface/trl.git@main#' requirements.txt; \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
pytest -n4 --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
pytest -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
pytest -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
|
||||
pytest --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
|
||||
|
||||
@@ -2,4 +2,4 @@
|
||||
set -e
|
||||
|
||||
# only run one test at a time so as not to OOM the GPU
|
||||
pytest -n1 /workspace/axolotl/tests/e2e/multigpu/
|
||||
pytest -v -n2 /workspace/axolotl/tests/e2e/multigpu/
|
||||
|
||||
@@ -91,6 +91,7 @@ datasets:
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
train_on_split: train # Optional[str] name of dataset split to load from
|
||||
revision: # Optional[str] The specific revision of the dataset to use when loading from the Hugging Face Hub. This can be a commit hash, tag, or branch name. If not specified, the latest version will be used. This parameter is ignored for local datasets.
|
||||
trust_remote_code: # Optional[bool] Trust remote code for untrusted source
|
||||
|
||||
# Custom user instruction prompt
|
||||
- path: repo
|
||||
|
||||
@@ -11,12 +11,10 @@ standard industry baselines.
|
||||
|
||||
### Installation
|
||||
|
||||
The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
|
||||
to date libraries.
|
||||
The following will install the correct unsloth and extras from source.
|
||||
|
||||
```bash
|
||||
pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
|
||||
pip install --no-deps --force-reinstall xformers==0.0.26.post1
|
||||
python scripts/unsloth_install.py | sh
|
||||
```
|
||||
|
||||
### Using unsloth w Axolotl
|
||||
|
||||
@@ -2,19 +2,15 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "AKjdG7tbTb-n"
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Example notebook for running Axolotl on google colab"
|
||||
"## Setting up"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "RcbNpOgWRcii"
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
@@ -22,82 +18,76 @@
|
||||
"assert (torch.cuda.is_available()==True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "h3nLav8oTRA5"
|
||||
},
|
||||
"source": [
|
||||
"## Install Axolotl and dependencies"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "3c3yGAwnOIdi",
|
||||
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01"
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
|
||||
"!pip install flash-attn==\"2.7.0.post2\"\n",
|
||||
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
|
||||
"!pip install axolotl[deepspeed]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "BW2MFr7HTjub"
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Create an yaml config file"
|
||||
"## Hugging Face login (optional)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "9pkF2dSoQEUN"
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from huggingface_hub import notebook_login\n",
|
||||
"notebook_login()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Example configuration"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import yaml\n",
|
||||
"\n",
|
||||
"# Your YAML string\n",
|
||||
"yaml_string = \"\"\"\n",
|
||||
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
|
||||
"model_type: LlamaForCausalLM\n",
|
||||
"tokenizer_type: LlamaTokenizer\n",
|
||||
"base_model: NousResearch/Meta-Llama-3.1-8B\n",
|
||||
"\n",
|
||||
"load_in_8bit: false\n",
|
||||
"load_in_4bit: true\n",
|
||||
"strict: false\n",
|
||||
"\n",
|
||||
"datasets:\n",
|
||||
" - path: mhenrichsen/alpaca_2k_test\n",
|
||||
" - path: tatsu-lab/alpaca\n",
|
||||
" type: alpaca\n",
|
||||
"dataset_prepared_path:\n",
|
||||
"dataset_prepared_path: last_run_prepared\n",
|
||||
"val_set_size: 0.05\n",
|
||||
"output_dir: ./outputs/qlora-out\n",
|
||||
"output_dir: ./outputs/lora-out\n",
|
||||
"\n",
|
||||
"sequence_len: 2048\n",
|
||||
"sample_packing: true\n",
|
||||
"eval_sample_packing: true\n",
|
||||
"pad_to_sequence_len: true\n",
|
||||
"\n",
|
||||
"adapter: qlora\n",
|
||||
"lora_model_dir:\n",
|
||||
"\n",
|
||||
"sequence_len: 4096\n",
|
||||
"sample_packing: true\n",
|
||||
"eval_sample_packing: false\n",
|
||||
"pad_to_sequence_len: true\n",
|
||||
"\n",
|
||||
"lora_r: 32\n",
|
||||
"lora_alpha: 16\n",
|
||||
"lora_dropout: 0.05\n",
|
||||
"lora_target_modules:\n",
|
||||
"lora_target_linear: true\n",
|
||||
"lora_fan_in_fan_out:\n",
|
||||
"lora_modules_to_save:\n",
|
||||
" - embed_tokens\n",
|
||||
" - lm_head\n",
|
||||
"\n",
|
||||
"wandb_project:\n",
|
||||
"wandb_entity:\n",
|
||||
@@ -105,12 +95,12 @@
|
||||
"wandb_name:\n",
|
||||
"wandb_log_model:\n",
|
||||
"\n",
|
||||
"gradient_accumulation_steps: 4\n",
|
||||
"micro_batch_size: 2\n",
|
||||
"num_epochs: 4\n",
|
||||
"optimizer: paged_adamw_32bit\n",
|
||||
"gradient_accumulation_steps: 2\n",
|
||||
"micro_batch_size: 1\n",
|
||||
"num_epochs: 1\n",
|
||||
"optimizer: paged_adamw_8bit\n",
|
||||
"lr_scheduler: cosine\n",
|
||||
"learning_rate: 0.0002\n",
|
||||
"learning_rate: 2e-5\n",
|
||||
"\n",
|
||||
"train_on_inputs: false\n",
|
||||
"group_by_length: false\n",
|
||||
@@ -121,13 +111,15 @@
|
||||
"gradient_checkpointing: true\n",
|
||||
"early_stopping_patience:\n",
|
||||
"resume_from_checkpoint:\n",
|
||||
"local_rank:\n",
|
||||
"logging_steps: 1\n",
|
||||
"xformers_attention:\n",
|
||||
"flash_attention: true\n",
|
||||
"flash_attention: false\n",
|
||||
"sdp_attention: true\n",
|
||||
"\n",
|
||||
"warmup_steps: 10\n",
|
||||
"evals_per_epoch: 4\n",
|
||||
"warmup_steps: 1\n",
|
||||
"max_steps: 25\n",
|
||||
"evals_per_epoch: 1\n",
|
||||
"eval_table_size:\n",
|
||||
"saves_per_epoch: 1\n",
|
||||
"debug:\n",
|
||||
"deepspeed:\n",
|
||||
@@ -135,9 +127,10 @@
|
||||
"fsdp:\n",
|
||||
"fsdp_config:\n",
|
||||
"special_tokens:\n",
|
||||
"\n",
|
||||
" pad_token: <|end_of_text|>\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Convert the YAML string to a Python dictionary\n",
|
||||
"yaml_dict = yaml.safe_load(yaml_string)\n",
|
||||
"\n",
|
||||
@@ -146,31 +139,124 @@
|
||||
"\n",
|
||||
"# Write the YAML file\n",
|
||||
"with open(file_path, 'w') as file:\n",
|
||||
" yaml.dump(yaml_dict, file)\n"
|
||||
" yaml.dump(yaml_dict, file)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "bidoj8YLTusD"
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Launch the training"
|
||||
"Above we have a configuration file with base LLM model and datasets specified, among many other things. Axolotl can automatically detect whether the specified datasets are on HuggingFace repo or local machine.\n",
|
||||
"\n",
|
||||
"The Axolotl configuration options encompass model and dataset selection, data pre-processing, and training. Let's go through them line by line:\n",
|
||||
"\n",
|
||||
"* \"base model\": String value, specifies the underlying pre-trained LLM that will be used for finetuning\n",
|
||||
"\n",
|
||||
"Next we have options for model weights quantization. Quantization allows for reduction in occupied memory on GPUs.\n",
|
||||
"\n",
|
||||
"* \"load_in_8bit\": Boolean value, whether to quantize the model weights into 8-bit integer.\n",
|
||||
"\n",
|
||||
"* \"load_in_4bit\": Boolean value, whether to quantize the model weights into 4-bit integer.\n",
|
||||
"\n",
|
||||
"* \"strict\": Boolean value. If false, it allows for overriding established configuration options in the yaml file when executing in command-line interface.\n",
|
||||
"\n",
|
||||
"* \"datasets\": a list of dicts that contain path and type of data sets as well as other optional configurations where datasets are concerned. Supports multiple datasets.\n",
|
||||
"\n",
|
||||
"* \"val_set_size\": Either a float value less than one or an integer less than the total size of dataset. Sets the size of validation set from the whole dataset. If float, sets the proportion of the dataset assigned for validation. If integer, sets the direct size of validation set.\n",
|
||||
"\n",
|
||||
"* \"output_dir\": String value. Path of trained model.\n",
|
||||
"\n",
|
||||
"For data preprocessing:\n",
|
||||
"\n",
|
||||
"* \"sequence_len\": Integer. Specifies the maximum sequence length of the input. Typically 2048 or less.\n",
|
||||
"\n",
|
||||
"* \"pad_to_sequence_len\": Boolean. Padding input to maximum sequence length.\n",
|
||||
"\n",
|
||||
"* \"sample_packing\": Boolean. Specifies whether to use multi-packing with block diagonal attention.\n",
|
||||
"\n",
|
||||
"* \"special_tokens\": Python dict, optional. Allows users to specify the additional special tokens to be ignored by the tokenizer.\n",
|
||||
"\n",
|
||||
"For LoRA configuration and its hyperparamters:\n",
|
||||
"\n",
|
||||
"* \"adapter\": String. Either \"lora\" or \"qlora\", depending on user's choice.\n",
|
||||
"\n",
|
||||
"* \"lora_model_dir\": String, Optional. Path to directory that contains LoRA model, if there is already a trained LoRA model the user would like to use.\n",
|
||||
"\n",
|
||||
"* \"lora_r\": Integer. Refers to the rank of LoRA decomposition matrices. Higher value will reduce LoRA efficiency. Recommended to be set to 8.\n",
|
||||
"\n",
|
||||
"* \"lora_alpha\": Integer. Scale the weight matrices by $\\frac{\\text{lora_alpha}}{\\text{lora_r}}$Recommended to be fixed at 16.\n",
|
||||
"\n",
|
||||
"* \"lora_dropout\": Float that is 1 or less. The dropout probability of a lora layer.\n",
|
||||
"\n",
|
||||
"* \"lora_target_linear\": Boolean. If true, lora will target all linear modules in the transformers architecture.\n",
|
||||
"\n",
|
||||
"* \"lora_modules_to_save\": If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.\n",
|
||||
"\n",
|
||||
"See [LoRA](https://arxiv.org/abs/2106.09685) for detailed explanation of LoRA implementation.\n",
|
||||
"\n",
|
||||
"For the training configurations:\n",
|
||||
"\n",
|
||||
"* \"gradient_accumulation_steps\": Integer. The number of steps over which to accumulate gradient for batch training. E.g. if 2, backprop is performed every two steps.\n",
|
||||
"\n",
|
||||
"* \"micro_batch_size\": Integer. Batch size per gpu / gradient_accumulation_steps\n",
|
||||
"\n",
|
||||
"* \"num_epochs\": Integer. Number of epochs. One epoch is when training has looped over every batch in the whole data set once.\n",
|
||||
"\n",
|
||||
"* \"optimizer\": The optimizer to use for the training.\n",
|
||||
"\n",
|
||||
"* \"learning_rate\": The learning rate.\n",
|
||||
"\n",
|
||||
"* \"lr_scheduler\": The learning rate scheduler to use for adjusting learning rate during training.\n",
|
||||
"\n",
|
||||
"* \"train_on_inputs\": Boolean. Whether to ignore or include the user's prompt from the training labels.\n",
|
||||
"\n",
|
||||
"* \"group_by_length\": Boolean. Whether to group similarly sized data to minimize padding.\n",
|
||||
"\n",
|
||||
"* \"bf16\": Either \"auto\", \"true\", or \"false\". Whether to use CUDA bf16 floating point format. If set to \"auto\", will automatically apply bf16 should the gpu supports it.\n",
|
||||
"\n",
|
||||
"* \"fp16\": Optional. Specifies whether to use CUDA fp16. Automatically set to true if \"bf16\" is set to true. Otherwise false.\n",
|
||||
"\n",
|
||||
"* \"tf32\": Boolean. Whether to use CUDA tf32. Will override bf16.\n",
|
||||
"\n",
|
||||
"* \"gradient_checkpointing\": Boolean. Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing\n",
|
||||
"\n",
|
||||
"* \"gradient_checkpointing_kwargs\": Python Dict. Fed into the trainer.\n",
|
||||
"\n",
|
||||
"* \"logging_steps\": Integer. Log training information over every specified number of steps.\n",
|
||||
"\n",
|
||||
"* \"flash_attention\": Boolean. Whether to use the [flash attention](https://github.com/Dao-AILab/flash-attention) mechanism.\n",
|
||||
"\n",
|
||||
"* \"sdp_attention\": Boolean. Whether to use the Scaled Dot Product attention mechanism (the attention mechanism in the [original implementation](https://arxiv.org/abs/1706.03762) of transformers.)\n",
|
||||
"\n",
|
||||
"* \"warmup_steps\": Integer. The number of pre-training steps where a very low learning rate is used.\n",
|
||||
"\n",
|
||||
"* \"evals_per_epoch\": Integer. Number of evaluations to be performed within one training epoch.\n",
|
||||
"\n",
|
||||
"* \"saves_per_epoch\": Integer. Number of times the model is saved in one training epoch.\n",
|
||||
"\n",
|
||||
"* \"weight_decay\": Positive Float. Sets the \"strength\" of weight decay (i.e. setting the coefficient of L2 regularization)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The above is but a snippet aiming to get users familiarized with the types of streamlined configuration options axolotl provides. For a full list of configuration options, see [here](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Train the model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "ydTI2Jk2RStU",
|
||||
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b"
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# By using the ! the comand will be executed as a bash command\n",
|
||||
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
|
||||
]
|
||||
},
|
||||
@@ -178,7 +264,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Play with inference"
|
||||
"Predict with trained model"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -187,36 +273,85 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# By using the ! the comand will be executed as a bash command\n",
|
||||
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
|
||||
" --qlora_model_dir=\"./qlora-out\" --gradio"
|
||||
" --lora_model_dir=\"./outputs/lora-out\" --gradio"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Deeper Dive"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"It is also helpful to gain some familiarity over some of the core inner workings of axolotl"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Configuration Normalization"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Axolotl uses a custom Dict class, called ```DictDefault```\n",
|
||||
"to store configurations specified in the yaml configuration file (into a Python variable named ```cfg```). The definition for this custom Dict can be found in the [utils/dict.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/dict.py)\n",
|
||||
"\n",
|
||||
"```DictDefault``` is amended such that calling a missing key from it will result in a ```None``` return type. This is important because if some configuration options aren't specified by the user, the ```None``` type allows Axolotl to perform boolean operations to determine the default settings for missing configurations. For more examples on how this is done, check out [utils/config/__init__.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/__init__.py)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Loading Models, Tokenizers, and Trainer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we inspect [cli.train.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/cli/train.py), we will find that most of the heavy lifting were done by the function ```train()``` which is itself imported from [src/axolotl/train.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/train.py).\n",
|
||||
"\n",
|
||||
"```train()``` takes care of loading the appropriate tokenizer and pre-trained model through ```load_model()``` and ```load_tokenizer()``` from [src/axolotl/utils/models.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/models.py) respectively.\n",
|
||||
"\n",
|
||||
"```load_tokenizer()``` loads in the appropriate tokenizer given the desired model, as well as chat templates.\n",
|
||||
"\n",
|
||||
"```ModelLoader``` class follows after tokenizer has been selected. It will automatically discern the base model type, load in the desired model, as well as applying model-appropriate attention mechanism modifications (e.g. flash attention). Depending on which base model the user chooses in the configuration, ```ModelLoader``` will utilize the corresponding \"attention hijacking\" script. For example, if the user specified the base model to be ```NousResearch/Meta-Llama-3.1-8B```, which is of llama type, and set ```flash_attn``` to ```True```, ```ModelLoader``` will load in [llama_attn_hijack_flash.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/llama_attn_hijack_flash.py). For a list of supported attention hijacking, please refer to the directory [/src/axolotl/monkeypatch/](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/monkeypatch)\n",
|
||||
"\n",
|
||||
"Another important operation encompassed in ```train()``` is setting up the training that takes into account of user-specified traning configurations (e.g. num_epochs, optimizer) through the use of ```setup_trainer()``` from [/src/axolotl/utils/trainer.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/trainer.py), which in turn relies on modules from [/src/axolotl/core/trainer_builder.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/core/trainer_builder.py).\n",
|
||||
"```trainer_builder.py``` provides a list of trainer object options bespoke for the task type (Causal or Reinforcement learning ('dpo', 'ipo', 'kto') )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Monkey patch\n",
|
||||
"\n",
|
||||
"The [Monkey patch directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/monkeypatch) is where model architecture/optimization patching scripts are stored (these are modifications that are not implemented in the official releases, hence the name monkey patch). It includes attention jacking, ReLoRA, and unsloth optimization."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"gpuType": "T4",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.1"
|
||||
"version": "3.9.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
|
||||
67
examples/qwen2/dpo.yaml
Normal file
@@ -0,0 +1,67 @@
|
||||
base_model: Qwen/Qwen2.5-0.5B
|
||||
|
||||
strict: false
|
||||
|
||||
chat_template: qwen_25
|
||||
rl: dpo
|
||||
datasets:
|
||||
- path: fozziethebeat/alpaca_messages_2k_dpo_test
|
||||
type: chat_template.default
|
||||
field_messages: conversation
|
||||
field_chosen: chosen
|
||||
field_rejected: rejected
|
||||
message_field_role: role
|
||||
message_field_content: content
|
||||
roles:
|
||||
system:
|
||||
- system
|
||||
user:
|
||||
- user
|
||||
assistant:
|
||||
- assistant
|
||||
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./outputs/dpo-out
|
||||
|
||||
sequence_len: 2048
|
||||
sample_packing: false
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16:
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
eval_table_size:
|
||||
eval_max_new_tokens: 128
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
BIN
image/axolotl-badge-web-legacy.png
Normal file
|
After Width: | Height: | Size: 11 KiB |
|
Before Width: | Height: | Size: 11 KiB After Width: | Height: | Size: 24 KiB |
19
image/axolotl_logo_digital_black.svg
Normal file
@@ -0,0 +1,19 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 1113 283.5">
|
||||
<path fill="#141310" d="M435,234.3l-12.1-48.8h-54.4l-12.1,48.8h-24.7l48.2-185.1h31.6l47.9,185.1h-24.5ZM417.7,164.9l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
|
||||
<path fill="#141310" d="M568.2,234.3l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.6-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.4,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.2,49.3,71.7h-28.5Z"/>
|
||||
<path fill="#141310" d="M658.6,236.3c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM658.6,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path fill="#141310" d="M860.6,236.3c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM860.6,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path fill="#141310" d="M773.9,234c-18,0-32.6-14.6-32.6-32.6V48.8h24.1v152.6c0,4.7,3.8,8.5,8.5,8.5h16.8v24.1h-16.8Z"/>
|
||||
<path fill="#141310" d="M1036.2,234.3V81.4c0-4.7-3.8-8.5-8.5-8.5h-16.8v-24.1h16.8c18,0,32.6,14.6,32.6,32.6v152.9h-24.1Z"/>
|
||||
<path fill="#141310" d="M978.6,234.3c-18,0-32.6-14.6-32.6-32.6v-85.1h-20.3v-22.1h20.3v-45.3h24.1v45.3h30.2v22.1h-30.2v85.1c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
|
||||
<path fill="#141310" d="M51.5,49h12.2v-20.6h-12.2c-16,0-29,13-29,29v32.8h20.6v-32.8c0-4.7,3.8-8.4,8.4-8.4Z"/>
|
||||
<path fill="#141310" d="M92.8,49h12.2v-20.6h-12.2c-16,0-29,13-29,29v12.2h20.6v-12.2c0-4.7,3.8-8.4,8.4-8.4Z"/>
|
||||
<path fill="#141310" d="M249.3,57.4c0-16-13-29-29-29h-12.2v20.6h12.2c4.7,0,8.4,3.8,8.4,8.4v32.8h20.6v-32.8Z"/>
|
||||
<path fill="#141310" d="M187.4,90.2v-20.6h-103.1v20.6h-41.2v20.6h-20.6v41.2c0,11.4,9.2,20.6,20.6,20.6h185.5c11.4,0,20.6-9.2,20.6-20.6v-41.2h-20.6v-20.6h-41.2ZM166.8,141.7c0-5.7-4.6-10.3-10.3-10.3s-10.3,4.6-10.3,10.3v10.3h-20.6v-20.6c0-11.4,9.2-20.6,20.6-20.6s20.6,9.2,20.6,20.6v10.3ZM228.7,141.7c0-5.7-4.6-10.3-10.3-10.3s-10.3,4.6-10.3,10.3v10.3h-20.6v-20.6c0-11.4,9.2-20.6,20.6-20.6s20.6,9.2,20.6,20.6v10.3Z"/>
|
||||
<path fill="#141310" d="M208,57.4c0-16-13-29-29-29h-12.2v20.6h12.2c4.7,0,8.4,3.8,8.4,8.4v12.2h20.6v-12.2Z"/>
|
||||
<rect fill="#141310" x="22.5" y="234.5" width="41.2" height="20.6"/>
|
||||
<rect fill="#141310" x="84.3" y="234.5" width="164.9" height="20.6"/>
|
||||
<rect fill="#141310" x="208" y="193.3" width="41.2" height="20.6"/>
|
||||
<rect fill="#141310" x="22.5" y="193.3" width="164.9" height="20.6"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 3.2 KiB |
11
image/axolotl_logo_digital_white.svg
Normal file
@@ -0,0 +1,11 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 1113 283.5">
|
||||
<path fill="#fff" d="M462.9,234.2l-12.1-48.8h-54.4l-12.1,48.8h-24.7l48.2-185h31.6l47.9,185h-24.4ZM445.7,164.8l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
|
||||
<path fill="#fff" d="M596.1,234.2l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.5-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.3,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.1,49.3,71.6h-28.5Z"/>
|
||||
<path fill="#fff" d="M686.4,236.2c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.6,14.8-41.4,9.8-9.7,23.4-14.7,40.2-14.7s30.4,4.9,40.2,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.4-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM686.4,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.8v-36.7c0-10.5-2.8-18.5-8.2-23.8-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path fill="#fff" d="M888.3,236.2c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.6,14.8-41.4,9.8-9.7,23.4-14.7,40.2-14.7s30.4,4.9,40.2,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.4-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM888.3,114.1c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.8v-36.7c0-10.5-2.8-18.5-8.2-23.8-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path fill="#fff" d="M801.7,234c-18,0-32.6-14.6-32.6-32.6V48.8h24.1v152.5c0,4.7,3.8,8.5,8.5,8.5h16.7v24.1h-16.7Z"/>
|
||||
<path fill="#fff" d="M1063.8,234.2V81.4c0-4.7-3.8-8.5-8.5-8.5h-16.7v-24.1h16.7c18,0,32.6,14.6,32.6,32.6v152.8h-24.1Z"/>
|
||||
<path fill="#fff" d="M1006.2,234.2c-18,0-32.6-14.6-32.6-32.6v-85h-20.3v-22.1h20.3v-45.2h24.1v45.2h30.2v22.1h-30.2v85c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
|
||||
<path fill="#fff" d="M160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM277.3,57.4c0-23.8-19.3-43.1-43.1-43.1h-12.2c-3.9,0-7.6,1.6-10.2,4.4-5.9-2.9-12.3-4.4-18.9-4.4h-12.2c-7.7,0-14.1,6.3-14.1,14.1v20.6c0,2.4.6,4.6,1.6,6.6h-37c1-2,1.6-4.2,1.6-6.6v-20.6c0-7.7-6.3-14.1-14.1-14.1h-12.2c-6.5,0-13,1.5-18.9,4.4-2.6-2.8-6.3-4.4-10.2-4.4h-12.2c-23.8,0-43.1,19.3-43.1,43.1v32.8c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v41.2c0,11,5.2,20.8,13.2,27.2-7.3.4-13.2,6.6-13.2,14v20.6c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v20.6c0,7.7,6.3,14.1,14.1,14.1h41.2c4.1,0,7.7-1.7,10.3-4.5,2.6,2.8,6.2,4.5,10.3,4.5h164.9c7.7,0,14.1-6.3,14.1-14.1v-20.6c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-20.6c0-7.5-5.8-13.6-13.2-14,8-6.4,13.2-16.2,13.2-27.2v-41.2c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-32.8ZM77.8,255.1h-41.2v-20.6h41.2v20.6ZM36.5,213.9v-20.6h164.9v20.6H36.5ZM263.3,255.1H98.4v-20.6h164.9v20.6ZM263.3,213.9h-41.2v-20.6h41.2v20.6ZM263.3,90.2h-20.6v20.6h20.6v41.2c0,11.4-9.2,20.6-20.6,20.6H57.2c-11.4,0-20.6-9.2-20.6-20.6v-41.2h20.6v-20.6h-20.6v-32.8c0-16,13-29,29-29h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v32.8h41.2v-20.6h-20.6v-12.2c0-16,13-29,29-29h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v12.2h103.1v-12.2c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16,0,29,13,29,29v12.2h-20.6v20.6h41.2v-32.8c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16,0,29,13,29,29v32.8ZM201.4,152h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6s-20.6,9.2-20.6,20.6v20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM222,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM160.2,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6Z"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 6.6 KiB |
26
image/axolotl_symbol_digital_black.svg
Normal file
@@ -0,0 +1,26 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 283.5 283.5">
|
||||
<defs>
|
||||
<style>
|
||||
.cls-1 {
|
||||
fill: #141310;
|
||||
}
|
||||
</style>
|
||||
</defs>
|
||||
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
|
||||
<g>
|
||||
<g id="Layer_1">
|
||||
<g>
|
||||
<path class="cls-1" d="M46.9,37.4h13.7V14.2h-13.7c-18,0-32.7,14.6-32.7,32.7v36.9h23.2v-36.9c0-5.2,4.2-9.5,9.5-9.5Z"/>
|
||||
<path class="cls-1" d="M93.2,37.4h13.7V14.2h-13.7c-18,0-32.7,14.6-32.7,32.7v13.7h23.2v-13.7c0-5.2,4.2-9.5,9.5-9.5Z"/>
|
||||
<path class="cls-1" d="M269.3,46.9c0-18-14.6-32.7-32.7-32.7h-13.7v23.2h13.7c5.2,0,9.5,4.2,9.5,9.5v36.9h23.2v-36.9Z"/>
|
||||
<path class="cls-1" d="M199.7,83.8v-23.2h-116v23.2h-46.4v23.2H14.2v46.4c0,12.8,10.4,23.2,23.2,23.2h208.7c12.8,0,23.2-10.4,23.2-23.2v-46.4h-23.2v-23.2h-46.4ZM176.5,141.7c0-6.4-5.2-11.6-11.6-11.6s-11.6,5.2-11.6,11.6v11.6h-23.2v-23.2c0-12.8,10.4-23.2,23.2-23.2s23.2,10.4,23.2,23.2v11.6ZM246.1,141.7c0-6.4-5.2-11.6-11.6-11.6s-11.6,5.2-11.6,11.6v11.6h-23.2v-23.2c0-12.8,10.4-23.2,23.2-23.2s23.2,10.4,23.2,23.2v11.6Z"/>
|
||||
<path class="cls-1" d="M222.9,46.9c0-18-14.6-32.7-32.7-32.7h-13.7v23.2h13.7c5.2,0,9.5,4.2,9.5,9.5v13.7h23.2v-13.7Z"/>
|
||||
<rect class="cls-1" x="14.2" y="246.1" width="46.4" height="23.2"/>
|
||||
<rect class="cls-1" x="83.8" y="246.1" width="185.5" height="23.2"/>
|
||||
<rect class="cls-1" x="222.9" y="199.7" width="46.4" height="23.2"/>
|
||||
<rect class="cls-1" x="14.2" y="199.7" width="185.5" height="23.2"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.6 KiB |
16
image/axolotl_symbol_digital_white.svg
Normal file
@@ -0,0 +1,16 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 283.5 283.5">
|
||||
<defs>
|
||||
<style>
|
||||
.cls-1 {
|
||||
fill: #fff;
|
||||
}
|
||||
</style>
|
||||
</defs>
|
||||
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
|
||||
<g>
|
||||
<g id="Layer_1">
|
||||
<path class="cls-1" d="M152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM269.3,57.3c0-23.8-19.4-43.1-43.1-43.1h-12.2c-3.9,0-7.6,1.6-10.2,4.4-5.9-2.9-12.3-4.4-18.9-4.4h-12.2c-7.8,0-14.1,6.3-14.1,14.1v20.6c0,2.4.6,4.6,1.6,6.6h-37c1-2,1.6-4.2,1.6-6.6v-20.6c0-7.8-6.3-14.1-14.1-14.1h-12.2c-6.6,0-13,1.5-18.9,4.4-2.6-2.8-6.3-4.4-10.2-4.4h-12.2c-23.8,0-43.1,19.4-43.1,43.1v32.8c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v41.3c0,11,5.2,20.9,13.2,27.2-7.4.4-13.2,6.6-13.2,14v20.6c0,4.1,1.7,7.7,4.5,10.3-2.8,2.6-4.5,6.2-4.5,10.3v20.6c0,7.8,6.3,14.1,14.1,14.1h41.3c4.1,0,7.7-1.7,10.3-4.5,2.6,2.8,6.2,4.5,10.3,4.5h165.1c7.8,0,14.1-6.3,14.1-14.1v-20.6c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-20.6c0-7.5-5.9-13.6-13.2-14,8-6.4,13.2-16.2,13.2-27.2v-41.3c0-4.1-1.7-7.7-4.5-10.3,2.8-2.6,4.5-6.2,4.5-10.3v-32.8ZM69.5,255.2H28.2v-20.6h41.3v20.6ZM28.2,214v-20.6h165.1v20.6H28.2ZM255.2,255.2H90.1v-20.6h165.1v20.6ZM255.2,214h-41.3v-20.6h41.3v20.6ZM255.2,90.1h-20.6v20.6h20.6v41.3c0,11.4-9.2,20.6-20.6,20.6H48.9c-11.4,0-20.6-9.2-20.6-20.6v-41.3h20.6v-20.6h-20.6v-32.8c0-16.1,13-29.1,29.1-29.1h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v32.8h41.3v-20.6h-20.6v-12.2c0-16.1,13-29.1,29.1-29.1h12.2v20.6h-12.2c-4.7,0-8.4,3.8-8.4,8.4v12.2h103.2v-12.2c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16.1,0,29.1,13,29.1,29.1v12.2h-20.6v20.6h41.3v-32.8c0-4.7-3.8-8.4-8.4-8.4h-12.2v-20.6h12.2c16.1,0,29.1,13,29.1,29.1v32.8ZM193.3,152h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6s-20.6,9.2-20.6,20.6v20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM214,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6ZM152,110.8c-11.4,0-20.6,9.2-20.6,20.6v20.6h20.6v-10.3c0-5.7,4.6-10.3,10.3-10.3s10.3,4.6,10.3,10.3v-10.3c0-11.4-9.2-20.6-20.6-20.6Z"/>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 5.0 KiB |
17
image/axolotl_wordmark_digital_black.svg
Normal file
@@ -0,0 +1,17 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 765.4 212.6">
|
||||
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
|
||||
<g>
|
||||
<g id="Layer_1">
|
||||
<g>
|
||||
<path d="M121.6,198.1l-12.1-48.8h-54.4l-12.1,48.8h-24.7L66.6,12.9h31.6l47.9,185.1h-24.5ZM104.4,128.6l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
|
||||
<path d="M254.9,198.1l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.6-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.4,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.2,49.3,71.7h-28.5Z"/>
|
||||
<path d="M345.2,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM345.2,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path d="M547.3,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM547.3,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path d="M460.6,197.8c-18,0-32.6-14.6-32.6-32.6V12.5h24.1v152.6c0,4.7,3.8,8.5,8.5,8.5h16.8v24.1h-16.8Z"/>
|
||||
<path d="M722.8,198.1V45.2c0-4.7-3.8-8.5-8.5-8.5h-16.8V12.5h16.8c18,0,32.6,14.6,32.6,32.6v152.9h-24.1Z"/>
|
||||
<path d="M665.2,198.1c-18,0-32.6-14.6-32.6-32.6v-85.1h-20.3v-22.1h20.3V12.9h24.1v45.3h30.2v22.1h-30.2v85.1c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.1 KiB |
24
image/axolotl_wordmark_digital_white.svg
Normal file
@@ -0,0 +1,24 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" version="1.1" viewBox="0 0 765.4 212.6">
|
||||
<defs>
|
||||
<style>
|
||||
.cls-1 {
|
||||
fill: #fff;
|
||||
}
|
||||
</style>
|
||||
</defs>
|
||||
<!-- Generator: Adobe Illustrator 28.7.1, SVG Export Plug-In . SVG Version: 1.2.0 Build 142) -->
|
||||
<g>
|
||||
<g id="Layer_1">
|
||||
<g>
|
||||
<path class="cls-1" d="M121.6,198.1l-12.1-48.8h-54.4l-12.1,48.8h-24.7L66.6,12.9h31.6l47.9,185.1h-24.5ZM104.4,128.6l-13.8-55.6c-2.7-10.7-4.8-19.7-6.3-26.9-.9-4.2-1.5-7.5-2-9.9-.5,2.5-1.2,5.8-2,9.9-1.5,7.1-3.6,16.1-6.3,26.7l-13.8,55.9h44.3Z"/>
|
||||
<path class="cls-1" d="M254.9,198.1l-29.9-45.6c-1.2-1.9-2.4-4.1-3.5-6.5-.8-1.7-1.5-3.3-2.1-4.5-.6,1.3-1.4,2.8-2.3,4.5-1.3,2.4-2.6,4.6-4,6.5l-29.9,45.6h-28.5l49.6-71.9-46.5-67.9h28.5l27.6,43.1c1.2,1.9,2.3,3.9,3.4,6.1.7,1.4,1.4,2.7,1.9,3.8.5-1.1,1.1-2.4,1.8-3.8,1.1-2.2,2.2-4.2,3.4-6.1l27.6-43.1h28.5l-46.5,68.2,49.3,71.7h-28.5Z"/>
|
||||
<path class="cls-1" d="M345.2,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM345.2,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path class="cls-1" d="M547.3,200.1c-16.7,0-30.2-5-40.1-14.8-9.9-9.8-14.9-23.7-14.9-41.3v-31.7c0-17.7,5-31.7,14.8-41.4,9.8-9.7,23.4-14.7,40.3-14.7s30.4,4.9,40.3,14.7c9.8,9.7,14.8,23.7,14.8,41.4v31.7c0,17.6-5,31.5-14.9,41.3-9.9,9.8-23.4,14.8-40.1,14.8ZM547.3,77.8c-9.5,0-17.1,2.7-22.6,8.1-5.5,5.4-8.3,13.4-8.3,23.8v36.7c0,10.5,2.8,18.5,8.3,23.8,5.5,5.4,13.1,8.1,22.6,8.1s17.3-2.7,22.7-8.1c5.4-5.4,8.2-13.4,8.2-23.9v-36.7c0-10.5-2.8-18.5-8.2-23.9-5.4-5.4-13.1-8.1-22.7-8.1Z"/>
|
||||
<path class="cls-1" d="M460.6,197.8c-18,0-32.6-14.6-32.6-32.6V12.5h24.1v152.6c0,4.7,3.8,8.5,8.5,8.5h16.8v24.1h-16.8Z"/>
|
||||
<path class="cls-1" d="M722.8,198.1V45.2c0-4.7-3.8-8.5-8.5-8.5h-16.8V12.5h16.8c18,0,32.6,14.6,32.6,32.6v152.9h-24.1Z"/>
|
||||
<path class="cls-1" d="M665.2,198.1c-18,0-32.6-14.6-32.6-32.6v-85.1h-20.3v-22.1h20.3V12.9h24.1v45.3h30.2v22.1h-30.2v85.1c0,4.7,3.8,8.5,8.5,8.5h21.7v24.1h-21.7Z"/>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.3 KiB |
@@ -1,12 +1,12 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft==0.13.2
|
||||
transformers==4.46.1
|
||||
transformers==4.46.3
|
||||
tokenizers>=0.20.1
|
||||
bitsandbytes==0.44.1
|
||||
accelerate==1.1.0
|
||||
datasets==3.0.1
|
||||
deepspeed==0.15.3
|
||||
datasets==3.1.0
|
||||
deepspeed==0.15.4
|
||||
pydantic==2.6.3
|
||||
addict
|
||||
fire
|
||||
@@ -31,9 +31,9 @@ art
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
python-dotenv==1.0.1
|
||||
autoawq>=0.2.5
|
||||
autoawq==0.2.7.post2
|
||||
triton>=2.3.0
|
||||
liger-kernel==0.4.1
|
||||
liger-kernel==0.4.2
|
||||
|
||||
mamba-ssm==1.2.0.post1
|
||||
|
||||
@@ -53,3 +53,4 @@ immutabledict==4.2.0
|
||||
antlr4-python3-runtime==4.13.2
|
||||
|
||||
torchao==0.5.0
|
||||
schedulefree==1.3.0
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
# Export specific ENV variables to /etc/rp_environment
|
||||
echo "Exporting environment variables..."
|
||||
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
|
||||
printenv | grep -E '^HF_|^BNB_|^CUDA_|^NCCL_|^NV|^RUNPOD_|^PATH=|^_=' | sed 's/^\([^=]*\)=\(.*\)$/export \1="\2"/' | grep -v 'printenv' >> /etc/rp_environment
|
||||
echo 'source /etc/rp_environment' >> ~/.bashrc
|
||||
|
||||
add_keys_to_authorized() {
|
||||
|
||||
33
scripts/unsloth_install.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# noqa
|
||||
# pylint: skip-file
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError("Install torch via `pip install torch`")
|
||||
from packaging.version import Version as V
|
||||
|
||||
v = V(torch.__version__)
|
||||
cuda = str(torch.version.cuda)
|
||||
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
||||
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
|
||||
raise RuntimeError(f"CUDA = {cuda} not supported!")
|
||||
if v <= V("2.1.0"):
|
||||
raise RuntimeError(f"Torch = {v} too old!")
|
||||
elif v <= V("2.1.1"):
|
||||
x = "cu{}{}-torch211"
|
||||
elif v <= V("2.1.2"):
|
||||
x = "cu{}{}-torch212"
|
||||
elif v < V("2.3.0"):
|
||||
x = "cu{}{}-torch220"
|
||||
elif v < V("2.4.0"):
|
||||
x = "cu{}{}-torch230"
|
||||
elif v < V("2.5.0"):
|
||||
x = "cu{}{}-torch240"
|
||||
elif v < V("2.6.0"):
|
||||
x = "cu{}{}-torch250"
|
||||
else:
|
||||
raise RuntimeError(f"Torch = {v} too new!")
|
||||
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
||||
print(
|
||||
f'pip install unsloth-zoo && pip install --no-deps "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"'
|
||||
)
|
||||
6
setup.py
@@ -96,11 +96,11 @@ install_requires, dependency_links = parse_requirements()
|
||||
|
||||
setup(
|
||||
name="axolotl",
|
||||
version="0.5.0",
|
||||
version="0.5.2",
|
||||
description="LLM Trainer",
|
||||
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages(),
|
||||
packages=find_packages("src"),
|
||||
install_requires=install_requires,
|
||||
dependency_links=dependency_links,
|
||||
extras_require={
|
||||
@@ -108,7 +108,7 @@ setup(
|
||||
"flash-attn==2.7.0.post2",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.14.4",
|
||||
"deepspeed==0.15.4",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
|
||||
@@ -30,7 +30,10 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.chat_templates import get_chat_template
|
||||
from axolotl.utils.chat_templates import (
|
||||
get_chat_template,
|
||||
get_chat_template_from_config,
|
||||
)
|
||||
from axolotl.utils.comet_ import setup_comet_env_vars
|
||||
from axolotl.utils.config import (
|
||||
normalize_cfg_datasets,
|
||||
@@ -199,6 +202,10 @@ def do_inference(
|
||||
)
|
||||
elif cfg.chat_template:
|
||||
chat_template_str = get_chat_template(cfg.chat_template)
|
||||
elif cfg.datasets[0].type == "chat_template":
|
||||
chat_template_str = get_chat_template_from_config(
|
||||
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
|
||||
)
|
||||
|
||||
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
|
||||
@@ -1038,24 +1038,37 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def tokenize_row(
|
||||
self,
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
) -> Dict:
|
||||
res = super().tokenize_row(
|
||||
res = DPOTrainer.tokenize_row(
|
||||
features,
|
||||
processing_class,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
add_special_tokens,
|
||||
)
|
||||
if processing_class.bos_token_id is None and res["prompt_input_ids"][0] is None:
|
||||
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
|
||||
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
|
||||
for key in res.keys():
|
||||
res[key] = res[key][1:]
|
||||
|
||||
if processing_class.bos_token and processing_class.bos_token_id is not None:
|
||||
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
|
||||
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
|
||||
res["chosen_labels"] = res["chosen_labels"][1:]
|
||||
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
|
||||
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
|
||||
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
|
||||
res["rejected_labels"] = res["rejected_labels"][1:]
|
||||
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
|
||||
|
||||
return res
|
||||
|
||||
def training_step(
|
||||
@@ -1199,11 +1212,17 @@ class TrainerBuilderBase(abc.ABC):
|
||||
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
||||
"""
|
||||
callbacks = []
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
|
||||
)
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
[
|
||||
cb
|
||||
for cb in plugin_manager.add_callbacks_post_trainer(
|
||||
self.cfg, trainer
|
||||
)
|
||||
if cb
|
||||
]
|
||||
)
|
||||
return callbacks
|
||||
|
||||
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
||||
@@ -1250,7 +1269,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
callbacks = []
|
||||
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer, "wandb"
|
||||
@@ -1288,17 +1307,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||
callbacks.append(lisa_callback_factory(trainer))
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
callbacks.extend(
|
||||
[
|
||||
cb
|
||||
for cb in plugin_manager.add_callbacks_post_trainer(
|
||||
self.cfg, trainer
|
||||
)
|
||||
if cb
|
||||
]
|
||||
)
|
||||
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self):
|
||||
@@ -1416,17 +1425,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
training_arguments_kwargs["eval_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
elif self.cfg.evaluation_strategy:
|
||||
training_arguments_kwargs[
|
||||
"evaluation_strategy"
|
||||
] = self.cfg.evaluation_strategy
|
||||
elif self.cfg.eval_strategy:
|
||||
training_arguments_kwargs["eval_strategy"] = self.cfg.eval_strategy
|
||||
else:
|
||||
# we have an eval set, but no steps defined, default to use epoch
|
||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||
training_arguments_kwargs["eval_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.save_steps:
|
||||
training_arguments_kwargs["save_strategy"] = "steps"
|
||||
@@ -1860,10 +1867,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
if self.eval_dataset:
|
||||
training_args_kwargs["evaluation_strategy"] = "steps"
|
||||
training_args_kwargs["eval_strategy"] = "steps"
|
||||
training_args_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
else:
|
||||
training_args_kwargs["evaluation_strategy"] = "no"
|
||||
training_args_kwargs["eval_strategy"] = "no"
|
||||
|
||||
if self.cfg.bf16 or self.cfg.bfloat16:
|
||||
training_args_kwargs["bf16"] = True
|
||||
|
||||
0
src/axolotl/integrations/sageattention/__init__.py
Normal file
361
src/axolotl/integrations/sageattention/lib/core.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Copyright (c) 2024 by SageAttention team.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
|
||||
from .triton.attn_qk_int8_per_block_causal_varlen import (
|
||||
backward as sageattn_varlen_backward,
|
||||
)
|
||||
from .triton.attn_qk_int8_per_block_causal_varlen import forward as attn_true_varlen
|
||||
from .triton.quant_per_block_varlen import (
|
||||
per_block_int8 as per_block_int8_varlen_triton,
|
||||
)
|
||||
|
||||
|
||||
def get_cuda_arch_versions():
|
||||
cuda_archs = []
|
||||
for i in range(torch.cuda.device_count()):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
cuda_archs.append(f"sm{major}{minor}")
|
||||
return cuda_archs
|
||||
|
||||
|
||||
def sageattn_varlen(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
sm_scale: Optional[float] = None,
|
||||
smooth_k: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
q : torch.Tensor
|
||||
The query tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
|
||||
|
||||
k : torch.Tensor
|
||||
The key tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
|
||||
|
||||
v : torch.Tensor
|
||||
The value tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
|
||||
|
||||
cu_seqlens_q : torch.Tensor
|
||||
The cumulative sequence lengths for the query sequences in the batch, used to index into `q`.
|
||||
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
|
||||
|
||||
cu_seqlens_k : torch.Tensor
|
||||
The cumulative sequence lengths for the key and value sequences in the batch, used to index into `k` and `v`.
|
||||
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
|
||||
|
||||
max_seqlen_q : int
|
||||
The maximum sequence length for the query tensor in the batch.
|
||||
|
||||
max_seqlen_k : int
|
||||
The maximum sequence length for the key and value tensors in the batch.
|
||||
|
||||
is_causal : bool
|
||||
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len for each sequence.
|
||||
Default: False.
|
||||
|
||||
sm_scale : Optional[float]
|
||||
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
|
||||
|
||||
smooth_k : bool
|
||||
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
|
||||
Default: True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The output tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
|
||||
|
||||
Note
|
||||
----
|
||||
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
|
||||
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``.
|
||||
- The tensors `cu_seqlens_q` and `cu_seqlens_k` must have the dtype ``torch.int32`` or ``torch.int64``.
|
||||
- All tensors must be on the same cuda device.
|
||||
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
|
||||
"""
|
||||
|
||||
dtype = q.dtype
|
||||
assert q.is_cuda, "Input tensors must be on cuda."
|
||||
assert dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
|
||||
assert q.device == k.device == v.device, "All tensors must be on the same device."
|
||||
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
|
||||
|
||||
head_dim = q.size(-1)
|
||||
assert head_dim in [64, 128], "varlen only support head_dim [64, 128]."
|
||||
|
||||
assert (
|
||||
q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
|
||||
), "Last dim of qkv must be contiguous."
|
||||
assert (
|
||||
cu_seqlens_q.is_contiguous() and cu_seqlens_k.is_contiguous()
|
||||
), "cu_seqlens_q and cu_seqlens_k must be contiguous."
|
||||
|
||||
if dtype == torch.bfloat16 or dtype == torch.float32:
|
||||
v = v.to(torch.float16)
|
||||
|
||||
if smooth_k:
|
||||
km = k.mean(
|
||||
dim=0, keepdim=True
|
||||
) # ! km is calculated on the all the batches. Calculate over each individual sequence requires dedicated kernel.
|
||||
k -= km
|
||||
|
||||
(
|
||||
q_int8,
|
||||
q_scale,
|
||||
k_int8,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
) = per_block_int8_varlen_triton(
|
||||
q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=sm_scale
|
||||
)
|
||||
|
||||
o = attn_true_varlen(
|
||||
q_int8,
|
||||
k_int8,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
output_dtype=dtype,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
|
||||
class SageAttentionFunction(Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
):
|
||||
"""
|
||||
query: Tensor of shape [batch_size, num_heads, seq_len_q, head_dim]
|
||||
key: Tensor of shape [batch_size, num_heads, seq_len_k, head_dim]
|
||||
value: Tensor of shape [batch_size, num_heads, seq_len_k, head_dim]
|
||||
attn_mask: Optional[Tensor], mask tensor
|
||||
dropout_p: float, dropout probability
|
||||
is_causal: bool, whether to apply causal masking
|
||||
scale: Optional[float], scaling factor for attention scores
|
||||
"""
|
||||
# Ensure inputs are contiguous
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
|
||||
# Handle default scale
|
||||
if scale is None:
|
||||
scale = 1.0 / (query.size(-1) ** 0.5)
|
||||
|
||||
# Save parameters needed for backward
|
||||
ctx.scale = scale
|
||||
ctx.is_causal = is_causal
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.attn_mask = attn_mask
|
||||
|
||||
# Prepare cumulative sequence lengths and max sequence lengths
|
||||
# Assuming batch sizes are consistent across query, key, and value
|
||||
batch_size, num_heads, seq_len_q, head_dim = query.shape
|
||||
seq_len_k = key.shape[2]
|
||||
|
||||
# Flatten batch and head dimensions
|
||||
q = query.view(
|
||||
-1, seq_len_q, head_dim
|
||||
) # [batch_size * num_heads, seq_len_q, head_dim]
|
||||
k = key.view(-1, seq_len_k, head_dim)
|
||||
v = value.view(-1, seq_len_k, head_dim)
|
||||
|
||||
# Create cumulative sequence lengths
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
(batch_size * num_heads + 1) * seq_len_q,
|
||||
seq_len_q,
|
||||
dtype=torch.int32,
|
||||
device=query.device,
|
||||
)
|
||||
cu_seqlens_k = torch.arange(
|
||||
0,
|
||||
(batch_size * num_heads + 1) * seq_len_k,
|
||||
seq_len_k,
|
||||
dtype=torch.int32,
|
||||
device=key.device,
|
||||
)
|
||||
max_seqlen_q = seq_len_q
|
||||
max_seqlen_k = seq_len_k
|
||||
|
||||
# Call your custom per-block int8 quantization function
|
||||
(
|
||||
q_int8,
|
||||
q_scale,
|
||||
k_int8,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
) = per_block_int8_varlen_triton(
|
||||
q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=scale
|
||||
)
|
||||
|
||||
# Call your custom attention function
|
||||
if is_causal:
|
||||
output = attn_true_varlen(
|
||||
q_int8,
|
||||
k_int8,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
output_dtype=query.dtype,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Non-causal attention is not implemented yet.")
|
||||
|
||||
# Reshape output to match the expected shape
|
||||
output = output.view(batch_size, num_heads, seq_len_q, head_dim)
|
||||
|
||||
# Save tensors for backward
|
||||
ctx.save_for_backward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
q_int8,
|
||||
k_int8,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
output,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
q_int8,
|
||||
k_int8,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
output,
|
||||
) = ctx.saved_tensors
|
||||
|
||||
scale = ctx.scale
|
||||
is_causal = ctx.is_causal
|
||||
dropout_p = ctx.dropout_p
|
||||
attn_mask = ctx.attn_mask
|
||||
|
||||
# Flatten batch and head dimensions
|
||||
batch_size, num_heads, seq_len_q, head_dim = query.shape
|
||||
seq_len_k = key.shape[2]
|
||||
grad_output = grad_output.contiguous()
|
||||
do = grad_output.view(-1, seq_len_q, head_dim)
|
||||
|
||||
# Compute gradients w.r.t. q, k, v
|
||||
dq, dk, dv = sageattn_varlen_backward(
|
||||
do,
|
||||
query.view(-1, seq_len_q, head_dim),
|
||||
key.view(-1, seq_len_k, head_dim),
|
||||
value.view(-1, seq_len_k, head_dim),
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seq_len_q,
|
||||
seq_len_k,
|
||||
q_int8,
|
||||
k_int8,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
scale,
|
||||
is_causal,
|
||||
)
|
||||
|
||||
# Reshape gradients to match the input shapes
|
||||
dq = dq.view(batch_size, num_heads, seq_len_q, head_dim)
|
||||
dk = dk.view(batch_size, num_heads, seq_len_k, head_dim)
|
||||
dv = dv.view(batch_size, num_heads, seq_len_k, head_dim)
|
||||
|
||||
# Handle optional arguments
|
||||
d_attn_mask = None # Assuming attn_mask does not require gradients
|
||||
d_dropout_p = (
|
||||
None # Dropout probability is a hyperparameter, typically not optimized
|
||||
)
|
||||
d_is_causal = None # Not differentiable
|
||||
d_scale = None # If scale is a tensor and requires grad, compute its gradient
|
||||
|
||||
return dq, dk, dv, d_attn_mask, d_dropout_p, d_is_causal, d_scale
|
||||
|
||||
|
||||
def scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
scale=None,
|
||||
):
|
||||
"""
|
||||
Custom scaled dot product attention using SageAttentionFunction.
|
||||
"""
|
||||
return SageAttentionFunction.apply(
|
||||
query, key, value, attn_mask, dropout_p, is_causal, scale
|
||||
)
|
||||
|
||||
|
||||
def monkeypatch_sdp_w_sage_attention():
|
||||
"""
|
||||
Replace torch.nn.functional.scaled_dot_product_attention with custom scaled dot product attention using SageAttentionFunction.
|
||||
"""
|
||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||
@@ -0,0 +1,622 @@
|
||||
"""
|
||||
Copyright (c) 2024 by SageAttention team.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
q_scale,
|
||||
kv_len,
|
||||
K_ptrs,
|
||||
K_scale_ptr,
|
||||
V_ptrs,
|
||||
stride_kn,
|
||||
stride_vn,
|
||||
start_m,
|
||||
H: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
STAGE: tl.constexpr,
|
||||
offs_m: tl.constexpr,
|
||||
offs_n: tl.constexpr,
|
||||
):
|
||||
if STAGE == 1:
|
||||
lo, hi = 0, start_m * BLOCK_M
|
||||
elif STAGE == 2:
|
||||
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
||||
lo = tl.multiple_of(lo, BLOCK_M)
|
||||
K_scale_ptr += (lo // BLOCK_N) * H
|
||||
K_ptrs += stride_kn * lo
|
||||
V_ptrs += stride_vn * lo
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
k_mask = offs_n[None, :] < (kv_len - start_n)
|
||||
k = tl.load(K_ptrs, mask=k_mask)
|
||||
k_scale = tl.load(K_scale_ptr)
|
||||
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
|
||||
|
||||
if STAGE == 2:
|
||||
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
||||
qk = qk + tl.where(mask, 0, -1.0e6)
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
qk -= m_ij[:, None]
|
||||
else:
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
qk = qk - m_ij[:, None]
|
||||
|
||||
p = tl.math.exp2(qk)
|
||||
l_ij = tl.sum(p, 1)
|
||||
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
l_i = l_i * alpha + l_ij
|
||||
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n))
|
||||
p = p.to(tl.float16)
|
||||
|
||||
acc += tl.dot(p, v, out_dtype=tl.float16)
|
||||
m_i = m_ij
|
||||
K_ptrs += BLOCK_N * stride_kn
|
||||
K_scale_ptr += H
|
||||
V_ptrs += BLOCK_N * stride_vn
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
Q_scale,
|
||||
K_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
Out,
|
||||
stride_qh,
|
||||
stride_qn,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_vh,
|
||||
stride_vn,
|
||||
stride_oh,
|
||||
stride_on,
|
||||
H: tl.constexpr,
|
||||
num_kv_groups: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
STAGE: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
|
||||
off_z = tl.program_id(2).to(tl.int64)
|
||||
off_h = tl.program_id(1).to(tl.int64)
|
||||
|
||||
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||
|
||||
qo_len = cu_seqlens_q_end - cu_seqlens_q_start
|
||||
|
||||
if (start_m * BLOCK_M) >= qo_len:
|
||||
return
|
||||
|
||||
cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z)
|
||||
cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z)
|
||||
|
||||
q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H
|
||||
k_scale_offset = (
|
||||
cu_seq_lens_k_scale_start * (H // num_kv_groups) + off_h // num_kv_groups
|
||||
)
|
||||
|
||||
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||
|
||||
kv_len = cu_seqlens_k_end - cu_seqlens_k_start
|
||||
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, HEAD_DIM)
|
||||
Q_ptrs = (
|
||||
Q
|
||||
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
||||
+ offs_m[:, None] * stride_qn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
Q_scale_ptr = Q_scale + q_scale_offset
|
||||
K_ptrs = (
|
||||
K
|
||||
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
|
||||
+ offs_n[None, :] * stride_kn
|
||||
+ offs_k[:, None]
|
||||
)
|
||||
K_scale_ptr = K_scale + k_scale_offset
|
||||
V_ptrs = (
|
||||
V
|
||||
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
|
||||
+ offs_n[:, None] * stride_vn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
O_block_ptr = (
|
||||
Out
|
||||
+ (cu_seqlens_q_start * stride_on + off_h * stride_oh)
|
||||
+ offs_m[:, None] * stride_on
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
||||
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
||||
|
||||
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
|
||||
q_scale = tl.load(Q_scale_ptr)
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
q_scale,
|
||||
kv_len,
|
||||
K_ptrs,
|
||||
K_scale_ptr,
|
||||
V_ptrs,
|
||||
stride_kn,
|
||||
stride_vn,
|
||||
start_m,
|
||||
H // num_kv_groups,
|
||||
BLOCK_M,
|
||||
HEAD_DIM,
|
||||
BLOCK_N,
|
||||
4 - STAGE,
|
||||
offs_m,
|
||||
offs_n,
|
||||
)
|
||||
|
||||
acc, l_i, _ = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
q_scale,
|
||||
kv_len,
|
||||
K_ptrs,
|
||||
K_scale_ptr,
|
||||
V_ptrs,
|
||||
stride_kn,
|
||||
stride_vn,
|
||||
start_m,
|
||||
H // num_kv_groups,
|
||||
BLOCK_M,
|
||||
HEAD_DIM,
|
||||
BLOCK_N,
|
||||
2,
|
||||
offs_m,
|
||||
offs_n,
|
||||
)
|
||||
acc = acc / l_i[:, None]
|
||||
tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=(offs_m[:, None] < qo_len))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_bwd_inner(
|
||||
dq_acc,
|
||||
dk_acc,
|
||||
dv_acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
do,
|
||||
q_scale,
|
||||
k_scale,
|
||||
kv_len,
|
||||
stride_kn,
|
||||
stride_vn,
|
||||
start_m,
|
||||
H,
|
||||
BLOCK_M: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
STAGE: tl.constexpr,
|
||||
offs_m: tl.constexpr,
|
||||
offs_n: tl.constexpr,
|
||||
):
|
||||
if STAGE == 1:
|
||||
lo, hi = 0, start_m * BLOCK_M
|
||||
elif STAGE == 2:
|
||||
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
||||
lo = tl.multiple_of(lo, BLOCK_M)
|
||||
k += stride_kn * lo
|
||||
v += stride_vn * lo
|
||||
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
k_mask = offs_n[None, :] < (kv_len - start_n)
|
||||
k_curr = tl.load(k, mask=k_mask)
|
||||
v_curr = tl.load(v, mask=k_mask)
|
||||
k_scale_curr = tl.load(k_scale)
|
||||
s = tl.dot(q, k_curr, trans_b=True).to(tl.float32) * q_scale * k_scale_curr
|
||||
|
||||
if STAGE == 2:
|
||||
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
||||
s = s + tl.where(mask, 0.0, -float("inf"))
|
||||
m_ij = tl.maximum(m_i, tl.max(s, 1))
|
||||
s = s - m_ij[:, None]
|
||||
else:
|
||||
m_ij = tl.maximum(m_i, tl.max(s, 1))
|
||||
s = s - m_ij[:, None]
|
||||
|
||||
p = tl.math.exp2(s)
|
||||
l_ij = tl.sum(p, 1)
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
l_i = l_i * alpha + l_ij
|
||||
m_i = m_ij
|
||||
|
||||
p = p / l_i[:, None] # Normalize probabilities
|
||||
|
||||
# Compute gradients
|
||||
# Compute softmax gradient
|
||||
do_scaled = do / l_i[:, None]
|
||||
dv_contrib = tl.dot(p.to(tl.float16).T, do_scaled.to(tl.float16))
|
||||
dv_acc += dv_contrib
|
||||
|
||||
dp = tl.dot(do_scaled.to(tl.float16), v_curr.to(tl.float16).T)
|
||||
|
||||
# Compute ds (gradient w.r.t. logits s)
|
||||
p_dp = p * dp
|
||||
sum_p_dp = tl.sum(p_dp, axis=1)
|
||||
ds = (p_dp - p * sum_p_dp[:, None]) * tl.math.log(2.0) # Adjust for exp2
|
||||
|
||||
# Compute gradients w.r.t q and k
|
||||
dq_contrib = tl.dot(ds.to(tl.float16), k_curr.to(tl.float16))
|
||||
dk_contrib = tl.dot(ds.to(tl.float16).T, q.to(tl.float16))
|
||||
|
||||
dq_acc += dq_contrib * (q_scale * k_scale_curr)
|
||||
dk_acc += dk_contrib * (q_scale * k_scale_curr)
|
||||
|
||||
k += BLOCK_N * stride_kn
|
||||
k_scale += H
|
||||
v += BLOCK_N * stride_vn
|
||||
|
||||
return dq_acc, dk_acc, dv_acc, l_i, m_i
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_bwd(
|
||||
DO,
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
Q_scale,
|
||||
K_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
L,
|
||||
M,
|
||||
DQ,
|
||||
DK,
|
||||
DV,
|
||||
stride_qh,
|
||||
stride_qn,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_vh,
|
||||
stride_vn,
|
||||
H: tl.constexpr,
|
||||
num_kv_groups: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
STAGE: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_z = tl.program_id(2).to(tl.int64)
|
||||
off_h = tl.program_id(1).to(tl.int64)
|
||||
|
||||
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||
qo_len = cu_seqlens_q_end - cu_seqlens_q_start
|
||||
|
||||
if (start_m * BLOCK_M) >= qo_len:
|
||||
return
|
||||
|
||||
cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z)
|
||||
cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z)
|
||||
|
||||
q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H
|
||||
k_scale_offset = (
|
||||
cu_seq_lens_k_scale_start * (H // num_kv_groups) + off_h // num_kv_groups
|
||||
)
|
||||
|
||||
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||
kv_len = cu_seqlens_k_end - cu_seqlens_k_start
|
||||
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_k = tl.arange(0, HEAD_DIM)
|
||||
Q_ptrs = (
|
||||
Q
|
||||
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
||||
+ offs_m[:, None] * stride_qn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
DO_ptrs = (
|
||||
DO
|
||||
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
||||
+ offs_m[:, None] * stride_qn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
Q_scale_ptr = Q_scale + q_scale_offset
|
||||
K_ptrs = (
|
||||
K
|
||||
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
|
||||
+ offs_n[None, :] * stride_kn
|
||||
+ offs_k[:, None]
|
||||
)
|
||||
K_scale_ptr = K_scale + k_scale_offset
|
||||
V_ptrs = (
|
||||
V
|
||||
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
|
||||
+ offs_n[:, None] * stride_vn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
DQ_ptrs = (
|
||||
DQ
|
||||
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
|
||||
+ offs_m[:, None] * stride_qn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
DK_ptrs = (
|
||||
DK
|
||||
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
|
||||
+ offs_n[None, :] * stride_kn
|
||||
+ offs_k[:, None]
|
||||
)
|
||||
DV_ptrs = (
|
||||
DV
|
||||
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
|
||||
+ offs_n[:, None] * stride_vn
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
L_ptrs = L + (cu_seqlens_q_start + offs_m)
|
||||
M_ptrs = M + (cu_seqlens_q_start + offs_m)
|
||||
|
||||
m_i = tl.load(M_ptrs, mask=offs_m < qo_len, other=float("-inf"))
|
||||
l_i = tl.load(L_ptrs, mask=offs_m < qo_len, other=1.0)
|
||||
|
||||
dq_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
|
||||
dk_acc = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32)
|
||||
dv_acc = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32)
|
||||
|
||||
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
|
||||
do = tl.load(DO_ptrs, mask=offs_m[:, None] < qo_len)
|
||||
q_scale = tl.load(Q_scale_ptr)
|
||||
|
||||
dq_acc, dk_acc, dv_acc, l_i, m_i = _attn_bwd_inner(
|
||||
dq_acc,
|
||||
dk_acc,
|
||||
dv_acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_ptrs,
|
||||
V_ptrs,
|
||||
do,
|
||||
q_scale,
|
||||
K_scale_ptr,
|
||||
kv_len,
|
||||
stride_kn,
|
||||
stride_vn,
|
||||
start_m,
|
||||
H // num_kv_groups,
|
||||
BLOCK_M,
|
||||
HEAD_DIM,
|
||||
BLOCK_N,
|
||||
4 - STAGE,
|
||||
offs_m,
|
||||
offs_n,
|
||||
)
|
||||
|
||||
dq_acc, dk_acc, dv_acc, l_i, m_i = _attn_bwd_inner(
|
||||
dq_acc,
|
||||
dk_acc,
|
||||
dv_acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_ptrs,
|
||||
V_ptrs,
|
||||
do,
|
||||
q_scale,
|
||||
K_scale_ptr,
|
||||
kv_len,
|
||||
stride_kn,
|
||||
stride_vn,
|
||||
start_m,
|
||||
H // num_kv_groups,
|
||||
BLOCK_M,
|
||||
HEAD_DIM,
|
||||
BLOCK_N,
|
||||
2,
|
||||
offs_m,
|
||||
offs_n,
|
||||
)
|
||||
|
||||
tl.store(DQ_ptrs, dq_acc.to(DQ.dtype.element_ty), mask=offs_m[:, None] < qo_len)
|
||||
tl.store(DK_ptrs, dk_acc.to(DK.dtype.element_ty), mask=offs_n[None, :] < kv_len)
|
||||
tl.store(DV_ptrs, dv_acc.to(DV.dtype.element_ty), mask=offs_n[:, None] < kv_len)
|
||||
|
||||
|
||||
def forward(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
output_dtype=torch.float16,
|
||||
):
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64
|
||||
stage = 3
|
||||
|
||||
o = torch.empty(q.shape, dtype=output_dtype, device=q.device)
|
||||
|
||||
b = cu_seqlens_q.shape[0] - 1
|
||||
_, h_qo, head_dim = q.shape
|
||||
_, h_kv, _ = k.shape
|
||||
|
||||
HEAD_DIM_K = head_dim
|
||||
num_kv_groups = h_qo // h_kv
|
||||
|
||||
grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b)
|
||||
_attn_fwd[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
o,
|
||||
q.stride(1),
|
||||
q.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(0),
|
||||
o.stride(1),
|
||||
o.stride(0),
|
||||
h_qo,
|
||||
num_kv_groups,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
HEAD_DIM=HEAD_DIM_K,
|
||||
STAGE=stage,
|
||||
num_warps=4 if head_dim == 64 else 8,
|
||||
num_stages=4,
|
||||
)
|
||||
return o
|
||||
|
||||
|
||||
def backward(
|
||||
do,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
l,
|
||||
m,
|
||||
output_dtype=torch.float16,
|
||||
):
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64
|
||||
stage = 3
|
||||
|
||||
device = q.device
|
||||
dtype = q.dtype
|
||||
b = cu_seqlens_q.shape[0] - 1
|
||||
_, h_qo, head_dim = q.shape
|
||||
_, h_kv, _ = k.shape
|
||||
num_kv_groups = h_qo // h_kv
|
||||
|
||||
dq = torch.zeros_like(q, dtype=output_dtype)
|
||||
dk = torch.zeros_like(k, dtype=output_dtype)
|
||||
dv = torch.zeros_like(v, dtype=output_dtype)
|
||||
|
||||
grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b)
|
||||
_attn_bwd[grid](
|
||||
do,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
q_scale,
|
||||
k_scale,
|
||||
cu_seqlens_q_scale,
|
||||
cu_seqlens_k_scale,
|
||||
l,
|
||||
m,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
q.stride(1),
|
||||
q.stride(0),
|
||||
k.stride(1),
|
||||
k.stride(0),
|
||||
v.stride(1),
|
||||
v.stride(0),
|
||||
h_qo,
|
||||
num_kv_groups,
|
||||
HEAD_DIM=head_dim,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
STAGE=stage,
|
||||
num_warps=4 if head_dim == 64 else 8,
|
||||
num_stages=4,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
|
||||
# class TritonAttentionFunction(torch.autograd.Function):
|
||||
# @staticmethod
|
||||
# def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale):
|
||||
# l = torch.zeros(q.shape[0], device=q.device, dtype=torch.float32)
|
||||
# m = torch.zeros(q.shape[0], device=q.device, dtype=torch.float32)
|
||||
# output = forward(q, k, v, cu_seqlens_q, cu_seqlens_k, q.shape[0], q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m)
|
||||
# ctx.save_for_backward(q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m)
|
||||
# return output
|
||||
#
|
||||
# @staticmethod
|
||||
# def backward(ctx, do):
|
||||
# q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m = ctx.saved_tensors
|
||||
# dq, dk, dv = backward(
|
||||
# do, q, k, v,
|
||||
# cu_seqlens_q, cu_seqlens_k,
|
||||
# q.shape[0], q_scale, k_scale,
|
||||
# cu_seqlens_q_scale, cu_seqlens_k_scale,
|
||||
# l, m,
|
||||
# )
|
||||
# return dq, dk, dv, None, None, None, None, None, None
|
||||
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Copyright (c) 2024 by SageAttention team.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def quant_per_block_int8_kernel(
|
||||
Input,
|
||||
Output,
|
||||
Scale,
|
||||
cu_seqlens_input,
|
||||
cu_seqlens_scale,
|
||||
stride_ih,
|
||||
stride_in,
|
||||
stride_oh,
|
||||
stride_on,
|
||||
sm_scale,
|
||||
H: tl.constexpr,
|
||||
C: tl.constexpr,
|
||||
BLK: tl.constexpr,
|
||||
):
|
||||
off_blk = tl.program_id(0)
|
||||
off_h = tl.program_id(1)
|
||||
off_b = tl.program_id(2)
|
||||
|
||||
cu_seqlens_input_start = tl.load(cu_seqlens_input + off_b)
|
||||
cu_seqlens_input_end = tl.load(cu_seqlens_input + off_b + 1)
|
||||
|
||||
L = cu_seqlens_input_end - cu_seqlens_input_start
|
||||
|
||||
if (off_blk * BLK) >= L:
|
||||
return
|
||||
|
||||
cu_seqlens_scale_start = tl.load(cu_seqlens_scale + off_b)
|
||||
|
||||
offs_n = off_blk * BLK + tl.arange(0, BLK)
|
||||
offs_k = tl.arange(0, C)
|
||||
|
||||
input_ptrs = (
|
||||
Input
|
||||
+ cu_seqlens_input_start * stride_in
|
||||
+ off_h * stride_ih
|
||||
+ offs_n[:, None] * stride_in
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
output_ptrs = (
|
||||
Output
|
||||
+ cu_seqlens_input_start * stride_on
|
||||
+ off_h * stride_oh
|
||||
+ offs_n[:, None] * stride_on
|
||||
+ offs_k[None, :]
|
||||
)
|
||||
scale_ptrs = Scale + cu_seqlens_scale_start * H + off_h + off_blk * H
|
||||
|
||||
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
|
||||
x = x.to(tl.float32)
|
||||
x *= sm_scale
|
||||
scale = tl.max(tl.abs(x)) / 127.0
|
||||
x_int8 = x / scale
|
||||
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
|
||||
x_int8 = x_int8.to(tl.int8)
|
||||
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
|
||||
tl.store(scale_ptrs, scale)
|
||||
|
||||
|
||||
def per_block_int8(
|
||||
q,
|
||||
k,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
BLKQ=128,
|
||||
BLKK=64,
|
||||
sm_scale=None,
|
||||
):
|
||||
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
|
||||
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
|
||||
|
||||
h_qo = q.shape[1]
|
||||
h_kv = k.shape[1]
|
||||
head_dim = q.shape[-1]
|
||||
|
||||
b = cu_seqlens_q.shape[0] - 1
|
||||
q_batch_len = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
|
||||
k_batch_len = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
|
||||
|
||||
q_scale_len = (q_batch_len + BLKQ - 1) // BLKQ
|
||||
k_scale_len = (k_batch_len + BLKK - 1) // BLKK
|
||||
|
||||
cu_seqlens_q_scale = torch.nn.functional.pad(
|
||||
torch.cumsum(q_scale_len, dim=0), (1, 0), value=0
|
||||
)
|
||||
cu_seqlens_k_scale = torch.nn.functional.pad(
|
||||
torch.cumsum(k_scale_len, dim=0), (1, 0), value=0
|
||||
)
|
||||
|
||||
q_scale = torch.empty(
|
||||
(cu_seqlens_q_scale[-1], h_qo), device=q.device, dtype=torch.float32
|
||||
)
|
||||
k_scale = torch.empty(
|
||||
(cu_seqlens_k_scale[-1], h_kv), device=k.device, dtype=torch.float32
|
||||
)
|
||||
|
||||
if sm_scale is None:
|
||||
sm_scale = head_dim**-0.5
|
||||
|
||||
grid = ((max_seqlen_q + BLKQ - 1) // BLKQ, h_qo, b)
|
||||
quant_per_block_int8_kernel[grid](
|
||||
q,
|
||||
q_int8,
|
||||
q_scale,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_q_scale,
|
||||
q.stride(1),
|
||||
q.stride(0),
|
||||
q_int8.stride(1),
|
||||
q_int8.stride(0),
|
||||
sm_scale=(sm_scale * 1.44269504),
|
||||
H=h_qo,
|
||||
C=head_dim,
|
||||
BLK=BLKQ,
|
||||
)
|
||||
|
||||
grid = ((max_seqlen_k + BLKK - 1) // BLKK, h_kv, b)
|
||||
quant_per_block_int8_kernel[grid](
|
||||
k,
|
||||
k_int8,
|
||||
k_scale,
|
||||
cu_seqlens_k,
|
||||
cu_seqlens_k_scale,
|
||||
k.stride(1),
|
||||
k.stride(0),
|
||||
k_int8.stride(1),
|
||||
k_int8.stride(0),
|
||||
sm_scale=1.0,
|
||||
H=h_kv,
|
||||
C=head_dim,
|
||||
BLK=BLKK,
|
||||
)
|
||||
|
||||
return q_int8, q_scale, k_int8, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale
|
||||
0
src/axolotl/monkeypatch/__init__.py
Normal file
0
src/axolotl/monkeypatch/attention/__init__.py
Normal file
@@ -1,4 +1,5 @@
|
||||
"""multipack patching for v2 of sample packing"""
|
||||
|
||||
import importlib
|
||||
|
||||
import transformers
|
||||
@@ -27,71 +28,28 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
]
|
||||
|
||||
|
||||
def patch_for_multipack(model_type, model_name=None, is_remote_code=False):
|
||||
if model_type == "gemmoe":
|
||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||
elif model_type == "deepseek_v2":
|
||||
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
||||
elif hasattr(transformers, "modeling_flash_attention_utils") and not is_remote_code:
|
||||
def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
||||
if has_remote_code:
|
||||
patch_remote(model_name)
|
||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
return
|
||||
|
||||
# retain for legacy
|
||||
if model_type == "mixtral":
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
elif model_type == "llama":
|
||||
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "mistral":
|
||||
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "qwen2":
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "qwen2_moe":
|
||||
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "falcon":
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "phi":
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemma":
|
||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemma2":
|
||||
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "starcoder2":
|
||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
|
||||
|
||||
def patch_remote(model_name, config_name, modeling_name):
|
||||
def patch_remote(model_name):
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_* to be available
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
|
||||
parts = model_config.__class__.__module__.split(".")
|
||||
parts[-1] = parts[-1].replace("configuration_", "modeling_", 1)
|
||||
module_name = ".".join(parts)
|
||||
modeling_arch = importlib.import_module(module_name)
|
||||
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
|
||||
if hasattr(modeling_arch, "_get_unpad_data"):
|
||||
modeling_arch._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
|
||||
@@ -188,7 +188,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
||||
for module in layer_modules
|
||||
)
|
||||
mlp_not_dora = all(
|
||||
getattr(module, "lora_magnitude_vector", None) is None
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
@@ -213,7 +213,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||
for module in layer_modules
|
||||
)
|
||||
qkv_not_dora = all(
|
||||
getattr(module, "lora_magnitude_vector", None) is None
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
@@ -232,7 +232,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||
for module in layer_modules
|
||||
)
|
||||
o_not_dora = all(
|
||||
getattr(module, "lora_magnitude_vector", None) is None
|
||||
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
|
||||
for module in layer_modules
|
||||
)
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@ import functools
|
||||
import pynvml
|
||||
import torch
|
||||
from pynvml.nvml import NVMLError
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.distributed import get_device_type
|
||||
|
||||
|
||||
def check_cuda_device(default_value):
|
||||
@@ -53,6 +56,12 @@ def mps_memory_usage_all():
|
||||
return usage, reserved - usage, 0
|
||||
|
||||
|
||||
def npu_memory_usage_all(device=0):
|
||||
usage = torch.npu.memory_allocated(device) / 1024.0**3
|
||||
reserved = torch.npu.memory_reserved(device) / 1024.0**3
|
||||
return usage, reserved - usage, 0
|
||||
|
||||
|
||||
@check_cuda_device(0.0)
|
||||
def gpu_memory_usage_smi(device=0):
|
||||
if isinstance(device, torch.device):
|
||||
@@ -69,8 +78,11 @@ def gpu_memory_usage_smi(device=0):
|
||||
|
||||
|
||||
def log_gpu_memory_usage(log, msg, device):
|
||||
cur_device = get_device_type()
|
||||
if torch.backends.mps.is_available():
|
||||
usage, cache, misc = mps_memory_usage_all()
|
||||
elif "npu" in str(cur_device) and is_torch_npu_available():
|
||||
usage, cache, misc = npu_memory_usage_all(device)
|
||||
else:
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
extras = []
|
||||
@@ -79,6 +91,7 @@ def log_gpu_memory_usage(log, msg, device):
|
||||
if misc > 0:
|
||||
extras.append(f"+{misc:.03f}GB misc")
|
||||
log.info(
|
||||
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
|
||||
f"{str(cur_device)} memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})",
|
||||
stacklevel=2,
|
||||
)
|
||||
return usage, cache, misc
|
||||
|
||||
@@ -64,10 +64,7 @@ class EvalFirstStepCallback(
|
||||
control: TrainerControl,
|
||||
**kwargs,
|
||||
):
|
||||
if (
|
||||
args.evaluation_strategy == IntervalStrategy.STEPS
|
||||
and state.global_step == 1
|
||||
):
|
||||
if args.eval_strategy == IntervalStrategy.STEPS and state.global_step == 1:
|
||||
control.should_evaluate = True
|
||||
return control
|
||||
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
"""Module for working with config dicts"""
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.integrations.config import merge_input_args
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.config.models.input.v0_4_1 import SUPPORTED_METRICS
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
)
|
||||
@@ -32,7 +30,10 @@ def choose_device(cfg):
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
|
||||
raise SystemError("No CUDA/mps device found")
|
||||
if is_torch_npu_available():
|
||||
return f"npu:{cfg.local_rank}"
|
||||
|
||||
raise SystemError("No CUDA/mps/npu device found")
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
return "cpu"
|
||||
|
||||
@@ -42,6 +43,8 @@ def choose_device(cfg):
|
||||
else:
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": torch.cuda.current_device()}
|
||||
elif cfg.device.startswith("npu"):
|
||||
cfg.device_map = {"npu": torch.npu.current_device()}
|
||||
else:
|
||||
cfg.device_map = {"": cfg.device}
|
||||
|
||||
@@ -247,370 +250,3 @@ def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
|
||||
return DictDefault(
|
||||
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
|
||||
)
|
||||
|
||||
|
||||
def legacy_validate_config(cfg):
|
||||
"""
|
||||
This is a "pre-validation" step that handles the yaml configuration before we have any
|
||||
information about the model architecture
|
||||
"""
|
||||
if is_torch_bf16_gpu_available():
|
||||
if not cfg.bf16 and not cfg.bfloat16:
|
||||
LOG.info("bf16 support detected, but not enabled for this configuration.")
|
||||
else:
|
||||
if (
|
||||
not cfg.merge_lora
|
||||
and not cfg.is_preprocess
|
||||
and (cfg.bf16 is True or cfg.bfloat16 is True)
|
||||
):
|
||||
raise ValueError(
|
||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
||||
)
|
||||
if (
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
not (cfg.bf16 or cfg.bfloat16)
|
||||
and (cfg.fp16 or cfg.float16)
|
||||
and not cfg.adapter
|
||||
and not cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
LOG.warning(
|
||||
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
|
||||
)
|
||||
# ValueError: Attempting to unscale FP16 gradients.
|
||||
# OR
|
||||
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
|
||||
if cfg.max_packed_sequence_len:
|
||||
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
||||
|
||||
if cfg.sample_packing and cfg.rl:
|
||||
raise ValueError("`sample_packing: true` does not work with RLHF training")
|
||||
|
||||
if cfg.sample_packing and not cfg.pad_to_sequence_len:
|
||||
LOG.warning(
|
||||
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||
)
|
||||
|
||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||
raise ValueError(
|
||||
"please set only one of gradient_accumulation_steps or batch_size"
|
||||
)
|
||||
if cfg.batch_size:
|
||||
LOG.warning(
|
||||
"%s\n%s",
|
||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||
)
|
||||
if (
|
||||
cfg.eval_batch_size
|
||||
and cfg.micro_batch_size
|
||||
and cfg.eval_batch_size != cfg.micro_batch_size
|
||||
):
|
||||
LOG.warning(
|
||||
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||
)
|
||||
|
||||
if cfg.adapter == "qlora":
|
||||
if cfg.merge_lora:
|
||||
# can't merge qlora if loaded in 8bit or 4bit
|
||||
if cfg.load_in_8bit:
|
||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||
|
||||
if cfg.gptq:
|
||||
raise ValueError("Can't merge qlora if gptq")
|
||||
|
||||
if cfg.load_in_4bit:
|
||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||
|
||||
else:
|
||||
if cfg.load_in_8bit:
|
||||
raise ValueError("Can't load qlora in 8bit")
|
||||
|
||||
if cfg.gptq:
|
||||
raise ValueError("Can't load qlora if gptq")
|
||||
|
||||
if not cfg.load_in_4bit:
|
||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||
|
||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with QLoRA")
|
||||
|
||||
loftq = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
|
||||
if not cfg.load_in_8bit and cfg.adapter == "lora" and not loftq:
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
|
||||
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
||||
raise ValueError("Fused modules are not supported with LoRA")
|
||||
|
||||
if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters:
|
||||
raise ValueError(
|
||||
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
|
||||
)
|
||||
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
|
||||
if cfg.fsdp:
|
||||
raise ValueError("fsdp not supported with ReLoRA")
|
||||
|
||||
if cfg.deepspeed:
|
||||
raise ValueError("deepspeed not supported with ReLoRA")
|
||||
|
||||
if cfg.lr_scheduler == "one_cycle":
|
||||
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
||||
|
||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with ReLoRA")
|
||||
|
||||
if cfg.trust_remote_code:
|
||||
LOG.warning(
|
||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||
)
|
||||
|
||||
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
|
||||
raise ValueError(
|
||||
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
|
||||
)
|
||||
|
||||
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
||||
raise ValueError("FSDP is not supported for falcon models")
|
||||
|
||||
if (
|
||||
cfg.base_model and "mpt" in cfg.base_model.lower()
|
||||
) and cfg.gradient_checkpointing:
|
||||
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
||||
|
||||
if cfg.flash_optimum is True:
|
||||
if cfg.adapter:
|
||||
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
||||
if cfg.fp16 or cfg.bf16:
|
||||
raise ValueError("AMP is not supported with BetterTransformer")
|
||||
if cfg.float16 is not True and cfg.bfloat16 is not True:
|
||||
LOG.warning(
|
||||
"You should probably set bfloat16 or float16 to true to "
|
||||
"load the model in float16 for BetterTransformers"
|
||||
)
|
||||
if int(torch.__version__.split(".", maxsplit=1)[0]) < 2:
|
||||
LOG.warning("torch>=2.0.0 required")
|
||||
raise ValueError(
|
||||
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
||||
)
|
||||
|
||||
if cfg.pretraining_dataset and cfg.group_by_length:
|
||||
LOG.warning(
|
||||
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
||||
)
|
||||
if cfg.pretraining_dataset and not cfg.max_steps:
|
||||
raise ValueError(
|
||||
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
|
||||
)
|
||||
|
||||
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
||||
not cfg.optimizer or "adamw" not in cfg.optimizer
|
||||
):
|
||||
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
|
||||
|
||||
if cfg.push_to_hub_model_id:
|
||||
raise ValueError(
|
||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||
)
|
||||
|
||||
if cfg.hub_model_id and cfg.save_strategy not in ["steps", "epoch", None]:
|
||||
LOG.warning(
|
||||
"hub_model_id is set without any models being saved. To save a model, set save_strategy to steps, epochs or leave empty."
|
||||
)
|
||||
|
||||
if cfg.gptq and cfg.revision_of_model:
|
||||
raise ValueError(
|
||||
"revision_of_model is not supported for GPTQ models. "
|
||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||
+ "point to its path, and remove revision_of_model from the config."
|
||||
)
|
||||
|
||||
# if cfg.sample_packing and cfg.sdp_attention:
|
||||
# # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
# raise ValueError(
|
||||
# "sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||
# )
|
||||
|
||||
if cfg.sample_packing and cfg.xformers_attention:
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
|
||||
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
|
||||
LOG.warning(
|
||||
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
|
||||
"This may work on H100s."
|
||||
)
|
||||
|
||||
if cfg.early_stopping_patience:
|
||||
if not cfg.save_steps or not cfg.eval_steps:
|
||||
raise ValueError(
|
||||
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
||||
)
|
||||
if cfg.save_steps % cfg.eval_steps != 0:
|
||||
raise ValueError(
|
||||
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
||||
)
|
||||
|
||||
if cfg.saves_per_epoch and cfg.save_steps:
|
||||
raise ValueError(
|
||||
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if cfg.save_strategy and cfg.saves_per_epoch and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||
)
|
||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||
)
|
||||
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||
raise ValueError(
|
||||
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||
)
|
||||
if (
|
||||
cfg.evals_per_epoch
|
||||
and cfg.evaluation_strategy
|
||||
and cfg.evaluation_strategy != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
if (
|
||||
cfg.evaluation_strategy
|
||||
and cfg.eval_steps
|
||||
and cfg.evaluation_strategy != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.val_set_size == 0
|
||||
and (cfg.eval_steps or cfg.evaluation_strategy)
|
||||
and not cfg.test_datasets
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.sample_packing
|
||||
and cfg.eval_table_size
|
||||
and cfg.eval_sample_packing is not False
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
|
||||
)
|
||||
|
||||
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
|
||||
raise ValueError(
|
||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
|
||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||
)
|
||||
|
||||
if cfg.rope_scaling:
|
||||
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
||||
|
||||
if cfg.wandb_run_id and not cfg.wandb_name:
|
||||
cfg.wandb_name = cfg.wandb_run_id
|
||||
|
||||
LOG.warning(
|
||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||
)
|
||||
|
||||
if cfg.noisy_embedding_alpha is not None:
|
||||
# Deprecated, use neftune_noise_alpha
|
||||
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||
if cfg.neftune_noise_alpha is None:
|
||||
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
|
||||
else:
|
||||
# User is providing both; bail and have them sort out their settings
|
||||
raise ValueError(
|
||||
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
||||
)
|
||||
|
||||
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
||||
raise ValueError("neftune_noise_alpha must be > 0.0")
|
||||
|
||||
if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
|
||||
raise ValueError(
|
||||
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.unfrozen_parameters
|
||||
and cfg.gradient_checkpointing_kwargs
|
||||
and cfg.gradient_checkpointing_kwargs.use_reentrant is True
|
||||
):
|
||||
# https://github.com/huggingface/transformers/issues/21381
|
||||
raise ValueError(
|
||||
"`use_reentrant` must be false when used with partially frozen model."
|
||||
)
|
||||
|
||||
if cfg.deepspeed and Path(cfg.deepspeed).is_file():
|
||||
with open(cfg.deepspeed, encoding="utf-8") as file:
|
||||
contents = file.read()
|
||||
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
|
||||
if cfg.flash_attention:
|
||||
if (
|
||||
deepspeed_cfg.zero_optimization
|
||||
and deepspeed_cfg.zero_optimization.stage == 3
|
||||
):
|
||||
if not (
|
||||
(
|
||||
deepspeed_cfg.bf16
|
||||
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
|
||||
is True
|
||||
)
|
||||
or (
|
||||
deepspeed_cfg.fp16
|
||||
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
|
||||
is True
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
|
||||
)
|
||||
if "8bit" in cfg.optimizer and deepspeed_cfg.optimizer:
|
||||
LOG.warning(
|
||||
f"conflicting optimizer: {cfg.optimizer} used alongside deepspeed optimizer."
|
||||
)
|
||||
|
||||
if cfg.test_datasets and cfg.val_set_size:
|
||||
raise ValueError(
|
||||
"non-zero val_set_size should not be used with test_datasets configuration"
|
||||
)
|
||||
|
||||
if cfg.fsdp and "bnb" in cfg.optimizer:
|
||||
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
|
||||
|
||||
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
|
||||
raise ValueError(
|
||||
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
|
||||
)
|
||||
|
||||
if cfg.eval_causal_lm_metrics:
|
||||
if not isinstance(cfg.eval_causal_lm_metrics, list):
|
||||
raise ValueError("eval_causal_lm_metrics must be a list")
|
||||
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
||||
if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS:
|
||||
raise ValueError(
|
||||
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
|
||||
)
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
# no 8bit adaAmw w bf16
|
||||
|
||||
# GPT-NeoX
|
||||
# evals broken when extending context len
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
|
||||
# attention_mask = causal_mask + attention_mask
|
||||
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
|
||||
|
||||
@@ -7,7 +7,6 @@ Module for pydantic models for configuration
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from importlib.metadata import version
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from pydantic import (
|
||||
@@ -20,6 +19,7 @@ from pydantic import (
|
||||
)
|
||||
from transformers import SchedulerType
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils.import_utils import is_torch_npu_available
|
||||
|
||||
from axolotl.utils.config.models.internals import GPUCapabilities
|
||||
|
||||
@@ -68,6 +68,7 @@ class DeprecatedParameters(BaseModel):
|
||||
rope_scaling: Optional[Any] = None
|
||||
noisy_embedding_alpha: Optional[float] = None
|
||||
dpo_beta: Optional[float] = None
|
||||
evaluation_strategy: Optional[str] = None
|
||||
|
||||
@field_validator("max_packed_sequence_len")
|
||||
@classmethod
|
||||
@@ -99,6 +100,13 @@ class DeprecatedParameters(BaseModel):
|
||||
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
||||
return dpo_beta
|
||||
|
||||
@field_validator("evaluation_strategy")
|
||||
@classmethod
|
||||
def validate_evaluation_strategy(cls, evaluation_strategy):
|
||||
if evaluation_strategy is not None:
|
||||
LOG.warning("evaluation_strategy is deprecated, use eval_strategy instead")
|
||||
return evaluation_strategy
|
||||
|
||||
|
||||
class RemappedParameters(BaseModel):
|
||||
"""parameters that have been remapped to other names"""
|
||||
@@ -242,8 +250,10 @@ class KTODataset(BaseModel):
|
||||
class LoftQConfig(BaseModel):
|
||||
"""LoftQ configuration subset"""
|
||||
|
||||
loftq_bits: int = Field(default=4, metadata={"help": "Quantization bits for LoftQ"})
|
||||
# loftq_iter: int = Field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
|
||||
loftq_bits: int = Field(
|
||||
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
|
||||
)
|
||||
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
|
||||
|
||||
|
||||
class PeftConfig(BaseModel):
|
||||
@@ -286,8 +296,8 @@ class LoraConfig(BaseModel):
|
||||
|
||||
qlora_sharded_model_loading: Optional[bool] = Field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||
json_schema_extra={
|
||||
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
|
||||
},
|
||||
)
|
||||
lora_on_cpu: Optional[bool] = None
|
||||
@@ -296,13 +306,15 @@ class LoraConfig(BaseModel):
|
||||
|
||||
loraplus_lr_ratio: Optional[float] = Field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
|
||||
json_schema_extra={
|
||||
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
|
||||
},
|
||||
)
|
||||
loraplus_lr_embedding: Optional[float] = Field(
|
||||
default=1e-6,
|
||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||
json_schema_extra={
|
||||
"description": "loraplus learning rate for lora embedding layers."
|
||||
},
|
||||
)
|
||||
|
||||
merge_lora: Optional[bool] = None
|
||||
@@ -372,10 +384,10 @@ class ModelInputConfig(BaseModel):
|
||||
tokenizer_use_fast: Optional[bool] = None
|
||||
tokenizer_legacy: Optional[bool] = None
|
||||
tokenizer_type: Optional[str] = Field(
|
||||
default=None, metadata={"help": "transformers tokenizer class"}
|
||||
default=None, json_schema_extra={"description": "transformers tokenizer class"}
|
||||
)
|
||||
processor_type: Optional[str] = Field(
|
||||
default=None, metadata={"help": "transformers processor class"}
|
||||
default=None, json_schema_extra={"description": "transformers processor class"}
|
||||
)
|
||||
trust_remote_code: Optional[bool] = None
|
||||
|
||||
@@ -397,18 +409,18 @@ class HyperparametersConfig(BaseModel):
|
||||
gradient_accumulation_steps: Optional[int] = Field(default=1)
|
||||
micro_batch_size: Optional[int] = Field(
|
||||
default=1,
|
||||
metadata={"help": "per gpu micro batch size for training"},
|
||||
json_schema_extra={"description": "per gpu micro batch size for training"},
|
||||
)
|
||||
batch_size: Optional[int] = Field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Total batch size, we do not recommended setting this manually"
|
||||
json_schema_extra={
|
||||
"description": "Total batch size, we do not recommended setting this manually"
|
||||
},
|
||||
)
|
||||
eval_batch_size: Optional[int] = Field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
|
||||
json_schema_extra={
|
||||
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
|
||||
},
|
||||
)
|
||||
|
||||
@@ -433,12 +445,13 @@ class HyperparametersConfig(BaseModel):
|
||||
]
|
||||
] = OptimizerNames.ADAMW_HF.value
|
||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
||||
default=None,
|
||||
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
||||
)
|
||||
optim_target_modules: Optional[Union[List[str], Literal["all_linear"]]] = Field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The target modules to optimize, i.e. the module names that you would like to train."
|
||||
json_schema_extra={
|
||||
"description": "The target modules to optimize, i.e. the module names that you would like to train."
|
||||
},
|
||||
)
|
||||
torchdistx_path: Optional[str] = None
|
||||
@@ -498,15 +511,15 @@ class LISAConfig(BaseModel):
|
||||
|
||||
lisa_n_layers: Optional[int] = Field(
|
||||
default=None,
|
||||
metadata={"help": "the number of activate layers in LISA"},
|
||||
json_schema_extra={"description": "the number of activate layers in LISA"},
|
||||
)
|
||||
lisa_step_interval: Optional[int] = Field(
|
||||
default=None,
|
||||
metadata={"help": "how often to switch layers in LISA"},
|
||||
json_schema_extra={"description": "how often to switch layers in LISA"},
|
||||
)
|
||||
lisa_layers_attribute: Optional[str] = Field(
|
||||
default="model.layers",
|
||||
metadata={"help": "path under the model to access the layers"},
|
||||
json_schema_extra={"description": "path under the model to access the layers"},
|
||||
)
|
||||
|
||||
|
||||
@@ -605,7 +618,8 @@ class AxolotlInputConfig(
|
||||
pretraining_dataset: Optional[ # type: ignore
|
||||
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
|
||||
] = Field(
|
||||
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
|
||||
default=None,
|
||||
json_schema_extra={"description": "streaming dataset to use for pretraining"},
|
||||
)
|
||||
dataset_processes: Optional[int] = Field(default=os.cpu_count())
|
||||
dataset_keep_in_memory: Optional[bool] = None
|
||||
@@ -665,7 +679,8 @@ class AxolotlInputConfig(
|
||||
sequence_len: int = Field(default=512)
|
||||
min_sample_len: Optional[int] = None
|
||||
max_prompt_len: int = Field(
|
||||
default=512, metadata={"help": "maximum prompt length for RL training"}
|
||||
default=512,
|
||||
json_schema_extra={"description": "maximum prompt length for RL training"},
|
||||
)
|
||||
sample_packing: Optional[bool] = None
|
||||
sample_packing_group_size: Optional[int] = 100_000
|
||||
@@ -684,8 +699,8 @@ class AxolotlInputConfig(
|
||||
pretrain_multipack_buffer_size: Optional[int] = 10_000
|
||||
pretrain_multipack_attn: Optional[bool] = Field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "whether to prevent cross attention for packed sequences during pretraining",
|
||||
json_schema_extra={
|
||||
"description": "whether to prevent cross attention for packed sequences during pretraining",
|
||||
},
|
||||
)
|
||||
|
||||
@@ -731,7 +746,7 @@ class AxolotlInputConfig(
|
||||
warmup_ratio: Optional[float] = None
|
||||
eval_steps: Optional[Union[int, float]] = None
|
||||
evals_per_epoch: Optional[Union[int]] = None
|
||||
evaluation_strategy: Optional[str] = None
|
||||
eval_strategy: Optional[str] = None
|
||||
save_steps: Optional[Union[int, float]] = None
|
||||
saves_per_epoch: Optional[int] = None
|
||||
save_strategy: Optional[str] = None
|
||||
@@ -1033,21 +1048,21 @@ class AxolotlInputConfig(
|
||||
@classmethod
|
||||
def check_evals(cls, data):
|
||||
if (
|
||||
data.get("evaluation_strategy")
|
||||
data.get("eval_strategy")
|
||||
and data.get("eval_steps")
|
||||
and data.get("evaluation_strategy") != "steps"
|
||||
and data.get("eval_strategy") != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||
"eval_strategy and eval_steps mismatch. Please set eval_strategy to 'steps' or remove eval_steps."
|
||||
)
|
||||
|
||||
if (
|
||||
data.get("val_set_size") == 0
|
||||
and (data.get("eval_steps") or data.get("evaluation_strategy"))
|
||||
and (data.get("eval_steps") or data.get("eval_strategy"))
|
||||
and not data.get("test_datasets")
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||
"eval_steps and eval_strategy are not supported with val_set_size == 0"
|
||||
)
|
||||
if data.get("evals_per_epoch") and data.get("eval_steps"):
|
||||
raise ValueError(
|
||||
@@ -1055,11 +1070,11 @@ class AxolotlInputConfig(
|
||||
)
|
||||
if (
|
||||
data.get("evals_per_epoch")
|
||||
and data.get("evaluation_strategy")
|
||||
and data.get("evaluation_strategy") != "steps"
|
||||
and data.get("eval_strategy")
|
||||
and data.get("eval_strategy") != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
"eval_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||
)
|
||||
|
||||
if data.get("do_bench_eval") and not (
|
||||
@@ -1291,6 +1306,26 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def warn_qlora_zero3_w_use_reentrant(cls, data):
|
||||
if (
|
||||
data.get("adapter") == "qlora"
|
||||
and data.get("gradient_checkpointing_kwargs", {})
|
||||
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
|
||||
is False
|
||||
and data.get("deepspeed", "") is not None
|
||||
and "zero3" in data.get("deepspeed", "")
|
||||
):
|
||||
# may result in:
|
||||
# torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint:
|
||||
# Recomputed values for the following tensors have different metadata
|
||||
# than during the forward pass.
|
||||
LOG.warning(
|
||||
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_val_w_test_datasets(cls, data):
|
||||
@@ -1300,6 +1335,19 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_eval_strategy(cls, data):
|
||||
if (
|
||||
data.get("evaluation_strategy") is not None
|
||||
and data.get("eval_strategy") is None
|
||||
):
|
||||
LOG.info(
|
||||
"explicitly setting `eval_strategy` from the `evaluation_strategy`"
|
||||
)
|
||||
data["eval_strategy"] = data.get("evaluation_strategy")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_offload_w_8bit_optimizer(cls, data):
|
||||
@@ -1378,21 +1426,6 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_unsloth_xformers_version(cls, data):
|
||||
if (
|
||||
data.get("unsloth_lora_mlp")
|
||||
or data.get("unsloth_lora_qkv")
|
||||
or data.get("unsloth_lora_o")
|
||||
):
|
||||
xformers_version = version("xformers")
|
||||
if xformers_version == "0.0.27":
|
||||
raise ValueError(
|
||||
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_torch_compile_deepspeed(cls, data):
|
||||
@@ -1402,6 +1435,40 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_npu_config(cls, data):
|
||||
if is_torch_npu_available():
|
||||
# check attention config
|
||||
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
|
||||
for attn in attn_list:
|
||||
if data.get(attn):
|
||||
raise NotImplementedError(
|
||||
f"{attn} is currently not supported in Ascend npu, please disable this configuration."
|
||||
)
|
||||
|
||||
# check quant config
|
||||
if data.get("optimizer") is not None and "bit" in data.get("optimizer"):
|
||||
optimizer = data.get("optimizer")
|
||||
raise NotImplementedError(
|
||||
f"{optimizer} is currently not supported in Ascend npu, choose another one please."
|
||||
)
|
||||
|
||||
quant_list = ["load_in_8bit", "load_in_4bit"]
|
||||
for quant in quant_list:
|
||||
if data.get(quant):
|
||||
raise NotImplementedError(
|
||||
f"Quantification is currently not supported in Ascend npu, please disable {quant}."
|
||||
)
|
||||
|
||||
# check dtype config
|
||||
if data.get("tf32"):
|
||||
raise NotImplementedError(
|
||||
"tf32 dtype is currently not supported in Ascend npu, please disable this configuration"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
"""wrapper to valdiate gpu capabilities with the configured options"""
|
||||
|
||||
@@ -64,15 +64,57 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
||||
|
||||
if isinstance(data_set, DatasetDict):
|
||||
data_set = data_set["train"]
|
||||
|
||||
data_set = data_set.map(
|
||||
ds_transform_fn,
|
||||
desc="Mapping RL Dataset",
|
||||
)
|
||||
if isinstance(data_set, DatasetDict):
|
||||
data_set = data_set["train"]
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
def drop_long_rl_seq(
|
||||
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
|
||||
):
|
||||
if rl in ("dpo", "ipo", "orpo", "simpo"):
|
||||
if not (
|
||||
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||
):
|
||||
raise ValueError(
|
||||
"Prompt, chosen and rejected keys are required for DPO/ORPO datasets"
|
||||
)
|
||||
|
||||
prompt = sample["prompt"]
|
||||
chosen = sample["chosen"]
|
||||
rejected = sample["rejected"]
|
||||
|
||||
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
|
||||
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
||||
|
||||
return (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len
|
||||
|
||||
if rl == "kto":
|
||||
if not (sample.get("prompt") and sample.get("completion")):
|
||||
raise ValueError("Prompt and completion keys are required for KTO datasets")
|
||||
|
||||
prompt = sample["prompt"]
|
||||
completion = sample["completion"]
|
||||
|
||||
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
|
||||
len_completion = len(
|
||||
tokenizer(completion, add_special_tokens=False)["input_ids"]
|
||||
)
|
||||
|
||||
return (len_prompt + len_completion) <= sequence_len
|
||||
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
|
||||
def load_prepare_dpo_datasets(cfg):
|
||||
def load_split(dataset_cfgs, _cfg):
|
||||
split_datasets: List[Any] = []
|
||||
@@ -94,7 +136,7 @@ def load_prepare_dpo_datasets(cfg):
|
||||
)
|
||||
split_datasets.insert(i, ds)
|
||||
|
||||
tokenizer = None
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
for i, data_set in enumerate(split_datasets):
|
||||
_type = dataset_cfgs[i]["type"]
|
||||
@@ -121,7 +163,28 @@ def load_prepare_dpo_datasets(cfg):
|
||||
# "prompt", "chosen" and "rejected" already preprocessed
|
||||
split_datasets[i] = data_set
|
||||
|
||||
return concatenate_datasets(split_datasets)
|
||||
drop_long = partial(
|
||||
drop_long_rl_seq,
|
||||
rl=_cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
|
||||
|
||||
combined_datasets = concatenate_datasets(split_datasets)
|
||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
||||
|
||||
return combined_datasets
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
train_is_preprocessed = False
|
||||
|
||||
@@ -260,6 +260,7 @@ def load_tokenized_prepared_datasets(
|
||||
for config_dataset in for_d_in_datasets(cfg_datasets):
|
||||
ds: Optional[Union[Dataset, DatasetDict]] = None
|
||||
ds_from_hub = False
|
||||
ds_trust_remote_code = config_dataset.trust_remote_code
|
||||
try:
|
||||
# this is just a basic check to see if the path is a
|
||||
# valid HF dataset that's loadable
|
||||
@@ -269,6 +270,7 @@ def load_tokenized_prepared_datasets(
|
||||
streaming=True,
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=ds_trust_remote_code,
|
||||
)
|
||||
ds_from_hub = True
|
||||
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
|
||||
@@ -348,7 +350,15 @@ def load_tokenized_prepared_datasets(
|
||||
split=None,
|
||||
)
|
||||
else:
|
||||
ds = load_from_disk(config_dataset.path)
|
||||
try:
|
||||
ds = load_from_disk(config_dataset.path)
|
||||
except FileNotFoundError:
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
|
||||
@@ -366,7 +376,7 @@ def load_tokenized_prepared_datasets(
|
||||
elif ds_from_hub:
|
||||
load_ds_kwargs = {}
|
||||
if config_dataset.split:
|
||||
load_ds_kwargs = {"split": config_dataset.split}
|
||||
load_ds_kwargs["split"] = config_dataset.split
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
@@ -374,6 +384,7 @@ def load_tokenized_prepared_datasets(
|
||||
data_files=config_dataset.data_files,
|
||||
token=use_auth_token,
|
||||
revision=config_dataset.revision,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
elif ds_from_cloud and remote_file_system:
|
||||
@@ -391,6 +402,7 @@ def load_tokenized_prepared_datasets(
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
)
|
||||
elif config_dataset.path.startswith("https://"):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
@@ -401,6 +413,7 @@ def load_tokenized_prepared_datasets(
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=config_dataset.trust_remote_code,
|
||||
)
|
||||
else:
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
|
||||
@@ -9,10 +9,44 @@ from datetime import timedelta
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from accelerate import PartialState
|
||||
from transformers.utils.import_utils import (
|
||||
is_torch_cuda_available,
|
||||
is_torch_mps_available,
|
||||
is_torch_npu_available,
|
||||
)
|
||||
|
||||
distributed_state = None # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def get_device_type():
|
||||
device = torch.device("cpu")
|
||||
if is_torch_cuda_available():
|
||||
device = torch.device("cuda")
|
||||
elif is_torch_mps_available():
|
||||
device = torch.device("mps")
|
||||
elif is_torch_npu_available():
|
||||
device = torch.device("npu")
|
||||
return device
|
||||
|
||||
|
||||
def get_device_count():
|
||||
cur_device = get_device_type()
|
||||
if "cuda" in str(cur_device):
|
||||
return torch.cuda.device_count()
|
||||
if "npu" in str(cur_device):
|
||||
return torch.npu.device_count()
|
||||
return 1
|
||||
|
||||
|
||||
def get_current_device():
|
||||
cur_device = get_device_type()
|
||||
if "cuda" in str(cur_device):
|
||||
return torch.cuda.current_device()
|
||||
if "npu" in str(cur_device):
|
||||
return torch.npu.current_device()
|
||||
return 0
|
||||
|
||||
|
||||
def is_distributed():
|
||||
"""
|
||||
Check if distributed training is initialized.
|
||||
@@ -91,7 +125,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
||||
if not is_distributed():
|
||||
return [value_scalar]
|
||||
value_tensor = torch.tensor(
|
||||
value_scalar, device=torch.cuda.current_device()
|
||||
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
|
||||
).float()
|
||||
|
||||
if not is_main_process():
|
||||
@@ -115,13 +149,14 @@ def broadcast_dict(vals: dict):
|
||||
if not is_distributed():
|
||||
return vals
|
||||
|
||||
cur_device = get_device_type()
|
||||
if is_main_process():
|
||||
data_byte = pickle.dumps(vals)
|
||||
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
|
||||
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
|
||||
data_tensor = torch.ByteTensor(list(data_byte)).to(cur_device)
|
||||
data_size = torch.IntTensor([len(data_byte)]).to(cur_device)
|
||||
else:
|
||||
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
|
||||
data_size = torch.IntTensor([0]).to("cuda")
|
||||
data_tensor = torch.empty([1024], dtype=torch.uint8, device=cur_device)
|
||||
data_size = torch.IntTensor([0]).to(cur_device)
|
||||
|
||||
dist.broadcast(data_size, 0)
|
||||
if not is_main_process():
|
||||
@@ -150,14 +185,15 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
|
||||
Returns:
|
||||
- The computed value (int or float).
|
||||
"""
|
||||
cur_device = f"{get_device_type()}:{get_current_device()}"
|
||||
if is_main_process():
|
||||
value_scalar = fn()
|
||||
value_tensor = torch.tensor(
|
||||
value_scalar, device=torch.cuda.current_device(), dtype=torch.float32
|
||||
value_scalar, device=cur_device, dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
value_tensor = torch.tensor(
|
||||
0.0, device=torch.cuda.current_device(), dtype=torch.float32
|
||||
0.0, device=cur_device, dtype=torch.float32
|
||||
) # Placeholder tensor
|
||||
|
||||
# Broadcast the tensor to all processes.
|
||||
@@ -184,7 +220,7 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
|
||||
"""
|
||||
value_scalar = fn()
|
||||
value_tensor = torch.tensor(
|
||||
value_scalar, device=torch.cuda.current_device()
|
||||
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
|
||||
).float()
|
||||
|
||||
# Placeholder tensor for gathering results
|
||||
|
||||
@@ -46,6 +46,7 @@ from transformers.integrations.deepspeed import (
|
||||
)
|
||||
|
||||
from axolotl.common.architectures import MOE_ARCH_BLOCK
|
||||
from axolotl.integrations.sageattention.lib.core import monkeypatch_sdp_w_sage_attention
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
@@ -55,7 +56,7 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import zero_only
|
||||
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
|
||||
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
|
||||
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
|
||||
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
|
||||
@@ -238,6 +239,7 @@ def load_tokenizer(cfg):
|
||||
x in cfg.lora_modules_to_save for x in lora_modules_to_save
|
||||
)
|
||||
)
|
||||
and k != "pad_token"
|
||||
):
|
||||
lora_modules_to_save = ", ".join(
|
||||
[f"`{x}`" for x in lora_modules_to_save]
|
||||
@@ -394,10 +396,17 @@ class ModelLoader:
|
||||
and self.cfg.flash_attention
|
||||
and self.cfg.sample_packing
|
||||
):
|
||||
has_remote_code = (
|
||||
"auto_map" in self.model_config
|
||||
and "AutoModelForCausalLM" in self.model_config["auto_map"]
|
||||
)
|
||||
if has_remote_code and self.cfg.trust_remote_code is False:
|
||||
# if explicitly set in the YAML, we should prefer that, for example if explicitly disabled
|
||||
has_remote_code = self.cfg.trust_remote_code
|
||||
patch_for_multipack(
|
||||
self.cfg.model_config_type,
|
||||
model_name=self.cfg.base_model,
|
||||
is_remote_code=self.cfg.trust_remote_code,
|
||||
has_remote_code=has_remote_code,
|
||||
)
|
||||
|
||||
if self.cfg.is_llama_derived_model:
|
||||
@@ -562,7 +571,8 @@ class ModelLoader:
|
||||
)
|
||||
|
||||
max_memory = {}
|
||||
for i in range(torch.cuda.device_count()):
|
||||
num_device = get_device_count()
|
||||
for i in range(num_device):
|
||||
max_memory[i] = gpu_memory_limit
|
||||
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
|
||||
|
||||
@@ -587,8 +597,11 @@ class ModelLoader:
|
||||
self.model_kwargs["device_map"] = device_map
|
||||
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
cur_device = get_device_type()
|
||||
if "mps" in str(cur_device):
|
||||
self.model_kwargs["device_map"] = "mps:0"
|
||||
elif "npu" in str(cur_device):
|
||||
self.model_kwargs["device_map"] = "npu:0"
|
||||
|
||||
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
|
||||
# if cfg.rl:
|
||||
@@ -695,6 +708,7 @@ class ModelLoader:
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"sdpa"
|
||||
)
|
||||
monkeypatch_sdp_w_sage_attention()
|
||||
elif self.cfg.eager_attention:
|
||||
self.model_kwargs["attn_implementation"] = "eager"
|
||||
self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
@@ -1042,7 +1056,11 @@ class ModelLoader:
|
||||
self.ajust_model_config()
|
||||
|
||||
# log device memory usage
|
||||
if hasattr(self.model, "device") and self.model.device.type in ("cuda", "mps"):
|
||||
if hasattr(self.model, "device") and self.model.device.type in (
|
||||
"cuda",
|
||||
"mps",
|
||||
"npu",
|
||||
):
|
||||
log_gpu_memory_usage(LOG, "after model load", self.model.device)
|
||||
|
||||
# make sure these are fp32 per Ramesh et al. (2021)
|
||||
@@ -1110,9 +1128,9 @@ class ModelLoader:
|
||||
and not skip_move_to_device
|
||||
):
|
||||
# TODO revaldate this conditional
|
||||
self.model.to(f"cuda:{self.cfg.local_rank}")
|
||||
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
|
||||
|
||||
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||
setattr(self.model, "is_parallelizable", True)
|
||||
setattr(self.model, "model_parallel", True)
|
||||
|
||||
|
||||
@@ -66,28 +66,47 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
|
||||
|
||||
|
||||
def check_rl_example_labels(example, tokenizer, text_only=False):
|
||||
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
|
||||
field_prompt, field_chosen, field_rejected, field_completion = (
|
||||
"prompt",
|
||||
"chosen",
|
||||
"rejected",
|
||||
"completion",
|
||||
)
|
||||
|
||||
input_tokens = example[field_prompt]
|
||||
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
|
||||
|
||||
labels_chosen = example.get(field_chosen)
|
||||
labels_rejected = example.get(field_rejected)
|
||||
labels_completion = example.get(field_completion)
|
||||
|
||||
# Create a delimiter based on text_only flag
|
||||
delimiter = "" if text_only else " "
|
||||
|
||||
# Process and color each type of token
|
||||
colored_tokens = process_tokens_for_rl_debug(
|
||||
input_tokens, "yellow", tokenizer, text_only
|
||||
)
|
||||
colored_chosens = process_tokens_for_rl_debug(
|
||||
labels_chosen, "green", tokenizer, text_only
|
||||
)
|
||||
colored_rejecteds = process_tokens_for_rl_debug(
|
||||
labels_rejected, "red", tokenizer, text_only
|
||||
)
|
||||
|
||||
# Create a delimiter based on text_only flag
|
||||
delimiter = "" if text_only else " "
|
||||
# Process tokens
|
||||
if labels_completion is None:
|
||||
colored_chosens = process_tokens_for_rl_debug(
|
||||
labels_chosen, "green", tokenizer, text_only
|
||||
)
|
||||
colored_rejecteds = process_tokens_for_rl_debug(
|
||||
labels_rejected, "red", tokenizer, text_only
|
||||
)
|
||||
else:
|
||||
colored_completion = process_tokens_for_rl_debug(
|
||||
labels_completion, "green", tokenizer, text_only
|
||||
)
|
||||
|
||||
# Logging information
|
||||
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
|
||||
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
|
||||
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
|
||||
|
||||
if labels_completion is None:
|
||||
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
|
||||
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
|
||||
else:
|
||||
LOG.info(f"COMPLETION RESPONSE: {delimiter.join(colored_completion)}\n\n\n")
|
||||
|
||||
return delimiter.join(colored_tokens)
|
||||
|
||||
@@ -203,37 +203,59 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
|
||||
eval_dataset = eval_dataset.remove_columns("token_type_ids")
|
||||
|
||||
prior_len = len(train_dataset)
|
||||
train_dataset = train_dataset.filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(train_dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from train dataset")
|
||||
|
||||
if eval_dataset:
|
||||
prior_len = len(eval_dataset)
|
||||
eval_dataset = eval_dataset.filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(eval_dataset)
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
|
||||
|
||||
# drop samples with where the number of elements with labels not equal to -100 is zero
|
||||
def drop_no_trainable_tokens(sample):
|
||||
return np.sum(np.array(sample["labels"]) != -100) > 0
|
||||
|
||||
prior_len = len(train_dataset)
|
||||
train_dataset = train_dataset.filter(
|
||||
drop_no_trainable_tokens,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Drop Samples with Zero Trainable Tokens",
|
||||
)
|
||||
dropped = prior_len - len(train_dataset)
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} samples with no trainable tokens from train dataset"
|
||||
)
|
||||
|
||||
if eval_dataset:
|
||||
prior_len = len(eval_dataset)
|
||||
eval_dataset = eval_dataset.filter(
|
||||
drop_no_trainable_tokens,
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Drop Samples with Zero Trainable Tokens",
|
||||
)
|
||||
dropped = prior_len - len(eval_dataset)
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} samples with no trainable tokens from eval dataset"
|
||||
)
|
||||
|
||||
if cfg.group_by_length:
|
||||
train_dataset = train_dataset.map(
|
||||
@@ -493,7 +515,7 @@ def prepare_opinionated_env(cfg):
|
||||
def setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
|
||||
):
|
||||
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
|
||||
if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"):
|
||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
|
||||
trainer_builder.model_ref = model[1]
|
||||
trainer_builder.peft_config = model[2]
|
||||
|
||||
35
tests/e2e/conftest.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
shared pytest fixtures
|
||||
"""
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_smollm2_135m_model():
|
||||
# download the model
|
||||
snapshot_download("HuggingFaceTB/SmolLM2-135M")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_tatsu_lab_alpaca_dataset():
|
||||
# download the model
|
||||
snapshot_download("tatsu-lab/alpaca", repo_type="dataset")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_mhenrichsen_alpaca_2k_dataset():
|
||||
# download the model
|
||||
snapshot_download("mhenrichsen/alpaca_2k_test", repo_type="dataset")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
# Create a temporary directory
|
||||
_temp_dir = tempfile.mkdtemp()
|
||||
yield _temp_dir
|
||||
# Clean up the directory after the test
|
||||
shutil.rmtree(_temp_dir)
|
||||
@@ -3,28 +3,25 @@ E2E tests for multigpu eval
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
|
||||
class TestMultiGPUEval(unittest.TestCase):
|
||||
class TestMultiGPUEval:
|
||||
"""
|
||||
Test case for MultiGPU Eval Sample Packing
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_eval_sample_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
@@ -83,13 +80,14 @@ class TestMultiGPUEval(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_eval(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
@@ -148,6 +146,8 @@ class TestMultiGPUEval(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
|
||||
@@ -4,17 +4,17 @@ E2E tests for multigpu lora tinyllama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import is_hopper, with_temp_dir
|
||||
from ..utils import is_hopper
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -25,21 +25,19 @@ AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def download_model():
|
||||
# download the model
|
||||
snapshot_download("TinyLlama/TinyLlama_v1.1")
|
||||
snapshot_download("HuggingFaceTB/SmolLM2-135M")
|
||||
|
||||
|
||||
class TestMultiGPULlama(unittest.TestCase):
|
||||
class TestMultiGPULlama:
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 2048,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
@@ -48,9 +46,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -81,19 +77,23 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_ddp_packed(self, temp_dir):
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 4],
|
||||
)
|
||||
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": False,
|
||||
@@ -105,9 +105,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -118,7 +116,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
@@ -138,6 +136,8 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
@@ -145,13 +145,11 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
|
||||
@with_temp_dir
|
||||
def test_dpo_lora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": False,
|
||||
"eval_sample_packing": False,
|
||||
@@ -164,12 +162,10 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"rl": "dpo",
|
||||
"chat_template": "llama3",
|
||||
"chat_template": "chatml",
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
@@ -210,18 +206,19 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_qlora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 2048,
|
||||
"sample_packing": False,
|
||||
"eval_sample_packing": False,
|
||||
@@ -278,25 +275,27 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_fsdp(self, temp_dir):
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 4],
|
||||
)
|
||||
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"val_set_size": 0.01,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -305,9 +304,9 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"max_steps": 10,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
@@ -324,7 +323,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"fsdp_use_orig_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": False,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||
"fsdp_state_dict_type": "FULL_STATE_DICT",
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
}
|
||||
@@ -341,28 +340,29 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_fsdp_packed(self, temp_dir):
|
||||
@pytest.mark.parametrize(
|
||||
"fsdp_state_dict_type",
|
||||
["FULL_STATE_DICT", "SHARDED_STATE_DICT"],
|
||||
)
|
||||
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -390,7 +390,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"fsdp_use_orig_params": False,
|
||||
"fsdp_cpu_ram_efficient_loading": False,
|
||||
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
|
||||
"fsdp_state_dict_type": "SHARDED_STATE_DICT",
|
||||
"fsdp_state_dict_type": fsdp_state_dict_type,
|
||||
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
|
||||
},
|
||||
}
|
||||
@@ -407,19 +407,19 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/TinyLlama_v1.1-bnb-nf4-bf16",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
|
||||
"adapter": "qlora",
|
||||
"mean_resizing_embeddings": True,
|
||||
"load_in_4bit": True,
|
||||
@@ -427,17 +427,17 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"lora_modules_to_save": [
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
],
|
||||
# "lora_modules_to_save": [
|
||||
# "embed_tokens",
|
||||
# "lm_head",
|
||||
# ],
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"pad_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -483,28 +483,29 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ds_zero3_packed(self, temp_dir):
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 4],
|
||||
)
|
||||
def test_ds_zero3_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": False,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -515,7 +516,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 4,
|
||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
@@ -536,19 +537,19 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ds_zero3_qlora_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TinyLlama/TinyLlama_v1.1",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
@@ -561,9 +562,7 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -595,6 +594,8 @@ class TestMultiGPULlama(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
|
||||
@@ -4,31 +4,30 @@ E2E tests for multigpu qwen2
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestMultiGPUQwen2(unittest.TestCase):
|
||||
class TestMultiGPUQwen2:
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_qlora_fsdp_dpo(self, temp_dir):
|
||||
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
||||
def test_qlora_fsdp_dpo(self, base_model, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2-1.5B",
|
||||
"base_model": base_model,
|
||||
"load_in_4bit": True,
|
||||
"rl": "dpo",
|
||||
"chat_template": "chatml",
|
||||
@@ -47,9 +46,9 @@ class TestMultiGPUQwen2(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 15,
|
||||
"max_steps": 5,
|
||||
"warmup_steps": 20,
|
||||
"micro_batch_size": 4,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
@@ -91,6 +90,8 @@ class TestMultiGPUQwen2(unittest.TestCase):
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
|
||||
@@ -66,6 +66,8 @@ class TestFAXentropyLlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 10,
|
||||
"save_steps": 10,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
|
||||
@@ -56,6 +56,8 @@ class TestLoraLlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
@@ -109,6 +111,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"max_steps": 20,
|
||||
"save_steps": 0.5,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
|
||||
66
tests/e2e/test_llama.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
E2E tests for llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_trust_remote_code(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"trust_remote_code": True,
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"bf16": True,
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
@@ -108,3 +108,37 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "schedule_free_adamw",
|
||||
"lr_scheduler": "constant",
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
|
||||
@@ -31,7 +31,7 @@ class TestPackedLlama(unittest.TestCase):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM-135M",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"flash_attention": True,
|
||||
|
||||
85
tests/e2e/test_qwen.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
E2E tests for qwen
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from transformers.testing_utils import get_torch_dist_unique_port
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.qwen")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestE2eQwen:
|
||||
"""
|
||||
Test cases for qwen models
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
||||
def test_dpo(self, base_model, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": base_model,
|
||||
"rl": "dpo",
|
||||
"chat_template": "qwen_25",
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.0,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
"split": "train",
|
||||
"type": "chat_template.default",
|
||||
"field_messages": "conversation",
|
||||
"field_chosen": "chosen",
|
||||
"field_rejected": "rejected",
|
||||
"message_field_role": "role",
|
||||
"message_field_content": "content",
|
||||
"roles": {
|
||||
"system": ["system"],
|
||||
"user": ["user"],
|
||||
"assistant": ["assistant"],
|
||||
},
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 5,
|
||||
"warmup_steps": 20,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"bf16": "auto",
|
||||
"tf32": True,
|
||||
"gradient_checkpointing": True,
|
||||
}
|
||||
)
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||
|
||||
execute_subprocess_async(
|
||||
[
|
||||
"accelerate",
|
||||
"launch",
|
||||
"--num-processes",
|
||||
"2",
|
||||
"--main_process_port",
|
||||
f"{get_torch_dist_unique_port()}",
|
||||
"-m",
|
||||
"axolotl.cli.train",
|
||||
str(Path(temp_dir) / "config.yaml"),
|
||||
]
|
||||
)
|
||||
@@ -371,44 +371,79 @@ class TestDatasetPreparation(unittest.TestCase):
|
||||
def test_load_local_hub_with_revision(self):
|
||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir2:
|
||||
tmp_ds_path = Path(tmp_dir2) / "mhenrichsen/alpaca_2k_test"
|
||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||
snapshot_download(
|
||||
repo_id="mhenrichsen/alpaca_2k_test",
|
||||
repo_type="dataset",
|
||||
local_dir=tmp_ds_path,
|
||||
revision="d05c1cb",
|
||||
)
|
||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||
snapshot_download(
|
||||
repo_id="mhenrichsen/alpaca_2k_test",
|
||||
repo_type="dataset",
|
||||
local_dir=tmp_ds_path,
|
||||
revision="d05c1cb",
|
||||
)
|
||||
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"ds_type": "parquet",
|
||||
"type": "alpaca",
|
||||
"data_files": [
|
||||
f"{tmp_ds_path}/alpaca_2000.parquet",
|
||||
],
|
||||
"revision": "d05c1cb",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"ds_type": "parquet",
|
||||
"type": "alpaca",
|
||||
"data_files": [
|
||||
f"{tmp_ds_path}/alpaca_2000.parquet",
|
||||
],
|
||||
"revision": "d05c1cb",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
def test_loading_local_dataset_folder(self):
|
||||
"""Verify that a dataset downloaded to a local folder can be loaded"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||
snapshot_download(
|
||||
repo_id="mhenrichsen/alpaca_2k_test",
|
||||
repo_type="dataset",
|
||||
local_dir=tmp_ds_path,
|
||||
)
|
||||
|
||||
prepared_path = Path(tmp_dir) / "prepared"
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"sequence_len": 1024,
|
||||
"datasets": [
|
||||
{
|
||||
"path": str(tmp_ds_path),
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
dataset, _ = load_tokenized_prepared_datasets(
|
||||
self.tokenizer, cfg, prepared_path
|
||||
)
|
||||
|
||||
assert len(dataset) == 2000
|
||||
assert "input_ids" in dataset.features
|
||||
assert "attention_mask" in dataset.features
|
||||
assert "labels" in dataset.features
|
||||
shutil.rmtree(tmp_ds_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -32,16 +32,19 @@ class TestCosineConstantLr(unittest.TestCase):
|
||||
def test_schedulers(self):
|
||||
self.assertEqual(self.lr_scheduler.get_last_lr()[0], 0)
|
||||
for _ in range(self.warmup_steps):
|
||||
self.optimizer.step()
|
||||
self.lr_scheduler.step()
|
||||
self.assertEqual(self.lr_scheduler.get_last_lr()[0], self._lr)
|
||||
constant_step = int(self.train_steps * self.constant_lr_ratio)
|
||||
remaining_step = self.train_steps - constant_step
|
||||
for _ in range(constant_step):
|
||||
self.optimizer.step()
|
||||
self.lr_scheduler.step()
|
||||
self.assertEqual(
|
||||
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
|
||||
)
|
||||
for _ in range(remaining_step):
|
||||
self.optimizer.step()
|
||||
self.lr_scheduler.step()
|
||||
self.assertEqual(
|
||||
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
|
||||
|
||||
@@ -68,6 +68,53 @@ class TestValidation(BaseValidation):
|
||||
assert cfg.train_on_inputs is False
|
||||
assert cfg.weight_decay is None
|
||||
|
||||
def test_zero3_qlora_use_reentrant_false(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"deepspeed": "deepspeed_configs/zero3_bf16.json",
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(test_cfg)
|
||||
assert (
|
||||
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
def test_deepspeed_empty(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"deepspeed": "",
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
_ = validate_config(test_cfg)
|
||||
|
||||
def test_deepspeed_not_set(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"deepspeed": None,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
_ = validate_config(test_cfg)
|
||||
|
||||
def test_datasets_min_length(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
@@ -726,7 +773,7 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "epoch",
|
||||
"eval_strategy": "epoch",
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
@@ -734,14 +781,14 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||
ValueError, match=r".*eval_strategy and eval_steps mismatch.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "no",
|
||||
"eval_strategy": "no",
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
@@ -749,14 +796,14 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||
ValueError, match=r".*eval_strategy and eval_steps mismatch.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "steps",
|
||||
"eval_strategy": "steps",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
@@ -767,7 +814,7 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "steps",
|
||||
"eval_strategy": "steps",
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
@@ -790,7 +837,7 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "no",
|
||||
"eval_strategy": "no",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
@@ -801,7 +848,7 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "epoch",
|
||||
"eval_strategy": "epoch",
|
||||
"val_set_size": 0,
|
||||
}
|
||||
)
|
||||
@@ -810,7 +857,7 @@ class TestValidation(BaseValidation):
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||
match=r".*eval_steps and eval_strategy are not supported with val_set_size == 0.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
@@ -826,7 +873,7 @@ class TestValidation(BaseValidation):
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||
match=r".*eval_steps and eval_strategy are not supported with val_set_size == 0.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
@@ -856,7 +903,7 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "epoch",
|
||||
"eval_strategy": "epoch",
|
||||
"val_set_size": 0.01,
|
||||
}
|
||||
)
|
||||
@@ -1095,6 +1142,24 @@ class TestValidation(BaseValidation):
|
||||
assert new_cfg["dpo_beta"] is None
|
||||
assert len(self._caplog.records) == 1
|
||||
|
||||
def test_eval_strategy_remap(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "steps",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.eval_strategy == "steps"
|
||||
assert (
|
||||
"evaluation_strategy is deprecated, use eval_strategy instead"
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
|
||||
class TestValidationCheckModelConfig(BaseValidation):
|
||||
"""
|
||||
|
||||