Compare commits

...

63 Commits

Author SHA1 Message Date
mhenrichsen
9084879861 tinyllama 2023-11-16 13:36:01 +00:00
NanoCode012
3cc67d2cdd Feat: Add dataset loading from S3, GCS (#765)
* Feat: Add dataset loading from S3, GCS

* chore: update docs

* chore: add more info on cloud loading
2023-11-16 14:33:58 +09:00
Wing Lian
1bc11868eb allow overriding of model_config parameters from the YML (#853)
* allow overriding of model_config parameters from the YML

* remove old logging, update readme

* move the updating of model config to the load_model_config function

* add warning for deprecated rope_scaling in the root of the YML config
2023-11-15 23:47:08 -05:00
Wing Lian
b3a61e8ce2 add e2e tests for checking functionality of resume from checkpoint (#865)
* use tensorboard to see if resume from checkpoint works

* make sure e2e test is either fp16 or bf16

* set max_steps and save limit so we have the checkpoint when testing resuming

* fix test parameters
2023-11-15 23:05:55 -05:00
Wing Lian
8a8d1c4023 make docker command more robust (#861)
* make docker command more robust

* update readme with more info
2023-11-15 23:03:54 -05:00
Wing Lian
332984db18 lint fix that didn't get caught by linter (#866) 2023-11-15 14:36:40 -05:00
MilesQLi
48630f5b34 Update data.py for signature generation (#851)
* Update data.py

Change of conversation formatting type should also trigger updating the preprocessed dataset, so it should be part of the signature.

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-11-15 14:12:32 -05:00
Zongheng Yang
b33c1d55a2 Docs: add instructions to 1-click launching on public clouds (#862)
* Update README.md

* Update ToC
2023-11-15 14:11:27 -05:00
Wing Lian
0c2a630326 multipack len should use max, not min (#863) 2023-11-15 12:52:32 -05:00
Wing Lian
db8a8afcba adds llama and mistral dropout support (#858)
* adds llama and mistral dropout support

* gracefully handle attention dropout if not available yet
2023-11-15 12:28:50 -05:00
Wing Lian
14706504e3 various bugfixes (#856)
* various bugfixes

use latest tinyllama release
check if val_set_size is empty first
update sdp and xformers llama patches for updated upstream transformers
fix system prompt when no input
calculate total and total supervised tokens even when not sample packing

* add fix for when eval size is estimated to be too small

* should be len 1 for dataset length

* add catchall kwargs
2023-11-15 12:23:18 -05:00
NanoCode012
501b4d1379 chore(doc): Separate section on runpod (#860) 2023-11-16 01:06:51 +09:00
NanoCode012
306fe19c54 feat(doc): add more info on train_on_split (#855) 2023-11-15 23:42:26 +09:00
Fabian Preiß
614cff4107 include the suffix modified string in ascii art (#852) 2023-11-15 07:12:28 -05:00
Wing Lian
1a6309c8a6 cleanup the old multipack dataloader (#841) 2023-11-12 05:39:09 -05:00
Bryan Thornbury
105d0b350b Pin optimum package (#838) 2023-11-09 22:36:15 -05:00
Wing Lian
f544ab2bed don't compile deepspeed or bitsandbytes from source (#837) 2023-11-08 19:49:55 -05:00
Wing Lian
641e6f7e51 multipack w batch sampler (#795)
* test batch sampler w varying batch lens

* wip

* multipack batchsampler wip

* wip

* fix for prepare data loader to get correct # of steps based on gpues

* lint and clean up

* calculate len estimate

* fix total num steps calc

* add options for dataloader_num_workers and dataloader_pin_memory

* remove gitbook

* support prefetch_factor for dataloader optimization

* fix the kwarg
2023-11-07 20:27:40 -05:00
Wing Lian
6dc68a653f use temp_dir kwarg instead 2023-11-06 18:33:01 -05:00
Wing Lian
7de6a5639c missing dunder-init 2023-11-06 18:33:01 -05:00
Wing Lian
c74f045ba7 chore: lint 2023-11-06 18:33:01 -05:00
Wing Lian
0402d19759 make sure to cleanup tmp output_dir for e2e tests 2023-11-06 18:33:01 -05:00
Wing Lian
b2430ce670 use accelerate logging for zero/main loggin only 2023-11-06 18:32:26 -05:00
Wing Lian
4c834bf25d cleanup verbosity a bit 2023-11-06 18:32:26 -05:00
Fabian Preiß
8056ecd30e add deepspeed-kernels dependency for deepspeed>=0.12.0 (#827) 2023-11-05 07:52:56 -05:00
Jason Stillerman
738a057674 Feat: Added Gradio support (#812)
* Added gradio support

* queuing and title

* pre-commit run
2023-11-04 23:59:22 -04:00
Wing Lian
cdc71f73c8 update table for rwkv4 support, fix process count for dataset (#822) 2023-11-04 23:45:44 -04:00
NanoCode012
6459ac7357 fix: pin autogptq (#818) 2023-11-03 10:14:55 -04:00
Wing Lian
964d858da0 fix model parallel (#816) 2023-11-02 21:34:22 -04:00
NanoCode012
10388a8daf fix(tokenizer): update log order after update (#806) 2023-10-31 13:21:20 +09:00
NanoCode012
9f7e8a971d feat(doc): add dummyoptim faq fix (#802) 2023-10-29 23:06:06 +09:00
NanoCode012
637ed095a0 fix(config): Set eos/bos to tokenizer if different (#801)
* fix(config): Set eos/bos to tokenizer if different

* chore: fix lint
2023-10-29 21:32:37 +09:00
Wing Lian
827ec3d274 refactor neft patch to be more re-usable similar to trl's impl (#796) 2023-10-29 04:33:13 -04:00
Wing Lian
8b79ff0e94 fix eval_steps to be a sane default (#797)
* fix eval_steps to be a sane default

* update docs for fractional eval_steps
2023-10-27 22:36:30 -04:00
MilesQLi
0800885e2f Update to adapt to sharegpt datasets with "assistant" rather than "gp… (#774)
* Update to adapt to sharegpt datasets with "assistant" rather than "gpt" as the machine answers.

* use a strict option for hanedling incorrect turn data

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-10-27 22:00:16 -04:00
Teknium
d3193beac3 Fix Deepspeed Zero3 Config (#791)
* Update zero3.json

Take away CPU Offload by default (Slows things down horribly, better off reducing batchsize), and changes LR Scheduler to a properly decaying one

* Update zero3.json

fix something
2023-10-27 21:57:02 -04:00
Aleksa Gordić
2e71ff03a6 Add docker advanced instruction to README (#792) 2023-10-27 09:24:04 -04:00
chanvichetvong
facc49f32b GitBook: No commit message 2023-10-26 15:11:00 +00:00
Casper
e50ab072e2 Create preprocess CLI (#785)
* Create preprocess CLI

* Print prompt template if debugging

* Add print for unsupported prompters

* Formatting

* Formatting

* Refactor variables

* Formatting

* Formatting

* Formatting

* Formatting
2023-10-26 09:35:42 -04:00
Casper
05bd6f1122 Threaded MultipackDistributedDataloader with prefetched samples (#759)
* Multithreading implementation [WIP]

* Added benchmarking

* 35% increased throughput

* Memory pinning

* Start threads in init

* Correct print of samples

* Sleep if queue is full

* Remove pin_memory (worse)

* Simplify logic to one thread

* Remove benchmark

* Use deque for constant speed

* Formatting

* Formatting

* Formatting

* Formatting

* Rollback to use queue

* Fix multi-epoch training

* Add num epochs arg

* Start thread in __iter__

* Formatting

* Use is_alive correctly

* Simplify loading thread
2023-10-26 07:49:52 +02:00
NanoCode012
20aa4b57d2 chore(readme): Improve documentation on conversation field (#782)
* chore(readme): Improve documentation on conversation field

* fix: clarify where the option is
2023-10-24 12:52:32 +09:00
NanoCode012
11d1d607db chore: refactor truthy check and fix mypy (#780) 2023-10-24 12:28:40 +09:00
Wing Lian
6c81c61bc4 refactor setup trainer so we can add more hooks (#773)
* refactor setup trainer so we can add more hooks

* Remove stray comma
2023-10-23 17:38:41 -04:00
Wing Lian
9b43e7ea15 disable eval table w sample packing in examples (#778) 2023-10-23 09:18:44 -04:00
Wing Lian
2d8def68dc simplify by removing duplicate base_model_config (#772) 2023-10-23 01:42:38 -04:00
NanoCode012
44c9d0151a Fix: Warn when fullfinetune without adapter (#770) 2023-10-22 15:41:43 -04:00
Wing Lian
ca84cca2c0 convert exponential notation lr to floats (#771) 2023-10-22 15:37:03 -04:00
Casper
32eeeb5b64 Hotfix for not saving correctly (#762) 2023-10-22 13:22:32 -04:00
NanoCode012
afedc470bd Fix: Cannot tokenize with bf16 and on cpu (#766) 2023-10-23 01:32:26 +09:00
NanoCode012
9923b72649 Fix: eval table conflict with eval_sample_packing (#769) 2023-10-23 01:18:12 +09:00
Wing Lian
21cf09b608 remove lora fused packing test (#758) 2023-10-21 22:59:35 -04:00
Casper
15d3a654bf Implement fused modules (#747)
* MLP: Memory saving

* Remove RMSNorm restrictions

* Map packed weights to original

* FusedAttention module

* Simplify code

* Move fused modules

* Fix critical typo

* Split inplace

* Add FFT config

* Add validation of fused arguments

* Add fused arguments to config

* Update docs

* Fix validation logic

* Add fused modules to flash attn

* Only fuse during training

* Remove timing

* Formatting

* Formatting

* Formatting

* chore: lint

* chore: lint

* add e2e tests for fused llama

* no lora for tests

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-10-21 16:08:25 -04:00
Wing Lian
a21935f07a add to docs (#703) 2023-10-19 21:32:30 -04:00
NanoCode012
8966a6f566 chore: bump transformers to v4.34.1 to fix tokenizer issue (#745) 2023-10-19 20:18:22 -04:00
Motoki Wu
e4d1585c4e Fix DeepSpeed Zero 3 Saving (#709)
* Update train.py

* add zero3 check

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-10-19 19:18:24 -04:00
Wing Lian
70157ccb8f add a latest tag for regular axolotl image, cleanup extraneous print statement (#746) 2023-10-19 12:28:29 -04:00
seungduk.kim.2304
3a99495b05 improve: Enhance code readability of prompt_tokenizers.py (#707) 2023-10-19 08:12:17 -04:00
NanoCode012
440c3ab527 Fix(model): Linear detected and added to target module with rope linear (#738)
* Fix(model): Linear detected and added to target module with rope linear

* fix: exclude layer instead
2023-10-18 22:13:20 -04:00
Napuh
992d57f20a catch ConnectionError when checking dataset from HuggingFace (#743) 2023-10-18 22:11:54 -04:00
mhenrichsen
91a016f410 badge (#739)
* badge

* fixed text
2023-10-18 10:21:34 -04:00
Casper
a045db0214 Mistral: Sliding Window Attention with Flash Attention and Sample Packing (#732)
* Implement Mistral FA + SWA + Sample Packing

* Handle unbroadcastable tensor

* chore: lint

* Simplify _prepare_decoder_attention_mask

* Uncomment window size

* Upgrade flash-attn to minimum of 2.3.0 to support SWA

* Add original condition to avoid error during inference

* chore: lint

* use torchscript to prevent oom

* chore: pylint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-10-16 15:13:46 -04:00
Casper
e1b214c62b Clarify custom format example (#729)
* Clarify custom prompt format

* Simplify format
2023-10-14 09:28:12 -04:00
Wing Lian
3553172e3c fixes for alpaca w chatml, and don't include attention_mask w mistral for flash attention (#728) 2023-10-14 09:27:07 -04:00
85 changed files with 2930 additions and 1645 deletions

View File

@@ -23,6 +23,7 @@ jobs:
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
is_latest: true
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
@@ -54,7 +55,9 @@ jobs:
PYTORCH_VERSION=${{ matrix.pytorch }} PYTORCH_VERSION=${{ matrix.pytorch }}
file: ./docker/Dockerfile file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-runpod: build-axolotl-runpod:
needs: build-axolotl needs: build-axolotl

178
README.md
View File

@@ -25,14 +25,15 @@ Features:
- [Installation](#installation) - [Installation](#installation)
- [Docker](#docker) - [Docker](#docker)
- [Conda/Pip venv](#condapip-venv) - [Conda/Pip venv](#condapip-venv)
- [Runpod](#runpod)
- [LambdaLabs](#lambdalabs) - [LambdaLabs](#lambdalabs)
- [Windows](#windows) - [Windows](#windows)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Dataset](#dataset) - [Dataset](#dataset)
- [How to Add Custom Prompts](#how-to-add-custom-prompts) - [How to Add Custom Prompts](#how-to-add-custom-prompts)
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset) - [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config) - [Config](#config)
- [Train](#train) - [Train](#train)
- [Training w/ Deepspeed](#training-with-deepspeed)
- [Inference](#inference) - [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base) - [Merge LORA to Base](#merge-lora-to-base)
- [Common Errors](#common-errors-) - [Common Errors](#common-errors-)
@@ -75,6 +76,7 @@ Features:
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ | | gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ | | XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ | | phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
## Quickstart ⚡ ## Quickstart ⚡
@@ -97,6 +99,10 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference # inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --lora_model_dir="./lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --gradio
``` ```
## Installation ## Installation
@@ -107,7 +113,6 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
```bash ```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1 docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
``` ```
- `winglian/axolotl-runpod:main-latest`: for runpod or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
Or run on the current files for development: Or run on the current files for development:
@@ -115,6 +120,27 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
docker compose up -d docker compose up -d
``` ```
<details>
<summary>Docker advanced</summary>
A more powerful Docker command to run would be this:
```bash
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
```
It additionally:
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
* The `--privileged` flag gives all capabilities to the container.
* The `--shm-size 10g` argument increases the shared memory size. Use this if you see `exitcode: -7` errors using deepspeed.
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
</details>
#### Conda/Pip venv #### Conda/Pip venv
1. Install python >=**3.9** 1. Install python >=**3.9**
@@ -131,6 +157,10 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
``` ```
Get the token at huggingface.co/settings/tokens Get the token at huggingface.co/settings/tokens
#### Runpod
Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
#### LambdaLabs #### LambdaLabs
<details> <details>
@@ -178,6 +208,28 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
#### Windows #### Windows
Please use WSL or Docker! Please use WSL or Docker!
#### Launching on public clouds via SkyPilot
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
```bash
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
sky check
```
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
```
git clone https://github.com/skypilot-org/skypilot.git
cd skypilot/llm/axolotl
```
Use one command to launch:
```bash
# On-demand
HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
# Managed spot (auto-recovery on preemption)
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
```
### Dataset ### Dataset
Axolotl supports a variety of dataset formats. Below are some of the formats you can use. Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
@@ -297,25 +349,24 @@ Have dataset(s) in one of the following format (JSONL recommended):
#### How to add custom prompts #### How to add custom prompts
Using yaml. Example: For a dataset that is preprocessed for instruction purposes:
```json
{"instruction": "...", "output": "..."}
```
You can use this example in your YAML config:
```yaml ```yaml
datasets: datasets:
- path: repo - path: repo
type: type:
system_prompt: "" system_prompt: ""
no_input_format: |- field_system: system
User: {instruction}<|end_of_turn|> format: "[INST] {instruction} [/INST]"
Assistant: no_input_format: "[INST] {instruction} [/INST]"
format: |-
User: {instruction}
{input}<|end_of_turn|>
Assistant:
``` ```
Using file:
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
#### How to use your custom pretokenized dataset #### How to use your custom pretokenized dataset
- Do not pass a `type:` - Do not pass a `type:`
@@ -357,6 +408,13 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript - typescript
type: ... # unimplemented custom format type: ... # unimplemented custom format
# fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
datasets:
- path: ...
type: sharegpt
conversation: chatml
# local # local
datasets: datasets:
- path: data.jsonl # or json - path: data.jsonl # or json
@@ -368,6 +426,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- path: knowrohit07/know_sql - path: knowrohit07/know_sql
type: context_qa.load_v2 type: context_qa.load_v2
train_on_split: validation train_on_split: validation
# loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access
dataset:
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
...
``` ```
- loading - loading
@@ -395,7 +459,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
<details> <details>
<summary>All yaml options</summary> <summary>All yaml options (click me)</summary>
```yaml ```yaml
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files # This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
@@ -431,6 +495,14 @@ is_llama_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default # Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model: is_mistral_derived_model:
# optional overrides to the base model configuration
model_config:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# Whether you are training a 4-bit GPTQ quantized model # Whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
gptq_groupsize: 128 # group size gptq_groupsize: 128 # group size
@@ -454,7 +526,7 @@ float16: true
# A list of one or more datasets to finetune the model with # A list of one or more datasets to finetune the model with
datasets: datasets:
# HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn> type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
@@ -462,7 +534,10 @@ datasets:
data_files: # Optional[str] path to source data files data_files: # Optional[str] path to source data files
shards: # Optional[int] number of shards to split data into shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load name: # Optional[str] name of dataset configuration to load
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt train_on_split: train # Optional[str] name of dataset split to load from
# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# Custom user prompt # Custom user prompt
- path: repo - path: repo
@@ -592,14 +667,14 @@ gradient_accumulation_steps: 1
# The number of samples to include in each batch. This is the number of samples sent to each GPU. # The number of samples to include in each batch. This is the number of samples sent to each GPU.
micro_batch_size: 2 micro_batch_size: 2
eval_batch_size: eval_batch_size:
num_epochs: 3 num_epochs: 4
warmup_steps: 100 warmup_steps: 100
learning_rate: 0.00003 learning_rate: 0.00003
lr_quadratic_warmup: lr_quadratic_warmup:
logging_steps: logging_steps:
save_strategy: # Set to `no` to skip checkpoint saves save_strategy: # Set to `no` to skip checkpoint saves
save_steps: # Leave empty to save at each epoch save_steps: # Leave empty to save at each epoch
eval_steps: # Leave empty to eval at each epoch eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
save_total_limit: # Checkpoints saved at a time save_total_limit: # Checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that # Maximum number of iterations to train for. It precedes num_epochs which means that
# if both are set, num_epochs will not be guaranteed. # if both are set, num_epochs will not be guaranteed.
@@ -685,6 +760,8 @@ xformers_attention:
flash_attention: flash_attention:
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# Whether to use scaled-dot-product attention # Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention: sdp_attention:
@@ -693,10 +770,6 @@ landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# LLaMA only # LLaMA only
xpos_rope: xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# Resume from a specific checkpoint dir # Resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:
@@ -814,14 +887,41 @@ Run
accelerate launch -m axolotl.cli.train your_config.yml accelerate launch -m axolotl.cli.train your_config.yml
``` ```
#### Multi-GPU #### Preprocess dataset
You can optionally pre-tokenize dataset with the following before finetuning.
This is recommended for large datasets.
- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
- Use `--debug` to see preprocessed examples.
You can optionally pre-tokenize dataset with the following before finetuning:
```bash ```bash
CUDA_VISIBLE_DEVICES="" accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only python -m axolotl.cli.preprocess your_config.yml
``` ```
##### Config #### Multi-GPU
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
is the recommended multi-GPU option currently because FSDP may experience
[loss instability](https://github.com/huggingface/transformers/issues/26498).
##### DeepSpeed
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
```yaml
deepspeed: deepspeed/zero1.json
```
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
```
##### FSDP
- llama FSDP - llama FSDP
```yaml ```yaml
@@ -846,24 +946,6 @@ wandb_run_id:
wandb_log_model: wandb_log_model:
``` ```
### Training with Deepspeed
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
```
or
```yaml
deepspeed: deepspeed/zero1.json
```
### Inference ### Inference
Pass the appropriate flag to the train command: Pass the appropriate flag to the train command:
@@ -881,6 +963,10 @@ Pass the appropriate flag to the train command:
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \ cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
--base_model="./completed-model" --prompter=None --load_in_8bit=True --base_model="./completed-model" --prompter=None --load_in_8bit=True
``` ```
-- With gradio hosting
```bash
python -m axolotl.cli.inference examples/your_config.yml --gradio
```
Please use `--sample_packing False` if you have it on and receive the error similar to below: Please use `--sample_packing False` if you have it on and receive the error similar to below:
@@ -902,6 +988,8 @@ CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
## Common Errors 🧰 ## Common Errors 🧰
See also the [FAQ's](./docs/faq.md).
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it: > If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
Please reduce any below Please reduce any below

View File

@@ -1,14 +1,6 @@
{ {
"zero_optimization": { "zero_optimization": {
"stage": 3, "stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true, "overlap_comm": true,
"contiguous_gradients": true, "contiguous_gradients": true,
"sub_group_size": 0, "sub_group_size": 0,
@@ -41,12 +33,13 @@
} }
}, },
"scheduler": { "scheduler": {
"type": "WarmupLR", "type": "WarmupDecayLR",
"params": { "params": {
"warmup_min_lr": "auto", "warmup_min_lr": "auto",
"warmup_max_lr": "auto", "warmup_max_lr": "auto",
"warmup_num_steps": "auto", "warmup_num_steps": "auto",
"warmup_type": "linear" "warmup_type": "linear",
"total_num_steps": "auto"
} }
}, },
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",

View File

@@ -21,9 +21,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \ pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
else \ else \
pip install -e .[flash-attn]; \ pip install -e .[deepspeed,flash-attn]; \
fi fi
# fix so that git fetch/pull from remote works # fix so that git fetch/pull from remote works

View File

@@ -10,8 +10,10 @@ ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.9" ARG PYTHON_VERSION="3.9"
ARG PYTORCH_VERSION="2.0.1" ARG PYTORCH_VERSION="2.0.1"
ARG CUDA="118" ARG CUDA="118"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \ && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
@@ -27,47 +29,9 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \ RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} deepspeed-kernels --extra-index-url https://download.pytorch.org/whl/cu$CUDA
FROM base-builder AS deepspeed-builder RUN git lfs install --skip-repo && \
pip3 install awscli && \
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
WORKDIR /workspace
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
cd DeepSpeed && \
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 DS_BUILD_EVOFORMER_ATTN=0 python3 setup.py bdist_wheel
FROM base-builder AS bnb-builder
WORKDIR /workspace
ARG CUDA="118"
ENV CUDA=$CUDA
ARG MAX_JOBS="-1"
ENV MAX_JOBS=$MAX_JOBS
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
cd bitsandbytes && \
CUDA_VERSION=$CUDA make cuda11x && \
python setup.py bdist_wheel
FROM base-builder
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN mkdir -p /workspace/builds
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
RUN mkdir -p /workspace/wheels/bitsandbytes
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
RUN pip3 install wheels/deepspeed-*.whl
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
RUN git lfs install --skip-repo
RUN pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working # The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 pip3 install -U --no-cache-dir pydantic==1.10.10

18
docs/faq.md Normal file
View File

@@ -0,0 +1,18 @@
# Axolotl FAQ's
> The trainer stopped and hasn't progressed in several minutes.
Usually an issue with the GPU's communicating with each other. See the [NCCL doc](../docs/nccl.md)
> Exitcode -9
This usually happens when you run out of system RAM.
> Exitcode -7 while using deepspeed
Try upgrading deepspeed w: `pip install -U deepspeed`
> AttributeError: 'DummyOptim' object has no attribute 'step'
You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.

View File

@@ -1,5 +1,4 @@
base_model: cerebras/btlm-3b-8k-base base_model: cerebras/btlm-3b-8k-base
base_model_config: cerebras/btlm-3b-8k-base
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: GPT2Tokenizer tokenizer_type: GPT2Tokenizer
trust_remote_code: true trust_remote_code: true
@@ -15,7 +14,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: last_prepared_run dataset_prepared_path: last_prepared_run
val_set_size: 0.01 val_set_size: 0.05
adapter: adapter:
lora_model_dir: lora_model_dir:

View File

@@ -1,5 +1,4 @@
base_model: cerebras/Cerebras-GPT-1.3B base_model: cerebras/Cerebras-GPT-1.3B
base_model_config: cerebras/Cerebras-GPT-1.3B
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true
strict: false strict: false
@@ -8,7 +7,7 @@ datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
@@ -50,7 +49,7 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 0.05
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,5 +1,4 @@
base_model: codellama/CodeLlama-13b-hf base_model: codellama/CodeLlama-13b-hf
base_model_config: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -12,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
@@ -35,7 +34,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 3 num_epochs: 4
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -55,7 +54,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 0.05
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,5 +1,4 @@
base_model: codellama/CodeLlama-13b-hf base_model: codellama/CodeLlama-13b-hf
base_model_config: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -12,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
@@ -37,7 +36,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 3 num_epochs: 4
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -57,7 +56,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 0.05
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,5 +1,4 @@
base_model: codellama/CodeLlama-34b-hf base_model: codellama/CodeLlama-34b-hf
base_model_config: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -12,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
@@ -35,7 +34,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 3 num_epochs: 4
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -55,7 +54,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 0.05
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,5 +1,4 @@
base_model: codellama/CodeLlama-34b-hf base_model: codellama/CodeLlama-34b-hf
base_model_config: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -12,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
@@ -37,7 +36,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 3 num_epochs: 4
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -57,7 +56,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 0.05
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,5 +1,4 @@
base_model: codellama/CodeLlama-7b-hf base_model: codellama/CodeLlama-7b-hf
base_model_config: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -12,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
@@ -35,7 +34,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 3 num_epochs: 4
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -55,7 +54,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 0.05
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,5 +1,4 @@
base_model: codellama/CodeLlama-7b-hf base_model: codellama/CodeLlama-7b-hf
base_model_config: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -12,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
@@ -37,7 +36,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 3 num_epochs: 4
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -57,7 +56,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 0.05
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,5 +1,4 @@
base_model: tiiuae/falcon-7b base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
@@ -13,7 +12,7 @@ datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca:chat type: alpaca:chat
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048

View File

@@ -1,7 +1,6 @@
# 1b: tiiuae/falcon-rw-1b # 1b: tiiuae/falcon-rw-1b
# 40b: tiiuae/falcon-40b # 40b: tiiuae/falcon-40b
base_model: tiiuae/falcon-7b base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main # required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
@@ -19,7 +18,7 @@ datasets:
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json - Chain-of-Thought/formatted_cot_data/gsm8k_train.json
type: "alpaca:chat" type: "alpaca:chat"
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
# enable QLoRA # enable QLoRA
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
@@ -54,7 +53,7 @@ output_dir: ./qlora-out
# decrease if OOM, increase for max VRAM utilization # decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1 micro_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
num_epochs: 3 num_epochs: 4
# Optimizer for QLoRA # Optimizer for QLoRA
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
torchdistx_path: torchdistx_path:

View File

@@ -1,5 +1,4 @@
base_model: tiiuae/falcon-7b base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
@@ -13,7 +12,7 @@ datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca:chat type: alpaca:chat
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
adapter: adapter:
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048

View File

@@ -1,5 +1,4 @@
base_model: EleutherAI/gpt-j-6b base_model: EleutherAI/gpt-j-6b
base_model_config: EleutherAI/gpt-j-6b
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true
strict: false strict: false
@@ -8,7 +7,7 @@ datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
@@ -47,7 +46,7 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 0.05
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,5 +1,4 @@
base_model: huggyllama/llama-7b base_model: huggyllama/llama-7b
base_model_config: huggyllama/llama-7b
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: false load_in_8bit: false
@@ -25,7 +24,7 @@ wandb_log_model:
output_dir: ./jeopardy-bot-7b output_dir: ./jeopardy-bot-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 3 num_epochs: 4
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine

View File

@@ -9,12 +9,16 @@ gradient_accumulation_steps: 2
micro_batch_size: 1 micro_batch_size: 1
```shell ```shell
accelerate launch scripts/finetune.py examples/llama-2/qlora.yml accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml
``` ```
or or
```shell ```shell
accelerate launch scripts/finetune.py examples/llama-2/lora.yml accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml
```
To launch a full finetuning with 16-bit precision:
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
``` ```

View File

@@ -0,0 +1,72 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_steps: 100
eval_steps: 0.05
eval_table_size:
save_steps:
debug:
deepspeed: #deepspeed/zero2.json # multi-gpu only
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,5 +1,4 @@
base_model: TheBloke/Llama-2-7B-GPTQ base_model: TheBloke/Llama-2-7B-GPTQ
base_model_config: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false is_llama_derived_model: false
gptq: true gptq: true
gptq_disable_exllama: true gptq_disable_exllama: true
@@ -16,7 +15,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 4096
@@ -38,7 +37,7 @@ wandb_log_model:
output_dir: ./model-out output_dir: ./model-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 3 num_epochs: 4
optimizer: adamw_torch optimizer: adamw_torch
adam_beta2: 0.95 adam_beta2: 0.95
adam_eps: 0.00001 adam_eps: 0.00001

View File

@@ -1,5 +1,4 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -12,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
@@ -35,7 +34,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 3 num_epochs: 4
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -55,7 +54,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 0.05
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
save_steps: save_steps:

View File

@@ -1,5 +1,4 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -12,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
@@ -37,7 +36,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 3 num_epochs: 4
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -57,7 +56,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 0.05
eval_table_size: eval_table_size:
save_steps: save_steps:
debug: debug:

View File

@@ -1,5 +1,4 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -12,7 +11,7 @@ datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./relora-out output_dir: ./relora-out
adapter: qlora adapter: qlora
@@ -41,7 +40,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 4 micro_batch_size: 4
num_epochs: 3 num_epochs: 4
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -61,7 +60,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 0.05
save_steps: 50 save_steps: 50
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,23 +1,23 @@
base_model: PY007/TinyLlama-1.1B-step-50K-105b base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T
base_model_config: PY007/TinyLlama-1.1B-step-50K-105b
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
load_in_8bit: true load_in_8bit: false
load_in_4bit: false load_in_4bit: false
strict: false strict: false
datasets: datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/context-aware-splits-english
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 200
output_dir: ./lora-out output_dir: ./tiny-llama
sequence_len: 4096 sequence_len: 8192
sample_packing: true sample_packing: true
pad_to_sequence_len: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
@@ -33,8 +33,8 @@ wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 1
micro_batch_size: 2 micro_batch_size: 8
num_epochs: 3 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
@@ -54,13 +54,13 @@ logging_steps: 1
xformers_attention: xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 50
eval_steps: 20 eval_steps: 0.05
eval_table_size: eval_table_size:
save_steps: save_steps: 0.50
debug: debug:
deepspeed: deepspeed:
weight_decay: 0.0 weight_decay: 0.1
fsdp: fsdp:
fsdp_config: fsdp_config:
special_tokens: special_tokens:

View File

@@ -1,5 +1,4 @@
base_model: mistralai/Mistral-7B-v0.1 base_model: mistralai/Mistral-7B-v0.1
base_model_config: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true is_mistral_derived_model: true
@@ -12,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./out output_dir: ./out
sequence_len: 8192 sequence_len: 8192
@@ -27,7 +26,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 3 num_epochs: 4
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.000005 learning_rate: 0.000005
@@ -47,8 +46,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 0.05
eval_table_size: 5 eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
save_steps: save_steps:
debug: debug:

View File

@@ -1,5 +1,4 @@
base_model: mistralai/Mistral-7B-v0.1 base_model: mistralai/Mistral-7B-v0.1
base_model_config: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true is_mistral_derived_model: true
@@ -12,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
@@ -64,8 +63,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 0.05
eval_table_size: 5 eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
save_steps: save_steps:
debug: debug:

View File

@@ -1,5 +1,4 @@
base_model: mosaicml/mpt-7b base_model: mosaicml/mpt-7b
base_model_config: mosaicml/mpt-7b
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
trust_remote_code: true # required for mpt as their model class is not merged into transformers yet trust_remote_code: true # required for mpt as their model class is not merged into transformers yet
load_in_8bit: false load_in_8bit: false
@@ -27,7 +26,7 @@ wandb_log_model:
output_dir: ./mpt-alpaca-7b output_dir: ./mpt-alpaca-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 3 num_epochs: 4
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine

View File

@@ -1,5 +1,4 @@
base_model: openlm-research/open_llama_3b_v2 base_model: openlm-research/open_llama_3b_v2
base_model_config: openlm-research/open_llama_3b_v2
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: false load_in_8bit: false

View File

@@ -1,5 +1,4 @@
base_model: openlm-research/open_llama_3b_v2 base_model: openlm-research/open_llama_3b_v2
base_model_config: openlm-research/open_llama_3b_v2
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: true load_in_8bit: true

View File

@@ -1,5 +1,4 @@
base_model: openlm-research/open_llama_3b_v2 base_model: openlm-research/open_llama_3b_v2
base_model_config: openlm-research/open_llama_3b_v2
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: false load_in_8bit: false
@@ -10,7 +9,7 @@ datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 1024 sequence_len: 1024

View File

@@ -1,5 +1,4 @@
base_model: microsoft/phi-1_5 base_model: microsoft/phi-1_5
base_model_config: microsoft/phi-1_5
model_type: MixFormerSequentialForCausalLM model_type: MixFormerSequentialForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_llama_derived_model: false is_llama_derived_model: false

View File

@@ -1,5 +1,4 @@
base_model: microsoft/phi-1_5 base_model: microsoft/phi-1_5
base_model_config: microsoft/phi-1_5
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_llama_derived_model: false is_llama_derived_model: false

View File

@@ -1,5 +1,4 @@
base_model: EleutherAI/pythia-12b-deduped base_model: EleutherAI/pythia-12b-deduped
base_model_config: EleutherAI/pythia-12b-deduped
base_model_ignore_patterns: pytorch* # prefer safetensors base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer

View File

@@ -1,5 +1,4 @@
base_model: EleutherAI/pythia-1.4b-deduped base_model: EleutherAI/pythia-1.4b-deduped
base_model_config: EleutherAI/pythia-1.4b-deduped
load_in_8bit: true load_in_8bit: true
datasets: datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
@@ -24,7 +23,7 @@ wandb_log_model:
output_dir: ./lora-alpaca-pythia output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 4 micro_batch_size: 4
num_epochs: 3 num_epochs: 4
learning_rate: 0.00001 learning_rate: 0.00001
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false
@@ -34,5 +33,5 @@ early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
weight_decay: 0.1 weight_decay: 0.1
eval_steps: 20 eval_steps: 0.05
logging_steps: 1 logging_steps: 1

View File

@@ -1,5 +1,4 @@
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1 base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
model_type: GPTNeoXForCausalLM model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
trust_remote_code: trust_remote_code:
@@ -28,7 +27,7 @@ wandb_log_model:
output_dir: ./redpajama-alpaca-3b output_dir: ./redpajama-alpaca-3b
batch_size: 4 batch_size: 4
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 3 num_epochs: 4
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine

View File

@@ -1,5 +1,4 @@
base_model: replit/replit-code-v1-3b base_model: replit/replit-code-v1-3b
base_model_config: replit/replit-code-v1-3b
trust_remote_code: true trust_remote_code: true
load_in_8bit: false load_in_8bit: false
datasets: datasets:
@@ -27,7 +26,7 @@ wandb_log_model:
output_dir: ./lora-replit output_dir: ./lora-replit
batch_size: 8 batch_size: 8
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 3 num_epochs: 4
optimizer: optimizer:
torchdistx_path: torchdistx_path:
lr_scheduler: lr_scheduler:

View File

@@ -1,7 +1,6 @@
# An example finetuning Saleforce's XGen-7b model with 8k context using qlora # An example finetuning Saleforce's XGen-7b model with 8k context using qlora
# on Tim Dettmer's Guanaco dataset. # on Tim Dettmer's Guanaco dataset.
base_model: Salesforce/xgen-7b-8k-base base_model: Salesforce/xgen-7b-8k-base
base_model_config: Salesforce/xgen-7b-8k-base
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
@@ -17,7 +16,7 @@ datasets:
- openassistant_best_replies_train.jsonl - openassistant_best_replies_train.jsonl
type: "completion" type: "completion"
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
# enable QLoRA # enable QLoRA
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
@@ -52,7 +51,7 @@ output_dir: ./qlora-out
# decrease if OOM, increase for max VRAM utilization # decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1 micro_batch_size: 1
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
num_epochs: 3 num_epochs: 4
# Optimizer for QLoRA # Optimizer for QLoRA
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
torchdistx_path: torchdistx_path:

BIN
image/sticker_fixed.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 370 KiB

View File

@@ -1,23 +1,23 @@
--extra-index-url https://download.pytorch.org/whl/cu118 --extra-index-url https://download.pytorch.org/whl/cu118
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
torch==2.0.1 torch==2.0.1
auto-gptq auto-gptq==0.4.2
packaging packaging
peft @ git+https://github.com/huggingface/peft.git peft==0.6.0
transformers @ git+https://github.com/huggingface/transformers.git@bd6205919aad4d3a2300a39a98a642f1cc3a5348 transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9 accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
deepspeed deepspeed
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
datasets datasets>=2.14.0
flash-attn>=2.3.0 flash-attn>=2.3.0
sentencepiece sentencepiece
wandb wandb
einops einops
xformers>=0.0.22 xformers>=0.0.22
optimum optimum==1.13.2
hf_transfer hf_transfer
colorama colorama
numba numba
@@ -31,3 +31,10 @@ scikit-learn==1.2.2
pynvml pynvml
art art
fschat==0.2.29 fschat==0.2.29
gradio
tensorboard
# remote filesystems
s3fs
gcsfs
# adlfs

View File

@@ -45,8 +45,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
shard(cfg=parsed_cfg, cli_args=parsed_cli_args) shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
else: else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)

View File

@@ -46,7 +46,7 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn>=2.2.1", "flash-attn>=2.3.0",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed", "deepspeed",

View File

@@ -6,8 +6,10 @@ import os
import random import random
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import gradio as gr
import torch import torch
import yaml import yaml
@@ -16,7 +18,7 @@ from accelerate.commands.config import config_args
from art import text2art from art import text2art
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextStreamer from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
@@ -44,7 +46,7 @@ def print_axolotl_text_art(suffix=None):
ascii_text = " axolotl" ascii_text = " axolotl"
if suffix: if suffix:
ascii_text += f" x {suffix}" ascii_text += f" x {suffix}"
ascii_art = text2art(" axolotl", font=font) ascii_art = text2art(ascii_text, font=font)
if is_main_process(): if is_main_process():
print(ascii_art) print(ascii_art)
@@ -153,6 +155,91 @@ def do_inference(
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
model = model.to(cfg.device)
def generate(instruction):
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
all_text = ""
for new_text in streamer:
all_text += new_text
yield all_text
demo = gr.Interface(
fn=generate,
inputs="textbox",
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(show_api=False, share=True)
def choose_config(path: Path): def choose_config(path: Path):
yaml_files = list(path.glob("*.yml")) yaml_files = list(path.glob("*.yml"))
@@ -222,7 +309,9 @@ def load_datasets(
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer) train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg, tokenizer
)
if cli_args.debug or cfg.debug: if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
@@ -238,6 +327,10 @@ def load_datasets(
text_only=cli_args.debug_text_only, text_only=cli_args.debug_text_only,
) )
LOG.info("printing prompters...")
for prompter in prompters:
LOG.info(prompter)
return TrainDatasetMeta( return TrainDatasetMeta(
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,

View File

@@ -6,11 +6,16 @@ from pathlib import Path
import fire import fire
import transformers import transformers
from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art from axolotl.cli import (
do_inference,
do_inference_gradio,
load_cfg,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
@@ -21,7 +26,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
) )
parsed_cli_args.inference = True parsed_cli_args.inference = True
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) if gradio:
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -0,0 +1,53 @@
"""
CLI to run training on a model
"""
import logging
from pathlib import Path
import fire
import transformers
from colorama import Fore
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
LOG = logging.getLogger("axolotl.cli.preprocess")
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "preprocess CLI called without dataset_prepared_path set, "
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
LOG.info(
Fore.GREEN
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
+ Fore.RESET
)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

@@ -6,7 +6,6 @@ from pathlib import Path
import fire import fire
import transformers import transformers
from colorama import Fore
from axolotl.cli import ( from axolotl.cli import (
check_accelerate_default_config, check_accelerate_default_config,
@@ -16,7 +15,6 @@ from axolotl.cli import (
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.train import train from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train") LOG = logging.getLogger("axolotl.cli.train")
@@ -32,18 +30,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "--prepare_ds_only called without dataset_prepared_path set."
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)

View File

@@ -25,11 +25,22 @@ class TrainerCliArgs:
debug_num_examples: int = field(default=5) debug_num_examples: int = field(default=5)
inference: bool = field(default=False) inference: bool = field(default=False)
merge_lora: bool = field(default=False) merge_lora: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None) prompter: Optional[str] = field(default=None)
shard: bool = field(default=False) shard: bool = field(default=False)
@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
def load_model_and_tokenizer( def load_model_and_tokenizer(
*, *,
cfg: DictDefault, cfg: DictDefault,

View File

View File

@@ -0,0 +1,748 @@
"""
Builder for the training args and trainer
"""
import abc
import importlib
import logging
import math
import sys
from abc import abstractmethod
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional
import torch
import transformers
from datasets import Dataset
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_utils import seed_worker
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
pass
LOG = logging.getLogger("axolotl.core.trainer_builder")
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
Extend the base TrainingArguments for axolotl helpers
"""
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
sample_packing_seq_len_multiplier: int = field(
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
class AxolotlTrainer(Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
):
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing:
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
self.args.train_batch_size,
drop_last=True,
batch_max_len=self._train_batch_size * self.args.max_seq_length,
lengths=(
self.train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
self.args.per_device_eval_batch_size,
drop_last=True,
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
lengths=(
eval_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing:
train_dataset = self.train_dataset
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> DataLoader:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
pct_start=pct_start,
div_factor=6,
)
return self.lr_scheduler
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
"""
_train_dataset = None
_eval_dataset = None
def __init__(self, cfg, model, tokenizer):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
@property
def train_dataset(self):
return self._train_dataset
@train_dataset.setter
def train_dataset(self, dataset):
self._train_dataset = dataset
@property
def eval_dataset(self):
return self._eval_dataset
@eval_dataset.setter
def eval_dataset(self, dataset):
self._eval_dataset = dataset
@abstractmethod
def build(self, total_num_steps):
pass
@abstractmethod
def get_callbacks(self):
pass
@abstractmethod
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for Causal models
"""
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
return training_arguments_kwargs
def hook_post_create_training_args(self, training_arguments):
# TODO
return training_arguments
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
# TODO
return trainer_kwargs, trainer_cls
def hook_post_create_trainer(self, trainer):
# TODO
return trainer
def get_callbacks(self):
callbacks = []
callbacks.append(GPUStatsCallback(self.cfg))
callbacks.append(EvalFirstStepCallback)
if self.cfg.relora_steps:
callbacks.append(ReLoRACallback(self.cfg))
if (
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback)
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
self.cfg.early_stopping_patience,
)
callbacks.append(early_stop_cb)
return callbacks
def _get_trainer_cls(self):
if self.cfg.lr_scheduler == "one_cycle" and (
self.cfg.fsdp or self.cfg.adapter == "qlora"
):
return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps:
return ReLoRATrainer
return AxolotlTrainer
def build(self, total_num_steps):
warmup_steps = (
self.cfg.warmup_steps
if self.cfg.warmup_steps is not None
else min(int(0.03 * total_num_steps), 100)
)
logging_steps = (
self.cfg.logging_steps
if self.cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_arguments_kwargs = {}
if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
else:
training_arguments_kwargs["bf16"] = self.cfg.bf16
training_arguments_kwargs["fp16"] = (
self.cfg.fp16 and not self.cfg.bf16
) or False
training_arguments_kwargs["tf32"] = self.cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
if self.cfg.seed:
training_arguments_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing:
training_arguments_kwargs[
"gradient_checkpointing"
] = self.cfg.gradient_checkpointing
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs[
"lr_quadratic_warmup"
] = self.cfg.lr_quadratic_warmup
if self.cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
if self.cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
if self.cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
if self.cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
if self.cfg.hub_model_id:
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if self.cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = self.cfg.sample_packing_eff_est
if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs[
"dataloader_pin_memory"
] = self.cfg.dataloader_pin_memory
if self.cfg.dataloader_num_workers is not None:
training_arguments_kwargs[
"dataloader_num_workers"
] = self.cfg.dataloader_num_workers
if self.cfg.dataloader_prefetch_factor is not None:
training_arguments_kwargs[
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor
if self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
elif self.cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.evaluation_strategy:
training_arguments_kwargs[
"evaluation_strategy"
] = self.cfg.evaluation_strategy
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
if self.cfg.save_steps:
training_arguments_kwargs["save_strategy"] = "steps"
training_arguments_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_arguments_kwargs["save_strategy"] = "epoch"
if self.cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.metric_for_best_model:
training_arguments_kwargs[
"metric_for_best_model"
] = self.cfg.metric_for_best_model
if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
if self.cfg.torch_compile:
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
elif torch._dynamo: # pylint: disable=protected-access
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = self.cfg.torch_compile_backend
# DDP Config
if self.cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs[
"ddp_broadcast_buffers"
] = self.cfg.ddp_broadcast_buffers
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1
)
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs[
"per_device_train_batch_size"
] = self.cfg.micro_batch_size
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs[
"eval_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
training_arguments_kwargs["save_total_limit"] = (
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
)
training_arguments_kwargs["load_best_model_at_end"] = (
(
self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience
)
and self.cfg.val_set_size > 0
and self.cfg.save_steps
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
) or False
training_arguments_kwargs["ddp_find_unused_parameters"] = (
False if self.cfg.ddp else None
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_run_id if self.cfg.use_wandb else None
)
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler
if self.cfg.lr_scheduler
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine"
)
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
training_arguments_kwargs["sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs["eval_sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs[
"sample_packing_seq_len_multiplier"
] = self.cfg.micro_batch_size
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
)
training_args = self.hook_post_create_training_args(training_args)
trainer_kwargs = {}
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
if self.cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
self.cfg.sequence_len / 64
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens,
get_mem_id,
set_model_mem_id,
)
set_model_mem_id(self.model, self.tokenizer)
LOG.info("Adding landmark attention tokens to dataset")
for dataset in [self.train_dataset, self.eval_dataset]:
dataset = dataset.map(
partial(
add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer)
),
batched=False,
num_proc=32,
)
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=self.get_callbacks(),
num_epochs=self.cfg.num_epochs,
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
if self.cfg.deepspeed and self.cfg.sample_packing:
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
"train_micro_batch_size_per_gpu"
] = self.cfg.micro_batch_size
return trainer

View File

@@ -2,7 +2,7 @@
import logging import logging
import os import os
from typing import List from typing import List, Optional
import torch import torch
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
@@ -30,14 +30,20 @@ class TokenizedPromptDataset(Dataset):
self, self,
prompt_tokenizer: PromptTokenizingStrategy, prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset, dataset: IterableDataset,
process_count: Optional[int] = None,
**kwargs, **kwargs,
): ):
self.prompt_tokenizer = prompt_tokenizer self.prompt_tokenizer = prompt_tokenizer
self.process_count = process_count
super().__init__(self.process(dataset).data, **kwargs) super().__init__(self.process(dataset).data, **kwargs)
def process(self, dataset): def process(self, dataset):
features = dataset.features.keys() features = dataset.features.keys()
num_proc = min(64, os.cpu_count()) num_proc = (
min(64, self.process_count)
if self.process_count
else min(64, os.cpu_count())
)
map_kwargs = {} map_kwargs = {}
if self.prompt_tokenizer.supports_batched: if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True map_kwargs["batched"] = True

View File

@@ -13,12 +13,18 @@ import transformers
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer, LlamaDecoderLayer as OriginalLlamaDecoderLayer,
) )
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from transformers.models.llama.modeling_llama import (
LlamaMLP,
apply_rotary_pos_emb,
repeat_kv,
)
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
try: try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
@@ -38,6 +44,28 @@ except ImportError:
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def replace_llama_mlp_with_swiglu(model):
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
mlp = FusedMLP(
module.config, module.gate_proj, module.up_proj, module.down_proj
)
set_module_name(model, name, mlp)
def replace_llama_qkv_with_fused(model):
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
qkv = FusedAttention(
module.config,
module.q_proj,
module.k_proj,
module.v_proj,
module.o_proj,
)
set_module_name(model, name, qkv)
def replace_llama_attn_with_flash_attn( def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False, packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False, cross_entropy: Optional[bool] = False,
@@ -86,6 +114,92 @@ def replace_llama_attn_with_flash_attn(
) )
class FusedAttention(LlamaAttention):
"""
Fused QKV Attention layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
q: torch.nn.Linear, # pylint: disable=invalid-name
k: torch.nn.Linear, # pylint: disable=invalid-name
v: torch.nn.Linear, # pylint: disable=invalid-name
o: torch.nn.Linear, # pylint: disable=invalid-name
):
super().__init__(config)
self.config = config
self.init_device = next(iter(q.state_dict().values())).device
# define equivalent fused qkv projection
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
self.qkv_proj = torch.nn.Linear(
q.in_features, sum(self.out_features), device=self.init_device, bias=False
)
self.o_proj = o
# overwrite initialized weights with pretrained weights
self.qkv_proj.weight.data = torch.cat(
(q.weight.data, k.weight.data, v.weight.data), dim=0
)
def _post_training(self, model, name):
q_proj, k_proj, v_proj = torch.split(
self.qkv_proj.weight.data, self.out_features, dim=0
)
new_attn = LlamaAttention(self.config)
new_attn.q_proj.weight.data = q_proj
new_attn.k_proj.weight.data = k_proj
new_attn.v_proj.weight.data = v_proj
new_attn.o_proj.weight.data = self.o_proj.weight.data
set_module_name(model, name, new_attn)
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)
# Disable the transformation of the attention mask in LlamaModel as the flash attention # Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask # requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask( def _prepare_decoder_attention_mask(
@@ -147,9 +261,14 @@ def flashattn_forward(
value_states = torch.cat(value_states, dim=-1) value_states = torch.cat(value_states, dim=-1)
else: else:
query_states = self.q_proj(hidden_states) if isinstance(self, FusedAttention):
key_states = self.k_proj(hidden_states) query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
value_states = self.v_proj(hidden_states) self.out_features, dim=-1
)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view( query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim bsz, q_len, self.num_heads, self.head_dim
@@ -202,6 +321,8 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen # only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape is_causal = key_states.shape == query_states.shape
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing # special handling using sample packing
qkv = torch.stack( qkv = torch.stack(
@@ -211,7 +332,12 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func( output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True qkv,
cu_seqlens,
max_seqlen,
dropout_p=dropout_rate,
softmax_scale=None,
causal=True,
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape: elif query_states.shape == key_states.shape:
@@ -234,7 +360,7 @@ def flashattn_forward(
qkv_unpad, qkv_unpad,
cu_seqlens_q, cu_seqlens_q,
max_seqlen_q, max_seqlen_q,
0.0, dropout_p=dropout_rate,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
) )
@@ -247,6 +373,7 @@ def flashattn_forward(
output = flash_attn_kvpacked_func( output = flash_attn_kvpacked_func(
query_states, query_states,
torch.stack([key_states, value_states], 2), torch.stack([key_states, value_states], 2),
dropout_p=dropout_rate,
causal=is_causal, causal=is_causal,
) )
else: else:
@@ -279,7 +406,7 @@ def flashattn_forward(
cu_seqlens_k, cu_seqlens_k,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
0.0, dropout_p=dropout_rate,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
) )

View File

@@ -25,6 +25,8 @@ def sdp_attention_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()

View File

@@ -29,6 +29,8 @@ def xformers_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()

View File

@@ -1,40 +0,0 @@
"""
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
"""
import torch
import transformers.models.llama.modeling_llama
from transformers.utils import logging
logger = logging.get_logger(__name__)
def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
# pylint: disable=duplicate-code
def noised_embed(orig_embed, noise_alpha, model):
def new_func(input_ids):
# during training, we add noise to the embedding
# during generation, we don't add noise to the embedding
if model.training:
embed_init = orig_embed(input_ids)
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
mag_norm = noise_alpha / torch.sqrt(dims)
return embed_init + torch.zeros_like(embed_init).uniform_(
-mag_norm, mag_norm
)
return orig_embed(input_ids)
return new_func
def post_init(orig_post_init):
def new_func(self):
orig_post_init(self)
self.embed_tokens.forward = noised_embed(
self.embed_tokens.forward, noise_alpha, self
)
return new_func
transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
transformers.models.llama.modeling_llama.LlamaModel.post_init
)

View File

@@ -14,6 +14,9 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_func,
) )
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import (
MistralAttention as OriginalMistralAttention,
)
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer, MistralDecoderLayer as OriginalMistralDecoderLayer,
) )
@@ -42,6 +45,44 @@ def replace_mistral_attn_with_flash_attn(
) )
@torch.jit.script
def _make_sliding_window_causal_mask(
bsz: int,
tgt_len: int,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: int = 4096,
):
"""
Make causal mask used for sliding window attention
"""
tensor = torch.full(
(tgt_len, tgt_len),
fill_value=1,
device=device,
)
mask = torch.tril(tensor, diagonal=0)
# make the mask banded to account for sliding window
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
mask = torch.triu(mask, diagonal=-sliding_window + 1)
mask = torch.log(mask).to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention # Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask # requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask( def _prepare_decoder_attention_mask(
@@ -53,11 +94,29 @@ def _prepare_decoder_attention_mask(
sliding_window, sliding_window,
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
# [bsz, seq_len] # [bsz, seq_len]
if attention_mask is None:
return attention_mask
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
sliding_window_mask = _make_sliding_window_causal_mask(
bsz=input_shape[0],
tgt_len=input_shape[1],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)
attention_mask = attention_mask + sliding_window_mask
else:
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
return attention_mask return attention_mask
def flashattn_forward( def flashattn_forward(
self, self: OriginalMistralAttention,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
@@ -91,10 +150,41 @@ def flashattn_forward(
query_states, key_states, cos, sin, position_ids query_states, key_states, cos, sin, position_ids
) )
use_sliding_windows = (
hasattr(self.config, "sliding_window") is not None
and kv_seq_len > self.config.sliding_window
)
if use_sliding_windows:
window_size = (self.config.sliding_window, self.config.sliding_window)
else:
window_size = (-1, -1)
if past_key_value is not None: if past_key_value is not None:
# reuse k, v, self_attention # Activate slicing cache only if the config has a value `sliding_windows` attribute
key_states = torch.cat([past_key_value[0], key_states], dim=2) if (
value_states = torch.cat([past_key_value[1], value_states], dim=2) hasattr(self.config, "sliding_window")
and kv_seq_len > self.config.sliding_window
):
slicing_tokens = kv_seq_len - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
past_key_value = (past_key, past_value) if use_cache else None
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
@@ -111,6 +201,8 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen # only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape is_causal = key_states.shape == query_states.shape
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing # special handling using sample packing
qkv = torch.stack( qkv = torch.stack(
@@ -120,7 +212,13 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func( output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True qkv,
cu_seqlens,
max_seqlen,
dropout_p=dropout_rate,
softmax_scale=None,
causal=True,
window_size=window_size,
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape: elif query_states.shape == key_states.shape:
@@ -143,9 +241,10 @@ def flashattn_forward(
qkv_unpad, qkv_unpad,
cu_seqlens_q, cu_seqlens_q,
max_seqlen_q, max_seqlen_q,
0.0, dropout_p=dropout_rate,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
window_size=window_size,
) )
output = output_pad_fn(output_unpad) output = output_pad_fn(output_unpad)
else: else:
@@ -156,7 +255,9 @@ def flashattn_forward(
output = flash_attn_kvpacked_func( output = flash_attn_kvpacked_func(
query_states, query_states,
torch.stack([key_states, value_states], 2), torch.stack([key_states, value_states], 2),
dropout_p=dropout_rate,
causal=is_causal, causal=is_causal,
window_size=window_size,
) )
else: else:
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
@@ -188,9 +289,10 @@ def flashattn_forward(
cu_seqlens_k, cu_seqlens_k,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
0.0, dropout_p=dropout_rate,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
window_size=window_size,
) )
output = output_pad_fn(output_unpad) output = output_pad_fn(output_unpad)

View File

@@ -1,40 +0,0 @@
"""
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
"""
import torch
import transformers.models.mistral.modeling_mistral
from transformers.utils import logging
logger = logging.get_logger(__name__)
def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
# pylint: disable=duplicate-code
def noised_embed(orig_embed, noise_alpha, model):
def new_func(input_ids):
# during training, we add noise to the embedding
# during generation, we don't add noise to the embedding
if model.training:
embed_init = orig_embed(input_ids)
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
mag_norm = noise_alpha / torch.sqrt(dims)
return embed_init + torch.zeros_like(embed_init).uniform_(
-mag_norm, mag_norm
)
return orig_embed(input_ids)
return new_func
def post_init(orig_post_init):
def new_func(self):
orig_post_init(self)
self.embed_tokens.forward = noised_embed(
self.embed_tokens.forward, noise_alpha, self
)
return new_func
transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
transformers.models.mistral.modeling_mistral.MistralModel.post_init
)

View File

@@ -0,0 +1,65 @@
"""
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
"""
import torch
from peft import PeftModel
from transformers import PreTrainedModel
def patch_neft(alpha, model):
embeddings = None
if isinstance(model, PreTrainedModel):
embeddings = model.get_input_embeddings()
if isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
if not embeddings:
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
embeddings.noisy_embedding_alpha = alpha
old_forward = embeddings.forward
# This hack seems to be needed to properly use a custom forward pass
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
embeddings, embeddings.__class__
)
setattr(embeddings, "forward", bound_method)
embeddings._old_forward = old_forward # pylint: disable=protected-access
return model
def unpatch_neft(model):
embeddings = None
if isinstance(model, PreTrainedModel):
embeddings = model.get_input_embeddings()
if isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
if not embeddings:
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
if hasattr(embeddings, "_old_forward"):
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
del embeddings._old_forward # pylint: disable=protected-access
del embeddings.noisy_embedding_alpha
def neft_forward(self, inputs: torch.Tensor):
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
if self.training:
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
-mag_norm, mag_norm
)
return embeddings
def pretrain_hook(cfg, trainer):
if cfg.noisy_embedding_alpha:
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
def post_train_hook(cfg, trainer):
if cfg.noisy_embedding_alpha:
unpatch_neft(trainer.model)

View File

@@ -101,3 +101,16 @@ def get_cu_seqlens_from_pos_ids(position_ids):
max_seq_lens.append(max_seq_len) max_seq_lens.append(max_seq_len)
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def set_module_name(model, name, value):
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model
child_name = name
setattr(parent, child_name, value)

View File

@@ -1,6 +1,6 @@
"""Module containing the AlpacaQAPromptTokenizingStrategy class""" """Module for Alpaca prompt strategy classes"""
from typing import Tuple from typing import Any, Dict, Optional, Tuple
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy, AlpacaPromptTokenizingStrategy,
@@ -9,9 +9,13 @@ from axolotl.prompt_tokenizers import (
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
def load(tokenizer, cfg): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
prompt_style = PromptStyle.CHAT.value
if ds_cfg and "conversation" in ds_cfg:
prompt_style = ds_cfg["conversation"]
return AlpacaPromptTokenizingStrategy( return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value), AlpacaPrompter(prompt_style),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,

View File

@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
) )
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
return SimpleShareGPTPromptTokenizingStrategy( strategy = SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2( ShareGPTPrompterV2(
conversation=conversation, conversation=conversation,
role_key_model=field_model, role_key_model=field_model,
@@ -34,6 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
if ds_cfg and "strict" in ds_cfg:
strategy.strict = ds_cfg["strict"]
return strategy
def load_role(tokenizer, cfg): def load_role(tokenizer, cfg):
@@ -59,8 +62,26 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
basic sharegpt strategy to grab conversations from the sample row basic sharegpt strategy to grab conversations from the sample row
""" """
_strict = True
@property
def strict(self):
return self._strict
@strict.setter
def strict(self, strict):
self._strict = strict
def get_conversation_thread(self, prompt): def get_conversation_thread(self, prompt):
return prompt["conversations"] conversations = prompt["conversations"]
if self.strict:
return conversations
# remap roles - allow for assistant turn
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
turns = [
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
]
return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):

View File

@@ -45,6 +45,8 @@ class PromptTokenizingStrategy(abc.ABC):
self.prompter = prompter self.prompter = prompter
self.tokenizer: PreTrainedTokenizer = tokenizer self.tokenizer: PreTrainedTokenizer = tokenizer
self.train_on_inputs = train_on_inputs self.train_on_inputs = train_on_inputs
# sequence_len and max_length can be different for CompletionPromptTokenizingStrategy.
# TODO: Document how they are different.
self.sequence_len = sequence_len self.sequence_len = sequence_len
self.max_length = sequence_len self.max_length = sequence_len
@@ -59,34 +61,31 @@ class PromptTokenizingStrategy(abc.ABC):
def _tokenize( def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding: ) -> BatchEncoding:
result: BatchEncoding empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
if not prompt: if not prompt:
LOG.warning("Empty text requested for tokenization.") LOG.warning("Empty text requested for tokenization.")
result = BatchEncoding(data={"input_ids": [], "attention_mask": []}) return empty
else:
result = self.tokenizer( result = self.tokenizer(
prompt, prompt,
truncation=True, truncation=True,
max_length=self.max_length, max_length=self.max_length,
padding=False, padding=False,
return_tensors=None, return_tensors=None,
) )
if len(result["input_ids"]) == 0: if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset") LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
return empty
if ( if (
len(result["input_ids"]) > 0 result["input_ids"][-1] != self.tokenizer.eos_token_id
and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.max_length and len(result["input_ids"]) < self.max_length
and add_eos_token and add_eos_token
): ):
result["input_ids"].append(self.tokenizer.eos_token_id) result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1) result["attention_mask"].append(1)
if ( if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
len(result["input_ids"]) > 0
and result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
result["input_ids"] = result["input_ids"][1:] result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:] result["attention_mask"] = result["attention_mask"][1:]
@@ -122,7 +121,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
if not self.train_on_inputs: if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"]) user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing # TODO this could be sped up using numpy array slicing
tokenized_prompt["labels"] = [-100] * user_prompt_len tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len
tokenized_res_prompt = self._tokenize( tokenized_res_prompt = self._tokenize(
response, strip_bos_token=True, add_eos_token=True response, strip_bos_token=True, add_eos_token=True
) )
@@ -246,6 +245,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
raise NotImplementedError raise NotImplementedError
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
( (
instruction, instruction,
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
@@ -270,7 +270,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
user_prompt_len = len(tokenized_user_prompt["input_ids"]) user_prompt_len = len(tokenized_user_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing # TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [ tokenized_full_prompt["labels"] = [
-100 IGNORE_INDEX
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:] ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
return tokenized_full_prompt return tokenized_full_prompt
@@ -334,6 +334,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
return prompt["conversations"] return prompt["conversations"]
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
# Initial values. We will append to these as we go through the conversation.
result, current_len = tokenize_prompt_default() result, current_len = tokenize_prompt_default()
conversation: Conversation = ( conversation: Conversation = (
self.prompter._conversation.copy() # pylint: disable=protected-access self.prompter._conversation.copy() # pylint: disable=protected-access
@@ -355,62 +356,67 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
for _, part in enumerate( for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt)) self.prompter.build_prompt(self.get_conversation_thread(prompt))
): ):
if isinstance(part, tuple): if not isinstance(part, tuple):
if conversation.roles[0] in part[0]: LOG.warning(f"expected tuple, got {part}")
role = ( continue
part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap user, assistant = conversation.roles
else part[0] role, content = part
)
turn = role + part[1] # Uses "in" because role contains extra characters
# this is still the user query, we should if user in role:
if not part[1].strip(): role = (
LOG.warning(f"user turn has empty text: {prompt}") role.replace(role_remap[0]["from"], role_remap[0]["to"])
res = self._tokenize( if role_remap
turn, else role
add_eos_token=False, )
strip_bos_token=True, turn = role + content
) # this is still the user query, we should
# everything from this is masked out from the labels if not content.strip():
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) LOG.warning(f"user turn has empty text: {prompt}")
elif conversation.roles[1] in part[0]: res = self._tokenize(
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID turn,
role = ( add_eos_token=False,
part[0].replace(role_remap[1]["from"], role_remap[1]["to"]) strip_bos_token=True,
if role_remap )
else part[0] # everything from this is masked out from the labels
) labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
turn = role + part[1] elif assistant in role:
# this should be the assistant response, should end with an eos token # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
if not part[1].strip(): role = (
LOG.warning(f"assistant turn has empty text: {prompt}") role.replace(role_remap[1]["from"], role_remap[1]["to"])
res = self._tokenize( if role_remap
turn, else role
add_eos_token=True, )
strip_bos_token=True, turn = role + content
) # this should be the assistant response, should end with an eos token
role_res = self._tokenize( if not content.strip():
role.rstrip(), LOG.warning(f"assistant turn has empty text: {prompt}")
add_eos_token=False, res = self._tokenize(
strip_bos_token=True, turn,
) add_eos_token=True,
# not masked out from labels strip_bos_token=True,
labels = copy.deepcopy(res["input_ids"]) )
len_role = len(role_res["input_ids"]) role_res = self._tokenize(
labels[:len_role] = [IGNORE_TOKEN_ID] * min( role.rstrip(),
len_role, len(labels) add_eos_token=False,
) strip_bos_token=True,
elif part[0] == "": )
turn = part[1] # not masked out from labels
# this is only ever the first part, should include the bos token and the user query labels = copy.deepcopy(res["input_ids"])
res = self._tokenize( len_role = len(role_res["input_ids"])
turn, add_eos_token=False, strip_bos_token=False labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels))
) elif role == "":
# everything from this is masked out from the labels turn = content
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) # this is only ever the first part, should include the bos token and the user query
else: res = self._tokenize(
LOG.warning(f"unhandled role: {part[0]}") turn, add_eos_token=False, strip_bos_token=False
continue )
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result( result, current_len = parse_tokenized_to_result(
@@ -424,38 +430,6 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
except (KeyError, AssertionError, IndexError) as err: except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err raise InvalidDataException(str(err)) from err
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
if not prompt.strip():
LOG.warning("Empty text requested for tokenization.")
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
else:
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if (
len(result["input_ids"]) > 0
and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.sequence_len
and add_eos_token
):
result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1)
if (
len(result["input_ids"]) > 0
and result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:]
result["labels"] = result["input_ids"].copy()
return result
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]: def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
""" """

View File

@@ -4,10 +4,12 @@ import logging
from enum import Enum from enum import Enum
from typing import Generator, Optional, Union from typing import Generator, Optional, Union
from colorama import Fore
from fastchat.conversation import Conversation, get_conv_template from fastchat.conversation import Conversation, get_conv_template
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100 IGNORE_TOKEN_ID = -100
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
class PromptStyle(Enum): class PromptStyle(Enum):
@@ -20,7 +22,13 @@ class PromptStyle(Enum):
CHATML = "chatml" CHATML = "chatml"
class AlpacaPrompter: class Prompter:
"""
Base prompter class for all prompters
"""
class AlpacaPrompter(Prompter):
""" """
Base class for alpaca prompters Base class for alpaca prompters
""" """
@@ -55,29 +63,38 @@ class AlpacaPrompter:
) )
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n" self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
def _build_result(self, instruction, input_text, output):
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input_text:
res = (
self.system_format.format(system=self.system_prompt)
if self.system_prompt
else ""
) + self.turn_format.format(instruction=instruction, input=input_text)
else:
res = (
self.system_format.format(system=self.system_no_input_prompt)
if self.system_no_input_prompt
else ""
) + self.turn_no_input_format.format(instruction=instruction)
if output:
res = f"{res}{output}"
return res
def build_prompt( def build_prompt(
self, self,
instruction: str, instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None, output: Union[None, str] = None,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input yield self._build_result(instruction, input, output)
# if a label (=response, =output) is provided, it's also appended.
if input: def __repr__(self) -> str:
res = ( return REPR_TEMPLATE.format(
self.system_format.format(system=self.system_prompt) full_prompt=self._build_result("{instruction}", "{input}", "{output}")
if self.system_prompt )
else ""
) + self.turn_format.format(instruction=instruction, input=input)
else:
res = (
self.system_format.format(system=self.system_no_input_prompt)
if self.system_prompt
else ""
) + self.turn_no_input_format.format(instruction=instruction)
if output:
res = f"{res}{output}"
yield res
class UnpromptedPrompter(AlpacaPrompter): class UnpromptedPrompter(AlpacaPrompter):
@@ -148,7 +165,7 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
""" """
class ReflectAlpacaPrompter: class ReflectAlpacaPrompter(Prompter):
""" """
Prompter for ReflectAlpaca Prompter for ReflectAlpaca
""" """
@@ -191,14 +208,14 @@ class ReflectAlpacaPrompter:
) )
self.response_split = "ASSISTANT:" self.response_split = "ASSISTANT:"
def build_prompt( def _build_result(
self, self,
instruction: str, instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None, output: Union[None, str] = None,
reflection: Union[None, str] = None, reflection: Union[None, str] = None,
corrected: Union[None, str] = None, corrected: Union[None, str] = None,
) -> Generator[str, None, None]: ):
# returns the full prompt from instruction and optional input # returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended. # if a label (=response, =output) is provided, it's also appended.
if input: if input:
@@ -212,7 +229,30 @@ class ReflectAlpacaPrompter:
corrected=corrected, corrected=corrected,
) )
res = f"{res}{label}" res = f"{res}{label}"
yield res
return res
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
reflection: Union[None, str] = None,
corrected: Union[None, str] = None,
) -> Generator[str, None, None]:
# pylint: disable=duplicate-code
yield self._build_result(
instruction,
input,
output,
reflection,
corrected,
)
def __repr__(self) -> str:
return REPR_TEMPLATE.format(
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
)
SHAREGPT_ASSERTION_FAILED_ROLE = ( SHAREGPT_ASSERTION_FAILED_ROLE = (
@@ -220,7 +260,7 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
) )
class ShareGPTPrompter: # pylint: disable=too-few-public-methods class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
""" """
A prompter that generates prompts for the ShareGPT A prompter that generates prompts for the ShareGPT
""" """
@@ -247,7 +287,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
if role_key_model: if role_key_model:
self.role_key_model = role_key_model self.role_key_model = role_key_model
def build_prompt(self, source) -> Generator[str, None, None]: def _build_result(self, source):
if len(source) < 2: if len(source) < 2:
# If there isn't a back and forth conversation, ignore it # If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations # also happens on the data splitting leaving empty conversations
@@ -282,11 +322,20 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
for part in conv.get_turns(): return conv.get_turns()
def build_prompt(self, source) -> Generator[str, None, None]:
turns = self._build_result(source)
for part in turns:
if part[0] and not part[1]: if part[0] and not part[1]:
LOG.warning(f"role with empty message: {part[0]}") LOG.warning(f"role with empty message: {part[0]}")
yield part yield part
def __repr__(self) -> str:
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
class ShareGPTPrompterV2(ShareGPTPrompter): class ShareGPTPrompterV2(ShareGPTPrompter):
""" """
@@ -304,3 +353,15 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
role_key_human=role_key_human, role_key_human=role_key_human,
role_key_model=role_key_model, role_key_model=role_key_model,
) )
class UnsupportedPrompter(Prompter):
"""
A dummy class for custom prompters
"""
def __init__(self) -> None:
pass
def __repr__(self):
return "Pre-tokenized or custom dataset types are unsupported for logging"

View File

@@ -1,6 +1,5 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import logging
import os import os
import signal import signal
import sys import sys
@@ -10,11 +9,14 @@ from typing import Optional
import torch import torch
import transformers.modelcard import transformers.modelcard
from accelerate.logging import get_logger
from datasets import Dataset from datasets import Dataset
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from transformers.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.monkeypatch import neft_embeddings
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
@@ -24,7 +26,7 @@ src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir) sys.path.insert(0, src_dir)
configure_logging() configure_logging()
LOG = logging.getLogger("axolotl.train") LOG = get_logger("axolotl.train")
@dataclass @dataclass
@@ -39,13 +41,13 @@ class TrainDatasetMeta:
def train( def train(
*, *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
cfg: DictDefault,
cli_args: TrainerCliArgs,
dataset_meta: TrainDatasetMeta,
): ):
# load the tokenizer first # load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
train_dataset = dataset_meta.train_dataset train_dataset = dataset_meta.train_dataset
@@ -53,7 +55,10 @@ def train(
total_num_steps = dataset_meta.total_num_steps total_num_steps = dataset_meta.total_num_steps
# Load the model and tokenizer # Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...") msg = "loading model"
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True
@@ -109,6 +114,7 @@ def train(
if cfg.group_by_length: if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length") LOG.info("hang tight... sorting dataset for group_by_length")
pretrain_hooks(cfg, trainer)
if cfg.flash_optimum: if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel( with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True enable_flash=True, enable_math=True, enable_mem_efficient=True
@@ -116,9 +122,15 @@ def train(
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else: else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
post_train_hooks(cfg, trainer)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# post training
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
if trainer.is_fsdp_enabled: if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.") LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
@@ -134,6 +146,22 @@ def train(
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp: if cfg.fsdp:
trainer.save_model(cfg.output_dir) trainer.save_model(cfg.output_dir)
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone()
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
# The model name saved is `pytorch_model.bin`
unwrapped_model.save_pretrained(
cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process,
save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
)
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
@@ -144,3 +172,23 @@ def train(
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
return model, tokenizer return model, tokenizer
def pretrain_hooks(cfg, trainer):
"""
Run hooks right before kicking off the training
:param cfg:
:param trainer:
:return:
"""
neft_embeddings.pretrain_hook(cfg, trainer)
def post_train_hooks(cfg, trainer):
"""
Run hooks right after training completes
:param cfg:
:param trainer:
:return:
"""
neft_embeddings.post_train_hook(cfg, trainer)

View File

@@ -37,7 +37,7 @@ from axolotl.utils.distributed import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from axolotl.utils.trainer import AxolotlTrainingArguments from axolotl.core.trainer_builder import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks") LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100 IGNORE_INDEX = -100

View File

@@ -119,3 +119,30 @@ class DataCollatorForSeq2Seq:
features["decoder_input_ids"] = decoder_input_ids features["decoder_input_ids"] = decoder_input_ids
return features return features
@dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
Collator for multipack specific to the using the BatchSampler
"""
def __call__(self, features, return_tensors=None):
chunked_data = {}
for feature in features[0].keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(1) * np.array(item[feature])
for item in features
if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)

View File

@@ -79,6 +79,9 @@ def normalize_config(cfg):
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count() cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
if not cfg.base_model_config:
cfg.base_model_config = cfg.base_model
model_config = load_model_config(cfg) model_config = load_model_config(cfg)
cfg.model_config_type = model_config.model_type cfg.model_config_type = model_config.model_type
@@ -119,6 +122,9 @@ def normalize_config(cfg):
or (cfg.model_type and "mistral" in cfg.model_type.lower()) or (cfg.model_type and "mistral" in cfg.model_type.lower())
) )
if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)
log_gpu_memory_usage(LOG, "baseline", cfg.device) log_gpu_memory_usage(LOG, "baseline", cfg.device)
@@ -189,9 +195,15 @@ def validate_config(cfg):
if not cfg.load_in_4bit: if not cfg.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora") raise ValueError("Require cfg.load_in_4bit to be True for qlora")
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with QLoRA")
if not cfg.load_in_8bit and cfg.adapter == "lora": if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
raise ValueError("Fused modules are not supported with LoRA")
if cfg.relora_steps: if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"): if cfg.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
@@ -205,6 +217,9 @@ def validate_config(cfg):
if cfg.lr_scheduler == "one_cycle": if cfg.lr_scheduler == "one_cycle":
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler") raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with ReLoRA")
if cfg.trust_remote_code: if cfg.trust_remote_code:
LOG.warning( LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
@@ -339,6 +354,24 @@ def validate_config(cfg):
"eval_steps and evaluation_strategy are not supported with val_set_size == 0" "eval_steps and evaluation_strategy are not supported with val_set_size == 0"
) )
if (
cfg.sample_packing
and cfg.eval_table_size
and cfg.eval_sample_packing is not False
):
raise ValueError(
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
)
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
if cfg.rope_scaling:
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -34,8 +34,10 @@ from axolotl.prompters import (
JeopardyPrompter, JeopardyPrompter,
MultipleChoiceConcisePrompter, MultipleChoiceConcisePrompter,
MultipleChoiceExplainPrompter, MultipleChoiceExplainPrompter,
Prompter,
ReflectAlpacaPrompter, ReflectAlpacaPrompter,
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
UnsupportedPrompter,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import is_main_process, zero_first
@@ -55,9 +57,10 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
def prepare_dataset(cfg, tokenizer): def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
with zero_first(is_main_process()): with zero_first(is_main_process()):
train_dataset, eval_dataset = load_prepare_datasets( train_dataset, eval_dataset, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
else: else:
@@ -70,7 +73,7 @@ def prepare_dataset(cfg, tokenizer):
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
eval_dataset = None eval_dataset = None
return train_dataset, eval_dataset, cfg.max_steps return train_dataset, eval_dataset, cfg.max_steps, prompters
with zero_first(is_main_process()): with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing( train_dataset, eval_dataset = process_datasets_for_packing(
@@ -78,17 +81,17 @@ def prepare_dataset(cfg, tokenizer):
) )
if cfg.max_steps: if cfg.max_steps:
total_num_steps = min( total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
) )
LOG.info(f"Maximum number of steps set at {total_num_steps}") LOG.info(f"Maximum number of steps set at {total_num_steps}")
else: else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) total_num_steps = calculate_total_num_steps(cfg, train_dataset)
return train_dataset, eval_dataset, total_num_steps return train_dataset, eval_dataset, total_num_steps, prompters
def load_tokenized_prepared_datasets( def load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path
) -> DatasetDict: ) -> Tuple[DatasetDict, List[Prompter]]:
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
ds_hash = str( ds_hash = str(
md5( md5(
@@ -96,7 +99,12 @@ def load_tokenized_prepared_datasets(
str(cfg.sequence_len) str(cfg.sequence_len)
+ "@" + "@"
+ "|".join( + "|".join(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) sorted(
[
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
for d in cfg.datasets
]
)
) )
+ "|" + "|"
+ tokenizer_name + tokenizer_name
@@ -109,6 +117,7 @@ def load_tokenized_prepared_datasets(
else Path(default_dataset_prepared_path) / ds_hash else Path(default_dataset_prepared_path) / ds_hash
) )
dataset = None dataset = None
prompters = []
use_auth_token = cfg.hf_use_auth_token use_auth_token = cfg.hf_use_auth_token
try: try:
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
@@ -147,48 +156,92 @@ def load_tokenized_prepared_datasets(
yield dataset yield dataset
# pylint: disable=invalid-name # pylint: disable=invalid-name
for d in for_d_in_datasets(cfg.datasets): for config_dataset in for_d_in_datasets(cfg.datasets):
ds: Union[Dataset, DatasetDict] = None ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False ds_from_hub = False
try: try:
load_dataset( load_dataset(
d.path, config_dataset.path,
name=d.name, name=config_dataset.name,
streaming=True, streaming=True,
token=use_auth_token, token=use_auth_token,
) )
ds_from_hub = True ds_from_hub = True
except FileNotFoundError: except (FileNotFoundError, ConnectionError):
pass
ds_from_cloud = False
storage_options = {}
remote_file_system = None
if config_dataset.path.startswith("s3://"):
try:
import aiobotocore.session # type: ignore
import s3fs # type: ignore
except ImportError as exc:
raise ImportError(
"s3:// paths require aiobotocore and s3fs to be installed"
) from exc
# Takes credentials from ~/.aws/credentials for default profile
s3_session = aiobotocore.session.AioSession(profile="default")
storage_options = {"session": s3_session}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif config_dataset.path.startswith(
"gs://"
) or config_dataset.path.startswith("gcs://"):
try:
import gcsfs # type: ignore
except ImportError as exc:
raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed"
) from exc
# gcsfs will use default credentials from the environment else anon
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
# TODO: Figure out how to get auth creds passed
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
# try:
# import adlfs
# except ImportError as exc:
# raise ImportError(
# "adl:// or abfs:// paths require adlfs to be installed"
# ) from exc
# # Gen 1
# storage_options = {
# "tenant_id": TENANT_ID,
# "client_id": CLIENT_ID,
# "client_secret": CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": ACCOUNT_NAME,
# "account_key": ACCOUNT_KEY,
# }
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
try:
if remote_file_system and remote_file_system.exists(
config_dataset.path
):
ds_from_cloud = True
except (FileNotFoundError, ConnectionError):
pass pass
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
local_path = Path(d.path) local_path = Path(config_dataset.path)
if local_path.exists(): if local_path.exists():
if local_path.is_dir(): if local_path.is_dir():
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk` ds = load_from_disk(config_dataset.path)
ds = load_dataset(
d.path,
name=d.name,
data_files=d.data_files,
streaming=False,
split=None,
)
elif local_path.is_file(): elif local_path.is_file():
ds_type = "json" ds_type = get_ds_type(config_dataset)
if d.ds_type:
ds_type = d.ds_type
elif ".parquet" in d.path:
ds_type = "parquet"
elif ".arrow" in d.path:
ds_type = "arrow"
elif ".csv" in d.path:
ds_type = "csv"
elif ".txt" in d.path:
ds_type = "text"
ds = load_dataset( ds = load_dataset(
ds_type, ds_type,
name=d.name, name=config_dataset.name,
data_files=d.path, data_files=config_dataset.path,
streaming=False, streaming=False,
split=None, split=None,
) )
@@ -198,25 +251,41 @@ def load_tokenized_prepared_datasets(
) )
elif ds_from_hub: elif ds_from_hub:
ds = load_dataset( ds = load_dataset(
d.path, config_dataset.path,
name=d.name, name=config_dataset.name,
streaming=False, streaming=False,
data_files=d.data_files, data_files=config_dataset.data_files,
token=use_auth_token, token=use_auth_token,
) )
else: elif ds_from_cloud and remote_file_system:
if isinstance(d.data_files, str): if remote_file_system.isdir(config_dataset.path):
fp = hf_hub_download( ds = load_from_disk(
repo_id=d.path, config_dataset.path,
repo_type="dataset", storage_options=storage_options,
filename=d.data_files,
) )
elif isinstance(d.data_files, list): elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
)
else:
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
)
elif isinstance(config_dataset.data_files, list):
fp = [] fp = []
for file in d.data_files: for file in config_dataset.data_files:
fp.append( fp.append(
hf_hub_download( hf_hub_download(
repo_id=d.path, repo_id=config_dataset.path,
repo_type="dataset", repo_type="dataset",
filename=file, filename=file,
) )
@@ -226,21 +295,27 @@ def load_tokenized_prepared_datasets(
"data_files must be either a string or list of strings" "data_files must be either a string or list of strings"
) )
ds = load_dataset( ds = load_dataset(
"json", name=d.name, data_files=fp, streaming=False, split=None "json",
name=config_dataset.name,
data_files=fp,
streaming=False,
split=None,
) )
if not ds: if not ds:
raise ValueError("unhandled dataset load") raise ValueError("unhandled dataset load")
# support for using a subset of the data # support for using a subset of the data
if d.shards: if config_dataset.shards:
if "train" in ds: if "train" in ds:
ds = ds.shuffle(seed=seed)["train"].shard( ds = ds.shuffle(seed=seed)["train"].shard(
num_shards=d.shards, index=0 num_shards=config_dataset.shards, index=0
) )
else: else:
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0) ds = ds.shuffle(seed=seed).shard(
num_shards=config_dataset.shards, index=0
)
d_base_type = d_prompt_style = None d_base_type = d_prompt_style = None
d_type = d.type d_type = config_dataset.type
if isinstance(d_type, str): if isinstance(d_type, str):
d_type_split = d_type.split(":") d_type_split = d_type.split(":")
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
@@ -249,108 +324,26 @@ def load_tokenized_prepared_datasets(
ds = ds["train"] ds = ds["train"]
elif ( elif (
isinstance(ds, DatasetDict) isinstance(ds, DatasetDict)
and d.train_on_split and config_dataset.train_on_split
and d.train_on_split in ds and config_dataset.train_on_split in ds
): ):
ds = ds[d.train_on_split] ds = ds[config_dataset.train_on_split]
elif isinstance(ds, DatasetDict): elif isinstance(ds, DatasetDict):
raise ValueError( raise ValueError(
f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `" f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
)
if (
"input_ids" in ds.features
and "attention_mask" in ds.features
and "labels" in ds.features
):
# dataset is already tokenized, just drop it straight in
datasets.append(ds)
elif isinstance(d.type, DictDefault):
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif ds_strategy := load(d.type, tokenizer, cfg, d):
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy(
AlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "explainchoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
MultipleChoiceExplainPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "concisechoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
MultipleChoiceConcisePrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "summarizetldr":
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
SummarizeTLDRPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "jeopardy":
ds_strategy = JeopardyPromptTokenizingStrategy(
JeopardyPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "oasst":
ds_strategy = OpenAssistantPromptTokenizingStrategy(
AlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "gpteacher":
ds_strategy = GPTeacherPromptTokenizingStrategy(
GPTeacherPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "reflection":
ds_strategy = AlpacaReflectionPTStrategy(
ReflectAlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
else:
suffix = ""
if ":load_" in d.type:
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
raise ValueError(
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
) )
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
config_dataset=config_dataset,
dataset=ds,
tokenizer=tokenizer,
cfg=cfg,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
LOG.info("merging datasets") LOG.info("merging datasets")
dataset = concatenate_datasets(datasets) dataset = concatenate_datasets(datasets)
@@ -368,14 +361,32 @@ def load_tokenized_prepared_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
) )
return dataset return dataset, prompters
def get_ds_type(config_dataset: DictDefault):
"""
Get the dataset type from the path if it's not specified
"""
ds_type = "json"
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
return ds_type
def load_prepare_datasets( def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
cfg, cfg,
default_dataset_prepared_path, default_dataset_prepared_path,
) -> Tuple[Dataset, Dataset]: ) -> Tuple[Dataset, Dataset, List[Prompter]]:
max_packed_sequence_len = ( max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
) )
@@ -384,6 +395,7 @@ def load_prepare_datasets(
) # make sure we don't accidentally set it larger than sequence_len ) # make sure we don't accidentally set it larger than sequence_len
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
prompters: List[Prompter] = []
if cfg.max_packed_sequence_len is not None: if cfg.max_packed_sequence_len is not None:
# see if we can go ahead and load the stacked dataset # see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else "" seed = f"@{str(cfg.seed)}" if cfg.seed else ""
@@ -439,7 +451,7 @@ def load_prepare_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
) )
else: else:
dataset = load_tokenized_prepared_datasets( dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path
) )
@@ -481,7 +493,7 @@ def load_prepare_datasets(
private=True, private=True,
) )
else: else:
dataset = load_tokenized_prepared_datasets( dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path
) )
@@ -517,14 +529,13 @@ def load_prepare_datasets(
train_fingerprint = md5(to_hash_train) train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test) test_fingerprint = md5(to_hash_test)
with zero_first(is_main_process()): dataset = dataset.train_test_split(
dataset = dataset.train_test_split( test_size=cfg.val_set_size,
test_size=cfg.val_set_size, shuffle=False,
shuffle=False, seed=cfg.seed or 42,
seed=cfg.seed or 42, train_new_fingerprint=train_fingerprint,
train_new_fingerprint=train_fingerprint, test_new_fingerprint=test_fingerprint,
test_new_fingerprint=test_fingerprint, )
)
train_dataset = dataset["train"] train_dataset = dataset["train"]
eval_dataset = dataset["test"] eval_dataset = dataset["test"]
@@ -532,7 +543,144 @@ def load_prepare_datasets(
train_dataset = dataset train_dataset = dataset
eval_dataset = None eval_dataset = None
return train_dataset, eval_dataset return train_dataset, eval_dataset, prompters
def get_dataset_wrapper(
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
):
dataset_wrapper = None
dataset_prompter = None
if (
"input_ids" in dataset.features
and "attention_mask" in dataset.features
and "labels" in dataset.features
):
# dataset is already tokenized, just drop it straight in
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = dataset
elif isinstance(config_dataset.type, DictDefault):
ds_strategy = load(
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
elif d_base_type == "alpaca":
dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "explainchoice":
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "concisechoice":
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "summarizetldr":
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "jeopardy":
dataset_prompter = JeopardyPrompter(d_prompt_style)
ds_strategy = JeopardyPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "oasst":
dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = OpenAssistantPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "gpteacher":
dataset_prompter = GPTeacherPrompter(d_prompt_style)
ds_strategy = GPTeacherPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
elif d_base_type == "reflection":
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaReflectionPTStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper
else:
suffix = ""
if ":load_" in config_dataset.type:
suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?"
LOG.error(
f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}"
)
raise ValueError(
f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}"
)
return dataset_wrapper, dataset_prompter
def encode_pretraining( def encode_pretraining(

View File

@@ -1,302 +0,0 @@
# pylint: skip-file
import hashlib
import itertools
import logging
import math
from typing import Any, Callable, List, Union
import numba
import numpy as np
from torch.utils.data import DistributedSampler, Sampler
LOG = logging.getLogger("axolotl.utils.dataloader")
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
not_found = True
for idx in range(n):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
if not_found:
return False
return True
@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
indices = np.argsort(a)[::-1]
a = a[indices]
bins: List[Any] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
break
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
return bins_result, len(a)
@numba.njit
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
):
"""
:param lengths: array of lengths of each sample
:param lengths_cumsum: cumulative sum of consecutive lengths
:param rank: rank for this process
:param c: length of tokens per batch
:param n: number of ranks
:return:
"""
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0
start_index = 0
result = []
result_totseqs = []
while True:
# binary search [left, right)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
left = mid
else:
right = mid
# use length left
batch, tot_seqs = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
if len(batch) < n:
break
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
# add total seqs for all ranks
result_totseqs.append(tot_seqs)
# yield batch[rank], tot_seqs, s, len(result) * c * n
return result, result_totseqs, s, len(result) * c * n
def chunk(iterable, n):
"""
Chunk data into tuples of length n
"""
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(itertools.islice(it, n)):
yield batch
def hash_indices(lst: List[int]) -> str:
# Convert the list of integers to a string representation
concatenated = ",".join(map(str, lst))
# Generate the hash
sha256 = hashlib.sha256()
sha256.update(concatenated.encode())
return sha256.hexdigest()
class MultipackDistributedDataloader:
"""Unpadded data loading using Multipack.
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
"""
def __init__(
self,
dataset: Any,
collate_fn: Callable,
seq_max_length: int = 2048,
batch_size: int = 1,
sampler: Union[Sampler, DistributedSampler] = None,
packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1,
):
# Dataset
self.dataset = dataset
self.lengths = (
dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
)
assert isinstance(self.lengths, np.ndarray)
assert batch_size % sample_packing_seq_len_multiplier == 0
assert batch_size >= sample_packing_seq_len_multiplier
self.sampler = sampler
self.batch_size = batch_size
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
self.seq_max_length = seq_max_length
self.batch_max_length = batch_size * seq_max_length
self.collate_fn = collate_fn
self.num_replicas = 1
self.rank = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count
def generate_batches(self, set_stats=False):
LOG.info("generating packed batches")
if self.sampler:
indices = [idx for idx in self.sampler]
else:
indices = range(0, len(self.dataset))
LOG.info(hash_indices(indices))
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, totseqs, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=self.rank,
# c=self.batch_max_length,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
n=self.num_replicas,
)
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
return batches, totseqs
def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys()
len_remaining = self._len_est()
for batches in chunk(
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
):
chunked_data = []
attn_mask_cum_idx = 0
for batch in batches:
concatenated = {}
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features:
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched_data)
if feature in item
]
attn_mask_cum_idx += len(batched_data)
concatenated[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature])
for item in batched_data
if feature in item
]
concatenated[feature] = np.concatenate(arrays)
chunked_data.append(concatenated)
yield self.collate_fn(chunked_data)
len_remaining -= 1
if not len_remaining:
return
# yield a no-op for cases where we don't have any data left to pack
for i in range(0, len_remaining):
yield self.collate_fn(
[
{
"input_ids": [0],
"labels": [-100],
"attention_mask": [True],
"position_ids": [0],
}
]
)
def _len_est(self):
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // self.device_count
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return (
math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
// self.seq_max_length
// self.batch_size
)
- 1
)
def __len__(self):
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
# the same share of total tokens
# if not self.eff_total_used:
# batches, _ = self.generate_batches(set_stats=True)
# LOG.info(
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
# f"actual packing efficiency: {self.efficiency()}"
# )
return max(1, self._len_est())
def len_w_stats(self):
if not self.eff_total_used:
batches, _ = self.generate_batches(set_stats=True)
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"actual packing efficiency: {self.efficiency()}"
)
return max(1, self._len_est())
def efficiency(self):
return self.eff_total_used / self.eff_total_slots

View File

@@ -50,6 +50,17 @@ def get_world_size():
return int(os.getenv("WORLD_SIZE", "1")) return int(os.getenv("WORLD_SIZE", "1"))
@contextmanager
def zero_only():
"""
Context manager that only runs the enclosed block on the main rank.
"""
if is_main_process():
yield
else:
yield None
@contextmanager @contextmanager
def zero_first(is_main): def zero_first(is_main):
""" """

View File

@@ -17,7 +17,6 @@ from transformers import ( # noqa: F401
AutoTokenizer, AutoTokenizer,
BitsAndBytesConfig, BitsAndBytesConfig,
GPTQConfig, GPTQConfig,
LlamaConfig,
PreTrainedModel, PreTrainedModel,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
@@ -31,10 +30,15 @@ LOG = logging.getLogger("axolotl")
def load_model_config(cfg): def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code: bool = False or cfg.trust_remote_code trust_remote_code = cfg.trust_remote_code is True
return AutoConfig.from_pretrained( model_config = AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code model_config_name, trust_remote_code=trust_remote_code
) )
if cfg.model_config:
for key, val in cfg.model_config.items():
setattr(model_config, key, val)
return model_config
def load_tokenizer(cfg): def load_tokenizer(cfg):
@@ -51,7 +55,7 @@ def load_tokenizer(cfg):
if cfg.tokenizer_type: if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type) tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
tokenizer = tokenizer_cls.from_pretrained( tokenizer = tokenizer_cls.from_pretrained(
tokenizer_config, tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
@@ -72,11 +76,6 @@ def load_tokenizer(cfg):
# set a pad_token, but use eos_token so we don't add a new token # set a pad_token, but use eos_token so we don't add a new token
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -98,6 +97,11 @@ def load_tokenizer(cfg):
] ]
) )
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
return tokenizer return tokenizer
@@ -110,7 +114,6 @@ def load_model(
Load a model for a given configuration and tokenizer. Load a model for a given configuration and tokenizer.
""" """
base_model = cfg.base_model base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type model_type = cfg.model_type
model_config = load_model_config(cfg) model_config = load_model_config(cfg)
@@ -180,26 +183,6 @@ def load_model(
LOG.info("patching with flash attention") LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.llama_embeddings_hijack import (
replace_llama_embeddings_with_uniform_distribution,
)
LOG.info("patching with noisy embeddings")
replace_llama_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha
)
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.mistral_embeddings_hijack import (
replace_mistral_embeddings_with_uniform_distribution,
)
LOG.info("patching with noisy embeddings")
replace_mistral_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha
)
if cfg.is_llama_derived_model and cfg.xpos_rope: if cfg.is_llama_derived_model and cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
replace_llama_rope_with_xpos_rope, replace_llama_rope_with_xpos_rope,
@@ -258,20 +241,27 @@ def load_model(
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
config_kwargs = {}
if cfg.rope_scaling:
config_kwargs["rope_scaling"] = cfg.rope_scaling
config = LlamaConfig.from_pretrained(
base_model_config,
**config_kwargs,
)
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs, **model_kwargs,
) )
if cfg.flash_attention and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused,
)
if cfg.flash_attn_fuse_mlp:
LOG.info("patching with SwiGLU")
replace_llama_mlp_with_swiglu(model)
if cfg.flash_attn_fuse_qkv:
LOG.info("patching with fused QKV")
replace_llama_qkv_with_fused(model)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
# This is a WIP, still an issue with the backward pass # This is a WIP, still an issue with the backward pass
# RuntimeError: grad can be implicitly created only for scalar outputs # RuntimeError: grad can be implicitly created only for scalar outputs
@@ -311,66 +301,55 @@ def load_model(
if cfg.gptq: if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
else: else:
model = getattr(transformers, model_type).from_pretrained( model = getattr(transformers, model_type).from_pretrained(
base_model, base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
else: else:
config = AutoConfig.from_pretrained(
base_model,
trust_remote_code=cfg.trust_remote_code or False,
)
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts # when training starts
if ( if (
hasattr(config, "max_seq_len") hasattr(model_config, "max_seq_len")
and config.max_seq_len and model_config.max_seq_len
and cfg.sequence_len > config.max_seq_len and cfg.sequence_len > model_config.max_seq_len
): ):
config.max_seq_len = cfg.sequence_len model_config.max_seq_len = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}") LOG.warning(f"increasing context length to {cfg.sequence_len}")
elif ( elif (
hasattr(config, "max_sequence_length") hasattr(model_config, "max_sequence_length")
and config.max_sequence_length and model_config.max_sequence_length
and cfg.sequence_len > config.max_sequence_length and cfg.sequence_len > model_config.max_sequence_length
): ):
config.max_sequence_length = cfg.sequence_len model_config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}") LOG.warning(f"increasing context length to {cfg.sequence_len}")
if cfg.gptq: if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=model_config,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
else: else:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
except Exception as err: # pylint: disable=broad-exception-caught except Exception as err: # pylint: disable=broad-exception-caught
LOG.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
)
LOG.exception(err) LOG.exception(err)
model = AutoModelForCausalLM.from_pretrained( raise err
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
embeddings_len = ( embeddings_len = (
math.ceil(len(tokenizer) / 32) * 32 math.ceil(len(tokenizer) / 32) * 32
@@ -392,6 +371,20 @@ def load_model(
) )
model.config.max_position_embeddings = cfg.sequence_len model.config.max_position_embeddings = cfg.sequence_len
if (
hasattr(model.config, "bos_token_id")
and model.config.bos_token_id
and model.config.bos_token_id != tokenizer.bos_token_id
):
model.config.bos_token_id = tokenizer.bos_token_id
if (
hasattr(model.config, "eos_token_id")
and model.config.eos_token_id
and model.config.eos_token_id != tokenizer.eos_token_id
):
model.config.eos_token_id = tokenizer.eos_token_id
if model.device.type == "cuda": if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device) log_gpu_memory_usage(LOG, "after model load", model.device)
@@ -434,14 +427,7 @@ def load_model(
if cfg.ddp and not load_in_8bit: if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")
if ( if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
torch.cuda.device_count() > 1
and int(os.getenv("WORLD_SIZE", "1")) > 1
and (cfg.load_in_4bit)
):
# llama is PROBABLY model parallelizable, but the default isn't that it is
# so let's only set it for the 4bit, see
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
setattr(model, "is_parallelizable", True) setattr(model, "is_parallelizable", True)
setattr(model, "model_parallel", True) setattr(model, "model_parallel", True)
@@ -507,7 +493,11 @@ def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set() lora_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, cls) or "Linear" in module.__class__.__name__: if (
isinstance(module, cls)
or "Linear" in module.__class__.__name__
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
):
names = name.split(".") names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) lora_module_names.add(names[0] if len(names) == 1 else names[-1])

View File

@@ -0,0 +1,4 @@
"""
axolotl samplers module
"""
from .multipack import MultipackBatchSampler # noqa: F401

View File

@@ -0,0 +1,196 @@
# pylint: skip-file
"""
Multipack Batch Sampler
"""
import logging
import math
import os
from typing import Any, Iterable, List, Union
import numba
import numpy as np
from torch.utils.data import BatchSampler, Sampler
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
not_found = True
for idx in range(n):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
if not_found:
return False
return True
@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
indices = np.argsort(a)[::-1]
a = a[indices]
bins: List[Any] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
break
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
return bins_result
@numba.njit
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
):
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0
start_index = 0
result = []
while True:
# binary search [l, r)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
left = mid
else:
right = mid
# use length l
batch = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
assert len(batch) <= n
if len(batch) < n:
break
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
return result, s, len(result) * c * n
class MultipackBatchSampler(BatchSampler):
"""
Batch Sampler class for multipack
"""
def __init__(
self,
sampler: Union[Sampler[int], Iterable[int]],
batch_size: int,
drop_last: bool,
batch_max_len: int,
lengths: np.ndarray,
packing_efficiency_estimate: float = 1.0,
):
super().__init__(sampler, batch_size, drop_last)
self.batch_size = None
self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
assert isinstance(self.lengths, np.ndarray)
self.epoch = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
def set_epoch(self, epoch: int):
self.epoch = epoch
def generate_batches(self, set_stats=False):
indices = [idx for idx in self.sampler]
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
return batches
def __iter__(self):
batches = self.generate_batches(set_stats=True)
return iter(batches)
def num_batches(self):
batches = self.generate_batches(set_stats=True)
return len(batches)
def efficiency(self):
return self.eff_total_used / self.eff_total_slots
def __len__(self):
self.num_batches()
return self._len_est()
def _len_est(self):
world_size = int(os.getenv("WORLD_SIZE", "1"))
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // world_size
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return max(
1,
(
world_size
* math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
// self.batch_max_len
)
- 1
),
)

View File

@@ -34,6 +34,5 @@ def check_example_labels(example, tokenizer, text_only=False):
delimiter = "" if text_only else " " delimiter = "" if text_only else " "
LOG.info(delimiter.join(colored_tokens)) LOG.info(delimiter.join(colored_tokens))
LOG.info("\n\n\n") LOG.info("\n\n\n")
print(" ".join(colored_tokens))
return " ".join(colored_tokens) return " ".join(colored_tokens)

View File

@@ -1,51 +1,22 @@
"""Module containing the Trainer class and related functions""" """Module containing the Trainer class and related functions"""
import importlib
import logging
import math import math
import os import os
import sys
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial from functools import partial
from pathlib import Path from typing import List
from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torch.cuda import torch.cuda
import torch.distributed as dist from accelerate.logging import get_logger
import transformers from datasets import set_caching_enabled
from datasets import Dataset, set_caching_enabled from torch.utils.data import DataLoader, RandomSampler
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import (
DataLoader,
DistributedSampler,
RandomSampler,
SequentialSampler,
)
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import SequentialDistributedSampler
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.core.trainer_builder import HFCausalTrainerBuilder
from axolotl.utils.callbacks import ( from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
EvalFirstStepCallback, from axolotl.utils.samplers import MultipackBatchSampler
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.distributed import (
is_distributed,
is_main_process,
reduce_and_broadcast,
zero_first,
)
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
LOG = logging.getLogger("axolotl") LOG = get_logger("axolotl")
@torch.jit.script @torch.jit.script
@@ -110,269 +81,6 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
return weighted_cross_entropy(logits, labels, weights) return weighted_cross_entropy(logits, labels, weights)
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
Extend the base TrainingArguments for axolotl helpers
"""
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
sample_packing_seq_len_multiplier: int = field(
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
class AxolotlTrainer(Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
def __init__(self, *args, bench_data_collator=None, **kwargs):
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
):
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing:
return DistributedSampler(
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if (
self.args.world_size > 1
and self.args.sample_packing
and self.args.eval_sample_packing is not False
):
return SequentialDistributedSampler(
eval_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
batch_size=self.args.per_device_eval_batch_size,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
return self.accelerator.prepare(
MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
)
return super().get_train_dataloader()
def get_eval_dataloader(
self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
return self.accelerator.prepare(
MultipackDistributedDataloader(
eval_dataset,
batch_size=self.args.eval_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> Union[DataLoader, MultipackDistributedDataloader]:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
pct_start=pct_start,
div_factor=6,
)
return self.lr_scheduler
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
def add_position_ids(sample): def add_position_ids(sample):
sample_len = len(sample["input_ids"]) sample_len = len(sample["input_ids"])
sample["position_ids"] = torch.arange(len(sample["input_ids"])) sample["position_ids"] = torch.arange(len(sample["input_ids"]))
@@ -423,7 +131,9 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
) )
# Phi doesn't want the attention_mask feature when training # Phi doesn't want the attention_mask feature when training
if "CodeGenTokenizer" in tokenizer.__class__.__name__: if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
cfg.is_mistral_derived_model and cfg.flash_attention
):
train_dataset = train_dataset.remove_columns("attention_mask") train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask") eval_dataset = eval_dataset.remove_columns("attention_mask")
@@ -431,30 +141,33 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
return train_dataset, eval_dataset return train_dataset, eval_dataset
def calculate_total_num_steps(cfg, train_dataset, tokenizer): def calculate_total_num_steps(cfg, train_dataset):
if not cfg.total_num_tokens:
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
cfg.total_num_tokens = total_num_tokens
if not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
.apply(lambda x: np.sum(np.array(x) != -100))
.sum()
)
LOG.debug(
f"`total_supervised_tokens: {total_supervised_tokens}`",
main_process_only=True,
)
cfg.total_supervised_tokens = total_supervised_tokens
if cfg.sample_packing: if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise # we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails # flash attention with position ids fails
if not cfg.total_num_tokens:
LOG.info("calculating total_num_tokens")
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.info(f"total_num_tokens: {total_num_tokens}")
cfg.total_num_tokens = total_num_tokens
if not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
.apply(lambda x: np.sum(np.array(x) != -100))
.sum()
)
LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`")
cfg.total_supervised_tokens = total_supervised_tokens
if cfg.sample_packing_eff_est: if cfg.sample_packing_eff_est:
total_num_steps = ( total_num_steps = (
@@ -472,40 +185,41 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
) )
* cfg.num_epochs * cfg.num_epochs
) )
LOG.info( LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}" f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}",
main_process_only=True,
) )
else: else:
if cfg.world_size > 1 and is_distributed(): sampler = MultipackBatchSampler(
sampler = DistributedSampler( sampler=RandomSampler(train_dataset),
train_dataset,
num_replicas=cfg.world_size,
rank=dist.get_rank(),
seed=cfg.seed or 42,
)
else:
sampler = RandomSampler(train_dataset)
data_loader = MultipackDistributedDataloader(
train_dataset,
batch_size=cfg.micro_batch_size, batch_size=cfg.micro_batch_size,
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len, drop_last=True,
collate_fn=DataCollatorForSeq2Seq( batch_max_len=cfg.micro_batch_size
tokenizer, * (cfg.max_packed_sequence_len or cfg.sequence_len),
return_tensors="pt", lengths=(
padding="longest", train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
), ),
sampler=sampler,
packing_efficiency_estimate=cfg.sample_packing_eff_est,
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
) )
data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency() data_loader = DataLoader(
LOG.info(f"data_loader_len: {data_loader_len}") train_dataset.remove_columns(["length"]),
batch_sampler=sampler,
)
data_loader_len = len(data_loader)
actual_eff = sampler.efficiency()
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
# FIXME: is there a bug here somewhere? the total num steps depends # FIXME: is there a bug here somewhere? the total num steps depends
# on the agreed on value for sample_packing_eff_est # on the agreed on value for sample_packing_eff_est
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs)) total_num_steps = int(
math.floor(
data_loader_len
* cfg.num_epochs
/ int(os.environ.get("WORLD_SIZE", 1))
)
)
def calc_sample_packing_eff_est(estimates: List[float]): def calc_sample_packing_eff_est(estimates: List[float]):
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}") LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
@@ -519,12 +233,20 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0 math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
) )
cfg.sample_packing_eff_est = sample_packing_eff_est cfg.sample_packing_eff_est = sample_packing_eff_est
LOG.info(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}") LOG.debug(
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
main_process_only=True,
)
else: else:
total_num_steps = int( total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(
len(train_dataset)
* cfg.num_epochs
/ int(os.environ.get("WORLD_SIZE", 1))
/ cfg.batch_size
)
) )
LOG.info(f"total_num_steps: {total_num_steps}") LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
return total_num_steps return total_num_steps
@@ -548,245 +270,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
elif cfg.deepspeed: elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
warmup_steps = ( trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
cfg.warmup_steps trainer_builder.train_dataset = train_dataset
if cfg.warmup_steps is not None trainer_builder.eval_dataset = eval_dataset
else min(int(0.03 * total_num_steps), 100)
)
logging_steps = (
cfg.logging_steps
if cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_arguments_kwargs = {} return trainer_builder.build(total_num_steps)
if cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
else:
training_arguments_kwargs["bf16"] = cfg.bf16
training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
training_arguments_kwargs["tf32"] = cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
if cfg.seed:
training_arguments_kwargs["seed"] = cfg.seed
if cfg.gradient_checkpointing:
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
if cfg.fsdp:
training_arguments_kwargs["fsdp"] = cfg.fsdp
if cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
# deepspeed
if cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
if cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
if cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
if cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
if cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
if cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
if cfg.hub_model_id:
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
if cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = cfg.sample_packing_eff_est
if cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
elif cfg.evaluation_strategy:
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
elif cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
if cfg.save_steps:
training_arguments_kwargs["save_strategy"] = "steps"
training_arguments_kwargs["save_steps"] = cfg.save_steps
elif cfg.save_strategy:
training_arguments_kwargs["save_strategy"] = cfg.save_strategy
else:
# default to saving each epoch if not defined
training_arguments_kwargs["save_strategy"] = "epoch"
if cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
if cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
if cfg.metric_for_best_model:
training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model
if cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
if cfg.torch_compile:
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
else:
import torch._dynamo # pylint: disable=redefined-outer-name
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_arguments_kwargs["torch_compile"] = cfg.torch_compile
if cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = cfg.torch_compile_backend
# DDP Config
if cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb
if cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
output_dir=cfg.output_dir,
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
load_best_model_at_end=(
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
and cfg.val_set_size > 0
and cfg.save_steps
and cfg.eval_steps
and cfg.save_steps % cfg.eval_steps == 0
)
or False,
ddp_find_unused_parameters=False if cfg.ddp else None,
group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None,
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
optim=cfg.optimizer if cfg.optimizer else "adamw_hf",
lr_scheduler_type=cfg.lr_scheduler
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
eval_sample_packing=cfg.eval_sample_packing,
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
relora_steps=cfg.relora_steps,
relora_warmup_steps=cfg.relora_warmup_steps,
**training_arguments_kwargs,
)
trainer_kwargs = {}
if cfg.optimizer == "adamw_anyprecision":
if Path(cfg.torchdistx_path).exists():
sys.path.append(cfg.torchdistx_path)
importlib.import_module("torchdistx")
callbacks = []
callbacks.append(GPUStatsCallback(cfg))
callbacks.append(EvalFirstStepCallback)
if cfg.relora_steps:
callbacks.append(ReLoRACallback(cfg))
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
if cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
cfg.sequence_len / 64
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
if cfg.is_llama_derived_model and cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens,
get_mem_id,
set_model_mem_id,
)
set_model_mem_id(model, tokenizer)
LOG.info("Adding landmark attention tokens to dataset")
for dataset in [train_dataset, eval_dataset]:
dataset = dataset.map(
partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
batched=False,
num_proc=32,
)
trainer_cls = AxolotlTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora"):
trainer_cls = OneCycleLRSchedulerTrainer
elif cfg.relora_steps:
trainer_cls = ReLoRATrainer
trainer = trainer_cls(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
data_collator=DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=callbacks,
**trainer_kwargs,
)
if cfg.use_wandb and cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
trainer.add_callback(LogPredictionCallback(cfg))
if cfg.use_wandb:
trainer.add_callback(SaveAxolotlConfigtoWandBCallback(cfg.axolotl_config_path))
if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience,
)
trainer.add_callback(early_stop_cb)
return trainer

0
tests/e2e/__init__.py Normal file
View File

View File

@@ -0,0 +1,73 @@
"""
E2E tests for lora llama
"""
import logging
import os
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestFusedLlama(unittest.TestCase):
"""
Test case for Llama models using Fused layers
"""
@with_temp_dir
def test_fft_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"flash_attention": True,
"flash_attn_fuse_qkv": True,
"flash_attn_fuse_mlp": True,
"sample_packing": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import logging import logging
import os import os
import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
@@ -14,6 +13,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -23,13 +24,12 @@ class TestLoraLlama(unittest.TestCase):
Test case for Llama models using LoRA Test case for Llama models using LoRA
""" """
def test_lora(self): @with_temp_dir
def test_lora(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer", "tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024, "sequence_len": 1024,
"load_in_8bit": True, "load_in_8bit": True,
@@ -53,7 +53,7 @@ class TestLoraLlama(unittest.TestCase):
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 8, "micro_batch_size": 8,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": output_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -64,15 +64,14 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
def test_lora_packing(self): @with_temp_dir
def test_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer", "tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": True, "sample_packing": True,
@@ -98,10 +97,11 @@ class TestLoraLlama(unittest.TestCase):
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 8, "micro_batch_size": 8,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": output_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"bf16": True,
} }
) )
normalize_config(cfg) normalize_config(cfg)
@@ -109,15 +109,14 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
def test_lora_gptq(self): @with_temp_dir
def test_lora_gptq(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ", "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"base_model_config": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"model_type": "AutoModelForCausalLM", "model_type": "AutoModelForCausalLM",
"tokenizer_type": "LlamaTokenizer", "tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024, "sequence_len": 1024,
@@ -147,7 +146,7 @@ class TestLoraLlama(unittest.TestCase):
"save_steps": 0.5, "save_steps": 0.5,
"micro_batch_size": 8, "micro_batch_size": 8,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": output_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -158,4 +157,4 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import logging import logging
import os import os
import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
@@ -16,6 +15,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -25,13 +26,12 @@ class TestMistral(unittest.TestCase):
Test case for Llama models using LoRA Test case for Llama models using LoRA
""" """
def test_lora(self): @with_temp_dir
def test_lora(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"load_in_8bit": True, "load_in_8bit": True,
@@ -55,7 +55,7 @@ class TestMistral(unittest.TestCase):
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": output_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -69,15 +69,14 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
def test_ft(self): @with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"val_set_size": 0.1, "val_set_size": 0.1,
@@ -95,7 +94,7 @@ class TestMistral(unittest.TestCase):
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": output_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -113,4 +112,4 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "pytorch_model.bin").exists() assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import logging import logging
import os import os
import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
@@ -16,6 +15,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -25,13 +26,12 @@ class TestMistral(unittest.TestCase):
Test case for Llama models using LoRA Test case for Llama models using LoRA
""" """
def test_lora_packing(self): @with_temp_dir
def test_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 1024, "sequence_len": 1024,
@@ -56,7 +56,7 @@ class TestMistral(unittest.TestCase):
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": output_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -70,15 +70,14 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
def test_ft_packing(self): @with_temp_dir
def test_ft_packing(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 1024, "sequence_len": 1024,
@@ -97,7 +96,7 @@ class TestMistral(unittest.TestCase):
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": output_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -115,4 +114,4 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "pytorch_model.bin").exists() assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -4,8 +4,8 @@ E2E tests for lora llama
import logging import logging
import os import os
import tempfile
import unittest import unittest
from pathlib import Path
from axolotl.cli import load_datasets from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
@@ -13,6 +13,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -22,12 +24,12 @@ class TestPhi(unittest.TestCase):
Test case for Llama models using LoRA Test case for Llama models using LoRA
""" """
def test_ft(self): @with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "microsoft/phi-1_5", "base_model": "microsoft/phi-1_5",
"base_model_config": "microsoft/phi-1_5",
"trust_remote_code": True, "trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM", "model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer", "tokenizer_type": "AutoTokenizer",
@@ -53,7 +55,7 @@ class TestPhi(unittest.TestCase):
"num_epochs": 1, "num_epochs": 1,
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(), "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit", "optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -65,13 +67,14 @@ class TestPhi(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
def test_ft_packed(self): @with_temp_dir
def test_ft_packed(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "microsoft/phi-1_5", "base_model": "microsoft/phi-1_5",
"base_model_config": "microsoft/phi-1_5",
"trust_remote_code": True, "trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM", "model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer", "tokenizer_type": "AutoTokenizer",
@@ -97,7 +100,7 @@ class TestPhi(unittest.TestCase):
"num_epochs": 1, "num_epochs": 1,
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(), "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit", "optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -109,3 +112,4 @@ class TestPhi(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

95
tests/e2e/test_resume.py Normal file
View File

@@ -0,0 +1,95 @@
"""
E2E tests for resuming training
"""
import logging
import os
import re
import subprocess
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import most_recent_subdir, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestResumeLlama(unittest.TestCase):
"""
Test case for resuming training of llama models
"""
@with_temp_dir
def test_resume_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_steps": 10,
"save_total_limit": 5,
"max_steps": 40,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
resume_cfg = cfg | DictDefault(
{
"resume_from_checkpoint": f"{temp_dir}/checkpoint-30/",
}
)
normalize_config(resume_cfg)
cli_args = TrainerCliArgs()
train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
cmd = f"tensorboard --inspect --logdir {tb_log_path_1}"
res = subprocess.run(
cmd, shell=True, text=True, capture_output=True, check=True
)
pattern = r"first_step\s+(\d+)"
first_steps = int(re.findall(pattern, res.stdout)[0])
assert first_steps == 31

33
tests/e2e/utils.py Normal file
View File

@@ -0,0 +1,33 @@
"""
helper utils for tests
"""
import os
import shutil
import tempfile
from functools import wraps
from pathlib import Path
def with_temp_dir(test_func):
@wraps(test_func)
def wrapper(*args, **kwargs):
# Create a temporary directory
temp_dir = tempfile.mkdtemp()
try:
# Pass the temporary directory to the test function
test_func(*args, temp_dir=temp_dir, **kwargs)
finally:
# Clean up the directory after the test
shutil.rmtree(temp_dir)
return wrapper
def most_recent_subdir(path):
base_path = Path(path)
subdirectories = [d for d in base_path.iterdir() if d.is_dir()]
if not subdirectories:
return None
subdir = max(subdirectories, key=os.path.getctime)
return subdir

View File

@@ -0,0 +1,46 @@
"""
Test classes for checking functionality of the cfg normalization
"""
import unittest
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
class NormalizeConfigTestCase(unittest.TestCase):
"""
test class for normalize_config checks
"""
def _get_base_cfg(self):
return DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
def test_lr_as_float(self):
cfg = (
self._get_base_cfg()
| DictDefault( # pylint: disable=unsupported-binary-operation
{
"learning_rate": "5e-5",
}
)
)
normalize_config(cfg)
assert cfg.learning_rate == 0.00005
def test_base_model_config_set_when_empty(self):
cfg = self._get_base_cfg()
del cfg.base_model_config
normalize_config(cfg)
assert cfg.base_model_config == cfg.base_model

View File

@@ -565,3 +565,87 @@ class ValidationTest(unittest.TestCase):
) )
validate_config(cfg) validate_config(cfg)
def test_eval_table_size_conflict_eval_packing(self):
cfg = DictDefault(
{
"sample_packing": True,
"eval_table_size": 100,
}
)
with pytest.raises(
ValueError, match=r".*Please set 'eval_sample_packing' to false.*"
):
validate_config(cfg)
cfg = DictDefault(
{
"sample_packing": True,
"eval_sample_packing": False,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"sample_packing": False,
"eval_table_size": 100,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"sample_packing": True,
"eval_table_size": 100,
"eval_sample_packing": False,
}
)
validate_config(cfg)
def test_load_in_x_bit_without_adapter(self):
cfg = DictDefault(
{
"load_in_4bit": True,
}
)
with pytest.raises(
ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = DictDefault(
{
"load_in_8bit": True,
}
)
with pytest.raises(
ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = DictDefault(
{
"load_in_4bit": True,
"adapter": "qlora",
}
)
validate_config(cfg)
cfg = DictDefault(
{
"load_in_8bit": True,
"adapter": "lora",
}
)
validate_config(cfg)