Compare commits

...

496 Commits

Author SHA1 Message Date
mhenrichsen
9084879861 tinyllama 2023-11-16 13:36:01 +00:00
NanoCode012
3cc67d2cdd Feat: Add dataset loading from S3, GCS (#765)
* Feat: Add dataset loading from S3, GCS

* chore: update docs

* chore: add more info on cloud loading
2023-11-16 14:33:58 +09:00
Wing Lian
1bc11868eb allow overriding of model_config parameters from the YML (#853)
* allow overriding of model_config parameters from the YML

* remove old logging, update readme

* move the updating of model config to the load_model_config function

* add warning for deprecated rope_scaling in the root of the YML config
2023-11-15 23:47:08 -05:00
Wing Lian
b3a61e8ce2 add e2e tests for checking functionality of resume from checkpoint (#865)
* use tensorboard to see if resume from checkpoint works

* make sure e2e test is either fp16 or bf16

* set max_steps and save limit so we have the checkpoint when testing resuming

* fix test parameters
2023-11-15 23:05:55 -05:00
Wing Lian
8a8d1c4023 make docker command more robust (#861)
* make docker command more robust

* update readme with more info
2023-11-15 23:03:54 -05:00
Wing Lian
332984db18 lint fix that didn't get caught by linter (#866) 2023-11-15 14:36:40 -05:00
MilesQLi
48630f5b34 Update data.py for signature generation (#851)
* Update data.py

Change of conversation formatting type should also trigger updating the preprocessed dataset, so it should be part of the signature.

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-11-15 14:12:32 -05:00
Zongheng Yang
b33c1d55a2 Docs: add instructions to 1-click launching on public clouds (#862)
* Update README.md

* Update ToC
2023-11-15 14:11:27 -05:00
Wing Lian
0c2a630326 multipack len should use max, not min (#863) 2023-11-15 12:52:32 -05:00
Wing Lian
db8a8afcba adds llama and mistral dropout support (#858)
* adds llama and mistral dropout support

* gracefully handle attention dropout if not available yet
2023-11-15 12:28:50 -05:00
Wing Lian
14706504e3 various bugfixes (#856)
* various bugfixes

use latest tinyllama release
check if val_set_size is empty first
update sdp and xformers llama patches for updated upstream transformers
fix system prompt when no input
calculate total and total supervised tokens even when not sample packing

* add fix for when eval size is estimated to be too small

* should be len 1 for dataset length

* add catchall kwargs
2023-11-15 12:23:18 -05:00
NanoCode012
501b4d1379 chore(doc): Separate section on runpod (#860) 2023-11-16 01:06:51 +09:00
NanoCode012
306fe19c54 feat(doc): add more info on train_on_split (#855) 2023-11-15 23:42:26 +09:00
Fabian Preiß
614cff4107 include the suffix modified string in ascii art (#852) 2023-11-15 07:12:28 -05:00
Wing Lian
1a6309c8a6 cleanup the old multipack dataloader (#841) 2023-11-12 05:39:09 -05:00
Bryan Thornbury
105d0b350b Pin optimum package (#838) 2023-11-09 22:36:15 -05:00
Wing Lian
f544ab2bed don't compile deepspeed or bitsandbytes from source (#837) 2023-11-08 19:49:55 -05:00
Wing Lian
641e6f7e51 multipack w batch sampler (#795)
* test batch sampler w varying batch lens

* wip

* multipack batchsampler wip

* wip

* fix for prepare data loader to get correct # of steps based on gpues

* lint and clean up

* calculate len estimate

* fix total num steps calc

* add options for dataloader_num_workers and dataloader_pin_memory

* remove gitbook

* support prefetch_factor for dataloader optimization

* fix the kwarg
2023-11-07 20:27:40 -05:00
Wing Lian
6dc68a653f use temp_dir kwarg instead 2023-11-06 18:33:01 -05:00
Wing Lian
7de6a5639c missing dunder-init 2023-11-06 18:33:01 -05:00
Wing Lian
c74f045ba7 chore: lint 2023-11-06 18:33:01 -05:00
Wing Lian
0402d19759 make sure to cleanup tmp output_dir for e2e tests 2023-11-06 18:33:01 -05:00
Wing Lian
b2430ce670 use accelerate logging for zero/main loggin only 2023-11-06 18:32:26 -05:00
Wing Lian
4c834bf25d cleanup verbosity a bit 2023-11-06 18:32:26 -05:00
Fabian Preiß
8056ecd30e add deepspeed-kernels dependency for deepspeed>=0.12.0 (#827) 2023-11-05 07:52:56 -05:00
Jason Stillerman
738a057674 Feat: Added Gradio support (#812)
* Added gradio support

* queuing and title

* pre-commit run
2023-11-04 23:59:22 -04:00
Wing Lian
cdc71f73c8 update table for rwkv4 support, fix process count for dataset (#822) 2023-11-04 23:45:44 -04:00
NanoCode012
6459ac7357 fix: pin autogptq (#818) 2023-11-03 10:14:55 -04:00
Wing Lian
964d858da0 fix model parallel (#816) 2023-11-02 21:34:22 -04:00
NanoCode012
10388a8daf fix(tokenizer): update log order after update (#806) 2023-10-31 13:21:20 +09:00
NanoCode012
9f7e8a971d feat(doc): add dummyoptim faq fix (#802) 2023-10-29 23:06:06 +09:00
NanoCode012
637ed095a0 fix(config): Set eos/bos to tokenizer if different (#801)
* fix(config): Set eos/bos to tokenizer if different

* chore: fix lint
2023-10-29 21:32:37 +09:00
Wing Lian
827ec3d274 refactor neft patch to be more re-usable similar to trl's impl (#796) 2023-10-29 04:33:13 -04:00
Wing Lian
8b79ff0e94 fix eval_steps to be a sane default (#797)
* fix eval_steps to be a sane default

* update docs for fractional eval_steps
2023-10-27 22:36:30 -04:00
MilesQLi
0800885e2f Update to adapt to sharegpt datasets with "assistant" rather than "gp… (#774)
* Update to adapt to sharegpt datasets with "assistant" rather than "gpt" as the machine answers.

* use a strict option for hanedling incorrect turn data

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-10-27 22:00:16 -04:00
Teknium
d3193beac3 Fix Deepspeed Zero3 Config (#791)
* Update zero3.json

Take away CPU Offload by default (Slows things down horribly, better off reducing batchsize), and changes LR Scheduler to a properly decaying one

* Update zero3.json

fix something
2023-10-27 21:57:02 -04:00
Aleksa Gordić
2e71ff03a6 Add docker advanced instruction to README (#792) 2023-10-27 09:24:04 -04:00
chanvichetvong
facc49f32b GitBook: No commit message 2023-10-26 15:11:00 +00:00
Casper
e50ab072e2 Create preprocess CLI (#785)
* Create preprocess CLI

* Print prompt template if debugging

* Add print for unsupported prompters

* Formatting

* Formatting

* Refactor variables

* Formatting

* Formatting

* Formatting

* Formatting
2023-10-26 09:35:42 -04:00
Casper
05bd6f1122 Threaded MultipackDistributedDataloader with prefetched samples (#759)
* Multithreading implementation [WIP]

* Added benchmarking

* 35% increased throughput

* Memory pinning

* Start threads in init

* Correct print of samples

* Sleep if queue is full

* Remove pin_memory (worse)

* Simplify logic to one thread

* Remove benchmark

* Use deque for constant speed

* Formatting

* Formatting

* Formatting

* Formatting

* Rollback to use queue

* Fix multi-epoch training

* Add num epochs arg

* Start thread in __iter__

* Formatting

* Use is_alive correctly

* Simplify loading thread
2023-10-26 07:49:52 +02:00
NanoCode012
20aa4b57d2 chore(readme): Improve documentation on conversation field (#782)
* chore(readme): Improve documentation on conversation field

* fix: clarify where the option is
2023-10-24 12:52:32 +09:00
NanoCode012
11d1d607db chore: refactor truthy check and fix mypy (#780) 2023-10-24 12:28:40 +09:00
Wing Lian
6c81c61bc4 refactor setup trainer so we can add more hooks (#773)
* refactor setup trainer so we can add more hooks

* Remove stray comma
2023-10-23 17:38:41 -04:00
Wing Lian
9b43e7ea15 disable eval table w sample packing in examples (#778) 2023-10-23 09:18:44 -04:00
Wing Lian
2d8def68dc simplify by removing duplicate base_model_config (#772) 2023-10-23 01:42:38 -04:00
NanoCode012
44c9d0151a Fix: Warn when fullfinetune without adapter (#770) 2023-10-22 15:41:43 -04:00
Wing Lian
ca84cca2c0 convert exponential notation lr to floats (#771) 2023-10-22 15:37:03 -04:00
Casper
32eeeb5b64 Hotfix for not saving correctly (#762) 2023-10-22 13:22:32 -04:00
NanoCode012
afedc470bd Fix: Cannot tokenize with bf16 and on cpu (#766) 2023-10-23 01:32:26 +09:00
NanoCode012
9923b72649 Fix: eval table conflict with eval_sample_packing (#769) 2023-10-23 01:18:12 +09:00
Wing Lian
21cf09b608 remove lora fused packing test (#758) 2023-10-21 22:59:35 -04:00
Casper
15d3a654bf Implement fused modules (#747)
* MLP: Memory saving

* Remove RMSNorm restrictions

* Map packed weights to original

* FusedAttention module

* Simplify code

* Move fused modules

* Fix critical typo

* Split inplace

* Add FFT config

* Add validation of fused arguments

* Add fused arguments to config

* Update docs

* Fix validation logic

* Add fused modules to flash attn

* Only fuse during training

* Remove timing

* Formatting

* Formatting

* Formatting

* chore: lint

* chore: lint

* add e2e tests for fused llama

* no lora for tests

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-10-21 16:08:25 -04:00
Wing Lian
a21935f07a add to docs (#703) 2023-10-19 21:32:30 -04:00
NanoCode012
8966a6f566 chore: bump transformers to v4.34.1 to fix tokenizer issue (#745) 2023-10-19 20:18:22 -04:00
Motoki Wu
e4d1585c4e Fix DeepSpeed Zero 3 Saving (#709)
* Update train.py

* add zero3 check

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-10-19 19:18:24 -04:00
Wing Lian
70157ccb8f add a latest tag for regular axolotl image, cleanup extraneous print statement (#746) 2023-10-19 12:28:29 -04:00
seungduk.kim.2304
3a99495b05 improve: Enhance code readability of prompt_tokenizers.py (#707) 2023-10-19 08:12:17 -04:00
NanoCode012
440c3ab527 Fix(model): Linear detected and added to target module with rope linear (#738)
* Fix(model): Linear detected and added to target module with rope linear

* fix: exclude layer instead
2023-10-18 22:13:20 -04:00
Napuh
992d57f20a catch ConnectionError when checking dataset from HuggingFace (#743) 2023-10-18 22:11:54 -04:00
mhenrichsen
91a016f410 badge (#739)
* badge

* fixed text
2023-10-18 10:21:34 -04:00
Casper
a045db0214 Mistral: Sliding Window Attention with Flash Attention and Sample Packing (#732)
* Implement Mistral FA + SWA + Sample Packing

* Handle unbroadcastable tensor

* chore: lint

* Simplify _prepare_decoder_attention_mask

* Uncomment window size

* Upgrade flash-attn to minimum of 2.3.0 to support SWA

* Add original condition to avoid error during inference

* chore: lint

* use torchscript to prevent oom

* chore: pylint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-10-16 15:13:46 -04:00
Casper
e1b214c62b Clarify custom format example (#729)
* Clarify custom prompt format

* Simplify format
2023-10-14 09:28:12 -04:00
Wing Lian
3553172e3c fixes for alpaca w chatml, and don't include attention_mask w mistral for flash attention (#728) 2023-10-14 09:27:07 -04:00
Wing Lian
7f2027d93f tweak for xformers install w pytorch 2.1.0 (#727) 2023-10-13 15:21:17 -04:00
Wing Lian
8d288a2ad4 workaround for installing xformers w torch 2.1.0 (#725) 2023-10-13 11:19:30 -04:00
Wing Lian
f30afe4544 misc sharegpt fixes (#723)
* support for sharegpt with assistant talking first, better masking of assistant token, allow remap of roles from dataset

* invalid role is actually not possible

* update tokenized fixture for corrected labels
2023-10-13 11:04:39 -04:00
Wing Lian
bfbdba8614 pin xformers >= 0.0.22 (#724) 2023-10-13 10:27:56 -04:00
Maxime
3bd9528390 add noisy embedding (#721)
* add noisy embedding

* fix format

* Update README.md

* Update README.md

* linter issues

* caseus fixes

---------

Co-authored-by: Maxime <maxime@nope.no>
2023-10-13 10:00:42 -04:00
Wing Lian
2aa1f71464 fix pytorch 2.1.0 build, add multipack docs (#722) 2023-10-13 08:57:28 -04:00
Wing Lian
1c412c7e9d improve handling of the prepared ds path and other cfg defaults (#701) 2023-10-13 07:46:07 -04:00
Jan Philipp Harries
490923fb78 Save Axolotl config as WandB artifact (#716) 2023-10-11 07:28:12 -04:00
NanoCode012
5855dded3d fix(doc): update default doc according to arg (#714) 2023-10-10 21:51:56 +09:00
atgctg
ace70b33c6 Fix: lowercase True values in config (#713)
* Fix: lowercase `True` values in config

* Fix: lowercase `True` values in config
2023-10-10 21:32:20 +09:00
NanoCode012
11c48c5e03 fix(doc): Add note on inference w sample packing (#712) 2023-10-10 21:08:17 +09:00
lukemarsden
295b2662e1 Get qlora mistral-7b fine tuning working on a single 4090 (#708) 2023-10-10 15:14:23 +09:00
seungduk.kim.2304
77c84e02fd Update README with some explanations (#700)
* Update README with some explanations

* revert commit-hook change

* add more explanation about batch size and gradient accum

* not use latex foromat

* decorate

* git hook again

* Attach a link that explains about LoRA hyperparameters

* update table of content

* Explanation about lora_modules_to_save
2023-10-08 13:37:54 -04:00
mhenrichsen
f91db198f3 fix unneeded space (#699) 2023-10-07 14:19:25 -04:00
Wing Lian
7f2618b5f4 add docker images for pytorch 2.10 (#697) 2023-10-07 12:23:31 -04:00
Wing Lian
aca0398315 apex not needed as amp is part of pytorch (#696) 2023-10-07 12:20:45 -04:00
mhenrichsen
29b8f46aed Merge pull request #693 from OpenAccess-AI-Collective/update-mistral-example
update mistral lr, sample pack
2023-10-07 11:04:58 +02:00
mhenrichsen
83a950bb87 lint 2023-10-07 11:04:35 +02:00
Wing Lian
de87ea68f6 fix multiline for docker (#694) 2023-10-06 22:38:15 -04:00
mhenrichsen
4c8ddf2c6f new lr, sample pack 2023-10-06 22:58:13 +02:00
NanoCode012
669f1d052c Fix: Higher vram usage for mistral and sample_packing (#691)
* Fix: Higher vram usage for mistral and sample_packing

* chore: update comment

* chore: lint
2023-10-06 12:33:43 -04:00
Abhishek Mishra
d4a88e4eca Adding qlora config for Mistral (#675)
* Adding qlora config for Mistral

Contains fix for Mistral FA issue - ValueError: You are attempting to perform batched generation with padding_side='right' this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to  call tokenizer.padding_side  = 'left' before tokenizing the input.

Fix for now is to set sample_packing: true and pad_to_sequence_len: true

* Renamed to qlora.yml
2023-10-06 21:05:56 +09:00
Wing Lian
2d60ba3a6e flash_attention + sample packing for stablelm 3b (#671)
* stablelm epoch fa patch

* is causal for fa

* working stablelm fa w packing

* chore: pre-commit linting
2023-10-05 16:03:43 -04:00
NanoCode012
eb480dfd68 Fix: ValueError when FA + Mistral when padding_side=right (#681)
* Fix: ValueError when FA + Mistral when padding_side=right

* fix: remove tokenizer class check
2023-10-06 04:12:54 +09:00
NanoCode012
133e676bcc Feat: Set WORKDIR to /workspace/axolotl (#679) 2023-10-06 04:09:14 +09:00
NanoCode012
69fac9a020 Fix: Future deprecation warning with use_auth_token (#680) 2023-10-06 03:56:18 +09:00
NanoCode012
e0b7eeabfd Fix(tokenizer): Set rstrip,lstrip,norm to False (#678) 2023-10-06 03:50:49 +09:00
NanoCode012
43856c0a39 Fix(version): Update FA to work with Mistral SWA (#673) 2023-10-04 21:32:19 +09:00
NanoCode012
e62d5901b5 chore: Clean up repetitive model kwargs (#670) 2023-10-04 20:41:26 +09:00
NanoCode012
697c50d408 Feat: Allow usage of native Mistral FA when no sample_packing (#669)
* Allow usage of native Mistral FA when no sample_packing

* fix: do not apply custom patch when sample_pack off

* chore: lint

* chore: pin transformer to v4.35.0.dev0

* fix: split sample_packing to separate test
2023-10-04 20:40:47 +09:00
NanoCode012
90e0d673f7 Feat: Add config yaml to section for reprod in bug-report.yaml (#667)
* Update bug-report.yaml

* Update bug-report.yaml

* Update bug-report.yaml
2023-10-03 23:38:42 +09:00
Wing Lian
2642caedf2 refactor to set eval_batch_size earlier if unset, so we can warn if mismatched (#662) 2023-10-02 21:08:07 -04:00
Wing Lian
f34648c8b9 remove patch fix for phi (#664) 2023-10-02 21:07:41 -04:00
Wing Lian
e50a64e85e prepared dataset caching, other misc fixes (#665)
* prepared dataset caching, other misc fixes

* also don't load from disk cache unless explicit
2023-10-02 21:07:24 -04:00
Wing Lian
f4868d733c make sure we also run CI tests when requirements.txt changes (#663) 2023-10-02 08:43:40 -04:00
Napuh
a7e56d83c2 removed duplicate on requirements.txt (#661) 2023-10-02 08:40:05 -04:00
Wing Lian
5b0bc48fbc add mistral e2e tests (#649)
* mistral e2e tests

* make sure to enable flash attention for the e2e tests

* use latest transformers full sha

* uninstall first
2023-09-29 00:22:40 -04:00
Kyle Corbitt
9ec20777ba Make dataset_processes configurable (#651)
I'm using the Axolotl script to train models on https://modal.com serverless GPUs. Unfortunately, their environment seems to have some kind of bug where if I try to run `datasets.filter` with too high a `num_proc`, it throws an error and dies.

This PR adds a new configuration option `dataset_processes`, which lets you explicitly set the number of processes used to map/filter the dataset. If not included, this defaults to the current behavior of setting that to `os.cpu_count()`.
2023-09-29 00:22:22 -04:00
ich
590d6032fd Fix bug when using pretokenized datasets (#652)
* fix pretokenized datasets readme

* check if dataset type is not set to handle pretokenized datasets
2023-09-28 22:54:10 -04:00
Wing Lian
409ca0f21c add support for defined train split (#654) 2023-09-28 20:14:14 -04:00
Wing Lian
8662e8ffe8 don't strip the prompt for check since we don't strip to tokenize anymore (#650) 2023-09-28 12:21:51 -04:00
Wing Lian
b2edaaeff6 fix for flash attn w mistral w/o sammple packing (#648) 2023-09-28 10:57:37 -04:00
Adarsh Shirawalmath
b88f51512a Update mistral/README.md (#647) 2023-09-28 10:24:56 -04:00
NanoCode012
eb41f76f92 Feat: Add example for Mistral (#644)
* Feat: Add example for Mistral

* chore: turn off flash

* chore: add is_mistral_derived_model

* chore: update following PR
2023-09-28 20:15:00 +09:00
NanoCode012
383f88d7a7 Fix(cfg): Add validation for save_strategy and eval_strategy (#633)
* Fix(cfg): Check save_strategy cfg conflict with save_steps

* Fix(cfg): Check evaluation_strategy cfg conflict with eval_steps

* chore: add extra check for steps only
2023-09-28 10:14:41 +09:00
Wing Lian
b6ab8aad62 Mistral flash attn packing (#646)
* add mistral monkeypatch

* add arg for decoder attention masl

* fix lint for duplicate code

* make sure to update transformers too

* tweak install for e2e

* move mistral patch to conditional
2023-09-27 18:41:00 -04:00
Napuh
85b0be2ba7 Warn users to login to HuggingFace (#645)
* added warning if user is not logged in HF

* updated doc to suggest logging in to HF
2023-09-27 17:43:35 -04:00
Ethan Smith
8fe0e633d2 Fix bug in dataset loading (#284)
* Fix bug in dataset loading

This fixes a bug when loading datasets. `d.data_files` is a list, so it cannot be directly passed to `hf_hub_download`

* Check type of data_files, and load accordingly
2023-09-27 13:41:31 -04:00
Felix Yan
d1236f2c41 Correct typos in datasets.py (#639) 2023-09-27 12:12:10 -04:00
Wing Lian
895f0a0723 skip some flash attn patches unless explicitly enabled (#643)
* skip some flash attn patches if explicitly disabled

* make the other patches optional
2023-09-27 12:11:07 -04:00
Wing Lian
e7d3e2dbb6 use fastchat conversations template (#578)
* use fastchat conversations template

* require fastchat (fschat) pip install

* handle roles dynamically from conversation

* tweak fastchat conversation with a monkeypatch to get individual turns

* fix up so it works with multiple conversation styles, and don't strip the turns

* fix sharegpt fixture now that we're using a more correct tokenization

* use a new prompter and support fastchat conversation type

* use sharegpt from prompt strategies now

* update docs, add chatml template

* add a newline after im_end token

* ensure we correctly set system message

* update per PR feedback to handle deprecated sharegpt types

* don't add duplicate wandb req

* make sharegpt fields configurable from yml

* llama2 fixes

* don't fail fatally when turns are improper
2023-09-27 12:10:45 -04:00
Wing Lian
60c7c48c97 update for recent transformers updates (#636)
* update for recent transformers updates

* fix checkpoint forward kwargs

* just pass args into torch checkpoint
2023-09-27 12:10:32 -04:00
Wing Lian
e8cbf50be6 attention_mask not needed for training (#642)
* attention_mask not needed for training

* specifically don't use attention mask for phi

* use a different check for phi

* small fixes since phi removed some values from their config
2023-09-27 11:12:08 -04:00
Wing Lian
d887ad86c3 eval_table isn't quite stable enough to be in default llama configs (#637) 2023-09-26 10:13:20 -04:00
NanoCode012
19a600a8b8 Feat: Add support for upstream FA2 (#626)
* Feat: Add support for upstream FA2

* chore: add is_falcon_derived_model: true to examples

* chore: add config to readme for documentation

* feat: add extra model types

* fix: remove old falcon flash patch

* chore: pin transformers and accelerate
2023-09-26 09:53:28 -04:00
Fernando Tarin Morales
5e5296a77c Added quotes to the pip install -e command to fix an incompatibility with shells that do glob expansion like zsh (#632) 2023-09-25 11:50:14 -04:00
mhenrichsen
f3d939016a Merge pull request #629 from OpenAccess-AI-Collective/chore/-change-default-model
default model changed
2023-09-25 09:32:01 +02:00
NanoCode012
cfbce020e9 Fix: Fail bf16 check when running on cpu during merge (#631) 2023-09-25 13:48:18 +09:00
mhenrichsen
4fecbfe5e1 default model changed 2023-09-24 18:52:53 +02:00
NanoCode012
67b9888630 Feat(doc): Add eval_sample_packing to doc (#625) 2023-09-23 13:11:27 +09:00
Maxime
923eb91304 tweak: improve base builder for smaller layers (#500) 2023-09-22 16:17:50 -04:00
Wing Lian
a363604dcf better handling and logging of empty sharegpt turns (#603) 2023-09-22 16:13:42 -04:00
Wing Lian
501958bb6f create a model card with axolotl badge (#624) 2023-09-22 16:13:26 -04:00
Wing Lian
c25ba7939b update README w deepspeed info (#605) 2023-09-22 00:15:52 -04:00
NanoCode012
d5f8589021 chore(callback): Remove old peft saving code (#510) 2023-09-22 12:31:33 +09:00
Wing Lian
03e59077a0 misc fixes to add gptq tests (#621)
* misc fixes to add gptq tests

* set bf16 needed for fa2
2023-09-21 21:52:12 -04:00
Wing Lian
97d3776ce6 split completion text to sequence_len (#616) 2023-09-21 21:51:25 -04:00
Wing Lian
2844eb22b6 run eval on the first step to get a baseline (#617)
* run eval on the first step to get a baseline

* wandb kleeps getting moved around by pre-commit ...
2023-09-21 21:51:09 -04:00
Wing Lian
e85d2eb06b let MAX_JOBS use the default since we're not resource constrained on our self-hosted runners (#427) 2023-09-21 20:36:30 -04:00
Wing Lian
196ff1181e skip the gpu memory checks if the device is set to 'auto' (#609)
* skip the gpu memory checks if the device is set to 'auto'

* skip gpu mem logging if cpu too

* don't worry about log_gpu_memory_usage since it calls another annotated fn

* rename decorator internal
2023-09-21 15:20:31 -04:00
Wing Lian
92512c390b ignore wandb to resolve isort headaches (#619) 2023-09-21 11:50:09 -04:00
Maxime
2fe95cdcc1 fix distributed devices (#612)
* fix distributed devices

* Update distributed.py

* Update distributed.py
2023-09-21 09:11:34 -04:00
Maxime
c1382e79b6 Create multi-node.md (#613)
* Create multi-node.md

* Update multi-node.md

* Update multi-node.md
2023-09-20 22:02:16 -04:00
Maxime
5d931cc042 Only run tests when a change to python files is made (#614)
* Update tests.yml

* Update .github/workflows/tests.yml

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-09-20 22:02:04 -04:00
Javier
ec0958f4f8 Update requirements.txt (#610) 2023-09-20 08:40:49 -04:00
Wing Lian
faecff9798 support to disable exllama for gptq (#604)
* support to disable exllama for gptq

* update property instead of item

* fix config key
2023-09-19 17:51:08 -04:00
bofeng huang
aa656e04bd Delete duplicate lines (#606) 2023-09-19 16:40:05 -04:00
Wing Lian
b53e77775b update dockerfile to not build evoformer since it fails the build (#607) 2023-09-19 16:28:29 -04:00
Wing Lian
674c57692d more sane defaults for openllama 3b used for quickstarts (#602)
* more sane defaults for openllama 3b used for quickstarts

* don't use bf16 for quickstart to simplify gpu compatibility

* use the update openlm-research/open_llama_3b_v2 models
2023-09-19 09:15:10 -04:00
Wing Lian
1eebbd09c3 improve handling for empty text on the tokenization step (#502) 2023-09-19 08:09:56 -04:00
Wing Lian
62a774140b Fix for check with cfg and merge_lora (#600) 2023-09-18 21:14:32 -04:00
Wing Lian
31b9e0c6e8 minor tweaks to simplify (#597) 2023-09-18 11:45:44 -04:00
Wing Lian
6b9b229356 btlm and falcon monkey patches for flash attn (#566) 2023-09-17 13:49:18 -04:00
Wing Lian
131afdbd89 add bf16 check (#587) 2023-09-17 13:49:03 -04:00
NanoCode012
00dce35fb2 Feat(data): Allow loading local csv and text (#594)
* Feat(data): Allow loading local csv and text

* chore: update readme for loading data
2023-09-17 11:32:27 -04:00
Wing Lian
b15b19eb8d gather/broadcast the max value of the packing efficiency automatically (#463) 2023-09-17 11:08:18 -04:00
Wing Lian
ab534d75ba don't add position_ids for evals (#591) 2023-09-16 16:11:57 -04:00
Wing Lian
21ec195c9f optionally configure sample packing for evals (#589) 2023-09-16 00:09:48 -04:00
Wing Lian
62eaee7649 make phi training work with Loras (#588)
* valdiation for phi loras

* fix model config class check

* update readme for phi traiing
2023-09-15 20:51:55 -04:00
Jan Philipp Harries
be75668400 set fsdp state dict (#584)
Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
2023-09-15 17:47:36 -04:00
Wing Lian
aeec7c4688 pop block_cls since it's not an actual kwarg 2023-09-15 15:54:06 -04:00
Wing Lian
360788296a don't resize embeddings if it's already large enough (#577)
* don't resize embeddings if it's already large enough

* make sure to tie weights, even if we aren't resizing
2023-09-15 15:47:09 -04:00
Wing Lian
12a2dbbc2c Support Sample packing for phi arch (#586)
* phi sequence packing

* sample packing fixes

* fix linting

* fix inference and phi e2e tests

* update phi example now that sample packing works

* wandb import keeps getting moved around
2023-09-15 15:46:54 -04:00
NanoCode012
3a2edc85c3 Feat(doc): Add features to doc (#583) 2023-09-16 01:14:15 +09:00
Wing Lian
f7a22632d7 support custom field for completion from yml (#580)
* support custom field for completion from yml

* remove legacy completion check and add doc

* update README docs
2023-09-15 07:48:21 -04:00
Doan Minh Phuong
1aa400721e Fix Codellama examples (#582)
* Fix seq_len

* Update lora.yml

* Update qlora.yml

* Update lora.yml

* Update lora.yml

* Update qlora.yml
2023-09-15 04:19:13 -04:00
Wing Lian
8dcd40ac78 prevent cli functions from getting fired on import (#581) 2023-09-15 04:03:32 -04:00
Wing Lian
a5a625f47e update support matrix with btlm and phi (#579) 2023-09-15 02:46:15 -04:00
Wing Lian
861cecac2a refactor scripts/finetune.py into new cli modules (#550)
* refactor scripts/finetune.py into new cli modules

* continue to support scripts/finetune.py

* update readme with updated cli commands

* Update scripts/finetune.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-09-15 01:43:52 -04:00
Wing Lian
1078d3eae7 E2e passing tests (#576)
* run e2e tests after all other checks have passed

* tweak tests so they get run on PRs or push to main

* change dependent action for chcecking

* one test workflow to rule them all

* no need for custom action, just use needs

* whoops, python version should be a string

* e2e tests can run on any available gpu
2023-09-15 01:03:49 -04:00
Wing Lian
24146733db E2e device cuda (#575)
* use torch.cuda.current_device() instead of local_rank

* ignore NVML errors for gpu stats

* llama lora packing e2e tests
2023-09-14 22:49:27 -04:00
Wing Lian
9218ebecd2 e2e testing (#574) 2023-09-14 21:56:11 -04:00
Wing Lian
228420972e Phi examples (#569)
* add phi full ft example

* Add readme to point out that deepspeed should be used

* zero1 is better than zero2 for phi
2023-09-14 11:17:47 -04:00
Wing Lian
c6d870b91d mypy wandb ignore (#572)
* mypy wandb ignore

* fix isort for wandb
2023-09-14 11:17:30 -04:00
Wing Lian
115795079d remove columns after tokenizing for pretraining (#571) 2023-09-14 11:08:22 -04:00
Wing Lian
3b18c963cc set auto for other params that hf trainer sets for ds. include zero1 json (#570) 2023-09-14 11:04:37 -04:00
Wing Lian
3fbde762ab fix save_steps so it doesn't get duplicated (#567) 2023-09-13 20:40:33 -04:00
Wing Lian
f6060a664e Model parallel (#538)
* model-parallel for single process

* fix device/device_map

* fix handling for device
2023-09-13 11:45:30 -04:00
Wing Lian
a4e1bb6606 let hf trainer handle torch compile (#516)
* let hf trainer handle torch compile

* remove torch compile checks, include option for backend

* suppress torch errors to get further

* require min torch version of 2.1.0 for torch compile to work

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>
2023-09-13 11:42:12 -04:00
Wing Lian
36e53c7442 improve how we setup eval/save strategies and steps (#547)
* setup save end eval strategies to be consistent with trainer logic

* add comments

* better eval handling
2023-09-13 11:37:23 -04:00
Wing Lian
e7aa7b1a1e gracefully handle length feature used for group by (#565) 2023-09-13 11:23:30 -04:00
Wing Lian
e5bb22a56b add optimization for group-by-len (#563) 2023-09-13 10:57:12 -04:00
Wing Lian
fdb777bc06 check for the existence of the default accelerate config that can create headaches (#561) 2023-09-13 10:38:28 -04:00
Wing Lian
bf0804447c fix wandb so mypy doesn't complain (#562)
* fix wandb so mypy doesn't complain

* fix wandb so mypy doesn't complain

* no need for mypy override anymore
2023-09-13 10:36:16 -04:00
Glavin Wiechert
5b67ea98a6 Add training callback to send predictions to WandB table (#521)
* WIP Add training callback to send predictions to WandB table

* WIP improve wandb table reporting callback

* WIP improve wandb table reporting callback (cont)

* Add VSCode launching for debugging

* Add tiny llama example

* WIP attempt to improve post-eval prediction generation for table

* WIP attempt to improve post-eval prediction generation for table - part 2

* WIP batch generation

* WIP attempt to handle sample_packing using position_ids for wandb prediction table

* WIP add code for debugging

* Fix sample_packing support for wandb prediction table

* Clean up code for PR review

* Add eval_table_size, eval_table_max_new_tokens configs & clean up code

* Clean up PR, delete VSCode config, add tiny-llama example

* Add eval_table_size, eval_table_max_new_tokens documentation. Fix linting/formatting
2023-09-13 09:51:08 -04:00
Jan Philipp Harries
2f586d18db Fix pretraining with iterable/streaming Dataset (#556)
* return without packing prep/len

* fix remove columns

* fix encode arguments

* add error when max steps not set

* fix test

---------

Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
2023-09-13 00:16:40 -04:00
Wing Lian
9845c5e12d document that packaging needs to be installed before flash-attn (#559) 2023-09-12 12:18:30 -04:00
Wing Lian
772cd870d4 fix the sed command to replace the version w the tag
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
publish pypi / Upload release to PyPI (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-09-11 13:44:19 -04:00
Wing Lian
6c5fbe6223 add long_description for pypi push (#555) 2023-09-11 13:34:29 -04:00
Wing Lian
bcbc9597e9 replace tags, build dist for pypi publish (#553)
* replace tags, build dist for pypi publish

* missing trailing comma
2023-09-11 13:25:41 -04:00
The Objective Dad
6d57f2f0f0 ergonomic update to optimizer config doc (#548) 2023-09-11 12:35:45 -04:00
Wing Lian
20ed4c1f9e pypi on tag push (#552) 2023-09-11 10:33:42 -04:00
Wing Lian
c5dedb17ad remove with section, doesn't seem to work (#551) 2023-09-11 10:27:17 -04:00
Wing Lian
b56503d423 publish to pypi workflow on tagged release (#549) 2023-09-11 09:44:47 -04:00
Wing Lian
a94f9cb99e fix for quant config from model (#540) 2023-09-10 12:40:52 -04:00
dongxiaolong
c1921c9acb Update requirements.txt (#543)
fix fsdp
2023-09-08 16:07:11 -04:00
Wing Lian
0b4cf5bc8c workaround for md5 variations (#533)
* workaround for md5 variations

* refactor the prepared hash too
2023-09-08 16:01:05 -04:00
SlapDrone
78ee2cdab2 add git environment variables to compose: avoid checkout failure error 128 on build (#534) 2023-09-08 15:59:49 -04:00
Wing Lian
34c0a86a11 update readme to point to direct link to runpod template, cleanup install instrucitons (#532)
* update readme to point to direct link to runpod template, cleanup install instrucitons

* default install flash-attn and auto-gptq now too

* update readme w flash-attn extra

* fix version in setup
2023-09-08 11:58:54 -04:00
The Objective Dad
5e2d8a42d9 Adding NCCL Timeout Guide (#536)
* fixes NCCL_P2P_LEVEL=NVL #429

* adding more insights into verious values of NCCL_P2P_LEVEL
2023-09-08 11:57:47 -04:00
Wing Lian
e30f1e3cf7 Early stopping metric (#537)
* set early stopping metric to check

* tweak how load_best_model_at_end gets set for early stopping

* add validation for earl;y stopping patience

* remove negation

* save results to metrics in callback

* move early stopping callback after the benchmark evals

* broadcast metrics so early stopping works
2023-09-08 11:57:02 -04:00
Wing Lian
343714972b recommend padding when using sample packing (#531) 2023-09-06 17:00:21 -04:00
Wing Lian
245c5c41e2 log rank too (#527) 2023-09-06 08:37:51 -04:00
Wing Lian
a546ca2813 misc fixes/improvements (#513)
fix per pr feedback
2023-09-05 16:40:13 -04:00
Wing Lian
3355706e22 Add support for GPTQ using native transformers/peft (#468)
* auto gptq support

* more tweaks and add yml

* remove old gptq docker

* don't need explicit peft install for tests

* fix setup.py to use extra index url

install torch for tests
fix cuda version for autogptq index
set torch in requirements so that it installs properly
move gptq install around to work with github cicd

* gptq doesn't play well with sample packing

* address pr feedback

* remove torch install for now

* set quantization_config from model config

* Fix the implementation for getting quant config from model config
2023-09-05 12:43:22 -04:00
mhenrichsen
daa4faca12 Merge pull request #520 from bdashore3/sharegpt-fixes
Allow for custom system prompts with ShareGPT
2023-09-05 09:02:55 +02:00
Aman Karmani
fc8766e502 reorg a bit 2023-09-05 02:21:24 +00:00
Aman Gupta Karmani
72a6fe1c1f use flash_attn rmsnorm when available (#526)
* use flash_attn xentropy when available

* use flash_attn.ops.rms_norm when available

* log when xentropy is not found

* log how to install RMSNorm

* add quotes so pip install works
2023-09-04 19:44:51 -04:00
Aman Gupta Karmani
5fe30b1497 use flash_attn xentropy when available (#525)
* use flash_attn xentropy when available

* log when xentropy is not found
2023-09-04 17:49:16 -04:00
Aman Gupta Karmani
44454ae4c4 move is_llama_derived_model into normalize_config (#524) 2023-09-04 00:19:03 -04:00
Wing Lian
09f154397e No gather single gpu (#523)
* don't attempt to gather on multi-gpu

* also check distributed status in bench callback
2023-09-03 23:24:28 -04:00
kingbri
995557bdf3 Prompters: ShareGPT: Allow for custom system prompts
If a system prompt is present in a conversation, add it instead of
using the default.

Signed-off-by: kingbri <bdashore3@proton.me>
2023-09-01 13:53:05 -04:00
Maxime
1991946c5a fix: bad dtype for full finetune (#504)
* fix: bad dtype for full finetune

* Update src/axolotl/utils/models.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update models.py

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-09-01 07:11:45 -07:00
NanoCode012
f51c9c56c6 Fix(doc): Inform Windows users to use WSL/docker (#518) 2023-09-01 00:08:21 -07:00
Wing Lian
7710e81f50 log supervised token count (#448) 2023-08-31 15:45:23 -07:00
Tom Jobbins
48434bec54 Debug tokenization output: Add ability to output text only (no tokens), and/or specify num samples to see (#511) 2023-08-31 14:26:52 -07:00
Jan Philipp Harries
396a7a74fc Added advanced DDP args (#515)
* add ddp_config

* add advanced ddp config

* add ddp_config

* add advanced ddp config

---------

Co-authored-by: Jan Philipp Harries <jphme@users.noreply.github.com>
2023-08-31 10:37:47 -07:00
Wing Lian
b21e4a20fe split train from other cli options (#503) 2023-08-30 22:01:47 -07:00
Alpay Ariyak
42f9642792 Changed Bench Eval to report metrics correctly by split. Added total accuracy and renamed previously used bench_accuracy to bench_average_accuracy. (#512)
* Added "eval_" prefix

* Added total bench accuracy and renamed the previous one to bench_average_accuracy. Changed naming to use bench_split instead of always using eval_ prefix.
2023-08-30 22:00:50 -07:00
Wing Lian
c56b450cf5 drop empty tokenized rows too (#509) 2023-08-30 06:55:26 -07:00
Aman Gupta Karmani
1e07c162f1 set zero3 optimizer betas to auto so they inherit from HF trainer config (#507) 2023-08-30 08:10:33 -04:00
Wing Lian
76576323df add eval benchmark callback (#441)
* add mmlu callback

* use hf dataset for mmlu evals

* default to mmlu-zs

* make sure to define all the explicit positional args

* include metrics in callback

* another callback fix for collator max len attribute

* fix mmlu evals

* sample benchmarks, ensure we drop long samples

* fix the data file

* fix elif and add better messaging

* more fixes

* rename mmlu to bench

* more fixes

* dataset handling and aggregate across benchmark

* better handling when no subjects

* benchmark callback has its own dataloader and collator

* fixes

* updated dataset

* more fixes

* missing transformers import

* improve support for customized dataset for bench evals

* gather benchmarks from all ranks

* fix for gather across multiple gpus
2023-08-29 13:24:19 -07:00
Wing Lian
548787daae customizable ascii art (#506) 2023-08-29 10:13:42 -07:00
Wing Lian
5ac3392075 support for datasets with multiple names (#480)
* support for datasets with multiple names

* update docs
2023-08-29 06:18:17 -07:00
Aman Gupta Karmani
e356b297cb remove --force-reinstall from Dockerfile to ensure correct pytorch version (#492) 2023-08-29 06:17:51 -07:00
NanoCode012
48c56470d0 Fix(doc): Clarify no amp to full yaml docs (#496) 2023-08-29 06:17:37 -07:00
Maxime
36b2e1cfee tweak: use default config file when only one file is present (#501) 2023-08-29 06:17:10 -07:00
Wing Lian
125cccb786 Refactor train cfg cli (#499)
* wip to cleanup cfg cli options

* fix launcher

* fix cli args
2023-08-29 05:37:53 -07:00
Aman Karmani
fd55bc87e2 use math.ceil instead of round /cc #498 2023-08-29 01:03:41 +00:00
Birch-san
8e197f6fb4 pad_to_worst_case_seq_len boolean, for testing memory limits (#498)
* pad_to_worst_case_seq_len boolean, for testing memory limits

* remove collator_pad_to_longest option since it does nothing

see docs: https://huggingface.co/docs/transformers/main_classes/data_collator#transformers.DataCollatorWithPadding.padding

True and "longest" mean the same thing

* rename to `pad_to_sequence_len, and ensure 64 alignment

---------

Co-authored-by: Aman Karmani <aman@tmm1.net>
2023-08-28 18:47:16 -04:00
Aman Karmani
267b7b24e5 simplify linear layer locator 2023-08-28 09:45:16 -04:00
Wing Lian
98bf76e236 fsdp requires params be the same type too (#493) 2023-08-28 04:33:50 -04:00
NanoCode012
4c37bd0b54 Fix(tokenizer): Make sure to add pad for CodeLlamaTokenizer (#489) 2023-08-28 09:39:10 +09:00
Aman Gupta Karmani
f144e98a32 Merge pull request #485 from maximegmd/patch-4
fix: finetune model inference needs the dtype fix to work with flash-attn
2023-08-27 16:27:47 -04:00
Aman Karmani
3a011ea1ef fix condition and add logging 2023-08-27 20:09:26 +00:00
Aman Karmani
1f613e5aa7 Merge branch 'main' into patch-4 2023-08-27 19:57:34 +00:00
Aman Karmani
f319b0bc67 rename var and reformat 2023-08-27 19:55:11 +00:00
Maxime
7fd662dd89 Update src/axolotl/utils/models.py
Co-authored-by: Aman Gupta Karmani <aman@tmm1.net>
2023-08-27 21:01:43 +02:00
Maxime
9e699683d7 Update src/axolotl/utils/models.py
Co-authored-by: Aman Gupta Karmani <aman@tmm1.net>
2023-08-27 21:01:37 +02:00
mhenrichsen
35130711d6 Feat(cfg): Add code-llama configs for all sizes (#479)
* configs for all sizes

* update tokenizer type

---------

Co-authored-by: mhenrichsen <some_email@hey.com>
2023-08-27 10:20:17 +09:00
mhenrichsen
3fc9006298 Feat(deepspeed): Add zero2 config (#476)
* zero2 config

* config added

* linting

---------

Co-authored-by: mhenrichsen <some_email@hey.com>
2023-08-27 10:10:33 +09:00
NanoCode012
ad8be435ad Feat(doc): Update eval_steps doc (#487) 2023-08-27 10:09:09 +09:00
Charles O. Goddard
fe4d6baf92 Add example Llama 2 ReLoRA config (#471)
* Add example Llama 2 ReLoRA config

* Use adamw_bnb_8bit in example relora config
2023-08-27 10:08:34 +09:00
Aman Gupta Karmani
f31301063d Merge pull request #486 from OpenAccess-AI-Collective/adam-bnb-simpler
let transformers handle adamw_bnb_8bit
2023-08-26 20:44:19 -04:00
Aman Karmani
868530c39c let transformers handle adamw_bnb_8bit 2023-08-26 21:40:12 +00:00
Maxime
d03887fad5 ignore: address pr review 2023-08-26 22:45:45 +02:00
Maxime
17605b85d8 fix: inference did not move the model to the correct device (#483) 2023-08-26 16:40:56 -04:00
Maxime
a184549e4c ignore: linter 2023-08-26 22:36:14 +02:00
Maxime
f311df9462 fix: finetune model inference needs the dtype fix to work with flash-attn 2023-08-26 22:34:11 +02:00
Maxime
c500d02517 Fix missing 'packaging' wheel (#482) 2023-08-26 12:02:15 -04:00
Wing Lian
31f3e71764 fix checkpints on multigpu (#481) 2023-08-26 12:00:03 -04:00
Aman Gupta Karmani
56c4a94caf Merge pull request #484 from OpenAccess-AI-Collective/reqs
allow newer deps in requirements.txt
2023-08-26 11:13:41 -04:00
Aman Karmani
c29117a0d7 allow newer deps 2023-08-26 15:06:05 +00:00
Wing Lian
0b7ba57ec4 fix types w lora (#478) 2023-08-25 02:03:24 -04:00
NanoCode012
71bd06243c Fix(tokenizer): Fix condition to add pad token (#477)
* Fix(tokenizer): Fix condition to add pad token

* chore: fix lint
2023-08-25 14:30:50 +09:00
Wing Lian
cb9797ef5a improve llama pad token handling (#475)
* improve llama pad token handling

* tweak logic to not clobber
2023-08-24 13:20:35 -04:00
Charles O. Goddard
bde3c5a478 ReLoRA implementation (with quantization) (#322)
* Experimental ReLoRA (+qlora) implementation

* Add CPU offload

* Remove local config

* Fix saving logic

* Remove redundant assert

* Fix logic errors

* Move ReLoRA into its own trainer class with a method override to create the proper scheduler

* Formatting & typing fixes

* Use safe_serialization

* Don't allow fsdp/deepspeed with ReLoRA

* Fix cpu-offload logic, enable multi gpu

* Document parameters and add comment

* Fix merge issue

* Smooth over some sharp edges

* Implement resume from checkpoint for relora

* Address review comments

* Fix saving logic

* Add necessary metadata to safetensors

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-08-23 23:07:18 -04:00
NanoCode012
55c23c7bcb Fix(doc): Clarify config (#466) 2023-08-23 11:56:01 -04:00
Wing Lian
c69faee7a7 workaround so training doesn't hang when packed dataloader batches aren't even (#461)
* workaround so training doesn't hang when packed dataloader batches aren't even

* don't bother labeling anything in the no-op data
2023-08-23 10:39:11 -04:00
Wing Lian
d5dcf9c350 fix test fixture b/c hf trainer tokenization changed (#464) 2023-08-23 04:04:49 -04:00
TearGosling
f4746507f6 feat: add Metharme prompt strategy (#446)
* Add Metharme tokenizing strategy

This strategy accounts for how the Metharme JSONLs are formatted as well as adds duplicated EOS tokens which can help trim model output length.
I haven't gotten the chance to test this yet, and probably won't have the chance for quite a bit, so I'm committing this now.

* Redo Metharme tokenizing strategy

lol

* fix: oops

* Rearrange a conditional

* chore: reformat code in accordance with linter

* chore: Make lint not freak out

* chore: fix lint

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-08-22 11:21:45 +09:00
Wing Lian
96deb6bd67 recast loralayer, norm, lmhead + embed token weights per original qlora (#393)
* recast loralayer, norm, lmhead + embed token weights per original qlora

* try again for the fix

* refactor torch dtype picking

* linter fixes

* missing import for LoraLayer

* fix install for tests now that peft is involved
2023-08-21 18:41:12 -04:00
Wing Lian
50682a3c06 always drop samples that are too long (#452) 2023-08-21 16:43:33 -04:00
Wing Lian
5a1985ba24 set env var for FSDP layer to wrap (#453) 2023-08-21 16:43:22 -04:00
Aman Gupta Karmani
5e9c6afa10 Merge pull request #451 from OpenAccess-AI-Collective/eval-is-causal
is_causal fix for evals?
2023-08-21 10:43:46 -07:00
Aman Karmani
a213d9972a fix eval regression caused in 13f7efaf74 2023-08-21 10:40:06 -07:00
Wing Lian
fbf49a4770 is_causal fix for evals? 2023-08-21 10:36:26 -04:00
Wing Lian
58cf7e7fed add missing positional arg (#450) 2023-08-21 04:10:19 -04:00
NanoCode012
04a42b6db1 feat(docs): improve user customized prompts (#443)
* feat(docs): improve user customized prompts

* feat(doc): add custom pretokenized instructions

* chore: clean old data folder

* chore: add new line
2023-08-20 23:59:43 -04:00
NanoCode012
919f4cac90 feat(doc): add pillow to lambda instructions (#445) 2023-08-20 23:59:23 -04:00
Wing Lian
ee262818ef fix evals (#447) 2023-08-20 23:39:42 -04:00
Wing Lian
9d629d8bff gracefully handle empty input (#442) 2023-08-20 09:18:18 -04:00
Wing Lian
d2e7f27240 support user defined prompters, pretokenized datasets in config, local parquet, local arrow files (#348)
* support user defined prompters, pretokenized datasets in config, local parquet, local arrow files

* fix user defined dataset types

* fix for system prompts

* fix tests

* fix checks for parquet and arrow

* aha moment that d.data_files isn't used

* add documentation for ds_type to add support for parquet and arrow
2023-08-20 09:17:49 -04:00
Philpax
d21318dfb9 docs(readme): add cd axolotl (#440) 2023-08-19 19:14:05 -04:00
Wing Lian
f733d0f31e disable eval using multipack for now (#437) 2023-08-19 10:35:04 -04:00
Wing Lian
008505c8ae fix comma, not a tuple (#436) 2023-08-19 00:57:40 -04:00
Wing Lian
b3f5e00ff5 use save_strategy from config if available (#434)
* use save_strategy from config if available

* update docs for save_strategy
2023-08-18 20:28:23 -04:00
Wing Lian
5247c5004e set env for FSDP offload params (#433) 2023-08-18 20:28:09 -04:00
mhenrichsen
cf6654769a flash attn pip install (#426)
* flash attn pip

* add packaging

* add packaging to apt get

* install flash attn in dockerfile

* remove unused whls

* add wheel

* clean up pr

fix packaging requirement for ci
upgrade pip for ci
skip build isolation for requiremnents to get flash-attn working
install flash-attn seperately

* install wheel for ci

* no flash-attn for basic cicd

* install flash-attn as pip extras

---------

Co-authored-by: Ubuntu <mgh@mgh-vm.wsyvwcia0jxedeyrchqg425tpb.ax.internal.cloudapp.net>
Co-authored-by: mhenrichsen <some_email@hey.com>
Co-authored-by: Mads Henrichsen <mads@BrbartiendeMads.lan>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-08-18 19:00:27 -04:00
Aman Gupta Karmani
06edf175ac standardize attn hijack patches (#381)
* split sdp attn into its own patch

* sync xformers patch to follow shared format and be diffable

* update flash-attn patch for 70B/GQA and inference using helper from flash-attn tests

* speed up flash-attn inference

* fix patch to check position ids and don't use multipack for evals

* copy LlamaModel.forward and LlamaDecoderLayer.forward into monkeypatch

* update forwards so we only calculate cu_seqlens once

* enable eval dataloader using multipack again

* fix the patch to work properly and work with FSDP

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-08-18 12:54:16 -04:00
mhenrichsen
0a228479b3 adds color (#425)
* adds color

* chore: lint

* fix for colorama

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-08-18 10:59:43 -04:00
Wing Lian
82e111aba9 remove extra accelearate in requirements (#430) 2023-08-18 10:56:14 -04:00
Wing Lian
8cace80175 fix fixture for new tokenizer handling in transformers (#428) 2023-08-17 17:01:52 -04:00
Wing Lian
1b7e8604bb fix orca prompts (#422) 2023-08-16 11:21:03 -04:00
NanoCode012
3d1f203b62 Fix(docs): Remove gptq+lora and fix xformer compat list (#423) 2023-08-16 13:56:48 +09:00
Wing Lian
d3d6fd6ae6 just resort to tags ans use main-latest (#424) 2023-08-16 00:39:57 -04:00
NanoCode012
b7449a997f Fix(template): Inform to place stack trace to Issue (#417)
* Fix(template): Inform to place stack trace to Issue

* Update following suggestions

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-08-16 11:55:48 +09:00
Wing Lian
5f80b3560b use inputs for image rather than outputs for docker metadata (#420) 2023-08-15 18:26:59 -04:00
Wing Lian
24959091d7 hopefully improve the README (#419)
* hopefully improve the README

* exitcode -9 help

* table of contents

* formatting
2023-08-15 15:30:53 -04:00
Wing Lian
7af816699e tag with latest as well for axolotl-runpod (#418)
* tag with latest as well for axolotl-runpod

* no dev branch for now
2023-08-15 15:30:41 -04:00
mhenrichsen
f806e86a6e Merge pull request #413 from mhenrichsen/chore/update-deepseed-config
update path to align with fsdp example
2023-08-15 20:08:23 +02:00
NanoCode012
2b990eb628 Feat(doc): Add lr_quadratic_warmup to readme (#412) 2023-08-16 02:55:48 +09:00
mhenrichsen
bd8cab49c9 update path to align with fsdp example 2023-08-15 19:51:58 +02:00
NanoCode012
c01015f33f Fix(config): Update handling of deepspeed config (#404)
* Fix(config): Update handling of deepspeed config

* feat: auto set deepspeed env if deepspeed passed

* fix: update new deepspeed instructions
2023-08-16 01:22:43 +09:00
NanoCode012
72fe3f8e3d Fix(docs): Update flash attn requirements (#409) 2023-08-15 22:40:52 +09:00
Wing Lian
47961fdb8b update docs for tokenizer_legacy (#401)
* update docs for tokenizer_legacy

* add default info
2023-08-15 22:34:42 +09:00
NanoCode012
7ad37cb6d7 Fix(template): Remove iPhone/android from Issue template (#407) 2023-08-15 22:32:51 +09:00
Wing Lian
29241cf1e4 Ax art (#405)
* axolotl text art :D

* only print art on rank0

* lint and pr feedback
2023-08-15 08:34:30 -04:00
lightningRalf
31db0ecce4 add templates, CoC and contributing guide (#126)
* add templates, CoC and contributing guide

* Update .github/SECURITY.md

correct responsible person

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* Update bug-report.yaml

axolotl version switch with axolotl branch-commit

* update CONTRIBUTING doc

* update reporting link

* linter fixes

* chore: fix linter

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-08-15 07:41:05 -04:00
Wing Lian
da10af03e9 fix eval steps and strategy (#403) 2023-08-15 07:28:50 -04:00
Wing Lian
85cf4f8e2c better handling of empty input ids when tokenizing (#395)
* better handling of empty input ids when tokenizing

* Add warning if tokenizer resulted in empty result

* fix len comparison for linter
2023-08-15 01:09:59 -04:00
Aman Karmani
2e22404d2d add utils.data.prepare_dataset 2023-08-14 21:28:29 -07:00
NanoCode012
be294fd605 Feat(doc): Add how to save by epochs (#396) 2023-08-15 13:24:25 +09:00
Wing Lian
fc2d6be96d use context manager to run things on rank0 before others (#397) 2023-08-15 00:10:47 -04:00
Wing Lian
1687be6a35 don't use mask expansion for inference (#392) 2023-08-14 20:52:54 -04:00
NanoCode012
41ecb451c2 Feat(doc): Add max_steps to readme (#389) 2023-08-15 00:34:22 +09:00
Gabriel Puliatti
3c2ad00d07 Feat(config): add max steps (#387) 2023-08-14 11:19:29 -04:00
florian peyron
5d48a10548 Added "epoch" evaluation_strategy (#388) 2023-08-14 10:59:23 -04:00
NanoCode012
73a0b6ead5 Feat(config): Add hub_strategy (#386) 2023-08-14 07:12:55 -04:00
florian peyron
63fdb5a7fb Error msg for sharegpt if conv has less than 2 msg (#379) 2023-08-14 17:40:40 +09:00
mhenrichsen
fdffef5940 new llama-2 default settings (#370)
* new default settings

* fix whitespace

* rm max packed sequence length

---------

Co-authored-by: Mads Henrichsen <mads@BrbartiendeMads.lan>
2023-08-14 17:39:09 +09:00
Wing Lian
919246fbc1 don't pass rope_scaling kwarg if it's None (#383) 2023-08-13 18:57:38 -04:00
Wing Lian
ffac902c1b bump flash-attn to 2.0.4 for the base docker image (#382) 2023-08-13 17:55:04 -04:00
Charles Goddard
15f6e57eaa Fix crash when running without CUDA 2023-08-13 13:36:40 -07:00
NanoCode012
729c299256 Feat(doc): Improve sharegpt doc (#378)
* Feat(doc): Improve sharegpt doc

* Fix typo
2023-08-14 00:36:00 +09:00
Wing Lian
86a91e260b save tokenizer before training starts (#380) 2023-08-13 11:28:58 -04:00
Aman Gupta Karmani
094fc2c6e6 try to detect accelerate and only use device_map=None in that case (#373) 2023-08-13 00:32:07 -04:00
Wing Lian
2dafa730ef Create FUNDING.yml 2023-08-13 00:30:34 -04:00
Wing Lian
343ac84e5a fix check for flash attn branching (#377) 2023-08-12 22:48:08 -04:00
Aman Karmani
0c967279ce remove unnecessary local variable 2023-08-13 01:58:39 +00:00
Aman Karmani
efb3b2c95e simplify load_tokenizer 2023-08-12 18:55:06 -07:00
Aman Karmani
7b55fe6419 improve GPU logging to break out pytorch cache and system mem 2023-08-12 18:52:57 -07:00
Aman Karmani
e029ab34ea quiet noise from llama tokenizer by setting pad token earlier 2023-08-12 18:31:40 -07:00
Aman Karmani
8cec513447 extract module for working with cfg 2023-08-12 18:25:27 -07:00
Aman Karmani
a13e45d548 fix DefaultDict.__or__ 2023-08-13 01:15:50 +00:00
Wing Lian
918f1b0dfb revert previous change and build ax images w docker on gpu (#371) 2023-08-12 20:23:00 -04:00
Wing Lian
c3fde36ada attempt to run non-base docker builds on regular cpu hosts (#369) 2023-08-12 19:07:38 -04:00
Wing Lian
2bb0b78975 Attention mask and position id fixes for packing (#285)
* fix attetion mask with packing

* set position ids and use block diagonal attn mask

* fix expand mask for multiple batch items, make sure we pad position_ids

* don't move masks to cpu

* use multi pack dataloader w random sampler

* add position_ids back

* more fixes for dataloader integration

* est total tokens, fix field loop

* more fixes, position_ids seems broken

* more fixes for sample packing

* use distributed sampler, avoid accelerate prepare

* use accelerator prepare for dataloader

* fix for position_ids w packing

* Update src/axolotl/utils/dataloader.py

* validation for sample packing and doc

* more fixes for 4k and optimizations

* optimized expand mask fn

* better handling of variance in multipack dataloader length and trainer hanging when it runs out of data

* fix rounding of len of batches to int

* better handling so that all devices have the same dataloader len

* fix step calc for packing

* pass sample packing efficiency to training args

* add a test for the mask expansion for sequence packing

* only process eval dataset for packing if not None

* don't split batches when packing

* weighted CE losses

* weighted CEL fixes

* limit packing to sequences of max seq len

* seq_len_multiple for packing

* make sure the chunk size is an int

* sample_packing_seq_len_multiplier config

* use cumulative seq len with var len flash attn v2 w packing

* properly calculate max len

* fix flash-attn, xformers, packing, support chatml

* fix chatml system prompt for openorca, legacy tokenizer opts

* add chatml

* add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test

* fix test and pylint checks

* more packing and dataset optimizations and fixes

* filter w multiple cpus

* more fixes and optimizations

* fixes and go back to distributed sampler since batch sampler won't work

* fix counts by accounting for num devices

* fix steps calculation

* previous accelerate is still most performant

* add numba to requirements.

* use custom distributed checks

* fix sampler to prevent overfit w new epochs

* let's not cleanup the cached datasets

* calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier

* speed optimizations and set accelerate fsdp env vars

* optimize dataset concatenation?

* more optimizations for dataset handling

* fix import for annotation

* manual pre-commit fixes

* another sum optimization and bug fix for calc steps

* fix packing estimations

* fix formatting

* pylint problems

* add back flash attention branch for handling unpacked sequences seperately

* Address PR feedback

* add optional sample packing config params to readme
2023-08-12 15:14:56 -04:00
NanoCode012
a276c9c88d Fix(save): Save as safetensors (#363) 2023-08-13 01:22:52 +09:00
Morgan McGuire
7019509daa Add wandb_entity to wandb options, update example configs, update README (#361)
* Update wandb_entity and add wandb descriptions

* add wandb to config section

* remove trailing whitespace for pre-commit hook

* remove trailing whitespace for pre-commit hook

---------

Co-authored-by: Morgan McGuire <morganmcguire@Morgans-MacBook-Pro.local>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2023-08-12 12:17:11 -04:00
NanoCode012
96bd6ae1c4 Fix(model loading): Warn when model revision is passed to gptq (#364)
* fix(model loading): warn when model revision is passed to gptq

* chore: improve message
2023-08-13 01:16:59 +09:00
NanoCode012
e37d9358e6 Fix(message): Improve error message for bad format (#365) 2023-08-13 01:16:18 +09:00
NanoCode012
b5212068ac Feat: Add rope scaling (#343)
* Feat: Add rope scaling

* fix: move rope config
2023-08-13 00:50:15 +09:00
NanoCode012
289d5c403d feat(merge): save tokenizer on merge (#362) 2023-08-13 00:18:10 +09:00
Aman Gupta Karmani
35c8b90306 Merge pull request #355 from tmm1/bitsandbytes-fixes
bump to latest bitsandbytes release with major bug fixes
2023-08-11 15:15:38 -07:00
NanoCode012
fae6ed8092 Update README.md on pretraining_dataset (#360)
* Update README.md on pretraining_dataset

* Fix message
2023-08-11 12:17:07 +09:00
NanoCode012
94d03c8402 Clarify pre-tokenize before multigpu (#359) 2023-08-11 11:27:42 +09:00
Aman Gupta Karmani
11ddccb80f Merge pull request #356 from tmm1/load_model-args
simplify `load_model` signature
2023-08-09 18:24:34 -07:00
Aman Gupta Karmani
964312199e Merge pull request #354 from tmm1/gpu-util
GPU memory usage logging
2023-08-09 15:44:18 -07:00
Aman Karmani
718102271f simplify load_model signature 2023-08-09 22:36:02 +00:00
Aman Gupta Karmani
f5c11f8262 Merge pull request #350 from tmm1/group-len-false-examples
set `group_by_length` to false in all examples
2023-08-09 14:48:48 -07:00
Aman Karmani
fce40aab23 bump to latest bitsandbytes release with major bug fixes 2023-08-09 21:47:11 +00:00
Aman Karmani
9c314101d5 use newer pynvml package 2023-08-09 21:06:28 +00:00
Aman Karmani
e303d64728 log GPU memory usage 2023-08-09 18:26:28 +00:00
Aman Karmani
b4d1d22782 note pattern when using groups 2023-08-07 16:18:42 -07:00
Aman Karmani
9f99104038 update comment for group_by_length 2023-08-07 01:04:56 -07:00
Aman Karmani
36fefcf94b set group_by_length to false in examples 2023-08-06 23:59:09 -07:00
Wing Lian
176b888a63 ensure enable_input_require_grads is called on model before getting the peft model (#345) 2023-08-06 18:13:10 -04:00
Jan Philipp Harries
3392270544 experimental llama 2 chat support (#296)
* experimental llama 2 chat support

* few small fixes

* llama2_chat

* small fix to follow original implementation

* small fixes and added fixtures/tests

* fix -mixed up inference and finetuning conversations

* args - small fix

* small fix

* small adjustment and warning

* fix with pre-commit

---------

Co-authored-by: Jan Philipp Harries <jpdus@users.noreply.github.com>
2023-08-06 17:40:52 -04:00
Wing Lian
bb53a165f5 add a basic ds zero3 config (#347)
better defaults for ds
2023-08-06 17:19:51 -04:00
ssmi153
10405b9995 Update XFormers Attention Monkeypatch to handle Llama-2 70B (GQA) (#339)
* Fix XFormers attention for Llama-2 70B (GQA)

Updated XFormers MonkeyPatch to handle GQA as used in Llama-2 70B. All the updated code is taken directly from the Transformers library: 07360b6c9c (diff-06392bad3b9e97be9ade60d4ac46f73b6809388f4d507c2ba1384ab872711c51) from their llama_modeling.py file.

* Catch configs without pretraining_tp

* Whitespace bug fix

Command had accidentally been moved out of if-else block.

* pre-commit formatting fixes

Thanks to @winglian
2023-08-06 11:09:04 -04:00
Jan Philipp Harries
c93655c0a3 Added Orca Mini prompt strategy (#263)
* added Orca Mini prompt strategy

* maybe this fixed precommit errors?

* pre-commits passing

---------

Co-authored-by: Jan Philipp Harries <jpdus@users.noreply.github.com>
2023-08-06 03:16:41 +09:00
Wing Lian
fe285430bc optimize the iteration when tokenizeing large datasets (#332) 2023-08-04 12:12:05 -04:00
Aman Gupta Karmani
0d2e34f056 Merge pull request #336 from tmm1/flash-attn
Fix flash-attn + qlora not working with llama models
2023-08-03 16:25:30 -07:00
Aman Gupta Karmani
b56a6c0101 Merge pull request #337 from tmm1/readme-fix
update README
2023-08-03 15:14:17 -07:00
Aman Karmani
2eda9e02a9 fix typo 2023-08-03 21:04:12 +00:00
Aman Karmani
78b9efb7f4 scope flash-attn+qlora fix correctly, scope to llama, add comment 2023-08-03 19:19:39 +00:00
Aman Karmani
312a9fad07 move flash-attn monkey patch alongside the others 2023-08-03 17:20:49 +00:00
Aman Karmani
58d665943e python 3.10 and 3.11 both work fine, as does pytorch 2.1.0.dev 2023-08-03 16:47:25 +00:00
Aman Karmani
cc7e80026e there is no configs folder 2023-08-03 16:31:37 +00:00
mhenrichsen
dc71d8872a feat/llama-2 examples (#319)
* qlora llama-2

* qlora llama-2

* linting

* readme

* lora added

* linting

* change group_by_length

* 13b fitting on 24gb

* grouped lengths true

* add pad token

* change out dir

---------

Co-authored-by: Mads Henrichsen <mads@Brbar-tilhrende-Mads.local>
2023-08-03 19:22:48 +09:00
Aman Karmani
248bf90f89 ensure flash-attn fixes happen in both adapter/lora modes, and use torch_dtype 2023-08-02 20:15:03 +00:00
Wing Lian
77085ea24e qlora w flash attention fixes (#333) 2023-08-01 23:26:16 -04:00
Wing Lian
db2a3586f3 add peft install back since it doesn't get installed by setup.py (#331) 2023-07-31 16:31:53 -04:00
Wing Lian
6c9a87c8ee pin accelerate so it works with llama2 (#330) 2023-07-30 22:20:06 -04:00
Wing Lian
894cba09f3 fix FSDP save of final model (#329) 2023-07-30 21:46:44 -04:00
Wing Lian
41a4d15d43 update README for updated docker images (#328)
* update README for updated docker images

* update readme from pr feedback
2023-07-28 16:50:03 -04:00
Wing Lian
2c37bf6c21 Prune cuda117 (#327)
* drop cuda117/torch 1.13.1 from support, pin flash attention to v2.0.1, rm torchvision/torchaudio install

* gptq base build not needed. add sm 9.0 support
2023-07-26 16:27:49 -04:00
Wing Lian
9f69c4d8c1 latest HEAD of accelerate causes 0 loss immediately w FSDP (#321) 2023-07-24 11:23:56 -04:00
Wing Lian
3d4984b9a5 update prompts for open orca to match the paper (#317)
fix the test for the updated system tokenizer
2023-07-22 13:49:11 -04:00
Wing Lian
ff7f18d1ed disable gh cache for first step of docker builds too 2023-07-22 11:46:37 -04:00
Wing Lian
cf62cfd661 add runpod envs to .bashrc, fix bnb env (#316)
* hopper support for base dockerfile, add runpod envs to .bashrc

* set BNB_CUDA_VERSION env for latest bnb

* don't support hopper yet w 118
2023-07-22 10:09:38 -04:00
Wing Lian
c5df969262 don't use the gha cache w docker 2023-07-22 08:46:21 -04:00
Wing Lian
40a53ff181 Merge pull request #307 from OpenAccess-AI-Collective/xgen-user-sharegpt-tokens
better handling since xgen tokenizer breaks with convert_tokens_to_ids
2023-07-22 04:10:38 -04:00
Wing Lian
dcdec44347 Merge pull request #306 from ethanhs/xgen
Add XGen info to README and example config
2023-07-22 04:10:18 -04:00
Wing Lian
3ffb018a4c Merge pull request #313 from OpenAccess-AI-Collective/tokenizer-llama2-embeddings
don't resize embeddings to multiples of 32x by default
2023-07-22 04:09:59 -04:00
Wing Lian
a94f2eecb1 Merge pull request #299 from OpenAccess-AI-Collective/flash-attention-2
Flash attention 2
2023-07-22 04:07:48 -04:00
Wing Lian
1066751358 don't resize embeddings to multiples of 32x by default 2023-07-22 01:52:38 -04:00
Wing Lian
1b63bf13bc Merge pull request #308 from OpenAccess-AI-Collective/apache2-license
add apache 2.0 license
2023-07-21 09:50:14 -04:00
Wing Lian
5cce2a42ff add apache 2.0 license 2023-07-21 09:49:29 -04:00
Wing Lian
2a428e8014 better handling since xgen tokenizer breaks with convert_tokens_to_ids 2023-07-21 09:24:11 -04:00
Wing Lian
cdf85fdbd5 pin flash attention 2 to the fix for backwards pass 2023-07-21 08:18:53 -04:00
Wing Lian
9b790d359b flash attention 2 2023-07-21 08:17:46 -04:00
Ethan Smith
38811434e6 Add XGen info to README and example config 2023-07-21 00:44:50 -07:00
NanoCode012
06c61d6f13 Merge pull request #304 from OpenAccess-AI-Collective/NanoCode012-patch-1
Fix(readme): Improve wording for push model
2023-07-21 13:39:45 +09:00
Wing Lian
262dc29df2 Merge pull request #300 from OpenAccess-AI-Collective/pytorch-201
Pytorch 2.0.1
2023-07-21 00:28:38 -04:00
NanoCode012
165907fddb Fix(readme): Improve wording for push model 2023-07-21 11:28:35 +09:00
Wing Lian
a032c9f452 fix sdp attention to use the flash/mem-efficient context manaager 2023-07-20 01:05:48 -04:00
Wing Lian
b06d3e3645 explicitly pin flash attention 1 to v1.0.9 2023-07-20 01:02:08 -04:00
Wing Lian
c58034d48c use pytorch 2.0.1 2023-07-20 00:47:13 -04:00
NanoCode012
28fd429bcf Merge pull request #293 from NanoCode012/fix/tokenize-speed
Fix(tokenizing): Use multi-core
2023-07-19 11:02:04 +09:00
NanoCode012
45ac7c4f88 feat: use multi-core 2023-07-19 10:16:54 +09:00
Wing Lian
edd6980dd9 Merge pull request #289 from OpenAccess-AI-Collective/hf_transfer
add hf_transfer to requirements for faster hf upload
2023-07-17 15:08:06 -04:00
Wing Lian
dc6d25124d Merge pull request #288 from OpenAccess-AI-Collective/NanoCode012-patch-1
fix(readme): remove accelerate config
2023-07-17 14:46:43 -04:00
Wing Lian
6dd2e7d671 add hf_transfer to requirements for faster hf upload 2023-07-17 14:44:48 -04:00
NanoCode012
b64f411849 fix(readme): remove accelerate config 2023-07-18 01:31:02 +09:00
Wing Lian
03a59c1ed4 Merge pull request #287 from OpenAccess-AI-Collective/dataclass-fix
fix axolotl training args dataclass annotation
2023-07-17 06:09:23 -04:00
Wing Lian
ebaec3c406 fix axolotl training args dataclass annotation 2023-07-17 04:57:02 -04:00
Wing Lian
73e70e3996 Merge pull request #286 from OpenAccess-AI-Collective/logging-docker-fixes
misc fixes
2023-07-17 04:26:39 -04:00
Wing Lian
d75adb9835 misc fixes 2023-07-17 03:00:27 -04:00
Wing Lian
02224668c3 Merge pull request #283 from OpenAccess-AI-Collective/docker-git-fetch
git fetch fix for docker
2023-07-17 02:17:00 -04:00
Wing Lian
f162f3c7cc set transformers cache env var in docker image 2023-07-16 23:03:54 -04:00
Wing Lian
eca3531329 git fetch fix for docker 2023-07-16 22:25:05 -04:00
Wing Lian
6f16c4569d Merge pull request #276 from theobjectivedad/logging_enhancement
Logging update: added PID and formatting
2023-07-16 17:04:52 -04:00
Wing Lian
0bd09c077d Merge pull request #280 from teknium1/main
Update requirements.txt
2023-07-16 16:08:58 -04:00
Wing Lian
469c08c9ba Merge pull request #279 from NanoCode012/feat/multi-gpu-readme
Feat(readme): improve docs on multi-gpu
2023-07-16 16:08:37 -04:00
Wing Lian
334af625d0 Merge pull request #277 from cg123/dataset-name
Allow non-default dataset configurations
2023-07-16 16:08:15 -04:00
Teknium
273b3a3aa7 Update requirements.txt
Require latest git accelerate to fix saving checkpoint issue
2023-07-16 10:24:24 -07:00
Charles Goddard
3cdd8e4122 Add dataset name to all yaml options in README 2023-07-15 13:17:37 -07:00
NanoCode012
cf5ae6b649 Feat(readme): improve docs on multi-gpu 2023-07-16 01:07:27 +09:00
theobjectivedad
b1f4f7a34d Fixed pre-commit problems, fixed small bug in logging_config to handle LOG_LEVEL env var 2023-07-15 12:29:35 +00:00
The Objective Dad
83237b8445 Merge branch 'OpenAccess-AI-Collective:main' into logging_enhancement 2023-07-15 06:16:04 -05:00
Charles Goddard
46032a1a1f Fix formatting mistake 2023-07-14 20:57:27 -07:00
Charles Goddard
8bba64258e Add example of dataset with configuration name to README 2023-07-14 20:46:21 -07:00
Charles Goddard
88089e8b32 Add ability to pass 'name' argument to load_dataset 2023-07-14 16:46:39 -07:00
NanoCode012
168a7a09cc Merge pull request #274 from OpenAccess-AI-Collective/NanoCode012-patch-2
Feat: Set push to hub as private by default
2023-07-14 23:15:47 +09:00
NanoCode012
231031a0e1 Merge pull request #275 from NanoCode012/feat/safetensors
Feat: Add save_safetensors
2023-07-14 23:07:26 +09:00
theobjectivedad
9234b75cb4 Update log message format, IMO this is easier to read. 2023-07-14 07:36:21 -05:00
theobjectivedad
553a86b52c Adding logging enhancement 2023-07-14 07:26:19 -05:00
NanoCode012
5daf7d5299 Merge pull request #273 from OpenAccess-AI-Collective/NanoCode012-patch-1
Feat(docs): Add model_revision arg
2023-07-14 21:09:50 +09:00
NanoCode012
5491278a79 Feat: Add save_safetensors 2023-07-14 13:21:47 +09:00
NanoCode012
1514739f0f Set push to hub as private by default 2023-07-14 13:17:49 +09:00
NanoCode012
896c1aebcf Feat(docs): Add model_revision arg 2023-07-14 12:56:07 +09:00
Wing Lian
ef17e15483 Merge pull request #272 from OpenAccess-AI-Collective/model-revision
support for loading a model by git revision
2023-07-13 23:12:00 -04:00
Wing Lian
69a235061b support for loading a model by git revision 2023-07-13 22:58:25 -04:00
Wing Lian
687d889928 Merge pull request #271 from OpenAccess-AI-Collective/quadratic-warmup
Quadratic warmup
2023-07-10 12:48:02 -04:00
Wing Lian
c4cf567b55 Merge branch 'main' into quadratic-warmup 2023-07-10 12:42:12 -04:00
Wing Lian
c49729d2bc better configuration for quadratic warmup 2023-07-10 11:52:59 -04:00
Wing Lian
13ac4d8de2 Merge pull request #268 from OpenAccess-AI-Collective/fix-adam-args
params are adam_*, not adamw_*
2023-07-08 12:33:34 -04:00
Wing Lian
19cf0bda99 params are adam_*, not adamw_* 2023-07-08 12:13:39 -04:00
Wing Lian
f74edd5b56 Merge pull request #266 from OpenAccess-AI-Collective/trust-remote-no-llama 2023-07-07 21:38:11 -04:00
Wing Lian
d69da99c2c skip explicit model type too if using trust_remote_code 2023-07-07 21:33:11 -04:00
Wing Lian
66afb76a15 don't use llama if trust_remote_code is set since that needs to use AutoModel path 2023-07-07 21:31:02 -04:00
NanoCode012
a692ad3f4c Merge pull request #264 from OpenAccess-AI-Collective/NanoCode012-patch-1
Fix(readme): local path loading and custom strategy type
2023-07-06 23:34:57 +09:00
NanoCode012
41da98b982 Fix for linter 2023-07-06 23:20:11 +09:00
NanoCode012
9e64f42e0f Fix local path loading and custom strategy type 2023-07-06 23:08:09 +09:00
Wing Lian
b9b7d4ce92 Merge pull request #221 from utensil/local_dataset
[WIP] Support loading data files from a local directory
2023-07-03 09:10:13 -04:00
Wing Lian
9bed281867 Merge pull request #258 from NanoCode012/fix/deprecate-push
Fix future deprecation push_to_hub_model_id
2023-07-03 09:08:26 -04:00
NanoCode012
e79c8e617e Fix future deprecation push_to_hub_model_id 2023-07-03 12:44:29 +09:00
Wing Lian
71456955f5 pin pydantic so deepspeed isn't broken 2023-07-02 22:26:51 -04:00
Wing Lian
3a783c04e4 Merge pull request #247 from OpenAccess-AI-Collective/fix-apex-base
update pip install command for apex
2023-07-01 06:18:25 -04:00
Wing Lian
1e5014acec Merge pull request #255 from OpenAccess-AI-Collective/open-orca-prompts
open orca support
2023-07-01 01:11:23 -04:00
Wing Lian
a10da1caff 11.7.0 nvidia/cuda docker images are deprecated, move to 11.7.1
Some checks failed
ci-cd-base / build-base (<nil>, 117, 11.7.1, 3.9, 1.13.1) (push) Has been cancelled
ci-cd-base / build-base (<nil>, 118, 11.8.0, 3.10, 2.0.0) (push) Has been cancelled
ci-cd-base / build-base (<nil>, 118, 11.8.0, 3.9, 2.0.0) (push) Has been cancelled
ci-cd-base / build-base (gptq, 118, 11.8.0, 3.9, 2.0.0) (push) Has been cancelled
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-07-01 00:29:07 -04:00
Wing Lian
4066c78631 Merge pull request #246 from OpenAccess-AI-Collective/sys-prompts-instruct
add option for instruct w sys prompts
2023-07-01 00:27:29 -04:00
Wing Lian
78a1e1fa12 open orca support 2023-07-01 00:19:41 -04:00
NanoCode012
bc8a2e5547 Merge pull request #249 from OpenAccess-AI-Collective/NanoCode012-patch-1
Fix typing list in prompt tokenizer
2023-06-30 15:01:41 +09:00
NanoCode012
910ebe47f5 Merge pull request #252 from OpenAccess-AI-Collective/NanoCode012-readme-fix
Add cfg.push_to_hub_model_id to readme
2023-06-30 14:56:55 +09:00
NanoCode012
c146880a75 Update README.md 2023-06-30 11:33:53 +09:00
NanoCode012
77bdb7d144 Fix typing list 2023-06-29 14:29:55 +09:00
Wing Lian
530809fd74 update pip install command for apex 2023-06-28 22:36:28 -04:00
Wing Lian
924bbfddec add option for instruct w sys prompts 2023-06-28 22:27:17 -04:00
Wing Lian
f150c027e3 Merge pull request #224 from OpenAccess-AI-Collective/system-prompt-data
System prompt data
2023-06-27 17:57:43 -04:00
Wing Lian
5c39c006c9 Merge pull request #244 from OpenAccess-AI-Collective/push-to-hub
push intermediate model checkpoints to hub
2023-06-27 17:57:30 -04:00
Wing Lian
612aabd8c4 push intermediate model checkpoints to hub 2023-06-27 15:40:25 -04:00
Wing Lian
af05883f75 Merge pull request #243 from OpenAccess-AI-Collective/unprompted-instruct
skip the system prompt
2023-06-25 22:50:35 -04:00
Wing Lian
05ab9092e3 skip the system prompt 2023-06-25 22:40:50 -04:00
Wing Lian
7b57ed7618 pylint for duplicated code for system prompts 2023-06-25 22:28:07 -04:00
Wing Lian
3a38271276 add tests and supoort for loader for sys prompt data 2023-06-25 22:28:07 -04:00
Wing Lian
8d20e0a3d3 initial wip to get sys prompt from dataset 2023-06-25 22:28:07 -04:00
Wing Lian
de8ed229c3 Merge pull request #240 from OpenAccess-AI-Collective/tokenizer-fast
optionally define whether to use_fast tokenizer
2023-06-25 12:47:55 -04:00
Wing Lian
478d8c7b8e Merge pull request #241 from OpenAccess-AI-Collective/py3-pre-commit
better py3 support w pre-commit
2023-06-25 12:47:02 -04:00
Wing Lian
645c13592c better py3 support w pre-commit 2023-06-25 10:26:02 -04:00
Wing Lian
47d601fa23 optionally define whether to use_fast tokenizer 2023-06-25 10:19:49 -04:00
Wing Lian
756dfba97b Merge pull request #218 from OpenAccess-AI-Collective/no-fail-fast
don't fail fast
2023-06-23 15:42:54 -04:00
Wing Lian
91ab0592af Merge pull request #235 from msinha251/Fixing-data-readme 2023-06-23 13:52:01 -04:00
Mahesh Sinha
0aeb7c7802 Fixing Data Readme 2023-06-21 15:34:48 +02:00
Utensil
9bdd30cdfd Support loading data files from a local directory
ref:  https://huggingface.co/docs/datasets/v2.13.0/en/package_reference/loading_methods#datasets.load_dataset.path
2023-06-21 08:00:58 +00:00
Wing Lian
d35278aaf1 don't fail fast 2023-06-15 16:01:27 -04:00
Wing Lian
9492d4ebb7 Merge pull request #215 from OpenAccess-AI-Collective/adamw-hyperparams-cfg
support adamw and grad norm hyperparams
2023-06-15 12:20:55 -04:00
Wing Lian
ad5ca4f734 Additional test case per pr 2023-06-15 10:12:47 -04:00
Wing Lian
cb9d3af5c0 add validation and tests for adamw hyperparam 2023-06-15 09:39:42 -04:00
Wing Lian
c969f0a9dc add docs 2023-06-15 08:43:20 -04:00
Wing Lian
6d0ee4ba34 support adamw and grad norm hyperparams 2023-06-15 08:40:41 -04:00
Wing Lian
a81f52d575 Merge pull request #212 from OpenAccess-AI-Collective/doc-20230615-v1
add float16 docs and tweak typehints
2023-06-15 08:28:57 -04:00
Wing Lian
1925eaf1e6 Merge pull request #214 from OpenAccess-AI-Collective/fix-tokenizing-labels
Fix tokenizing labels
2023-06-15 08:13:43 -04:00
Wing Lian
1ab3bf3e67 fix test name 2023-06-15 02:09:33 -04:00
Wing Lian
d7635b7148 hint to what AMP means 2023-06-15 02:06:27 -04:00
Wing Lian
88e17ffc50 add float16 docs and tweak typehints 2023-06-15 02:05:31 -04:00
Wing Lian
baed440fa1 ingore duplicate code in tests 2023-06-15 02:03:53 -04:00
Wing Lian
7925ddce86 bugfix for potential off by one 2023-06-15 01:59:33 -04:00
Wing Lian
6f849809c5 Merge pull request #206 from MaciejKarasek/issue205
issue #205 bugfix
2023-06-14 14:23:38 -04:00
Wing Lian
c16644d05e Merge pull request #209 from sroecker/fix_redpajama_example_tokenizer
Use AutoTokenizer for redpajama example
2023-06-14 14:23:21 -04:00
Steffen Röcker
945c4191a3 Use AutoTokenizer for redpajama example 2023-06-14 20:09:26 +02:00
maciej.karasek
136522f9c9 style correction 2023-06-14 20:02:09 +02:00
maciej.karasek
556fe408b3 issue #205 bugfix 2023-06-14 16:59:57 +02:00
Wing Lian
16bb6276a5 Merge pull request #92 from OpenAccess-AI-Collective/flash-optimum
add support for opimum bettertransformers
2023-06-14 07:50:15 -04:00
Wing Lian
4b43a66a0b update alpaca_chat prompts for instructions to explainn the conversation 2023-06-12 18:38:38 -04:00
Wing Lian
7dc580b837 add axolotl trainer and quadratic warmup 2023-06-12 13:16:40 -04:00
Wing Lian
fd2c9814c9 Merge branch 'main' into flash-optimum 2023-06-12 13:12:15 -04:00
Wing Lian
c9a149f9e8 add check for attr 2023-06-11 10:11:17 -04:00
Wing Lian
958da70376 fix formatting 2023-06-10 15:28:08 -04:00
Wing Lian
759e8673ce Update scripts/finetune.py
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
2023-06-10 14:25:21 -04:00
Wing Lian
0c6f928601 address PR feedback 2023-06-10 14:23:56 -04:00
Wing Lian
eea2731a5e add streaming dataset support for pretraining datasets 2023-06-10 14:23:56 -04:00
Wing Lian
1db46a9c72 linting fix 2023-06-10 14:23:56 -04:00
Wing Lian
ab5cd28acf more gpt-neox long ctx fixes 2023-06-10 14:23:55 -04:00
Wing Lian
1a82082e91 fix bettertransformers save, force it to skip after saving correctly in callback 2023-06-10 14:23:55 -04:00
Wing Lian
1210dc8fd5 more tweaks to do pre-training with bettertransformers 2023-06-10 14:23:55 -04:00
Wing Lian
488a67d75a experimental expansion of ctx len 2023-06-10 14:23:53 -04:00
Wing Lian
71a43f8479 add validation/warning for bettertransformers and torch version 2023-06-10 14:22:31 -04:00
Wing Lian
39619028a3 use pythia-12b, neox-20b is flaky 2023-06-10 14:22:30 -04:00
Wing Lian
8792199799 add flash attn context for efficient training and attempt setting model to train mode: 2023-06-10 14:22:30 -04:00
Wing Lian
1edc30c786 add support for opimum bettertransformers 2023-06-10 14:22:30 -04:00
162 changed files with 14727 additions and 2246 deletions

129
.github/CODE_OF_CONDUCT.md vendored Normal file
View File

@@ -0,0 +1,129 @@
# Contributor Covenant Code of Conduct
## Our Pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, religion, or sexual identity
and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our Standards
Examples of behavior that contributes to a positive environment for our
community include:
* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
* Focusing on what is best not just for us as individuals, but for the
overall community
Examples of unacceptable behavior include:
* The use of sexualized language or imagery, and sexual attention or
advances of any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email
address, without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement Responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement on Discord
at https://discord.gg/QYF8QrtEUm
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series
of actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or
permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within
the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.0, available at
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
Community Impact Guidelines were inspired by [Mozilla's code of conduct
enforcement ladder](https://github.com/mozilla/diversity).
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see the FAQ at
https://www.contributor-covenant.org/faq. Translations are available at
https://www.contributor-covenant.org/translations.

76
.github/CONTRIBUTING.md vendored Normal file
View File

@@ -0,0 +1,76 @@
# Contributing to axolotl
First of all, thank you for your interest in contributing to axolotl! We appreciate the time and effort you're willing to invest in making our project better. This document provides guidelines and information to make the contribution process as smooth as possible.
## Table of Contents
- [Code of Conduct](#code-of-conduct)
- [Getting Started](#getting-started)
- [How to Contribute](#how-to-contribute)
- [Reporting Bugs](#reporting-bugs)
- [Suggesting Enhancements](#suggesting-enhancements)
- [Submitting Pull Requests](#submitting-pull-requests)
- [Style Guidelines](#style-guidelines)
- [Code Style](#code-style)
- [Commit Messages](#commit-messages)
- [Additional Resources](#additional-resources)
## Code of Conductcode
All contributors are expected to adhere to our [Code of Conduct](CODE_OF_CONDUCT.md). Please read it before participating in the axolotl community.
## Getting Started
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
PRs are **greatly welcome**!
1. Fork the repository and clone it to your local machine.
2. Set up the development environment by following the instructions in the [README.md](https://github.com/OpenAccess-AI-Collective/axolotl/tree/main/README.md) file.
3. Explore the codebase, run tests, and verify that everything works as expected.
Please run below to setup env
```bash
pip3 install -r requirements-dev.txt -r requirements-tests.txt
pre-commit install
# test
pytest tests/
```
## How to Contribute
### Reporting Bugs
If you encounter a bug or issue while using axolotl, please open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Provide a clear and concise description of the problem, steps to reproduce it, and any relevant error messages or logs.
### Suggesting Enhancements
We welcome ideas for improvements and new features. To suggest an enhancement, open a new issue on the [GitHub Issues](https://github.com/OpenAccess-AI-Collective/axolotl/issues) page. Describe the enhancement in detail, explain the use case, and outline the benefits it would bring to the project.
### Submitting Pull Requests
1. Create a new branch for your feature or bugfix. Use a descriptive name like `feature/your-feature-name` or `fix/your-bugfix-name`.
2. Make your changes, following the [Style Guidelines](#style-guidelines) below.
3. Test your changes and ensure that they don't introduce new issues or break existing functionality.
4. Commit your changes, following the [commit message guidelines](#commit-messages).
5. Push your branch to your fork on GitHub.
6. Open a new pull request against the `main` branch of the axolotl repository. Include a clear and concise description of your changes, referencing any related issues.
## Style Guidelines
### Code Style
axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines.
### Commit Messages
Write clear and concise commit messages that briefly describe the changes made in each commit. Use the imperative mood and start with a capitalized verb, e.g., "Add new feature" or "Fix bug in function".
## Additional Resources
- [GitHub Help](https://help.github.com/)
- [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests)
- [{codestyle}]({URLofCodestyle})
Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together!

13
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1,13 @@
# These are supported funding model platforms
github: OpenAccess-AI-Collective # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

112
.github/ISSUE_TEMPLATE/bug-report.yaml vendored Normal file
View File

@@ -0,0 +1,112 @@
name: Bug Report
description: File a bug report
labels: ["bug", "needs triage"]
body:
- type: markdown
attributes:
value: |
## Before you start
Please **make sure you are on the latest version.**
If you encountered the issue after you installed, updated, or reloaded, **please try restarting before reporting the bug**.
- type: checkboxes
id: no-duplicate-issues
attributes:
label: "Please check that this issue hasn't been reported before."
description: "The **Label filters** may help make your search more focussed."
options:
- label: "I searched previous [Bug Reports](https://github.com/OpenAccess-AI-Collective/axolotl/labels/bug) didn't find any similar reports."
required: true
- type: textarea
id: expected
attributes:
label: Expected Behavior
description: Tell us what **should** happen.
validations:
required: true
- type: textarea
id: what-happened
attributes:
label: Current behaviour
description: |
Tell us what happens instead of the expected behavior.
Provide stacktrace and/or screenshots.
validations:
required: true
- type: textarea
id: reproduce
attributes:
label: Steps to reproduce
description: |
Which exact steps can a developer take to reproduce the issue?
The more detail you provide, the easier it will be to narrow down and fix the bug.
Please paste in tasks and/or queries **as text, not screenshots**.
placeholder: |
Example of the level of detail needed to reproduce any bugs efficiently and reliably.
1. Go to the '...' page.
2. Click on the '...' button.
3. Scroll down to '...'.
4. Observe the error.
validations:
required: true
- type: textarea
id: config
attributes:
label: Config yaml
description: |
Please attach the config yaml!
- type: textarea
id: possible-solution
attributes:
label: Possible solution
description: |
Not obligatory, but please suggest a fix or reason for the bug, if you have an idea.
- type: checkboxes
id: operating-systems
attributes:
label: Which Operating Systems are you using?
description: You may select more than one.
options:
- label: Linux
- label: macOS
- label: Windows
- type: input
id: Python-version
attributes:
label: Python Version
description: Which {Programming} version are you using?
placeholder: 3.10 / please change accordingly
validations:
required: true
- type: input
id: axolotl-branch-commit
attributes:
label: axolotl branch-commit
description: On which branch/commit are you?
placeholder: main/4d6490b
validations:
required: true
- type: checkboxes
id: acknowledgements
attributes:
label: 'Acknowledgements'
description: 'Please confirm the following:'
options:
- label: 'My issue title is concise, descriptive, and in title casing.'
required: true
- label: 'I have searched the existing issues to make sure this bug has not been reported yet.'
required: true
- label: 'I am using the latest version of axolotl.'
required: true
- label: 'I have provided enough information for the maintainers to reproduce and diagnose the issue.'
required: true

7
.github/ISSUE_TEMPLATE/config.yml vendored Normal file
View File

@@ -0,0 +1,7 @@
blank_issues_enabled: false
contact_links:
- name: Ask a question
url: https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/q-a
about: Ask questions and discuss with other community members
- name: Discuss the Project in Discord
url: https://discord.gg/HhrNrHJPRb

46
.github/ISSUE_TEMPLATE/docs.yml vendored Normal file
View File

@@ -0,0 +1,46 @@
name: Documentation Improvement / Clarity
description: Make a suggestion to improve the project documentation.
labels: ['needs triage', 'docs']
body:
- type: markdown
attributes:
value: '## :book: Documentation :book:'
- type: markdown
attributes:
value: |
* Ask questions in [Discord](https://discord.gg/HhrNrHJPRb).
* Before you file an issue read the [Contributing guide](./CONTRIBUTING.md).
* Check to make sure someone hasn't already opened a [similar issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues).
- type: textarea
attributes:
label: What piece of documentation is affected?
description: Please link to the article you'd like to see updated.
validations:
required: true
- type: textarea
attributes:
label: What part(s) of the article would you like to see updated?
description: |
- Give as much detail as you can to help us understand the change you want to see.
- Why should the docs be changed? What use cases does it support?
- What is the expected outcome?
validations:
required: true
- type: textarea
attributes:
label: Additional Information
description: Add any other context or screenshots about the feature request here.
validations:
required: false
- type: checkboxes
id: acknowledgements
attributes:
label: 'Acknowledgements'
description: 'Please confirm the following:'
options:
- label: 'My issue title is concise, descriptive, and in title casing.'
required: true
- label: 'I have searched the existing issues to make sure this feature has not been requested yet.'
required: true
- label: 'I have provided enough information for the maintainers to understand and evaluate this request.'
required: true

View File

@@ -0,0 +1,63 @@
name: Feature Request / Enhancement
description: Suggest a new feature or feature enhancement for the project
labels: ["enhancement", "needs triage"]
body:
- type: checkboxes
id: no-duplicate-issues
attributes:
label: "⚠️ Please check that this feature request hasn't been suggested before."
description: "There are two locations for previous feature requests. Please search in both. Thank you. The **Label filters** may help make your search more focussed."
options:
- label: "I searched previous [Ideas in Discussions](https://github.com/OpenAccess-AI-Collective/axolotl/discussions/categories/ideas) didn't find any similar feature requests."
required: true
- label: "I searched previous [Issues](https://github.com/OpenAccess-AI-Collective/axolotl/labels/enhancement) didn't find any similar feature requests."
required: true
- type: textarea
id: feature-description
validations:
required: true
attributes:
label: "🔖 Feature description"
description: "A clear and concise description of what the feature request is."
placeholder: "You should add ..."
- type: textarea
id: solution
validations:
required: true
attributes:
label: "✔️ Solution"
description: "A clear and concise description of what you want to happen, and why."
placeholder: "In my use-case, ..."
- type: textarea
id: alternatives
validations:
required: false
attributes:
label: "❓ Alternatives"
description: "A clear and concise description of any alternative solutions or features you've considered."
placeholder: "I have considered ..."
- type: textarea
id: additional-context
validations:
required: false
attributes:
label: "📝 Additional Context"
description: "Add any other context or screenshots about the feature request here."
placeholder: "..."
- type: checkboxes
id: acknowledgements
attributes:
label: 'Acknowledgements'
description: 'Please confirm the following:'
options:
- label: 'My issue title is concise, descriptive, and in title casing.'
required: true
- label: 'I have searched the existing issues to make sure this feature has not been requested yet.'
required: true
- label: 'I have provided enough information for the maintainers to understand and evaluate this request.'
required: true

View File

@@ -0,0 +1,22 @@
<!--- Provide a general summary of your changes in the Title above -->
# Description
<!--- Describe your changes in detail -->
## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here. -->
## How has this been tested?
<!--- Please describe in detail how you tested your changes. -->
<!--- Include details of your testing environment, tests ran to see how -->
<!--- your change affects other areas of the code, etc. -->
## Screenshots (if appropriate)
## Types of changes
<!--- What types of changes does your code introduce? Put an `x` in all the boxes that apply: -->

9
.github/SECURITY.md vendored Normal file
View File

@@ -0,0 +1,9 @@
# Security Policy
## Supported Versions
Due to the nature of the fast development that is happening in this project, only the latest released version can be supported.
## Reporting a Vulnerability
If you find a vulnerability, please contact us on [Discord](https://discord.gg/xcu3ECkH9a) rather than creating a GitHub issue to allow us some time to fix it before it is a known vulnerability to others.

10
.github/SUPPORT.md vendored Normal file
View File

@@ -0,0 +1,10 @@
# Support
If you need help with this project or have questions, please:
1. Check the documentation.
2. Search the existing issues and pull requests.
3. Create a new issue if your question is not answered or your problem is not solved.
4. Have a look in the [Discord server](https://discord.gg/HhrNrHJPRb)
Please note that this project is maintained by volunteers who have limited availability. We'll do our best to address your questions and concerns in a timely manner.

View File

@@ -12,28 +12,24 @@ jobs:
# this job needs to be run on self-hosted GPU runners...
runs-on: self-hosted
strategy:
fail-fast: false
matrix:
include:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
axolotl_extras:
pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.0
axolotl_extras:
- cuda: "117"
cuda_version: 11.7.0
python_version: "3.9"
pytorch: 1.13.1
axolotl_extras:
pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
axolotl_extras: gptq
python_version: "3.10"
pytorch: 2.1.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v3
@@ -57,11 +53,9 @@ jobs:
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
build-args: |
CUDA_VERSION=${{ matrix.cuda_version }}
CUDA=${{ matrix.cuda }}
PYTHON_VERSION=${{ matrix.python_version }}
PYTORCH_VERSION=${{ matrix.pytorch }}
AXOLOTL_EXTRAS=${{ matrix.axolotl_extras }}
TORCH_CUDA_ARCH_LIST=${{ matrix.torch_cuda_arch_list }}

View File

@@ -4,36 +4,32 @@ on:
push:
branches:
- "main"
- "dev"
jobs:
build-axolotl:
if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
include:
- cuda: cu118
- cuda: 118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
pytorch: 2.0.1
axolotl_extras:
- cuda: cu118
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.0
pytorch: 2.0.1
axolotl_extras:
- cuda: cu118
is_latest: true
- cuda: 118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
axolotl_extras: gptq
- cuda: cu117
cuda_version: 11.7.0
python_version: "3.9"
pytorch: 1.13.1
python_version: "3.10"
pytorch: 2.1.0
axolotl_extras:
runs-on: self-hosted
runs-on: [self-hosted, gpu, docker]
steps:
- name: Checkout
uses: actions/checkout@v3
@@ -54,13 +50,15 @@ jobs:
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
build-axolotl-runpod:
needs: build-axolotl
if: github.repository_owner == 'OpenAccess-AI-Collective'
@@ -68,27 +66,23 @@ jobs:
strategy:
matrix:
include:
- cuda: cu118
- cuda: 118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
pytorch: 2.0.1
axolotl_extras:
- cuda: cu118
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.0
pytorch: 2.0.1
axolotl_extras:
- cuda: cu118
is_latest: true
- cuda: 118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
axolotl_extras: gptq
- cuda: cu117
cuda_version: 11.7.0
python_version: "3.9"
pytorch: 1.13.1
python_version: "3.10"
pytorch: 2.1.0
axolotl_extras:
runs-on: self-hosted
runs-on: [self-hosted, gpu, docker]
steps:
- name: Checkout
uses: actions/checkout@v3
@@ -109,10 +103,11 @@ jobs:
with:
context: .
build-args: |
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
BASE_TAG=${{ github.ref_name }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
CUDA=${{ matrix.cuda }}
file: ./docker/Dockerfile-runpod
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max

View File

@@ -1,16 +0,0 @@
name: pre-commit
on:
pull_request:
push:
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.9"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0

45
.github/workflows/pypi.yml vendored Normal file
View File

@@ -0,0 +1,45 @@
name: publish pypi
on:
push:
tags:
- '*'
jobs:
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/axolotl
permissions:
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
steps:
- name: Check out repository code
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
run: |
pip3 install wheel
pip3 install -e .
pip3 install -r requirements-tests.txt
- name: Extract tag name
id: tag
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
- name: Update version in setup.py
run: >-
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
- name: Build a binary wheel
run: >-
python setup.py sdist bdist_wheel
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

View File

@@ -1,12 +1,35 @@
name: PyTest
name: Tests
on:
# check on push/merge to main, PRs, and manual triggers
push:
branches:
- "main"
paths:
- '**.py'
- 'requirements.txt'
pull_request:
paths:
- '**.py'
- 'requirements.txt'
workflow_dispatch:
jobs:
test:
pre-commit:
name: pre-commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.9"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.0
pytest:
name: PyTest
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.9", "3.10"]
timeout-minutes: 10
@@ -23,9 +46,35 @@ jobs:
- name: Install dependencies
run: |
pip install -e .
pip install -r requirements-tests.txt
pip3 install -U -e .
pip3 install -r requirements-tests.txt
- name: Run tests
run: |
pytest tests/
pytest --ignore=tests/e2e/ tests/
e2e-test:
name: E2E Tests
runs-on: [self-hosted, gpu]
timeout-minutes: 20
needs: [pre-commit, pytest]
steps:
- name: Check out repository code
uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
# cache: 'pip' # caching pip dependencies
- name: Install dependencies
run: |
pip3 uninstall -y transformers accelerate
pip3 install -U -e .[flash-attn]
pip3 install -r requirements-tests.txt
- name: Run e2e tests
run: |
pytest tests/e2e/

4
.gitignore vendored
View File

@@ -161,3 +161,7 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# WandB
# wandb creates a folder to store logs for training runs
wandb

View File

@@ -1,2 +1,3 @@
[settings]
profile=black
known_third_party=wandb

View File

@@ -8,6 +8,9 @@ ignore_missing_imports = True
[mypy-axolotl.monkeypatch.*]
ignore_errors = True
[mypy-axolotl.models.phi.*]
ignore_errors = True
[mypy-flash_attn.*]
ignore_missing_imports = True
@@ -20,6 +23,9 @@ ignore_missing_imports = True
[mypy-peft]
ignore_missing_imports = True
[mypy-wandb]
ignore_missing_imports = True
[mypy-bitsandbytes]
ignore_missing_imports = True

View File

@@ -1,5 +1,5 @@
default_language_version:
python: python3.9
python: python3
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks

202
LICENSE Normal file
View File

@@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

734
README.md

File diff suppressed because it is too large Load Diff

View File

@@ -1,24 +0,0 @@
## Download some datasets
```shell
curl https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_gpt4.json -o data/raw/alpaca_data_gpt4.json
curl https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -L -o data/raw/vicuna_cleaned.json
curl https://github.com/teknium1/GPTeacher/blob/main/Instruct/gpt4-instruct-similarity-0.6-dataset.json?raw=true -L -o data/raw/gpt4-instruct-similarity-0.6-dataset.json
curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarity_0.6-instruct-dataset.json?raw=true -L -o data/raw/roleplay-similarity_0.6-instruct-dataset.json
```
## Convert the JSON data files to JSONL.
```shell
python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > data/vicuna_cleaned.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
```
---
Using JSONL makes it easier to subset the data if you want a smaller training set, i.e get 2000 random examples.
```shell
shuf -n2000 data/vicuna_cleaned.jsonl > data/vicuna_cleaned.subset0.jsonl
```

1
data/raw/.gitignore vendored
View File

@@ -1 +0,0 @@
**

41
deepspeed/zero1.json Normal file
View File

@@ -0,0 +1,41 @@
{
"zero_optimization": {
"stage": 1,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

45
deepspeed/zero2.json Normal file
View File

@@ -0,0 +1,45 @@
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -1,14 +1,6 @@
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 0,
@@ -35,22 +27,22 @@
"type": "AdamW",
"params": {
"lr": "auto",
"betas": [
0.9,
0.999
],
"eps": 1e-8,
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "OneCycle",
"type": "WarmupDecayLR",
"params": {
"cycle_min_lr": 0.00001,
"cycle_max_lr": 0.00003,
"cycle_first_step_size": 120
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear",
"total_num_steps": "auto"
}
},
"gradient_accumulation_steps": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false

View File

@@ -9,6 +9,11 @@ services:
- ~/.cache/huggingface/:/root/.cache/huggingface/
# set environment variables
environment:
# Set environment variables
- GIT_AUTHOR_NAME=${GIT_AUTHOR_NAME}
- GIT_AUTHOR_EMAIL=${GIT_AUTHOR_EMAIL}
- GIT_COMMITTER_NAME=${GIT_COMMITTER_NAME}
- GIT_COMMITTER_EMAIL=${GIT_COMMITTER_EMAIL}
- WANDB_API_KEY=${WANDB_API_KEY}
deploy:
resources:

View File

@@ -3,24 +3,32 @@ FROM winglian/axolotl-base:$BASE_TAG
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS=""
ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.0.1"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \
apt-get install -y vim curl
WORKDIR /workspace
RUN pip3 install --force-reinstall "peft @ git+https://github.com/huggingface/peft.git@main" \
"accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
"transformers @ git+https://github.com/huggingface/transformers.git@main"
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets
RUN cd axolotl && \
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[$AXOLOTL_EXTRAS]; \
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
else \
pip install -e .; \
pip install -e .[deepspeed,flash-attn]; \
fi
# fix so that git fetch/pull from remote works
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
git config --get remote.origin.fetch
# helper for huggingface-login cli
RUN git config --global credential.helper store

View File

@@ -8,93 +8,30 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION a
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.9"
ARG PYTORCH="2.0.0"
ARG PYTORCH_VERSION="2.0.1"
ARG CUDA="118"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update
RUN apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/*
RUN wget \
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh
RUN conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH} torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu$CUDA
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} deepspeed-kernels --extra-index-url https://download.pytorch.org/whl/cu$CUDA
FROM base-builder AS flash-attn-builder
WORKDIR /workspace
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
RUN git clone https://github.com/HazyResearch/flash-attention.git && \
cd flash-attention && \
python3 setup.py bdist_wheel && \
cd csrc/fused_dense_lib && \
python3 setup.py bdist_wheel && \
cd ../xentropy && \
python3 setup.py bdist_wheel && \
cd ../rotary && \
python3 setup.py bdist_wheel && \
cd ../layer_norm && \
python3 setup.py bdist_wheel
FROM base-builder AS deepspeed-builder
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
WORKDIR /workspace
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
cd DeepSpeed && \
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 python3 setup.py bdist_wheel
FROM base-builder AS bnb-builder
WORKDIR /workspace
ARG CUDA="118"
ENV CUDA=$CUDA
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
cd bitsandbytes && \
CUDA_VERSION=$CUDA make cuda11x && \
python setup.py bdist_wheel
FROM base-builder
# recompile apex
RUN python3 -m pip uninstall -y apex
RUN git clone https://github.com/NVIDIA/apex
# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
RUN cd apex && MAX_JOBS=1 python3 -m pip install --global-option="--cpp_ext" --global-option="--cuda_ext" --no-cache -v --disable-pip-version-check .
RUN mkdir -p /workspace/builds
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
RUN mkdir -p /workspace/wheels/bitsandbytes
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy_cuda_lib-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary_emb-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels
RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xentropy_cuda_lib-*.whl wheels/rotary_emb-*.whl wheels/dropout_layer_norm-*.whl
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
RUN git lfs install --skip-repo
RUN pip3 install awscli && \
RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic
pip3 install -U --no-cache-dir pydantic==1.10.10

View File

@@ -1,6 +1,10 @@
ARG BASE_TAG=main
FROM winglian/axolotl:$BASE_TAG
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
RUN apt install --yes --no-install-recommends openssh-server tmux && \

18
docs/faq.md Normal file
View File

@@ -0,0 +1,18 @@
# Axolotl FAQ's
> The trainer stopped and hasn't progressed in several minutes.
Usually an issue with the GPU's communicating with each other. See the [NCCL doc](../docs/nccl.md)
> Exitcode -9
This usually happens when you run out of system RAM.
> Exitcode -7 while using deepspeed
Try upgrading deepspeed w: `pip install -U deepspeed`
> AttributeError: 'DummyOptim' object has no attribute 'step'
You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.

45
docs/multi-node.md Normal file
View File

@@ -0,0 +1,45 @@
# Multi Node
You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:
~/.cache/huggingface/accelerate/default_config.yaml
```yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
machine_rank: 0 # Set to 0 for the main machine, increment by one for other machines
main_process_ip: 10.0.0.4 # Set to main machine's IP
main_process_port: 5000
main_training_function: main
mixed_precision: bf16
num_machines: 2 # Change to the number of machines
num_processes: 4 # That's the total number of GPUs, (for example: if you have 2 machines with 4 GPU, put 8)
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
Configure your model to use FSDP with for example:
```yaml
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_offload_params: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Machine configuration
On each machine you need a copy of Axolotl, we suggest using the same commit to ensure compatibility.
You will also need to have the same configuration file for your model on each machine.
On the main machine only, make sure the port you set as `main_process_port` is open in TCP and reachable by other machines.
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.

51
docs/multipack.md Normal file
View File

@@ -0,0 +1,51 @@
# Multipack
4k context, bsz =4,
each character represents 256 tokens
X represents a padding token
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[ A A A A A A A A A A A ]
B B B B B B ]
C C C C C C C ]
D D D D ]]
[[ E E E E E E E E ]
[ F F F F ]
[ G G G ]
[ H H H H ]]
[[ I I I ]
[ J J J ]
[ K K K K K]
[ L L L ]]
```
after padding to longest input in each step
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[ A A A A A A A A A A A ]
B B B B B B X X X X X X ]
C C C C C C C X X X X ]
D D D D X X X X X X X ]]
[[ E E E E E E E E ]
[ F F F F X X X X ]
[ G G G X X X X X ]
[ H H H H X X X X ]]
[[ I I I X X ]
[ J J J X X ]
[ K K K K K ]
[ L L L X X ]]
```
w packing ( note it's the same effective number of tokens per step, but a true bsz of 1)
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[ A A A A A A A A A A A B B B B B
B C C C C C C C D D D D E E E E
E E E E F F F F F G G G H H H H
I I I J J J J K K K K K L L L X ]]
```

46
docs/nccl.md Normal file
View File

@@ -0,0 +1,46 @@
# NCCL
NVIDIA NCCL is a library to facilitate and optimize multi-GPU communication operations, such as broadcast, all-gather, reduce, all-reduce, etc. Broadly, NCCL configuration is highly environment-specific and is configured via several [environment variables](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html). A common NCCL-related problem occurs when a long-running operation times out causing the training process to abort:
```text
Watchdog caught collective operation timeout: WorkNCCL(SeqNum=42, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1806948 milliseconds before timing out.
```
Often, this timeout will happen after 30 minutes (the default setting) and is accompanied by below-average power consumption with near 100% GPU utilization before the error is raised. Nvidia recommends [disabling PCI access control services (ACS)](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#pci-access-control-services-acs) as a possible solution if this is available to you.
Forcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command:
```shell
nvidia-smi nvlink --status
```
To force NCCL to use NVLink, simply set this in the environment:
```shell
export NCCL_P2P_LEVEL=NVL
```
If NVLink is not available in your environment there are other options for ``NCCL_P2P_LEVEL`` in the table below:
| NCCL_P2P_LEVEL | Description |
| -------------- | ----------- |
| PIX | P2P data transfers through no more than a single PCIe bridge. Faster data transfer rates vs to paths involving multiple bridges, but slower compared to direct GPU-to-GPU communication. |
| PXB | P2P data transfers through multiple PCIe bridges but not going through the PCIe Host Bridge; this path involves a complex routing process, potentially incurring a moderate level of latency. |
| PHB | P2P data transfers occur over the PCIe and through a PCIe Host Bridge, typically involving the CPU, which can facilitate direct memory access but might introduce additional latency compared to more direct paths (ex PIX, NVL) |
To validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example:
```shell
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
```
It can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL:
```shell
export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL
export TORCH_DISTRIBUTED_DEBUG=INFO
export TORCHELASTIC_ERROR_FILE=/PATH/TO/torcherror.log
```
Finally, if you believe your training job needs more time you can increase the timeout past 30 minutes by setting the ``ddp_timeout`` value in the Axolotl configuration. See [PyTorch init_process_group](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for documentation on this value.

View File

@@ -0,0 +1,89 @@
base_model: cerebras/btlm-3b-8k-base
model_type: AutoModelForCausalLM
tokenizer_type: GPT2Tokenizer
trust_remote_code: true
tokenizer_use_fast: true
tokenizer_legacy: true
load_in_8bit: false
load_in_4bit: false
strict: false
push_dataset_to_hub:
hf_use_auth_token: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_prepared_run
val_set_size: 0.05
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
sample_packing: false
sample_packing_eff_est:
sample_packing_seq_len_multiplier:
total_num_tokens:
lora_r:
lora_alpha:
lora_dropout:
lora_target_modules:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: btlm-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
adam_beta2: 0.95
adam_eps: 0.000000001
max_grad_norm: 1.0
torchdistx_path:
lr_scheduler: cosine
lr_quadratic_warmup: true
learning_rate: 0.000085
train_on_inputs: true
group_by_length: false
bf16: true
fp16: false
tf32: true
gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
sdp_attention:
flash_optimum:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 32
eval_steps:
save_steps:
save_total_limit:
debug:
deepspeed:
weight_decay: 0.1
special_tokens:
pad_token: "<|endoftext|>"
fsdp:
# - full_shard
# - auto_wrap
fsdp_config:
# fsdp_state_dict_type: FULL_STATE_DICT
# fsdp_transformer_layer_cls_to_wrap: BTLMBlock

View File

@@ -1,5 +1,4 @@
base_model: cerebras/Cerebras-GPT-1.3B
base_model_config: cerebras/Cerebras-GPT-1.3B
load_in_8bit: false
load_in_4bit: true
strict: false
@@ -7,8 +6,8 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
dataset_prepared_path:
val_set_size: 0.05
adapter: qlora
lora_model_dir:
sequence_len: 2048
@@ -23,6 +22,7 @@ lora_target_modules:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -35,7 +35,7 @@ torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: true
group_by_length: false
bf16: true
fp16: false
tf32: true
@@ -49,7 +49,7 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 20
eval_steps: 0.05
save_steps:
debug:
deepspeed:

View File

@@ -0,0 +1,67 @@
base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,69 @@
base_model: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,67 @@
base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,69 @@
base_model: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,67 @@
base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,69 @@
base_model: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,22 @@
# Overview
This is an example of CodeLLaMA configuration for 7b, 13b and 34b.
The 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes.
The 13b variant will fit if you change these settings to these values:
gradient_accumulation_steps: 2
micro_batch_size: 1
The 34b variant does not fit on 24GB of VRAM - you will need something with +40 gb VRAM that also supports flash attention v2 - A6000 or A100 are good choices.
```shell
accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/qlora.yml
```
or
```shell
accelerate launch scripts/finetune.py examples/code-llama/[MODEL_SIZE]/lora.yml
```

View File

@@ -1,8 +1,8 @@
base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: true
load_in_4bit: false
gptq: false
@@ -11,8 +11,8 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca:chat
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
dataset_prepared_path:
val_set_size: 0.05
adapter: lora
lora_model_dir:
sequence_len: 2048
@@ -24,6 +24,7 @@ lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -1,11 +1,11 @@
# 1b: tiiuae/falcon-rw-1b
# 40b: tiiuae/falcon-40b
base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true
@@ -17,8 +17,8 @@ datasets:
data_files:
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
type: "alpaca:chat"
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
dataset_prepared_path:
val_set_size: 0.05
# enable QLoRA
adapter: qlora
lora_model_dir:
@@ -38,6 +38,7 @@ lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -52,7 +53,7 @@ output_dir: ./qlora-out
# decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1
gradient_accumulation_steps: 2
num_epochs: 3
num_epochs: 4
# Optimizer for QLoRA
optimizer: paged_adamw_32bit
torchdistx_path:

View File

@@ -1,8 +1,8 @@
base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_falcon_derived_model: true
load_in_8bit: false
load_in_4bit: false
gptq: false
@@ -11,8 +11,8 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca:chat
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
dataset_prepared_path:
val_set_size: 0.05
adapter:
lora_model_dir:
sequence_len: 2048
@@ -24,6 +24,7 @@ lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

View File

@@ -1,5 +1,4 @@
base_model: EleutherAI/gpt-j-6b
base_model_config: EleutherAI/gpt-j-6b
load_in_8bit: false
load_in_4bit: true
strict: false
@@ -7,8 +6,8 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
dataset_prepared_path:
val_set_size: 0.05
adapter: qlora
lora_model_dir:
sequence_len: 2048
@@ -20,6 +19,7 @@ lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
@@ -32,7 +32,7 @@ torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0001
train_on_inputs: false
group_by_length: true
group_by_length: false
bf16: true
fp16: false
tf32: true
@@ -46,7 +46,7 @@ flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 20
eval_steps: 0.05
save_steps:
debug:
deepspeed:

View File

@@ -1,8 +0,0 @@
# LLaMa 7B using LoRA
This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
```shell
accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml
```

View File

@@ -1,62 +0,0 @@
base_model: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
base_model_config: Neko-Institute-of-Science/LLaMA-7B-4bit-128g
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code:
load_in_8bit: true
gptq: true
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
lora_fan_in_fan_out: false
wandb_project: llama-7b-lora-int4
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./llama-7b-lora-int4
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0000002
train_on_inputs: false
group_by_length: false
fp16: true
bf16: false
tf32: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 5
xformers_attention:
flash_attention:
gradient_checkpointing: true
gptq_groupsize: 128
gptq_model_v1: false
warmup_steps: 20
eval_steps: 110
save_steps: 660
debug:
deepspeed:
weight_decay: 0.0001
fsdp:
fsdp_config:
tokens:
pad_token: "[PAD]"
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,12 +1,11 @@
base_model: huggyllama/llama-7b
base_model_config: huggyllama/llama-7b
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
datasets:
- path: openaccess-ai-collective/jeopardy
type: jeopardy
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.02
adapter:
lora_model_dir:
@@ -18,13 +17,14 @@ lora_dropout:
lora_target_modules:
lora_fan_in_fan_out: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./jeopardy-bot-7b
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
num_epochs: 4
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine

View File

@@ -0,0 +1,24 @@
# Overview
This is an example of a llama-2 configuration for 7b and 13b. The yaml file contains configuration for the 7b variant, but you can just aswell use the same settings for 13b.
The 7b variant fits on any 24GB VRAM GPU and will take up about 17 GB of VRAM during training if using qlora and 20 GB if using lora. On a RTX 4090 it trains 3 epochs of the default dataset in about 15 minutes.
The 13b variant will fit if you change these settings to these values:
gradient_accumulation_steps: 2
micro_batch_size: 1
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml
```
or
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml
```
To launch a full finetuning with 16-bit precision:
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
```

View File

@@ -0,0 +1,72 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_steps: 100
eval_steps: 0.05
eval_table_size:
save_steps:
debug:
deepspeed: #deepspeed/zero2.json # multi-gpu only
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,73 @@
base_model: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false
gptq: true
gptq_disable_exllama: true
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
tokenizer_use_fast: true
tokenizer_legacy: true
load_in_8bit: false
load_in_4bit: false
strict: false
push_dataset_to_hub:
hf_use_auth_token: true
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
adapter: lora
lora_model_dir:
sequence_len: 4096
sample_packing:
lora_r: 8
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- k_proj
- o_proj
- q_proj
- v_proj
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./model-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
adam_beta2: 0.95
adam_eps: 0.00001
max_grad_norm: 1.0
torchdistx_path:
lr_scheduler: cosine
lr_quadratic_warmup: true
learning_rate: 0.000017
train_on_inputs: false
group_by_length: false
bf16: false
fp16: false
float16: true
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
sdp_attention:
flash_optimum:
warmup_steps: 100
eval_steps:
save_steps:
debug:
deepspeed:
weight_decay: 0.1
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

69
examples/llama-2/lora.yml Normal file
View File

@@ -0,0 +1,69 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,70 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: paged_adamw_32bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
eval_table_size:
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,73 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./relora-out
adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
relora_steps: 150
relora_warmup_steps: 10
relora_cpu_offload: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
save_steps: 50
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,69 @@
base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/context-aware-splits-english
type: alpaca
dataset_prepared_path:
val_set_size: 200
output_dir: ./tiny-llama
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 8
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 50
eval_steps: 0.05
eval_table_size:
save_steps: 0.50
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,12 @@
**Mistral 7B** is a language model with a total of 7.3 billion parameters, showcasing a notable performance across a variety of benchmarks.
Fine Tune:
```shell
accelerate launch -m axolotl.cli.train examples/mistral/config.yml
```
If you run into CUDA OOM, use deepspeed with config zero2.json:
```shell
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json
```

View File

@@ -0,0 +1,61 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.000005
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,78 @@
base_model: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
eval_steps: 0.05
eval_table_size:
eval_table_max_new_tokens: 128
save_steps:
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,12 +1,11 @@
base_model: mosaicml/mpt-7b
base_model_config: mosaicml/mpt-7b
tokenizer_type: AutoTokenizer
trust_remote_code: true # required for mpt as their model class is not merged into transformers yet
load_in_8bit: false
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.02
adapter:
lora_model_dir:
@@ -20,13 +19,14 @@ lora_target_modules:
- v_proj
lora_fan_in_fan_out: false
wandb_project: mpt-alpaca-7b
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./mpt-alpaca-7b
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
num_epochs: 4
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine

View File

@@ -1,5 +1,4 @@
base_model: openlm-research/open_llama_3b
base_model_config: openlm-research/open_llama_3b
base_model: openlm-research/open_llama_3b_v2
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
@@ -9,12 +8,12 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.02
adapter:
lora_model_dir:
sequence_len: 256
max_packed_sequence_len:
sequence_len: 1024
sample_packing: true
lora_r:
lora_alpha:
lora_dropout:
@@ -22,17 +21,18 @@ lora_target_modules:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./openllama-out
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
num_epochs: 4
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.00001
learning_rate: 0.000003
train_on_inputs: false
group_by_length: false
float16: true
@@ -44,12 +44,12 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
xformers_attention:
flash_attention: true
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 50
warmup_steps: 20
eval_steps: 0.05
save_steps:
debug:
deepspeed:

View File

@@ -1,5 +1,4 @@
base_model: openlm-research/open_llama_3b
base_model_config: openlm-research/open_llama_3b
base_model: openlm-research/open_llama_3b_v2
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: true
@@ -9,12 +8,12 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.02
adapter: lora
lora_model_dir:
sequence_len: 256
max_packed_sequence_len:
sequence_len: 1024
sample_packing: true
lora_r: 8
lora_alpha: 16
lora_dropout: 0.0
@@ -28,13 +27,14 @@ lora_target_modules:
- o_proj
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-out
batch_size: 16
micro_batch_size: 4
num_epochs: 3
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine
@@ -49,16 +49,16 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
xformers_attention:
flash_attention: true
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 50
warmup_steps: 20
eval_steps: 0.05
save_steps:
debug:
deepspeed:
weight_decay: 0.0
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:

View File

@@ -1,5 +1,4 @@
base_model: openlm-research/open_llama_3b
base_model_config: openlm-research/open_llama_3b
base_model: openlm-research/open_llama_3b_v2
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
load_in_8bit: false
@@ -9,12 +8,12 @@ push_dataset_to_hub:
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
dataset_prepared_path:
val_set_size: 0.05
adapter: qlora
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
sequence_len: 1024
sample_packing: true
lora_r: 8
lora_alpha: 32
lora_dropout: 0.05
@@ -22,37 +21,38 @@ lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
batch_size: 4
micro_batch_size: 4
num_epochs: 2
gradient_accumulation_steps: 1
micro_batch_size: 2
num_epochs: 4
optimizer: paged_adamw_32bit
torchdistx_path:
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: true
bf16: true
fp16: false
tf32: true
group_by_length: false
bf16: false
fp16: true
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
xformers_attention:
flash_attention: true
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 20
warmup_steps: 20
eval_steps: 0.05
save_steps:
debug:
deepspeed:
weight_decay: 0.0
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:

11
examples/phi/README.md Normal file
View File

@@ -0,0 +1,11 @@
# Phi
Due to some nuances with the phi code, please use deepspeed when training phi for full finetune.
```shell
accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed/zero1.json
# OR
python -m axolotl.cli.train examples/phi/phi-qlora.yml
```

74
examples/phi/phi-ft.yml Normal file
View File

@@ -0,0 +1,74 @@
base_model: microsoft/phi-1_5
model_type: MixFormerSequentialForCausalLM
tokenizer_type: AutoTokenizer
is_llama_derived_model: false
trust_remote_code: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./phi-sft-out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len:
adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0
lr_scheduler: cosine
learning_rate: 0.000003
train_on_inputs: false
group_by_length: true
bf16: true
fp16: false
tf32: true
gradient_checkpointing:
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 100
eval_steps: 0.05
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
resize_token_embeddings_to_32x: true
special_tokens:
bos_token: "<|endoftext|>"
eos_token: "<|endoftext|>"
unk_token: "<|endoftext|>"
pad_token: "<|endoftext|>"

View File

@@ -0,0 +1,74 @@
base_model: microsoft/phi-1_5
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
is_llama_derived_model: false
trust_remote_code: true
load_in_8bit: false
load_in_4bit: true
strict: false
datasets:
- path: garage-bAInd/Open-Platypus
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./phi-sft-out
sequence_len: 1024
sample_packing: false # not CURRENTLY compatible with LoRAs
pad_to_sequence_len:
adapter: qlora
lora_model_dir:
lora_r: 64
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 4
optimizer: adamw_torch
adam_beta2: 0.95
adam_epsilon: 0.00001
max_grad_norm: 1.0
lr_scheduler: cosine
learning_rate: 0.000003
train_on_inputs: false
group_by_length: true
bf16: true
fp16: false
tf32: true
gradient_checkpointing:
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention:
warmup_steps: 100
eval_steps: 0.05
save_steps:
debug:
deepspeed:
weight_decay: 0.1
fsdp:
fsdp_config:
resize_token_embeddings_to_32x: true
special_tokens:
bos_token: "<|endoftext|>"
eos_token: "<|endoftext|>"
unk_token: "<|endoftext|>"
pad_token: "<|endoftext|>"

View File

@@ -0,0 +1,9 @@
# Pythia 12B
- Single-GPU A100 only (?)
```shell
python scripts/finetune.py examples/pythia-12b/config.yml
```
⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️

View File

@@ -0,0 +1,48 @@
base_model: EleutherAI/pythia-12b-deduped
base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
load_in_4bit: false
gptq: false
device_map: auto
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
adapter:
lora_model_dir:
sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 64
lora_alpha: 32
lora_dropout: 0.0
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./pythia-12b
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 5
learning_rate: 0.00003
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
train_on_inputs: false
group_by_length: false
bf16: false
fp16: false
float16: true
tf32: true
flash_optimum: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
gradient_checkpointing: true
fsdp:
fsdp_config:

View File

@@ -1,10 +1,9 @@
base_model: EleutherAI/pythia-1.4b-deduped
base_model_config: EleutherAI/pythia-1.4b-deduped
load_in_8bit: true
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.05
adapter: lora
lora_model_dir:
@@ -17,21 +16,22 @@ lora_target_modules:
lora_target_linear:
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1
micro_batch_size: 4
num_epochs: 3
num_epochs: 4
learning_rate: 0.00001
train_on_inputs: false
group_by_length: false
bf16: True
tf32: True
bf16: true
tf32: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
weight_decay: 0.1
eval_steps: 20
eval_steps: 0.05
logging_steps: 1

View File

@@ -1,13 +1,12 @@
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
model_type: GPTNeoXForCausalLM
tokenizer_type: GPTNeoXTokenizer
tokenizer_type: AutoTokenizer
trust_remote_code:
load_in_8bit: false
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.02
adapter:
lora_model_dir:
@@ -21,13 +20,14 @@ lora_target_modules:
- v_proj
lora_fan_in_fan_out: false
wandb_project: redpajama-alpaca-3b
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./redpajama-alpaca-3b
batch_size: 4
micro_batch_size: 1
num_epochs: 3
num_epochs: 4
optimizer: adamw_bnb_8bit
torchdistx_path:
lr_scheduler: cosine

View File

@@ -1,11 +1,10 @@
base_model: replit/replit-code-v1-3b
base_model_config: replit/replit-code-v1-3b
trust_remote_code: true
load_in_8bit: false
datasets:
- path: vicgalle/alpaca-gpt4
type: alpaca
dataset_prepared_path: last_run_prepared
dataset_prepared_path:
val_set_size: 0.05
adapter: lora
lora_model_dir:
@@ -20,13 +19,14 @@ lora_target_modules:
- mlp_down
lora_fan_in_fan_out:
wandb_project: lora-replit
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./lora-replit
batch_size: 8
micro_batch_size: 1
num_epochs: 3
num_epochs: 4
optimizer:
torchdistx_path:
lr_scheduler:

View File

@@ -0,0 +1,90 @@
# An example finetuning Saleforce's XGen-7b model with 8k context using qlora
# on Tim Dettmer's Guanaco dataset.
base_model: Salesforce/xgen-7b-8k-base
trust_remote_code: true
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
load_in_8bit: false
# enable 4bit for QLoRA
load_in_4bit: true
gptq: false
strict: false
push_dataset_to_hub:
datasets:
- path: timdettmers/openassistant-guanaco
data_files:
- openassistant_best_replies_train.jsonl
type: "completion"
dataset_prepared_path:
val_set_size: 0.05
# enable QLoRA
adapter: qlora
lora_model_dir:
sequence_len: 8192
max_packed_sequence_len:
# hyperparameters from QLoRA paper Appendix B.2
# "We find hyperparameters to be largely robust across datasets"
lora_r: 64
lora_alpha: 16
# 0.1 for models up to 13B
# 0.05 for 33B and 65B models
lora_dropout: 0.05
# add LoRA modules on all linear layers of the base model
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
output_dir: ./qlora-out
# QLoRA paper Table 9
# - 16 for 7b & 13b
# - 32 for 33b, 64 for 64b
# Max size tested on A6000
# - 7b: 40
# - 40b: 4
# decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1
gradient_accumulation_steps: 1
num_epochs: 4
# Optimizer for QLoRA
optimizer: paged_adamw_32bit
torchdistx_path:
lr_scheduler: cosine
# QLoRA paper Table 9
# - 2e-4 for 7b & 13b
# - 1e-4 for 33b & 64b
learning_rate: 0.00002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
# stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3
resume_from_checkpoint:
auto_resume_from_checkpoints: true
local_rank:
logging_steps: 1
xformers_attention: true
flash_attention:
gptq_groupsize:
gptq_model_v1:
warmup_steps: 10
eval_steps: 50
save_steps: 50
debug:
deepspeed:
weight_decay: 0.0
special_tokens:
eos_token: "<|endoftext|>"
bos_token: "<|endoftext|>"
unk_token: "<|endoftext|>"
pad_token: "<|endoftext|>"

BIN
image/sticker_fixed.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 370 KiB

View File

@@ -1,19 +1,40 @@
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.39.0
accelerate
--extra-index-url https://download.pytorch.org/whl/cu118
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
torch==2.0.1
auto-gptq==0.4.2
packaging
peft==0.6.0
transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697
bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
deepspeed
addict
fire
PyYAML==6.0
datasets
accelerate>=0.19.0
PyYAML>=6.0
datasets>=2.14.0
flash-attn>=2.3.0
sentencepiece
wandb
einops
xformers
xformers>=0.0.22
optimum==1.13.2
hf_transfer
colorama
numba
numpy>=1.24.4
# qlora things
bert-score==0.3.13
evaluate==0.4.0
rouge-score==0.1.2
scipy
scikit-learn==1.2.2
pynvml
art
fschat==0.2.29
gradio
tensorboard
# remote filesystems
s3fs
gcsfs
# adlfs

View File

@@ -1,49 +0,0 @@
"""Module to convert json file to jsonl"""
import os
import sys
from pathlib import Path
from typing import Optional, Union
import fire
from axolotl.convert import (
FileReader,
FileWriter,
JsonlSerializer,
JsonParser,
JsonToJsonlConverter,
StdoutWriter,
)
# add src to the pythonpath so we don't need to pip install this
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
def main(
file: Path,
output: Optional[Path] = None,
to_stdout: Optional[bool] = False,
):
"""
Convert a json file to jsonl
"""
file_reader = FileReader()
writer: Union[StdoutWriter, FileWriter]
if to_stdout or output is None:
writer = StdoutWriter()
else:
writer = FileWriter(output)
json_parser = JsonParser()
jsonl_serializer = JsonlSerializer()
converter = JsonToJsonlConverter(file_reader, writer, json_parser, jsonl_serializer)
converter.convert(file, output)
if __name__ == "__main__":
fire.Fire(main)

View File

@@ -1,329 +1,52 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import logging
import os
import random
import signal
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import fire
import torch
import yaml
from transformers import GenerationConfig, TextStreamer
import transformers
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
do_inference,
do_merge_lora,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.cli.shard import shard
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
# add src to the pythonpath so we don't need to pip install this
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.validation import validate_config
from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
LOG = logging.getLogger("axolotl.scripts.finetune")
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def choose_device(cfg):
def get_device():
try:
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"
if torch.backends.mps.is_available():
return "mps"
raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
if cfg.device_map != "auto":
if cfg.device.startswith("cuda"):
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to finish): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
def do_cli(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art()
LOG.warning(
str(
PendingDeprecationWarning(
"scripts/finetune.py will be replaced with calling axolotl.cli.train"
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def choose_config(path: Path):
yaml_files = list(path.glob("*.yml"))
if not yaml_files:
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = yaml_files[choice - 1]
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
return chosen_file
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
return not any(el in list2 for el in list1)
def train(
config: Path = Path("configs/"),
prepare_ds_only: bool = False,
**kwargs,
):
if Path(config).is_dir():
config = choose_config(config)
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
validate_config(cfg)
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
cfg.batch_size // cfg.micro_batch_size
)
cfg.batch_size = (
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.batch_size = cfg.batch_size * cfg.world_size
setup_wandb_env_vars(cfg)
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
if cfg.tf32:
torch.backends.cuda.matmul.allow_tf32 = True
# load the tokenizer first
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
logging.info(f"loading tokenizer... {tokenizer_config}")
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
if (
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if cfg.debug or "debug" in kwargs:
logging.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for _ in range(5)] # nosec
),
tokenizer,
)
if prepare_ds_only:
logging.info("Finished preparing dataset. Exiting...")
return
# Load the model and tokenizer
logging.info("loading model and peft_config...")
model, peft_config = load_model(
cfg.base_model,
cfg.base_model_config,
cfg.model_type,
tokenizer,
cfg,
adapter=cfg.adapter,
)
if "merge_lora" in kwargs and cfg.adapter is not None:
logging.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
logging.info("saving merged model")
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
if cfg.inference:
logging.info("calling do_inference function")
prompter: Optional[str] = "AlpacaPrompter"
if "prompter" in kwargs:
if kwargs["prompter"] == "None":
prompter = None
else:
prompter = kwargs["prompter"]
do_inference(cfg, model, tokenizer, prompter=prompter)
return
if "shard" in kwargs:
model.save_pretrained(cfg.output_dir)
return
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
logging.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
signal.signal(
signal.SIGINT,
lambda signal, frame: (
model.save_pretrained(cfg.output_dir),
sys.exit(0),
),
)
logging.info("Starting trainer...")
if cfg.group_by_length:
logging.info("hang tight... sorting dataset for group_by_length")
resume_from_checkpoint = cfg.resume_from_checkpoint
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
resume_from_checkpoint = sorted_paths[-1]
logging.info(
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
)
if not Path(cfg.output_dir).is_dir():
os.makedirs(cfg.output_dir, exist_ok=True)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0:
model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
if parsed_cli_args.inference:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.merge_lora:
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.shard:
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__":
fire.Fire(train)
fire.Fire(do_cli)

19
scripts/runpod-entrypoint.sh Normal file → Executable file
View File

@@ -1,10 +1,21 @@
#!/bin/bash
echo $PUBLIC_KEY >> ~/.ssh/authorized_keys
chmod 700 -R ~/.ssh
# Export specific ENV variables to /etc/rp_environment
echo "Exporting environment variables..."
printenv | grep -E '^RUNPOD_|^PATH=|^_=' | sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >> /etc/rp_environment
echo 'source /etc/rp_environment' >> ~/.bashrc
# Start the SSH service in the background
service ssh start
if [[ $PUBLIC_KEY ]]
then
mkdir -p ~/.ssh
chmod 700 ~/.ssh
echo $PUBLIC_KEY >> ~/.ssh/authorized_keys
chmod 700 -R ~/.ssh
# Start the SSH service in the background
service ssh start
else
echo "No PUBLIC_KEY ENV variable provided, not starting openSSH daemon"
fi
# Execute the passed arguments (CMD)
exec "$@"

View File

@@ -2,31 +2,53 @@
from setuptools import find_packages, setup
install_requires = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
# don't include peft yet until we check the int4
# need to manually install peft for now...
reqs = [r.strip() for r in requirements_file.readlines() if "peft" not in r]
reqs = [r for r in reqs if r and r[0] != "#"]
for r in reqs:
install_requires.append(r)
def parse_requirements():
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif (
"flash-attn" not in line
and "deepspeed" not in line
and line
and line[0] != "#"
):
# Handle standard packages
_install_requires.append(line)
# TODO(wing) remove once xformers release supports torch 2.1.0
if "torch==2.1.0" in _install_requires:
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
_install_requires.append(
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
)
return _install_requires, _dependency_links
install_requires, dependency_links = parse_requirements()
setup(
name="axolotl",
version="0.1",
description="You know you're going to axolotl questions",
version="0.3.0",
description="LLM Trainer",
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
package_dir={"": "src"},
packages=find_packages(),
install_requires=install_requires,
dependency_links=dependency_links,
extras_require={
"gptq": [
"alpaca_lora_4bit @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
"flash-attn": [
"flash-attn>=2.3.0",
],
"gptq_triton": [
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
],
"extras": [
"flash-attn",
"deepspeed": [
"deepspeed",
],
},

358
src/axolotl/cli/__init__.py Normal file
View File

@@ -0,0 +1,358 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import logging
import os
import random
import sys
from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union
import gradio as gr
import torch
import yaml
# add src to the pythonpath so we don't need to pip install this
from accelerate.commands.config import config_args
from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
def print_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(ascii_text, font=font)
if is_main_process():
print(ascii_art)
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
def do_merge_lora(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def do_inference(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
model = model.to(cfg.device)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})
prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)
model = model.to(cfg.device)
def generate(instruction):
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
all_text = ""
for new_text in streamer:
all_text += new_text
yield all_text
demo = gr.Interface(
fn=generate,
inputs="textbox",
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(show_api=False, share=True)
def choose_config(path: Path):
yaml_files = list(path.glob("*.yml"))
if not yaml_files:
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return yaml_files[0]
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = yaml_files[choice - 1]
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
return chosen_file
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
return not any(el in list2 for el in list1)
def load_cfg(config: Path = Path("examples/"), **kwargs):
if Path(config).is_dir():
config = choose_config(config)
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
validate_config(cfg)
normalize_config(cfg)
setup_wandb_env_vars(cfg)
return cfg
def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg, tokenizer
)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
LOG.info("printing prompters...")
for prompter in prompters:
LOG.info(prompter)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def check_accelerate_default_config():
if Path(config_args.default_yaml_config_file).exists():
LOG.warning(
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
)
def check_user_token():
# Verify if token is valid
api = HfApi()
try:
user_info = api.whoami()
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False

View File

@@ -0,0 +1,36 @@
"""
CLI to run inference on a trained model
"""
from pathlib import Path
import fire
import transformers
from axolotl.cli import (
do_inference,
do_inference_gradio,
load_cfg,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.inference = True
if gradio:
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

@@ -0,0 +1,27 @@
"""
CLI to run merge a trained LoRA into a base model
"""
from pathlib import Path
import fire
import transformers
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

@@ -0,0 +1,53 @@
"""
CLI to run training on a model
"""
import logging
from pathlib import Path
import fire
import transformers
from colorama import Fore
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
LOG = logging.getLogger("axolotl.cli.preprocess")
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "preprocess CLI called without dataset_prepared_path set, "
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
LOG.info(
Fore.GREEN
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
+ Fore.RESET
)
if __name__ == "__main__":
fire.Fire(do_cli)

42
src/axolotl/cli/shard.py Normal file
View File

@@ -0,0 +1,42 @@
"""
CLI to shard a trained model into 10GiB chunks
"""
import logging
from pathlib import Path
import fire
import transformers
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.scripts")
def shard(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.shard = True
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__":
fire.Fire(do_cli)

38
src/axolotl/cli/train.py Normal file
View File

@@ -0,0 +1,38 @@
"""
CLI to run training on a model
"""
import logging
from pathlib import Path
import fire
import transformers
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

54
src/axolotl/common/cli.py Normal file
View File

@@ -0,0 +1,54 @@
"""
shared module for cli specific things
"""
import logging
from dataclasses import dataclass, field
from typing import Optional
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger("axolotl.common.cli")
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=5)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model and (optionally) peft_config...")
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
return model, tokenizer

View File

@@ -0,0 +1,5 @@
"""
Various shared constants
"""
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"

View File

View File

@@ -0,0 +1,748 @@
"""
Builder for the training args and trainer
"""
import abc
import importlib
import logging
import math
import sys
from abc import abstractmethod
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional
import torch
import transformers
from datasets import Dataset
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_utils import seed_worker
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
pass
LOG = logging.getLogger("axolotl.core.trainer_builder")
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
Extend the base TrainingArguments for axolotl helpers
"""
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
sample_packing_seq_len_multiplier: int = field(
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
class AxolotlTrainer(Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
):
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing:
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
self.args.train_batch_size,
drop_last=True,
batch_max_len=self._train_batch_size * self.args.max_seq_length,
lengths=(
self.train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
self.args.per_device_eval_batch_size,
drop_last=True,
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
lengths=(
eval_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing:
train_dataset = self.train_dataset
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> DataLoader:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
pct_start=pct_start,
div_factor=6,
)
return self.lr_scheduler
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
"""
_train_dataset = None
_eval_dataset = None
def __init__(self, cfg, model, tokenizer):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
@property
def train_dataset(self):
return self._train_dataset
@train_dataset.setter
def train_dataset(self, dataset):
self._train_dataset = dataset
@property
def eval_dataset(self):
return self._eval_dataset
@eval_dataset.setter
def eval_dataset(self, dataset):
self._eval_dataset = dataset
@abstractmethod
def build(self, total_num_steps):
pass
@abstractmethod
def get_callbacks(self):
pass
@abstractmethod
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for Causal models
"""
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
return training_arguments_kwargs
def hook_post_create_training_args(self, training_arguments):
# TODO
return training_arguments
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
# TODO
return trainer_kwargs, trainer_cls
def hook_post_create_trainer(self, trainer):
# TODO
return trainer
def get_callbacks(self):
callbacks = []
callbacks.append(GPUStatsCallback(self.cfg))
callbacks.append(EvalFirstStepCallback)
if self.cfg.relora_steps:
callbacks.append(ReLoRACallback(self.cfg))
if (
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback)
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
self.cfg.early_stopping_patience,
)
callbacks.append(early_stop_cb)
return callbacks
def _get_trainer_cls(self):
if self.cfg.lr_scheduler == "one_cycle" and (
self.cfg.fsdp or self.cfg.adapter == "qlora"
):
return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps:
return ReLoRATrainer
return AxolotlTrainer
def build(self, total_num_steps):
warmup_steps = (
self.cfg.warmup_steps
if self.cfg.warmup_steps is not None
else min(int(0.03 * total_num_steps), 100)
)
logging_steps = (
self.cfg.logging_steps
if self.cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_arguments_kwargs = {}
if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
else:
training_arguments_kwargs["bf16"] = self.cfg.bf16
training_arguments_kwargs["fp16"] = (
self.cfg.fp16 and not self.cfg.bf16
) or False
training_arguments_kwargs["tf32"] = self.cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
if self.cfg.seed:
training_arguments_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing:
training_arguments_kwargs[
"gradient_checkpointing"
] = self.cfg.gradient_checkpointing
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs[
"lr_quadratic_warmup"
] = self.cfg.lr_quadratic_warmup
if self.cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
if self.cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
if self.cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
if self.cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
if self.cfg.hub_model_id:
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if self.cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = self.cfg.sample_packing_eff_est
if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs[
"dataloader_pin_memory"
] = self.cfg.dataloader_pin_memory
if self.cfg.dataloader_num_workers is not None:
training_arguments_kwargs[
"dataloader_num_workers"
] = self.cfg.dataloader_num_workers
if self.cfg.dataloader_prefetch_factor is not None:
training_arguments_kwargs[
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor
if self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
elif self.cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.evaluation_strategy:
training_arguments_kwargs[
"evaluation_strategy"
] = self.cfg.evaluation_strategy
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
if self.cfg.save_steps:
training_arguments_kwargs["save_strategy"] = "steps"
training_arguments_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_arguments_kwargs["save_strategy"] = "epoch"
if self.cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.metric_for_best_model:
training_arguments_kwargs[
"metric_for_best_model"
] = self.cfg.metric_for_best_model
if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
if self.cfg.torch_compile:
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
elif torch._dynamo: # pylint: disable=protected-access
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = self.cfg.torch_compile_backend
# DDP Config
if self.cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs[
"ddp_broadcast_buffers"
] = self.cfg.ddp_broadcast_buffers
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1
)
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs[
"per_device_train_batch_size"
] = self.cfg.micro_batch_size
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs[
"eval_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
training_arguments_kwargs["save_total_limit"] = (
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
)
training_arguments_kwargs["load_best_model_at_end"] = (
(
self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience
)
and self.cfg.val_set_size > 0
and self.cfg.save_steps
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
) or False
training_arguments_kwargs["ddp_find_unused_parameters"] = (
False if self.cfg.ddp else None
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_run_id if self.cfg.use_wandb else None
)
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler
if self.cfg.lr_scheduler
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine"
)
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
training_arguments_kwargs["sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs["eval_sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs[
"sample_packing_seq_len_multiplier"
] = self.cfg.micro_batch_size
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
)
training_args = self.hook_post_create_training_args(training_args)
trainer_kwargs = {}
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
if self.cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
self.cfg.sequence_len / 64
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens,
get_mem_id,
set_model_mem_id,
)
set_model_mem_id(self.model, self.tokenizer)
LOG.info("Adding landmark attention tokens to dataset")
for dataset in [self.train_dataset, self.eval_dataset]:
dataset = dataset.map(
partial(
add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer)
),
batched=False,
num_proc=32,
)
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=self.get_callbacks(),
num_epochs=self.cfg.num_epochs,
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
if self.cfg.deepspeed and self.cfg.sample_packing:
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
"train_micro_batch_size_per_gpu"
] = self.cfg.micro_batch_size
return trainer

View File

@@ -1,12 +1,13 @@
"""Module containing Dataset functionality"""
import logging
from typing import List
import os
from typing import List, Optional
import torch
from datasets import IterableDataset
from datasets import Dataset, IterableDataset
from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
@@ -14,12 +15,14 @@ from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = logging.getLogger("axolotl")
class TokenizedPromptDataset(IterableDataset):
class TokenizedPromptDataset(Dataset):
"""
Iterable dataset that returns tokenized prompts from a stream of text files.
Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
dataset (dataset.Dataset): Dataset with text files.
"""
@@ -27,22 +30,30 @@ class TokenizedPromptDataset(IterableDataset):
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: IterableDataset,
process_count: Optional[int] = None,
**kwargs,
):
self.prompt_tokenizer = prompt_tokenizer
self.dataset = dataset
self.process_count = process_count
super().__init__(self.process(dataset).data, **kwargs)
def __iter__(self):
iterator = iter(self.dataset)
count = 0
# Loop through the entire dataset
for example in iterator:
try:
yield self.prompt_tokenizer.tokenize_prompt(example)
count += 1
except InvalidDataException:
pass
if count == 0:
raise RuntimeError("Expected at least one datapoint in dataset.")
def process(self, dataset):
features = dataset.features.keys()
num_proc = (
min(64, self.process_count)
if self.process_count
else min(64, os.cpu_count())
)
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
**map_kwargs,
)
# TODO this isn't the best since it can't interleave datasets
@@ -50,7 +61,7 @@ class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for proccessing the data.
tokenizer (Tokenizer): The processor used for processing the data.
dataset (dataset.Dataset): Dataset with text files.
seq_length (int): Length of token sequences to return.
"""
@@ -76,14 +87,21 @@ class ConstantLengthDataset(IterableDataset):
self.tokens_dtype = torch.int64
def __iter__(self):
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
@@ -105,6 +123,9 @@ class ConstantLengthDataset(IterableDataset):
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
@@ -113,29 +134,28 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
logging.warning(
LOG.warning(
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if (
buffer["input_ids"]
and input_ids[0] == self.tokenizer.bos_token_id
):
attention_mask[0] = 0
if add_concat_token:
input_ids.append(self.concat_token_id)
@@ -146,13 +166,17 @@ class ConstantLengthDataset(IterableDataset):
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
attention_mask, dtype=self.tokens_dtype
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)

View File

@@ -1,140 +0,0 @@
"""Flash attention monkey patch for llama model"""
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
from typing import Optional, Tuple
import torch
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
assert past_key_value is None, "past_key_value is not supported"
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"
# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# transform the data into the format required by flash attention
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(
0,
(bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=qkv.device,
)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
# pylint: disable=invalid-name
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad,
"nnz (three h d) -> nnz three h d",
three=3,
h=nheads,
)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
bsz,
q_len,
),
"b s (h d) -> b s h d",
h=nheads,
)
return (
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
None,
None,
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
): # pylint: disable=unused-argument
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward

View File

@@ -0,0 +1,71 @@
"""
Common logging module for axolotl
"""
import os
import sys
from logging import Formatter
from logging.config import dictConfig
from typing import Any, Dict
from colorama import Fore, Style, init
class ColorfulFormatter(Formatter):
"""
Formatter to add coloring to log messages by log type
"""
COLORS = {
"WARNING": Fore.YELLOW,
"ERROR": Fore.RED,
"CRITICAL": Fore.RED + Style.BRIGHT,
}
def format(self, record):
record.rank = int(os.getenv("LOCAL_RANK", "0"))
log_message = super().format(record)
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
"version": 1,
"formatters": {
"simple": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
},
"colorful": {
"()": ColorfulFormatter,
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
},
},
"filters": {},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"formatter": "simple",
"filters": [],
"stream": sys.stdout,
},
"color_console": {
"class": "logging.StreamHandler",
"formatter": "colorful",
"filters": [],
"stream": sys.stdout,
},
},
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
"loggers": {
"axolotl": {
"handlers": ["color_console"],
"level": "DEBUG",
"propagate": False,
},
},
}
def configure_logging():
"""Configure with default logging"""
init() # Initialize colorama
dictConfig(DEFAULT_LOGGING_CONFIG)

View File

View File

@@ -0,0 +1,6 @@
"""
MixFormers model architecture used for phi models
"""
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa

View File

@@ -0,0 +1,63 @@
# pylint: skip-file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
from typing import Any, Dict, List, Optional, Union
from transformers import PretrainedConfig
class MixFormerSequentialConfig(PretrainedConfig):
"""MixFormer (sequential for DeepSpeed) configuration."""
model_type = "mixformer-sequential"
attribute_map = {
"max_position_embeddings": "n_positions",
"hidden_size": "n_embd",
"num_attention_heads": "n_head",
"num_hidden_layers": "n_layer",
"input_emb_layer": "embd_layer", # `input_emb_layer` key is for backward compatibility
"blocks": "architecture", # `blocks` key is for backward compatibility
}
def __init__(
self,
vocab_size: Optional[int] = 50304,
n_positions: Optional[int] = 2048,
n_embd: Optional[int] = 1024,
n_layer: Optional[int] = 20,
n_inner: Optional[int] = None,
n_head: Optional[int] = 16,
rotary_dim: Optional[int] = 32,
activation_function: Optional[str] = "gelu_new",
embd_layer: Optional[str] = "default",
architecture: Union[Dict[str, Any], List[Dict[str, Any]]] = None,
embd_pdrop: Optional[float] = 0.0,
resid_pdrop: Optional[float] = 0.0,
layer_norm_epsilon: Optional[float] = 1e-5,
initializer_range: Optional[float] = 0.02,
tie_word_embeddings: Optional[bool] = False,
pad_vocab_size_multiple: Optional[int] = 64,
**kwargs
) -> None:
self.vocab_size = int(
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
)
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.rotary_dim = min(rotary_dim, n_embd // n_head)
self.activation_function = activation_function
self.embd_layer = embd_layer
self.architecture = architecture
self.embd_pdrop = embd_pdrop
self.resid_pdrop = resid_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

View File

@@ -0,0 +1,930 @@
# pylint: skip-file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import annotations
import copy
import inspect
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from flash_attn.flash_attn_interface import (
flash_attn_kvpacked_func,
flash_attn_qkvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
from transformers import PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import CausalLMOutputWithPast
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids
from .configuration_mixformer_sequential import MixFormerSequentialConfig
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference.
Adapted from https://github.com/Dao-AILab/flash-attention."""
max_sequence_len: int
max_batch_size: int
sequence_len_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
fused_ft_kernel: bool = False
lengths_per_sample: Optional[torch.Tensor] = None
class Embedding(nn.Module):
"""Token embedding with dropout."""
def __init__(self, config: PretrainedConfig) -> None:
super().__init__()
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.wte(input_ids)
hidden_states = self.drop(hidden_states)
return hidden_states
class RotaryEmbedding(nn.Module):
"""PyTorch implementation of `flash-attn` RotaryEmbedding layer.
Adapted from https://github.com/Dao-AILab/flash-attention."""
def __init__(
self,
dim: int,
base: Optional[int] = 10000,
scale_base: Optional[float] = None,
device: Optional[str] = None,
**kwargs,
) -> None:
super().__init__()
if scale_base is not None:
raise NotImplementedError
# Generate and save the inverse frequency buffer (non-trainable)
self.dim = dim
self.base = base
self.scale_base = scale_base
self.device = device
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
self.register_buffer("inv_freq", inv_freq)
scale = (
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
/ (1.4 * dim)
if scale_base is not None
else None
)
self.register_buffer("scale", scale)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
self._cos_k_cached = None
self._sin_k_cached = None
def _update_cos_sin_cache(
self, x: torch.FloatTensor, seqlen_offset: Optional[int] = 0
) -> None:
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
seqlen = x.shape[1] + seqlen_offset
# Re-generate the inverse frequency buffer if it's not fp32
# (for instance if model.half() was called)
if self.inv_freq.dtype != "torch.float32":
self.inv_freq = 1.0 / (
self.base
** (
torch.arange(
0, self.dim, 2, device=self.device, dtype=torch.float32
)
/ self.dim
)
)
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != x.device
or self._cos_cached.dtype != x.dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=x.device, dtype=torch.float32)
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(
t, self.inv_freq.to(device=t.device, dtype=torch.float32)
)
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(x.dtype)
self._sin_cached = torch.sin(freqs).to(x.dtype)
else:
power = (
torch.arange(
seqlen, dtype=self.scale.dtype, device=self.scale.device
)
- seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(
power, "s -> s 1"
)
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
def apply_rotary_emb_qkv(
self,
qkv: torch.FloatTensor,
sin: torch.FloatTensor,
cos: torch.FloatTensor,
sin_k: Optional[torch.FloatTensor] = None,
cos_k: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
_, seqlen, three, _, headdim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
assert (
sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
)
q_rot = qkv[:, :, 0, :, :rotary_dim]
q_pass = qkv[:, :, 0, :, rotary_dim:]
k_rot = qkv[:, :, 1, :, :rotary_dim]
k_pass = qkv[:, :, 1, :, rotary_dim:]
# Splits the queries and keys in half
q1, q2 = q_rot.chunk(2, dim=-1)
k1, k2 = k_rot.chunk(2, dim=-1)
c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(
sin[:seqlen], "s d -> s 1 d"
)
# Casts to fp32 are necessary to prevent fp16 overflow issues
q1, q2, k1, k2, c, s = [
t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]
]
# Computes the new keys and queries, recasting to original dtype
q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
return torch.cat(
[
torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
qkv[:, :, 2:3, :, :],
],
axis=2,
)
def forward(
self, qkv: torch.Tensor, seqlen_offset: int = 0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform the forward pass.
Args:
qkv: Query, key and value tensors of shape (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim).
seqlen_offset: Used in generation where the passed `qkv` is only the last token in the batch.
Returns:
New `qkv` and the cached sinusoids.
"""
self._update_cos_sin_cache(qkv, seqlen_offset)
return self.apply_rotary_emb_qkv(
qkv, self._sin_cached[seqlen_offset:], self._cos_cached[seqlen_offset:]
)
def _update_kv_cache(kv, inference_params, layer_idx):
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
Adapted from https://github.com/Dao-AILab/flash-attention."""
# Pre-allocate memory for key-values for inference.
num_heads, head_dim = kv.shape[-2:]
if layer_idx not in inference_params.key_value_memory_dict:
kv_cache = torch.empty(
inference_params.max_batch_size,
inference_params.max_sequence_len,
2,
num_heads,
head_dim,
dtype=kv.dtype,
device=kv.device,
)
inference_params.key_value_memory_dict[layer_idx] = kv_cache
else:
kv_cache = inference_params.key_value_memory_dict[layer_idx]
# Adjust key and value for inference
batch_start = inference_params.batch_size_offset
batch_end = batch_start + kv.shape[0]
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + kv.shape[1]
assert batch_end <= (
kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0] # noqa
)
assert sequence_end <= (
kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2] # noqa
)
assert kv_cache is not None
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
kv = kv_cache[batch_start:batch_end, :sequence_end, ...]
return kv
class MLP(nn.Module):
"""Multi-Layer Perceptron.
Reference:
Attention Is All You Need.
https://arxiv.org/pdf/1706.03762.pdf.
"""
def __init__(
self,
config: PretrainedConfig,
n_inner: Optional[int] = None,
act_fn: Optional[str] = None,
) -> None:
super().__init__()
act_fn = config.activation_function if act_fn is None else act_fn
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
self.fc1 = nn.Linear(config.n_embd, n_inner)
self.fc2 = nn.Linear(n_inner, config.n_embd)
self.act = ACT2FN[act_fn]
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
old_keys = [
prefix + "fc_in.weight",
prefix + "fc_out.weight",
prefix + "fc_in.bias",
prefix + "fc_out.bias",
]
new_keys = [
prefix + "fc1.weight",
prefix + "fc2.weight",
prefix + "fc1.bias",
prefix + "fc2.bias",
]
if all(k in state_dict for k in old_keys) and not all(
k in state_dict for k in new_keys
):
# Older version of `MLP` saved with different key names.
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
return super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class FusedMLP(nn.Module):
"""Fused Multi-Layer Perceptron from `flash-attn`.
Reference:
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/ops/fused_dense.py.
"""
def __init__(
self,
config: PretrainedConfig,
n_inner: Optional[int] = None,
act_fn: Optional[str] = None,
raise_on_missing: bool = False,
) -> None:
super().__init__()
act_fn = config.activation_function if act_fn is None else act_fn
assert act_fn in ACT2FN.keys(), f"`act_fn` must be one of: {ACT2FN.keys()}."
n_inner = getattr(config, "n_inner", None) if n_inner is None else n_inner
n_inner = n_inner if n_inner is not None else 4 * config.n_embd
gelu_activations = ["gelu_new", "gelu_fast", "gelu_approx"] # noqa
activation = "gelu_approx" if act_fn in gelu_activations else "relu" # noqa
self.mlp = MLP(config, n_inner=n_inner, act_fn=act_fn)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
return self.mlp(hidden_states)
class SelfAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Adapted from https://github.com/Dao-AILab/flash-attention.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
def forward(
self, qkv, causal=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, S)
"""
causal = self.causal if causal is None else causal
if cu_seqlens is not None:
return flash_attn_varlen_qkvpacked_func(
qkv.squeeze(0),
cu_seqlens,
max_seqlen,
dropout_p=self.drop.p,
softmax_scale=self.softmax_scale,
causal=causal,
)
else:
return flash_attn_qkvpacked_func(
qkv,
dropout_p=self.drop.p,
softmax_scale=self.softmax_scale,
causal=causal,
)
class CrossAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Adapted from https://github.com/Dao-AILab/flash-attention.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
super().__init__()
self.causal = causal
self.softmax_scale = softmax_scale
self.drop = nn.Dropout(attention_dropout)
def forward(self, q, kv, causal=None, key_padding_mask=None):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk)
"""
causal = self.causal if causal is None else causal
return flash_attn_kvpacked_func(
q,
kv,
dropout_p=self.drop.p,
softmax_scale=self.softmax_scale,
causal=causal,
)
def find_mha_dims(
config: PretrainedConfig,
n_head: Optional[int] = None,
head_dim: Optional[int] = None,
) -> Tuple[int, int]:
"""Validate and return the number of heads and head dimension for multi-head attention.
Args:
config: Model configuration.
n_head: Number of heads.
head_dim: Head dimension.
Returns:
Number of heads and head dimension.
"""
assert all(
hasattr(config, attr) for attr in ["n_embd", "n_head"]
), "`config` must have `n_embd` and `n_head` attributes."
if head_dim is None:
assert (
config.n_embd % config.n_head == 0
), f"Hidden size ({config.n_embd}) must be divisible by the number of heads ({config.n_head})."
if n_head is None and head_dim is None:
head_dim = config.n_embd // config.n_head
n_head = config.n_head
elif n_head is None or head_dim is None:
raise ValueError("`n_head` and `head_dim` must be both specified or `None`.")
return n_head, head_dim
class MHA(nn.Module):
"""Multi-head attention layer.
Adapted from https://github.com/Dao-AILab/flash-attention."""
def __init__(
self,
config: PretrainedConfig,
rotary_dim: Optional[int] = None,
n_head: Optional[int] = None,
head_dim: Optional[int] = None,
bias: Optional[bool] = True,
dropout: Optional[float] = 0.0,
softmax_scale: Optional[float] = None,
causal: Optional[bool] = True,
layer_idx: Optional[int] = None,
rotary_emb_scale_base: Optional[float] = None,
return_residual: Optional[bool] = False,
checkpointing: Optional[bool] = False,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
fused_dense: Optional[bool] = True,
flash_attn: Optional[bool] = True,
cutlass_attn: Optional[bool] = False,
flash_rotary: Optional[bool] = True,
raise_on_missing: Optional[bool] = False,
) -> None:
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
n_head, head_dim = find_mha_dims(config, n_head, head_dim)
self.hidden_size = config.n_embd
self.n_head = n_head
self.head_dim = head_dim
self.op_size = n_head * head_dim
self.causal = causal
self.layer_idx = layer_idx
self.rotary_emb_dim = (
rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
)
self.fused_dense = fused_dense
self.flash_attn = flash_attn
self.cutlass_attn = cutlass_attn
self.flash_rotary = flash_rotary
self.return_residual = return_residual
self.checkpointing = checkpointing
if self.rotary_emb_dim > 0:
rotary_kwargs = {"device": device}
if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
rotary_kwargs["scale_base"] = rotary_emb_scale_base
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
else:
pass
self.Wqkv = nn.Linear(
self.hidden_size, 3 * self.op_size, bias=bias, **factory_kwargs
)
self.out_proj = nn.Linear(
self.op_size, self.hidden_size, bias=bias, **factory_kwargs
)
self.inner_attn = SelfAttention(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
self.inner_cross_attn = CrossAttention(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
)
def _update_kv_cache(
self, kv: torch.FloatTensor, inference_params: InferenceParams
) -> None:
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
Adapted from https://github.com/Dao-AILab/flash-attention."""
assert (
self.layer_idx is not None
), "Generation requires layer_idx in the constructor"
return _update_kv_cache(kv, inference_params, self.layer_idx)
def forward(
self,
x: torch.FloatTensor,
x_kv: Optional[torch.FloatTensor] = None,
key_padding_mask: Optional[torch.BoolTensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
mixer_subset: Optional[torch.LongTensor] = None,
past_cache: Optional[InferenceParams] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Perform the forward pass.
Args:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
is the is the sum of the sequence lengths in the batch.
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into x. Only applicable when using
FlashAttention.
max_seqlen: int. Maximum sequence length in the batch.
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
past_cache: For generation only.
Returns:
(batch, seqlen, hidden_dim) if cu_seqlens is None and max_seqlen is None,
else (total, hidden_dim) where total is the is the sum of the sequence lengths
in the batch.
"""
if cu_seqlens is not None:
assert max_seqlen is not None
assert key_padding_mask is None
assert self.flash_attn
# assert self.rotary_emb_dim == 0
if key_padding_mask is not None:
assert cu_seqlens is None
assert max_seqlen is None
assert not self.flash_attn
if past_cache is not None:
assert key_padding_mask is None
assert cu_seqlens is None and max_seqlen is None
attn_kwargs = {"key_padding_mask": key_padding_mask}
assert x_kv is None and mixer_subset is None
qkv = self.Wqkv(x)
qkv = rearrange(
qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
)
if past_cache is None:
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv)
context = self.inner_attn(
qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, **attn_kwargs
)
else:
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=past_cache.sequence_len_offset)
q = qkv[:, :, 0]
kv = self._update_kv_cache(qkv[:, :, 1:], past_cache)
# If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False.
causal = None if past_cache.sequence_len_offset == 0 else False
context = self.inner_cross_attn(q, kv, causal=causal)
out = rearrange(context, "... h d -> ... (h d)")
out = self.out_proj(out)
return out if not self.return_residual else (out, x)
class ParallelBlock(nn.Module):
"""Parallel block.
This block applies parallel mixer and MLP layers to the input (used in GPT-J and CodeGen).
"""
def __init__(
self,
config: PretrainedConfig,
mixer: Optional[Dict[str, Any]] = None,
mlp: Optional[Dict[str, Any]] = None,
block_idx: Optional[int] = None,
) -> None:
super().__init__()
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.block_idx = block_idx
self.mixer = MHA(config, layer_idx=block_idx)
self.mlp = MLP(config)
def forward(
self,
hidden_states: torch.FloatTensor,
past_cache: Optional[torch.FloatTensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
) -> torch.FloatTensor:
residual = hidden_states
hidden_states = self.ln(hidden_states)
attn_outputs = self.mixer(
hidden_states,
past_cache=past_cache,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
if isinstance(attn_outputs, tuple):
attn_outputs = attn_outputs[0]
attn_outputs = self.resid_dropout(attn_outputs)
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
hidden_states = attn_outputs + feed_forward_hidden_states + residual
return hidden_states
class CausalLMHead(nn.Module):
"""Causal Language Modeling head.
Reference:
Improving Language Understanding by Generative Pre-Training.
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
"""
def __init__(self, config: PretrainedConfig) -> None:
super().__init__()
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.linear = nn.Linear(config.n_embd, config.vocab_size)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = self.ln(hidden_states)
logits = self.linear(hidden_states).to(torch.float32)
return logits
class CausalLMLoss(nn.Module):
"""Causal Language Modeling loss.
Reference:
Improving Language Understanding by Generative Pre-Training.
https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf.
"""
def __init__(self, shift_labels: Optional[bool] = True) -> None:
super().__init__()
self.shift_labels = shift_labels
self.loss_fct = nn.CrossEntropyLoss()
def forward(
self, logits: torch.FloatTensor, labels: torch.LongTensor
) -> torch.FloatTensor:
if self.shift_labels:
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
loss = self.loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
return loss
class MixFormerSequentialPreTrainedModel(PreTrainedModel):
"""MixFormer (sequential for DeepSpeed) pre-trained model."""
config_class = MixFormerSequentialConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
def __init__(self, *inputs, **kwargs) -> None:
super().__init__(*inputs, **kwargs)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, **kwargs
) -> Dict[str, Any]:
if "use_cache" in kwargs and not kwargs["use_cache"]:
return {"input_ids": input_ids}
if past_key_values is None or not (
isinstance(past_key_values, InferenceParams)
):
past_key_values = InferenceParams(
max_batch_size=input_ids.shape[0],
max_sequence_len=self.config.n_positions,
sequence_len_offset=0,
batch_size_offset=0,
fused_ft_kernel=False,
key_value_memory_dict={},
)
else:
# assume past_key_values has cached all but last token in input_ids
past_key_values.sequence_len_offset = len(input_ids[0]) - 1
input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past_key_values": past_key_values, **kwargs}
class PackedSequential(nn.Sequential):
def forward(
self,
input,
cu_seqlens: Optional[torch.LongTensor] = None,
max_seqlen: Optional[int] = None,
):
for module in self:
sig = inspect.signature(module.forward)
if "cu_seqlens" in sig.parameters:
input = module(input, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
else:
input = module(input)
return input
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
"""MixFormer (sequential for DeepSpeed) for Causal Language Modeling."""
_keys_to_ignore_on_load_missing = [""]
_keys_to_ignore_on_load_unexpected = [
r"layers\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"
]
_no_split_modules = ["ParallelBlock"]
def __init__(self, config: MixFormerSequentialConfig) -> None:
super().__init__(config)
modules = [Embedding(config)]
block_config = config.architecture
if not isinstance(block_config, list):
block_config = [block_config for _ in range(config.n_layer)]
if config.n_layer != len(block_config):
config.n_layer = len(block_config)
for block_idx, block in enumerate(block_config):
# `block_cls` with `legacy` value is for backward compatibility
# `path` key is for backward compatibility
block = copy.deepcopy(block) or {"block_cls": "parallel"}
block.pop("path", None) or block.pop("block_cls", None)
block["block_idx"] = block_idx
modules.append(ParallelBlock(config, **block))
modules.append(CausalLMHead(config))
self.layers = PackedSequential(*modules)
self.loss = CausalLMLoss()
self.post_init()
def get_input_embeddings(self) -> nn.Embedding:
return self.layers[0].wte
def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
self.layers[0].wte = new_embeddings
def get_output_embeddings(self) -> nn.Linear:
return self.layers[-1].linear
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.layers[-1].linear = new_embeddings
def forward(
self,
input_ids: torch.LongTensor,
labels: Optional[torch.LongTensor] = None,
past_key_values: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
cu_seqlens: Optional[torch.LongTensor] = None
max_seqlen: Optional[int] = None
if position_ids is not None:
batch_size, seq_length = input_ids.shape
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if not past_key_values:
lm_logits = self.layers(
input_ids, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
)
else:
hidden_layer = self.layers[0](input_ids)
for module in self.layers[1:-1]:
hidden_layer = module(
hidden_layer,
past_cache=past_key_values,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
lm_logits = self.layers[-1](hidden_layer)
loss = None
if labels is not None:
loss = self.loss(lm_logits, labels)
return CausalLMOutputWithPast(
loss=loss, logits=lm_logits, past_key_values=past_key_values
)

View File

@@ -0,0 +1,66 @@
"""
Flash attention monkey patch for cerebras btlm model
"""
import importlib
import logging
from typing import Optional, Tuple
import torch
from accelerate import init_empty_weights
from flash_attn.flash_attn_interface import flash_attn_func
from transformers import AutoConfig, AutoModelForCausalLM
LOG = logging.getLogger("axolotl")
def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
# this is a wonky hack to get the remotely loaded module
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
# we need to load the model here in order for modeling_btlm to be available
with init_empty_weights():
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
module_name = model_config.__class__.__module__.replace(
".configuration_btlm", ".modeling_btlm"
)
modeling_btlm = importlib.import_module(module_name)
modeling_btlm.BTLMAttention._attn = ( # pylint: disable=protected-access
flashattn_attn
)
def flashattn_attn(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
head_mask: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
softmax_scale = (
1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None
)
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
# Perform Flash attention
attn_output = flash_attn_func(
query,
key,
value,
dropout_p=0.0, # Assuming you have this attribute
softmax_scale=softmax_scale, # Set this if you have specific scaling in mind
causal=not self.is_cross_attention, # Assuming you have this attribute
return_attn_probs=False, # Set this based on your needs
)
# Optional: Apply head mask if it's not None
if head_mask is not None:
attn_output *= head_mask
attn_output = attn_output.permute(0, 2, 1, 3)
return attn_output, None # We don't have explicit attn_weights in Flash attention

View File

@@ -0,0 +1,174 @@
"""
monkeypatch to add a get_turns method
"""
import logging
from typing import Generator, Tuple
from fastchat.conversation import SeparatorStyle
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
def get_prompt(self) -> str:
ret = ""
for role, msg in self.get_turns():
ret += role + msg
return ret
def get_turns( # pylint: disable=too-many-return-statements
self,
) -> Generator[Tuple[str, str], None, None]:
"""Get the prompt for generation."""
system_prompt = self.system_template.format(system_message=self.system_message)
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message + seps[i % 2]
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ": ", "" # must be end with a space
return
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
yield "", "" if system_prompt == "" else system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role, message + self.sep
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role, message + seps[i % 2]
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.RWKV:
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
yield role + ": ", message.replace("\r\n", "\n").replace(
"\n\n", "\n"
) + "\n\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
if self.system_message:
yield "", system_prompt
else:
yield "", "[INST] "
for i, (role, message) in enumerate(self.messages[1:]):
if message:
yield role + " ", message + seps[i % 2]
else:
yield role, ""
return
if self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
round_add_n = 1 if self.name == "chatglm2" else 0
if system_prompt:
yield "", system_prompt + self.sep
for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
if message:
yield f"{role}", f"{message}{self.sep}"
else:
yield f"{role}", ""
return
if self.sep_style == SeparatorStyle.CHATML:
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
for role, message in self.messages:
if message:
yield role + "\n", message + self.sep + "\n"
else:
yield role + "\n", ""
return
if self.sep_style == SeparatorStyle.CHATINTERN:
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
prefix = "<s>" if i % 2 == 0 else ""
if message:
yield prefix + role + ":", message + seps[i % 2] + "\n"
else:
yield role + ":", ""
return
if self.sep_style == SeparatorStyle.DOLLY:
seps = [self.sep, self.sep2]
yield "", system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
suffix = "\n\n" if i % 2 == 1 else ""
yield role + ":\n", message + seps[i % 2] + suffix
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.PHOENIX:
yield "", system_prompt
for role, message in self.messages:
if message:
yield role + ": ", "<s>" + message + "</s>"
else:
yield role + ": " + "<s>", ""
return
if self.sep_style == SeparatorStyle.ROBIN:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ":\n", message + self.sep
else:
yield role + ":\n", ""
return
if self.sep_style == SeparatorStyle.FALCON_CHAT:
if self.system_message:
yield "", system_prompt + self.sep
for role, message in self.messages:
if message:
yield role + ": ", message + self.sep
else:
yield role + ":", ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def add_get_turns_to_conversation():
import fastchat.conversation
fastchat.conversation.Conversation.get_turns = get_turns
fastchat.conversation.Conversation.get_prompt = get_prompt

View File

@@ -0,0 +1,782 @@
"""Flash attention monkey patch for llama model"""
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
import logging
import warnings
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
)
from transformers.models.llama.modeling_llama import (
LlamaMLP,
apply_rotary_pos_emb,
repeat_kv,
)
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
except ImportError:
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
)
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
)
LOG = logging.getLogger("axolotl")
def replace_llama_mlp_with_swiglu(model):
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
mlp = FusedMLP(
module.config, module.gate_proj, module.up_proj, module.down_proj
)
set_module_name(model, name, mlp)
def replace_llama_qkv_with_fused(model):
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
qkv = FusedAttention(
module.config,
module.q_proj,
module.k_proj,
module.v_proj,
module.o_proj,
)
set_module_name(model, name, qkv)
def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
rms_norm: Optional[bool] = False,
):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
if packed:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = (
llama_model_forward
)
# skip only if explicitly disabled
if cross_entropy:
try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
)
# skip only if explicitly disabled
if rms_norm:
try:
from flash_attn.ops.rms_norm import RMSNorm
class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""
def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)
LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.info(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
class FusedAttention(LlamaAttention):
"""
Fused QKV Attention layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
q: torch.nn.Linear, # pylint: disable=invalid-name
k: torch.nn.Linear, # pylint: disable=invalid-name
v: torch.nn.Linear, # pylint: disable=invalid-name
o: torch.nn.Linear, # pylint: disable=invalid-name
):
super().__init__(config)
self.config = config
self.init_device = next(iter(q.state_dict().values())).device
# define equivalent fused qkv projection
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
self.qkv_proj = torch.nn.Linear(
q.in_features, sum(self.out_features), device=self.init_device, bias=False
)
self.o_proj = o
# overwrite initialized weights with pretrained weights
self.qkv_proj.weight.data = torch.cat(
(q.weight.data, k.weight.data, v.weight.data), dim=0
)
def _post_training(self, model, name):
q_proj, k_proj, v_proj = torch.split(
self.qkv_proj.weight.data, self.out_features, dim=0
)
new_attn = LlamaAttention(self.config)
new_attn.q_proj.weight.data = q_proj
new_attn.k_proj.weight.data = k_proj
new_attn.v_proj.weight.data = v_proj
new_attn.o_proj.weight.data = self.o_proj.weight.data
set_module_name(model, name, new_attn)
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
): # pylint: disable=unused-argument
# [bsz, seq_len]
return attention_mask
def flashattn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
if isinstance(self, FusedAttention):
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
self.out_features, dim=-1
)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# flash-attn v2 start
#
if self.training:
# during training q,k,v always have same seqlen
assert key_states.shape == query_states.shape
is_causal = True
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = key_states.shape == query_states.shape
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
dropout_p=dropout_rate,
softmax_scale=None,
causal=True,
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
query_states,
key_states,
value_states,
qkvpacked=True,
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
cu_seqlens_q,
max_seqlen_q,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if attention_mask is None or attention_mask.all().item():
output = flash_attn_kvpacked_func(
query_states,
torch.stack([key_states, value_states], 2),
dropout_p=dropout_rate,
causal=is_causal,
)
else:
( # pylint: disable=unbalanced-tuple-unpacking
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
_,
_,
output_pad_fn,
) = generate_qkv(
query_states,
key_states,
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=dropout_rate,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
attn_output = output
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
#
# flash-attn v2 end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
def generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask
)
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q_unpad.device,
)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=k_unpad.device,
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
if kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
return (
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
)
return (
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
)
def llama_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
cu_seqlens = None
max_seqlen = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
padding_mask = None
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
transformers.logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(
*inputs,
)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
None,
padding_mask,
cu_seqlens,
max_seqlen,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
"""
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs

Some files were not shown because too many files have changed in this diff Show More