Compare commits

..

51 Commits

Author SHA1 Message Date
Casper
4f9b172c47 Remove FP32 cast 2023-12-07 16:28:25 +01:00
Casper Hansen
8671ed5a0c Fix import 2023-12-06 20:26:31 +00:00
Casper Hansen
538c004080 Fix shapes 2023-12-06 20:26:25 +00:00
Casper
add3b139ed Mistral with fast cross entropy 2023-12-06 20:17:42 +01:00
NanoCode012
a581e9f8f6 feat: add check for quantized model (#913)
* feat: add check for quantized model

* chore: refactor and add another check

* Update src/axolotl/utils/models.py

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-12-05 01:20:06 +09:00
Bryan Thornbury
992e742cdc Support device_map=sequential & max_memory config parameters (#903)
* Support device_map sequential (and others). Support max_memory in cfg.

* Update documentation in README accordingly.

* Update README.md

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-12-04 09:29:21 -05:00
NanoCode012
a1da39cd48 Feat(wandb): Refactor to be more flexible (#767)
* Feat: Update to handle wandb env better

* chore: rename wandb_run_id to wandb_name

* feat: add new recommendation and update config

* fix: indent and pop disabled env if project passed

* feat: test env set for wandb and recommendation

* feat: update to use wandb_name and allow id

* chore: add info to readme
2023-12-04 22:17:25 +09:00
kallewoof
58ec8b1113 feature: loss watchdog for terminating training runs that are failing (#899)
Co-authored-by: Karl-Johan Alm <kalle@gmail.com>
2023-12-04 07:54:34 -05:00
Haoxiang Wang
476a205cea Remove learning rate scheduler in deepspeed config to avoid conflict (#909) 2023-12-04 05:17:38 -05:00
Wing Lian
3e3229e2d9 fix for qwen w lora (#906) 2023-11-30 12:45:50 -05:00
Wing Lian
1d21aa6b0a ensure merged model matches the training dtype (#902)
* ensure merged model matches the training dtype

* Update src/axolotl/cli/__init__.py

* Update src/axolotl/cli/__init__.py
2023-11-29 09:55:19 -05:00
kallewoof
71b7ea3c05 Determine FSDP/deepspeed settings on device select. (#883)
* Determine FSDP/deepspeed settings on device select.

Without this, the OS env check for accelerate will fail.

* rename and move env setup call

* chore: lint

---------

Co-authored-by: Karl-Johan Alm <kalle@gmail.com>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-11-29 08:36:35 -05:00
NanoCode012
a48dbf6561 fix: remove FA for qwen examples (#900)
* fix: remove FA for qwen lora

* fix: remove FA for qlora
2023-11-27 21:23:54 +09:00
Wing Lian
6a4562ac08 update datasets version to cut down the warnings due to pyarrow arg change (#897)
* update datasets to cut down the warnings

* set versions for tokenizers and gradio

* upgrade transformers to latest version
2023-11-25 16:30:00 -05:00
NanoCode012
1115c501b8 Feat: Add Qwen (#894)
* Feat: Add Qwen

* feat: add qwen lora example

* feat: update matrix

* fix: add trust_remote_code

* fix: disable gradient checkpointing

* chore: add warning about gradient checkpointing

* fix: config

* fix: turn off sample packing for this example and reduce seq len

* chore: add comment on seq len
2023-11-26 00:05:01 +09:00
NanoCode012
7ee3c4cacb fix: warning should not show if eval_batch_size not provided (#896) 2023-11-25 16:04:00 +09:00
NanoCode012
fb12895a17 Feat: Add warmup_ratio (#893)
* Feat: Add warmup_ratio

* fix: update readme with more details on conflict
2023-11-25 12:15:43 +09:00
NanoCode012
9fc29e082b chore(doc): Add info on changing role in sharegpt (#886) 2023-11-22 15:32:50 +09:00
NanoCode012
575a082aae fix: revert local dir dataset load (#878) 2023-11-18 22:50:41 +09:00
Mark Saroufim
ddf815022a Install from git url (#874)
* Install from git url

* Update README.md
2023-11-17 12:50:51 -05:00
Wing Lian
9bf854e59c Phi update 202311 (#876)
* add phi modeling from hf

* update for packing and use new modeling class for phi

* update e2e tests for phi to use new model name

* update example phi to also use new phi model name

* use AutoModelForCausalLM for phi lora since sample packing isn't supported
2023-11-17 12:47:17 -05:00
Wing Lian
797f3dd1de don't train if eval split is too small (#873)
* allow zero len dataset

* better handling and warning of small eval splits

* raise error if eval split is too small

* don't mess with calculating total num steps in distributed context

* fix eval_sample_packing training args logic
2023-11-16 11:35:42 -05:00
Wing Lian
0de1457189 try #2: pin hf transformers and accelerate to latest release, don't reinstall pytorch (#867)
* isolate torch from the requirements.txt

* fix typo for removed line ending

* pin transformers and accelerate to latest releases

* try w auto-gptq==0.5.1

* update README to remove manual peft install

* pin xformers to 0.0.22

* bump flash-attn to 2.3.3

* pin flash attn to exact version
2023-11-16 10:42:36 -05:00
NanoCode012
3cc67d2cdd Feat: Add dataset loading from S3, GCS (#765)
* Feat: Add dataset loading from S3, GCS

* chore: update docs

* chore: add more info on cloud loading
2023-11-16 14:33:58 +09:00
Wing Lian
1bc11868eb allow overriding of model_config parameters from the YML (#853)
* allow overriding of model_config parameters from the YML

* remove old logging, update readme

* move the updating of model config to the load_model_config function

* add warning for deprecated rope_scaling in the root of the YML config
2023-11-15 23:47:08 -05:00
Wing Lian
b3a61e8ce2 add e2e tests for checking functionality of resume from checkpoint (#865)
* use tensorboard to see if resume from checkpoint works

* make sure e2e test is either fp16 or bf16

* set max_steps and save limit so we have the checkpoint when testing resuming

* fix test parameters
2023-11-15 23:05:55 -05:00
Wing Lian
8a8d1c4023 make docker command more robust (#861)
* make docker command more robust

* update readme with more info
2023-11-15 23:03:54 -05:00
Wing Lian
332984db18 lint fix that didn't get caught by linter (#866) 2023-11-15 14:36:40 -05:00
MilesQLi
48630f5b34 Update data.py for signature generation (#851)
* Update data.py

Change of conversation formatting type should also trigger updating the preprocessed dataset, so it should be part of the signature.

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-11-15 14:12:32 -05:00
Zongheng Yang
b33c1d55a2 Docs: add instructions to 1-click launching on public clouds (#862)
* Update README.md

* Update ToC
2023-11-15 14:11:27 -05:00
Wing Lian
0c2a630326 multipack len should use max, not min (#863) 2023-11-15 12:52:32 -05:00
Wing Lian
db8a8afcba adds llama and mistral dropout support (#858)
* adds llama and mistral dropout support

* gracefully handle attention dropout if not available yet
2023-11-15 12:28:50 -05:00
Wing Lian
14706504e3 various bugfixes (#856)
* various bugfixes

use latest tinyllama release
check if val_set_size is empty first
update sdp and xformers llama patches for updated upstream transformers
fix system prompt when no input
calculate total and total supervised tokens even when not sample packing

* add fix for when eval size is estimated to be too small

* should be len 1 for dataset length

* add catchall kwargs
2023-11-15 12:23:18 -05:00
NanoCode012
501b4d1379 chore(doc): Separate section on runpod (#860) 2023-11-16 01:06:51 +09:00
NanoCode012
306fe19c54 feat(doc): add more info on train_on_split (#855) 2023-11-15 23:42:26 +09:00
Fabian Preiß
614cff4107 include the suffix modified string in ascii art (#852) 2023-11-15 07:12:28 -05:00
Wing Lian
1a6309c8a6 cleanup the old multipack dataloader (#841) 2023-11-12 05:39:09 -05:00
Bryan Thornbury
105d0b350b Pin optimum package (#838) 2023-11-09 22:36:15 -05:00
Wing Lian
f544ab2bed don't compile deepspeed or bitsandbytes from source (#837) 2023-11-08 19:49:55 -05:00
Wing Lian
641e6f7e51 multipack w batch sampler (#795)
* test batch sampler w varying batch lens

* wip

* multipack batchsampler wip

* wip

* fix for prepare data loader to get correct # of steps based on gpues

* lint and clean up

* calculate len estimate

* fix total num steps calc

* add options for dataloader_num_workers and dataloader_pin_memory

* remove gitbook

* support prefetch_factor for dataloader optimization

* fix the kwarg
2023-11-07 20:27:40 -05:00
Wing Lian
6dc68a653f use temp_dir kwarg instead 2023-11-06 18:33:01 -05:00
Wing Lian
7de6a5639c missing dunder-init 2023-11-06 18:33:01 -05:00
Wing Lian
c74f045ba7 chore: lint 2023-11-06 18:33:01 -05:00
Wing Lian
0402d19759 make sure to cleanup tmp output_dir for e2e tests 2023-11-06 18:33:01 -05:00
Wing Lian
b2430ce670 use accelerate logging for zero/main loggin only 2023-11-06 18:32:26 -05:00
Wing Lian
4c834bf25d cleanup verbosity a bit 2023-11-06 18:32:26 -05:00
Fabian Preiß
8056ecd30e add deepspeed-kernels dependency for deepspeed>=0.12.0 (#827) 2023-11-05 07:52:56 -05:00
Jason Stillerman
738a057674 Feat: Added Gradio support (#812)
* Added gradio support

* queuing and title

* pre-commit run
2023-11-04 23:59:22 -04:00
Wing Lian
cdc71f73c8 update table for rwkv4 support, fix process count for dataset (#822) 2023-11-04 23:45:44 -04:00
NanoCode012
6459ac7357 fix: pin autogptq (#818) 2023-11-03 10:14:55 -04:00
Wing Lian
964d858da0 fix model parallel (#816) 2023-11-02 21:34:22 -04:00
80 changed files with 2880 additions and 888 deletions

View File

@@ -71,6 +71,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
pip3 uninstall -y transformers accelerate pip3 uninstall -y transformers accelerate
pip3 install -U -e .[flash-attn] pip3 install -U -e .[flash-attn]
pip3 install -r requirements-tests.txt pip3 install -r requirements-tests.txt

View File

@@ -25,8 +25,10 @@ Features:
- [Installation](#installation) - [Installation](#installation)
- [Docker](#docker) - [Docker](#docker)
- [Conda/Pip venv](#condapip-venv) - [Conda/Pip venv](#condapip-venv)
- [Runpod](#runpod)
- [LambdaLabs](#lambdalabs) - [LambdaLabs](#lambdalabs)
- [Windows](#windows) - [Windows](#windows)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Dataset](#dataset) - [Dataset](#dataset)
- [How to Add Custom Prompts](#how-to-add-custom-prompts) - [How to Add Custom Prompts](#how-to-add-custom-prompts)
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset) - [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
@@ -74,6 +76,8 @@ Features:
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ | | gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ | | XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ | | phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
## Quickstart ⚡ ## Quickstart ⚡
@@ -82,20 +86,29 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
**Requirements**: Python >=3.9 and Pytorch >=2.0. **Requirements**: Python >=3.9 and Pytorch >=2.0.
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
### For developers
```bash ```bash
git clone https://github.com/OpenAccess-AI-Collective/axolotl git clone https://github.com/OpenAccess-AI-Collective/axolotl
cd axolotl cd axolotl
pip3 install packaging pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]' pip3 install -e '.[flash-attn,deepspeed]'
pip3 install -U git+https://github.com/huggingface/peft.git ```
### Usage
```bash
# finetune lora # finetune lora
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference # inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --lora_model_dir="./lora-out"
# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --gradio
``` ```
## Installation ## Installation
@@ -106,7 +119,6 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
```bash ```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1 docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
``` ```
- `winglian/axolotl-runpod:main-latest`: for runpod or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
Or run on the current files for development: Or run on the current files for development:
@@ -121,13 +133,15 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
A more powerful Docker command to run would be this: A more powerful Docker command to run would be this:
```bash ```bash
docker run --gpus '"all"' --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1 docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
``` ```
It additionally: It additionally:
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args. * Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args. * Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal. * The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
* The `--privileged` flag gives all capabilities to the container.
* The `--shm-size 10g` argument increases the shared memory size. Use this if you see `exitcode: -7` errors using deepspeed.
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem) [More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
@@ -149,6 +163,10 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
``` ```
Get the token at huggingface.co/settings/tokens Get the token at huggingface.co/settings/tokens
#### Runpod
Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
#### LambdaLabs #### LambdaLabs
<details> <details>
@@ -196,6 +214,28 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
#### Windows #### Windows
Please use WSL or Docker! Please use WSL or Docker!
#### Launching on public clouds via SkyPilot
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
```bash
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
sky check
```
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
```
git clone https://github.com/skypilot-org/skypilot.git
cd skypilot/llm/axolotl
```
Use one command to launch:
```bash
# On-demand
HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
# Managed spot (auto-recovery on preemption)
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
```
### Dataset ### Dataset
Axolotl supports a variety of dataset formats. Below are some of the formats you can use. Axolotl supports a variety of dataset formats. Below are some of the formats you can use.
@@ -392,6 +432,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- path: knowrohit07/know_sql - path: knowrohit07/know_sql
type: context_qa.load_v2 type: context_qa.load_v2
train_on_split: validation train_on_split: validation
# loading from s3 or gcs
# s3 creds will be loaded from the system default and gcs only supports public access
dataset:
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
...
``` ```
- loading - loading
@@ -454,6 +500,15 @@ is_falcon_derived_model:
is_llama_derived_model: is_llama_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default # Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model: is_mistral_derived_model:
is_qwen_derived_model:
# optional overrides to the base model configuration
model_config:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# Whether you are training a 4-bit GPTQ quantized model # Whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
@@ -478,7 +533,7 @@ float16: true
# A list of one or more datasets to finetune the model with # A list of one or more datasets to finetune the model with
datasets: datasets:
# HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn> type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
@@ -486,9 +541,12 @@ datasets:
data_files: # Optional[str] path to source data files data_files: # Optional[str] path to source data files
shards: # Optional[int] number of shards to split data into shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
# Optional[str] fastchat conversation type, only used with type: sharegpt # Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
field_human: # Optional[str]. Human key to use for conversation.
field_model: # Optional[str]. Assistant key to use for conversation.
# Custom user prompt # Custom user prompt
- path: repo - path: repo
@@ -554,6 +612,12 @@ eval_sample_packing:
sample_packing_eff_est: sample_packing_eff_est:
total_num_tokens: total_num_tokens:
# Passed through to transformers when loading the model when launched without accelerate
# Use `sequential` when training w/ model parallelism to limit memory
device_map:
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
max_memory:
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora adapter: lora
# If you already have a lora model trained that you want to load, put that here. # If you already have a lora model trained that you want to load, put that here.
@@ -601,7 +665,8 @@ wandb_mode: # "offline" to save run metadata locally and not sync to the server,
wandb_project: # Your wandb project name wandb_project: # Your wandb project name
wandb_entity: # A wandb Team name if using a Team wandb_entity: # A wandb Team name if using a Team
wandb_watch: wandb_watch:
wandb_run_id: # Set the name of your wandb run wandb_name: # Set the name of your wandb run
wandb_run_id: # Set the ID of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# Where to save the full-finetuned model to # Where to save the full-finetuned model to
@@ -619,7 +684,8 @@ gradient_accumulation_steps: 1
micro_batch_size: 2 micro_batch_size: 2
eval_batch_size: eval_batch_size:
num_epochs: 4 num_epochs: 4
warmup_steps: 100 warmup_steps: 100 # cannot use with warmup_ratio
warmup_ratio: 0.05 # cannot use with warmup_steps
learning_rate: 0.00003 learning_rate: 0.00003
lr_quadratic_warmup: lr_quadratic_warmup:
logging_steps: logging_steps:
@@ -635,6 +701,9 @@ max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
# Save model as safetensors (require safetensors package) # Save model as safetensors (require safetensors package)
save_safetensors: save_safetensors:
@@ -721,10 +790,6 @@ landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# LLaMA only # LLaMA only
xpos_rope: xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling:
type: # linear | dynamic
factor: # float
# Resume from a specific checkpoint dir # Resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:
@@ -897,7 +962,7 @@ wandb_mode:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
``` ```
@@ -918,6 +983,10 @@ Pass the appropriate flag to the train command:
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \ cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
--base_model="./completed-model" --prompter=None --load_in_8bit=True --base_model="./completed-model" --prompter=None --load_in_8bit=True
``` ```
-- With gradio hosting
```bash
python -m axolotl.cli.inference examples/your_config.yml --gradio
```
Please use `--sample_packing False` if you have it on and receive the error similar to below: Please use `--sample_packing False` if you have it on and receive the error similar to below:

View File

@@ -24,16 +24,6 @@
"weight_decay": "auto" "weight_decay": "auto"
} }
}, },
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",

View File

@@ -28,16 +28,6 @@
"weight_decay": "auto" "weight_decay": "auto"
} }
}, },
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",

View File

@@ -32,16 +32,6 @@
"weight_decay": "auto" "weight_decay": "auto"
} }
}, },
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",

View File

@@ -21,9 +21,9 @@ WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \ pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
else \ else \
pip install -e .[flash-attn]; \ pip install -e .[deepspeed,flash-attn]; \
fi fi
# fix so that git fetch/pull from remote works # fix so that git fetch/pull from remote works

View File

@@ -10,8 +10,10 @@ ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.9" ARG PYTHON_VERSION="3.9"
ARG PYTORCH_VERSION="2.0.1" ARG PYTORCH_VERSION="2.0.1"
ARG CUDA="118" ARG CUDA="118"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \ && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
@@ -27,47 +29,9 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \ RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} deepspeed-kernels --extra-index-url https://download.pytorch.org/whl/cu$CUDA
FROM base-builder AS deepspeed-builder RUN git lfs install --skip-repo && \
pip3 install awscli && \
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
WORKDIR /workspace
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
cd DeepSpeed && \
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 DS_BUILD_EVOFORMER_ATTN=0 python3 setup.py bdist_wheel
FROM base-builder AS bnb-builder
WORKDIR /workspace
ARG CUDA="118"
ENV CUDA=$CUDA
ARG MAX_JOBS="-1"
ENV MAX_JOBS=$MAX_JOBS
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
cd bitsandbytes && \
CUDA_VERSION=$CUDA make cuda11x && \
python setup.py bdist_wheel
FROM base-builder
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN mkdir -p /workspace/builds
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
RUN mkdir -p /workspace/wheels/bitsandbytes
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
RUN pip3 install wheels/deepspeed-*.whl
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
RUN git lfs install --skip-repo
RUN pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working # The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 pip3 install -U --no-cache-dir pydantic==1.10.10

View File

@@ -14,7 +14,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: last_prepared_run dataset_prepared_path: last_prepared_run
val_set_size: 0.01 val_set_size: 0.05
adapter: adapter:
lora_model_dir: lora_model_dir:
@@ -35,7 +35,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: btlm-out output_dir: btlm-out

View File

@@ -7,7 +7,7 @@ datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
@@ -24,7 +24,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out
batch_size: 4 batch_size: 4

View File

@@ -11,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -11,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -11,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -11,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -11,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -11,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -12,7 +12,7 @@ datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca:chat type: alpaca:chat
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./falcon-7b output_dir: ./falcon-7b
batch_size: 2 batch_size: 2

View File

@@ -18,7 +18,7 @@ datasets:
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json - Chain-of-Thought/formatted_cot_data/gsm8k_train.json
type: "alpaca:chat" type: "alpaca:chat"
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
# enable QLoRA # enable QLoRA
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
@@ -40,7 +40,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out

View File

@@ -12,7 +12,7 @@ datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca:chat type: alpaca:chat
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
adapter: adapter:
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
@@ -26,7 +26,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./falcon-7b output_dir: ./falcon-7b
batch_size: 2 batch_size: 2

View File

@@ -7,7 +7,7 @@ datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2

View File

@@ -19,7 +19,7 @@ lora_fan_in_fan_out: false
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./jeopardy-bot-7b output_dir: ./jeopardy-bot-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -11,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./out output_dir: ./out
sequence_len: 4096 sequence_len: 4096
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -15,7 +15,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 4096
@@ -32,7 +32,7 @@ lora_target_linear:
lora_fan_in_fan_out: lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./model-out output_dir: ./model-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -11,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -11,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -11,7 +11,7 @@ datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./relora-out output_dir: ./relora-out
adapter: qlora adapter: qlora
@@ -35,7 +35,7 @@ relora_cpu_offload: false
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -1,4 +1,4 @@
base_model: PY007/TinyLlama-1.1B-step-50K-105b base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
@@ -12,7 +12,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -11,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./out output_dir: ./out
sequence_len: 8192 sequence_len: 8192
@@ -21,7 +21,7 @@ pad_to_sequence_len: true
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4

View File

@@ -11,7 +11,7 @@ datasets:
- path: mhenrichsen/alpaca_2k_test - path: mhenrichsen/alpaca_2k_test
type: alpaca type: alpaca
dataset_prepared_path: last_run_prepared dataset_prepared_path: last_run_prepared
val_set_size: 0.01 val_set_size: 0.05
output_dir: ./qlora-out output_dir: ./qlora-out
adapter: qlora adapter: qlora
@@ -38,7 +38,7 @@ lora_target_modules:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
@@ -62,6 +62,9 @@ logging_steps: 1
xformers_attention: xformers_attention:
flash_attention: true flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 0.05
eval_table_size: eval_table_size:

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b wandb_project: mpt-alpaca-7b
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./mpt-alpaca-7b output_dir: ./mpt-alpaca-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./openllama-out output_dir: ./openllama-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -29,7 +29,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./lora-out output_dir: ./lora-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -9,7 +9,7 @@ datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
type: alpaca type: alpaca
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 1024 sequence_len: 1024
@@ -23,7 +23,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -1,5 +1,5 @@
base_model: microsoft/phi-1_5 base_model: microsoft/phi-1_5
model_type: MixFormerSequentialForCausalLM model_type: PhiForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_llama_derived_model: false is_llama_derived_model: false
trust_remote_code: true trust_remote_code: true
@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -31,7 +31,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -24,7 +24,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./pythia-12b output_dir: ./pythia-12b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

View File

@@ -18,7 +18,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./lora-alpaca-pythia output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1

68
examples/qwen/lora.yml Normal file
View File

@@ -0,0 +1,68 @@
base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 2048 # supports up to 8192
sample_packing: false
pad_to_sequence_len:
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

68
examples/qwen/qlora.yml Normal file
View File

@@ -0,0 +1,68 @@
base_model: Qwen/Qwen-7B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_qwen_derived_model: true
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 2048 # supports up to 8192
sample_packing: false
pad_to_sequence_len:
adapter: qlora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -22,7 +22,7 @@ lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b wandb_project: redpajama-alpaca-3b
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./redpajama-alpaca-3b output_dir: ./redpajama-alpaca-3b
batch_size: 4 batch_size: 4

View File

@@ -21,7 +21,7 @@ lora_fan_in_fan_out:
wandb_project: lora-replit wandb_project: lora-replit
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./lora-replit output_dir: ./lora-replit
batch_size: 8 batch_size: 8

View File

@@ -16,7 +16,7 @@ datasets:
- openassistant_best_replies_train.jsonl - openassistant_best_replies_train.jsonl
type: "completion" type: "completion"
dataset_prepared_path: dataset_prepared_path:
val_set_size: 0.01 val_set_size: 0.05
# enable QLoRA # enable QLoRA
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
@@ -38,7 +38,7 @@ lora_fan_in_fan_out:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
wandb_watch: wandb_watch:
wandb_run_id: wandb_name:
wandb_log_model: wandb_log_model:
output_dir: ./qlora-out output_dir: ./qlora-out

View File

@@ -1 +0,0 @@
# Page

View File

@@ -1,4 +0,0 @@
# Table of contents
* [Page](README.md)
* [Small dev details](small-dev-details.md)

View File

@@ -1,3 +0,0 @@
# Small dev details
/

View File

@@ -1,23 +1,22 @@
--extra-index-url https://download.pytorch.org/whl/cu118
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
torch==2.0.1 auto-gptq==0.5.1
auto-gptq
packaging packaging
peft @ git+https://github.com/huggingface/peft.git peft==0.6.0
transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697 transformers==4.35.2
tokenizers==0.15.0
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9 accelerate==0.24.1
deepspeed deepspeed
addict addict
fire fire
PyYAML>=6.0 PyYAML>=6.0
datasets datasets>=2.15.0
flash-attn>=2.3.0 flash-attn==2.3.3
sentencepiece sentencepiece
wandb wandb
einops einops
xformers>=0.0.22 xformers==0.0.22
optimum optimum==1.13.2
hf_transfer hf_transfer
colorama colorama
numba numba
@@ -31,4 +30,10 @@ scikit-learn==1.2.2
pynvml pynvml
art art
fschat==0.2.29 fschat==0.2.29
tensor_parallel gradio==3.50.2
tensorboard
# remote filesystems
s3fs
gcsfs
# adlfs

View File

@@ -6,8 +6,10 @@ import os
import random import random
import sys import sys
from pathlib import Path from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import gradio as gr
import torch import torch
import yaml import yaml
@@ -16,7 +18,7 @@ from accelerate.commands.config import config_args
from art import text2art from art import text2art
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextStreamer from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
@@ -27,6 +29,7 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_tokenizer from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -44,7 +47,7 @@ def print_axolotl_text_art(suffix=None):
ascii_text = " axolotl" ascii_text = " axolotl"
if suffix: if suffix:
ascii_text += f" x {suffix}" ascii_text += f" x {suffix}"
ascii_art = text2art(" axolotl", font=font) ascii_art = text2art(ascii_text, font=font)
if is_main_process(): if is_main_process():
print(ascii_art) print(ascii_art)
@@ -69,7 +72,7 @@ def do_merge_lora(
LOG.info("running merge of LoRA with base model") LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload() model = model.merge_and_unload()
model.to(dtype=torch.float16) model.to(dtype=cfg.torch_dtype)
if cfg.local_rank == 0: if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}") LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
@@ -153,6 +156,91 @@ def do_inference(
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
model = model.to(cfg.device)
def generate(instruction):
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
all_text = ""
for new_text in streamer:
all_text += new_text
yield all_text
demo = gr.Interface(
fn=generate,
inputs="textbox",
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(show_api=False, share=True)
def choose_config(path: Path): def choose_config(path: Path):
yaml_files = list(path.glob("*.yml")) yaml_files = list(path.glob("*.yml"))
@@ -209,6 +297,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
validate_config(cfg) validate_config(cfg)
prepare_optim_env(cfg)
normalize_config(cfg) normalize_config(cfg)
setup_wandb_env_vars(cfg) setup_wandb_env_vars(cfg)

View File

@@ -6,11 +6,16 @@ from pathlib import Path
import fire import fire
import transformers import transformers
from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art from axolotl.cli import (
do_inference,
do_inference_gradio,
load_cfg,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
@@ -21,7 +26,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
) )
parsed_cli_args.inference = True parsed_cli_args.inference = True
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) if gradio:
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -6,35 +6,33 @@ import abc
import importlib import importlib
import logging import logging
import math import math
import os
import sys import sys
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional
import tensor_parallel as tp
import torch import torch
import transformers import transformers
from datasets import Dataset from datasets import Dataset
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import SequentialDistributedSampler from transformers.trainer_utils import seed_worker
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
GPUStatsCallback, GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback, SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
log_prediction_callback_factory, log_prediction_callback_factory,
) )
from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.distributed import is_distributed
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
try: try:
@@ -104,8 +102,9 @@ class AxolotlTrainingArguments(TrainingArguments):
bench_source_max_len: int = field( bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."} default=2048, metadata={"help": "Maximum source sequence length for bench."}
) )
tensor_parallel: bool = field( dataloader_prefetch_factor: Optional[int] = field(
default=False, metadata={"help": "Use tensor parallelism to train"} default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
) )
@@ -150,70 +149,102 @@ class AxolotlTrainer(Trainer):
return self.lr_scheduler return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing: if self.args.sample_packing:
return DistributedSampler( return MultipackBatchSampler(
self.train_dataset, RandomSampler(self.train_dataset),
num_replicas=self.args.world_size, self.args.train_batch_size,
rank=self.args.process_index, drop_last=True,
seed=self.args.seed, batch_max_len=self._train_batch_size * self.args.max_seq_length,
lengths=(
self.train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
) )
return super()._get_train_sampler() return super()._get_train_sampler()
def _get_eval_sampler( def _get_eval_sampler(
self, eval_dataset: Dataset self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]: ) -> Optional[torch.utils.data.Sampler]:
if ( if self.args.sample_packing and self.args.eval_sample_packing is not False:
self.args.world_size > 1 return MultipackBatchSampler(
and self.args.sample_packing SequentialSampler(eval_dataset),
and self.args.eval_sample_packing is not False self.args.per_device_eval_batch_size,
): drop_last=True,
return SequentialDistributedSampler( batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
eval_dataset, lengths=(
num_replicas=self.args.world_size, eval_dataset.data.column("position_ids")
rank=self.args.process_index, .to_pandas()
batch_size=self.args.per_device_eval_batch_size, .apply(lambda x: x[-1] + 1)
.values
),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
) )
return super()._get_eval_sampler(eval_dataset) return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]: def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing: if self.args.sample_packing:
train_sampler = self._get_train_sampler() train_dataset = self.train_dataset
return self.accelerator.prepare( train_dataset = train_dataset.remove_columns(["length"])
MultipackDistributedDataloader( data_collator = self.data_collator
self.train_dataset, dataloader_params = {
batch_size=self._train_batch_size, "batch_size": self._train_batch_size,
seq_max_length=self.args.max_seq_length, "collate_fn": data_collator,
collate_fn=self.data_collator, "num_workers": self.args.dataloader_num_workers,
sampler=train_sampler, "pin_memory": self.args.dataloader_pin_memory,
packing_efficiency_estimate=self.args.sample_packing_efficiency, }
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, if self.args.dataloader_prefetch_factor:
device_count=int(os.environ.get("WORLD_SIZE", 1)), dataloader_params[
num_epochs=self.num_epochs, "prefetch_factor"
) ] = self.args.dataloader_prefetch_factor
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
) )
return super().get_train_dataloader() return super().get_train_dataloader()
def get_eval_dataloader( def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing and self.args.eval_sample_packing is not False: if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = ( eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset if eval_dataset is not None else self.eval_dataset
) )
eval_sampler = self._get_eval_sampler(eval_dataset) eval_sampler = self._get_eval_sampler(eval_dataset)
return self.accelerator.prepare( eval_dataset = eval_dataset.remove_columns(["length"])
MultipackDistributedDataloader( data_collator = self.data_collator
eval_dataset, dataloader_params = {
batch_size=self.args.eval_batch_size, "batch_size": self.args.eval_batch_size,
seq_max_length=self.args.max_seq_length, "collate_fn": data_collator,
collate_fn=self.data_collator, "num_workers": self.args.dataloader_num_workers,
sampler=eval_sampler, "pin_memory": self.args.dataloader_pin_memory,
packing_efficiency_estimate=self.args.sample_packing_efficiency, }
sample_packing_seq_len_multiplier=self.args.eval_batch_size, if self.args.dataloader_prefetch_factor:
device_count=int(os.environ.get("WORLD_SIZE", 1)), dataloader_params[
num_epochs=self.num_epochs, "prefetch_factor"
) ] = self.args.dataloader_prefetch_factor
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
) )
return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)
@@ -227,13 +258,15 @@ class AxolotlTrainer(Trainer):
def get_bench_dataloader( def get_bench_dataloader(
self, self,
bench_dataset: Dataset, bench_dataset: Dataset,
) -> Union[DataLoader, MultipackDistributedDataloader]: ) -> DataLoader:
dataloader_params = { dataloader_params = {
"batch_size": self.args.eval_batch_size, "batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator, "collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers, "num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory, "pin_memory": self.args.dataloader_pin_memory,
} }
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if not isinstance(bench_dataset, torch.utils.data.IterableDataset): if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
@@ -251,14 +284,6 @@ class AxolotlTrainer(Trainer):
# return (loss, outputs) if return_outputs else loss # return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs) return super().compute_loss(model, inputs, return_outputs=return_outputs)
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.tensor_parallel:
model = tp.tensor_parallel(model, distributed=is_distributed())
model.hf_device_map = tp.infer_sharded_device_map(model)
else:
model = super()._wrap_model(model, training=training, dataloader=dataloader)
return model
class OneCycleLRSchedulerTrainer(AxolotlTrainer): class OneCycleLRSchedulerTrainer(AxolotlTrainer):
""" """
@@ -384,10 +409,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return trainer_kwargs, trainer_cls return trainer_kwargs, trainer_cls
def hook_post_create_trainer(self, trainer): def hook_post_create_trainer(self, trainer):
if self.cfg.tensor_parallel: # TODO
trainer.model = trainer.accelerator.prepare_model(
trainer.model, device_placement=True
)
return trainer return trainer
def get_callbacks(self): def get_callbacks(self):
@@ -409,6 +431,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
) )
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
@@ -440,11 +465,14 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlTrainer return AxolotlTrainer
def build(self, total_num_steps): def build(self, total_num_steps):
warmup_steps = ( warmup_steps = None
self.cfg.warmup_steps if self.cfg.warmup_steps is not None:
if self.cfg.warmup_steps is not None warmup_steps = self.cfg.warmup_steps
else min(int(0.03 * total_num_steps), 100) elif self.cfg.warmup_ratio is not None:
) warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
else:
warmup_steps = min(int(0.03 * total_num_steps), 100)
logging_steps = ( logging_steps = (
self.cfg.logging_steps self.cfg.logging_steps
if self.cfg.logging_steps is not None if self.cfg.logging_steps is not None
@@ -509,16 +537,29 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
"sample_packing_efficiency" "sample_packing_efficiency"
] = self.cfg.sample_packing_eff_est ] = self.cfg.sample_packing_eff_est
if self.cfg.eval_steps: if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs[
"dataloader_pin_memory"
] = self.cfg.dataloader_pin_memory
if self.cfg.dataloader_num_workers is not None:
training_arguments_kwargs[
"dataloader_num_workers"
] = self.cfg.dataloader_num_workers
if self.cfg.dataloader_prefetch_factor is not None:
training_arguments_kwargs[
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor
if self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
elif self.cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps" training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.evaluation_strategy: elif self.cfg.evaluation_strategy:
training_arguments_kwargs[ training_arguments_kwargs[
"evaluation_strategy" "evaluation_strategy"
] = self.cfg.evaluation_strategy ] = self.cfg.evaluation_strategy
elif self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
else: else:
# we have an eval set, but no steps defined, default to use epoch # we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch" training_arguments_kwargs["evaluation_strategy"] = "epoch"
@@ -606,7 +647,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
training_arguments_kwargs["run_name"] = ( training_arguments_kwargs["run_name"] = (
self.cfg.wandb_run_id if self.cfg.use_wandb else None self.cfg.wandb_name if self.cfg.use_wandb else None
) )
training_arguments_kwargs["optim"] = ( training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf" self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
@@ -624,15 +665,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.sample_packing if self.cfg.sample_packing else False self.cfg.sample_packing if self.cfg.sample_packing else False
) )
training_arguments_kwargs["eval_sample_packing"] = ( training_arguments_kwargs["eval_sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False self.cfg.sample_packing
if self.cfg.eval_sample_packing is not False
else False
) )
training_arguments_kwargs[ training_arguments_kwargs[
"sample_packing_seq_len_multiplier" "sample_packing_seq_len_multiplier"
] = self.cfg.micro_batch_size ] = self.cfg.micro_batch_size
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs["tensor_parallel"] = self.cfg.tensor_parallel is True
training_arguments_kwargs = self.hook_pre_create_training_args( training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs training_arguments_kwargs
) )
@@ -690,7 +731,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
train_dataset=self.train_dataset, train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset, eval_dataset=self.eval_dataset,
args=training_args, args=training_args,
data_collator=DataCollatorForSeq2Seq( data_collator=BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer, self.tokenizer,
return_tensors="pt", return_tensors="pt",
**data_collator_kwargs, **data_collator_kwargs,
@@ -708,4 +749,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
for callback in self.get_post_trainer_create_callbacks(trainer): for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback) trainer.add_callback(callback)
if self.cfg.deepspeed and self.cfg.sample_packing:
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
"train_micro_batch_size_per_gpu"
] = self.cfg.micro_batch_size
return trainer return trainer

View File

@@ -2,7 +2,7 @@
import logging import logging
import os import os
from typing import List from typing import List, Optional
import torch import torch
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
@@ -30,14 +30,20 @@ class TokenizedPromptDataset(Dataset):
self, self,
prompt_tokenizer: PromptTokenizingStrategy, prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset, dataset: IterableDataset,
process_count: Optional[int] = None,
**kwargs, **kwargs,
): ):
self.prompt_tokenizer = prompt_tokenizer self.prompt_tokenizer = prompt_tokenizer
self.process_count = process_count
super().__init__(self.process(dataset).data, **kwargs) super().__init__(self.process(dataset).data, **kwargs)
def process(self, dataset): def process(self, dataset):
features = dataset.features.keys() features = dataset.features.keys()
num_proc = min(64, os.cpu_count()) num_proc = (
min(64, self.process_count)
if self.process_count
else min(64, os.cpu_count())
)
map_kwargs = {} map_kwargs = {}
if self.prompt_tokenizer.supports_batched: if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True map_kwargs["batched"] = True

View File

@@ -3,4 +3,6 @@ MixFormers model architecture used for phi models
""" """
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
from .configuration_phi import PhiConfig # noqa
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
from .modeling_phi import PhiForCausalLM # noqa

View File

@@ -0,0 +1,65 @@
# pylint: skip-file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
from typing import Optional
from transformers import PretrainedConfig
class PhiConfig(PretrainedConfig):
"""Phi configuration."""
model_type = "phi"
attribute_map = {
"max_position_embeddings": "n_positions",
"hidden_size": "n_embd",
"num_attention_heads": "n_head",
"num_hidden_layers": "n_layer",
}
def __init__(
self,
vocab_size: int = 50304,
n_positions: int = 2048,
n_embd: int = 1024,
n_layer: int = 20,
n_inner: Optional[int] = None,
n_head: int = 16,
n_head_kv: Optional[int] = None,
rotary_dim: Optional[int] = 32,
activation_function: Optional[str] = "gelu_new",
flash_attn: bool = False,
flash_rotary: bool = False,
fused_dense: bool = False,
attn_pdrop: float = 0.0,
embd_pdrop: float = 0.0,
resid_pdrop: float = 0.0,
layer_norm_epsilon: float = 1e-5,
initializer_range: float = 0.02,
tie_word_embeddings: bool = False,
pad_vocab_size_multiple: int = 64,
**kwargs
) -> None:
self.vocab_size = int(
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.n_head_kv = n_head_kv
self.rotary_dim = min(rotary_dim, n_embd // n_head)
self.activation_function = activation_function
self.flash_attn = flash_attn
self.flash_rotary = flash_rotary
self.fused_dense = fused_dense
self.attn_pdrop = attn_pdrop
self.embd_pdrop = embd_pdrop
self.resid_pdrop = resid_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,168 @@
# Adapted from Unsloth
# https://github.com/unslothai/unsloth/blob/4b97a810b509c93f44be4c037c7aa18fb8922884/unsloth/kernels/cross_entropy_loss.py
import triton
import triton.language as tl
import torch
MAX_FUSED_SIZE = 65536
def calculate_settings(n):
BLOCK_SIZE = triton.next_power_of_2(n)
# CUDA only supports 65536 - 2^16 threads per block
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
num_warps = 4
if BLOCK_SIZE >= 32768: num_warps = 32
elif BLOCK_SIZE >= 8192: num_warps = 16
elif BLOCK_SIZE >= 2048: num_warps = 8
return BLOCK_SIZE, num_warps
pass
@triton.jit
def _cross_entropy_forward(logits_ptr, logits_row_stride,
loss_ptr,
lse_ptr,
labels_ptr,
n_cols,
BLOCK_SIZE: tl.constexpr,):
"""
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
Pi = exp(xi) / sum(exp(xi))
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
= -y [ x - log[sum(exp(x))] ]
= y * (log[sum(exp(x))] - x)
If y == 0: CE_i = 0
If y == 1: CE_i = logsumexp - x
"""
row_idx = tl.program_id(0)
logits_ptr += row_idx * logits_row_stride
loss_ptr += row_idx
lse_ptr += row_idx
labels_ptr += row_idx
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# TODO: Fixup int32 locations to int64
label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
max_logits = tl.max(logits, 0)
# Maximum stops overflow
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
tl.store(lse_ptr, lse)
if label_idx != -100:
logits_label = tl.load(logits_ptr + label_idx).to(tl.float32)
loss = lse - logits_label
else:
loss = 0.0
tl.store(loss_ptr, loss)
pass
@triton.jit
def _cross_entropy_backward(logits_ptr, logits_row_stride,
dloss_ptr, dloss_row_stride,
lse_ptr,
labels_ptr,
n_cols,
BLOCK_SIZE: tl.constexpr,):
"""
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
From https://en.wikipedia.org/wiki/LogSumExp
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
If y == 0: dC/dx = 0
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
"""
row_idx = tl.program_id(0)
logits_ptr += row_idx * logits_row_stride
dloss_ptr += row_idx * dloss_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# TODO: Fixup int32 locations to int64
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
if label_idx != -100:
dloss = tl.load(dloss_ptr)
else:
dloss = 0.0
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = 0).to(tl.float32)
lse = tl.load(lse_ptr + row_idx)
probs = tl.exp(logits - lse)
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
tl.store(logits_ptr + col_offsets, dloss * probs, mask = mask)
class CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels):
n_rows, n_cols = logits.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
_cross_entropy_forward[(n_rows,)](
logits, logits.stride(0),
losses,
logsumexp,
labels,
n_cols,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.save_for_backward(logits, logsumexp, labels)
return losses
pass
@staticmethod
def backward(ctx, dlosses):
logits, logsumexp, labels = ctx.saved_tensors
n_rows, n_cols = logits.shape
_cross_entropy_backward[(n_rows,)](
logits, logits.stride(0),
dlosses, dlosses.stride(0),
logsumexp,
labels,
n_cols,
BLOCK_SIZE = ctx.BLOCK_SIZE,
num_warps = ctx.num_warps,
)
return logits, None, None,
pass
pass
def fast_cross_entropy_loss(logits, labels):
"""
Arguments:
logits: (batch, seq_len, vocab_size)
labels: (batch, seq_len,)
Returns:
losses: float
"""
batch, seq_len, d = logits.shape
assert(labels.shape == (batch, seq_len))
loss = CrossEntropyLoss.apply(
logits.view(batch*seq_len, d),
labels.view(-1),
)
n_items = torch.count_nonzero(labels != -100)
return loss.sum() / n_items
pass

View File

@@ -321,6 +321,8 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen # only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape is_causal = key_states.shape == query_states.shape
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing # special handling using sample packing
qkv = torch.stack( qkv = torch.stack(
@@ -330,7 +332,12 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func( output = flash_attn_varlen_qkvpacked_func(
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True qkv,
cu_seqlens,
max_seqlen,
dropout_p=dropout_rate,
softmax_scale=None,
causal=True,
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape: elif query_states.shape == key_states.shape:
@@ -353,7 +360,7 @@ def flashattn_forward(
qkv_unpad, qkv_unpad,
cu_seqlens_q, cu_seqlens_q,
max_seqlen_q, max_seqlen_q,
0.0, dropout_p=dropout_rate,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
) )
@@ -366,6 +373,7 @@ def flashattn_forward(
output = flash_attn_kvpacked_func( output = flash_attn_kvpacked_func(
query_states, query_states,
torch.stack([key_states, value_states], 2), torch.stack([key_states, value_states], 2),
dropout_p=dropout_rate,
causal=is_causal, causal=is_causal,
) )
else: else:
@@ -398,7 +406,7 @@ def flashattn_forward(
cu_seqlens_k, cu_seqlens_k,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
0.0, dropout_p=dropout_rate,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
) )

View File

@@ -25,6 +25,8 @@ def sdp_attention_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()

View File

@@ -29,6 +29,8 @@ def xformers_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False, output_attentions: bool = False,
use_cache: bool = False, use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()

View File

@@ -13,16 +13,20 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
flash_attn_varlen_kvpacked_func, flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_func,
) )
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
MistralAttention as OriginalMistralAttention, MistralAttention as OriginalMistralAttention,
) )
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer, MistralDecoderLayer as OriginalMistralDecoderLayer,
) )
from transformers.models.mistral.modeling_mistral import (
MistralForCausalLM as OriginalMistralForCausalLM,
)
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.monkeypatch.cross_entropy import fast_cross_entropy_loss
LOG = logging.getLogger("axolotl.monkeypatch.mistral") LOG = logging.getLogger("axolotl.monkeypatch.mistral")
@@ -36,6 +40,9 @@ def replace_mistral_attn_with_flash_attn(
transformers.models.mistral.modeling_mistral.MistralAttention.forward = ( transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
flashattn_forward flashattn_forward
) )
transformers.models.mistral.modeling_mistral.MistralForCausalLM.forward = (
mistral_causallm_forward
)
if packed: if packed:
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = ( transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
MistralDecoderLayer MistralDecoderLayer
@@ -201,6 +208,8 @@ def flashattn_forward(
# only on first autoregressive step q,k,v have same seqlen # only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape is_causal = key_states.shape == query_states.shape
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1: if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing # special handling using sample packing
qkv = torch.stack( qkv = torch.stack(
@@ -213,7 +222,7 @@ def flashattn_forward(
qkv, qkv,
cu_seqlens, cu_seqlens,
max_seqlen, max_seqlen,
0.0, dropout_p=dropout_rate,
softmax_scale=None, softmax_scale=None,
causal=True, causal=True,
window_size=window_size, window_size=window_size,
@@ -239,7 +248,7 @@ def flashattn_forward(
qkv_unpad, qkv_unpad,
cu_seqlens_q, cu_seqlens_q,
max_seqlen_q, max_seqlen_q,
0.0, dropout_p=dropout_rate,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
window_size=window_size, window_size=window_size,
@@ -253,6 +262,7 @@ def flashattn_forward(
output = flash_attn_kvpacked_func( output = flash_attn_kvpacked_func(
query_states, query_states,
torch.stack([key_states, value_states], 2), torch.stack([key_states, value_states], 2),
dropout_p=dropout_rate,
causal=is_causal, causal=is_causal,
window_size=window_size, window_size=window_size,
) )
@@ -286,7 +296,7 @@ def flashattn_forward(
cu_seqlens_k, cu_seqlens_k,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
0.0, dropout_p=dropout_rate,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
window_size=window_size, window_size=window_size,
@@ -638,3 +648,71 @@ class MistralDecoderLayer(OriginalMistralDecoderLayer):
outputs += (present_key_value,) outputs += (present_key_value,)
return outputs return outputs
def mistral_causallm_forward(
self: OriginalMistralForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits
if not hasattr(self, "extra_ignored_labels"):
self.extra_ignored_labels = torch.full((self.model.config.max_position_embeddings, 1), -100, device=shift_logits.device)
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
shift_labels = shift_labels.to(shift_logits.device)
# FAST CROSS ENTROPY
loss = fast_cross_entropy_loss(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

View File

@@ -22,7 +22,13 @@ class PromptStyle(Enum):
CHATML = "chatml" CHATML = "chatml"
class AlpacaPrompter: class Prompter:
"""
Base prompter class for all prompters
"""
class AlpacaPrompter(Prompter):
""" """
Base class for alpaca prompters Base class for alpaca prompters
""" """
@@ -69,7 +75,7 @@ class AlpacaPrompter:
else: else:
res = ( res = (
self.system_format.format(system=self.system_no_input_prompt) self.system_format.format(system=self.system_no_input_prompt)
if self.system_prompt if self.system_no_input_prompt
else "" else ""
) + self.turn_no_input_format.format(instruction=instruction) ) + self.turn_no_input_format.format(instruction=instruction)
if output: if output:
@@ -159,7 +165,7 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
""" """
class ReflectAlpacaPrompter: class ReflectAlpacaPrompter(Prompter):
""" """
Prompter for ReflectAlpaca Prompter for ReflectAlpaca
""" """
@@ -254,7 +260,7 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
) )
class ShareGPTPrompter: # pylint: disable=too-few-public-methods class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
""" """
A prompter that generates prompts for the ShareGPT A prompter that generates prompts for the ShareGPT
""" """
@@ -349,7 +355,7 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
) )
class UnsupportedPrompter: class UnsupportedPrompter(Prompter):
""" """
A dummy class for custom prompters A dummy class for custom prompters
""" """

View File

@@ -1,6 +1,5 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import logging
import os import os
import signal import signal
import sys import sys
@@ -10,6 +9,7 @@ from typing import Optional
import torch import torch
import transformers.modelcard import transformers.modelcard
from accelerate.logging import get_logger
from datasets import Dataset from datasets import Dataset
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.deepspeed import is_deepspeed_zero3_enabled
@@ -26,7 +26,7 @@ src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir) sys.path.insert(0, src_dir)
configure_logging() configure_logging()
LOG = logging.getLogger("axolotl.train") LOG = get_logger("axolotl.train")
@dataclass @dataclass
@@ -44,7 +44,10 @@ def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
): ):
# load the tokenizer first # load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
train_dataset = dataset_meta.train_dataset train_dataset = dataset_meta.train_dataset
@@ -52,7 +55,10 @@ def train(
total_num_steps = dataset_meta.total_num_steps total_num_steps = dataset_meta.total_num_steps
# Load the model and tokenizer # Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...") msg = "loading model"
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True

View File

@@ -1,13 +1,10 @@
"""Benchmarking and measurement utilities""" """Benchmarking and measurement utilities"""
import functools import functools
import logging
import pynvml import pynvml
import torch import torch
from pynvml.nvml import NVMLError from pynvml.nvml import NVMLError
LOG = logging.getLogger("axolotl.utils.bench")
def check_cuda_device(default_value): def check_cuda_device(default_value):
""" """
@@ -65,14 +62,7 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device): def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available(): usage, cache, misc = gpu_memory_usage_all(device)
return (0, 0, 0)
try:
usage, cache, misc = gpu_memory_usage_all(device)
except ValueError as exc:
LOG.exception(exc)
return (0, 0, 0)
extras = [] extras = []
if cache > 0: if cache > 0:
extras.append(f"+{cache:.03f}GB cache") extras.append(f"+{cache:.03f}GB cache")

View File

@@ -124,6 +124,36 @@ class GPUStatsCallback(
return control return control
class LossWatchDogCallback(TrainerCallback):
"""Callback to track loss and stop training if loss is too high"""
def __init__(self, cfg):
self.cfg = cfg
self.logged = False
self.violations = 0
self.threshold = cfg.loss_watchdog_threshold
self.patience = cfg.loss_watchdog_patience or 3
def on_step_end(
self,
_args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**_kwargs,
):
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
if state.log_history[-1]["loss"] > self.threshold:
self.violations += 1
if self.violations >= self.patience:
LOG.warning(
"Loss is too high, stopping training (loss_watchdog_threshold)"
)
control.should_training_stop = True
else:
self.violations = 0
return control
def bench_eval_callback_factory(trainer, tokenizer): def bench_eval_callback_factory(trainer, tokenizer):
accuracy = evaluate.load("accuracy") accuracy = evaluate.load("accuracy")
abcd_idx = [ abcd_idx = [

View File

@@ -119,3 +119,30 @@ class DataCollatorForSeq2Seq:
features["decoder_input_ids"] = decoder_input_ids features["decoder_input_ids"] = decoder_input_ids
return features return features
@dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
Collator for multipack specific to the using the BatchSampler
"""
def __call__(self, features, return_tensors=None):
chunked_data = {}
for feature in features[0].keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(1) * np.array(item[feature])
for item in features
if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)

View File

@@ -27,7 +27,7 @@ def choose_device(cfg):
cfg.device = get_device() cfg.device = get_device()
if cfg.world_size == 1: if cfg.world_size == 1:
cfg.device_map = "auto" cfg.device_map = cfg.device_map or "auto"
else: else:
if cfg.device.startswith("cuda"): if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()} cfg.device_map = {"": torch.cuda.current_device()}
@@ -122,6 +122,19 @@ def normalize_config(cfg):
or (cfg.model_type and "mistral" in cfg.model_type.lower()) or (cfg.model_type and "mistral" in cfg.model_type.lower())
) )
cfg.is_qwen_derived_model = (
(
hasattr(model_config, "model_type")
and model_config.model_type
in [
"qwen",
]
)
or cfg.is_qwen_derived_model
or "qwen" in cfg.base_model.lower()
or (cfg.model_type and "qwen" in cfg.model_type.lower())
)
if isinstance(cfg.learning_rate, str): if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate) cfg.learning_rate = float(cfg.learning_rate)
@@ -165,7 +178,11 @@ def validate_config(cfg):
"batch_size is not recommended. Please use gradient_accumulation_steps instead.", "batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.", "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
) )
if cfg.eval_batch_size != cfg.micro_batch_size: if (
cfg.eval_batch_size
and cfg.micro_batch_size
and cfg.eval_batch_size != cfg.micro_batch_size
):
LOG.warning( LOG.warning(
"eval_batch_size != micro_batch_size. This can lead to VRAM instability." "eval_batch_size != micro_batch_size. This can lead to VRAM instability."
) )
@@ -369,10 +386,24 @@ def validate_config(cfg):
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit." "If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
) )
if cfg.tensor_parallel and cfg.gradient_checkpointing: if cfg.rope_scaling:
raise ValueError( LOG.warning("`rope_scaling` should now be be a key under `model_config`")
"TensorParallelPreTrainedModel does not support gradient checkpointing"
if cfg.warmup_steps and cfg.warmup_ratio:
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
if cfg.is_qwen_derived_model and cfg.gradient_checkpointing:
LOG.warning(
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
) )
if cfg.wandb_run_id and not cfg.wandb_name:
cfg.wandb_name = cfg.wandb_run_id
LOG.warning(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -3,7 +3,7 @@ import functools
import hashlib import hashlib
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
import torch import torch
from datasets import ( from datasets import (
@@ -34,6 +34,7 @@ from axolotl.prompters import (
JeopardyPrompter, JeopardyPrompter,
MultipleChoiceConcisePrompter, MultipleChoiceConcisePrompter,
MultipleChoiceExplainPrompter, MultipleChoiceExplainPrompter,
Prompter,
ReflectAlpacaPrompter, ReflectAlpacaPrompter,
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
UnsupportedPrompter, UnsupportedPrompter,
@@ -78,19 +79,27 @@ def prepare_dataset(cfg, tokenizer):
train_dataset, eval_dataset = process_datasets_for_packing( train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset, tokenizer cfg, train_dataset, eval_dataset, tokenizer
) )
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
raise ValueError(
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
)
if cfg.max_steps: if cfg.max_steps:
total_num_steps = min( total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
) )
LOG.info(f"Maximum number of steps set at {total_num_steps}") LOG.info(f"Maximum number of steps set at {total_num_steps}")
else: else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) total_num_steps = calculate_total_num_steps(cfg, train_dataset)
return train_dataset, eval_dataset, total_num_steps, prompters return train_dataset, eval_dataset, total_num_steps, prompters
def load_tokenized_prepared_datasets( def load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path
) -> DatasetDict: ) -> Tuple[DatasetDict, List[Prompter]]:
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
ds_hash = str( ds_hash = str(
md5( md5(
@@ -98,7 +107,12 @@ def load_tokenized_prepared_datasets(
str(cfg.sequence_len) str(cfg.sequence_len)
+ "@" + "@"
+ "|".join( + "|".join(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) sorted(
[
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
for d in cfg.datasets
]
)
) )
+ "|" + "|"
+ tokenizer_name + tokenizer_name
@@ -164,6 +178,66 @@ def load_tokenized_prepared_datasets(
except (FileNotFoundError, ConnectionError): except (FileNotFoundError, ConnectionError):
pass pass
ds_from_cloud = False
storage_options = {}
remote_file_system = None
if config_dataset.path.startswith("s3://"):
try:
import aiobotocore.session # type: ignore
import s3fs # type: ignore
except ImportError as exc:
raise ImportError(
"s3:// paths require aiobotocore and s3fs to be installed"
) from exc
# Takes credentials from ~/.aws/credentials for default profile
s3_session = aiobotocore.session.AioSession(profile="default")
storage_options = {"session": s3_session}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif config_dataset.path.startswith(
"gs://"
) or config_dataset.path.startswith("gcs://"):
try:
import gcsfs # type: ignore
except ImportError as exc:
raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed"
) from exc
# gcsfs will use default credentials from the environment else anon
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
# TODO: Figure out how to get auth creds passed
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
# try:
# import adlfs
# except ImportError as exc:
# raise ImportError(
# "adl:// or abfs:// paths require adlfs to be installed"
# ) from exc
# # Gen 1
# storage_options = {
# "tenant_id": TENANT_ID,
# "client_id": CLIENT_ID,
# "client_secret": CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": ACCOUNT_NAME,
# "account_key": ACCOUNT_KEY,
# }
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
try:
if remote_file_system and remote_file_system.exists(
config_dataset.path
):
ds_from_cloud = True
except (FileNotFoundError, ConnectionError):
pass
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
local_path = Path(config_dataset.path) local_path = Path(config_dataset.path)
if local_path.exists(): if local_path.exists():
@@ -177,17 +251,8 @@ def load_tokenized_prepared_datasets(
split=None, split=None,
) )
elif local_path.is_file(): elif local_path.is_file():
ds_type = "json" ds_type = get_ds_type(config_dataset)
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
ds = load_dataset( ds = load_dataset(
ds_type, ds_type,
name=config_dataset.name, name=config_dataset.name,
@@ -207,6 +272,22 @@ def load_tokenized_prepared_datasets(
data_files=config_dataset.data_files, data_files=config_dataset.data_files,
token=use_auth_token, token=use_auth_token,
) )
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(config_dataset.path):
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
)
else: else:
if isinstance(config_dataset.data_files, str): if isinstance(config_dataset.data_files, str):
fp = hf_hub_download( fp = hf_hub_download(
@@ -298,11 +379,29 @@ def load_tokenized_prepared_datasets(
return dataset, prompters return dataset, prompters
def get_ds_type(config_dataset: DictDefault):
"""
Get the dataset type from the path if it's not specified
"""
ds_type = "json"
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
return ds_type
def load_prepare_datasets( def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
cfg, cfg,
default_dataset_prepared_path, default_dataset_prepared_path,
) -> Tuple[Dataset, Dataset, List[Any]]: ) -> Tuple[Dataset, Dataset, List[Prompter]]:
max_packed_sequence_len = ( max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
) )
@@ -311,7 +410,7 @@ def load_prepare_datasets(
) # make sure we don't accidentally set it larger than sequence_len ) # make sure we don't accidentally set it larger than sequence_len
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
prompters = [] prompters: List[Prompter] = []
if cfg.max_packed_sequence_len is not None: if cfg.max_packed_sequence_len is not None:
# see if we can go ahead and load the stacked dataset # see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else "" seed = f"@{str(cfg.seed)}" if cfg.seed else ""
@@ -445,14 +544,13 @@ def load_prepare_datasets(
train_fingerprint = md5(to_hash_train) train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test) test_fingerprint = md5(to_hash_test)
with zero_first(is_main_process()): dataset = dataset.train_test_split(
dataset = dataset.train_test_split( test_size=cfg.val_set_size,
test_size=cfg.val_set_size, shuffle=False,
shuffle=False, seed=cfg.seed or 42,
seed=cfg.seed or 42, train_new_fingerprint=train_fingerprint,
train_new_fingerprint=train_fingerprint, test_new_fingerprint=test_fingerprint,
test_new_fingerprint=test_fingerprint, )
)
train_dataset = dataset["train"] train_dataset = dataset["train"]
eval_dataset = dataset["test"] eval_dataset = dataset["test"]
@@ -482,10 +580,14 @@ def get_dataset_wrapper(
"user_defined", tokenizer, cfg, config_dataset.type.to_dict() "user_defined", tokenizer, cfg, config_dataset.type.to_dict()
) )
dataset_prompter = UnsupportedPrompter() dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset) dataset_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset): elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
dataset_prompter = UnsupportedPrompter() dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset) dataset_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
elif d_base_type == "alpaca": elif d_base_type == "alpaca":
dataset_prompter = AlpacaPrompter(d_prompt_style) dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaPromptTokenizingStrategy( ds_strategy = AlpacaPromptTokenizingStrategy(
@@ -494,7 +596,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
elif d_base_type == "explainchoice": elif d_base_type == "explainchoice":
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style) dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
@@ -504,7 +608,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
elif d_base_type == "concisechoice": elif d_base_type == "concisechoice":
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style) dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
@@ -514,7 +620,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
elif d_base_type == "summarizetldr": elif d_base_type == "summarizetldr":
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style) dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
@@ -524,7 +632,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
elif d_base_type == "jeopardy": elif d_base_type == "jeopardy":
dataset_prompter = JeopardyPrompter(d_prompt_style) dataset_prompter = JeopardyPrompter(d_prompt_style)
@@ -534,7 +644,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
elif d_base_type == "oasst": elif d_base_type == "oasst":
dataset_prompter = AlpacaPrompter(d_prompt_style) dataset_prompter = AlpacaPrompter(d_prompt_style)
@@ -544,7 +656,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
elif d_base_type == "gpteacher": elif d_base_type == "gpteacher":
dataset_prompter = GPTeacherPrompter(d_prompt_style) dataset_prompter = GPTeacherPrompter(d_prompt_style)
@@ -554,7 +668,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
elif d_base_type == "reflection": elif d_base_type == "reflection":
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style) dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
@@ -564,7 +680,9 @@ def get_dataset_wrapper(
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset) ds_wrapper = TokenizedPromptDataset(
ds_strategy, dataset, process_count=cfg.dataset_processes
)
dataset_wrapper = ds_wrapper dataset_wrapper = ds_wrapper
else: else:
suffix = "" suffix = ""

View File

@@ -1,342 +0,0 @@
# pylint: skip-file
import hashlib
import itertools
import logging
import math
import time
from queue import Queue
from threading import Thread
from typing import Any, Callable, List, Union
import numba
import numpy as np
from torch.utils.data import DistributedSampler, Sampler
LOG = logging.getLogger("axolotl.utils.dataloader")
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
not_found = True
for idx in range(n):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
if not_found:
return False
return True
@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
indices = np.argsort(a)[::-1]
a = a[indices]
bins: List[Any] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
break
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
return bins_result, len(a)
@numba.njit
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
):
"""
:param lengths: array of lengths of each sample
:param lengths_cumsum: cumulative sum of consecutive lengths
:param rank: rank for this process
:param c: length of tokens per batch
:param n: number of ranks
:return:
"""
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0
start_index = 0
result = []
result_totseqs = []
while True:
# binary search [left, right)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
left = mid
else:
right = mid
# use length left
batch, tot_seqs = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
if len(batch) < n:
break
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
# add total seqs for all ranks
result_totseqs.append(tot_seqs)
# yield batch[rank], tot_seqs, s, len(result) * c * n
return result, result_totseqs, s, len(result) * c * n
def chunk(iterable, n):
"""
Chunk data into tuples of length n
"""
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(itertools.islice(it, n)):
yield batch
def hash_indices(lst: List[int]) -> str:
# Convert the list of integers to a string representation
concatenated = ",".join(map(str, lst))
# Generate the hash
sha256 = hashlib.sha256()
sha256.update(concatenated.encode())
return sha256.hexdigest()
class MultipackDistributedDataloader:
"""Unpadded data loading using Multipack.
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
"""
def __init__(
self,
dataset: Any,
collate_fn: Callable,
seq_max_length: int = 2048,
batch_size: int = 1,
sampler: Union[Sampler, DistributedSampler] = None,
packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1,
prefetch_max: int = 1000,
num_epochs: int = 1,
):
# Dataset
self.dataset = dataset
self.lengths = (
dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
)
assert isinstance(self.lengths, np.ndarray)
assert batch_size % sample_packing_seq_len_multiplier == 0
assert batch_size >= sample_packing_seq_len_multiplier
self.sampler = sampler
self.batch_size = batch_size
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
self.seq_max_length = seq_max_length
self.batch_max_length = batch_size * seq_max_length
self.collate_fn = collate_fn
self.num_epochs = num_epochs
self.num_replicas = 1
self.rank = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count
# maxsize is maximum number of samples in queue
self.prefetch_max = prefetch_max
self.queue: Queue = Queue(maxsize=prefetch_max)
self.thread = None
def _worker(self):
LOG.info(
f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
)
for epoch in range(self.num_epochs):
for sample in self._internal_batch_generator():
while True:
if self.queue.full():
time.sleep(1)
else:
break
self.queue.put(sample)
# stop the queue when epoch is done
self.queue.put(None)
def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
if self.thread is None:
self.thread = Thread(target=self._worker, daemon=True)
self.thread.start()
while True:
item = self.queue.get()
if item is None:
break
yield item
def generate_batches(self, set_stats=False):
LOG.info("generating packed batches")
if self.sampler:
indices = [idx for idx in self.sampler]
else:
indices = range(0, len(self.dataset))
LOG.info(hash_indices(indices))
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, totseqs, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=self.rank,
# c=self.batch_max_length,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
n=self.num_replicas,
)
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
return batches, totseqs
def _internal_batch_generator(self):
all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys()
len_remaining = self._len_est()
for batches in chunk(
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
):
chunked_data = []
attn_mask_cum_idx = 0
for batch in batches:
concatenated = {}
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features:
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched_data)
if feature in item
]
attn_mask_cum_idx += len(batched_data)
concatenated[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature])
for item in batched_data
if feature in item
]
concatenated[feature] = np.concatenate(arrays)
chunked_data.append(concatenated)
yield self.collate_fn(chunked_data)
len_remaining -= 1
if not len_remaining:
return
# yield a no-op for cases where we don't have any data left to pack
for i in range(0, len_remaining):
yield self.collate_fn(
[
{
"input_ids": [0],
"labels": [-100],
"attention_mask": [True],
"position_ids": [0],
}
]
)
def _len_est(self):
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // self.device_count
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return (
math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
// self.seq_max_length
// self.batch_size
)
- 1
)
def __len__(self):
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
# the same share of total tokens
# if not self.eff_total_used:
# batches, _ = self.generate_batches(set_stats=True)
# LOG.info(
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
# f"actual packing efficiency: {self.efficiency()}"
# )
return max(1, self._len_est())
def len_w_stats(self):
if not self.eff_total_used:
batches, _ = self.generate_batches(set_stats=True)
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"actual packing efficiency: {self.efficiency()}"
)
return max(1, self._len_est())
def efficiency(self):
return self.eff_total_used / self.eff_total_slots

View File

@@ -50,6 +50,17 @@ def get_world_size():
return int(os.getenv("WORLD_SIZE", "1")) return int(os.getenv("WORLD_SIZE", "1"))
@contextmanager
def zero_only():
"""
Context manager that only runs the enclosed block on the main rank.
"""
if is_main_process():
yield
else:
yield None
@contextmanager @contextmanager
def zero_first(is_main): def zero_first(is_main):
""" """

View File

@@ -7,7 +7,6 @@ from typing import Optional, Tuple # noqa: F401
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
import transformers.utils.bitsandbytes
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training from peft import PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear from peft.tuners.lora import QuantLinear
@@ -18,7 +17,6 @@ from transformers import ( # noqa: F401
AutoTokenizer, AutoTokenizer,
BitsAndBytesConfig, BitsAndBytesConfig,
GPTQConfig, GPTQConfig,
LlamaConfig,
PreTrainedModel, PreTrainedModel,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
@@ -30,12 +28,40 @@ from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
quant_config_exists = hasattr(model_config, "quantization_config")
quant_config_method_is_gptq = (
quant_config_exists
and "quant_method" in model_config.quantization_config
and model_config.quantization_config["quant_method"] == "gptq"
)
if cfg.gptq and not quant_config_method_is_gptq:
raise ValueError(
"model_config.quantization_config is not set or quant_method is not set to gptq. "
"Please make sure to point to a GPTQ model."
)
if not cfg.gptq and quant_config_exists:
raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
)
def load_model_config(cfg): def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code = cfg.trust_remote_code is True trust_remote_code = cfg.trust_remote_code is True
return AutoConfig.from_pretrained( model_config = AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code model_config_name, trust_remote_code=trust_remote_code
) )
if cfg.model_config:
for key, val in cfg.model_config.items():
setattr(model_config, key, val)
check_model_config(cfg, model_config)
return model_config
def load_tokenizer(cfg): def load_tokenizer(cfg):
@@ -52,7 +78,7 @@ def load_tokenizer(cfg):
if cfg.tokenizer_type: if cfg.tokenizer_type:
tokenizer_cls = getattr(transformers, cfg.tokenizer_type) tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
tokenizer = tokenizer_cls.from_pretrained( tokenizer = tokenizer_cls.from_pretrained(
tokenizer_config, tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
@@ -81,6 +107,18 @@ def load_tokenizer(cfg):
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing: if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens
if cfg.is_qwen_derived_model:
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
for attr_name in token_ids:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, tokenizer.eod_id)
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
for attr_name in token_names:
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
if cfg.special_tokens: if cfg.special_tokens:
for k, val in cfg.special_tokens.items(): for k, val in cfg.special_tokens.items():
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
@@ -111,7 +149,6 @@ def load_model(
Load a model for a given configuration and tokenizer. Load a model for a given configuration and tokenizer.
""" """
base_model = cfg.base_model base_model = cfg.base_model
base_model_config = cfg.base_model_config
model_type = cfg.model_type model_type = cfg.model_type
model_config = load_model_config(cfg) model_config = load_model_config(cfg)
@@ -202,6 +239,7 @@ def load_model(
model_kwargs = {} model_kwargs = {}
model_kwargs["device_map"] = cfg.device_map model_kwargs["device_map"] = cfg.device_map
model_kwargs["max_memory"] = cfg.max_memory
model_kwargs["torch_dtype"] = cfg.torch_dtype model_kwargs["torch_dtype"] = cfg.torch_dtype
if cfg.model_revision: if cfg.model_revision:
@@ -222,7 +260,7 @@ def load_model(
load_in_4bit=True, load_in_4bit=True,
llm_int8_threshold=6.0, llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False, llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_compute_dtype=cfg.torch_dtype,
bnb_4bit_use_double_quant=True, bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4", bnb_4bit_quant_type="nf4",
) )
@@ -236,24 +274,12 @@ def load_model(
model_kwargs["use_flash_attention_2"] = True model_kwargs["use_flash_attention_2"] = True
try: try:
if ( if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
cfg.is_llama_derived_model
and not cfg.trust_remote_code
and not cfg.gptq
and not cfg.tensor_parallel
):
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
config_kwargs = {}
if cfg.rope_scaling:
config_kwargs["rope_scaling"] = cfg.rope_scaling
config = LlamaConfig.from_pretrained(
base_model_config,
**config_kwargs,
)
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs, **model_kwargs,
@@ -298,103 +324,78 @@ def load_model(
# device=cfg.device, # device=cfg.device,
# ) # )
# model.train() # sets to train instead of eval mode # model.train() # sets to train instead of eval mode
elif model_type == "MixFormerSequentialForCausalLM": elif model_type == "PhiForCausalLM":
from axolotl.models.phi import MixFormerSequentialForCausalLM from axolotl.models.phi import PhiForCausalLM
model = MixFormerSequentialForCausalLM.from_pretrained( model = PhiForCausalLM.from_pretrained(
base_model, base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs, **model_kwargs,
) )
elif model_type and not cfg.trust_remote_code and not cfg.tensor_parallel: elif model_type and not cfg.trust_remote_code:
if cfg.gptq: if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=model_config,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
else: else:
model = getattr(transformers, model_type).from_pretrained( model = getattr(transformers, model_type).from_pretrained(
base_model, base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
elif cfg.tensor_parallel:
model_kwargs.pop("device_map")
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
low_cpu_mem_usage=True,
offload_state_dict=True,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else: else:
config = AutoConfig.from_pretrained(
base_model,
trust_remote_code=cfg.trust_remote_code or False,
)
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
# when training starts # when training starts
if ( if (
hasattr(config, "max_seq_len") hasattr(model_config, "max_seq_len")
and config.max_seq_len and model_config.max_seq_len
and cfg.sequence_len > config.max_seq_len and cfg.sequence_len > model_config.max_seq_len
): ):
config.max_seq_len = cfg.sequence_len model_config.max_seq_len = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}") LOG.warning(f"increasing context length to {cfg.sequence_len}")
elif ( elif (
hasattr(config, "max_sequence_length") hasattr(model_config, "max_sequence_length")
and config.max_sequence_length and model_config.max_sequence_length
and cfg.sequence_len > config.max_sequence_length and cfg.sequence_len > model_config.max_sequence_length
): ):
config.max_sequence_length = cfg.sequence_len model_config.max_sequence_length = cfg.sequence_len
LOG.warning(f"increasing context length to {cfg.sequence_len}") LOG.warning(f"increasing context length to {cfg.sequence_len}")
if cfg.gptq: if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=model_config,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
else: else:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
config=config, config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
except Exception as err: # pylint: disable=broad-exception-caught except Exception as err: # pylint: disable=broad-exception-caught
LOG.error(
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
)
LOG.exception(err) LOG.exception(err)
model = AutoModelForCausalLM.from_pretrained( raise err
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
try: embeddings_len = (
embeddings_len = ( math.ceil(len(tokenizer) / 32) * 32
math.ceil(len(tokenizer) / 32) * 32 if cfg.resize_token_embeddings_to_32x
if cfg.resize_token_embeddings_to_32x else len(tokenizer)
else len(tokenizer) )
) if model.get_input_embeddings().num_embeddings < embeddings_len:
if model.get_input_embeddings().num_embeddings < embeddings_len: model.resize_token_embeddings(embeddings_len)
model.resize_token_embeddings(embeddings_len) else:
else: model.tie_weights()
model.tie_weights()
except NotImplementedError:
LOG.warning("`resize_token_embeddings` not implemented on model")
if ( if (
hasattr(model.config, "max_position_embeddings") hasattr(model.config, "max_position_embeddings")
@@ -435,15 +436,22 @@ def load_model(
module.to(torch.float32) module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True
if (cfg.adapter == "lora" and load_in_8bit) or ( if (cfg.adapter == "lora" and load_in_8bit) or (
cfg.adapter == "qlora" and cfg.load_in_4bit cfg.adapter == "qlora" and cfg.load_in_4bit
): ):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.gradient_checkpointing: if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training( if not skip_prepare_model_for_kbit_training:
model, use_gradient_checkpointing=cfg.gradient_checkpointing model = prepare_model_for_kbit_training(
) model, use_gradient_checkpointing=cfg.gradient_checkpointing
)
needs_fa2_dtype = True needs_fa2_dtype = True
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
@@ -462,14 +470,7 @@ def load_model(
if cfg.ddp and not load_in_8bit: if cfg.ddp and not load_in_8bit:
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")
if ( if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
torch.cuda.device_count() > 1
and int(os.getenv("WORLD_SIZE", "1")) > 1
and (cfg.load_in_4bit)
):
# llama is PROBABLY model parallelizable, but the default isn't that it is
# so let's only set it for the 4bit, see
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
setattr(model, "is_parallelizable", True) setattr(model, "is_parallelizable", True)
setattr(model, "model_parallel", True) setattr(model, "model_parallel", True)
@@ -497,12 +498,7 @@ def load_adapter(model, cfg, adapter, inference=False):
if adapter is None: if adapter is None:
return model, None return model, None
if hasattr(model, "enable_input_require_grads"): if hasattr(model, "enable_input_require_grads"):
try: model.enable_input_require_grads()
model.enable_input_require_grads()
except NotImplementedError:
LOG.warning("enable_input_require_grads not implemented on model")
if adapter == "qlora" and cfg.tensor_parallel:
model, _ = load_tp_qlora(model)
if adapter in ["lora", "qlora"]: if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference) return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter": if adapter == "llama-adapter":
@@ -554,25 +550,6 @@ def find_all_linear_names(model):
return list(lora_module_names) return list(lora_module_names)
def load_tp_qlora(model):
from transformers.utils.bitsandbytes import replace_with_bnb_linear
model = replace_with_bnb_linear(
model,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
),
)
model.is_loaded_in_4bit = True
return model, None
def load_lora(model, cfg, inference=False): def load_lora(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]

View File

@@ -0,0 +1,4 @@
"""
axolotl samplers module
"""
from .multipack import MultipackBatchSampler # noqa: F401

View File

@@ -0,0 +1,196 @@
# pylint: skip-file
"""
Multipack Batch Sampler
"""
import logging
import math
import os
from typing import Any, Iterable, List, Union
import numba
import numpy as np
from torch.utils.data import BatchSampler, Sampler
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
not_found = True
for idx in range(n):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
if not_found:
return False
return True
@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
indices = np.argsort(a)[::-1]
a = a[indices]
bins: List[Any] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
break
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
return bins_result
@numba.njit
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
):
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0
start_index = 0
result = []
while True:
# binary search [l, r)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
left = mid
else:
right = mid
# use length l
batch = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
assert len(batch) <= n
if len(batch) < n:
break
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
return result, s, len(result) * c * n
class MultipackBatchSampler(BatchSampler):
"""
Batch Sampler class for multipack
"""
def __init__(
self,
sampler: Union[Sampler[int], Iterable[int]],
batch_size: int,
drop_last: bool,
batch_max_len: int,
lengths: np.ndarray,
packing_efficiency_estimate: float = 1.0,
):
super().__init__(sampler, batch_size, drop_last)
self.batch_size = None
self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
assert isinstance(self.lengths, np.ndarray)
self.epoch = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
def set_epoch(self, epoch: int):
self.epoch = epoch
def generate_batches(self, set_stats=False):
indices = [idx for idx in self.sampler]
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
return batches
def __iter__(self):
batches = self.generate_batches(set_stats=True)
return iter(batches)
def num_batches(self):
batches = self.generate_batches(set_stats=True)
return len(batches)
def efficiency(self):
return self.eff_total_used / self.eff_total_slots
def __len__(self):
self.num_batches()
return self._len_est()
def _len_est(self):
world_size = int(os.getenv("WORLD_SIZE", "1"))
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // world_size
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return max(
0,
(
world_size
* math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
// self.batch_max_len
)
- 1
),
)

View File

@@ -1,5 +1,4 @@
"""Module containing the Trainer class and related functions""" """Module containing the Trainer class and related functions"""
import logging
import math import math
import os import os
from contextlib import contextmanager from contextlib import contextmanager
@@ -9,21 +8,15 @@ from typing import List
import numpy as np import numpy as np
import torch import torch
import torch.cuda import torch.cuda
import torch.distributed as dist from accelerate.logging import get_logger
from datasets import set_caching_enabled from datasets import set_caching_enabled
from torch.utils.data import DistributedSampler, RandomSampler from torch.utils.data import DataLoader, RandomSampler
from axolotl.core.trainer_builder import HFCausalTrainerBuilder from axolotl.core.trainer_builder import HFCausalTrainerBuilder
from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.distributed import (
is_distributed,
is_main_process,
reduce_and_broadcast,
zero_first,
)
LOG = logging.getLogger("axolotl") LOG = get_logger("axolotl")
@torch.jit.script @torch.jit.script
@@ -148,30 +141,35 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
return train_dataset, eval_dataset return train_dataset, eval_dataset
def calculate_total_num_steps(cfg, train_dataset, tokenizer): def calculate_total_num_steps(cfg, train_dataset, update=True):
if not cfg.total_num_tokens:
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
if update:
cfg.total_num_tokens = total_num_tokens
if not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
.apply(lambda x: np.sum(np.array(x) != -100))
.sum()
)
LOG.debug(
f"`total_supervised_tokens: {total_supervised_tokens}`",
main_process_only=True,
)
if update:
cfg.total_supervised_tokens = total_supervised_tokens
if cfg.sample_packing: if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise # we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails # flash attention with position ids fails
if not cfg.total_num_tokens:
LOG.info("calculating total_num_tokens")
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values
)
LOG.info(f"total_num_tokens: {total_num_tokens}")
cfg.total_num_tokens = total_num_tokens
if not cfg.total_supervised_tokens:
total_supervised_tokens = (
train_dataset.data.column("labels")
.to_pandas()
.apply(lambda x: np.sum(np.array(x) != -100))
.sum()
)
LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`")
cfg.total_supervised_tokens = total_supervised_tokens
if cfg.sample_packing_eff_est: if cfg.sample_packing_eff_est:
total_num_steps = ( total_num_steps = (
@@ -189,41 +187,41 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
) )
* cfg.num_epochs * cfg.num_epochs
) )
LOG.info( LOG.debug(
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}" f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}",
main_process_only=True,
) )
else: else:
if cfg.world_size > 1 and is_distributed(): sampler = MultipackBatchSampler(
sampler = DistributedSampler( sampler=RandomSampler(train_dataset),
train_dataset,
num_replicas=cfg.world_size,
rank=dist.get_rank(),
seed=cfg.seed or 42,
)
else:
sampler = RandomSampler(train_dataset)
data_loader = MultipackDistributedDataloader(
train_dataset,
batch_size=cfg.micro_batch_size, batch_size=cfg.micro_batch_size,
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len, drop_last=True,
collate_fn=DataCollatorForSeq2Seq( batch_max_len=cfg.micro_batch_size
tokenizer, * (cfg.max_packed_sequence_len or cfg.sequence_len),
return_tensors="pt", lengths=(
padding="longest", train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
), ),
sampler=sampler,
packing_efficiency_estimate=cfg.sample_packing_eff_est,
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
num_epochs=cfg.num_epochs,
) )
data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency() data_loader = DataLoader(
LOG.info(f"data_loader_len: {data_loader_len}") train_dataset.remove_columns(["length"]),
batch_sampler=sampler,
)
data_loader_len = len(data_loader)
actual_eff = sampler.efficiency()
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
# FIXME: is there a bug here somewhere? the total num steps depends # FIXME: is there a bug here somewhere? the total num steps depends
# on the agreed on value for sample_packing_eff_est # on the agreed on value for sample_packing_eff_est
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs)) total_num_steps = int(
math.floor(
data_loader_len
* cfg.num_epochs
/ int(os.environ.get("WORLD_SIZE", 1))
)
)
def calc_sample_packing_eff_est(estimates: List[float]): def calc_sample_packing_eff_est(estimates: List[float]):
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}") LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
@@ -236,13 +234,22 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
sample_packing_eff_est = ( sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0 math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
) )
cfg.sample_packing_eff_est = sample_packing_eff_est if update:
LOG.info(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}") cfg.sample_packing_eff_est = sample_packing_eff_est
LOG.debug(
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
main_process_only=True,
)
else: else:
total_num_steps = int( total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(
len(train_dataset)
* cfg.num_epochs
/ int(os.environ.get("WORLD_SIZE", 1))
/ cfg.batch_size
)
) )
LOG.info(f"total_num_steps: {total_num_steps}") LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
return total_num_steps return total_num_steps
@@ -260,12 +267,14 @@ def setup_fsdp_envs(cfg):
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap ] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): def prepare_optim_env(cfg):
if cfg.fsdp: if cfg.fsdp:
setup_fsdp_envs(cfg) setup_fsdp_envs(cfg)
elif cfg.deepspeed: elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer) trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
trainer_builder.train_dataset = train_dataset trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset trainer_builder.eval_dataset = eval_dataset

View File

@@ -2,20 +2,20 @@
import os import os
from axolotl.utils.dict import DictDefault
def setup_wandb_env_vars(cfg):
if cfg.wandb_mode and cfg.wandb_mode == "offline": def setup_wandb_env_vars(cfg: DictDefault):
os.environ["WANDB_MODE"] = cfg.wandb_mode for key in cfg.keys():
elif cfg.wandb_project and len(cfg.wandb_project) > 0: if key.startswith("wandb_"):
os.environ["WANDB_PROJECT"] = cfg.wandb_project value = cfg.get(key, "")
if value and isinstance(value, str) and len(value) > 0:
os.environ[key.upper()] = value
# Enable wandb if project name is present
if cfg.wandb_project and len(cfg.wandb_project) > 0:
cfg.use_wandb = True cfg.use_wandb = True
if cfg.wandb_entity and len(cfg.wandb_entity) > 0: os.environ.pop("WANDB_DISABLED", None) # Remove if present
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
os.environ["WANDB_WATCH"] = cfg.wandb_watch
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
else: else:
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"

0
tests/e2e/__init__.py Normal file
View File

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import logging import logging
import os import os
import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
@@ -16,6 +15,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -25,9 +26,9 @@ class TestFusedLlama(unittest.TestCase):
Test case for Llama models using Fused layers Test case for Llama models using Fused layers
""" """
def test_fft_packing(self): @with_temp_dir
def test_fft_packing(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "JackFram/llama-68m",
@@ -51,7 +52,7 @@ class TestFusedLlama(unittest.TestCase):
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": output_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -69,4 +70,4 @@ class TestFusedLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "pytorch_model.bin").exists() assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import logging import logging
import os import os
import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
@@ -14,6 +13,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -23,9 +24,9 @@ class TestLoraLlama(unittest.TestCase):
Test case for Llama models using LoRA Test case for Llama models using LoRA
""" """
def test_lora(self): @with_temp_dir
def test_lora(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "JackFram/llama-68m",
@@ -52,7 +53,7 @@ class TestLoraLlama(unittest.TestCase):
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 8, "micro_batch_size": 8,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": output_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -63,11 +64,11 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
def test_lora_packing(self): @with_temp_dir
def test_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "JackFram/llama-68m",
@@ -96,10 +97,11 @@ class TestLoraLlama(unittest.TestCase):
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 8, "micro_batch_size": 8,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": output_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"bf16": True,
} }
) )
normalize_config(cfg) normalize_config(cfg)
@@ -107,11 +109,11 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
def test_lora_gptq(self): @with_temp_dir
def test_lora_gptq(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ", "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
@@ -144,7 +146,7 @@ class TestLoraLlama(unittest.TestCase):
"save_steps": 0.5, "save_steps": 0.5,
"micro_batch_size": 8, "micro_batch_size": 8,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": output_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -155,4 +157,4 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import logging import logging
import os import os
import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
@@ -16,6 +15,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -25,9 +26,9 @@ class TestMistral(unittest.TestCase):
Test case for Llama models using LoRA Test case for Llama models using LoRA
""" """
def test_lora(self): @with_temp_dir
def test_lora(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "openaccess-ai-collective/tiny-mistral",
@@ -54,7 +55,7 @@ class TestMistral(unittest.TestCase):
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": output_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -68,11 +69,11 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
def test_ft(self): @with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "openaccess-ai-collective/tiny-mistral",
@@ -93,7 +94,7 @@ class TestMistral(unittest.TestCase):
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": output_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -111,4 +112,4 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "pytorch_model.bin").exists() assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -4,7 +4,6 @@ E2E tests for lora llama
import logging import logging
import os import os
import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
@@ -16,6 +15,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -25,9 +26,9 @@ class TestMistral(unittest.TestCase):
Test case for Llama models using LoRA Test case for Llama models using LoRA
""" """
def test_lora_packing(self): @with_temp_dir
def test_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "openaccess-ai-collective/tiny-mistral",
@@ -55,7 +56,7 @@ class TestMistral(unittest.TestCase):
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": output_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -69,11 +70,11 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
def test_ft_packing(self): @with_temp_dir
def test_ft_packing(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "openaccess-ai-collective/tiny-mistral",
@@ -95,7 +96,7 @@ class TestMistral(unittest.TestCase):
"num_epochs": 2, "num_epochs": 2,
"micro_batch_size": 2, "micro_batch_size": 2,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": output_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -113,4 +114,4 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "pytorch_model.bin").exists() assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -4,8 +4,8 @@ E2E tests for lora llama
import logging import logging
import os import os
import tempfile
import unittest import unittest
from pathlib import Path
from axolotl.cli import load_datasets from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
@@ -13,6 +13,8 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -22,13 +24,14 @@ class TestPhi(unittest.TestCase):
Test case for Llama models using LoRA Test case for Llama models using LoRA
""" """
def test_ft(self): @with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "microsoft/phi-1_5", "base_model": "microsoft/phi-1_5",
"trust_remote_code": True, "trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM", "model_type": "PhiForCausalLM",
"tokenizer_type": "AutoTokenizer", "tokenizer_type": "AutoTokenizer",
"sequence_len": 512, "sequence_len": 512,
"sample_packing": False, "sample_packing": False,
@@ -52,7 +55,7 @@ class TestPhi(unittest.TestCase):
"num_epochs": 1, "num_epochs": 1,
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(), "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit", "optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -64,14 +67,16 @@ class TestPhi(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
def test_ft_packed(self): @with_temp_dir
def test_ft_packed(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "microsoft/phi-1_5", "base_model": "microsoft/phi-1_5",
"trust_remote_code": True, "trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM", "model_type": "PhiForCausalLM",
"tokenizer_type": "AutoTokenizer", "tokenizer_type": "AutoTokenizer",
"sequence_len": 512, "sequence_len": 512,
"sample_packing": True, "sample_packing": True,
@@ -95,7 +100,7 @@ class TestPhi(unittest.TestCase):
"num_epochs": 1, "num_epochs": 1,
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"output_dir": tempfile.mkdtemp(), "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_bnb_8bit", "optimizer": "adamw_bnb_8bit",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
@@ -107,3 +112,4 @@ class TestPhi(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

95
tests/e2e/test_resume.py Normal file
View File

@@ -0,0 +1,95 @@
"""
E2E tests for resuming training
"""
import logging
import os
import re
import subprocess
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import most_recent_subdir, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestResumeLlama(unittest.TestCase):
"""
Test case for resuming training of llama models
"""
@with_temp_dir
def test_resume_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 32,
"lora_alpha": 64,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_steps": 10,
"save_total_limit": 5,
"max_steps": 40,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
resume_cfg = cfg | DictDefault(
{
"resume_from_checkpoint": f"{temp_dir}/checkpoint-30/",
}
)
normalize_config(resume_cfg)
cli_args = TrainerCliArgs()
train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
cmd = f"tensorboard --inspect --logdir {tb_log_path_1}"
res = subprocess.run(
cmd, shell=True, text=True, capture_output=True, check=True
)
pattern = r"first_step\s+(\d+)"
first_steps = int(re.findall(pattern, res.stdout)[0])
assert first_steps == 31

33
tests/e2e/utils.py Normal file
View File

@@ -0,0 +1,33 @@
"""
helper utils for tests
"""
import os
import shutil
import tempfile
from functools import wraps
from pathlib import Path
def with_temp_dir(test_func):
@wraps(test_func)
def wrapper(*args, **kwargs):
# Create a temporary directory
temp_dir = tempfile.mkdtemp()
try:
# Pass the temporary directory to the test function
test_func(*args, temp_dir=temp_dir, **kwargs)
finally:
# Clean up the directory after the test
shutil.rmtree(temp_dir)
return wrapper
def most_recent_subdir(path):
base_path = Path(path)
subdirectories = [d for d in base_path.iterdir() if d.is_dir()]
if not subdirectories:
return None
subdir = max(subdirectories, key=os.path.getctime)
return subdir

View File

@@ -1,6 +1,7 @@
"""Module for testing the validation module""" """Module for testing the validation module"""
import logging import logging
import os
import unittest import unittest
from typing import Optional from typing import Optional
@@ -8,6 +9,7 @@ import pytest
from axolotl.utils.config import validate_config from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.wandb_ import setup_wandb_env_vars
class ValidationTest(unittest.TestCase): class ValidationTest(unittest.TestCase):
@@ -649,3 +651,113 @@ class ValidationTest(unittest.TestCase):
) )
validate_config(cfg) validate_config(cfg)
def test_warmup_step_no_conflict(self):
cfg = DictDefault(
{
"warmup_steps": 10,
"warmup_ratio": 0.1,
}
)
with pytest.raises(
ValueError,
match=r".*warmup_steps and warmup_ratio are mutually exclusive*",
):
validate_config(cfg)
cfg = DictDefault(
{
"warmup_steps": 10,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"warmup_ratio": 0.1,
}
)
validate_config(cfg)
class ValidationWandbTest(ValidationTest):
"""
Validation test for wandb
"""
def test_wandb_set_run_id_to_name(self):
cfg = DictDefault(
{
"wandb_run_id": "foo",
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
in record.message
for record in self._caplog.records
)
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
cfg = DictDefault(
{
"wandb_name": "foo",
}
)
validate_config(cfg)
assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
def test_wandb_sets_env(self):
cfg = DictDefault(
{
"wandb_project": "foo",
"wandb_name": "bar",
"wandb_run_id": "bat",
"wandb_entity": "baz",
"wandb_mode": "online",
"wandb_watch": "false",
"wandb_log_model": "checkpoint",
}
)
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_PROJECT", "") == "foo"
assert os.environ.get("WANDB_NAME", "") == "bar"
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
assert os.environ.get("WANDB_ENTITY", "") == "baz"
assert os.environ.get("WANDB_MODE", "") == "online"
assert os.environ.get("WANDB_WATCH", "") == "false"
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
assert os.environ.get("WANDB_DISABLED", "") != "true"
def test_wandb_set_disabled(self):
cfg = DictDefault({})
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_DISABLED", "") == "true"
cfg = DictDefault(
{
"wandb_project": "foo",
}
)
validate_config(cfg)
setup_wandb_env_vars(cfg)
assert os.environ.get("WANDB_DISABLED", "") != "true"