Compare commits

..

37 Commits

Author SHA1 Message Date
Wing Lian
3b432346e3 WIP 2024-03-07 08:30:13 -05:00
Wing Lian
58b0d4b0d8 update flash attention for gemma support: (#1368) 2024-03-06 10:08:54 -05:00
Hamel Husain
ed70a08348 add docs for input_output format (#1367) [skip ci]
* add docs

* add docs

* run linter
2024-03-06 09:09:49 -05:00
Wing Lian
0cfdb2c90c support for DoRA w/ PEFT (#1363) 2024-03-05 21:20:15 -05:00
Nicolas Rojas
37657473c8 Remove unsupported python version 3.9 from README (#1364) [skip ci] 2024-03-05 21:19:36 -05:00
Eric Hartford
e0f1895408 add starcoder2 (#1349)
* add starcoder2

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* chore: lint

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2024-03-05 19:49:17 -05:00
Sebastian Raschka
8984bf1722 Update tinyllama lora.yml to fix eval packing issue (#1362) 2024-03-05 14:36:29 -05:00
Wing Lian
2598c9f045 allow the sharegpt handler to also better handle datasets destined for openai finetuning (#1361)
* allow the sharegpt handler to also better handle datasets destined for openai finetuning

* make sure to support system role
2024-03-05 11:43:33 -05:00
Wing Lian
decb66e170 lora+ support (#1352)
* lora+ support

* optimizer should default to None

* include mit license
2024-03-05 07:29:23 -05:00
Wing Lian
4d09b42ee3 plain input/output prompt strategy w/o chat templates (#1346)
* plain input/output prompt strategy w/o chat templates

* disable duplicate code check

* make sure to add an eos/eot token to the end of the output so it will stop

* multi turn segement support and test
2024-03-04 16:25:16 -05:00
Chirag Jain
b5b44925ec Fix validation for early stopping (#1358) 2024-03-03 22:15:18 -05:00
NanoCode012
170d4d7092 chore: enable sample_packing for Gemma (#1351) 2024-03-01 21:56:22 -05:00
Wing Lian
00018629e7 run tests again on Modal (#1289) [skip ci]
* run tests again on Modal

* make sure to run the full suite of tests on modal

* run cicd steps via shell script

* run tests in different runs

* increase timeout

* split tests into steps on modal

* increase workflow timeout

* retry doing this with only a single script

* fix yml launch for modal ci

* reorder tests to run on modal

* skip dpo tests on modal

* run on L4s, A10G takes too long

* increase CPU and RAM for modal test

* run modal tests on A100s

* skip phi test on modal

* env not arg in modal dockerfile

* upgrade pydantic and fastapi for modal tests

* cleanup stray character

* use A10s instead of A100 for modal
2024-02-29 14:26:26 -05:00
Wing Lian
6b3b271925 fix for protected model_ namespace w pydantic (#1345) 2024-02-28 15:07:49 -05:00
Chirag Jain
3a5a2d2f34 Fix use_mlflow to be bool instead of str (#1344) 2024-02-28 12:58:29 -05:00
Wing Lian
6d4bbb877f deprecate py 3.9 support, set min pytorch version (#1343) [skip ci] 2024-02-28 12:58:05 -05:00
Wing Lian
0f985e12fe more fixes 20240228 (#1342) [skip ci]
* add missing evals_per_epoch setting

* more pydantic fixes

* more fixes

* move test from normalization to validation

* increase eval size for sample packing tests
2024-02-28 12:57:45 -05:00
Wing Lian
c1a7b3dd69 add gemma instruct chat template (#1341)
* add gemma instruct chat template

* support for chat tempalte strategy too
2024-02-27 17:20:01 -05:00
Ikko Eltociear Ashimine
2b9687f341 Update fastchat_conversation_turns.py (#1294) [skip ci]
seperated -> separated
2024-02-27 09:06:10 -05:00
Wing Lian
2c9c88b32a fix steps check for anneal on first cycle (#1316) 2024-02-27 08:56:08 -05:00
Hamel Husain
5265cd6b2c Update debugging.md (#1339) [skip ci] 2024-02-27 15:47:31 +09:00
NanoCode012
5be8b555a0 fix: checkpoint saving with deepspeed (#1321) 2024-02-27 15:46:44 +09:00
Maxime
0f6af36d50 Mps mistral lora (#1292) [skip ci]
* Lora example for Mistral on MPS backend

* Add some MPS documentation

* Update examples/mistral/lora-mps.yml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update examples/mistral/lora-mps.yml

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* Update README.md

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-02-26 22:39:57 -05:00
Wing Lian
3f69571943 more pydantic fixes (#1338) 2024-02-26 22:39:13 -05:00
nopperl
1e3d5305d3 Support user-defined prompt processing strategies for dpo (#1248)
* support user-defined prompt processing strategies for dpo

* interpret dict dataset types as user-defined

* fix lint errors

* setup pydantic config for validation of User defined DPO

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-02-26 18:49:34 -05:00
Maxime
16482796b0 add lion-pytorch optimizer (#1299) [skip ci]
* add lion-pytorch optimizer

* update pydantic to support lion optimizer

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-02-26 18:45:14 -05:00
Nathan Cooper
f30d062b48 Add StableLM 2 Example Scripts (#1327) [skip ci]
* Add StableLM examples and configurations

* Add FFT and LORA configuration files and modify readme with usage
2024-02-26 18:44:25 -05:00
Wing Lian
269c5436ea hotfix to exclude_unset from pydantic config when converting back to a dict (#1334) 2024-02-26 15:06:25 -05:00
Wing Lian
e7eed203d8 hotfix for missing outputs params (#1333) 2024-02-26 14:36:37 -05:00
Wing Lian
cf002312e0 hotfix for lora rank (#1332) 2024-02-26 14:28:43 -05:00
Wing Lian
7de912e097 hotfix for capabilities loading (#1331) 2024-02-26 14:24:28 -05:00
JohanWork
d75653407c ADD: push checkpoints to mlflow artifact registry (#1295) [skip ci]
* Add checkpoint logging to mlflow artifact registry

* clean up

* Update README.md

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* update pydantic config from rebase

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-02-26 13:32:39 -05:00
NanoCode012
c6b01e0f4a chore: update readme to be more clear (#1326) [skip ci] 2024-02-26 13:32:13 -05:00
Wing Lian
cc3cebfa70 Pydantic 2.x cfg (#1239)
* WIP conversion to use pydantic for config validation

* wip, more fields, add capabilities

* wip

* update pydantic validation to match existing tests

* tweak requirements

* setup deprecated paams pydantic model

* more validations

* wrap up rest of the validations

* flesh out the rest of the options from the readme into pydantic

* fix model validators as class methods

remember to return in validator
missing return
add missing relora attributes
fix test for DictDefault change
fix sys template for mistral from fastchat change in PR 2872
fix test for batch size warning

* more missing attributes for cfg

* updates from PR feedback

* fix validation for datasets and pretrain datasets

* fix test for lora check
2024-02-26 12:24:14 -05:00
Wing Lian
5894f0e57e make mlflow optional (#1317)
* make mlflow optional

* fix xformers

don't patch swiglu if xformers not working
fix the check for xformers swiglu

* fix install of xformers with extra index url for docker builds

* fix docker build arg quoting
2024-02-26 11:41:33 -05:00
kallewoof
5cf226e177 Use yaml codeblock for config.yaml field (#1303) [skip ci] 2024-02-24 21:59:16 +09:00
NanoCode012
2ed52bd568 fix(readme): Clarify doc for tokenizer_config (#1323) [skip ci] 2024-02-24 21:55:04 +09:00
65 changed files with 3383 additions and 620 deletions

View File

@@ -59,6 +59,7 @@ body:
label: Config yaml
description: |
Please attach the config yaml!
render: yaml
- type: textarea
id: possible-solution

View File

@@ -17,6 +17,6 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0

View File

@@ -18,6 +18,7 @@ jobs:
python_version: "3.10"
pytorch: 2.1.2
axolotl_extras:
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
is_latest: true
- cuda: 121
cuda_version: 12.1.0
@@ -54,6 +55,7 @@ jobs:
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_ARGS=${{ matrix.axolotl_args }}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: |

View File

@@ -23,7 +23,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: "3.10"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
@@ -33,7 +33,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python_version: ["3.9", "3.10", "3.11"]
python_version: ["3.10", "3.11"]
timeout-minutes: 10
steps:
@@ -58,8 +58,8 @@ jobs:
docker-e2e-tests:
if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, gpu, docker]
timeout-minutes: 30
runs-on: [self-hosted, modal]
timeout-minutes: 60
needs: [pre-commit, pytest]
strategy:
@@ -70,43 +70,31 @@ jobs:
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.2
axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
num_gpus: 1
- cuda: 121
cuda_version: 12.1.0
python_version: "3.10"
pytorch: 2.1.2
num_gpus: 1
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Docker metadata
id: metadata
uses: docker/metadata-action@v5
- name: Install Python
uses: actions/setup-python@v5
with:
images: winglian/axolotl-tests
- name: Build Docker image
python-version: "3.10"
- name: Install Modal
run: |
# Set up build arguments
BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
CUDA="${{ matrix.cuda }}"
PYTORCH_VERSION="${{ matrix.pytorch }}"
# Build the Docker image
docker build . \
--file ./docker/Dockerfile-tests \
--build-arg BASE_TAG=$BASE_TAG \
--build-arg CUDA=$CUDA \
--build-arg GITHUB_REF=$GITHUB_REF \
--build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
--tag ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} \
--no-cache
- name: Unit Tests w docker image
python -m pip install --upgrade pip
pip install modal jinja2
- name: Update env vars
run: |
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
- name: GPU Unit Tests w docker image
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
- name: GPU Unit Tests monkeypatched w docker image
run: |
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest /workspace/axolotl/tests/e2e/patched/
- name: Prune image from docker
if: github.ref != 'refs/heads/main'
run: |
docker rmi -f ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
modal run cicd.tests

5
.gitignore vendored
View File

@@ -167,3 +167,8 @@ cython_debug/
# WandB
# wandb creates a folder to store logs for training runs
wandb
# Runs
lora-out/*
qlora-out/*
mlruns/*

View File

@@ -1,5 +1,5 @@
[mypy]
plugins = pydantic.mypy
exclude = venv
[mypy-alpaca_lora_4bit.*]

View File

@@ -31,6 +31,7 @@ repos:
additional_dependencies:
[
'types-PyYAML',
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.5

185
README.md
View File

@@ -22,7 +22,7 @@ Features:
- [Introduction](#axolotl)
- [Supported Features](#axolotl-supports)
- [Quickstart](#quickstart-)
- [Installation](#installation)
- [Environment](#environment)
- [Docker](#docker)
- [Conda/Pip venv](#condapip-venv)
- [Cloud GPU](#cloud-gpu) - Latitude.sh, RunPod
@@ -87,15 +87,17 @@ Features:
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| Gemma | ✅ | ✅ | ✅ | ❓ | ❓ | ✅ | ❓ |
✅: supported
❌: not supported
❓: untested
## Quickstart ⚡
Get started with Axolotl in just a few steps! This quickstart guide will walk you through setting up and running a basic fine-tuning task.
**Requirements**: Python >=3.9 and Pytorch >=2.0.
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
**Requirements**: Python >=3.10 and Pytorch >=2.1.1.
### For developers
```bash
@@ -103,9 +105,18 @@ git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install packaging
```
General case:
```
pip3 install -e '.[flash-attn,deepspeed]'
```
Mac: see https://github.com/OpenAccess-AI-Collective/axolotl/blob/13199f678b9aab39e92961323bdbce3234ee4b2b/docs/mac.md
```
pip3 install -e '.'
```
### Usage
```bash
# preprocess datasets - optional but recommended
@@ -127,13 +138,14 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
```
## Installation
## Advanced Setup
### Environment
#### Docker
```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
docker run --gpus '"all"' --rm -it winglian/axolotl:main-latest
```
Or run on the current files for development:
@@ -152,7 +164,7 @@ accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAcc
A more powerful Docker command to run would be this:
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=bind,src="${PWD}",target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-latest
```
It additionally:
@@ -167,7 +179,7 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
</details>
#### Conda/Pip venv
1. Install python >=**3.9**
1. Install python >=**3.10**
2. Install pytorch stable https://pytorch.org/get-started/locally/
@@ -200,11 +212,11 @@ For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud
1. Install python
```bash
sudo apt update
sudo apt install -y python3.9
sudo apt install -y python3.10
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
sudo update-alternatives --config python # pick 3.9 if given option
python -V # should be 3.9
sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1
sudo update-alternatives --config python # pick 3.10 if given option
python -V # should be 3.10
```
@@ -242,15 +254,18 @@ Please use WSL or Docker!
#### Launching on public clouds via SkyPilot
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
```bash
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
sky check
```
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
```
git clone https://github.com/skypilot-org/skypilot.git
cd skypilot/llm/axolotl
```
Use one command to launch:
```bash
# On-demand
@@ -260,32 +275,33 @@ HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
```
### Dataset
Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
Have dataset(s) in one of the following format (JSONL recommended):
- `alpaca`: instruction; input(optional)
```json
{"instruction": "...", "input": "...", "output": "..."}
```
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
- `llama-2`: the json is the same format as `sharegpt` above, with the following config (see the [config section](#config) for more details)
```yml
datasets:
- path: <your-path>
type: sharegpt
conversation: llama-2
```
#### Pretraining
- `completion`: raw corpus
```json
{"text": "..."}
```
Note: Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
```yaml
pretraining_dataset: # hf path only
```
#### Supervised finetuning
##### Instruction
- `alpaca`: instruction; input(optional)
```json
{"instruction": "...", "input": "...", "output": "..."}
```
<details>
<summary>See other formats</summary>
@@ -362,14 +378,37 @@ Have dataset(s) in one of the following format (JSONL recommended):
```json
{"scores": "...", "critiques": "...", "instruction": "...", "answer": "...", "revision": "..."}
```
- `pygmalion`: pygmalion
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `metharme`: instruction, adds additional eos tokens
```json
{"prompt": "...", "generation": "..."}
```
</details>
##### Template-Free
- `input_output`: template-free prompt construction
```json
{"segments": [{"label": true|false, "text": "..."}]}
```
This is a special format that allows you to construct prompts without using templates. This is for advanced users who want more freedom with prompt construction. See [these docs](docs/input_output.md) for more details.
##### Conversation
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt)
```json
{"conversations": [{"from": "...", "value": "..."}]}
```
<details>
<summary>See other formats</summary>
- `pygmalion`: pygmalion
```json
{"conversations": [{"role": "...", "value": "..."}]}
```
- `sharegpt.load_role`: conversations where `role` is used instead of `from`
```json
{"conversations": [{"role": "...", "value": "..."}]}
@@ -385,6 +424,8 @@ Have dataset(s) in one of the following format (JSONL recommended):
</details>
Note: `type: sharegpt` opens a special config `conversation:` that enables conversions to many Conversation types. See dataset section under [all yaml options](#all-yaml-options).
#### How to add custom prompts
For a dataset that is preprocessed for instruction purposes:
@@ -406,12 +447,16 @@ datasets:
format: "[INST] {instruction} [/INST]"
no_input_format: "[INST] {instruction} [/INST]"
```
See full config options under [all yaml options](#all-yaml-options).
#### How to use your custom pretokenized dataset
- Do not pass a `type:`
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
```yaml
- path: ...
```
### Config
@@ -425,22 +470,18 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- dataset
```yaml
sequence_len: 2048 # max token length for prompt
# huggingface repo
datasets:
# huggingface repo
- path: vicgalle/alpaca-gpt4
type: alpaca # format from earlier
type: alpaca
# huggingface repo with specific configuration/subset
datasets:
# huggingface repo with specific configuration/subset
- path: EleutherAI/pile
name: enron_emails
type: completion # format from earlier
field: text # Optional[str] default: text, field to use for completion data
# huggingface repo with multiple named configurations/subsets
datasets:
# huggingface repo with multiple named configurations/subsets
- path: bigcode/commitpackft
name:
- ruby
@@ -448,34 +489,29 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript
type: ... # unimplemented custom format
# fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
datasets:
# fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
- path: ...
type: sharegpt
conversation: chatml
conversation: chatml # default: vicuna_v1.1
# local
datasets:
# local
- path: data.jsonl # or json
ds_type: json # see other options below
type: alpaca
# dataset with splits, but no train split
dataset:
# dataset with splits, but no train split
- path: knowrohit07/know_sql
type: context_qa.load_v2
train_on_split: validation
# loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access
dataset:
# loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
...
# Loading Data From a Public URL
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
dataset:
# Loading Data From a Public URL
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
ds_type: json # this is the default, see other options below.
```
@@ -484,9 +520,11 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
```yaml
load_in_4bit: true
load_in_8bit: true
bf16: auto # require >=ampere, auto will detect if your GPU supports this and choose automatically.
fp16: # leave empty to use fp16 when bf16 is 'auto'. set to false if you want to fallback to fp32
tf32: true # require >=ampere
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
float16: true # use instead of fp16 when you don't want AMP
```
@@ -494,7 +532,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- lora
```yaml
adapter: lora # qlora or leave blank for full finetune
adapter: lora # 'qlora' or leave blank for full finetune
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
@@ -503,9 +541,9 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- v_proj
```
<details>
<details id="all-yaml-options">
<summary>All yaml options (click me)</summary>
<summary>All yaml options (click to expand)</summary>
```yaml
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
@@ -517,8 +555,8 @@ base_model_ignore_patterns:
# You can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf
# You can specify to choose a specific model revision from huggingface hub
model_revision:
# Optional tokenizer configuration override in case you want to use a different tokenizer
revision_of_model:
# Optional tokenizer configuration path in case you want to use a different tokenizer
# than the one defined in the base model
tokenizer_config:
# If you want to specify the type of model to load, AutoModelForCausalLM is a good choice too
@@ -535,15 +573,16 @@ tokenizer_legacy:
# This is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# (Internal use only)
# Used to identify which the model is based on
is_falcon_derived_model:
is_llama_derived_model:
is_qwen_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model:
is_qwen_derived_model:
# optional overrides to the base model configuration
model_config:
overrides_of_model_config:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
@@ -560,8 +599,6 @@ bnb_config_kwargs:
# Whether you are training a 4-bit GPTQ quantized model
gptq: true
gptq_groupsize: 128 # group size
gptq_model_v1: false # v1 or v2
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true
@@ -635,7 +672,7 @@ test_datasets:
data_files:
- /workspace/data/eval.jsonl
# use RL training: dpo, ipo, kto_pair
# use RL training: 'dpo', 'ipo', 'kto_pair'
rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
@@ -655,7 +692,7 @@ dataset_processes: # defaults to os.cpu_count() if not set
# Only needed if cached dataset is taking too much storage
dataset_keep_in_memory:
# push checkpoints to hub
hub_model_id: # repo path to push finetuned model
hub_model_id: # private repo path to push finetuned model
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
@@ -751,6 +788,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# mlflow configuration if you're using it
mlflow_tracking_uri: # URI to mlflow
mlflow_experiment_name: # Your experiment name
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
# Where to save the full-finetuned model to
output_dir: ./completed-model
@@ -819,10 +857,6 @@ cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosin
# For one_cycle optim
lr_div_factor: # Learning rate div factor
# For log_sweep optim
log_sweep_min_lr:
log_sweep_max_lr:
# Specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
@@ -1106,7 +1140,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
### Merge LORA to base
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
The following command will merge your LORA adapater with your base model. You can optionally pass the argument `--lora_model_dir` to specify the directory where your LORA adapter was saved, otherwhise, this will be inferred from `output_dir` in your axolotl config file. The merged model is saved in the sub-directory `{lora_model_dir}/merged`.
```bash
python3 -m axolotl.cli.merge_lora your_config.yml --lora_model_dir="./completed-model"
@@ -1167,7 +1201,7 @@ If you decode a prompt constructed by axolotl, you might see spaces between toke
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same adjust your inference server accordingly.
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same, adjust your inference server accordingly.
4. As an additional troubleshooting step, you can look at the token ids between 1 and 2 to make sure they are identical.
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
@@ -1214,11 +1248,20 @@ PRs are **greatly welcome**!
Please run below to setup env
```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl
pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]'
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
# optional: run against all files
pre-commit run --all-files
```
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.

39
cicd/Dockerfile.jinja Normal file
View File

@@ -0,0 +1,39 @@
FROM winglian/axolotl-base:{{ BASE_TAG }}
ENV TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ENV AXOLOTL_EXTRAS="{{ AXOLOTL_EXTRAS }}"
ENV AXOLOTL_ARGS="{{ AXOLOTL_ARGS }}"
ENV CUDA="{{ CUDA }}"
ENV BNB_CUDA_VERSION="{{ CUDA }}"
ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
WORKDIR /workspace
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
WORKDIR /workspace/axolotl
RUN git fetch origin +$GITHUB_REF && \
git checkout FETCH_HEAD
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image
RUN pip install pytest
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
RUN git config --global credential.helper store

5
cicd/cicd.sh Executable file
View File

@@ -0,0 +1,5 @@
#!/bin/bash
pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest /workspace/axolotl/tests/e2e/patched/
pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/

75
cicd/tests.py Normal file
View File

@@ -0,0 +1,75 @@
"""
modal application to run axolotl gpu tests in Modal
"""
import os
import pathlib
import tempfile
import jinja2
import modal
from jinja2 import select_autoescape
from modal import Image, Stub
cicd_path = pathlib.Path(__file__).parent.resolve()
template_loader = jinja2.FileSystemLoader(searchpath=cicd_path)
template_env = jinja2.Environment(
loader=template_loader, autoescape=select_autoescape()
)
df_template = template_env.get_template("Dockerfile.jinja")
df_args = {
"AXOLOTL_EXTRAS": os.environ.get("AXOLOTL_EXTRAS", ""),
"AXOLOTL_ARGS": os.environ.get("AXOLOTL_ARGS", ""),
"PYTORCH_VERSION": os.environ.get("PYTORCH_VERSION", "2.0.1"),
"BASE_TAG": os.environ.get("BASE_TAG", "main-base-py3.10-cu118-2.0.1"),
"CUDA": os.environ.get("CUDA", "118"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
}
dockerfile_contents = df_template.render(**df_args)
temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
force_build=True,
gpu="A10G",
)
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
stub = Stub("Axolotl CI/CD", secrets=[])
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):
import subprocess # nosec
# Propagate errors from subprocess.
if exit_code := subprocess.call(cmd.split(), cwd=run_folder): # nosec
exit(exit_code) # pylint: disable=consider-using-sys-exit
@stub.function(
image=cicd_image,
gpu=GPU_CONFIG,
timeout=45 * 60,
cpu=8.0,
memory=131072,
)
def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
@stub.local_entrypoint()
def main():
cicd_pytest.remote()

View File

@@ -3,9 +3,10 @@ FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.0.1"
ARG PYTORCH_VERSION="2.1.2"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -20,9 +21,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image

View File

@@ -7,8 +7,8 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION a
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.9"
ARG PYTORCH_VERSION="2.0.1"
ARG PYTHON_VERSION="3.10"
ARG PYTORCH_VERSION="2.1.2"
ARG CUDA="118"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"

View File

@@ -3,9 +3,10 @@ FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG AXOLOTL_ARGS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.0.1"
ARG PYTORCH_VERSION="2.1.2"
ARG GITHUB_REF="main"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
@@ -24,9 +25,9 @@ RUN git fetch origin +$GITHUB_REF && \
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
fi
# So we can test the Docker image

View File

@@ -74,7 +74,6 @@ pip3 install -e '.[flash-attn,deepspeed]'
If you developing on a remote host, you can easily use VSCode to debug remotely. To do so, you will need to follow this [remote - SSH guide](https://code.visualstudio.com/docs/remote/ssh). You can also see the video below on [Docker and Remote SSH debugging](#video---attaching-to-docker-on-remote-host).
```bash
### Configuration

260
docs/input_output.md Normal file
View File

@@ -0,0 +1,260 @@
# Template-free prompt construction with the `input_output` format
<!-- TOC -->
- [Background](#background)
- [Masking Inputs](#masking-inputs)
- [You may not want prompt templates](#you-may-not-want-prompt-templates)
- [The `input_output` format](#the-input_output-format)
- [Usage](#usage)
- [1. Prepare Data](#1-prepare-data)
- [2. Use `type: input_output`](#2-use-type-input_output)
- [3. Check the prompts](#3-check-the-prompts)
<!-- /TOC -->
<a id="markdown-background" name="background"></a>
## Background
<a id="markdown-masking-inputs" name="masking-inputs"></a>
### Masking Inputs
One of the most popular features of
[axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) is
setting the following configuration value:
```yaml
train_on_inputs: false
```
If you declare a [dataset formats](https://github.com/OpenAccess-AI-Collective/axolotl?tab=readme-ov-file#dataset)
such as `alpaca` or `chatml`, axolotl knows what is an input
(i.e. human) vs. an output (i.e. the assistant) and masks the input
labels so that your model can focus on predicting the outputs only.
<a id="markdown-you-may-not-want-prompt-templates" name="you-may-not-want-prompt-templates"></a>
### You may not want prompt templates
However, there are many situations where you don't want to use one of
these formats or templates (I usually don't!). This is because they can:
- Add unnecessary boilerplate to your prompts.
- Create artifacts like special delimiters `<|im_start|>` that can
quickly become footguns if you don't include them correctly at
inference time.
- Enforce a *chat* interface when you do not want one. Sometimes you
just want to fine-tune a model to a very specific task and do NOT
want multi-turn conversations, roles, etc.
- Limit you to only certain roles that the template allows.
<a id="markdown-the-inputoutput-format" name="the-inputoutput-format"></a>
### The `input_output` format
You can construct your prompts without a template by using the
`input_output` format, by setting `type: input_output` in your
configuration file like this:
**config.yml**
```yaml
train_on_inputs: false # Mask segments of your data
datasets:
- path: output.jsonl
type: input_output # use template free prompt construction
```
Unlike `type: completion`, which is also template-free,
`type: input_output` allows you to mask segments of your text. More
details on how this works are described below.
<a id="markdown-usage" name="usage"></a>
## Usage
This is how you can use the `input_output` format:
<a id="markdown-1-prepare-data" name="1-prepare-data"></a>
### 1. Prepare Data
To use the `input_output` format, collect your data in the following
format into a jsonl file (below is the first row from the file
`output`.jsonl` pretty printed):
```bash
$ head -n1 output.jsonl | python -m json.tool
{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
```
Set `label:false` when you want to mask a segment of text so that the
model isn't trained on it. Some things to keep in mind:
> [!IMPORTANT]
> 1. **EOS, BOS, spaces, newlines etc. are entirely up to you. Axolotl
concatenates all the segments as-is.** The tokenizer doesn't add
anything additional. Notice how I added spaces, newlines, `<s>`
(BOS), and `</s>` (EOS) myself.
> 2. Make sure you check the materialized output to validate that the
prompt is getting assembled how you like.
<a id="markdown-2-use-type-inputoutput" name="2-use-type-inputoutput"></a>
### 2. Use `type: input_output`
Let's materialize data with our `output.jsonl` file by setting
`type: input_output` in our axolotl config:
```yaml
# training_config.yaml
base_model: mistralai/Mistral-7B-v0.1
data_seed: 49
seed: 49
datasets:
- path: output.jsonl
type: input_output
val_set_size: 0.1
sequence_len: 896
sample_packing: false
micro_batch_size: 2
gradient_accumulation_steps: 3
eval_batch_size: 2
num_epochs: 1
learning_rate: 0.0002
train_on_inputs: false
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
```
You can use the following command to materialize your data. The
`--debug` flag will print the tokens, along with the labels so you can
verify that the correct items are being ignored:
```bash
$ python -m axolotl.cli.preprocess training_config.yaml --debug
...
[2024-03-05 23:36:46,969] [INFO] [axolotl.check_example_labels:35] [PID:607731] [RANK:0] <s>(1, 1) Hello(22557, 22557)
(13, 13) hi(12014, 12014) there(736, 736) !(28808, 28808) .(28723, 28723) (28705, 28705) good(-100, 1179) bye(-100, 17664) (-100, 28705) fare(19111, 19111) well(5458, 5458) </s>(2, 2)
```
The format is `decoded_token`(`label`, `token_id`), for example,
`<s>(1, 1)` means that the token is `<s>`, the label is `1` and the
token_id is `1`. When the label is `-100` then that token is ignored for
training.
<a id="markdown-3-check-the-prompts" name="3-check-the-prompts"></a>
### 3. Check the prompts
Here is another way to check the materialized output:
```python
from transformers import AutoTokenizer
from datasets import load_from_disk
import yaml
directory = !ls last_run_prepared/
with open('training_config.yaml', 'r') as f:
cfg = yaml.safe_load(f)
model_id = cfg['base_model']
tok = AutoTokenizer.from_pretrained(model_id)
ds = load_from_disk(f'last_run_prepared/{directory[0]}/')
```
```python
>>> row = ds[0]
>>> print(tok.decode(row['input_ids']))
<s> Hello
hi there!. goodbye farewell</s>
```
We can check that the right tokens are ingored by comparing the labels
to each token:
```python
import pandas as pd
pd.DataFrame([{'token': tok.decode(i), 'label': l, 'id':i} for i,l in
zip(row['input_ids'], row['labels'])])
```
| token | label | id |
|-------|-------|-------|
| 0 | \<s\> | 1 |
| 1 | Hello | 22557 |
| 2 | \\n | 13 |
| 3 | hi | 12014 |
| 4 | there | 736 |
| 5 | ! | 28808 |
| 6 | . | 28723 |
| 7 | | 28705 |
| 8 | good | -100 |
| 9 | bye | -100 |
| 10 | | -100 |
| 11 | fare | 19111 |
| 12 | well | 5458 |
| 13 | \</s\>| 2 |
If we look at the input data, the above table seems correct! (The jsonl
version is repeated below for reference):
```bash
$ head -n1 output.jsonl | python -m json.tool
{.cell-output .cell-output-stdout}
{
"segments": [
{
"label": true,
"text": "<s>Hello\n"
},
{
"label": true,
"text": "hi there!. "
},
{
"label": false,
"text": "goodbye "
},
{
"label": true,
"text": "farewell</s>"
}
]
}
```

18
docs/mac.md Normal file
View File

@@ -0,0 +1,18 @@
# Mac M series support
Currently Axolotl on Mac is partially usable, many of the dependencies of Axolotl including Pytorch do not support MPS or have incomplete support.
Current support:
- [x] Support for all models
- [x] Full training of models
- [x] LoRA training
- [x] Sample packing
- [ ] FP16 and BF16 (awaiting AMP support for MPS in Pytorch)
- [ ] Tri-dao's flash-attn (until it is supported use spd_attention as an alternative)
- [ ] xformers
- [ ] bitsandbytes (meaning no 4/8 bits loading and bnb optimizers)
- [ ] qlora
- [ ] DeepSpeed
Untested:
- FSDP

View File

@@ -21,8 +21,8 @@ lora_dropout: 0.05
lora_target_linear: true
sequence_len: 4096
sample_packing: false
pad_to_sequence_len: false
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:

View File

@@ -0,0 +1,79 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./lora-out
eval_sample_packing: false
adapter: lora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 2
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
sdp_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -0,0 +1,69 @@
base_model: stabilityai/stablelm-2-1_6b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_steps: 100
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed: #deepspeed_configs/zero2.json # multi-gpu only
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:

View File

@@ -0,0 +1,66 @@
base_model: stabilityai/stablelm-2-1_6b
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
trust_remote_code: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -0,0 +1,36 @@
# StableLM 2
This repository contains examples for training and processing using StableLM-2. It also includes a section to help you estimate the GPU requirements for your specific use case.
## Estimating GPU Requirements
| type | deepspeed | batch size | context length | vRAM GPU (GBs) |
|---------------|-----------|------------|----------------|----------------|
| full finetune | N/A | 1 | 4096 | ~21.5GBs |
| full finetune | zero2 | 1 | 4096 | ~20GBs |
| lora | N/A | 1 | 4096 | ~16.6GBs |
The above are estimates and might differ slight depending on the setup for example whether you pack your sequence lengths or not (the above assumes you do to length 4096).
This blog post from Hamel Husain was a great resource for estimating these numbers: https://hamel.dev/notes/llm/03_estimating_vram.html
## Training
We have example scripts here for both full finetuning and lora using the popular alpaca dataset:
```shell
# preprocess the dataset
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess examples/stablelm-2/1.6b/lora.yml
```
Single GPU Training:
```shell
python -m axolotl.cli.train examples/stablelm-2/fft.yml --deepspeed deepspeed_configs/zero2.json
# OR
python -m axolotl.cli.train examples/stablelm-2/1.6b/lora.yml
```
Multinode GPU Training with `accelerate`:
```shell
# make sure you've configured accelerate properly
accelerate launch -m axolotl.cli.train examples/stablelm-2/1.6b/fft.yml --deepspeed deepspeed_configs/zero2.json
```

View File

@@ -0,0 +1,69 @@
base_model: bigcode/starcoder2-3b
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.2
output_dir: ./qlora
adapter: qlora
lora_model_dir:
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 20
evals_per_epoch: 4
eval_steps:
eval_table_size:
saves_per_epoch: 4
save_steps:
save_total_limit: 2
debug:
deepspeed:
weight_decay:
fsdp:
fsdp_config:
special_tokens:

View File

@@ -15,6 +15,7 @@ output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true
adapter: lora

View File

@@ -1,11 +1,12 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git@ae49b218c3d718df90d8e4a109016450fb8f0632
peft==0.9.0
transformers==4.38.2
tokenizers==0.15.0
bitsandbytes>=0.41.1
accelerate==0.26.1
deepspeed>=0.13.1
deepspeed==0.13.1
pydantic==2.6.3
addict
fire
PyYAML>=6.0
@@ -21,14 +22,13 @@ hf_transfer
colorama
numba
numpy>=1.24.4
mlflow
# qlora things
evaluate==0.4.1
scipy
scikit-learn==1.2.2
pynvml
art
fschat==0.2.34
fschat==0.2.36
gradio==3.50.2
tensorboard

View File

@@ -18,6 +18,7 @@ def parse_requirements():
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
@@ -73,7 +74,7 @@ setup(
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
],
"deepspeed": [
"deepspeed>=0.13.1",
"deepspeed==0.13.1",
"deepspeed-kernels",
],
"mamba-ssm": [
@@ -82,5 +83,11 @@ setup(
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"lion-pytorch": [
"lion-pytorch==0.1.2",
],
},
)

View File

@@ -13,7 +13,6 @@ from threading import Thread
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
import gradio as gr
import requests
import torch
import yaml
@@ -24,6 +23,7 @@ from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
@@ -214,6 +214,8 @@ def do_inference_gradio(
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
@@ -328,7 +330,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
@@ -341,7 +342,22 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
else:
cfg[k] = kwargs[k]
validate_config(cfg)
cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
cfg = validate_config(
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"n_gpu": os.environ.get("WORLD_SIZE", 1),
"compute_capability": gpu_version,
},
)
prepare_optim_env(cfg)

View File

@@ -5,6 +5,7 @@ Builder for the training args and trainer
import abc
import importlib
import importlib.util
import logging
import math
import sys
@@ -26,15 +27,16 @@ from transformers import (
TrainingArguments,
)
from transformers.trainer_utils import seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoMlflowCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
@@ -54,6 +56,9 @@ from axolotl.utils.schedulers import (
get_cosine_schedule_with_warmup_decay_constant,
)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
@@ -62,6 +67,10 @@ except ImportError:
LOG = logging.getLogger("axolotl.core.trainer_builder")
def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
@@ -175,6 +184,13 @@ class AxolotlTrainingArguments(TrainingArguments):
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
class AxolotlTrainer(Trainer):
@@ -199,6 +215,33 @@ class AxolotlTrainer(Trainer):
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
@@ -648,7 +691,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow:
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
)
callbacks.append(
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
)
@@ -694,7 +741,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlTrainer
def build(self, total_num_steps):
warmup_steps = None
if self.cfg.warmup_steps is not None:
warmup_steps = self.cfg.warmup_steps
elif self.cfg.warmup_ratio is not None:
@@ -907,6 +953,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler
if self.cfg.lr_scheduler
@@ -962,18 +1012,42 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
"neftune_noise_alpha"
] = self.cfg.neftune_noise_alpha
trainer_kwargs = {}
if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
if "weight_decay" in training_arguments_kwargs:
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
if (
"adam_beta1" in training_arguments_kwargs
and "adam_beta2" in training_arguments_kwargs
):
lion_kwargs["betas"] = (
training_arguments_kwargs["adam_beta1"],
training_arguments_kwargs["adam_beta2"],
)
trainer_kwargs["optimizers"] = (
Lion(params=self.model.parameters(), **lion_kwargs),
None,
)
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
)
training_args = self.hook_post_create_training_args(training_args)
trainer_kwargs = {}
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
data_collator_kwargs = {
"padding": True, # True/"longest" is the default

133
src/axolotl/loraplus.py Normal file
View File

@@ -0,0 +1,133 @@
"""Module for LoRA+"""
# MIT License
#
# Copyright (c) 2024 nikhil-ghosh-berkeley
# https://github.com/nikhil-ghosh-berkeley/loraplus
import logging
from functools import reduce
from peft.tuners import lora
from torch import nn
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names
LOG = logging.getLogger("axolotl.loraplus")
def get_module(name, opt_model):
"""
Retrieve a module from a model using its parameter name.
Args:
name (str): Full name of the parameter, typically including module path.
opt_model (torch.nn.Module): The model from which to retrieve the module.
Returns:
Module corresponding to the given name.
"""
parent_idx = 2 if "lora" in name else 1
module_names = name.split(sep=".")[:-parent_idx]
module = reduce(getattr, module_names, opt_model)
return module
def create_loraplus_optimizer(
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding=None,
):
"""
Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
Args:
opt_model (torch.nn.Module): The model for which the optimizer is being created.
optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
Returns:
An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
"""
assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
if loraplus_lr_embedding is None:
loraplus_lr_embedding = 1e-6
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
param_groups = {
"groupA": {},
"groupB": {},
"groupB_no_decay": {},
"embedding": {},
}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
module = get_module(name, opt_model)
if isinstance(module, lora.Embedding):
param_groups["embedding"][name] = param
elif "lora_B" in name or param.ndim == 1:
if name in decay_parameters:
param_groups["groupB"][name] = param
else:
param_groups["groupB_no_decay"][name] = param
else:
param_groups["groupA"][name] = param
assigned_param_groups = ""
for group, group_params in param_groups.items():
assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
LOG.info(assigned_param_groups)
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
optimizer_grouped_parameters = [
{
"params": list(param_groups["groupA"].values()),
"weight_decay": weight_decay,
"lr": lr,
},
{
"params": list(param_groups["embedding"].values()),
"weight_decay": weight_decay,
"lr": loraplus_lr_embedding,
},
{
"params": list(param_groups["groupB"].values()),
"weight_decay": weight_decay,
"lr": lr * loraplus_lr_ratio,
},
{
"params": list(param_groups["groupB_no_decay"].values()),
"weight_decay": 0.0,
"lr": lr * loraplus_lr_ratio,
},
]
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
return optimizer

View File

@@ -106,7 +106,7 @@ def get_turns( # pylint: disable=too-many-return-statements
if self.system_message:
contains_sys_msg = True
if self.messages:
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction separated by a newline
first_role, first_msg = self.messages[0]
if first_role == self.roles[0]:
system_prompt = self.system_template.format(

View File

@@ -44,6 +44,18 @@ except ImportError:
LOG = logging.getLogger("axolotl")
def is_xformers_swiglu_available() -> bool:
from xformers.ops.common import get_xformers_operator
try:
get_xformers_operator("swiglu_packedw")()
return True
except RuntimeError as exc:
if "No such operator xformers::swiglu_packedw " in str(exc):
return False
return True
def replace_llama_mlp_with_swiglu(model):
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):

View File

@@ -6,7 +6,14 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"]
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mixtral",
"qwen2",
"falcon",
"phi",
"gemma",
"starcoder2",
]
def patch_for_multipack(model_type):
@@ -32,3 +39,7 @@ def patch_for_multipack(model_type):
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -267,7 +267,7 @@ class ReLoRAScheduler(LRScheduler):
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.relora_steps:
if step < self.relora_steps - self.warmup_steps:
scale = 1
else:
per_relora_progress = step % self.relora_steps

View File

@@ -0,0 +1,78 @@
"""
HF Chat Templates prompt strategy
"""
from typing import Any, Dict, Optional
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import Prompter
from axolotl.utils.chat_templates import chat_templates
class ChatTemplatePrompter(Prompter):
"""prompter for HF chat templates"""
def __init__(self, tokenizer, chat_template=None, max_length=2048):
self.tokenizer = tokenizer
self.chat_template = chat_template
self.max_length = max_length
def build_prompt(self, conversation, add_generation_prompt=False):
return self.tokenizer.apply_chat_template(
conversation,
truncation=True,
max_length=self.max_length,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
class ChatTemplateStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for instruction-based prompts.
"""
def tokenize_prompt(self, prompt):
turns = self.get_conversation_thread(prompt)
prompt_ids = self.prompter.build_prompt([turns[0]], add_generation_prompt=True)
input_ids = self.prompter.build_prompt(turns)
if not self.train_on_inputs:
user_prompt_len = len(prompt_ids)
labels = [-100] * user_prompt_len + input_ids[user_prompt_len:]
else:
labels = input_ids
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
return tokenized_prompt
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap roles - allow for assistant turn
role_map = {
"human": "user",
"user": "user",
"assistant": "assistant",
"gpt": "assistant",
}
turns = [
{"role": role_map[t["from"]], "content": t["value"]} for t in conversations
]
return turns
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
chat_template = (
ds_cfg["chat_template"] if ds_cfg and "chat_template" in ds_cfg else "chatml"
)
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(tokenizer, chat_templates(chat_template)),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
return strategy

View File

@@ -8,14 +8,13 @@ import logging
LOG = logging.getLogger("axolotl")
def load(strategy, cfg):
def load(strategy, cfg, **kwargs):
try:
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
func = getattr(mod, load_fn)
load_kwargs = {}
return func(cfg, **load_kwargs)
return func(cfg, **kwargs)
except Exception: # pylint: disable=broad-exception-caught
LOG.warning(f"unable to load strategy {strategy}")
return None

View File

@@ -5,6 +5,7 @@ DPO strategies for chatml
def argilla(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
@@ -25,6 +26,7 @@ def argilla(
def icr(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
@@ -48,7 +50,7 @@ def icr(
return transform_fn
def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
For Intel Orca DPO Pairs
"""
@@ -70,7 +72,9 @@ def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
return transform_fn
def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def prompt_pairs(
cfg, **kwargs
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
@@ -88,7 +92,7 @@ def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argume
return transform_fn
def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
"""
for ultrafeedback binarized conversations
"""

View File

@@ -0,0 +1,41 @@
"""
User-defined DPO strategies
"""
def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
ds_cfg = cfg["datasets"][dataset_idx]["type"]
if not isinstance(ds_cfg, dict):
raise ValueError(
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
)
field_prompt = ds_cfg.get("field_prompt", "prompt")
field_system = ds_cfg.get("field_system", "system")
field_chosen = ds_cfg.get("field_chosen", "chosen")
field_rejected = ds_cfg.get("field_rejected", "rejected")
prompt_format = ds_cfg.get("prompt_format")
if not prompt_format:
prompt_format = "{" + field_prompt + "}"
chosen_format = ds_cfg.get("chosen_format")
if not chosen_format:
chosen_format = "{" + field_chosen + "}"
rejected_format = ds_cfg.get("rejected_format")
if not rejected_format:
rejected_format = "{" + field_rejected + "}"
def transform_fn(sample):
if (
"{" + field_system + "}" in prompt_format
and field_system in sample
and sample[field_system]
):
sample["prompt"] = prompt_format.format(
system=sample[field_system], prompt=sample[field_prompt]
)
else:
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
sample["chosen"] = chosen_format.format(chosen=sample[field_chosen])
sample["rejected"] = rejected_format.format(rejected=sample[field_rejected])
return sample
return transform_fn

View File

@@ -3,7 +3,7 @@ DPO strategies for zephyr
"""
def nectar(cfg): # pylint: disable=possibly-unused-variable,unused-argument
def nectar(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
data = {}
data["prompt"] = (

View File

@@ -0,0 +1,54 @@
"""Module for plain input/output prompt pairs"""
from typing import Generator, Tuple
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
class RawInputOutputStrategy(PromptTokenizingStrategy):
"""Prompt Strategy class for input/output pairs"""
def __init__(self, *args, eos_token=None, **kwargs):
super().__init__(*args, **kwargs)
self.eos_token = eos_token
if not eos_token:
self.eos_token = self.tokenizer.eos_token
def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
input_ids = []
labels = []
for label, text in self.prompter.build_prompt(prompt["segments"]):
tokenized_output = self.tokenizer(
text, add_special_tokens=False, return_tensors=None
)["input_ids"]
input_ids += tokenized_output
if label or self.train_on_inputs:
labels += tokenized_output
else:
labels += [IGNORE_TOKEN_ID] * len(tokenized_output)
tokenized_prompt = {
"input_ids": input_ids,
"labels": labels,
"attention_mask": [1] * len(input_ids),
}
return tokenized_prompt
class RawInputOutputPrompter(Prompter):
"""prompter for raw i/o data"""
def build_prompt(self, source) -> Generator[Tuple[bool, str], None, None]:
for segment in source:
yield segment["label"], segment["text"]
def load(tokenizer, cfg):
return RawInputOutputStrategy(
RawInputOutputPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -82,7 +82,7 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
basic sharegpt strategy to grab conversations from the sample row
"""
_strict = True
_strict = False
@property
def strict(self):
@@ -96,10 +96,25 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
conversations = prompt["conversations"]
if self.strict:
return conversations
# remap roles - allow for assistant turn
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
role_key = "from"
if "role" in conversations[0].keys():
role_key = "role"
value_key = "value"
if "text" in conversations[0].keys():
value_key = "text"
elif "content" in conversations[0].keys():
value_key = "content"
# remap roles - allow for assistant turn"
role_map = {
"user": "human",
"human": "human",
"assistant": "gpt",
"gpt": "gpt",
"system": "system",
}
turns = [
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
{"from": role_map[t[role_key]], "value": t[value_key]}
for t in conversations
]
return turns

View File

@@ -11,7 +11,7 @@ import torch
import transformers.modelcard
from accelerate.logging import get_logger
from datasets import Dataset
from peft import PeftModel
from peft import PeftModel, PeftModelForCausalLM
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
@@ -207,6 +207,20 @@ def train(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if cfg.adapter and isinstance(model, (PeftModel, PeftModelForCausalLM)):
model.to("cpu")
model = model.merge_and_unload()
if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
if not cfg.hub_model_id:
try:
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))

View File

@@ -9,7 +9,6 @@ from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Dict, List
import evaluate
import mlflow
import numpy as np
import pandas as pd
import torch
@@ -42,8 +41,8 @@ from axolotl.utils.distributed import (
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100
LOG = logging.getLogger("axolotl.callbacks")
class EvalFirstStepCallback(
@@ -756,31 +755,3 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
"""Callback to save axolotl config to mlflow"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
return control

View File

@@ -0,0 +1,44 @@
"""MLFlow module for trainer callbacks"""
import logging
from shutil import copyfile
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING
import mlflow
from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks")
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
# pylint: disable=duplicate-code
"""Callback to save axolotl config to mlflow"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
with NamedTemporaryFile(
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
) as temp_file:
copyfile(self.axolotl_config_path, temp_file.name)
mlflow.log_artifact(temp_file.name, artifact_path="")
LOG.info(
"The Axolotl config has been saved to the MLflow artifacts."
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
return control

View File

@@ -22,6 +22,7 @@ def chat_templates(user_choice: str):
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
}
if user_choice in templates:

View File

@@ -3,11 +3,16 @@ import json
import logging
import os
from pathlib import Path
from typing import Optional
import torch
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import (
AxolotlConfigWCapabilities,
AxolotlInputConfig,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config
@@ -119,7 +124,7 @@ def normalize_config(cfg):
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower()
or (cfg.model_type and "llama" in cfg.model_type.lower())
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
)
# figure out if the model is falcon
@@ -135,7 +140,7 @@ def normalize_config(cfg):
)
or cfg.is_falcon_derived_model
or "falcon" in cfg.base_model.lower()
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
or (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower())
)
cfg.is_mistral_derived_model = (
@@ -148,7 +153,7 @@ def normalize_config(cfg):
)
or cfg.is_mistral_derived_model
or "mistral" in cfg.base_model.lower().split("/")[-1]
or (cfg.model_type and "mistral" in cfg.model_type.lower())
or (cfg.type_of_model and "mistral" in cfg.type_of_model.lower())
)
cfg.is_qwen_derived_model = (
@@ -159,9 +164,6 @@ def normalize_config(cfg):
]
) or cfg.is_qwen_derived_model
if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)
if isinstance(cfg.pretraining_dataset, dict):
cfg.pretraining_dataset = [cfg.pretraining_dataset]
@@ -191,7 +193,21 @@ def normalize_cfg_datasets(cfg):
cfg.datasets[idx].conversation = "chatml"
def validate_config(cfg):
def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
if capabilities:
return DictDefault(
dict(
AxolotlConfigWCapabilities(
**cfg.to_dict(), capabilities=capabilities
).model_dump(exclude_unset=True)
)
)
return DictDefault(
dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_unset=True))
)
def legacy_validate_config(cfg):
"""
This is a "pre-validation" step that handles the yaml configuration before we have any
information about the model architecture
@@ -363,11 +379,11 @@ def validate_config(cfg):
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
)
if cfg.gptq and cfg.model_revision:
if cfg.gptq and cfg.revision_of_model:
raise ValueError(
"model_revision is not supported for GPTQ models. "
"revision_of_model is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove model_revision from the config."
+ "point to its path, and remove revision_of_model from the config."
)
# if cfg.sample_packing and cfg.sdp_attention:
@@ -480,9 +496,6 @@ def validate_config(cfg):
if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
if cfg.warmup_steps and cfg.warmup_ratio:
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id

View File

@@ -0,0 +1,991 @@
"""
Module for pydantic models for configuration
"""
import logging
import os
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
from axolotl.utils.config.models.internals import GPUCapabilities
LOG = logging.getLogger("axolotl.utils.config.models.input")
class DeprecatedParameters(BaseModel):
"""configurations that are deprecated"""
max_packed_sequence_len: Optional[int] = None
rope_scaling: Optional[Any] = None
noisy_embedding_alpha: Optional[float] = None
@field_validator("max_packed_sequence_len")
@classmethod
def validate_max_packed_sequence_len(cls, max_packed_sequence_len):
if max_packed_sequence_len:
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
return max_packed_sequence_len
@field_validator("rope_scaling")
@classmethod
def validate_rope_scaling(cls, rope_scaling):
if rope_scaling:
raise DeprecationWarning(
"`rope_scaling` is no longer supported, it should now be be a key under `model_config`"
)
return rope_scaling
@field_validator("noisy_embedding_alpha")
@classmethod
def validate_noisy_embedding_alpha(cls, noisy_embedding_alpha):
if noisy_embedding_alpha:
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
return noisy_embedding_alpha
class RemappedParameters(BaseModel):
"""parameters that have been remapped to other names"""
overrides_of_model_config: Optional[Dict[str, Any]] = Field(
default=None, alias="model_config"
)
type_of_model: Optional[str] = Field(default=None, alias="model_type")
revision_of_model: Optional[str] = Field(default=None, alias="model_revision")
class PretrainingDataset(BaseModel):
"""pretraining dataset configuration subset"""
path: Optional[str] = None
class UserDefinedPrompterType(BaseModel):
"""structure for user defined prompt types"""
system_prompt: Optional[str] = None
system_format: Optional[str] = None
field_system: Optional[str] = None
field_instruction: Optional[str] = None
field_input: Optional[str] = None
field_output: Optional[str] = None
format: Optional[str] = None
no_input_format: Optional[str] = None
field: Optional[str] = None
class SFTDataset(BaseModel):
"""SFT configuration subset"""
path: Optional[str] = None
split: Optional[str] = None
type: Optional[Union[str, UserDefinedPrompterType]] = None
shards: Optional[int] = None
conversation: Optional[str] = None
chat_template: Optional[str] = None
data_files: Optional[Union[str, List[str]]] = None
name: Optional[str] = None
ds_type: Optional[str] = None
train_on_split: Optional[str] = None
field_human: Optional[str] = None
field_model: Optional[str] = None
class UserDefinedDPOType(BaseModel):
"""User defined typing for DPO"""
field_system: Optional[str] = None
field_prompt: Optional[str] = None
field_chosen: Optional[str] = None
field_rejected: Optional[str] = None
prompt_format: Optional[str] = None
chosen_format: Optional[str] = None
rejected_format: Optional[str] = None
class DPODataset(BaseModel):
"""DPO configuration subset"""
path: Optional[str] = None
split: Optional[str] = None
type: Optional[Union[UserDefinedDPOType, str]] = None
data_files: Optional[List[str]] = None
class RLType(str, Enum):
"""RL trainer type configuration subset"""
dpo = "dpo" # pylint: disable=invalid-name
ipo = "ipo" # pylint: disable=invalid-name
kto_pair = "kto_pair" # pylint: disable=invalid-name
class ChatTemplate(str, Enum):
"""Chat templates configuration subset"""
chatml = "chatml" # pylint: disable=invalid-name
inst = "inst" # pylint: disable=invalid-name
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
loftq_bits: int = Field(default=4, metadata={"help": "Quantization bits for LoftQ"})
# loftq_iter: int = Field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
class PeftConfig(BaseModel):
"""peftq configuration subset"""
loftq_config: Optional[LoftQConfig] = None
class AutoType(str, Enum):
"""auto type string configuration subset - used for bf16"""
AUTO = "auto"
class SpecialTokensConfig(BaseModel):
"""Special tokens configuration subset"""
bos_token: Optional[str] = None
eos_token: Optional[str] = None
pad_token: Optional[str] = None
unk_token: Optional[str] = None
additional_special_tokens: Optional[List[str]] = None
class LoraConfig(BaseModel):
"""Peft / LoRA configuration subset"""
load_in_8bit: Optional[bool] = Field(default=False)
load_in_4bit: Optional[bool] = Field(default=False)
adapter: Optional[str] = None
lora_model_dir: Optional[str] = None
lora_r: Optional[int] = None
lora_alpha: Optional[int] = None
lora_fan_in_fan_out: Optional[bool] = None
lora_target_modules: Optional[List[str]] = None
lora_target_linear: Optional[bool] = None
lora_modules_to_save: Optional[List[str]] = None
lora_dropout: Optional[float] = None
peft_layers_to_transform: Optional[List[int]] = None
peft: Optional[PeftConfig] = None
peft_use_dora: Optional[bool] = None
lora_on_cpu: Optional[bool] = None
gptq: Optional[bool] = None
bnb_config_kwargs: Optional[Dict[str, Any]] = None
loraplus_lr_ratio: Optional[float] = Field(
default=None,
metadata={
"help": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
},
)
loraplus_lr_embedding: Optional[float] = Field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
merge_lora: Optional[bool] = None
@model_validator(mode="before")
@classmethod
def validate_adapter(cls, data):
if not data.get("adapter") and (
data.get("load_in_8bit") or data.get("load_in_4bit")
):
raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
return data
@model_validator(mode="after")
def validate_qlora(self):
if self.adapter == "qlora":
if self.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
if self.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit")
if self.gptq:
raise ValueError("Can't merge qlora if gptq")
if self.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
if self.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if self.gptq:
raise ValueError("Can't load qlora if gptq")
if not self.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
return self
@model_validator(mode="before")
@classmethod
def validate_quantized_dora(cls, data):
if data.get("peft_use_dora") and (
data.get("load_in_8bit") or data.get("load_in_4bit")
):
raise ValueError(
"`peft_use_dora` is not currently compatible with quantized weights."
)
return data
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
relora_steps: Optional[int] = None
relora_warmup_steps: Optional[int] = None
relora_anneal_steps: Optional[int] = None
relora_prune_ratio: Optional[float] = None
relora_cpu_offload: Optional[bool] = None
class ModelInputConfig(BaseModel):
"""model to train on configuration subset"""
base_model: str
base_model_config: Optional[str] = None
tokenizer_config: Optional[str] = None
tokenizer_use_fast: Optional[bool] = None
tokenizer_legacy: Optional[bool] = None
tokenizer_type: Optional[str] = Field(
default=None, metadata={"help": "transformers tokenizer class"}
)
trust_remote_code: Optional[bool] = None
@field_validator("trust_remote_code")
@classmethod
def hint_trust_remote_code(cls, trust_remote_code):
if trust_remote_code:
LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
)
return trust_remote_code
class HyperparametersConfig(BaseModel):
"""training hyperparams configuration subset"""
gradient_accumulation_steps: Optional[int] = Field(default=1)
micro_batch_size: Optional[int] = Field(
default=1,
metadata={"help": "per gpu micro batch size for training"},
)
batch_size: Optional[int] = Field(
default=None,
metadata={
"help": "Total batch size, we do not recommended setting this manually"
},
)
eval_batch_size: Optional[int] = Field(
default=None,
metadata={
"help": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
},
)
train_on_inputs: Optional[bool] = None
group_by_length: Optional[bool] = None
learning_rate: Union[str, float]
weight_decay: Optional[float] = None
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[SchedulerType] = None
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
lr_quadratic_warmup: Optional[bool] = None
cosine_min_lr_ratio: Optional[float] = None
cosine_constant_lr_ratio: Optional[float] = None
lr_div_factor: Optional[float] = None
adam_epsilon: Optional[float] = None
adam_beta1: Optional[float] = None
adam_beta2: Optional[float] = None
max_grad_norm: Optional[float] = None
num_epochs: int = Field(default=1)
@field_validator("batch_size")
@classmethod
def hint_batch_size_set(cls, batch_size):
if batch_size:
LOG.warning(
"%s\n%s",
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
return batch_size
@field_validator("learning_rate")
@classmethod
def convert_learning_rate(cls, learning_rate):
if learning_rate and isinstance(learning_rate, str):
learning_rate = float(learning_rate)
return learning_rate
class ModelOutputConfig(BaseModel):
"""model save configuration subset"""
output_dir: str = Field(default="./model-out")
hub_model_id: Optional[str] = None
hub_strategy: Optional[str] = None
save_safetensors: Optional[bool] = None
class MLFlowConfig(BaseModel):
"""mlflow configuration subset"""
use_mlflow: Optional[bool] = None
mlflow_tracking_uri: Optional[str] = None
mlflow_experiment_name: Optional[str] = None
hf_mlflow_log_artifacts: Optional[bool] = None
class WandbConfig(BaseModel):
"""wandb configuration subset"""
use_wandb: Optional[bool] = None
wandb_name: Optional[str] = None
wandb_run_id: Optional[str] = None
wandb_mode: Optional[str] = None
wandb_project: Optional[str] = None
wandb_entity: Optional[str] = None
wandb_watch: Optional[str] = None
wandb_log_model: Optional[str] = None
@model_validator(mode="before")
@classmethod
def check_wandb_run(cls, data):
if data.get("wandb_run_id") and not data.get("wandb_name"):
data["wandb_name"] = data.get("wandb_run_id")
LOG.warning(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
)
return data
# pylint: disable=too-many-public-methods,too-many-ancestors
class AxolotlInputConfig(
ModelInputConfig,
ModelOutputConfig,
LoraConfig,
ReLoRAConfig,
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
RemappedParameters,
DeprecatedParameters,
BaseModel,
):
"""wrapper of all config options"""
class Config:
"""Config for alias"""
populate_by_name = True
strict: Optional[bool] = Field(default=False)
resume_from_checkpoint: Optional[str] = None
auto_resume_from_checkpoints: Optional[bool] = None
resize_token_embeddings_to_32x: Optional[bool] = None
rl: Optional[RLType] = None
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
dataset_prepared_path: Optional[str] = None
dataset_shard_num: Optional[int] = None
dataset_shard_idx: Optional[int] = None
pretraining_dataset: Optional[ # type: ignore
conlist(Union[SFTDataset, PretrainingDataset], min_length=1)
] = Field(
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
)
dataset_processes: Optional[int] = Field(default=os.cpu_count())
dataset_keep_in_memory: Optional[bool] = None
dataloader_pin_memory: Optional[bool] = None
dataloader_num_workers: Optional[int] = None
dataloader_prefetch_factor: Optional[int] = None
dataloader_drop_last: Optional[bool] = None
push_dataset_to_hub: Optional[str] = None
hf_use_auth_token: Optional[bool] = None
device: Optional[Any] = None
device_map: Optional[Any] = None
world_size: Optional[int] = None
local_rank: Optional[int] = None
ddp: Optional[bool] = None
seed: Optional[int] = None
ddp_timeout: Optional[int] = None
ddp_bucket_cap_mb: Optional[int] = None
ddp_broadcast_buffers: Optional[bool] = None
ddp_find_unused_parameters: Optional[bool] = None
eval_table_size: Optional[int] = None
eval_max_new_tokens: Optional[int] = None
do_causal_lm_eval: Optional[bool] = None
eval_causal_lm_metrics: Optional[List[str]] = None
do_bench_eval: Optional[bool] = None
bench_dataset: Optional[str] = None
metric_for_best_model: Optional[str] = None
greater_is_better: Optional[bool] = None
loss_watchdog_threshold: Optional[float] = None
loss_watchdog_patience: Optional[int] = None
bf16: Optional[Union[AutoType, bool]] = AutoType.AUTO
fp16: Optional[bool] = None
bfloat16: Optional[bool] = None # for non-AMP cases
float16: Optional[bool] = None # for non-AMP cases
tf32: Optional[bool] = None
float32: Optional[bool] = None
# torch_dtype: Optional[torch.dtype]
gradient_checkpointing: Optional[bool] = Field(default=False)
gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
unfrozen_parameters: Optional[List[str]] = None
sequence_len: int = Field(default=1024)
sample_packing: Optional[bool] = None
eval_sample_packing: Optional[bool] = None
pad_to_sequence_len: Optional[bool] = None
xformers_attention: Optional[bool] = None
sdp_attention: Optional[bool] = None
s2_attention: Optional[bool] = None
flash_attention: Optional[bool] = None
flash_attn_cross_entropy: Optional[bool] = None
flash_attn_rms_norm: Optional[bool] = None
flash_attn_fuse_qkv: Optional[bool] = None
flash_attn_fuse_mlp: Optional[bool] = None
flash_optimum: Optional[bool] = None
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None
fsdp_config: Optional[Dict[str, Any]] = None
val_set_size: Optional[float] = Field(default=0.0)
special_tokens: Optional[SpecialTokensConfig] = None
tokens: Optional[List[str]] = None
torch_compile: Optional[bool] = None
torch_compile_backend: Optional[str] = None
max_steps: Optional[int] = None
warmup_steps: Optional[int] = None
warmup_ratio: Optional[float] = None
eval_steps: Optional[Union[int, float]] = None
evals_per_epoch: Optional[Union[int]] = None
evaluation_strategy: Optional[str] = None
save_steps: Optional[Union[int, float]] = None
saves_per_epoch: Optional[int] = None
save_strategy: Optional[str] = None
save_total_limit: Optional[int] = None
logging_steps: Optional[int] = None
early_stopping_patience: Optional[int] = None
load_best_model_at_end: Optional[bool] = False
neftune_noise_alpha: Optional[float] = None
max_memory: Optional[Union[int, str]] = None
gpu_memory_limit: Optional[Union[int, str]] = None
chat_template: Optional[Union[Literal["chatml", "inst"], ChatTemplate]] = None
default_system_message: Optional[str] = None
# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None
total_num_tokens: Optional[int] = None
total_supervised_tokens: Optional[int] = None
sample_packing_eff_est: Optional[float] = None
axolotl_config_path: Optional[str] = None
is_falcon_derived_model: Optional[bool] = Field(default=False)
is_llama_derived_model: Optional[bool] = Field(default=False)
is_mistral_derived_model: Optional[bool] = Field(default=False)
is_qwen_derived_model: Optional[bool] = Field(default=False)
@field_validator("datasets", mode="before")
@classmethod
def fix_sharegpt_datasets(cls, datasets):
for idx, ds_cfg in enumerate(datasets):
if not ds_cfg["type"]:
continue
if ds_cfg["type"] == "sharegpt:chat":
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
datasets[idx]["type"] = "sharegpt"
if "sharegpt_simple" in ds_cfg["type"]:
LOG.warning(
PendingDeprecationWarning(
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
)
)
datasets[idx]["type"] = datasets[idx]["type"].replace(
"sharegpt_simple", "sharegpt"
)
return datasets
@model_validator(mode="before")
@classmethod
def check_batch_size_fields(cls, data):
fields = ("micro_batch_size", "gradient_accumulation_steps", "batch_size")
non_empty_count = sum(1 for field in fields if data.get(field))
if non_empty_count < 2:
raise ValueError(f"At least two of {', '.join(fields)} must be set")
return data
@model_validator(mode="before")
@classmethod
def check_pretraining_w_max_steps(cls, data):
if data.get("pretraining_dataset") and not data.get("max_steps"):
raise ValueError(
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
)
return data
@model_validator(mode="before")
@classmethod
def check_pretraining_w_group_by_length(cls, data):
if data.get("pretraining_dataset") and data.get("group_by_length"):
LOG.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
)
return data
@model_validator(mode="before")
@classmethod
def check_gptq_w_revision(cls, data):
if data.get("gptq") and data.get("revision_of_model"):
raise ValueError(
"revision_of_model is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove revision_of_model from the config."
)
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_xformers(cls, data):
if data.get("sample_packing") and data.get("xformers_attention"):
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_rl(cls, data):
if data.get("sample_packing") and data.get("rl"):
raise ValueError("`sample_packing: true` does not work with RLHF training")
return data
@model_validator(mode="before")
@classmethod
def hint_sample_packing_padding(cls, data):
if data.get("sample_packing") and not data.get("pad_to_sequence_len"):
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
return data
@model_validator(mode="before")
@classmethod
def check_gas_bsz(cls, data):
if data.get("gradient_accumulation_steps") and data.get("batch_size"):
raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size"
)
return data
@model_validator(mode="before")
@classmethod
def hint_eval_train_mbsz(cls, data):
if (
data.get("eval_batch_size")
and data.get("micro_batch_size")
and data.get("eval_batch_size") != data.get("micro_batch_size")
):
LOG.warning(
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
)
return data
@model_validator(mode="before")
@classmethod
def check_push_ds_auth(cls, data):
if (
data.get("push_dataset_to_hub")
and data.get("hf_use_auth_token") is not True
):
raise ValueError(
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
)
return data
@model_validator(mode="after")
def check_falcon_fsdp(self):
if (self.base_model and "falcon" in self.base_model.lower()) and self.fsdp:
raise ValueError("FSDP is not supported for falcon models")
return self
@model_validator(mode="after")
def check_mpt_checkpointing(self):
if (
self.base_model and "mpt" in self.base_model.lower()
) and self.gradient_checkpointing:
raise ValueError("gradient_checkpointing is not supported for MPT models")
return self
@model_validator(mode="after")
def check_better_transformers(self):
if self.flash_optimum is True:
if self.adapter:
LOG.warning(
"BetterTransformers probably doesn't work with PEFT adapters"
)
if self.fp16 or self.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if self.float16 is not True and self.bfloat16 is not True:
LOG.warning(
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
)
return self
@model_validator(mode="after")
def check_adamw_optimizer_params(self):
if any([self.adam_beta1, self.adam_beta2, self.adam_epsilon]) and (
not self.optimizer or "adamw" not in self.optimizer.value
):
LOG.warning("adamw hyperparameters found, but no adamw optimizer set")
return self
@model_validator(mode="before")
@classmethod
def check_saves(cls, data):
if (
data.get("save_strategy")
and data.get("save_steps")
and data.get("save_strategy") != "steps"
):
raise ValueError(
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
)
if data.get("saves_per_epoch") and data.get("save_steps"):
raise ValueError(
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
)
return data
@model_validator(mode="before")
@classmethod
def check_push_save(cls, data):
if data.get("hub_model_id") and not (
data.get("save_steps") or data.get("saves_per_epoch")
):
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
)
return data
@model_validator(mode="before")
@classmethod
def check_evals(cls, data):
if (
data.get("evaluation_strategy")
and data.get("eval_steps")
and data.get("evaluation_strategy") != "steps"
):
raise ValueError(
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
)
if (
data.get("val_set_size") == 0
and (data.get("eval_steps") or data.get("evaluation_strategy"))
and not data.get("test_datasets")
):
raise ValueError(
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
)
if data.get("evals_per_epoch") and data.get("eval_steps"):
raise ValueError(
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
)
if (
data.get("evals_per_epoch")
and data.get("evaluation_strategy")
and data.get("evaluation_strategy") != "steps"
):
raise ValueError(
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
)
return data
@model_validator(mode="before")
@classmethod
def check_eval_packing(cls, data):
if (
data.get("sample_packing")
and data.get("eval_table_size")
and data.get("eval_sample_packing") is not False
):
raise ValueError(
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
)
return data
@model_validator(mode="before")
@classmethod
def check_warmup(cls, data):
if data.get("warmup_steps") and data.get("warmup_ratio"):
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
return data
@model_validator(mode="before")
@classmethod
def check_neftune(cls, data):
if data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"):
data["neftune_noise_alpha"] = data["noisy_embedding_alpha"]
del data["noisy_embedding_alpha"]
elif data.get("noisy_embedding_alpha") and not data.get("neftune_noise_alpha"):
raise ValueError(
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
)
return data
@field_validator("neftune_noise_alpha")
@classmethod
def validate_neftune_noise_alpha(cls, neftune_noise_alpha):
if neftune_noise_alpha is not None and neftune_noise_alpha <= 0.0:
raise ValueError("neftune_noise_alpha must be > 0.0")
return neftune_noise_alpha
@model_validator(mode="before")
@classmethod
def check_frozen(cls, data):
if (
data.get("adapter")
and data.get("peft_layers_to_transform")
and data.get("unfrozen_parameters")
):
raise ValueError(
"`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
)
return data
@model_validator(mode="after")
def check_fft_possible_bad_config(self):
if (
# pylint: disable=too-many-boolean-expressions
not (self.bf16 or self.bfloat16)
and (self.fp16 or self.float16)
and not self.adapter
and not self.flash_attention
and self.sample_packing
):
LOG.warning(
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
)
# ValueError: Attempting to unscale FP16 gradients.
# OR
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
return self
@model_validator(mode="after")
def check_fused_lora(self):
if self.adapter in ["lora", "qlora"] and (
self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp
):
raise ValueError("Fused modules are not supported with LoRA/QLoRA")
return self
@model_validator(mode="after")
def hint_lora_8bit(self):
loftq = (
self.peft and self.peft.loftq_config and self.peft.loftq_config.loftq_bits
)
if not self.load_in_8bit and self.adapter == "lora" and not loftq:
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
return self
@model_validator(mode="after")
def check_early_stopping(self):
if self.early_stopping_patience:
if not self.save_steps or not self.eval_steps:
raise ValueError(
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
)
if self.save_steps % self.eval_steps != 0:
raise ValueError(
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
)
return self
@model_validator(mode="after")
def check_relora(self):
if self.relora_steps:
if self.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
if self.fsdp:
raise ValueError("fsdp not supported with ReLoRA")
if self.deepspeed:
raise ValueError("deepspeed not supported with ReLoRA")
if self.lr_scheduler == "one_cycle":
raise ValueError(
"ReLoRA is not compatible with the one_cycle scheduler"
)
if self.flash_attn_fuse_qkv or self.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with ReLoRA")
return self
@model_validator(mode="before")
@classmethod
def check_mem_mismatch(cls, data):
if (
data.get("max_memory") is not None
and data.get("gpu_memory_limit") is not None
):
raise ValueError(
"max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
)
return data
@model_validator(mode="before")
@classmethod
def check_use_reentrant_mismatch(cls, data):
if (
data.get("unfrozen_parameters")
and data.get("gradient_checkpointing_kwargs")
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
is True
):
# https://github.com/huggingface/transformers/issues/21381
raise ValueError(
"`use_reentrant` must be false when used with partially frozen model."
)
return data
@model_validator(mode="before")
@classmethod
def check_val_w_test_datasets(cls, data):
if data.get("test_datasets") and data.get("val_set_size"):
raise ValueError(
"non-zero val_set_size should not be used with test_datasets configuration"
)
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_w_8bit_optimizer(cls, data):
if data.get("fsdp") and "bnb" in data.get("optimizer", ""):
raise ValueError(f"FSDP not compatible with {data.get('optimizer')}")
return data
@model_validator(mode="before")
@classmethod
def check_causal_lm_evals(cls, data):
if data.get("do_causal_lm_eval") and data.get("eval_sample_packing"):
raise ValueError(
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
)
if data.get("eval_causal_lm_metrics"):
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
if not isinstance(data.get("eval_causal_lm_metrics"), list):
raise ValueError("eval_causal_lm_metrics must be a list")
# only ["sacrebleu", "comet", "ter", "chrf"] supported
if set(data.get("eval_causal_lm_metrics")) - set(supported_metrics):
raise ValueError(
f"eval_causal_lm_metrics must be one of {supported_metrics}"
)
return data
@model_validator(mode="before")
@classmethod
def check_dataset_or_pretraining_dataset(cls, data):
if data.get("datasets") is None and data.get("pretraining_dataset") is None:
raise ValueError("either datasets or pretraining_dataset is required")
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
capabilities: GPUCapabilities
@model_validator(mode="after")
def check_bf16(self):
if self.capabilities.bf16:
if not self.bf16 and not self.bfloat16:
LOG.info(
"bf16 support detected, but not enabled for this configuration."
)
else:
if (
not self.merge_lora
and not self.is_preprocess
and (self.bf16 is True or self.bfloat16 is True)
):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
return self
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_sdpa_bf16(cls, data):
is_sm_90: bool = (
data["capabilities"]
and data["capabilities"].get("compute_capability") == "sm_90"
)
if (
data.get("sample_packing")
and data.get("sdp_attention")
and (data.get("bfloat16") or data.get("bf16"))
and not is_sm_90
):
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
LOG.warning(
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
"This may work on H100s."
)
return data

View File

@@ -0,0 +1,14 @@
"""module for gpu capabilities"""
from typing import Optional
from pydantic import BaseModel, Field
class GPUCapabilities(BaseModel):
"""model to manage the gpu capabilities statically"""
bf16: bool = Field(default=False)
fp8: bool = Field(default=False)
n_gpu: int = Field(default=1)
n_node: int = Field(default=1)
compute_capability: Optional[str] = Field(default=None)

View File

@@ -114,7 +114,9 @@ def prepare_dataset(cfg, tokenizer):
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
raise ValueError(
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
"eval dataset split is too small for sample_packing. "
"You should set `eval_sample_packing: False` "
"or decrease the value of `eval_batch_size`. "
)
if cfg.max_steps:
@@ -937,7 +939,9 @@ def load_prepare_dpo_datasets(cfg):
for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"]
if _type:
ds_transform_fn = load_dpo(_type, _cfg)
if isinstance(_type, DictDefault):
_type = "user_defined.default"
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
split_datasets[i] = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",

View File

@@ -12,4 +12,4 @@ class DictDefault(Dict):
return None
def __or__(self, other):
return DictDefault(super().__or__(other))
return DictDefault(super().__ror__(other))

View File

@@ -7,7 +7,7 @@ from axolotl.utils.dict import DictDefault
def setup_mlflow_env_vars(cfg: DictDefault):
for key in cfg.keys():
if key.startswith("mlflow_"):
if key.startswith("mlflow_") or key.startswith("hf_mlflow_"):
value = cfg.get(key, "")
if value and isinstance(value, str) and len(value) > 0:

View File

@@ -86,8 +86,8 @@ def load_model_config(cfg):
model_config_name = cfg.tokenizer_config
trust_remote_code = cfg.trust_remote_code is True
config_kwargs = {}
if cfg.model_revision:
config_kwargs["revision"] = cfg.model_revision
if cfg.revision_of_model:
config_kwargs["revision"] = cfg.revision_of_model
try:
model_config = AutoConfig.from_pretrained(
@@ -104,8 +104,8 @@ def load_model_config(cfg):
)
raise err
if cfg.model_config:
for key, val in cfg.model_config.items():
if cfg.overrides_of_model_config:
for key, val in cfg.overrides_of_model_config.items():
setattr(model_config, key, val)
check_model_config(cfg, model_config)
@@ -272,7 +272,7 @@ def load_model(
Load a model for a given configuration and tokenizer.
"""
base_model = cfg.base_model
model_type = cfg.model_type
model_type = cfg.type_of_model
model_config = load_model_config(cfg)
# TODO refactor as a kwarg
@@ -426,8 +426,8 @@ def load_model(
if is_deepspeed_zero3_enabled():
del model_kwargs["device_map"]
if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.revision_of_model:
model_kwargs["revision"] = cfg.revision_of_model
if cfg.gptq:
if not hasattr(model_config, "quantization_config"):
LOG.warning("model config does not contain quantization_config information")
@@ -512,11 +512,12 @@ def load_model(
if cfg.flash_attention and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
is_xformers_swiglu_available,
replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused,
)
if cfg.flash_attn_fuse_mlp:
if cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
LOG.info("patching with SwiGLU")
replace_llama_mlp_with_swiglu(model)
@@ -829,6 +830,8 @@ def load_lora(model, cfg, inference=False, config_only=False):
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
if cfg.peft_use_dora:
lora_config_kwargs["use_dora"] = cfg.peft_use_dora
lora_config = LoraConfig(
r=cfg.lora_r,

View File

@@ -5,7 +5,7 @@ Multipack Batch Sampler
import logging
import math
import os
from typing import Any, Iterable, List, Union
from typing import Any, Iterable, List, Union, Optional
import numba
import numpy as np
@@ -115,12 +115,14 @@ class MultipackBatchSampler(BatchSampler):
batch_max_len: int,
lengths: np.ndarray,
packing_efficiency_estimate: float = 1.0,
consistent_length: Optional[bool] = False,
):
super().__init__(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.consistent_length = consistent_length
assert isinstance(self.lengths, np.ndarray)
@@ -164,11 +166,18 @@ class MultipackBatchSampler(BatchSampler):
def __iter__(self):
batches = self.generate_batches(set_stats=True)
return iter(batches)
if self.consistent_length:
length = self._len_est()
return iter(batches[:length])
else:
return iter(batches)
def num_batches(self):
batches = self.generate_batches(set_stats=True)
return len(batches)
if self.consistent_length:
return self._len_est()
else:
return len(batches)
def efficiency(self):
return self.eff_total_used / self.eff_total_slots

View File

@@ -255,7 +255,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
train_dataset.remove_columns(["length"]),
batch_sampler=sampler,
)
data_loader_len = len(data_loader) // batch_size
data_loader_len = len(data_loader) // cfg.batch_size
actual_eff = sampler.efficiency()
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
# FIXME: is there a bug here somewhere? the total num steps depends
@@ -277,7 +277,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
calc_sample_packing_eff_est,
)
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
math.ceil(sample_packing_actual_eff_all * 10000.0) / 10000.0
)
if update:
cfg.sample_packing_eff_est = sample_packing_eff_est

View File

@@ -57,9 +57,9 @@ class TestFusedLlama(unittest.TestCase):
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"max_steps": 10,
"save_steps": 5,
"eval_steps": 5,
}
)
if is_torch_bf16_gpu_available():

View File

@@ -43,7 +43,7 @@ class TestLoraLlama(unittest.TestCase):
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"val_set_size": 0.2,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",

View File

@@ -7,6 +7,8 @@ import os
import unittest
from pathlib import Path
import pytest
from axolotl.cli import load_rl_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
@@ -19,6 +21,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip(reason="doesn't seem to work on modal")
class TestDPOLlamaLora(unittest.TestCase):
"""
Test case for DPO Llama models using LoRA

View File

@@ -7,6 +7,8 @@ import os
import unittest
from pathlib import Path
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
@@ -19,6 +21,7 @@ LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@pytest.mark.skip(reason="doesn't seem to work on modal")
class TestPhi(unittest.TestCase):
"""
Test case for Phi2 models

View File

@@ -0,0 +1,116 @@
"""
Test module for raw i/o data for prompts
"""
import pytest
from datasets import Dataset
from tokenizers import AddedToken
from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.input_output import (
RawInputOutputPrompter,
RawInputOutputStrategy,
)
@pytest.fixture(name="segments_dataset")
def fixture_sharegpt_dataset():
return Dataset.from_list(
[
{
"segments": [
{
"label": False,
"text": "<s>hello ",
},
{
"label": True,
"text": "hi there.<eot>",
},
{
"label": False,
"text": "goodbye ",
},
{
"label": True,
"text": "farewell<eot>",
},
]
}
]
)
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.add_tokens(
[
AddedToken("<eot>", rstrip=False, lstrip=False, normalized=False),
]
)
return tokenizer
class TestRawInputOutputPrompts:
"""
Test class for raw i/o prompter
"""
def test_segment_prompts(self, segments_dataset, tokenizer):
strategy = RawInputOutputStrategy(
RawInputOutputPrompter(),
tokenizer,
False, # train_on_inputs
2048, # sequence_len
)
dataset_wrapper = TokenizedPromptDataset(
strategy, segments_dataset, process_count=1
)
input_ids = dataset_wrapper[0]["input_ids"]
labels = dataset_wrapper[0]["labels"]
assert (
tokenizer.decode(input_ids)
== "<s> hello hi there.<eot> goodbye farewell<eot>"
)
# fmt: off
assert input_ids == [
1, # <s>
6312, # hell
28709, # o
28705, #
12014, # hi
736, # there
28723, # .
32000, # <eot>
1179, # good
17664, # bye
28705, #
19111, # fare
5458, # well
32000, # <eot>
]
# fmt: on
# fmt: off
assert labels == [
-100, # <s>
-100, # hell
-100, # o
-100, #
12014, # hi
736, # there
28723, # .
32000, # <eot>
-100, # good
-100, # bye
-100, #
19111, # fare
5458, # well
32000, # <eot>
]
# fmt: on

View File

@@ -39,7 +39,9 @@ class DictDefaultTest(unittest.TestCase):
), "DictDefault should support in operator for existing keys in list"
def test_dict_or_operator(self):
cfg = DictDefault(
cfg = DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"})
cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{
"key_a": {"key_b": "value_a"},
"key_c": "value_c",
@@ -48,10 +50,6 @@ class DictDefaultTest(unittest.TestCase):
}
)
cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
{"key_a": {"key_b": "value_b"}, "key_f": "value_g"}
)
assert (
cfg.key_a.key_b == "value_b"
), "DictDefault should support OR operator for existing nested keys"

View File

@@ -25,20 +25,6 @@ class NormalizeConfigTestCase(unittest.TestCase):
}
)
def test_lr_as_float(self):
cfg = (
self._get_base_cfg()
| DictDefault( # pylint: disable=unsupported-binary-operation
{
"learning_rate": "5e-5",
}
)
)
normalize_config(cfg)
assert cfg.learning_rate == 0.00005
def test_base_model_config_set_when_empty(self):
cfg = self._get_base_cfg()
del cfg.base_model_config

View File

@@ -204,13 +204,13 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
# fmt: off
# System message, multi-turn conversations
mt_ids = tokenize(test_data['multi_turn_sys'])
assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
assert mt_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
# System message, single-turn conversations
st_ids = tokenize(test_data['single_turn_sys'])
assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
assert st_ids == [1, 518, 25580, 29962, 29871, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
# No system message, single-turn
ns_ids = tokenize(test_data['single_turn_no_sys'])

File diff suppressed because it is too large Load Diff