Compare commits

...

61 Commits

Author SHA1 Message Date
Wing Lian
ca476d7f8e don't load the actual model when pre-loading to load modeling code 2023-09-20 13:37:32 -04:00
Javier
ec0958f4f8 Update requirements.txt (#610) 2023-09-20 08:40:49 -04:00
Wing Lian
faecff9798 support to disable exllama for gptq (#604)
* support to disable exllama for gptq

* update property instead of item

* fix config key
2023-09-19 17:51:08 -04:00
bofeng huang
aa656e04bd Delete duplicate lines (#606) 2023-09-19 16:40:05 -04:00
Wing Lian
b53e77775b update dockerfile to not build evoformer since it fails the build (#607) 2023-09-19 16:28:29 -04:00
Wing Lian
674c57692d more sane defaults for openllama 3b used for quickstarts (#602)
* more sane defaults for openllama 3b used for quickstarts

* don't use bf16 for quickstart to simplify gpu compatibility

* use the update openlm-research/open_llama_3b_v2 models
2023-09-19 09:15:10 -04:00
Wing Lian
1eebbd09c3 improve handling for empty text on the tokenization step (#502) 2023-09-19 08:09:56 -04:00
Wing Lian
62a774140b Fix for check with cfg and merge_lora (#600) 2023-09-18 21:14:32 -04:00
Wing Lian
31b9e0c6e8 minor tweaks to simplify (#597) 2023-09-18 11:45:44 -04:00
Wing Lian
6b9b229356 btlm and falcon monkey patches for flash attn (#566) 2023-09-17 13:49:18 -04:00
Wing Lian
131afdbd89 add bf16 check (#587) 2023-09-17 13:49:03 -04:00
NanoCode012
00dce35fb2 Feat(data): Allow loading local csv and text (#594)
* Feat(data): Allow loading local csv and text

* chore: update readme for loading data
2023-09-17 11:32:27 -04:00
Wing Lian
b15b19eb8d gather/broadcast the max value of the packing efficiency automatically (#463) 2023-09-17 11:08:18 -04:00
Wing Lian
ab534d75ba don't add position_ids for evals (#591) 2023-09-16 16:11:57 -04:00
Wing Lian
21ec195c9f optionally configure sample packing for evals (#589) 2023-09-16 00:09:48 -04:00
Wing Lian
62eaee7649 make phi training work with Loras (#588)
* valdiation for phi loras

* fix model config class check

* update readme for phi traiing
2023-09-15 20:51:55 -04:00
Jan Philipp Harries
be75668400 set fsdp state dict (#584)
Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
2023-09-15 17:47:36 -04:00
Wing Lian
aeec7c4688 pop block_cls since it's not an actual kwarg 2023-09-15 15:54:06 -04:00
Wing Lian
360788296a don't resize embeddings if it's already large enough (#577)
* don't resize embeddings if it's already large enough

* make sure to tie weights, even if we aren't resizing
2023-09-15 15:47:09 -04:00
Wing Lian
12a2dbbc2c Support Sample packing for phi arch (#586)
* phi sequence packing

* sample packing fixes

* fix linting

* fix inference and phi e2e tests

* update phi example now that sample packing works

* wandb import keeps getting moved around
2023-09-15 15:46:54 -04:00
NanoCode012
3a2edc85c3 Feat(doc): Add features to doc (#583) 2023-09-16 01:14:15 +09:00
Wing Lian
f7a22632d7 support custom field for completion from yml (#580)
* support custom field for completion from yml

* remove legacy completion check and add doc

* update README docs
2023-09-15 07:48:21 -04:00
Doan Minh Phuong
1aa400721e Fix Codellama examples (#582)
* Fix seq_len

* Update lora.yml

* Update qlora.yml

* Update lora.yml

* Update lora.yml

* Update qlora.yml
2023-09-15 04:19:13 -04:00
Wing Lian
8dcd40ac78 prevent cli functions from getting fired on import (#581) 2023-09-15 04:03:32 -04:00
Wing Lian
a5a625f47e update support matrix with btlm and phi (#579) 2023-09-15 02:46:15 -04:00
Wing Lian
861cecac2a refactor scripts/finetune.py into new cli modules (#550)
* refactor scripts/finetune.py into new cli modules

* continue to support scripts/finetune.py

* update readme with updated cli commands

* Update scripts/finetune.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-09-15 01:43:52 -04:00
Wing Lian
1078d3eae7 E2e passing tests (#576)
* run e2e tests after all other checks have passed

* tweak tests so they get run on PRs or push to main

* change dependent action for chcecking

* one test workflow to rule them all

* no need for custom action, just use needs

* whoops, python version should be a string

* e2e tests can run on any available gpu
2023-09-15 01:03:49 -04:00
Wing Lian
24146733db E2e device cuda (#575)
* use torch.cuda.current_device() instead of local_rank

* ignore NVML errors for gpu stats

* llama lora packing e2e tests
2023-09-14 22:49:27 -04:00
Wing Lian
9218ebecd2 e2e testing (#574) 2023-09-14 21:56:11 -04:00
Wing Lian
228420972e Phi examples (#569)
* add phi full ft example

* Add readme to point out that deepspeed should be used

* zero1 is better than zero2 for phi
2023-09-14 11:17:47 -04:00
Wing Lian
c6d870b91d mypy wandb ignore (#572)
* mypy wandb ignore

* fix isort for wandb
2023-09-14 11:17:30 -04:00
Wing Lian
115795079d remove columns after tokenizing for pretraining (#571) 2023-09-14 11:08:22 -04:00
Wing Lian
3b18c963cc set auto for other params that hf trainer sets for ds. include zero1 json (#570) 2023-09-14 11:04:37 -04:00
Wing Lian
3fbde762ab fix save_steps so it doesn't get duplicated (#567) 2023-09-13 20:40:33 -04:00
Wing Lian
f6060a664e Model parallel (#538)
* model-parallel for single process

* fix device/device_map

* fix handling for device
2023-09-13 11:45:30 -04:00
Wing Lian
a4e1bb6606 let hf trainer handle torch compile (#516)
* let hf trainer handle torch compile

* remove torch compile checks, include option for backend

* suppress torch errors to get further

* require min torch version of 2.1.0 for torch compile to work

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>
2023-09-13 11:42:12 -04:00
Wing Lian
36e53c7442 improve how we setup eval/save strategies and steps (#547)
* setup save end eval strategies to be consistent with trainer logic

* add comments

* better eval handling
2023-09-13 11:37:23 -04:00
Wing Lian
e7aa7b1a1e gracefully handle length feature used for group by (#565) 2023-09-13 11:23:30 -04:00
Wing Lian
e5bb22a56b add optimization for group-by-len (#563) 2023-09-13 10:57:12 -04:00
Wing Lian
fdb777bc06 check for the existence of the default accelerate config that can create headaches (#561) 2023-09-13 10:38:28 -04:00
Wing Lian
bf0804447c fix wandb so mypy doesn't complain (#562)
* fix wandb so mypy doesn't complain

* fix wandb so mypy doesn't complain

* no need for mypy override anymore
2023-09-13 10:36:16 -04:00
Glavin Wiechert
5b67ea98a6 Add training callback to send predictions to WandB table (#521)
* WIP Add training callback to send predictions to WandB table

* WIP improve wandb table reporting callback

* WIP improve wandb table reporting callback (cont)

* Add VSCode launching for debugging

* Add tiny llama example

* WIP attempt to improve post-eval prediction generation for table

* WIP attempt to improve post-eval prediction generation for table - part 2

* WIP batch generation

* WIP attempt to handle sample_packing using position_ids for wandb prediction table

* WIP add code for debugging

* Fix sample_packing support for wandb prediction table

* Clean up code for PR review

* Add eval_table_size, eval_table_max_new_tokens configs & clean up code

* Clean up PR, delete VSCode config, add tiny-llama example

* Add eval_table_size, eval_table_max_new_tokens documentation. Fix linting/formatting
2023-09-13 09:51:08 -04:00
Jan Philipp Harries
2f586d18db Fix pretraining with iterable/streaming Dataset (#556)
* return without packing prep/len

* fix remove columns

* fix encode arguments

* add error when max steps not set

* fix test

---------

Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
2023-09-13 00:16:40 -04:00
Wing Lian
9845c5e12d document that packaging needs to be installed before flash-attn (#559) 2023-09-12 12:18:30 -04:00
Wing Lian
772cd870d4 fix the sed command to replace the version w the tag
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-09-11 13:44:19 -04:00
Wing Lian
6c5fbe6223 add long_description for pypi push (#555) 2023-09-11 13:34:29 -04:00
Wing Lian
bcbc9597e9 replace tags, build dist for pypi publish (#553)
* replace tags, build dist for pypi publish

* missing trailing comma
2023-09-11 13:25:41 -04:00
The Objective Dad
6d57f2f0f0 ergonomic update to optimizer config doc (#548) 2023-09-11 12:35:45 -04:00
Wing Lian
20ed4c1f9e pypi on tag push (#552) 2023-09-11 10:33:42 -04:00
Wing Lian
c5dedb17ad remove with section, doesn't seem to work (#551) 2023-09-11 10:27:17 -04:00
Wing Lian
b56503d423 publish to pypi workflow on tagged release (#549) 2023-09-11 09:44:47 -04:00
Wing Lian
a94f9cb99e fix for quant config from model (#540) 2023-09-10 12:40:52 -04:00
dongxiaolong
c1921c9acb Update requirements.txt (#543)
fix fsdp
2023-09-08 16:07:11 -04:00
Wing Lian
0b4cf5bc8c workaround for md5 variations (#533)
* workaround for md5 variations

* refactor the prepared hash too
2023-09-08 16:01:05 -04:00
SlapDrone
78ee2cdab2 add git environment variables to compose: avoid checkout failure error 128 on build (#534) 2023-09-08 15:59:49 -04:00
Wing Lian
34c0a86a11 update readme to point to direct link to runpod template, cleanup install instrucitons (#532)
* update readme to point to direct link to runpod template, cleanup install instrucitons

* default install flash-attn and auto-gptq now too

* update readme w flash-attn extra

* fix version in setup
2023-09-08 11:58:54 -04:00
The Objective Dad
5e2d8a42d9 Adding NCCL Timeout Guide (#536)
* fixes NCCL_P2P_LEVEL=NVL #429

* adding more insights into verious values of NCCL_P2P_LEVEL
2023-09-08 11:57:47 -04:00
Wing Lian
e30f1e3cf7 Early stopping metric (#537)
* set early stopping metric to check

* tweak how load_best_model_at_end gets set for early stopping

* add validation for earl;y stopping patience

* remove negation

* save results to metrics in callback

* move early stopping callback after the benchmark evals

* broadcast metrics so early stopping works
2023-09-08 11:57:02 -04:00
Wing Lian
343714972b recommend padding when using sample packing (#531) 2023-09-06 17:00:21 -04:00
Wing Lian
245c5c41e2 log rank too (#527) 2023-09-06 08:37:51 -04:00
Wing Lian
a546ca2813 misc fixes/improvements (#513)
fix per pr feedback
2023-09-05 16:40:13 -04:00
66 changed files with 3193 additions and 482 deletions

View File

@@ -23,7 +23,7 @@ jobs:
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
runs-on: self-hosted runs-on: [self-hosted, gpu, docker]
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3
@@ -68,7 +68,7 @@ jobs:
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
runs-on: self-hosted runs-on: [self-hosted, gpu, docker]
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3

View File

@@ -1,16 +0,0 @@
name: pre-commit
on:
pull_request:
push:
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.9"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0

45
.github/workflows/pypi.yml vendored Normal file
View File

@@ -0,0 +1,45 @@
name: publish pypi
on:
push:
tags:
- '*'
jobs:
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/axolotl
permissions:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
steps:
- name: Check out repository code
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: |
pip3 install wheel
pip3 install -e .
pip3 install -r requirements-tests.txt
- name: Extract tag name
id: tag
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
- name: Update version in setup.py
run: >-
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
- name: Build a binary wheel
run: >-
python setup.py sdist bdist_wheel
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

View File

@@ -1,10 +1,26 @@
name: PyTest name: Tests
on: on:
# check on push/merge to main, PRs, and manual triggers
push: push:
branches:
- "main"
pull_request: pull_request:
workflow_dispatch:
jobs: jobs:
test: pre-commit:
name: pre-commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.9"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
pytest:
name: PyTest
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false
@@ -24,9 +40,35 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip install -e . pip3 install -e .
pip install -r requirements-tests.txt pip3 install -r requirements-tests.txt
- name: Run tests - name: Run tests
run: | run: |
pytest tests/ pytest --ignore=tests/e2e/ tests/
e2e-test:
name: E2E Tests
runs-on: [self-hosted, gpu]
timeout-minutes: 20
needs: [pre-commit, pytest]
steps:
- name: Check out repository code
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
# cache: 'pip' # caching pip dependencies
- name: Install dependencies
run: |
pip3 install -e .
pip3 install flash-attn
pip3 install -r requirements-tests.txt
- name: Run e2e tests
run: |
pytest tests/e2e/

View File

@@ -8,6 +8,9 @@ ignore_missing_imports = True
[mypy-axolotl.monkeypatch.*] [mypy-axolotl.monkeypatch.*]
ignore_errors = True ignore_errors = True
[mypy-axolotl.models.phi.*]
ignore_errors = True
[mypy-flash_attn.*] [mypy-flash_attn.*]
ignore_missing_imports = True ignore_missing_imports = True
@@ -20,6 +23,9 @@ ignore_missing_imports = True
[mypy-peft] [mypy-peft]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-wandb]
ignore_missing_imports = True
[mypy-bitsandbytes] [mypy-bitsandbytes]
ignore_missing_imports = True ignore_missing_imports = True

122
README.md
View File

@@ -2,6 +2,18 @@
Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures. Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.
Features:
- Train various Huggingface models such as llama, pythia, falcon, mpt
- Supports fullfinetune, lora, qlora, relora, and gptq
- Customize configurations using a simple yaml file or CLI overwrite
- Load different dataset formats, use custom formats, or bring your own tokenized datasets
- Integrated with xformer, flash attention, rope scaling, and multipacking
- Works with single GPU or multiple GPUs via FSDP or Deepspeed
- Easily run with Docker locally or on the cloud
- Log results and optionally checkpoints to wandb
- And more!
<table> <table>
<tr> <tr>
<td> <td>
@@ -51,14 +63,16 @@ Axolotl is a tool designed to streamline the fine-tuning of various AI models, o
## Axolotl supports ## Axolotl supports
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn | | | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|----------|:----------|:-----|-------|------|-------------------|------------|---------------| |----------|:----------|:-----|-------|------|-------------------|------------|--------------|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ | | Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ | | cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
| mpt | ✅ | | | ❌ | ❌ | ❌ | ❓ | | btlm | ✅ | | | ❌ | ❌ | ❌ | ❓ |
| falcon | ✅ | | | ❌ | ❌ | ❌ | ❓ | | mpt | ✅ | | | ❌ | ❌ | ❌ | ❓ |
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | | ❓ | | falcon | ✅ | ✅ | ✅ | ❌ | ❌ | | ❓ |
| XGen | ✅ | | ✅ | | | ❓ | | | gpt-j | ✅ | | ✅ | | | ❓ | |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
## Quickstart ⚡ ## Quickstart ⚡
@@ -71,15 +85,16 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
git clone https://github.com/OpenAccess-AI-Collective/axolotl git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl cd axolotl
pip3 install packaging
pip3 install -e .[flash-attn] pip3 install -e .[flash-attn]
pip3 install -U git+https://github.com/huggingface/peft.git pip3 install -U git+https://github.com/huggingface/peft.git
# finetune lora # finetune lora
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference # inference
accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--inference --lora_model_dir="./lora-out" --lora_model_dir="./lora-out"
``` ```
## Installation ## Installation
@@ -90,8 +105,7 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
```bash ```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1 docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
``` ```
- `winglian/axolotl-runpod:main-py3.10-cu118-2.0.1`: for runpod - `winglian/axolotl-runpod:main-latest`: for runpod or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.1-gptq`: for gptq
Or run on the current files for development: Or run on the current files for development:
@@ -104,19 +118,10 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
2. Install pytorch stable https://pytorch.org/get-started/locally/ 2. Install pytorch stable https://pytorch.org/get-started/locally/
3. Install python dependencies with ONE of the following: 3. Install axolotl along with python dependencies
- Recommended, supports QLoRA, NO gptq/int4 support
```bash ```bash
pip3 install -e . pip3 install packaging
pip3 install -U git+https://github.com/huggingface/peft.git pip3 install -e .[flash-attn]
```
- gptq/int4 support, NO QLoRA
```bash
pip3 install -e .[gptq]
```
- same as above but not recommended
```bash
pip3 install -e .[gptq_triton]
``` ```
- LambdaLabs - LambdaLabs
@@ -151,10 +156,10 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
git clone https://github.com/OpenAccess-AI-Collective/axolotl git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl cd axolotl
pip3 install -e . # change depend on needs pip3 install packaging
pip3 install -e .[flash-attn]
pip3 install protobuf==3.20.3 pip3 install protobuf==3.20.3
pip3 install -U --ignore-installed requests Pillow psutil scipy pip3 install -U --ignore-installed requests Pillow psutil scipy
pip3 install git+https://github.com/huggingface/peft.git # not for gptq
``` ```
5. Set path 5. Set path
@@ -329,6 +334,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- path: EleutherAI/pile - path: EleutherAI/pile
name: enron_emails name: enron_emails
type: completion # format from earlier type: completion # format from earlier
field: text # Optional[str] default: text, field to use for completion data
# huggingface repo with multiple named configurations/subsets # huggingface repo with multiple named configurations/subsets
datasets: datasets:
@@ -428,10 +434,10 @@ datasets:
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn> type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
ds_type: # Optional[str] (json|arrow|parquet) defines the datatype when path is a file ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file
data_files: # path to source data files data_files: # Optional[str] path to source data files
shards: # number of shards to split data into shards: # Optional[int] number of shards to split data into
name: # name of dataset configuration to load name: # Optional[str] name of dataset configuration to load
# custom user prompt # custom user prompt
- path: repo - path: repo
@@ -451,6 +457,9 @@ datasets:
# 'no_input_format' cannot include {input} # 'no_input_format' cannot include {input}
no_input_format: "{instruction} " no_input_format: "{instruction} "
# for completions datsets, uses the provided field if not `text`
field:
# axolotl attempts to save the dataset as an arrow after packing the data together so # axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path # subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared dataset_prepared_path: data/last_run_prepared
@@ -528,6 +537,10 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
# where to save the finished model to # where to save the finished model to
output_dir: ./completed-model output_dir: ./completed-model
# whether to use torch.compile and which backend to use
torch_compile: # bool
torch_compile_backend: # Optional[str]
# training hyperparameters # training hyperparameters
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 2 micro_batch_size: 2
@@ -543,6 +556,9 @@ eval_steps: # leave empty to eval at each epoch
save_total_limit: # checkpoints saved at a time save_total_limit: # checkpoints saved at a time
max_steps: max_steps:
eval_table_size: # approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # total number of tokens generated for predictions sent to wandb. Default is 128
# save model as safetensors (require safetensors package) # save model as safetensors (require safetensors package)
save_safetensors: save_safetensors:
@@ -572,6 +588,30 @@ log_sweep_min_lr:
log_sweep_max_lr: log_sweep_max_lr:
# specify optimizer # specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
#
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
# in the examples/ for your model and fine-tuning use case.
#
# Valid values for 'optimizer' include:
# - adamw_hf
# - adamw_torch
# - adamw_torch_fused
# - adamw_torch_xla
# - adamw_apex_fused
# - adafactor
# - adamw_anyprecision
# - sgd
# - adagrad
# - adamw_bnb_8bit
# - lion_8bit
# - lion_32bit
# - paged_adamw_32bit
# - paged_adamw_8bit
# - paged_lion_32bit
# - paged_lion_8bit
optimizer: optimizer:
# specify weight decay # specify weight decay
weight_decay: weight_decay:
@@ -652,14 +692,14 @@ strict:
Run Run
```bash ```bash
accelerate launch scripts/finetune.py your_config.yml accelerate launch -m axolotl.cli.train your_config.yml
``` ```
#### Multi-GPU #### Multi-GPU
You can optionally pre-tokenize dataset with the following before finetuning: You can optionally pre-tokenize dataset with the following before finetuning:
```bash ```bash
CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only CUDA_VISIBLE_DEVICES="" accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
``` ```
##### Config ##### Config
@@ -698,16 +738,16 @@ Pass the appropriate flag to the train command:
- Pretrained LORA: - Pretrained LORA:
```bash ```bash
--inference --lora_model_dir="./lora-output-dir" python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir"
``` ```
- Full weights finetune: - Full weights finetune:
```bash ```bash
--inference --base_model="./completed-model" python -m axolotl.cli.inference examples/your_config.yml --base_model="./completed-model"
``` ```
- Full weights finetune w/ a prompt from a text file: - Full weights finetune w/ a prompt from a text file:
```bash ```bash
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \ cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
--base_model="./completed-model" --inference --prompter=None --load_in_8bit=True --base_model="./completed-model" --prompter=None --load_in_8bit=True
``` ```
### Merge LORA to base ### Merge LORA to base
@@ -715,13 +755,13 @@ Pass the appropriate flag to the train command:
Add below flag to train command above Add below flag to train command above
```bash ```bash
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
``` ```
If you run out of CUDA memory, you can try to merge in system RAM with If you run out of CUDA memory, you can try to merge in system RAM with
```bash ```bash
CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ... CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
``` ```
## Common Errors 🧰 ## Common Errors 🧰
@@ -752,6 +792,10 @@ Try to turn off xformers.
It's safe to ignore it. It's safe to ignore it.
> NCCL Timeouts during training
See the [NCCL](docs/nccl.md) guide.
## Need help? 🙋♂️ ## Need help? 🙋♂️
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you

39
deepspeed/zero1.json Normal file
View File

@@ -0,0 +1,39 @@
{
"zero_optimization": {
"stage": 1,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -23,11 +23,8 @@
"type": "AdamW", "type": "AdamW",
"params": { "params": {
"lr": "auto", "lr": "auto",
"betas": [ "betas": "auto",
0.9, "eps": "auto",
0.999
],
"eps": 1e-8,
"weight_decay": "auto" "weight_decay": "auto"
} }
}, },

View File

@@ -36,7 +36,7 @@
"params": { "params": {
"lr": "auto", "lr": "auto",
"betas": "auto", "betas": "auto",
"eps": 1e-8, "eps": "auto",
"weight_decay": "auto" "weight_decay": "auto"
} }
}, },

View File

@@ -9,6 +9,11 @@ services:
- ~/.cache/huggingface/:/root/.cache/huggingface/ - ~/.cache/huggingface/:/root/.cache/huggingface/
# set environment variables # set environment variables
environment: environment:
# Set environment variables
- GIT_AUTHOR_NAME=${GIT_AUTHOR_NAME}
- GIT_AUTHOR_EMAIL=${GIT_AUTHOR_EMAIL}
- GIT_COMMITTER_NAME=${GIT_COMMITTER_NAME}
- GIT_COMMITTER_EMAIL=${GIT_COMMITTER_EMAIL}
- WANDB_API_KEY=${WANDB_API_KEY} - WANDB_API_KEY=${WANDB_API_KEY}
deploy: deploy:
resources: resources:

View File

@@ -15,9 +15,9 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN cd axolotl && \ RUN cd axolotl && \
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[flash-attn,gptq,$AXOLOTL_EXTRAS]; \ pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
else \ else \
pip install -e .[flash-attn,gptq]; \ pip install -e .[flash-attn]; \
fi fi
# fix so that git fetch/pull from remote works # fix so that git fetch/pull from remote works

View File

@@ -39,7 +39,7 @@ WORKDIR /workspace
RUN git clone https://github.com/microsoft/DeepSpeed.git && \ RUN git clone https://github.com/microsoft/DeepSpeed.git && \
cd DeepSpeed && \ cd DeepSpeed && \
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 python3 setup.py bdist_wheel MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 DS_BUILD_EVOFORMER_ATTN=0 python3 setup.py bdist_wheel
FROM base-builder AS bnb-builder FROM base-builder AS bnb-builder

46
docs/nccl.md Normal file
View File

@@ -0,0 +1,46 @@
# NCCL
NVIDIA NCCL is a library to facilitate and optimize multi-GPU communication operations, such as broadcast, all-gather, reduce, all-reduce, etc. Broadly, NCCL configuration is highly environment-specific and is configured via several [environment variables](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html). A common NCCL-related problem occurs when a long-running operation times out causing the training process to abort:
```text
Watchdog caught collective operation timeout: WorkNCCL(SeqNum=42, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1806948 milliseconds before timing out.
```
Often, this timeout will happen after 30 minutes (the default setting) and is accompanied by below-average power consumption with near 100% GPU utilization before the error is raised. Nvidia recommends [disabling PCI access control services (ACS)](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#pci-access-control-services-acs) as a possible solution if this is available to you.
Forcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command:
```shell
nvidia-smi nvlink --status
```
To force NCCL to use NVLink, simply set this in the environment:
```shell
export NCCL_P2P_LEVEL=NVL
```
If NVLink is not available in your environment there are other options for ``NCCL_P2P_LEVEL`` in the table below:
| NCCL_P2P_LEVEL | Description |
| -------------- | ----------- |
| PIX | P2P data transfers through no more than a single PCIe bridge. Faster data transfer rates vs to paths involving multiple bridges, but slower compared to direct GPU-to-GPU communication. |
| PXB | P2P data transfers through multiple PCIe bridges but not going through the PCIe Host Bridge; this path involves a complex routing process, potentially incurring a moderate level of latency. |
| PHB | P2P data transfers occur over the PCIe and through a PCIe Host Bridge, typically involving the CPU, which can facilitate direct memory access but might introduce additional latency compared to more direct paths (ex PIX, NVL) |
To validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example:
```shell
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
```
It can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL:
```shell
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL
export TORCH_DISTRIBUTED_DEBUG=INFO
export TORCHELASTIC_ERROR_FILE=/PATH/TO/torcherror.log
```
Finally, if you believe your training job needs more time you can increase the timeout past 30 minutes by setting the ``ddp_timeout`` value in the Axolotl configuration. See [PyTorch init_process_group](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for documentation on this value.

View File

@@ -0,0 +1,90 @@
base_model: cerebras/btlm-3b-8k-base
base_model_config: cerebras/btlm-3b-8k-base
model_type: AutoModelForCausalLM
tokenizer_type: GPT2Tokenizer
trust_remote_code: true
tokenizer_use_fast: true
tokenizer_legacy: true
load_in_8bit: false
load_in_4bit: false
strict: false
push_dataset_to_hub:
hf_use_auth_token: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_prepared_run
val_set_size: 0.01
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
sample_packing: false
sample_packing_eff_est:
sample_packing_seq_len_multiplier:
total_num_tokens:
lora_r:
lora_alpha:
lora_dropout:
lora_target_modules:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: btlm-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
adam_beta2: 0.95
adam_eps: 0.000000001
max_grad_norm: 1.0
torchdistx_path:
lr_scheduler: cosine
lr_quadratic_warmup: true
learning_rate: 0.000085
train_on_inputs: true
group_by_length: false
bf16: true
fp16: false
tf32: true
gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
sdp_attention:
flash_optimum:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 32
eval_steps:
save_steps:
save_total_limit:
debug:
deepspeed:
weight_decay: 0.1
special_tokens:
pad_token: "<|endoftext|>"
fsdp:
# - full_shard
# - auto_wrap
fsdp_config:
# fsdp_state_dict_type: FULL_STATE_DICT
# fsdp_transformer_layer_cls_to_wrap: BTLMBlock

View File

@@ -15,8 +15,9 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.01 val_set_size: 0.01
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 100000 sequence_len: 4096
sample_packing: true sample_packing: true
pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:

View File

@@ -18,8 +18,9 @@ output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 100000 sequence_len: 4096
sample_packing: true sample_packing: true
pad_to_sequence_len: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16

View File

@@ -15,8 +15,9 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.01 val_set_size: 0.01
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 100000 sequence_len: 4096
sample_packing: true sample_packing: true
pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:

View File

@@ -18,8 +18,9 @@ output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 100000 sequence_len: 4096
sample_packing: true sample_packing: true
pad_to_sequence_len: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16

View File

@@ -15,8 +15,9 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.01 val_set_size: 0.01
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 100000 sequence_len: 4096
sample_packing: true sample_packing: true
pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:

View File

@@ -18,8 +18,9 @@ output_dir: ./qlora-out
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 100000 sequence_len: 4096
sample_packing: true sample_packing: true
pad_to_sequence_len: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16

View File

@@ -2,7 +2,7 @@ base_model: TheBloke/Llama-2-7B-GPTQ
base_model_config: TheBloke/Llama-2-7B-GPTQ base_model_config: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false is_llama_derived_model: false
gptq: true gptq: true
gptq_bits: 4 gptq_disable_exllama: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
tokenizer_use_fast: true tokenizer_use_fast: true
@@ -62,8 +62,6 @@ xformers_attention:
flash_attention: flash_attention:
sdp_attention: sdp_attention:
flash_optimum: flash_optimum:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 100 warmup_steps: 100
eval_steps: eval_steps:
save_steps: save_steps:

View File

@@ -17,6 +17,7 @@ output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true
pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
@@ -55,6 +56,8 @@ flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 20
eval_table_size: 5
eval_table_max_new_tokens: 128
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -20,6 +20,7 @@ lora_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true
pad_to_sequence_len: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
@@ -57,6 +58,7 @@ flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 20
eval_table_size: 5
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -20,6 +20,7 @@ lora_model_dir:
sequence_len: 4096 sequence_len: 4096
sample_packing: true sample_packing: true
pad_to_sequence_len: true
lora_r: 8 lora_r: 8
lora_alpha: 16 lora_alpha: 16

View File

@@ -0,0 +1,69 @@
base_model: PY007/TinyLlama-1.1B-step-50K-105b
base_model_config: PY007/TinyLlama-1.1B-step-50K-105b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 20
eval_table_size: 5
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,5 +1,5 @@
base_model: openlm-research/open_llama_3b base_model: openlm-research/open_llama_3b_v2
base_model_config: openlm-research/open_llama_3b base_model_config: openlm-research/open_llama_3b_v2
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: false load_in_8bit: false
@@ -13,8 +13,8 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.02 val_set_size: 0.02
adapter: adapter:
lora_model_dir: lora_model_dir:
sequence_len: 256 sequence_len: 1024
max_packed_sequence_len: sample_packing: true
lora_r: lora_r:
lora_alpha: lora_alpha:
lora_dropout: lora_dropout:
@@ -29,11 +29,11 @@ wandb_log_model:
output_dir: ./openllama-out output_dir: ./openllama-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 3 num_epochs: 4
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.00001 learning_rate: 0.000003
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false
float16: true float16: true
@@ -45,12 +45,12 @@ early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
logging_steps: 1 logging_steps: 1
xformers_attention: true xformers_attention:
flash_attention: flash_attention: true
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 20
eval_steps: 50 eval_steps: 0.05
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,5 +1,5 @@
base_model: openlm-research/open_llama_3b base_model: openlm-research/open_llama_3b_v2
base_model_config: openlm-research/open_llama_3b base_model_config: openlm-research/open_llama_3b_v2
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: true load_in_8bit: true
@@ -13,8 +13,8 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.02 val_set_size: 0.02
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
sequence_len: 256 sequence_len: 1024
max_packed_sequence_len: sample_packing: true
lora_r: 8 lora_r: 8
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.0 lora_dropout: 0.0
@@ -33,9 +33,9 @@ wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./lora-out output_dir: ./lora-out
batch_size: 16 gradient_accumulation_steps: 1
micro_batch_size: 4 micro_batch_size: 2
num_epochs: 3 num_epochs: 4
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
@@ -50,16 +50,16 @@ early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
logging_steps: 1 logging_steps: 1
xformers_attention: true xformers_attention:
flash_attention: flash_attention: true
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 20
eval_steps: 50 eval_steps: 0.05
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.1
fsdp: fsdp:
fsdp_config: fsdp_config:
special_tokens: special_tokens:

View File

@@ -1,5 +1,5 @@
base_model: openlm-research/open_llama_3b base_model: openlm-research/open_llama_3b_v2
base_model_config: openlm-research/open_llama_3b base_model_config: openlm-research/open_llama_3b_v2
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: false load_in_8bit: false
@@ -13,8 +13,8 @@ dataset_prepared_path: last_run_prepared
val_set_size: 0.01 val_set_size: 0.01
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 1024
max_packed_sequence_len: 2048 sample_packing: true
lora_r: 8 lora_r: 8
lora_alpha: 32 lora_alpha: 32
lora_dropout: 0.05 lora_dropout: 0.05
@@ -27,33 +27,33 @@ wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out
batch_size: 4 gradient_accumulation_steps: 1
micro_batch_size: 4 micro_batch_size: 2
num_epochs: 2 num_epochs: 4
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false
bf16: true bf16: false
fp16: false fp16: true
tf32: true tf32: false
gradient_checkpointing: true gradient_checkpointing: true
early_stopping_patience: early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
logging_steps: 1 logging_steps: 1
xformers_attention: true xformers_attention:
flash_attention: flash_attention: true
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 20
eval_steps: 20 eval_steps: 0.05
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.1
fsdp: fsdp:
fsdp_config: fsdp_config:
special_tokens: special_tokens:

11
examples/phi/README.md Normal file
View File

@@ -0,0 +1,11 @@
# Phi
Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.
```shell
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed/zero1.json
# OR
python -m axolotl.cli.train examples/phi/phi-qlora.yml
```

75
examples/phi/phi-ft.yml Normal file
View File

@@ -0,0 +1,75 @@
base_model: microsoft/phi-1_5
base_model_config: microsoft/phi-1_5
model_type: MixFormerSequentialForCausalLM
tokenizer_type: AutoTokenizer
is_llama_derived_model: false
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./phi-sft-out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len:
adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0
lr_scheduler: cosine
learning_rate: 0.000003
train_on_inputs: false
group_by_length: true
bf16: true
fp16: false
tf32: true
gradient_checkpointing:
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 100
eval_steps: 0.05
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
resize_token_embeddings_to_32x: true
special_tokens:
bos_token: "<|endoftext|>"
eos_token: "<|endoftext|>"
unk_token: "<|endoftext|>"
pad_token: "<|endoftext|>"

View File

@@ -0,0 +1,75 @@
base_model: microsoft/phi-1_5
base_model_config: microsoft/phi-1_5
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_llama_derived_model: false
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./phi-sft-out
sequence_len: 1024
sample_packing: false # not CURRENTLY compatible with LoRAs
pad_to_sequence_len:
adapter: qlora
lora_model_dir:
lora_r: 64
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0
lr_scheduler: cosine
learning_rate: 0.000003
train_on_inputs: false
group_by_length: true
bf16: true
fp16: false
tf32: true
gradient_checkpointing:
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 100
eval_steps: 0.05
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
resize_token_embeddings_to_32x: true
special_tokens:
bos_token: "<|endoftext|>"
eos_token: "<|endoftext|>"
unk_token: "<|endoftext|>"
pad_token: "<|endoftext|>"

View File

@@ -6,13 +6,13 @@ packaging
peft @ git+https://github.com/huggingface/peft.git peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b accelerate @ git+https://github.com/huggingface/accelerate
addict addict
evaluate evaluate
fire fire
PyYAML>=6.0 PyYAML>=6.0
datasets datasets
flash-attn>=2.0.8 flash-attn>=2.2.1
sentencepiece sentencepiece
wandb wandb
einops einops

View File

@@ -1,262 +1,36 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import logging import logging
import os
import random
import sys
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import fire import fire
import torch
import transformers import transformers
import yaml
# add src to the pythonpath so we don't need to pip install this from axolotl.cli import (
from art import text2art check_accelerate_default_config,
from transformers import GenerationConfig, TextStreamer do_inference,
do_merge_lora,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.cli.shard import shard
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer LOG = logging.getLogger("axolotl.scripts.finetune")
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta, train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def print_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(" axolotl", font=font)
if is_main_process():
print(ascii_art)
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to finish): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
def do_merge_lora(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info("saving merged model")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def shard(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_inference(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
model = model.to(cfg.device)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def choose_config(path: Path):
yaml_files = list(path.glob("*.yml"))
if not yaml_files:
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return yaml_files[0]
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = yaml_files[choice - 1]
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
return chosen_file
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
return not any(el in list2 for el in list1)
def load_cfg(config: Path = Path("examples/"), **kwargs):
if Path(config).is_dir():
config = choose_config(config)
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
validate_config(cfg)
normalize_config(cfg)
setup_wandb_env_vars(cfg)
return cfg
def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art() print_axolotl_text_art()
LOG.warning(
str(
PendingDeprecationWarning(
"scripts/finetune.py will be replaced with calling axolotl.cli.train"
)
)
)
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
parser = transformers.HfArgumentParser((TrainerCliArgs)) parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True

View File

@@ -7,9 +7,7 @@ def parse_requirements():
_install_requires = [] _install_requires = []
_dependency_links = [] _dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file: with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [ lines = [r.strip() for r in requirements_file.readlines()]
r.strip() for r in requirements_file.readlines() if "auto-gptq" not in r
]
for line in lines: for line in lines:
if line.startswith("--extra-index-url"): if line.startswith("--extra-index-url"):
# Handle custom index URLs # Handle custom index URLs
@@ -26,18 +24,16 @@ install_requires, dependency_links = parse_requirements()
setup( setup(
name="axolotl", name="axolotl",
version="0.1", version="0.3.0",
description="You know you're going to axolotl questions", description="LLM Trainer",
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
package_dir={"": "src"}, package_dir={"": "src"},
packages=find_packages(), packages=find_packages(),
install_requires=install_requires, install_requires=install_requires,
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"gptq": [
"auto-gptq",
],
"flash-attn": [ "flash-attn": [
"flash-attn==2.0.8", "flash-attn>=2.2.1",
], ],
"extras": [ "extras": [
"deepspeed", "deepspeed",

249
src/axolotl/cli/__init__.py Normal file
View File

@@ -0,0 +1,249 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import logging
import os
import random
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
import yaml
# add src to the pythonpath so we don't need to pip install this
from accelerate.commands.config import config_args
from art import text2art
from transformers import GenerationConfig, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def print_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(" axolotl", font=font)
if is_main_process():
print(ascii_art)
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to finish): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
def do_merge_lora(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def do_inference(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
model = model.to(cfg.device)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def choose_config(path: Path):
yaml_files = list(path.glob("*.yml"))
if not yaml_files:
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return yaml_files[0]
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = yaml_files[choice - 1]
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
return chosen_file
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
return not any(el in list2 for el in list1)
def load_cfg(config: Path = Path("examples/"), **kwargs):
if Path(config).is_dir():
config = choose_config(config)
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
validate_config(cfg)
normalize_config(cfg)
setup_wandb_env_vars(cfg)
return cfg
def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def check_accelerate_default_config():
if Path(config_args.default_yaml_config_file).exists():
LOG.warning(
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
)

View File

@@ -0,0 +1,27 @@
"""
CLI to run inference on a trained model
"""
from pathlib import Path
import fire
import transformers
from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.inference = True
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

@@ -0,0 +1,27 @@
"""
CLI to run merge a trained LoRA into a base model
"""
from pathlib import Path
import fire
import transformers
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__":
fire.Fire(do_cli)

42
src/axolotl/cli/shard.py Normal file
View File

@@ -0,0 +1,42 @@
"""
CLI to shard a trained model into 10GiB chunks
"""
import logging
from pathlib import Path
import fire
import transformers
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.scripts")
def shard(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.shard = True
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__":
fire.Fire(do_cli)

36
src/axolotl/cli/train.py Normal file
View File

@@ -0,0 +1,36 @@
"""
CLI to run training on a model
"""
from pathlib import Path
import fire
import transformers
from axolotl.cli import (
check_accelerate_default_config,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

@@ -23,6 +23,7 @@ class ColorfulFormatter(Formatter):
} }
def format(self, record): def format(self, record):
record.rank = int(os.getenv("LOCAL_RANK", "0"))
log_message = super().format(record) log_message = super().format(record)
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
@@ -35,7 +36,7 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
}, },
"colorful": { "colorful": {
"()": ColorfulFormatter, "()": ColorfulFormatter,
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s", "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
}, },
}, },
"filters": {}, "filters": {},

View File

View File

@@ -0,0 +1,6 @@
"""
MixFormers model architecture used for phi models
"""
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa

View File

@@ -0,0 +1,63 @@
# pylint: skip-file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
from typing import Any, Dict, List, Optional, Union
from transformers import PretrainedConfig
class MixFormerSequentialConfig(PretrainedConfig):
"""MixFormer (sequential for DeepSpeed) configuration."""
model_type = "mixformer-sequential"
attribute_map = {
"max_position_embeddings": "n_positions",
"hidden_size": "n_embd",
"num_attention_heads": "n_head",
"num_hidden_layers": "n_layer",
"input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility
"blocks": "architecture", # `blocks` key is for backward compatibility
}
def __init__(
self,
vocab_size: Optional[int] = 50304,
n_positions: Optional[int] = 2048,
n_embd: Optional[int] = 1024,
n_layer: Optional[int] = 20,
n_inner: Optional[int] = None,
n_head: Optional[int] = 16,
rotary_dim: Optional[int] = 32,
activation_function: Optional[str] = "gelu_new",
embd_layer: Optional[str] = "default",
architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None,
embd_pdrop: Optional[float] = 0.0,
resid_pdrop: Optional[float] = 0.0,
layer_norm_epsilon: Optional[float] = 1e-5,
initializer_range: Optional[float] = 0.02,
tie_word_embeddings: Optional[bool] = False,
pad_vocab_size_multiple: Optional[int] = 64,
**kwargs
) -> None:
self.vocab_size = int(
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.rotary_dim = min(rotary_dim, n_embd // n_head)
self.activation_function = activation_function
self.embd_layer = embd_layer
self.architecture = architecture
self.embd_pdrop = embd_pdrop
self.resid_pdrop = resid_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

View File

@@ -0,0 +1,934 @@
# pylint: skip-file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
import copy
import inspect
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from flash_attn.flash_attn_interface import (
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from transformers import PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import CausalLMOutputWithPast
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
from .configuration_mixformer_sequential import MixFormerSequentialConfig
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.
Adapted from https://github.com/Dao-AILab/flash-attention."""
max_sequence_len: int
max_batch_size: int
sequence_len_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
fused_ft_kernel: bool = False
lengths_per_sample: Optional[torch.Tensor] = None
class Embedding(nn.Module):
"""Token embedding with dropout."""
def __init__(self, config: PretrainedConfig) -> None:
super().__init__()
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.wte(input_ids)
hidden_states = self.drop(hidden_states)
return hidden_states
class RotaryEmbedding(nn.Module):
"""PyTorch implementation of `flash-attn` RotaryEmbedding layer.
Adapted from https://github.com/Dao-AILab/flash-attention."""
def __init__(
self,
dim: int,
base: Optional[int] = 10000,
scale_base: Optional[float] = None,
device: Optional[str] = None,
**kwargs,
) -> None:
super().__init__()
if scale_base is not None:
raise NotImplementedError
# Generate and save the inverse frequency buffer (non-trainable)
self.dim = dim
self.base = base
self.scale_base = scale_base
self.device = device
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
self.register_buffer("inv_freq", inv_freq)
scale = (
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
/ (1.4 * dim)
if scale_base is not None
else None
)
self.register_buffer("scale", scale)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
def _update_cos_sin_cache(
self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0
) -> None:
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
seqlen = x.shape[1] + seqlen_offset
# Re-generate the inverse frequency buffer if it's not fp32
# (for instance if model.half() was called)
if self.inv_freq.dtype != "torch.float32":
self.inv_freq = 1.0 / (
self.base
** (
torch.arange(
0, self.dim, 2, device=self.device, dtype=torch.float32
)
/ self.dim
)
)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != x.device
or self._cos_cached.dtype != x.dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(
t, self.inv_freq.to(device=t.device, dtype=torch.float32)
)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(x.dtype)
self._sin_cached = torch.sin(freqs).to(x.dtype)
else:
power = (
torch.arange(
seqlen, dtype=self.scale.dtype, device=self.scale.device
)
- seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(
power, "s -> s 1"
)
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
def apply_rotary_emb_qkv(
self,
qkv: torch.FloatTensor,
sin: torch.FloatTensor,
cos: torch.FloatTensor,
sin_k: Optional[torch.FloatTensor] = None,
cos_k: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
_, seqlen, three, _, headdim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
assert (
sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
)
q_rot = qkv[:, :, 0, :, :rotary_dim]
q_pass = qkv[:, :, 0, :, rotary_dim:]
k_rot = qkv[:, :, 1, :, :rotary_dim]
k_pass = qkv[:, :, 1, :, rotary_dim:]
# Splits the queries and keys in half
q1, q2 = q_rot.chunk(2, dim=-1)
k1, k2 = k_rot.chunk(2, dim=-1)
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
sin[:seqlen], "s d -> s 1 d"
)
# Casts to fp32 are necessary to prevent fp16 overflow issues
q1, q2, k1, k2, c, s = [
t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]
]
# Computes the new keys and queries, recasting to original dtype
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
return torch.cat(
[
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
qkv[:, :, 2:3, :, :],
],
axis=2,
)
def forward(
self, qkv: torch.Tensor, seqlen_offset: int = 0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform the forward pass.
Args:
qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim).
seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch.
Returns:
New `qkv` and the cached sinusoids.
"""
self._update_cos_sin_cache(qkv, seqlen_offset)
return self.apply_rotary_emb_qkv(
qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:]
)
def _update_kv_cache(kv, inference_params, layer_idx):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
Adapted from https://github.com/Dao-AILab/flash-attention."""
# Pre-allocate memory for key-values for inference.
num_heads, head_dim = kv.shape[-2:]
if layer_idx not in inference_params.key_value_memory_dict:
kv_cache = torch.empty(
inference_params.max_batch_size,
inference_params.max_sequence_len,
2,
num_heads,
head_dim,
dtype=kv.dtype,
device=kv.device,
)
inference_params.key_value_memory_dict[layer_idx] = kv_cache
else:
kv_cache = inference_params.key_value_memory_dict[layer_idx]
# Adjust key and value for inference
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + kv.shape[1]
assert batch_end <= (
kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0] # noqa
)
assert sequence_end <= (
kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2] # noqa
)
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
class MLP(nn.Module):
"""Multi-Layer Perceptron.
Reference:
Attention Is All You Need.
https://arxiv.org/pdf/1706.03762.pdf.
"""
def __init__(
self,
config: PretrainedConfig,
n_inner: Optional[int] = None,
act_fn: Optional[str] = None,
) -> None:
super().__init__()
act_fn = config.activation_function if act_fn is None else act_fn
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
self.fc1 = nn.Linear(config.n_embd, n_inner)
self.fc2 = nn.Linear(n_inner, config.n_embd)
self.act = ACT2FN[act_fn]
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
old_keys = [
prefix + "fc_in.weight",
prefix + "fc_out.weight",
prefix + "fc_in.bias",
prefix + "fc_out.bias",
]
new_keys = [
prefix + "fc1.weight",
prefix + "fc2.weight",
prefix + "fc1.bias",
prefix + "fc2.bias",
]
if all(k in state_dict for k in old_keys) and not all(
k in state_dict for k in new_keys
):
# Older version of `MLP` saved with different key names.
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
return super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class FusedMLP(nn.Module):
"""Fused Multi-Layer Perceptron from `flash-attn`.
Reference:
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py.
"""
def __init__(
self,
config: PretrainedConfig,
n_inner: Optional[int] = None,
act_fn: Optional[str] = None,
raise_on_missing: bool = False,
) -> None:
super().__init__()
act_fn = config.activation_function if act_fn is None else act_fn
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"] # noqa
activation = "gelu_approx" if act_fn in gelu_activations else "relu" # noqa
self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
return self.mlp(hidden_states)
class SelfAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Adapted from https://github.com/Dao-AILab/flash-attention.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
def forward(
self, qkv, causal=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, S)
"""
causal = self.causal if causal is None else causal
if cu_seqlens is not None:
return flash_attn_varlen_qkvpacked_func(
qkv.squeeze(0),
cu_seqlens,
max_seqlen,
dropout_p=self.drop.p,
softmax_scale=self.softmax_scale,
causal=causal,
)
else:
return flash_attn_qkvpacked_func(
qkv,
dropout_p=self.drop.p,
softmax_scale=self.softmax_scale,
causal=causal,
)
class CrossAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Adapted from https://github.com/Dao-AILab/flash-attention.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
def forward(self, q, kv, causal=None, key_padding_mask=None):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk)
"""
causal = self.causal if causal is None else causal
return flash_attn_kvpacked_func(
q,
kv,
dropout_p=self.drop.p,
softmax_scale=self.softmax_scale,
causal=causal,
)
def find_mha_dims(
config: PretrainedConfig,
n_head: Optional[int] = None,
head_dim: Optional[int] = None,
) -> Tuple[int, int]:
"""Validate and return the number of heads and head dimension for multi-head attention.
Args:
config: Model configuration.
n_head: Number of heads.
head_dim: Head dimension.
Returns:
Number of heads and head dimension.
"""
assert all(
hasattr(config, attr) for attr in ["n_embd", "n_head"]
), "`config` must have `n_embd` and `n_head` attributes."
if head_dim is None:
assert (
config.n_embd % config.n_head == 0
), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
if n_head is None and head_dim is None:
head_dim = config.n_embd // config.n_head
n_head = config.n_head
elif n_head is None or head_dim is None:
raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
return n_head, head_dim
class MHA(nn.Module):
"""Multi-head attention layer.
Adapted from https://github.com/Dao-AILab/flash-attention."""
def __init__(
self,
config: PretrainedConfig,
rotary_dim: Optional[int] = None,
n_head: Optional[int] = None,
head_dim: Optional[int] = None,
bias: Optional[bool] = True,
dropout: Optional[float] = 0.0,
softmax_scale: Optional[float] = None,
causal: Optional[bool] = True,
layer_idx: Optional[int] = None,
rotary_emb_scale_base: Optional[float] = None,
return_residual: Optional[bool] = False,
checkpointing: Optional[bool] = False,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
fused_dense: Optional[bool] = True,
flash_attn: Optional[bool] = True,
cutlass_attn: Optional[bool] = False,
flash_rotary: Optional[bool] = True,
raise_on_missing: Optional[bool] = False,
) -> None:
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
n_head, head_dim = find_mha_dims(config, n_head, head_dim)
self.hidden_size = config.n_embd
self.n_head = n_head
self.head_dim = head_dim
self.op_size = n_head * head_dim
self.causal = causal
self.layer_idx = layer_idx
self.rotary_emb_dim = (
rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
)
self.fused_dense = fused_dense
self.flash_attn = flash_attn
self.cutlass_attn = cutlass_attn
self.flash_rotary = flash_rotary
self.return_residual = return_residual
self.checkpointing = checkpointing
if self.rotary_emb_dim > 0:
rotary_kwargs = {"device": device}
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
rotary_kwargs["scale_base"] = rotary_emb_scale_base
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
else:
pass
self.Wqkv = nn.Linear(
self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs
)
self.out_proj = nn.Linear(
self.op_size, self.hidden_size, bias=bias, **factory_kwargs
)
self.inner_attn = SelfAttention(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
self.inner_cross_attn = CrossAttention(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
def _update_kv_cache(
self, kv: torch.FloatTensor, inference_params: InferenceParams
) -> None:
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
Adapted from https://github.com/Dao-AILab/flash-attention."""
assert (
self.layer_idx is not None
), "Generation requires layer_idx in the constructor"
return _update_kv_cache(kv, inference_params, self.layer_idx)
def forward(
self,
x: torch.FloatTensor,
x_kv: Optional[torch.FloatTensor] = None,
key_padding_mask: Optional[torch.BoolTensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
mixer_subset: Optional[torch.LongTensor] = None,
past_cache: Optional[InferenceParams] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Perform the forward pass.
Args:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
is the is the sum of the sequence lengths in the batch.
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into x. Only applicable when using
FlashAttention.
max_seqlen: int. Maximum sequence length in the batch.
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
past_cache: For generation only.
Returns:
(batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None,
else (total, hidden_dim) where total is the is the sum of the sequence lengths
in the batch.
"""
if cu_seqlens is not None:
assert max_seqlen is not None
assert key_padding_mask is None
assert self.flash_attn
# assert self.rotary_emb_dim == 0
if key_padding_mask is not None:
assert cu_seqlens is None
assert max_seqlen is None
assert not self.flash_attn
if past_cache is not None:
assert key_padding_mask is None
assert cu_seqlens is None and max_seqlen is None
attn_kwargs = {"key_padding_mask": key_padding_mask}
assert x_kv is None and mixer_subset is None
qkv = self.Wqkv(x)
qkv = rearrange(
qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
)
if past_cache is None:
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv)
context = self.inner_attn(
qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, **attn_kwargs
)
else:
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset)
q = qkv[:, :, 0]
kv = self._update_kv_cache(qkv[:, :, 1:], past_cache)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if past_cache.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
out = rearrange(context, "... h d -> ... (h d)")
out = self.out_proj(out)
return out if not self.return_residual else (out, x)
class ParallelBlock(nn.Module):
"""Parallel block.
This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
"""
def __init__(
self,
config: PretrainedConfig,
mixer: Optional[Dict[str, Any]] = None,
mlp: Optional[Dict[str, Any]] = None,
block_idx: Optional[int] = None,
) -> None:
super().__init__()
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.block_idx = block_idx
self.mixer = MHA(config=config, **mixer, layer_idx=block_idx)
mlp_cls = mlp.pop("mlp_cls")
if mlp_cls == "fused_mlp":
self.mlp = FusedMLP(config=config, **mlp)
else:
self.mlp = MLP(config=config, **mlp)
def forward(
self,
hidden_states: torch.FloatTensor,
past_cache: Optional[torch.FloatTensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
) -> torch.FloatTensor:
residual = hidden_states
hidden_states = self.ln(hidden_states)
attn_outputs = self.mixer(
hidden_states,
past_cache=past_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
if isinstance(attn_outputs, tuple):
attn_outputs = attn_outputs[0]
attn_outputs = self.resid_dropout(attn_outputs)
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
hidden_states = attn_outputs + feed_forward_hidden_states + residual
return hidden_states
class CausalLMHead(nn.Module):
"""Causal Language Modeling head.
Reference:
Improving Language Understanding by Generative Pre-Training.
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
"""
def __init__(self, config: PretrainedConfig) -> None:
super().__init__()
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.linear = nn.Linear(config.n_embd, config.vocab_size)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.ln(hidden_states)
logits = self.linear(hidden_states).to(torch.float32)
return logits
class CausalLMLoss(nn.Module):
"""Causal Language Modeling loss.
Reference:
Improving Language Understanding by Generative Pre-Training.
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
"""
def __init__(self, shift_labels: Optional[bool] = True) -> None:
super().__init__()
self.shift_labels = shift_labels
self.loss_fct = nn.CrossEntropyLoss()
def forward(
self, logits: torch.FloatTensor, labels: torch.LongTensor
) -> torch.FloatTensor:
if self.shift_labels:
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
return loss
class MixFormerSequentialPreTrainedModel(PreTrainedModel):
"""MixFormer (sequential for DeepSpeed) pre-trained model."""
config_class = MixFormerSequentialConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
def __init__(self, *inputs, **kwargs) -> None:
super().__init__(*inputs, **kwargs)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, **kwargs
) -> Dict[str, Any]:
if "use_cache" in kwargs and not kwargs["use_cache"]:
return {"input_ids": input_ids}
if past_key_values is None or not (
isinstance(past_key_values, InferenceParams)
):
past_key_values = InferenceParams(
max_batch_size=input_ids.shape[0],
max_sequence_len=self.config.n_positions,
sequence_len_offset=0,
batch_size_offset=0,
fused_ft_kernel=False,
key_value_memory_dict={},
)
else:
# assume past_key_values has cached all but last token in input_ids
past_key_values.sequence_len_offset = len(input_ids[0]) - 1
input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}
class PackedSequential(nn.Sequential):
def forward(
self,
input,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
):
for module in self:
sig = inspect.signature(module.forward)
if "cu_seqlens" in sig.parameters:
input = module(input, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
else:
input = module(input)
return input
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
"""MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
_keys_to_ignore_on_load_missing = [""]
_keys_to_ignore_on_load_unexpected = [
r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"
]
_no_split_modules = ["ParallelBlock"]
def __init__(self, config: MixFormerSequentialConfig) -> None:
super().__init__(config)
modules = [Embedding(config)]
block_config = config.architecture
if not isinstance(block_config, list):
block_config = [block_config for _ in range(config.n_layer)]
if config.n_layer != len(block_config):
config.n_layer = len(block_config)
for block_idx, block in enumerate(block_config):
# `block_cls` with `legacy` value is for backward compatibility
# `path` key is for backward compatibility
block = copy.deepcopy(block) or {"block_cls": "parallel"}
block.pop("path", None) or block.pop("block_cls", None)
block["block_idx"] = block_idx
modules.append(ParallelBlock(config, **block))
modules.append(CausalLMHead(config))
self.layers = PackedSequential(*modules)
self.loss = CausalLMLoss()
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.layers[0].wte
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
self.layers[0].wte = new_embeddings
def get_output_embeddings(self) -> nn.Linear:
return self.layers[-1].linear
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.layers[-1].linear = new_embeddings
def forward(
self,
input_ids: torch.LongTensor,
labels: Optional[torch.LongTensor] = None,
past_key_values: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
cu_seqlens: Optional[torch.LongTensor] = None
max_seqlen: Optional[int] = None
if position_ids is not None:
batch_size, seq_length = input_ids.shape
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if not past_key_values:
lm_logits = self.layers(
input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
)
else:
hidden_layer = self.layers[0](input_ids)
for module in self.layers[1:-1]:
hidden_layer = module(
hidden_layer,
past_cache=past_key_values,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
lm_logits = self.layers[-1](hidden_layer)
loss = None
if labels is not None:
loss = self.loss(lm_logits, labels)
return CausalLMOutputWithPast(
loss=loss, logits=lm_logits, past_key_values=past_key_values
)

View File

@@ -0,0 +1,66 @@
"""
Flash attention monkey patch for cerebras btlm model
"""
import importlib
import logging
from typing import Optional, Tuple
import accelerate
import torch
from flash_attn.flash_attn_interface import flash_attn_func
from transformers import AutoConfig, AutoModelForCausalLM
LOG = logging.getLogger("axolotl")
def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
# this is a wonky hack to get the remotely loaded module
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_btlm to be available
with accelerate.init_empty_weights():
AutoModelForCausalLM(model_config)
module_name = model_config.__class__.__module__.replace(
".configuration_btlm", ".modeling_btlm"
)
modeling_btlm = importlib.import_module(module_name)
modeling_btlm.BTLMAttention._attn = ( # pylint: disable=protected-access
flashattn_attn
)
def flashattn_attn(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
head_mask: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
softmax_scale = (
1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None
)
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
# Perform Flash attention
attn_output = flash_attn_func(
query,
key,
value,
dropout_p=0.0, # Assuming you have this attribute
softmax_scale=softmax_scale, # Set this if you have specific scaling in mind
causal=not self.is_cross_attention, # Assuming you have this attribute
return_attn_probs=False, # Set this based on your needs
)
# Optional: Apply head mask if it's not None
if head_mask is not None:
attn_output *= head_mask
attn_output = attn_output.permute(0, 2, 1, 3)
return attn_output, None # We don't have explicit attn_weights in Flash attention

View File

@@ -0,0 +1,101 @@
"""
Flash Attention monkey patch for Falcon
copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py
"""
from typing import Optional, Tuple
import torch
import transformers
from flash_attn import flash_attn_func
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor, # pylint: disable=unused-argument
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
use_cache: bool = False,
output_attentions: bool = False, # pylint: disable=unused-argument
):
fused_qkv = self.query_key_value(
hidden_states
) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = (
self.num_heads if self.new_decoder_architecture else self.num_kv_heads
)
# 3 x [batch_size, seq_length, num_heads, head_dim]
(
query_layer,
key_layer,
value_layer,
) = self._split_heads( # pylint: disable=protected-access
fused_qkv
)
batch_size, query_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(
batch_size * self.num_heads, query_length, self.head_dim
)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads, query_length, self.head_dim
)
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
# unused
# _, kv_length, _ = key_layer.shape
if use_cache:
present = (key_layer, value_layer)
else:
present = None
# unused
# attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
query_layer_ = (
query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
.transpose(1, 2)
.to(torch.bfloat16)
)
key_layer_ = (
key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
.transpose(1, 2)
.to(torch.bfloat16)
)
value_layer_ = (
value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
.transpose(1, 2)
.to(torch.bfloat16)
)
if alibi is not None:
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
# below output will have shape (batch_size, seqlen, nheads, headdim)
attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True)
attn_output = attn_output.reshape(
batch_size, query_length, self.num_heads * self.head_dim
)
output_tensor = self.dense(attn_output)
return output_tensor, present
def replace_falcon_attn_with_flash_attn():
transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward

View File

@@ -193,7 +193,7 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen # only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape is_causal = key_states.shape == query_states.shape
if cu_seqlens is not None and max_seqlen is not None: if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing # special handling using sample packing
qkv = torch.stack( qkv = torch.stack(
[query_states, key_states, value_states], dim=2 [query_states, key_states, value_states], dim=2
@@ -261,6 +261,8 @@ def flashattn_forward(
if attention_mask is not None if attention_mask is not None
else None, else None,
) )
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func( output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad, q_unpad,
kv_unpad, kv_unpad,

View File

@@ -1,6 +1,7 @@
"""Module to load prompt strategies.""" """Module to load prompt strategies."""
import importlib import importlib
import inspect
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
@@ -16,6 +17,10 @@ def load(strategy, tokenizer, cfg, ds_cfg):
load_kwargs = {} load_kwargs = {}
if strategy == "user_defined": if strategy == "user_defined":
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg) load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
else:
sig = inspect.signature(func)
if "ds_cfg" in sig.parameters:
load_kwargs["ds_cfg"] = ds_cfg
return func(tokenizer, cfg, **load_kwargs) return func(tokenizer, cfg, **load_kwargs)
except Exception: # pylint: disable=broad-exception-caught except Exception: # pylint: disable=broad-exception-caught
return None return None

View File

@@ -0,0 +1,20 @@
"""
Basic completion text
"""
from typing import Any, Dict, Optional
from axolotl.prompt_tokenizers import CompletionPromptTokenizingStrategy
from axolotl.prompters import CompletionPrompter
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
strat = CompletionPromptTokenizingStrategy(
CompletionPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
if ds_cfg and "field" in ds_cfg:
strat.field = ds_cfg["field"]
return strat

View File

@@ -6,7 +6,7 @@ import functools
import logging import logging
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
from transformers import PreTrainedTokenizer from transformers import BatchEncoding, PreTrainedTokenizer
from axolotl.prompters import IGNORE_TOKEN_ID from axolotl.prompters import IGNORE_TOKEN_ID
@@ -66,14 +66,21 @@ class PromptTokenizingStrategy(abc.ABC):
pass pass
return False return False
def _tokenize(self, prompt: str, add_eos_token=True, strip_bos_token=False): def _tokenize(
result = self.tokenizer( self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
prompt, ) -> BatchEncoding:
truncation=True, result: BatchEncoding
max_length=self.sequence_len, if not prompt.strip():
padding=False, LOG.warning("Empty text requested for tokenization.")
return_tensors=None, result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
) else:
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if len(result["input_ids"]) == 0: if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset") LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
if ( if (
@@ -245,8 +252,31 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
Tokenizing strategy for Completion prompts. Tokenizing strategy for Completion prompts.
""" """
_field: str = "text"
@property
def field(self) -> str:
return self._field
@field.setter
def field(self, new_field: str):
self._field = new_field
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt[self.field],
"",
"",
)
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
full_prompt = self._build_full_prompt(prompt["text"], None, None) (
instruction,
_,
_,
) = self.parse_instruction_fields(prompt)
full_prompt = self._build_full_prompt(instruction, None, None)
tokenized_full_prompt = self._tokenize(full_prompt) tokenized_full_prompt = self._tokenize(full_prompt)
return tokenized_full_prompt return tokenized_full_prompt

View File

@@ -80,14 +80,15 @@ def train(
model.config.use_cache = False model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
LOG.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect # go ahead and presave, so we have the adapter config available to inspect
if peft_config: if peft_config:
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir) peft_config.save_pretrained(cfg.output_dir)
# additionally presave the tokenizer and model configs
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
model.config.save_pretrained(str(Path(cfg.output_dir)))
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0: if cfg.local_rank == 0:
@@ -106,9 +107,6 @@ def train(
if cfg.group_by_length: if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length") LOG.info("hang tight... sorting dataset for group_by_length")
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
tokenizer.save_pretrained(cfg.output_dir)
if cfg.flash_optimum: if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel( with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True enable_flash=True, enable_math=True, enable_mem_efficient=True
@@ -119,6 +117,10 @@ def train(
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
if cfg.relora_steps: if cfg.relora_steps:
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit):
model = model.merge_and_unload() model = model.merge_and_unload()

View File

@@ -2,6 +2,7 @@
import pynvml import pynvml
import torch import torch
from pynvml.nvml import NVMLError
def gpu_memory_usage(device=0): def gpu_memory_usage(device=0):
@@ -20,15 +21,17 @@ def gpu_memory_usage_smi(device=0):
device = device.index device = device.index
if isinstance(device, str) and device.startswith("cuda:"): if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:]) device = int(device[5:])
try:
pynvml.nvmlInit() pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device) handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle) info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.used / 1024.0**3 return info.used / 1024.0**3
except NVMLError:
return 0.0
def log_gpu_memory_usage(log, msg, device): def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available(): if not torch.cuda.is_available() or device == "auto":
return (0, 0, 0) return (0, 0, 0)
usage, cache, misc = gpu_memory_usage_all(device) usage, cache, misc = gpu_memory_usage_all(device)

View File

@@ -11,10 +11,13 @@ import numpy as np
import pandas as pd import pandas as pd
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import wandb
from datasets import load_dataset from datasets import load_dataset
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from tqdm import tqdm from tqdm import tqdm
from transformers import ( from transformers import (
GenerationConfig,
Trainer,
TrainerCallback, TrainerCallback,
TrainerControl, TrainerControl,
TrainerState, TrainerState,
@@ -25,6 +28,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.distributed import ( from axolotl.utils.distributed import (
barrier, barrier,
broadcast_dict,
gather_scalar_from_all_ranks, gather_scalar_from_all_ranks,
get_world_size, get_world_size,
is_distributed, is_distributed,
@@ -271,6 +275,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
lambda: len(data_loader), get_world_size() lambda: len(data_loader), get_world_size()
) )
results = {}
if is_distributed() and not is_main_process(): if is_distributed() and not is_main_process():
dist.gather_object(local_bench_names, dst=0) dist.gather_object(local_bench_names, dst=0)
else: else:
@@ -316,4 +321,196 @@ def bench_eval_callback_factory(trainer, tokenizer):
)["accuracy"] )["accuracy"]
trainer.log(results) trainer.log(results)
results = broadcast_dict(results)
for key, val in results.items():
metrics[key] = val
return BenchEvalCallback return BenchEvalCallback
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
class LogPredictionCallback(TrainerCallback):
"""Callback to log prediction values during each evaluation"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
def on_evaluate(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState,
control: TrainerControl,
train_dataloader, # pylint: disable=unused-argument
eval_dataloader,
**kwargs, # pylint: disable=unused-argument
):
eval_table_size = self.cfg.eval_table_size
if eval_table_size <= 0:
return control
trainer.model.eval()
device = torch.device(self.cfg.device)
# pylint: disable=duplicate-code
generation_config = GenerationConfig(
max_new_tokens=self.cfg.eval_table_max_new_tokens,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
def logits_to_tokens(logits) -> torch.Tensor:
probabilities = torch.softmax(logits, dim=-1)
# Get the predicted token ids (the ones with the highest probability)
predicted_token_ids = torch.argmax(probabilities, dim=-1)
return predicted_token_ids
def find_ranges(lst):
ranges = []
start = 0
for i in range(1, len(lst)):
if lst[i] == 0:
ranges.append((start, i - 1))
start = i
end = len(lst) - 1
ranges.append((start, end))
return ranges
def log_table_from_dataloader(name: str, table_dataloader):
table = wandb.Table( # type: ignore[attr-defined]
columns=[
"id",
"Prompt",
"Correct Completion",
"Predicted Completion (model.generate)",
"Predicted Completion (trainer.prediction_step)",
]
)
row_index = 0
for batch in tqdm(table_dataloader):
if row_index > eval_table_size:
break
batch_labels = batch["labels"].to(device)
batch_input_ids = batch["input_ids"].to(device)
if "position_ids" in batch:
batch_pos_ids = batch["position_ids"].tolist()
else:
batch_pos_ids = [None] * len(batch["input_ids"])
(_, batch_logits, _) = trainer.prediction_step(
trainer.model,
batch,
prediction_loss_only=False,
)
prompt_token_ids_list = []
pred_step_token_ids_list = []
completion_token_ids_list = []
for input_ids_all, labels_all, pos_ids, logits in zip(
batch_input_ids,
batch_labels,
batch_pos_ids,
batch_logits,
):
if pos_ids is None:
pos_ranges = [(0, len(input_ids_all) - 1)]
else:
pos_ranges = find_ranges(pos_ids)
for pos_range in pos_ranges:
start, end = pos_range
if start == end:
continue
input_ids = input_ids_all[start : end + 1]
labels = labels_all[start : end + 1]
tokens_without_loss = labels == IGNORE_INDEX
tokens_with_loss = labels != IGNORE_INDEX
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
prompt_token_includes = (
tokens_without_loss & tokens_exclude_padding
)
prompt_token_ids = input_ids[prompt_token_includes]
prompt_token_ids_list.append(prompt_token_ids)
completion_token_ids = input_ids[tokens_with_loss]
completion_token_ids_list.append(completion_token_ids)
pred_step_token_ids = logits_to_tokens(
logits[start : end + 1]
)[tokens_with_loss]
pred_step_token_ids_list.append(pred_step_token_ids)
prompt_texts = tokenizer.batch_decode(
prompt_token_ids_list, skip_special_tokens=True
)
completion_texts = tokenizer.batch_decode(
completion_token_ids_list, skip_special_tokens=True
)
pred_step_texts = tokenizer.batch_decode(
pred_step_token_ids_list, skip_special_tokens=True
)
with torch.no_grad():
prompt_encoding = tokenizer(
prompt_texts, padding=True, return_tensors="pt"
).to(self.cfg.device)
predictions = trainer.model.generate(
**prompt_encoding, generation_config=generation_config
)
prediction_all_tokens = predictions["sequences"].cpu().tolist()
prediction_without_prompt_tokens_list = []
for prompt_token_ids, prediction_tokens in zip(
prompt_token_ids_list, prediction_all_tokens
):
prediction_without_prompt_tokens = prediction_tokens[
len(prompt_token_ids) :
]
prediction_without_prompt_tokens_list.append(
prediction_without_prompt_tokens
)
predicted_texts = tokenizer.batch_decode(
prediction_without_prompt_tokens_list, skip_special_tokens=True
)
for (
prompt_text,
completion_text,
prediction_text,
pred_step_text,
) in zip(
prompt_texts, completion_texts, predicted_texts, pred_step_texts
):
table.add_data(
row_index,
prompt_text,
completion_text,
prediction_text,
pred_step_text,
)
row_index += 1
wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined]
if is_main_process():
log_table_from_dataloader("Eval", eval_dataloader)
return control
return LogPredictionCallback

View File

@@ -4,6 +4,7 @@ import logging
import os import os
import torch import torch
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.models import load_model_config from axolotl.utils.models import load_model_config
@@ -25,9 +26,11 @@ def choose_device(cfg):
return "cpu" return "cpu"
cfg.device = get_device() cfg.device = get_device()
if cfg.device_map != "auto": if cfg.world_size == 1:
cfg.device_map = "auto"
else:
if cfg.device.startswith("cuda"): if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank} cfg.device_map = {"": torch.cuda.current_device()}
else: else:
cfg.device_map = {"": cfg.device} cfg.device_map = {"": cfg.device}
@@ -48,6 +51,8 @@ def normalize_config(cfg):
) )
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1)) cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0)) cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
cfg.eval_table_size = cfg.eval_table_size or 0
cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
choose_device(cfg) choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1 cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp: if cfg.ddp:
@@ -71,6 +76,7 @@ def normalize_config(cfg):
cfg.torch_dtype = torch.float32 cfg.torch_dtype = torch.float32
model_config = load_model_config(cfg) model_config = load_model_config(cfg)
cfg.model_config_type = model_config.model_type
# figure out if the model is llama # figure out if the model is llama
cfg.is_llama_derived_model = ( cfg.is_llama_derived_model = (
@@ -84,6 +90,14 @@ def normalize_config(cfg):
def validate_config(cfg): def validate_config(cfg):
if is_torch_bf16_gpu_available():
if not cfg.bf16 and not cfg.bfloat16:
LOG.info("bf16 support detected, but not enabled for this configuration.")
else:
if cfg.bf16 or cfg.bfloat16:
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
if cfg.max_packed_sequence_len and cfg.sample_packing: if cfg.max_packed_sequence_len and cfg.sample_packing:
raise ValueError( raise ValueError(
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing" "please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
@@ -97,6 +111,11 @@ def validate_config(cfg):
) )
) )
if cfg.sample_packing and not cfg.pad_to_sequence_len:
LOG.warning(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
)
if cfg.gradient_accumulation_steps and cfg.batch_size: if cfg.gradient_accumulation_steps and cfg.batch_size:
raise ValueError( raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size" "please set only one of gradient_accumulation_steps or batch_size"
@@ -186,6 +205,10 @@ def validate_config(cfg):
LOG.warning( LOG.warning(
"You probably want to disable group_by_length as it will force a streamed dataset to download completely." "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
) )
if cfg.pretraining_dataset and not cfg.max_steps:
raise ValueError(
"max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
)
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and ( if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer not cfg.optimizer or "adamw" not in cfg.optimizer
@@ -215,6 +238,30 @@ def validate_config(cfg):
"sample_packing not compatible with xformers_attention. Use flash_attention" "sample_packing not compatible with xformers_attention. Use flash_attention"
) )
if cfg.early_stopping_patience:
if not cfg.save_steps or not cfg.eval_steps:
raise ValueError(
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
)
if cfg.save_steps % cfg.eval_steps != 0:
raise ValueError(
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
)
if cfg.model_type == "MixFormerSequentialForCausalLM" and cfg.adapter is not None:
LOG.warning("Use AutoModelForCausalLM for phi/MixFormer models with qLoRA")
if cfg.model_config_type == "mixformer-sequential":
if cfg.sample_packing:
if cfg.adapter is not None:
LOG.warning(
"phi/MixFormer models are not currently compatible with LoRA and sample_packing"
)
if cfg.model_type == "AutoModelForCausalLM":
raise ValueError(
"`model_type: MixFormerSequentialForCausalLM` required for sample_packing"
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -2,9 +2,8 @@
import functools import functools
import hashlib import hashlib
import logging import logging
from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import Dict, List, Tuple, Union
import torch import torch
from datasets import ( from datasets import (
@@ -23,7 +22,6 @@ from axolotl.prompt_tokenizers import (
AlpacaMultipleChoicePromptTokenizingStrategy, AlpacaMultipleChoicePromptTokenizingStrategy,
AlpacaPromptTokenizingStrategy, AlpacaPromptTokenizingStrategy,
AlpacaReflectionPTStrategy, AlpacaReflectionPTStrategy,
CompletionPromptTokenizingStrategy,
GPTeacherPromptTokenizingStrategy, GPTeacherPromptTokenizingStrategy,
JeopardyPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy,
OpenAssistantPromptTokenizingStrategy, OpenAssistantPromptTokenizingStrategy,
@@ -32,7 +30,6 @@ from axolotl.prompt_tokenizers import (
) )
from axolotl.prompters import ( from axolotl.prompters import (
AlpacaPrompter, AlpacaPrompter,
CompletionPrompter,
GPTeacherPrompter, GPTeacherPrompter,
JeopardyPrompter, JeopardyPrompter,
MultipleChoiceConcisePrompter, MultipleChoiceConcisePrompter,
@@ -52,6 +49,13 @@ LOG = logging.getLogger("axolotl")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def md5(to_hash: str, encoding: str = "utf-8") -> str:
try:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
except TypeError:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
def prepare_dataset(cfg, tokenizer): def prepare_dataset(cfg, tokenizer):
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
with zero_first(is_main_process()): with zero_first(is_main_process()):
@@ -68,6 +72,7 @@ def prepare_dataset(cfg, tokenizer):
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
eval_dataset = None eval_dataset = None
return train_dataset, eval_dataset, cfg.max_steps
with zero_first(is_main_process()): with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing( train_dataset, eval_dataset = process_datasets_for_packing(
@@ -88,7 +93,7 @@ def load_tokenized_prepared_datasets(
) -> DatasetDict: ) -> DatasetDict:
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
ds_hash = str( ds_hash = str(
md5( # nosec md5(
( (
str(cfg.sequence_len) str(cfg.sequence_len)
+ "@" + "@"
@@ -97,8 +102,8 @@ def load_tokenized_prepared_datasets(
) )
+ "|" + "|"
+ tokenizer_name + tokenizer_name
).encode("utf-8") )
).hexdigest() )
) )
prepared_ds_path = ( prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash Path(cfg.dataset_prepared_path) / ds_hash
@@ -178,6 +183,10 @@ def load_tokenized_prepared_datasets(
ds_type = "parquet" ds_type = "parquet"
elif ".arrow" in d.path: elif ".arrow" in d.path:
ds_type = "arrow" ds_type = "arrow"
elif ".csv" in d.path:
ds_type = "csv"
elif ".txt" in d.path:
ds_type = "text"
ds = load_dataset( ds = load_dataset(
ds_type, ds_type,
name=d.name, name=d.name,
@@ -320,15 +329,6 @@ def load_tokenized_prepared_datasets(
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
elif d_base_type == "completion":
ds_strategy = CompletionPromptTokenizingStrategy(
CompletionPrompter(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
else: else:
suffix = "" suffix = ""
if ":load_" in d.type: if ":load_" in d.type:
@@ -374,7 +374,7 @@ def load_prepare_datasets(
# see if we can go ahead and load the stacked dataset # see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else "" seed = f"@{str(cfg.seed)}" if cfg.seed else ""
ds_hash = str( ds_hash = str(
md5( # nosec md5(
( (
str(cfg.sequence_len) str(cfg.sequence_len)
+ "@" + "@"
@@ -385,8 +385,8 @@ def load_prepare_datasets(
) )
+ "|" + "|"
+ tokenizer_name + tokenizer_name
).encode("utf-8") )
).hexdigest() )
) )
prepared_ds_path = ( prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash Path(cfg.dataset_prepared_path) / ds_hash
@@ -500,12 +500,8 @@ def load_prepare_datasets(
+ "|" + "|"
+ str(cfg.seed or 42) + str(cfg.seed or 42)
) )
train_fingerprint = hashlib.md5( train_fingerprint = md5(to_hash_train)
to_hash_train.encode(), usedforsecurity=False test_fingerprint = md5(to_hash_test)
).hexdigest()
test_fingerprint = hashlib.md5(
to_hash_test.encode(), usedforsecurity=False
).hexdigest()
with zero_first(is_main_process()): with zero_first(is_main_process()):
dataset = dataset.train_test_split( dataset = dataset.train_test_split(
@@ -525,9 +521,11 @@ def load_prepare_datasets(
return train_dataset, eval_dataset return train_dataset, eval_dataset
def encode_pretraining(tokenizer, max_tokens, examples): def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
) -> Dict[str, List]:
res = tokenizer( res = tokenizer(
examples["text"], examples,
truncation=True, truncation=True,
max_length=max_tokens - 2, max_length=max_tokens - 2,
add_special_tokens=True, add_special_tokens=True,
@@ -635,6 +633,12 @@ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
encode = functools.partial(encode_pretraining, tokenizer, max_tokens) encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train") dataset = load_dataset(path, streaming=True, split="train")
dataset = dataset.shuffle(seed=seed, buffer_size=10_000) dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
# TODO dynamically figure out which columns/features to remove dataset = dataset.map(
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"]) encode,
batched=True,
input_columns="text",
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
remove_columns=dataset.features.keys(),
)
return dataset return dataset

View File

@@ -223,6 +223,8 @@ class MultipackDistributedDataloader:
concatenated = {} concatenated = {}
batched_data = [self.dataset[batch_idx] for batch_idx in batch] batched_data = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features: for feature in features:
if feature == "length":
continue
if feature == "attention_mask": if feature == "attention_mask":
arrays = [ arrays = [
(attn_mask_cum_idx + idx + 1) * np.array(item[feature]) (attn_mask_cum_idx + idx + 1) * np.array(item[feature])

View File

@@ -2,6 +2,7 @@
utility helpers for distributed checks utility helpers for distributed checks
""" """
import os import os
import pickle # nosec
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
@@ -93,3 +94,118 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
gathered_values.append(float(tensor.item())) gathered_values.append(float(tensor.item()))
return gathered_values return gathered_values
return None return None
def broadcast_dict(vals: dict):
if not is_distributed():
return vals
if is_main_process():
data_byte = pickle.dumps(vals)
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
else:
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
data_size = torch.IntTensor([0]).to("cuda")
dist.broadcast(data_size, 0)
if not is_main_process():
# resize
data_tensor = data_tensor.new_empty([data_size.item()])
dist.broadcast(data_tensor, 0)
if not is_main_process():
data_list = data_tensor.cpu().tolist()
data_byte = bytes(data_list[: data_size.item()])
vals = pickle.loads(data_byte) # nosec
return vals
def compute_and_broadcast(fn): # pylint: disable=invalid-name
"""
Compute a value using the function 'fn' only on the specified rank (default is 0).
The value is then broadcasted to all other ranks.
Args:
- fn (callable): A function that computes the value. This should not have any side effects.
- rank (int, optional): The rank that computes the value. Default is 0.
Returns:
- The computed value (int or float).
"""
if is_main_process():
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
else:
value_tensor = torch.tensor(0.0, device=dist.get_rank()) # Placeholder tensor
# Broadcast the tensor to all processes.
barrier()
dist.broadcast(value_tensor, src=0)
# Convert the tensor back to its original type (int or float)
if value_tensor == value_tensor.int():
return int(value_tensor.item())
return float(value_tensor.item())
def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
"""
Run a callable 'fn' on all ranks and gather the results on the specified rank.
Args:
- fn (callable): A function that computes the value. This should not have any side effects.
- rank (int, optional): The rank that gathers the values. Default is 0.
- world_size (int, optional): Total number of processes in the current distributed setup.
Returns:
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
value_scalar = fn()
value_tensor = torch.tensor(value_scalar, device=dist.get_rank()).float()
# Placeholder tensor for gathering results
if is_main_process():
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
else:
gathered_tensors = None
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
if is_main_process():
# Convert tensors back to their original type (int or float)
gathered_values = []
for tensor in gathered_tensors:
if tensor == tensor.int():
gathered_values.append(int(tensor.item()))
else:
gathered_values.append(float(tensor.item()))
return gathered_values
return None
def reduce_and_broadcast(fn1, fn2):
"""
Run a callable 'fn1' on all ranks, gather the results, reduce them using 'fn2',
and then broadcast the reduced result to all ranks.
Args:
- fn1 (callable): A function that computes the value on each rank.
- fn2 (callable): A reduction function that takes a list of values and returns a single value.
- world_size (int, optional): Total number of processes in the current distributed setup.
Returns:
- The reduced and broadcasted value.
"""
# Gather values from all ranks using fn1
if not is_distributed():
return fn2([fn1()])
gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size())
# Use compute_and_broadcast to compute the reduced value on the main process
# and then broadcast it to all ranks
return compute_and_broadcast(lambda: fn2(gathered_values))

View File

@@ -1,6 +1,5 @@
"""Module for models and model loading""" """Module for models and model loading"""
import importlib
import logging import logging
import math import math
import os import os
@@ -101,10 +100,31 @@ def load_model(
base_model = cfg.base_model base_model = cfg.base_model
base_model_config = cfg.base_model_config base_model_config = cfg.base_model_config
model_type = cfg.model_type model_type = cfg.model_type
model_config = load_model_config(cfg)
# TODO refactor as a kwarg # TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit load_in_8bit = cfg.load_in_8bit
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
if cfg.flash_attention:
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
replace_btlm_attn_with_flash_attn,
)
replace_btlm_attn_with_flash_attn(cfg.base_model)
if hasattr(model_config, "model_type") and model_config.model_type in [
"falcon",
"RefinedWebModel",
"RefinedWeb",
]:
if cfg.flash_attention:
from axolotl.monkeypatch.falcon_attn_hijack_flash import (
replace_falcon_attn_with_flash_attn,
)
replace_falcon_attn_with_flash_attn()
if cfg.is_llama_derived_model and cfg.flash_attention: if cfg.is_llama_derived_model and cfg.flash_attention:
if cfg.device not in ["mps", "cpu"] and not inference: if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import ( from axolotl.monkeypatch.llama_attn_hijack_flash import (
@@ -155,14 +175,31 @@ def load_model(
LOG.info("patching _expand_mask") LOG.info("patching _expand_mask")
hijack_expand_mask() hijack_expand_mask()
# special handling b/c remote MixFormers code doesn't have _no_split_modules set
if (
"MixFormerSequentialConfig" in model_config.__class__.__name__
and cfg.model_type == "AutoModelForCausalLM"
):
module_name = model_config.__class__.__module__.replace(
".configuration_mixformer_sequential", ".modeling_mixformer_sequential"
)
modeling_phi = importlib.import_module(module_name)
# pylint:disable=protected-access
modeling_phi.MixFormerSequentialForCausalLM._no_split_modules = [
"ParallelBlock"
]
model_kwargs = {} model_kwargs = {}
if cfg.model_revision: if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision model_kwargs["revision"] = cfg.model_revision
if cfg.gptq: if cfg.gptq:
model_config = load_model_config(cfg) if not hasattr(model_config, "quantization_config"):
if hasattr(model_config, "quantization_config"):
LOG.warning("model config does not contain quantization_config information") LOG.warning("model config does not contain quantization_config information")
else: else:
if cfg.gptq_disable_exllama is not None:
model_config.quantization_config[
"disable_exllama"
] = cfg.gptq_disable_exllama
model_kwargs["quantization_config"] = GPTQConfig( model_kwargs["quantization_config"] = GPTQConfig(
**model_config.quantization_config **model_config.quantization_config
) )
@@ -221,6 +258,17 @@ def load_model(
# device=cfg.device, # device=cfg.device,
# ) # )
# model.train() # sets to train instead of eval mode # model.train() # sets to train instead of eval mode
elif model_type == "MixFormerSequentialForCausalLM":
from axolotl.models.phi import MixFormerSequentialForCausalLM
model = MixFormerSequentialForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=cfg.torch_dtype,
**model_kwargs,
)
elif model_type and not cfg.trust_remote_code: elif model_type and not cfg.trust_remote_code:
if cfg.gptq: if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
@@ -291,15 +339,18 @@ def load_model(
if cfg.resize_token_embeddings_to_32x if cfg.resize_token_embeddings_to_32x
else len(tokenizer) else len(tokenizer)
) )
model.resize_token_embeddings(embeddings_len) if model.get_input_embeddings().num_embeddings < embeddings_len:
model.resize_token_embeddings(embeddings_len)
else:
model.tie_weights()
if ( if (
hasattr(model.config, "max_position_embeddings") hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings and model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings and cfg.sequence_len > model.config.max_position_embeddings
): ):
LOG.warning( LOG.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}" f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}"
) )
model.config.max_position_embeddings = cfg.sequence_len model.config.max_position_embeddings = cfg.sequence_len
@@ -310,6 +361,9 @@ def load_model(
for name, module in model.named_modules(): for name, module in model.named_modules():
if "norm" in name: if "norm" in name:
module.to(torch.float32) module.to(torch.float32)
if model_config.model_type == "btlm":
# don't upcast lm_head for btlm
continue
if "lm_head" in name or "embed_tokens" in name: if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"): if hasattr(module, "weight"):
module.to(torch.float32) module.to(torch.float32)

View File

@@ -18,21 +18,16 @@ def check_example_labels(example, tokenizer, text_only=False):
# Get the input_ids, labels, and attention_mask from the dataset # Get the input_ids, labels, and attention_mask from the dataset
input_ids = example["input_ids"] input_ids = example["input_ids"]
labels = example["labels"] labels = example["labels"]
attention_mask = example["attention_mask"]
# You can compare the input_ids and labels element-wise # You can compare the input_ids and labels element-wise
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0 # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
colored_tokens = [] colored_tokens = []
for _, (input_id, label_id, mask) in enumerate( for _, (input_id, label_id) in enumerate(zip(input_ids, labels)):
zip(input_ids, labels, attention_mask)
):
decoded_input_token = tokenizer.decode(input_id) decoded_input_token = tokenizer.decode(input_id)
# Choose the color based on whether the label has the ignore value or not # Choose the color based on whether the label has the ignore value or not
color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green") color = "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
colored_token = colored(decoded_input_token, color) + ( colored_token = colored(decoded_input_token, color) + (
not text_only not text_only and colored(f"({label_id}, {input_id})", "white") or ""
and colored(f"({label_id}, {mask}, {input_id})", "white")
or ""
) )
colored_tokens.append(colored_token) colored_tokens.append(colored_token)

View File

@@ -8,10 +8,12 @@ from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import List, Optional, Union
import numpy as np import numpy as np
import torch
import torch.cuda import torch.cuda
import torch.distributed as dist
import transformers import transformers
from datasets import Dataset, set_caching_enabled from datasets import Dataset, set_caching_enabled
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
@@ -30,9 +32,16 @@ from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
SavePeftModelCallback, SavePeftModelCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
log_prediction_callback_factory,
) )
from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.distributed import (
is_distributed,
is_main_process,
reduce_and_broadcast,
zero_first,
)
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -114,6 +123,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False, default=False,
metadata={"help": "Use sample packing for efficient training."}, metadata={"help": "Use sample packing for efficient training."},
) )
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field( sample_packing_efficiency: float = field(
default=1.0, default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."}, metadata={"help": "Sample packing efficiency for calculating batch length."},
@@ -209,7 +222,11 @@ class AxolotlTrainer(Trainer):
def _get_eval_sampler( def _get_eval_sampler(
self, eval_dataset: Dataset self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]: ) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing: if (
self.args.world_size > 1
and self.args.sample_packing
and self.args.eval_sample_packing is not False
):
return SequentialDistributedSampler( return SequentialDistributedSampler(
eval_dataset, eval_dataset,
num_replicas=self.args.world_size, num_replicas=self.args.world_size,
@@ -238,7 +255,7 @@ class AxolotlTrainer(Trainer):
def get_eval_dataloader( def get_eval_dataloader(
self, eval_dataset: Optional[Dataset] = None self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]: ) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing: if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = ( eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset if eval_dataset is not None else self.eval_dataset
) )
@@ -356,7 +373,14 @@ class ReLoRATrainer(AxolotlTrainer):
def add_position_ids(sample): def add_position_ids(sample):
sample_len = len(sample["input_ids"])
sample["position_ids"] = torch.arange(len(sample["input_ids"])) sample["position_ids"] = torch.arange(len(sample["input_ids"]))
sample["length"] = sample_len
return sample
def add_length(sample):
sample["length"] = len(sample["input_ids"])
return sample return sample
@@ -375,14 +399,21 @@ def disable_datasets_caching():
def process_datasets_for_packing(cfg, train_dataset, eval_dataset): def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count()) with zero_first(is_main_process()):
if eval_dataset: train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
if cfg.sample_packing:
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.map(add_position_ids, num_proc=os.cpu_count()) eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
if cfg.group_by_length:
train_dataset = train_dataset.map(add_length, num_proc=os.cpu_count())
if cfg.sample_packing:
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids, num_proc=os.cpu_count()
)
return train_dataset, eval_dataset return train_dataset, eval_dataset
@@ -398,7 +429,7 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values .values
) )
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`") LOG.info(f"total_num_tokens: {total_num_tokens}")
cfg.total_num_tokens = total_num_tokens cfg.total_num_tokens = total_num_tokens
if not cfg.total_supervised_tokens: if not cfg.total_supervised_tokens:
@@ -431,7 +462,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}" f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
) )
else: else:
sampler = RandomSampler(train_dataset) if cfg.world_size > 1 and is_distributed():
sampler = DistributedSampler(
train_dataset,
num_replicas=cfg.world_size,
rank=dist.get_rank(),
seed=cfg.seed or 42,
)
else:
sampler = RandomSampler(train_dataset)
data_loader = MultipackDistributedDataloader( data_loader = MultipackDistributedDataloader(
train_dataset, train_dataset,
batch_size=cfg.micro_batch_size, batch_size=cfg.micro_batch_size,
@@ -449,18 +489,23 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
data_loader_len = data_loader.len_w_stats() data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency() actual_eff = data_loader.efficiency()
LOG.info(f"data_loader_len: {data_loader_len}") LOG.info(f"data_loader_len: {data_loader_len}")
total_num_steps = int( # FIXME: is there a bug here somewhere? the total num steps depends
math.floor( # on the agreed on value for sample_packing_eff_est
data_loader_len total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
* cfg.micro_batch_size
* cfg.num_epochs def calc_sample_packing_eff_est(estimates: List[float]):
// cfg.batch_size LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
) return max(estimates)
sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: actual_eff,
calc_sample_packing_eff_est,
) )
LOG.info( sample_packing_eff_est = (
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`" math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
) )
cfg.sample_packing_eff_est = math.ceil(actual_eff * 100.0) / 100.0 cfg.sample_packing_eff_est = sample_packing_eff_est
LOG.info(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
else: else:
total_num_steps = int( total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
@@ -552,26 +597,57 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"sample_packing_efficiency" "sample_packing_efficiency"
] = cfg.sample_packing_eff_est ] = cfg.sample_packing_eff_est
if cfg.val_set_size == 0: if cfg.eval_steps and cfg.evaluation_strategy:
# assume if the user set both, they know what they're doing
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
elif cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no" training_arguments_kwargs["evaluation_strategy"] = "no"
elif cfg.evaluation_strategy and cfg.evaluation_strategy in ["epoch", "no"]:
# if explicitly set for epoch, just set, and eval steps don't matter
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
elif cfg.eval_steps: elif cfg.eval_steps:
# steps isn't used w/ epochs
training_arguments_kwargs["evaluation_strategy"] = "steps" training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = cfg.eval_steps training_arguments_kwargs["eval_steps"] = cfg.eval_steps
else: else:
# we have an eval set, but no steps defined, use epoch # we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch" training_arguments_kwargs["evaluation_strategy"] = "epoch"
if cfg.save_strategy: if cfg.save_steps:
# save_steps implies save_strategy of steps
training_arguments_kwargs["save_strategy"] = "steps"
training_arguments_kwargs["save_steps"] = cfg.save_steps
elif cfg.save_strategy:
training_arguments_kwargs["save_strategy"] = cfg.save_strategy training_arguments_kwargs["save_strategy"] = cfg.save_strategy
else: else:
training_arguments_kwargs["save_strategy"] = ( # default to saving each epoch if not defined
"steps" if cfg.save_steps else "epoch" training_arguments_kwargs["save_strategy"] = "epoch"
)
if cfg.do_bench_eval: if cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
if cfg.bench_dataset: if cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
if cfg.metric_for_best_model:
training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model
if cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
if cfg.torch_compile:
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
else:
import torch._dynamo # pylint: disable=redefined-outer-name
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_arguments_kwargs["torch_compile"] = cfg.torch_compile
if cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = cfg.torch_compile_backend
# DDP Config # DDP Config
if cfg.ddp_timeout: if cfg.ddp_timeout:
@@ -593,15 +669,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
eval_accumulation_steps=cfg.gradient_accumulation_steps, eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs, num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate, learning_rate=cfg.learning_rate,
save_steps=cfg.save_steps,
output_dir=cfg.output_dir, output_dir=cfg.output_dir,
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4, save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
load_best_model_at_end=( load_best_model_at_end=(
cfg.load_best_model_at_end is not False (cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
and cfg.val_set_size > 0 and cfg.val_set_size > 0
and cfg.save_steps and cfg.save_steps
and cfg.save_steps % cfg.eval_steps == 0 and cfg.save_steps % cfg.eval_steps == 0
and cfg.load_in_8bit is not True
) )
or False, or False,
ddp_find_unused_parameters=False if cfg.ddp else None, ddp_find_unused_parameters=False if cfg.ddp else None,
@@ -614,6 +688,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
else "cosine", else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0, weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False, sample_packing=cfg.sample_packing if cfg.sample_packing else False,
eval_sample_packing=cfg.eval_sample_packing,
sample_packing_seq_len_multiplier=cfg.micro_batch_size, sample_packing_seq_len_multiplier=cfg.micro_batch_size,
relora_steps=cfg.relora_steps, relora_steps=cfg.relora_steps,
relora_warmup_steps=cfg.relora_warmup_steps, relora_warmup_steps=cfg.relora_warmup_steps,
@@ -633,13 +708,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if cfg.relora_steps: if cfg.relora_steps:
callbacks.append(ReLoRACallback(cfg)) callbacks.append(ReLoRACallback(cfg))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience,
)
callbacks.append(early_stop_cb)
if cfg.local_rank == 0 and cfg.adapter in [ if cfg.local_rank == 0 and cfg.adapter in [
"lora", "lora",
"qlora", "qlora",
@@ -703,7 +771,18 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
**trainer_kwargs, **trainer_kwargs,
) )
if cfg.use_wandb and cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
trainer.add_callback(LogPredictionCallback(cfg))
if cfg.do_bench_eval: if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer)) trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience,
)
trainer.add_callback(early_stop_cb)
return trainer return trainer

1
tests/e2e/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
last_run_prepared

View File

@@ -0,0 +1,107 @@
"""
E2E tests for lora llama
"""
import logging
import os
import tempfile
import unittest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestLoraLlama(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
def test_lora(self):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
def test_lora_packing(self):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)

109
tests/e2e/test_phi.py Normal file
View File

@@ -0,0 +1,109 @@
"""
E2E tests for lora llama
"""
import logging
import os
import tempfile
import unittest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestPhi(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
def test_ft(self):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",
"base_model_config": "microsoft/phi-1_5",
"trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
"sample_packing": False,
"load_in_8bit": True,
"adapter": None,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<|endoftext|>",
"bos_token": "<|endoftext|>",
"eos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"dataset_shard_num": 10,
"dataset_shard_idx": 0,
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
def test_ft_packed(self):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",
"base_model_config": "microsoft/phi-1_5",
"trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
"sample_packing": True,
"load_in_8bit": True,
"adapter": None,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<|endoftext|>",
"bos_token": "<|endoftext|>",
"eos_token": "<|endoftext|>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"dataset_shard_num": 10,
"dataset_shard_idx": 0,
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(),
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)

64
tests/test_data.py Normal file
View File

@@ -0,0 +1,64 @@
"""
test module for the axolotl.utis.data module
"""
import unittest
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_pretraining, md5
class TestEncodePretraining(unittest.TestCase):
"""
test class for encode pretraining and md5 helper
"""
def setUp(self):
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"eos_token": "</s>",
"bos_token": "<s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
}
)
self.max_tokens = 15 # set a small number for easy inspection
def test_encode_pretraining(self):
examples = {
"text": [
"Hello, world!",
"Nice to meet you.",
"lorem ipsum dolor sit amet.",
"Nice to meet you again!.",
"hello, hello",
]
}
result = encode_pretraining(self.tokenizer, self.max_tokens, examples["text"])
self.assertEqual(len(result["input_ids"]), 3)
# Assert the length of input_ids and attention_mask is correct
self.assertEqual(len(result["input_ids"][0]), self.max_tokens)
self.assertEqual(len(result["attention_mask"][0]), self.max_tokens)
# Assert EOS and PAD tokens are correctly added
# hello world! is 4 tokens
self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id)
self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id)
self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id)
# second part, 5 tokens
self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id)
self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id)
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
def test_md5(self):
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
self.assertEqual(
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
)
if __name__ == "__main__":
unittest.main()

View File

@@ -328,6 +328,20 @@ class ValidationTest(unittest.TestCase):
for record in self._caplog.records for record in self._caplog.records
) )
cfg = DictDefault(
{
"sample_packing": True,
"pad_to_sequence_len": None,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"`pad_to_sequence_len: true` is recommended when using sample_packing"
in record.message
for record in self._caplog.records
)
cfg = DictDefault( cfg = DictDefault(
{ {
"max_packed_sequence_len": 2048, "max_packed_sequence_len": 2048,